mpn_mul_fft type overflow issue
Mark Sofroniou
marks at wolfram.com
Wed Sep 25 09:39:16 CEST 2013
Hi Torbjorn.
I appreciate your comments and the patch you sent. I'll be sure to
use your head branch as a starting point in the future - sorry about that
I'm still learning about your development workflow so thanks for your
patience.
Attached is a new patch that has been retested.
Here is a description of the changes from your patch, hopefully to save you
a bit of time,
All capital K values (K, K2, K3 etc) are determined from a small k shift so
they are now consistently of type mp_size_t. I changed a few function APIs
to reflect that (e.g. mpn_mul_fft_decompose, mpn_fft_fft, fft_mul_modF_K,
mpn_fft_fftinv).
Minor change to mpn_fft_fft to make it consistent with mpn_fft_fftinv.
A good compiler will probably make the optimization, but it doesn't hurt.
Added a couple of size_t casts to prevent potential overflow in
TMP_BALLOC_TYPE:
tmp = TMP_BALLOC_TYPE (2 << k, int);
is now:
tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int);
I didn't have enough memory to trigger this but I don't see why, for a
large enough
allocation, k couldn't eventually be 31 in value and this will prevent
int overflow.
I changed the variable sh in mpn_fft_mul_2exp_modF from int to
unsigned int to match the shift type of mpn_lshift in which it is used.
Maybe this should be mp_bitcnt_t? This depends on what your plans are
for the API of mpn_lshift - I don't know if that will eventually also
have an mp_bitcnt_t for the shift. It's really not important - the value is
tiny as you pointed out already.
In the code mpn_mul_fft_full the variables cc, c2, oldcc were int type.
I changed these to mp_size_t.
The only remaining int types are related to squaring or parameters
related to
the fft table.
Regards, Mark
-------------- next part --------------
diff -r fff25440e878 mpn/generic/mul_fft.c
--- a/mpn/generic/mul_fft.c Tue Sep 24 16:15:39 2013 +0200
+++ b/mpn/generic/mul_fft.c Wed Sep 25 02:26:50 2013 -0500
@@ -67,8 +67,8 @@
static mp_limb_t mpn_mul_fft_internal (mp_ptr, mp_size_t, int, mp_ptr *,
mp_ptr *, mp_ptr, mp_ptr, mp_size_t,
mp_size_t, mp_size_t, int **, mp_ptr, int);
-static void mpn_mul_fft_decompose (mp_ptr, mp_ptr *, int, int, mp_srcptr,
- mp_size_t, int, int, mp_ptr);
+static void mpn_mul_fft_decompose (mp_ptr, mp_ptr *, mp_size_t, mp_size_t, mp_srcptr,
+ mp_size_t, mp_size_t, mp_size_t, mp_ptr);
/* Find the best k to use for a mod 2^(m*GMP_NUMB_BITS)+1 FFT for m >= n.
@@ -192,38 +192,39 @@
r and a must have n+1 limbs, and not overlap.
*/
static void
-mpn_fft_mul_2exp_modF (mp_ptr r, mp_srcptr a, unsigned int d, mp_size_t n)
+mpn_fft_mul_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t d, mp_size_t n)
{
- int sh;
+ unsigned int sh;
+ mp_size_t m;
mp_limb_t cc, rd;
sh = d % GMP_NUMB_BITS;
- d /= GMP_NUMB_BITS;
+ m = d / GMP_NUMB_BITS;
- if (d >= n) /* negate */
+ if (m >= n) /* negate */
{
- /* r[0..d-1] <-- lshift(a[n-d]..a[n-1], sh)
- r[d..n-1] <-- -lshift(a[0]..a[n-d-1], sh) */
+ /* r[0..m-1] <-- lshift(a[n-m]..a[n-1], sh)
+ r[m..n-1] <-- -lshift(a[0]..a[n-m-1], sh) */
- d -= n;
+ m -= n;
if (sh != 0)
{
/* no out shift below since a[n] <= 1 */
- mpn_lshift (r, a + n - d, d + 1, sh);
- rd = r[d];
- cc = mpn_lshiftc (r + d, a, n - d, sh);
+ mpn_lshift (r, a + n - m, m + 1, sh);
+ rd = r[m];
+ cc = mpn_lshiftc (r + m, a, n - m, sh);
}
else
{
- MPN_COPY (r, a + n - d, d);
+ MPN_COPY (r, a + n - m, m);
rd = a[n];
- mpn_com (r + d, a, n - d);
+ mpn_com (r + m, a, n - m);
cc = 0;
}
- /* add cc to r[0], and add rd to r[d] */
+ /* add cc to r[0], and add rd to r[m] */
- /* now add 1 in r[d], subtract 1 in r[n], i.e. add 1 in r[0] */
+ /* now add 1 in r[m], subtract 1 in r[n], i.e. add 1 in r[0] */
r[n] = 0;
/* cc < 2^sh <= 2^(GMP_NUMB_BITS-1) thus no overflow here */
@@ -233,46 +234,46 @@
rd++;
/* rd might overflow when sh=GMP_NUMB_BITS-1 */
cc = (rd == 0) ? 1 : rd;
- r = r + d + (rd == 0);
+ r = r + m + (rd == 0);
mpn_incr_u (r, cc);
}
else
{
- /* r[0..d-1] <-- -lshift(a[n-d]..a[n-1], sh)
- r[d..n-1] <-- lshift(a[0]..a[n-d-1], sh) */
+ /* r[0..m-1] <-- -lshift(a[n-m]..a[n-1], sh)
+ r[m..n-1] <-- lshift(a[0]..a[n-m-1], sh) */
if (sh != 0)
{
/* no out bits below since a[n] <= 1 */
- mpn_lshiftc (r, a + n - d, d + 1, sh);
- rd = ~r[d];
- /* {r, d+1} = {a+n-d, d+1} << sh */
- cc = mpn_lshift (r + d, a, n - d, sh); /* {r+d, n-d} = {a, n-d}<<sh */
+ mpn_lshiftc (r, a + n - m, m + 1, sh);
+ rd = ~r[m];
+ /* {r, m+1} = {a+n-m, m+1} << sh */
+ cc = mpn_lshift (r + m, a, n - m, sh); /* {r+m, n-m} = {a, n-m}<<sh */
}
else
{
- /* r[d] is not used below, but we save a test for d=0 */
- mpn_com (r, a + n - d, d + 1);
+ /* r[m] is not used below, but we save a test for m=0 */
+ mpn_com (r, a + n - m, m + 1);
rd = a[n];
- MPN_COPY (r + d, a, n - d);
+ MPN_COPY (r + m, a, n - m);
cc = 0;
}
- /* now complement {r, d}, subtract cc from r[0], subtract rd from r[d] */
+ /* now complement {r, m}, subtract cc from r[0], subtract rd from r[m] */
- /* if d=0 we just have r[0]=a[n] << sh */
- if (d != 0)
+ /* if m=0 we just have r[0]=a[n] << sh */
+ if (m != 0)
{
- /* now add 1 in r[0], subtract 1 in r[d] */
+ /* now add 1 in r[0], subtract 1 in r[m] */
if (cc-- == 0) /* then add 1 to r[0] */
cc = mpn_add_1 (r, r, n, CNST_LIMB(1));
- cc = mpn_sub_1 (r, r, d, cc) + 1;
+ cc = mpn_sub_1 (r, r, m, cc) + 1;
/* add 1 to cc instead of rd since rd might overflow */
}
- /* now subtract cc and rd from r[d..n] */
+ /* now subtract cc and rd from r[m..n] */
- r[n] = -mpn_sub_1 (r + d, r + d, n - d, cc);
- r[n] -= mpn_sub_1 (r + d, r + d, n - d, rd);
+ r[n] = -mpn_sub_1 (r + m, r + m, n - m, cc);
+ r[n] -= mpn_sub_1 (r + m, r + m, n - m, rd);
if (r[n] & GMP_LIMB_HIGHBIT)
r[n] = mpn_add_1 (r, r, n, CNST_LIMB(1));
}
@@ -283,7 +284,7 @@
Assumes a and b are semi-normalized.
*/
static inline void
-mpn_fft_add_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, int n)
+mpn_fft_add_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n)
{
mp_limb_t c, x;
@@ -314,7 +315,7 @@
Assumes a and b are semi-normalized.
*/
static inline void
-mpn_fft_sub_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, int n)
+mpn_fft_sub_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n)
{
mp_limb_t c, x;
@@ -366,14 +367,14 @@
}
else
{
- int j;
+ mp_size_t j, K2 = K >> 1;
int *lk = *ll;
- mpn_fft_fft (Ap, K >> 1, ll-1, 2 * omega, n, inc * 2, tp);
- mpn_fft_fft (Ap+inc, K >> 1, ll-1, 2 * omega, n, inc * 2, tp);
+ mpn_fft_fft (Ap, K2, ll-1, 2 * omega, n, inc * 2, tp);
+ mpn_fft_fft (Ap+inc, K2, ll-1, 2 * omega, n, inc * 2, tp);
/* A[2*j*inc] <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc]
A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */
- for (j = 0; j < (K >> 1); j++, lk += 2, Ap += 2 * inc)
+ for (j = 0; j < K2; j++, lk += 2, Ap += 2 * inc)
{
/* Ap[inc] <- Ap[0] + Ap[inc] * 2^(lk[1] * omega)
Ap[0] <- Ap[0] + Ap[inc] * 2^(lk[0] * omega) */
@@ -418,7 +419,7 @@
/* a[i] <- a[i]*b[i] mod 2^(n*GMP_NUMB_BITS)+1 for 0 <= i < K */
static void
-mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, int K)
+mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, mp_size_t K)
{
int i;
int sqr = (ap == bp);
@@ -428,12 +429,13 @@
if (n >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
{
- int k, K2, nprime2, Nprime2, M2, maxLK, l, Mp2;
+ mp_size_t K2, nprime2, Nprime2, M2, maxLK, l, Mp2;
+ int k;
int **fft_l, *tmp;
mp_ptr *Ap, *Bp, A, B, T;
k = mpn_fft_best_k (n, sqr);
- K2 = 1 << k;
+ K2 = (mp_size_t) 1 << k;
ASSERT_ALWAYS((n & (K2 - 1)) == 0);
maxLK = (K2 > GMP_NUMB_BITS) ? K2 : GMP_NUMB_BITS;
M2 = n * GMP_NUMB_BITS >> k;
@@ -445,10 +447,10 @@
/* we should ensure that nprime2 is a multiple of the next K */
if (nprime2 >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
{
- unsigned long K3;
+ mp_size_t K3;
for (;;)
{
- K3 = 1L << mpn_fft_best_k (nprime2, sqr);
+ K3 = (mp_size_t) 1 << mpn_fft_best_k (nprime2, sqr);
if ((nprime2 & (K3 - 1)) == 0)
break;
nprime2 = (nprime2 + K3 - 1) & -K3;
@@ -466,16 +468,16 @@
T = TMP_BALLOC_LIMBS (2 * (nprime2 + 1));
B = A + ((nprime2 + 1) << k);
fft_l = TMP_BALLOC_TYPE (k + 1, int *);
- tmp = TMP_BALLOC_TYPE (2 << k, int);
+ tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int);
for (i = 0; i <= k; i++)
{
fft_l[i] = tmp;
- tmp += 1 << i;
+ tmp += (mp_size_t) 1 << i;
}
mpn_fft_initl (fft_l, k);
- TRACE (printf ("recurse: %ldx%ld limbs -> %d times %dx%d (%1.2f)\n", n,
+ TRACE (printf ("recurse: %ldx%ld limbs -> %ld times %ldx%ld (%1.2f)\n", n,
n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2));
for (i = 0; i < K; i++, ap++, bp++)
{
@@ -497,10 +499,10 @@
{
mp_ptr a, b, tp, tpn;
mp_limb_t cc;
- int n2 = 2 * n;
+ mp_size_t n2 = 2 * n;
tp = TMP_BALLOC_LIMBS (n2);
tpn = tp + n;
- TRACE (printf (" mpn_mul_n %d of %ld limbs\n", K, n));
+ TRACE (printf (" mpn_mul_n %ld of %ld limbs\n", K, n));
for (i = 0; i < K; i++)
{
a = *ap++;
@@ -534,7 +536,7 @@
This condition is also fulfilled at exit.
*/
static void
-mpn_fft_fftinv (mp_ptr *Ap, int K, mp_size_t omega, mp_size_t n, mp_ptr tp)
+mpn_fft_fftinv (mp_ptr *Ap, mp_size_t K, mp_size_t omega, mp_size_t n, mp_ptr tp)
{
if (K == 2)
{
@@ -553,7 +555,7 @@
}
else
{
- int j, K2 = K >> 1;
+ mp_size_t j, K2 = K >> 1;
mpn_fft_fftinv (Ap, K2, 2 * omega, n, tp);
mpn_fft_fftinv (Ap + K2, K2, 2 * omega, n, tp);
@@ -573,12 +575,12 @@
/* R <- A/2^k mod 2^(n*GMP_NUMB_BITS)+1 */
static void
-mpn_fft_div_2exp_modF (mp_ptr r, mp_srcptr a, int k, mp_size_t n)
+mpn_fft_div_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t k, mp_size_t n)
{
- int i;
+ mp_bitcnt_t i;
ASSERT (r != a);
- i = 2 * n * GMP_NUMB_BITS - k;
+ i = (mp_bitcnt_t) 2 * n * GMP_NUMB_BITS - k;
mpn_fft_mul_2exp_modF (r, a, i, n);
/* 1/2^k = 2^(2nL-k) mod 2^(n*GMP_NUMB_BITS)+1 */
/* normalize so that R < 2^(n*GMP_NUMB_BITS)+1 */
@@ -590,13 +592,11 @@
Returns carry out, i.e. 1 iff {ap,an} = -1 mod 2^(n*GMP_NUMB_BITS)+1,
then {rp,n}=0.
*/
-static int
+static mp_size_t
mpn_fft_norm_modF (mp_ptr rp, mp_size_t n, mp_ptr ap, mp_size_t an)
{
- mp_size_t l;
- long int m;
+ mp_size_t l, m, rpn;
mp_limb_t cc;
- int rpn;
ASSERT ((n <= an) && (an <= 3 * n));
m = an - 2 * n;
@@ -630,10 +630,11 @@
We must have nl <= 2*K*l.
*/
static void
-mpn_mul_fft_decompose (mp_ptr A, mp_ptr *Ap, int K, int nprime, mp_srcptr n,
- mp_size_t nl, int l, int Mp, mp_ptr T)
+mpn_mul_fft_decompose (mp_ptr A, mp_ptr *Ap, mp_size_t K, mp_size_t nprime,
+ mp_srcptr n, mp_size_t nl, mp_size_t l, mp_size_t Mp,
+ mp_ptr T)
{
- int i, j;
+ mp_size_t i, j;
mp_ptr tmp;
mp_size_t Kl = K * l;
TMP_DECL;
@@ -717,11 +718,11 @@
mp_size_t nprime, mp_size_t l, mp_size_t Mp,
int **fft_l, mp_ptr T, int sqr)
{
- int K, i, pla, lo, sh, j;
+ mp_size_t K, i, pla, lo, sh, j;
mp_ptr p;
mp_limb_t cc;
- K = 1 << k;
+ K = (mp_size_t) 1 << k;
/* direct fft's */
mpn_fft_fft (Ap, K, fft_l + k, 2 * Mp, nprime, 1, T);
@@ -797,10 +798,10 @@
}
/* return the lcm of a and 2^k */
-static unsigned long int
-mpn_mul_fft_lcm (unsigned long int a, unsigned int k)
+static mp_bitcnt_t
+mpn_mul_fft_lcm (mp_bitcnt_t a, int k)
{
- unsigned long int l = k;
+ mp_bitcnt_t l = k;
while (a % 2 == 0 && k > 0)
{
@@ -817,7 +818,8 @@
mp_srcptr m, mp_size_t ml,
int k)
{
- int K, maxLK, i;
+ int i;
+ mp_size_t K, maxLK;
mp_size_t N, Nprime, nprime, M, Mp, l;
mp_ptr *Ap, *Bp, A, T, B;
int **fft_l, *tmp;
@@ -831,45 +833,45 @@
TMP_MARK;
N = pl * GMP_NUMB_BITS;
fft_l = TMP_BALLOC_TYPE (k + 1, int *);
- tmp = TMP_BALLOC_TYPE (2 << k, int);
+ tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int);
for (i = 0; i <= k; i++)
{
fft_l[i] = tmp;
- tmp += 1 << i;
+ tmp += (mp_size_t) 1 << i;
}
mpn_fft_initl (fft_l, k);
- K = 1 << k;
+ K = (mp_size_t) 1 << k;
M = N >> k; /* N = 2^k M */
l = 1 + (M - 1) / GMP_NUMB_BITS;
- maxLK = mpn_mul_fft_lcm ((unsigned long) GMP_NUMB_BITS, k); /* lcm (GMP_NUMB_BITS, 2^k) */
+ maxLK = mpn_mul_fft_lcm (GMP_NUMB_BITS, k); /* lcm (GMP_NUMB_BITS, 2^k) */
Nprime = (1 + (2 * M + k + 2) / maxLK) * maxLK;
/* Nprime = ceil((2*M+k+3)/maxLK)*maxLK; */
nprime = Nprime / GMP_NUMB_BITS;
- TRACE (printf ("N=%ld K=%d, M=%ld, l=%ld, maxLK=%d, Np=%ld, np=%ld\n",
+ TRACE (printf ("N=%ld K=%ld, M=%ld, l=%ld, maxLK=%ld, Np=%ld, np=%ld\n",
N, K, M, l, maxLK, Nprime, nprime));
/* we should ensure that recursively, nprime is a multiple of the next K */
if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
{
- unsigned long K2;
+ mp_size_t K2;
for (;;)
{
- K2 = 1L << mpn_fft_best_k (nprime, sqr);
+ K2 = (mp_size_t) 1 << mpn_fft_best_k (nprime, sqr);
if ((nprime & (K2 - 1)) == 0)
break;
nprime = (nprime + K2 - 1) & -K2;
Nprime = nprime * GMP_LIMB_BITS;
/* warning: since nprime changed, K2 may change too! */
}
- TRACE (printf ("new maxLK=%d, Np=%ld, np=%ld\n", maxLK, Nprime, nprime));
+ TRACE (printf ("new maxLK=%ld, Np=%ld, np=%ld\n", maxLK, Nprime, nprime));
}
ASSERT_ALWAYS (nprime < pl); /* otherwise we'll loop */
T = TMP_BALLOC_LIMBS (2 * (nprime + 1));
Mp = Nprime >> k;
- TRACE (printf ("%ldx%ld limbs -> %d times %ldx%ld limbs (%1.2f)\n",
+ TRACE (printf ("%ldx%ld limbs -> %ld times %ldx%ld limbs (%1.2f)\n",
pl, pl, K, nprime, nprime, 2.0 * (double) N / Nprime / K);
printf (" temp space %ld\n", 2 * K * (nprime + 1)));
@@ -904,9 +906,9 @@
{
mp_ptr pad_op;
mp_size_t pl, pl2, pl3, l;
+ mp_size_t cc, c2, oldcc;
int k2, k3;
int sqr = (n == m && nl == ml);
- int cc, c2, oldcc;
pl = nl + ml; /* total number of limbs of the result */
More information about the gmp-devel
mailing list