mpz_mul core dump - attempt 2

Torbjorn Granlund tege at swox.com
Sat Sep 18 02:05:53 CEST 2004


Please try this patch for the FFT crash.  (Thanks to Paul
Zimmermann for debugging this and providing the patch.)

-- 
Torbjörn


--- mul_fft.c.orig	2004-09-16 16:10:43.000000000 +0200
+++ mul_fft.c	2004-09-17 17:05:02.339802415 +0200
@@ -72,8 +72,10 @@
 	 mp_ptr, mp_ptr, mp_size_t, mp_size_t, mp_size_t, int **, mp_ptr,int));
 
 
-/* Find the best k to use for a mod 2^(n*BITS_PER_MP_LIMB)+1 FFT.
-   sqr==0 if for a multiply, sqr==1 for a square */
+/* Find the best k to use for a mod 2^(m*BITS_PER_MP_LIMB)+1 FFT
+   with m >= n.
+   sqr==0 if for a multiply, sqr==1 for a square.
+*/
 int
 mpn_fft_best_k (mp_size_t n, int sqr)
 {
@@ -91,26 +93,17 @@
 }
 
 
-/* Returns smallest possible number of limbs >= pl for a fft of size 2^k.
-
-   FIXME: Is this N rounded up to the next multiple of (2^k)*BITS_PER_MP_LIMB
-   bits and therefore simply pl rounded up to a multiple of 2^k? */
+/* Returns smallest possible number of limbs >= pl for a fft of size 2^k,
+   i.e. smallest multiple of 2^k >= pl. */
 
 mp_size_t
 mpn_fft_next_size (mp_size_t pl, int k)
 {
-  mp_size_t N, M;
-  int K;
+  unsigned long K;
 
-  /*  if (k==0) k = mpn_fft_best_k (pl, sqr); */
-  N = pl * BITS_PER_MP_LIMB;
   K = 1 << k;
-  if (N % K)
-    N = (N / K + 1) * K;
-  M = N / K;
-  if (M % BITS_PER_MP_LIMB)
-    N = ((M / BITS_PER_MP_LIMB) + 1) * BITS_PER_MP_LIMB * K;
-  return N / BITS_PER_MP_LIMB;
+  pl = 1 + (pl - 1) / K; /* ceil(pl/K) */
+  return pl * K;
 }
 
 
@@ -205,15 +198,18 @@
 {
   if (K == 2)
     {
+      mp_limb_t cy;
 #if HAVE_NATIVE_mpn_addsub_n
-      if (mpn_addsub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1)
-	Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, CNST_LIMB(1));
+      cy = mpn_addsub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1;
 #else
       MPN_COPY (tp, Ap[0], n + 1);
-      mpn_add_n (Ap[0], Ap[0], Ap[inc],n + 1);
-      if (mpn_sub_n (Ap[inc], tp, Ap[inc],n + 1))
-	Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, CNST_LIMB(1));
+      mpn_add_n (Ap[0], Ap[0], Ap[inc], n + 1);
+      cy = mpn_sub_n (Ap[inc], tp, Ap[inc], n + 1);
 #endif
