3-prime FFT

Torbjorn Granlund tege@swox.com
13 Dec 2002 01:35:43 +0100


--=-=-=

I'd like to evaluate the 3-prime FFT algorithm for GMP 4.2.
I think that with current GMP implementation techniques, it
could be competitive with the current mul_fft code.

The idea with the 3-prime algorithm is to compute three
separate results, each with coefficients mod a distinct
limb-sized prime.  At the end, CRT is used to reconstruct
the coefficients.

The inner loops could be made to run very fast on modern
computers.  A nice property that will allow high speed is
the non-recurrency of the inner loops.

I have a basic implementation, originally written back in
the GMP 1.0 times.  An updated version is included below.

This implementation used a umul_ppmm and a udiv_qrnnd in the
inner loop (see the function `fft').  This is silly.  At the
very least udiv_qrnnd_preinv should be used.  But of course,
the real trick is to use Montgomery multiplication!

And then we should write the inner loop in assembly for
important processors.  Since the inner loop is simpler than
any current GMP assembly loops, that isn't much work.  I
expect it to run at about the same speed as mpn_mul_1.

(My code also uses a large precomputed table.  It would be
great to avoid that, and I think that could be done without
much speed loss.  I also think the inverse transform now
computed with `ffi' could use `fft'.)

Any takers?  Kevin?  Paul?


--=-=-=
Content-Type: application/octet-stream
Content-Disposition: attachment; filename=mpn_p3mul.c

/* 3-prime convolution/multiplication algorithm.  */

#include <stdio.h>
#include <stdlib.h>

#include "gmp.h"
#include "gmp-impl.h"
#include "longlong.h"

#ifndef GMP_LIMB_BYTES
#define GMP_LIMB_BYTES BYTES_PER_MP_LIMB
#endif

#if GMP_LIMB_BITS == 32
#define MOD_1		0xC0000001
#define P_ELEM_1	5
#define MOD_2		0xD0000001
#define P_ELEM_2	3
#define MOD_3		0xE8000001
#define P_ELEM_3	3
#define M1_I_M2		13
#define M1M2_0		0x90000001
#define M1M2_1		0x9C000001
#define M1M2_I_M3	0xC911114A
#define MAX_TRANSFORM_SIZE (1<<27)
#endif

#if GMP_LIMB_BITS == 64
#define MOD_1		CNST_LIMB(0xD800000000000001)
#define P_ELEM_1	5
#define MOD_2		CNST_LIMB(0xBE00000000000001)
#define P_ELEM_2	3
#define MOD_3		CNST_LIMB(0xF600000000000001)
#define P_ELEM_3	5
#define M1_I_M2		CNST_LIMB(0x664ec4ec4ec4ec48)
#define M1M2_0		CNST_LIMB(ox9600000000000001)
#define M1M2_1		CNST_LIMB(0xa050000000000001)
#define M1M2_I_M3	CNST_LIMB(0x6b2f8af8af8af8d4)
#define MAX_TRANSFORM_SIZE (1<<57)
#endif


static mp_limb_t
upowm (mp_limb_t b, unsigned int e, mp_limb_t m)
{
  mp_limb_t yh, yl = 1;
  mp_limb_t bh, bl = b;
  int dummy;

  while (e != 0)
    {
      while ((e & 1) == 0)
	{
	  /* bl = bl * bl % m; */
	  umul_ppmm (bh, bl, bl, bl);
	  udiv_qrnnd (dummy, bl, bh, bl, m);

	  e >>= 1;
	}
      e -= 1;

      /* y = y * b % m; */
      umul_ppmm (yh, yl, yl, bl);
      udiv_qrnnd (dummy, yl, yh, yl, m);
    }

  return yl;
}

static inline mp_size_t
back_add (mp_size_t a, mp_size_t pos)
{
  a ^= pos;
  while ((a & pos) == 0)
    {
      pos >>= 1;
      a ^= pos;
    }
  return a;
}

