Multiplication of unbalanced operands

Paul Zimmermann Paul.Zimmermann at loria.fr
Wed Nov 22 22:34:00 CET 2006


I have implemented at the mpn level the Toom-Cook 2.5 algorithm suggested by
Marco in http://gmplib.org/list-archives/gmp-devel/2006-November/000675.html
(except one should exchange vinf and v0 in the last two lines).

The code is included below. On a Pentium M with gmp-4.2.1, it it faster than
mpn_mul up from n=11 (i.e. a 33 x 22 product), and up to the FFT threshold
(i.e. 4992 x 3328). On this large range, it saves up to about 20% with respect
to mpn_mul.

Paul

#include <stdio.h>
#include <stdlib.h>
#include "gmp.h"
#include "gmp-impl.h"

void
mpn_print (mp_ptr A, mp_size_t n)
{
  int j;

  for (j=0; j<n; j++)
    {
      printf ("+%lu*B^%u", A[j], j);
      if (j % 4 == 3 && j != n-1)
        printf ("\n");
    }
  printf (":\n");
}

/* {cp, 5n} <- {ap, 3n} * {bp, 2n} using Bodrato-Zanoni algorithm */
void
mpn_toom_3_2 (mp_ptr cp, mp_srcptr ap, mp_srcptr bp, mp_size_t n)
{
  mp_ptr tp;
  mp_limb_t ca, cb, ba, bb, cv1, cvm1;
  TMP_DECL;

  TMP_MARK;
  tp = TMP_ALLOC_LIMBS (4 * n);
#define v0 cp
#define v1 tp
#define vm1 (tp + 2 * n)
#define vinf cp + 3 * n
  ba = mpn_add_n (cp, ap, ap + 2 * n, n); /* a0 + a2 */
  ca = ba + mpn_add_n (cp + n, cp, ap + n, n); /* a0 + a1 + a2 */
  cb = mpn_add_n (cp + 2 * n, bp, bp + n, n); /* b0 + b1 */
  mpn_mul_n (v1, cp + n, cp + 2 * n, n); /* v1 */
  /* 0 <= ca <= 2, 0 <= cb <= 1 */
  cv1 = (cb == 0) ? 0 : mpn_add_n (v1 + n, v1 + n, cp + n, n);
  if (ca != 0)
    {
      cv1 += (ca == 1) ? mpn_add_n (v1 + n, v1 + n, cp + 2 * n, n)
	: mpn_addmul_1 (v1 + n, cp + 2 * n, n, ca);
      cv1 += ca * cb;
    }
  /* 0 <= cv1 <= 5 */

  ba -= mpn_sub_n (cp, cp, ap + n, n); /* a0 - a1 + a2 */
  bb = -mpn_sub_n (cp + n, bp, bp + n, n); /* b0 - b1 */
  mpn_mul_n (vm1, cp, cp + n, n); /* vm1 */
  /* -1 <= ba <= 1, -1 <= bb <= 0 */
  cvm1 = (bb == 0) ? 0 : -mpn_sub_n (vm1 + n, vm1 + n, cp, n);
  if (ba == 1)
    {
      cvm1 += mpn_add_n (vm1 + n, vm1 + n, cp + n, n);
      cvm1 -= bb != 0;
    }
  else if (ba != 0) /* ba = -1 */
    {
      cvm1 -= mpn_sub_n (vm1 + n, vm1 + n, cp + n, n);
      cvm1 += bb != 0;
    }
  /* -2 <= cvm1 <= 1 */

  mpn_mul_n (v0, ap, bp, n); /* v0 */

  mpn_mul_n (vinf, ap + 2 * n, bp + n, n); /* vinf */

  /************************** interpolation **********************************/

  /* vm1 <- (v1 - vm1) / 2 */
  cvm1 = cv1 - cvm1 - mpn_sub_n (vm1, v1, vm1, 2 * n);
  ca = mpn_rshift (vm1, vm1, 2 * n, 1);
  ASSERT(ca == 0);
  vm1[2 * n - 1] |= (cvm1 & 1) << (GMP_NUMB_BITS - 1);
  cvm1 /= 2;
  ASSERT(0 <= cvm1 && cvm1 <= 2);

  /* v1 <- v1 - vm1 - v0 */
  cv1 = cv1 - cvm1 - mpn_sub_n (v1, v1, vm1, 2 * n);
  cv1 = cv1 - mpn_sub_n (v1, v1, v0, 2 * n);
  ASSERT(cv1 == 0 || cv1 == 1);

  /* vm1 <- vm1 - vinf */
  ca = mpn_sub_n (vm1, vm1, vinf, n);
  /* put directly the high part in {cp+2n, n} */
  cvm1 = cvm1 - mpn_sub_nc (cp + 2 * n, vm1 + n, vinf + n, n, ca);
  ASSERT(cvm1 == 0 || cvm1 == 1);

  cv1 += mpn_add_n (cp + 2 * n, cp + 2 * n, v1, 2 * n);
  cb = mpn_add_1 (cp + 4 * n, cp + 4 * n, n, cv1);
  
  ca = mpn_add_n (cp + n, cp + n, vm1, n);
  cvm1 += mpn_add_1 (cp + 2 * n, cp + 2 * n, n, ca);
  cb += mpn_add_1 (cp + 3 * n, cp + 3 * n, 2 * n, cvm1);  

  ASSERT (cb == 0);
#undef v0
#undef v1
#undef vm1
#undef vinf
  TMP_FREE;
}

#ifdef MAIN
#include "cputime.h"

int
main (int argc, char *argv[])
{
  mp_size_t n = atoi (argv[1]);
  unsigned long k, i;
  mp_ptr ap, bp, cp, dp;
  int st1, st2;

  k = (argc > 2) ? atoi (argv[2]) : 1;

  ap = malloc (3 * n * sizeof (mp_limb_t));
  bp = malloc (2 * n * sizeof (mp_limb_t));
  cp = malloc (5 * n * sizeof (mp_limb_t));
  dp = malloc (5 * n * sizeof (mp_limb_t));

  mpn_random2 (ap, 3 * n);
  mpn_random2 (bp, 2 * n);

  st1 = cputime ();
  for (i = 0; i < k; i++)
    mpn_mul (cp, ap, 3 * n, bp, 2 * n);
  printf ("mpn_mul      took %dms\n", st1 = cputime () - st1);

  st2 = cputime ();
  for (i = 0; i < k; i++)
    mpn_toom_3_2 (dp, ap, bp, n);
  st2 = cputime () - st2;
  printf ("mpn_toom_3_2 took %dms (%1.2f)\n", st2, (double) st2 / st1);

  if (mpn_cmp (cp, dp, 5 * n) != 0)
    {
      printf ("resuls differ\n");
      printf ("ap:="); mpn_print (ap, 3 * n);
      printf ("bp:="); mpn_print (bp, 2 * n);
      printf ("cp:="); mpn_print (cp, 5 * n);
      printf ("dp:="); mpn_print (dp, 5 * n);
      exit (1);
    }

  free (ap);
  free (bp);
  free (cp);
  free (dp);

  return 0;
}
#endif


More information about the gmp-devel mailing list