+      if (Ap[0][n] > CNST_LIMB(1)) /* can be 2 or 3 */
+        Ap[0][n] = CNST_LIMB(1) - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - CNST_LIMB(1));
+      if (cy) /* Ap[inc][n] can be -1 or -2 */
+        Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, ~Ap[inc][n] + CNST_LIMB(1));
     }
   else
     {
@@ -252,24 +248,26 @@
 {
   if (K == 2)
     {
+      mp_limb_t ca, cb;
 #if HAVE_NATIVE_mpn_addsub_n
-      if (mpn_addsub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1)
-	Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, CNST_LIMB(1));
+      ca = mpn_addsub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1;
+      cb = mpn_addsub_n (Bp[0], Bp[inc], Bp[0], Bp[inc], n + 1) & 1;
 #else
       MPN_COPY (tp, Ap[0], n + 1);
-      mpn_add_n (Ap[0], Ap[0], Ap[inc],n + 1);
-      if (mpn_sub_n (Ap[inc], tp, Ap[inc],n + 1))
-	Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, CNST_LIMB(1));
-#endif
-#if HAVE_NATIVE_mpn_addsub_n
-      if (mpn_addsub_n (Bp[0], Bp[inc], Bp[0], Bp[inc], n + 1) & 1)
-	Bp[inc][n] = mpn_add_1 (Bp[inc], Bp[inc], n, CNST_LIMB(1));
-#else
+      mpn_add_n (Ap[0], Ap[0], Ap[inc], n + 1);
+      ca = mpn_sub_n (Ap[inc], tp, Ap[inc], n + 1);
       MPN_COPY (tp, Bp[0], n + 1);
-      mpn_add_n (Bp[0], Bp[0], Bp[inc],n + 1);
-      if (mpn_sub_n (Bp[inc], tp, Bp[inc],n + 1))
-	Bp[inc][n] = mpn_add_1 (Bp[inc], Bp[inc], n, CNST_LIMB(1));
+      mpn_add_n (Bp[0], Bp[0], Bp[inc], n + 1);
+      cb = mpn_sub_n (Bp[inc], tp, Bp[inc], n + 1);
 #endif
+      if (Ap[0][n] > CNST_LIMB(1)) /* can be 2 or 3 */
+        Ap[0][n] = CNST_LIMB(1) - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - CNST_LIMB(1));
+      if (ca) /* Ap[inc][n] can be -1 or -2 */
+        Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, ~Ap[inc][n] + CNST_LIMB(1));
+      if (Bp[0][n] > CNST_LIMB(1)) /* can be 2 or 3 */
+        Bp[0][n] = CNST_LIMB(1) - mpn_sub_1 (Bp[0], Bp[0], n, Bp[0][n] - CNST_LIMB(1));
+      if (cb) /* Bp[inc][n] can be -1 or -2 */
+        Bp[inc][n] = mpn_add_1 (Bp[inc], Bp[inc], n, ~Bp[inc][n] + CNST_LIMB(1));
     }
   else
     {
@@ -339,13 +337,29 @@
       mp_ptr *Ap,*Bp,A,B,T;
 
       k = mpn_fft_best_k (n, sqr);
-      K2 = 1<<k;
+      K2 = 1 << k;
+      ASSERT_ALWAYS(n % K2 == 0);
       maxLK = (K2>BITS_PER_MP_LIMB) ? K2 : BITS_PER_MP_LIMB;
       M2 = n*BITS_PER_MP_LIMB/K2;
-      l = n/K2;
-      Nprime2 = ((2 * M2+k+2+maxLK)/maxLK)*maxLK; /* ceil()(2*M2+k+3)/maxLK)*maxLK*/
-      nprime2 = Nprime2/BITS_PER_MP_LIMB;
-      Mp2 = Nprime2/K2;
+      l = n / K2;
+      Nprime2 = ((2 * M2+k+2+maxLK)/maxLK)*maxLK; /* ceil((2*M2+k+3)/maxLK)*maxLK*/
+      nprime2 = Nprime2 / BITS_PER_MP_LIMB;
+
+      /* we should ensure that nprime2 is a multiple of the next K */
+      if (nprime2 >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
+        {
+          unsigned long K3;
+          while (nprime2 % (K3 = 1 << mpn_fft_best_k (nprime2, sqr)))
+            {
+              nprime2 = ((nprime2 + K3 - 1) / K3) * K3;
+              Nprime2 = nprime2 * BITS_PER_MP_LIMB;
+              /* warning: since nprime2 changed, K3 may change too! */
+            }
+          ASSERT(nprime2 % K3 == 0);
+        }
+      ASSERT_ALWAYS(nprime2 < n); /* otherwise we'll loop */
+
+      Mp2 = Nprime2 / K2;
 
       Ap = TMP_ALLOC_MP_PTRS (K2);
       Bp = TMP_ALLOC_MP_PTRS (K2);
@@ -402,22 +416,28 @@
 
 
 /* input: A^[l[k][0]] A^[l[k][1]] ... A^[l[k][K-1]]
-   output: K*A[0] K*A[K-1] ... K*A[1] */
+   output: K*A[0] K*A[K-1] ... K*A[1].
+   Assumes the Ap[] are pseudo-normalized, i.e. 0 <= Ap[][n] <= 1.
+   This condition is also fulfilled at exit.
+*/
 
 static void
 mpn_fft_fftinv (mp_ptr *Ap, int K, mp_size_t omega, mp_size_t n, mp_ptr tp)
 {
   if (K == 2)
     {
+      mp_limb_t cy;
 #if HAVE_NATIVE_mpn_addsub_n
-      if (mpn_addsub_n (Ap[0], Ap[1], Ap[0], Ap[1], n + 1) & 1)
-	Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, CNST_LIMB(1));
+      cy = mpn_addsub_n (Ap[0], Ap[1], Ap[0], Ap[1], n + 1) & 1;
 #else
       MPN_COPY (tp, Ap[0], n + 1);
       mpn_add_n (Ap[0], Ap[0], Ap[1], n + 1);
-      if (mpn_sub_n (Ap[1], tp, Ap[1], n + 1))
-	Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, CNST_LIMB(1));
+      cy = mpn_sub_n (Ap[1], tp, Ap[1], n + 1);
 #endif
+      if (Ap[0][n] > CNST_LIMB(1)) /* can be 2 or 3 */
+        Ap[0][n] = CNST_LIMB(1) - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - CNST_LIMB(1));
+      if (cy) /* Ap[1][n] can be -1 or -2 */
+        Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, ~Ap[1][n] + CNST_LIMB(1));
     }
   else
     {
@@ -459,12 +479,13 @@
 }
 
 
