mpn_mul_fft type overflow issue

Mark Sofroniou marks at wolfram.com
Wed Sep 25 09:39:16 CEST 2013


Hi Torbjorn.

I appreciate your comments and the patch you sent. I'll be sure to
use your head branch as a starting point in the future - sorry about that
I'm still learning about your development workflow so thanks for your
patience.

Attached is a new patch that has been retested.

Here is a description of the changes from your patch, hopefully to save you
a bit of time,

All capital K values (K, K2, K3 etc) are determined from a small k shift so
they are now consistently of type mp_size_t. I changed a few function APIs
to reflect that (e.g. mpn_mul_fft_decompose, mpn_fft_fft, fft_mul_modF_K,
mpn_fft_fftinv).

Minor change to mpn_fft_fft to make it consistent with mpn_fft_fftinv.
A good compiler will probably make the optimization, but it doesn't hurt.

Added a couple of size_t casts to prevent potential overflow in 
TMP_BALLOC_TYPE:

       tmp = TMP_BALLOC_TYPE (2 << k, int);

is now:

       tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int);

I didn't have enough memory to trigger this but I don't see why, for a 
large enough
allocation, k couldn't eventually be 31 in value and this will prevent 
int overflow.

I changed the variable sh in mpn_fft_mul_2exp_modF from int to
unsigned int to match the shift type of mpn_lshift in which it is used.
Maybe this should be mp_bitcnt_t? This depends on what your plans are
for the API of mpn_lshift - I don't know if that will eventually also
have an mp_bitcnt_t for the shift. It's really not important - the value is
tiny as you pointed out already.

In the code mpn_mul_fft_full the variables cc, c2, oldcc were int type.
I changed these to mp_size_t.

The only remaining int types are related to squaring or parameters 
related to
the fft table.

Regards, Mark

-------------- next part --------------
diff -r fff25440e878 mpn/generic/mul_fft.c
--- a/mpn/generic/mul_fft.c	Tue Sep 24 16:15:39 2013 +0200
+++ b/mpn/generic/mul_fft.c	Wed Sep 25 02:26:50 2013 -0500
@@ -67,8 +67,8 @@
 static mp_limb_t mpn_mul_fft_internal (mp_ptr, mp_size_t, int, mp_ptr *,
 				       mp_ptr *, mp_ptr, mp_ptr, mp_size_t,
 				       mp_size_t, mp_size_t, int **, mp_ptr, int);
-static void mpn_mul_fft_decompose (mp_ptr, mp_ptr *, int, int, mp_srcptr,
-				   mp_size_t, int, int, mp_ptr);
+static void mpn_mul_fft_decompose (mp_ptr, mp_ptr *, mp_size_t, mp_size_t, mp_srcptr,
+				   mp_size_t, mp_size_t, mp_size_t, mp_ptr);
 
 
 /* Find the best k to use for a mod 2^(m*GMP_NUMB_BITS)+1 FFT for m >= n.
@@ -192,38 +192,39 @@
    r and a must have n+1 limbs, and not overlap.
 */
 static void