static inline mp_size_t
back_sub (mp_size_t a, mp_size_t pos)
{
  a ^= pos;
  while ((a & pos) != 0)
    {
      pos >>= 1;
      a ^= pos;
    }
  return a;
}

static mp_limb_t
invert_mod (mp_limb_t a, mp_limb_t N)
{
  mp_limb_t b, q;
  long long int x;
  long long int s0 = 1;
  long long int s1 = 0;

  b = N;
  while (b != 0)
    {
      q = a / b;
      x = b;  b =  a  - b  * q; a = x;
      x = s1; s1 = s0 - s1 * q; s0 = x;
    }

  if (a != 1)
    abort ();

  if (s0 < 0)
    return s0 + N;
  else
    return s0;
}

static void
omega_init (mp_ptr w, mp_ptr w_n, mp_size_t size, mp_limb_t N, mp_limb_t PRIMITIVE_ELEMENT)
{
  mp_size_t i, b;
  mp_limb_t Nth_1_root, Nth_1_root_inv, we;
  mp_limb_t xh, xl;
  int dummy;

  Nth_1_root = upowm (PRIMITIVE_ELEMENT, N/size, N);
  Nth_1_root_inv = invert_mod (Nth_1_root, N);

  we = 1;
  w[0] = we;
  for (i = 1, b = 0; i < size/2; i++)
    {
      b = back_add (b, size/4);

      umul_ppmm (xh, xl, we, Nth_1_root);
      udiv_qrnnd (dummy, we, xh, xl, N);
      w[b] = we;
    }

  we = 1;
  w_n[0] = we;
  for (i = 1, b = 0; i < size/2; i++)
    {
      b = back_add (b, size/4);

      umul_ppmm (xh, xl, we, Nth_1_root_inv);
      udiv_qrnnd (dummy, we, xh, xl, N);
      w_n[b] = we;
    }
}

static void
modularize (mp_ptr dp, mp_size_t dn, mp_srcptr ip, mp_size_t in, mp_limb_t N)
{
  mp_size_t i;

  for (i = 0; i < in; i++)
    dp[i] = ip[i] % N;
  for (i = in; i < dn; i++)
    dp[i] = 0;
}

static void
fft (mp_ptr p, mp_size_t size, mp_srcptr wp, mp_limb_t N)
{
  mp_ptr p1, p2;
  mp_size_t t, s;
  mp_size_t tl = size >> 1;
  mp_size_t sl = 1;

  do
    {
      p1 = p;
      p2 = p + tl;
      for (t = 0; t < tl; t++)
	{
	  mp_limb_t a, x;
	  mp_limb_t z;

	  a = p1[t];
	  x = p2[t];

	  z = a + x;
	  if (z >= N || z < a)
	    z -= N;
	  p1[t] = z;

	  z = a - x;
	  if (z > a)
	    z += N;
	  p2[t] = z;
	}

      p1 += 2 * tl;
      p2 += 2 * tl;
      for (s = 1; s < sl; s++)
	{
	  mp_limb_t w = wp[s];
	  for (t = 0; t < tl; t++)
	    {
	      mp_limb_t a, x;
	      mp_limb_t z;
	      mp_limb_t xh, xl;
	      int dummy;

	      umul_ppmm (xh, xl, w, p2[t]);
	      udiv_qrnnd (dummy, x, xh, xl, N);
	      a = p1[t];

	      z = a + x;
	      if (z >= N || z < a)
		z -= N;
	      p1[t] = z;

	      z = a - x;
	      if (z > a)
		z += N;
	      p2[t] = z;
	    }

	  p1 += 2 * tl;
	  p2 += 2 * tl;
	}
      sl <<= 1;
      tl >>= 1;
    }
  while (tl > 0);

  for (s = 1, t = 0; s < size; s++)
    {
      t = back_add (t, size/2);
      if (s < t)
	{
	  mp_limb_t x;
	  x = p[s],  p[s] = p[t],  p[t] = x;
	}
    }
}