-/* R <- A mod 2^(n*BITS_PER_MP_LIMB)+1, n<=an<=3*n */
+/* R <- A mod 2^(n*BITS_PER_MP_LIMB)+1, n <= an <= 3*n */
 static void
 mpn_fft_norm_modF (mp_ptr rp, mp_ptr ap, mp_size_t n, mp_size_t an)
 {
   mp_size_t l;
 
+  ASSERT (n <= an && an <= 3 * n);
   if (an > 2 * n)
     {
       l = n;
@@ -595,7 +616,7 @@
 /* op <- n*m mod 2^N+1 with fft of size 2^k where N=pl*BITS_PER_MP_LIMB
    n and m have respectively nl and ml limbs
    op must have space for pl+1 limbs
-   One must have pl = mpn_fft_next_size (pl, k).
+   Assumes pl is multiple of 2^k.
 */
 
 void
@@ -612,7 +633,6 @@
   TMP_DECL(marker);
 
   TRACE (printf ("\nmpn_mul_fft pl=%ld nl=%ld ml=%ld k=%d\n", pl, nl, ml, k));
-  ASSERT_ALWAYS (mpn_fft_next_size (pl, k) == pl);
 
   TMP_MARK(marker);
   N = pl * BITS_PER_MP_LIMB;
@@ -620,25 +640,31 @@
   for (i = 0; i <= k; i++)
     _fft_l[i] = TMP_ALLOC_TYPE (1<<i, int);
   mpn_fft_initl (_fft_l, k);
-  K = 1<<k;
-  M = N/K;	/* N = 2^k M */
-  l = M/BITS_PER_MP_LIMB;
+  K = 1 << k;
+  ASSERT_ALWAYS (pl % K == 0);
+  M = N/K;	/* exact: N = 2^k M */
+  l = M / BITS_PER_MP_LIMB; /* l = pl / K also */
   maxLK = (K>BITS_PER_MP_LIMB) ? K : BITS_PER_MP_LIMB;
 
   Nprime = ((2 * M + k + 2 + maxLK) / maxLK) * maxLK; /* ceil((2*M+k+3)/maxLK)*maxLK; */
   nprime = Nprime / BITS_PER_MP_LIMB;
+  /* with B := BITS_PER_MP_LIMB, nprime >= 2*M/B = 2*N/(K*B) = 2*pl/K = 2*l */
   TRACE (printf ("N=%d K=%d, M=%d, l=%d, maxLK=%d, Np=%d, np=%d\n",
 		 N, K, M, l, maxLK, Nprime, nprime));
+  /* we should ensure that recursively, nprime is a multiple of the next K */
   if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
     {
-      maxLK = (1 << mpn_fft_best_k (nprime,n == m)) * BITS_PER_MP_LIMB;
-      if (Nprime % maxLK)
-	{
-	  Nprime = ((Nprime / maxLK) + 1) * maxLK;
-	  nprime = Nprime / BITS_PER_MP_LIMB;
-	}
+      unsigned long K2;
+      while (nprime % (K2 = 1 << mpn_fft_best_k (nprime, sqr)))
+        {
+          nprime = ((nprime + K2 - 1) / K2) * K2;
+          Nprime = nprime * BITS_PER_MP_LIMB;
+          /* warning: since nprime changed, K2 may change too! */
+        }
       TRACE (printf ("new maxLK=%d, Np=%d, np=%d\n", maxLK, Nprime, nprime));
+      ASSERT(nprime % K2 == 0);
     }
+  ASSERT_ALWAYS (nprime < pl); /* otherwise we'll loop */
 
   T = TMP_ALLOC_LIMBS (nprime + 1);
   Mp = Nprime/K;


More information about the gmp-bugs mailing list