/* mpn_mul_toom22 -- Multiply {ap,an} and {bp,bn} where an >= bn. Or more accurately, bn <= an < 2bn. Contributed to the GNU project by Torbjorn Granlund. The idea of using asymmetric operands was suggested by Marco Bodrato and Alberto Zanoni. THE FUNCTION IN THIS FILE IS INTERNAL WITH A MUTABLE INTERFACE. IT IS ONLY SAFE TO REACH IT THROUGH DOCUMENTED INTERFACES. IN FACT, IT IS ALMOST GUARANTEED THAT IT WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE. Copyright 2006, 2007 Free Software Foundation, Inc. This file is part of the GNU MP Library. The GNU MP Library is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 3 of the License, or (at your option) any later version. The GNU MP Library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with the GNU MP Library. If not, see http://www.gnu.org/licenses/. */ #include "gmp.h" #include "gmp-impl.h" #ifdef STAT #undef STAT #define WANT_STAT int cputime (void); #define STAT(x) x double t_eval, t_interpol, t_mulpts, t_total; #else #define STAT(x) #endif static inline int mpn_zero_p (mp_srcptr ap, mp_size_t n) { mp_size_t i; for (i = n - 1; i >= 0; i--) { if (ap[i] != 0) return 0; } return 1; } /* Evaluate in: -1, 0, +inf <-s--><--n--> ____ ______ |_a1_|___a0_| |b1_|___b0_| <-t-><--n--> v0 = a0 * b0 # A(0)*B(0) vm1 = (a0- a1)*(b0- b1) # A(-1)*B(-1) vinf= a1 * b1 # A(inf)*B(inf) */ void mpn_mul_toom22 (mp_ptr pp, mp_srcptr ap, mp_size_t an, mp_srcptr bp, mp_size_t bn, mp_ptr scratch) { mp_size_t n, s, t; int vm1_neg; mp_limb_t cy, cy2; mp_ptr asm1; mp_ptr bsm1; #define a0 ap #define a1 (ap + n) #define b0 bp #define b1 (bp + n) n = (an + 1) >> 1; s = an - n; t = bn - n; ASSERT (0 < s && s <= n); ASSERT (0 < t && t <= s); asm1 = pp; bsm1 = pp + n; STAT (t_eval -= cputime ()); vm1_neg = 0; /* Compute asm1. */ if (s == n) { if (mpn_cmp (a0, a1, n) < 0) { mpn_sub_n (asm1, a1, a0, n); vm1_neg = 1; } else { mpn_sub_n (asm1, a0, a1, n); } } else { if (mpn_zero_p (a0 + s, n - s) && mpn_cmp (a0, a1, s) < 0) { mpn_sub_n (asm1, a1, a0, s); MPN_ZERO (asm1 + s, n - s); vm1_neg = 1; } else { mpn_sub (asm1, a0, n, a1, s); } } /* Compute bsm1. */ if (t == n) { if (mpn_cmp (b0, b1, n) < 0) { mpn_sub_n (bsm1, b1, b0, n); vm1_neg ^= 1; } else { mpn_sub_n (bsm1, b0, b1, n); } } else { if (mpn_zero_p (b0 + t, n - t) && mpn_cmp (b0, b1, t) < 0) { mpn_sub_n (bsm1, b1, b0, t); MPN_ZERO (bsm1 + t, n - t); vm1_neg ^= 1; } else { mpn_sub (bsm1, b0, n, b1, t); } } STAT (t_eval += cputime ()); #define v0 pp /* 2n */ #define vinf (pp + 2 * n) /* s+t */ #define vm1 scratch /* 2n */ #define tp (scratch + 2*n) /* vm1, 2n limbs */ mpn_mul_n (vm1, asm1, bsm1, n); /* vinf, s+t limbs */ mpn_mul (vinf, a1, s, b1, t); /* v0, 2n limbs */ mpn_mul_n (v0, ap, bp, n); STAT (t_interpol -= cputime ()); /* H(v0) + L(vinf) */ cy = mpn_add_n (tp, v0 + n, vinf, n); /* L(v0) + H(v0) */ cy2 = cy + mpn_add_n (pp + n, tp, v0, n); /* L(vinf) + H(vinf) */ cy += mpn_add (pp + 2 * n, tp, n, vinf + n, s + t - n); if (vm1_neg) cy += mpn_add_n (pp + n, pp + n, vm1, 2 * n); else cy -= mpn_sub_n (pp + n, pp + n, vm1, 2 * n); ASSERT (cy + 1 <= 3); ASSERT (cy2 <= 2); mpn_incr_u (pp + 2 * n, cy2); if (LIKELY (cy <= 2)) mpn_incr_u (pp + 3 * n, cy); else mpn_decr_u (pp + 3 * n, 1); STAT (t_interpol += cputime ()); } #define CONCAT(name,M,N) name ## M ## N #define M 2 #define N 2 #define mpn_mul_toomMN CONCAT(mpn_mul_toom,2,2) #ifdef CHECK #include #include #ifndef SIZE #define SIZE 10 #endif void dumpy (mp_srcptr p, mp_size_t n) { mp_size_t i; for (i = n - 1; i >= 0; i--) { printf ("%0*lx", (int) (2 * sizeof (mp_limb_t)), p[i]); printf (" " + (i == 0)); } puts (""); } int main (int argc, char **argv) { mp_size_t n, s, t, an, bn, clearn; mp_ptr ap, bp, refp, pp, scratch; mp_limb_t keep; int test; int maxn; int norandom; int err = 0; TMP_DECL; TMP_MARK; an = M * SIZE; bn = N * SIZE; norandom = 0; if (argc >= 2) { maxn = strtol (argv[1], 0, 0); an = M * maxn; bn = N * maxn; if (argc == 3) { an = maxn; bn = strtol (argv[2], 0, 0); norandom = 1; } } else return 1; ap = TMP_ALLOC_LIMBS (an); bp = TMP_ALLOC_LIMBS (bn); refp = TMP_ALLOC_LIMBS (an + bn); pp = TMP_ALLOC_LIMBS (an + bn + 1); scratch = TMP_SALLOC_LIMBS (an + bn); for (test = 0;; test++) { if (err == 0 && test % 0x100 == 0) { printf ("\r%d", test); fflush (stdout); } if (! norandom) { n = random () % maxn + 1; s = random () % n + 1; #if M == N t = random () % s + 1; #else t = random () % n + 1; #endif an = (M - 1) * n + s; bn = (N - 1) * n + t; } mpn_random2 (ap, an); clearn = random () % (an + 1); MPN_ZERO (ap + clearn, an - clearn); mpn_random2 (bp, bn); clearn = random () % (bn + 1); MPN_ZERO (bp + clearn, bn - clearn); mpn_random2 (pp, an + bn + 1); keep = pp[an + bn]; mpn_mul_toomMN (pp, ap, an, bp, bn, scratch); mpn_mul (refp, ap, an, bp, bn); if (pp[an + bn] != keep || mpn_cmp (refp, pp, an + bn) != 0) { printf ("ERROR in test %d\n", test); if (pp[an + bn] != keep) { printf ("pp high:"); dumpy (pp + an + bn, 1); printf ("keep: "); dumpy (&keep, 1); } dumpy (ap, an); dumpy (bp, bn); dumpy (pp, an + bn); dumpy (refp, an + bn); if (++err > 5) abort(); } } TMP_FREE; } #endif #ifdef TIMING #include #include #include "timing.h" #ifndef SIZE #define SIZE 10 #endif int main (int argc, char **argv) { mp_size_t an, bn; mp_ptr ap, bp, refp, pp, scratch; double t; TMP_DECL; TMP_MARK; if (argc >= 2) { an = bn = strtol (argv[1], 0, 0); if (argc == 3) bn = strtol (argv[2], 0, 0); } else return 1; if (!(an >= bn && an < 2 * bn)) { fprintf (stderr, "Invalid value combination of an,bn\n"); return 1; } ap = TMP_ALLOC_LIMBS (an); bp = TMP_ALLOC_LIMBS (bn); refp = TMP_ALLOC_LIMBS (an + bn); pp = TMP_ALLOC_LIMBS (an + bn); scratch = TMP_SALLOC_LIMBS (an + bn); mpn_random (ap, an); mpn_random (bp, bn); TIME (t, mpn_mul_toomMN (pp, ap, an, bp, bn, scratch)); printf ("mpn_mul_toom%d%d: %f\n", M, N, t); TIME (t, mpn_mul (refp, ap, an, bp, bn)); printf ("mpn_mul: %f\n", t); TIME (t, mpn_mul_basecase (refp, ap, an, bp, bn)); printf ("mpn_mul_basecase: %f\n", t); TMP_FREE; #ifdef WANT_STAT printf ("time in eval : %f\n", t_eval); printf ("time in interpolate: %f\n", t_interpol); #endif return 0; } #endif #ifdef WANT_STAT #include #include #include "timing.h" int main (int argc, char **argv) { mp_size_t an, bn; mp_ptr ap, bp, pp, scratch; long reps, i; double t0; TMP_DECL; TMP_MARK; if (argc < 3) return 1; reps = strtol (argv[1], 0, 0); if (argc >= 3) { an = bn = strtol (argv[2], 0, 0); if (argc == 4) bn = strtol (argv[2], 0, 0); } if (!(an >= bn && an < 2 * bn)) { fprintf (stderr, "Invalid value combination of an,bn\n"); return 1; } ap = TMP_ALLOC_LIMBS (an); bp = TMP_ALLOC_LIMBS (bn); pp = TMP_ALLOC_LIMBS (an + bn); scratch = TMP_SALLOC_LIMBS (an + bn); mpn_random (ap, an); mpn_random (bp, bn); t_eval = 0; t_interpol = 0; t0 = cputime (); for (i = reps; i != 0; i--) mpn_mul_toomMN (pp, ap, an, bp, bn, scratch); t_total = cputime () - t0; t_mulpts = t_total - t_eval - t_interpol; TMP_FREE; printf ("time in eval : %.10f\n", t_eval); printf ("time in interpolate : %.10f\n", t_interpol); printf ("time in mul points : %.10f\n", t_mulpts); printf ("time total mul_toom22 : %.10f\n", t_total); t0 = cputime (); for (i = reps; i != 0; i--) mpn_mul (pp, ap, an, bp, bn); t_total = cputime () - t0; printf ("time in mpn_mul : %.10f\n", t_total); return 0; } #endif