-mpn_fft_mul_2exp_modF (mp_ptr r, mp_srcptr a, unsigned int d, mp_size_t n)
+mpn_fft_mul_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t d, mp_size_t n)
 {
-  int sh;
+  unsigned int sh;
+  mp_size_t m;
   mp_limb_t cc, rd;
 
   sh = d % GMP_NUMB_BITS;
-  d /= GMP_NUMB_BITS;
+  m = d / GMP_NUMB_BITS;
 
-  if (d >= n)			/* negate */
+  if (m >= n)			/* negate */
     {
-      /* r[0..d-1]  <-- lshift(a[n-d]..a[n-1], sh)
-	 r[d..n-1]  <-- -lshift(a[0]..a[n-d-1],  sh) */
+      /* r[0..m-1]  <-- lshift(a[n-m]..a[n-1], sh)
+	 r[m..n-1]  <-- -lshift(a[0]..a[n-m-1],  sh) */
 
-      d -= n;
+      m -= n;
       if (sh != 0)
 	{
 	  /* no out shift below since a[n] <= 1 */
-	  mpn_lshift (r, a + n - d, d + 1, sh);
-	  rd = r[d];
-	  cc = mpn_lshiftc (r + d, a, n - d, sh);
+	  mpn_lshift (r, a + n - m, m + 1, sh);
+	  rd = r[m];
+	  cc = mpn_lshiftc (r + m, a, n - m, sh);
 	}
       else
 	{
-	  MPN_COPY (r, a + n - d, d);
+	  MPN_COPY (r, a + n - m, m);
 	  rd = a[n];
-	  mpn_com (r + d, a, n - d);
+	  mpn_com (r + m, a, n - m);
 	  cc = 0;
 	}
 
-      /* add cc to r[0], and add rd to r[d] */
+      /* add cc to r[0], and add rd to r[m] */
 
-      /* now add 1 in r[d], subtract 1 in r[n], i.e. add 1 in r[0] */
+      /* now add 1 in r[m], subtract 1 in r[n], i.e. add 1 in r[0] */
 
       r[n] = 0;
       /* cc < 2^sh <= 2^(GMP_NUMB_BITS-1) thus no overflow here */
@@ -233,46 +234,46 @@
       rd++;
       /* rd might overflow when sh=GMP_NUMB_BITS-1 */
       cc = (rd == 0) ? 1 : rd;
-      r = r + d + (rd == 0);
+      r = r + m + (rd == 0);
       mpn_incr_u (r, cc);
     }
   else
     {
-      /* r[0..d-1]  <-- -lshift(a[n-d]..a[n-1], sh)
-	 r[d..n-1]  <-- lshift(a[0]..a[n-d-1],  sh)  */
+      /* r[0..m-1]  <-- -lshift(a[n-m]..a[n-1], sh)
+	 r[m..n-1]  <-- lshift(a[0]..a[n-m-1],  sh)  */
       if (sh != 0)
 	{
 	  /* no out bits below since a[n] <= 1 */
-	  mpn_lshiftc (r, a + n - d, d + 1, sh);
-	  rd = ~r[d];
-	  /* {r, d+1} = {a+n-d, d+1} << sh */
-	  cc = mpn_lshift (r + d, a, n - d, sh); /* {r+d, n-d} = {a, n-d}<<sh */
+	  mpn_lshiftc (r, a + n - m, m + 1, sh);
+	  rd = ~r[m];
+	  /* {r, m+1} = {a+n-m, m+1} << sh */
+	  cc = mpn_lshift (r + m, a, n - m, sh); /* {r+m, n-m} = {a, n-m}<<sh */
 	}
       else
 	{
-	  /* r[d] is not used below, but we save a test for d=0 */
-	  mpn_com (r, a + n - d, d + 1);
+	  /* r[m] is not used below, but we save a test for m=0 */
+	  mpn_com (r, a + n - m, m + 1);
 	  rd = a[n];
-	  MPN_COPY (r + d, a, n - d);
+	  MPN_COPY (r + m, a, n - m);
 	  cc = 0;
 	}
 
-      /* now complement {r, d}, subtract cc from r[0], subtract rd from r[d] */
+      /* now complement {r, m}, subtract cc from r[0], subtract rd from r[m] */
 
-      /* if d=0 we just have r[0]=a[n] << sh */
-      if (d != 0)
+      /* if m=0 we just have r[0]=a[n] << sh */
+      if (m != 0)
 	{
-	  /* now add 1 in r[0], subtract 1 in r[d] */
+	  /* now add 1 in r[0], subtract 1 in r[m] */
 	  if (cc-- == 0) /* then add 1 to r[0] */
 	    cc = mpn_add_1 (r, r, n, CNST_LIMB(1));
-	  cc = mpn_sub_1 (r, r, d, cc) + 1;
+	  cc = mpn_sub_1 (r, r, m, cc) + 1;
 	  /* add 1 to cc instead of rd since rd might overflow */
 	}
 
-      /* now subtract cc and rd from r[d..n] */
+      /* now subtract cc and rd from r[m..n] */
 
-      r[n] = -mpn_sub_1 (r + d, r + d, n - d, cc);
-      r[n] -= mpn_sub_1 (r + d, r + d, n - d, rd);
+      r[n] = -mpn_sub_1 (r + m, r + m, n - m, cc);
+      r[n] -= mpn_sub_1 (r + m, r + m, n - m, rd);
       if (r[n] & GMP_LIMB_HIGHBIT)
 	r[n] = mpn_add_1 (r, r, n, CNST_LIMB(1));
     }
@@ -283,7 +284,7 @@
    Assumes a and b are semi-normalized.
 */
 static inline void
-mpn_fft_add_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, int n)
+mpn_fft_add_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n)
 {
   mp_limb_t c, x;
 
@@ -314,7 +315,7 @@
    Assumes a and b are semi-normalized.
 */
 static inline void
-mpn_fft_sub_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, int n)
+mpn_fft_sub_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n)
 {
   mp_limb_t c, x;
 
@@ -366,14 +367,14 @@
     }
   else
     {
-      int j;
+      mp_size_t j, K2 = K >> 1;
       int *lk = *ll;
 
-      mpn_fft_fft (Ap,     K >> 1, ll-1, 2 * omega, n, inc * 2, tp);
-      mpn_fft_fft (Ap+inc, K >> 1, ll-1, 2 * omega, n, inc * 2, tp);
+      mpn_fft_fft (Ap,     K2, ll-1, 2 * omega, n, inc * 2, tp);
+      mpn_fft_fft (Ap+inc, K2, ll-1, 2 * omega, n, inc * 2, tp);
       /* A[2*j*inc]   <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc]
 	 A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */
-      for (j = 0; j < (K >> 1); j++, lk += 2, Ap += 2 * inc)
+      for (j = 0; j < K2; j++, lk += 2, Ap += 2 * inc)
 	{
 	  /* Ap[inc] <- Ap[0] + Ap[inc] * 2^(lk[1] * omega)
 	     Ap[0]   <- Ap[0] + Ap[inc] * 2^(lk[0] * omega) */
@@ -418,7 +419,7 @@
 
 /* a[i] <- a[i]*b[i] mod 2^(n*GMP_NUMB_BITS)+1 for 0 <= i < K */
 static void
-mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, int K)
+mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, mp_size_t K)
 {
   int i;
   int sqr = (ap == bp);
@@ -428,12 +429,13 @@
 
   if (n >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
     {
-      int k, K2, nprime2, Nprime2, M2, maxLK, l, Mp2;
+      mp_size_t K2, nprime2, Nprime2, M2, maxLK, l, Mp2;
+      int k;
       int **fft_l, *tmp;
       mp_ptr *Ap, *Bp, A, B, T;
 
       k = mpn_fft_best_k (n, sqr);
-      K2 = 1 << k;
+      K2 = (mp_size_t) 1 << k;
       ASSERT_ALWAYS((n & (K2 - 1)) == 0);
       maxLK = (K2 > GMP_NUMB_BITS) ? K2 : GMP_NUMB_BITS;
       M2 = n * GMP_NUMB_BITS >> k;
@@ -445,10 +447,10 @@
       /* we should ensure that nprime2 is a multiple of the next K */
       if (nprime2 >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
 	{
-	  unsigned long K3;
+	  mp_size_t K3;
 	  for (;;)
 	    {
-	      K3 = 1L << mpn_fft_best_k (nprime2, sqr);
+	      K3 = (mp_size_t) 1 << mpn_fft_best_k (nprime2, sqr);
 	      if ((nprime2 & (K3 - 1)) == 0)
 		break;
 	      nprime2 = (nprime2 + K3 - 1) & -K3;
@@ -466,16 +468,16 @@
       T = TMP_BALLOC_LIMBS (2 * (nprime2 + 1));
       B = A + ((nprime2 + 1) << k);
       fft_l = TMP_BALLOC_TYPE (k + 1, int *);
-      tmp = TMP_BALLOC_TYPE (2 << k, int);
+      tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int);
       for (i = 0; i <= k; i++)
 	{
 	  fft_l[i] = tmp;
-	  tmp += 1 << i;
+	  tmp += (mp_size_t) 1 << i;
 	}
 
       mpn_fft_initl (fft_l, k);
 
-      TRACE (printf ("recurse: %ldx%ld limbs -> %d times %dx%d (%1.2f)\n", n,
+      TRACE (printf ("recurse: %ldx%ld limbs -> %ld times %ldx%ld (%1.2f)\n", n,
 		    n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2));
       for (i = 0; i < K; i++, ap++, bp++)
 	{
@@ -497,10 +499,10 @@
     {
       mp_ptr a, b, tp, tpn;
       mp_limb_t cc;
-      int n2 = 2 * n;
+      mp_size_t n2 = 2 * n;
       tp = TMP_BALLOC_LIMBS (n2);
       tpn = tp + n;
-      TRACE (printf ("  mpn_mul_n %d of %ld limbs\n", K, n));
+      TRACE (printf ("  mpn_mul_n %ld of %ld limbs\n", K, n));
       for (i = 0; i < K; i++)
 	{
 	  a = *ap++;
@@ -534,7 +536,7 @@
    This condition is also fulfilled at exit.
 */
 static void
-mpn_fft_fftinv (mp_ptr *Ap, int K, mp_size_t omega, mp_size_t n, mp_ptr tp)
+mpn_fft_fftinv (mp_ptr *Ap, mp_size_t K, mp_size_t omega, mp_size_t n, mp_ptr tp)
 {
   if (K == 2)
     {
@@ -553,7 +555,7 @@
     }
   else
     {
-      int j, K2 = K >> 1;
+      mp_size_t j, K2 = K >> 1;
 
       mpn_fft_fftinv (Ap,      K2, 2 * omega, n, tp);
       mpn_fft_fftinv (Ap + K2, K2, 2 * omega, n, tp);
@@ -573,12 +575,12 @@
 
 /* R <- A/2^k mod 2^(n*GMP_NUMB_BITS)+1 */
 static void
-mpn_fft_div_2exp_modF (mp_ptr r, mp_srcptr a, int k, mp_size_t n)
+mpn_fft_div_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t k, mp_size_t n)
 {
-  int i;
+  mp_bitcnt_t i;
 
   ASSERT (r != a);
-  i = 2 * n * GMP_NUMB_BITS - k;
+  i = (mp_bitcnt_t) 2 * n * GMP_NUMB_BITS - k;
   mpn_fft_mul_2exp_modF (r, a, i, n);
   /* 1/2^k = 2^(2nL-k) mod 2^(n*GMP_NUMB_BITS)+1 */
   /* normalize so that R < 2^(n*GMP_NUMB_BITS)+1 */
@@ -590,13 +592,11 @@
    Returns carry out, i.e. 1 iff {ap,an} = -1 mod 2^(n*GMP_NUMB_BITS)+1,
    then {rp,n}=0.
 */
-static int
+static mp_size_t
 mpn_fft_norm_modF (mp_ptr rp, mp_size_t n, mp_ptr ap, mp_size_t an)
 {
-  mp_size_t l;
-  long int m;
+  mp_size_t l, m, rpn;
   mp_limb_t cc;
-  int rpn;
 
   ASSERT ((n <= an) && (an <= 3 * n));
   m = an - 2 * n;
@@ -630,10 +630,11 @@
    We must have nl <= 2*K*l.
 */
 static void
-mpn_mul_fft_decompose (mp_ptr A, mp_ptr *Ap, int K, int nprime, mp_srcptr n,
-		       mp_size_t nl, int l, int Mp, mp_ptr T)
+mpn_mul_fft_decompose (mp_ptr A, mp_ptr *Ap, mp_size_t K, mp_size_t nprime,
+		       mp_srcptr n, mp_size_t nl, mp_size_t l, mp_size_t Mp,
+		       mp_ptr T)
 {
-  int i, j;
+  mp_size_t i, j;
   mp_ptr tmp;
   mp_size_t Kl = K * l;
   TMP_DECL;
@@ -717,11 +718,11 @@
 		      mp_size_t nprime, mp_size_t l, mp_size_t Mp,
 		      int **fft_l, mp_ptr T, int sqr)
 {
-  int K, i, pla, lo, sh, j;
+  mp_size_t K, i, pla, lo, sh, j;
   mp_ptr p;
   mp_limb_t cc;
 
-  K = 1 << k;
+  K = (mp_size_t) 1 << k;
 
   /* direct fft's */
   mpn_fft_fft (Ap, K, fft_l + k, 2 * Mp, nprime, 1, T);
@@ -797,10 +798,10 @@
 }
 
 /* return the lcm of a and 2^k */
-static unsigned long int
-mpn_mul_fft_lcm (unsigned long int a, unsigned int k)
+static mp_bitcnt_t
+mpn_mul_fft_lcm (mp_bitcnt_t a, int k)
 {
-  unsigned long int l = k;
+  mp_bitcnt_t l = k;
 
   while (a % 2 == 0 && k > 0)
     {
@@ -817,7 +818,8 @@
 	     mp_srcptr m, mp_size_t ml,
 	     int k)
 {
-  int K, maxLK, i;
+  int i;
+  mp_size_t K, maxLK;
   mp_size_t N, Nprime, nprime, M, Mp, l;
   mp_ptr *Ap, *Bp, A, T, B;
   int **fft_l, *tmp;
@@ -831,45 +833,45 @@
   TMP_MARK;
   N = pl * GMP_NUMB_BITS;
   fft_l = TMP_BALLOC_TYPE (k + 1, int *);
-  tmp = TMP_BALLOC_TYPE (2 << k, int);
+  tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int);
   for (i = 0; i <= k; i++)
     {
       fft_l[i] = tmp;
-      tmp += 1 << i;
+      tmp += (mp_size_t) 1 << i;
     }
 
   mpn_fft_initl (fft_l, k);
-  K = 1 << k;
+  K = (mp_size_t) 1 << k;
   M = N >> k;	/* N = 2^k M */
   l = 1 + (M - 1) / GMP_NUMB_BITS;
-  maxLK = mpn_mul_fft_lcm ((unsigned long) GMP_NUMB_BITS, k); /* lcm (GMP_NUMB_BITS, 2^k) */
+  maxLK = mpn_mul_fft_lcm (GMP_NUMB_BITS, k); /* lcm (GMP_NUMB_BITS, 2^k) */
 
   Nprime = (1 + (2 * M + k + 2) / maxLK) * maxLK;
   /* Nprime = ceil((2*M+k+3)/maxLK)*maxLK; */
   nprime = Nprime / GMP_NUMB_BITS;
-  TRACE (printf ("N=%ld K=%d, M=%ld, l=%ld, maxLK=%d, Np=%ld, np=%ld\n",
+  TRACE (printf ("N=%ld K=%ld, M=%ld, l=%ld, maxLK=%ld, Np=%ld, np=%ld\n",
 		 N, K, M, l, maxLK, Nprime, nprime));
   /* we should ensure that recursively, nprime is a multiple of the next K */
   if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
     {
-      unsigned long K2;
+      mp_size_t K2;
       for (;;)
 	{
-	  K2 = 1L << mpn_fft_best_k (nprime, sqr);
+	  K2 = (mp_size_t) 1 << mpn_fft_best_k (nprime, sqr);
 	  if ((nprime & (K2 - 1)) == 0)
 	    break;
 	  nprime = (nprime + K2 - 1) & -K2;
 	  Nprime = nprime * GMP_LIMB_BITS;
 	  /* warning: since nprime changed, K2 may change too! */
 	}
-      TRACE (printf ("new maxLK=%d, Np=%ld, np=%ld\n", maxLK, Nprime, nprime));
+      TRACE (printf ("new maxLK=%ld, Np=%ld, np=%ld\n", maxLK, Nprime, nprime));
     }
   ASSERT_ALWAYS (nprime < pl); /* otherwise we'll loop */
 
   T = TMP_BALLOC_LIMBS (2 * (nprime + 1));
   Mp = Nprime >> k;
 
-  TRACE (printf ("%ldx%ld limbs -> %d times %ldx%ld limbs (%1.2f)\n",
+  TRACE (printf ("%ldx%ld limbs -> %ld times %ldx%ld limbs (%1.2f)\n",
 		pl, pl, K, nprime, nprime, 2.0 * (double) N / Nprime / K);
 	 printf ("   temp space %ld\n", 2 * K * (nprime + 1)));
 
@@ -904,9 +906,9 @@
 {
   mp_ptr pad_op;
   mp_size_t pl, pl2, pl3, l;
+  mp_size_t cc, c2, oldcc;
   int k2, k3;
   int sqr = (n == m && nl == ml);
-  int cc, c2, oldcc;
 
   pl = nl + ml; /* total number of limbs of the result */
 


More information about the gmp-devel mailing list