static void
ffi (mp_ptr p, mp_size_t size, mp_srcptr wp, mp_limb_t N)
{
  mp_ptr p1, p2;
  mp_size_t t, s;
  mp_size_t tl = size >> 1;
  mp_size_t sl = 1;
  int lbn = 0;

  do
    {
      p1 = p;
      p2 = p + tl;
      for (s = 0; s < sl; s++)
	{
	  mp_limb_t w = wp[s];
	  for (t = 0; t < tl; t++)
	    {
	      mp_limb_t a, x;
	      mp_limb_t z;
	      mp_limb_t xh, xl;
	      int dummy;

	      umul_ppmm (xh, xl, w, p2[t]);
	      udiv_qrnnd (dummy, x, xh, xl, N);
	      a = p1[t];

	      z = a + x;
	      if (z >= N || z < a)
		z -= N;
	      p1[t] = z;

	      z = a - x;
	      if (z > a)
		z += N;
	      p2[t] = z;
	    }

	  p1 += 2 * tl;
	  p2 += 2 * tl;
	}
      sl <<= 1;
      tl >>= 1;

      lbn++;
    }
  while (tl > 0);

  {
    mp_limb_t xh, xl;
    int dummy;
    mp_limb_t w = invert_mod (1 << lbn, N);
    umul_ppmm (xh, xl, w, p[0]);
    udiv_qrnnd (dummy, p[0], xh, xl, N);
    for (s = 1, t = 0; s < size; s++)
      {
	mp_limb_t x;

	t = back_add (t, size/2);
	x = p[s];
	if (s <= t)
	  {
	    if (s < t)
	      {
		umul_ppmm (xh, xl, w, p[t]);
		udiv_qrnnd (dummy, p[s], xh, xl, N);
	      }
	    umul_ppmm (xh, xl, w, x);
	    udiv_qrnnd (dummy, p[t], xh, xl, N);
	  }
      }
  }
}

static void

emul (mp_ptr p, mp_srcptr m, mp_size_t size, mp_limb_t N)
{
  mp_size_t i;
  mp_limb_t xh, xl, dummy;

  for (i = size - 1; i >= 0; i--)
    {
      umul_ppmm (xh, xl, p[i], m[i]);
      udiv_qrnnd (dummy, p[i], xh, xl, N);
    }
}

static void
crr (mp_ptr Uptr, mp_limb_t r1, mp_limb_t r2, mp_limb_t r3)
{
  mp_limb_t u, o;
  mp_limb_t U0, U1, U2;
  mp_limb_t dummy;
  mp_limb_t t0, t1, t2;

  if (MOD_3 < MOD_2 || MOD_3 < MOD_1)
    abort ();

  U0 = r1;

  u = U0 % MOD_2;
  r2 = r2 > u ? r2 - u : r2 + (MOD_2 - u);

  /* o = r2 * M1_I_M2 mod MOD_2; */
  umul_ppmm (t1, t0, r2, M1_I_M2);
  udiv_qrnnd (dummy, o, t1, t0, MOD_2);

  /* U += o * MOD_1; */
  umul_ppmm (t1, t0, o, MOD_1);
  add_ssaaaa (U1, U0, 0, U0, t1, t0);


  /* u = U mod MOD_3; */
  udiv_qrnnd (dummy, u, U1, U0, MOD_3);

  r3 = r3 > u ? r3 - u : r3 + (MOD_3 - u);

  /* o = r3 * M1M2_I_M3 mod MOD_3; */
  umul_ppmm (t1, t0, r3, M1M2_I_M3);
  udiv_qrnnd (dummy, o, t1, t0, MOD_3);

  /* U += o * M1M2; */
  umul_ppmm (t1, t0, o, M1M2_0);
  add_ssaaaa (U1, U0, U1, U0, t1, t0);
  umul_ppmm (t2, t1, o, M1M2_1);
  add_ssaaaa (U2, U1, 0, U1, t2, t1);

  Uptr[0] = U0;
  Uptr[1] = U1;
  Uptr[2] = U2;
}

static void
un_modularize (mp_ptr rp,
	       mp_srcptr m1p, mp_srcptr m2p, mp_srcptr m3p,
	       mp_size_t size)
{
  mp_limb_t part_res[3];
  mp_limb_t cy0, cy1;
  mp_size_t i;

  cy0 = cy1 = 0;
  for (i = 0; i < size; i++)
    {
      crr (part_res, m1p[i], m2p[i], m3p[i]);
      add_ssaaaa (cy0, rp[i], part_res[1], part_res[0], cy1, cy0);
      if (cy0 < cy1)
	cy1 = part_res[2] + 1;
      else
	cy1 = part_res[2];
    }
}

mp_ptr omega_1, omega_2, omega_3;
mp_ptr omega_inv_1, omega_inv_2, omega_inv_3;
mp_size_t omega_tab_size = 0;

mp_size_t
mpn_p3mul (mp_ptr prodp,
	   mp_srcptr up, mp_size_t un,
	   mp_srcptr vp, mp_size_t vn)
{
  mp_size_t n;
  mp_size_t prod_size;

#define FFT_THRESHOLD 1024
  if (vn < FFT_THRESHOLD)
    {
      /* Handle simple cases with traditional multiplication.

	 This is the most critical code of the entire function.  All
	 multiplies rely on this, both small and huge.  Small ones arrive
	 here immediately.  Huge ones arrive here as this is the base case
	 for Karatsuba's recursive algorithm below.  */
      mp_size_t i;
      mp_limb_t cy_limb;
      mp_limb_t v_limb;

      if (vn == 0)
	return 0;

      /* Multiply by the first limb in V separately, as the result can be
	 stored (not added) to PROD.  We also avoid a loop for zeroing.  */
      v_limb = vp[0];
      if (v_limb <= 1)
	{
	  if (v_limb == 1)
	    MPN_COPY (prodp, up, un);
	  else
	    MPN_ZERO (prodp, un);
	  cy_limb = 0;
	}
      else
	cy_limb = mpn_mul_1 (prodp, up, un, v_limb);

      prodp[un] = cy_limb;
      prodp++;

      /* For each iteration in the outer loop, multiply one limb from
	 U with one limb from V, and add it to PROD.  */
      for (i = 1; i < vn; i++)
	{
	  v_limb = vp[i];
	  if (v_limb <= 1)
	    {
	      cy_limb = 0;
	      if (v_limb == 1)
		cy_limb = mpn_add_n (prodp, prodp, up, un);
	    }
	  else
	    cy_limb = mpn_addmul_1 (prodp, up, un, v_limb);

	  prodp[un] = cy_limb;
	  prodp++;
	}
      return un + vn - (cy_limb == 0);
    }

  n = (un + 1) / 2;

  if (vn <= n)
    {
      /* If U has at least twice as many digits as V.  Split U in two
	 pieces, U1 and U0, such that U = U0 + U1*(2**GMP_LIMB_BITS)**N,
	 and recursively multiply the two pieces separately with V.  */

      mp_ptr tmp;
      mp_size_t tmp_size_t;
      mp_limb_t cy;

      /* V1 (the high part of V) is zero.  */

      /* Perform (U0 * V).  */
      prod_size = mpn_p3mul (prodp, up, n, vp, vn);

      tmp = alloca ((un + vn - n) * GMP_LIMB_BYTES);

      /* Perform (U1 * V).  Make sure the first source argument to mpn_p3mul
	 is not less than the second source argument.  */
      if (vn <= un - n)
	tmp_size_t = mpn_p3mul (tmp, up + n, un - n, vp, vn);
      else
	tmp_size_t = mpn_p3mul (tmp, vp, vn, up + n, un - n);

      /* In this addition hides a potentially large copying of TMP.  */
      if (prod_size - n >= tmp_size_t)
	cy = mpn_add (prodp + n, prodp + n, prod_size - n, tmp, tmp_size_t);
      else
	cy = mpn_add (prodp + n, tmp, tmp_size_t, prodp + n, prod_size - n);
      if (cy)
	abort (); /* prodp[prod_size] = 1; */

      return tmp_size_t + n;
    }
  else
    {
      /* The sizes of U and V are close.  Use the 3 prime FFT algorithm.  */

      mp_size_t size;
      int cnt;
      mp_ptr U_1, U_2, U_3;
      mp_ptr V_1, V_2, V_3;

      count_leading_zeros (cnt, un - 1);
      size = 2 << (GMP_LIMB_BITS - cnt);

      U_1 = alloca (size * GMP_LIMB_BYTES);
      U_2 = alloca (size * GMP_LIMB_BYTES);
      U_3 = alloca (size * GMP_LIMB_BYTES);
      V_1 = alloca (size * GMP_LIMB_BYTES);
      V_2 = alloca (size * GMP_LIMB_BYTES);
      V_3 = alloca (size * GMP_LIMB_BYTES);

      if (size > omega_tab_size)
	{
	  if (omega_tab_size != 0)
	    {
	      __GMP_FREE_FUNC_LIMBS (omega_1, omega_tab_size);
	      __GMP_FREE_FUNC_LIMBS (omega_2, omega_tab_size);
	      __GMP_FREE_FUNC_LIMBS (omega_3, omega_tab_size);
	    }

	  omega_1 = __GMP_ALLOCATE_FUNC_LIMBS (size/2);
	  omega_2 = __GMP_ALLOCATE_FUNC_LIMBS (size/2);
	  omega_3 = __GMP_ALLOCATE_FUNC_LIMBS (size/2);
	  omega_inv_1 = __GMP_ALLOCATE_FUNC_LIMBS (size/2);
	  omega_inv_2 = __GMP_ALLOCATE_FUNC_LIMBS (size/2);
	  omega_inv_3 = __GMP_ALLOCATE_FUNC_LIMBS (size/2);

	  omega_tab_size = size;
	  omega_init (omega_1, omega_inv_1, omega_tab_size, MOD_1, P_ELEM_1);
	  omega_init (omega_2, omega_inv_2, omega_tab_size, MOD_2, P_ELEM_2);
	  omega_init (omega_3, omega_inv_3, omega_tab_size, MOD_3, P_ELEM_3);
	}

      modularize (U_1, size, up, un, MOD_1);
      fft (U_1, size, omega_1, MOD_1);

      modularize (U_2, size, up, un, MOD_2);
      fft (U_2, size, omega_2, MOD_2);

      modularize (U_3, size, up, un, MOD_3);
      fft (U_3, size, omega_3, MOD_3);

      modularize (V_1, size, vp, vn, MOD_1);
      fft (V_1, size, omega_1, MOD_1);

      modularize (V_2, size, vp, vn, MOD_2);
      fft (V_2, size, omega_2, MOD_2);

      modularize (V_3, size, vp, vn, MOD_3);
      fft (V_3, size, omega_3, MOD_3);

      emul (U_1, V_1, size, MOD_1);
      ffi (U_1, size, omega_inv_1, MOD_1);

      emul (U_2, V_2, size, MOD_2);
      ffi (U_2, size, omega_inv_2, MOD_2);

      emul (U_3, V_3, size, MOD_3);
      ffi (U_3, size, omega_inv_3, MOD_3);

      size = un + vn;
      un_modularize (prodp, U_1, U_2, U_3, size);

      return (prodp[size - 1] != 0) ? size : size - 1;
    }
}
#if 0
mp_limb_t
mpn_mul (wp, up, un, vp, vn)
     mp_ptr wp;
     mp_srcptr up;
     mp_size_t un;
     mp_srcptr vp;
     mp_size_t vn;
{
  mpn_p3mul (wp, up, un, vp, vn);
  return wp[un + vn - 1];
}
#endif

--=-=-=
Content-Type: text/plain; charset=iso-8859-1
Content-Transfer-Encoding: 8bit


--
Torbjörn

--=-=-=--