3-prime FFT

Kevin Ryde user42@zip.com.au
Fri, 13 Dec 2002 10:47:24 +1000


--=-=-=

Torbjorn Granlund <tege@swox.com> writes:
>
> This implementation used a umul_ppmm and a udiv_qrnnd in the
> inner loop (see the function `fft').  This is silly.  At the
> very least udiv_qrnnd_preinv should be used.  But of course,
> the real trick is to use Montgomery multiplication!

Some code below I made a while ago for 1-limb redc.  Comes out quite
nice on i386.  I think it works, or it did whenever I last ran it :-).


--=-=-=
Content-Type: text/x-chdr
Content-Disposition: attachment; filename=redc1.h

/* Usage:

   Start with an mp_limb_t modulus N, then use the various INIT macros to
   setup values dependent on the modulus and required for later macros.  Eg.

        mp_limb_t  Ninv;
        REDC1_INIT_Ninv (Ninv, N);

   Convert input values to residues with either REDC1_IN or REDC1_IN_SINGLE.
   The latter doesn't need R2modN and is probably faster if only one or two
   inputs are to be converted.

        mp_limb_t  xR;
        REDC1_IN (xR, x, N, R2modN);

   Special macros REDC1_IN_1 and REDC1_IN_NEG1 give residues 1 and -1.

   Arithmetic can be done,

        mp_limb_t  zR;
        REDC1_ADD (zR, xR, yR, N);
        REDC1_SUB (zR, xR, yR, N);
        REDC1_MUL (zR, xR, yR, N, Ninv);

   And results are converted back out of redc format with REDC1_OUT,

        mp_limb_t  x;
        REDC1_OUT (x, xR, N, Ninv);

   If the modulus N is no more than GMP_LIMB_BITS-1, ie. has at least one
   unused bit at the high end, then #define MODULUS_HAS_ONE_BIT_SPARE to run
   very slightly faster.


   Possible enhancements:

   Make a REDC1_IN_Npreinv using udiv_qrnnd_preinv.

   Setup a normalized copy of N for use when UDIV_NEEDS_NORMALIZATION, and
   arrange that the code for that can go dead when not needed.

*/

#include "gmp.h"
#include "gmp-impl.h"
#include "longlong.h"


#ifdef MODULUS_HAS_ONE_BIT_SPARE
#define addmod_saam(r, a, b, m)                                 \
  do {                                                          \
    mp_limb_t  __addmod1__r;                                    \
    ASSERT ((a) < (m));                                         \
    ASSERT ((b) < (m));                                         \
    ASSERT ((m) < GMP_LIMB_HIGHBIT);                            \
    __addmod1__r = (a) + (b);                                   \
    (r) = __addmod1__r - (__addmod1__r >= (m) ? (m) : 0);       \
    ASSERT ((r) < (m));                                         \
  } while (0)
#endif
#ifndef addmod_saam
#define addmod_saam(r, a, b, m)                                         \
  do {                                                                  \
    mp_limb_t  __addmod1__a = (a);                                      \
    mp_limb_t  __addmod1__m = (m);                                      \
    mp_limb_t  __addmod1__r;                                            \
    ASSERT ((a) < (m));                                                 \
    ASSERT ((b) < (m));                                                 \
                                                                        \
    __addmod1__r = __addmod1__a + (b);                                  \
    if (__addmod1__r < __addmod1__a || __addmod1__r >= __addmod1__m)    \
      __addmod1__r -= __addmod1__m;                                     \
    (r) = __addmod1__r;                                                 \
    ASSERT ((r) < (m));                                                 \
  } while (0)
#endif

/* r = a-b mod m */
#ifdef MODULUS_HAS_ONE_BIT_SPARE
#define submod_dmsm(r, a, b, m)                                               \
  do {                                                                        \
    mp_limb_t  __submod1__r;                                                  \
    ASSERT ((a) < (m));                                                       \
    ASSERT ((b) < (m));                                                       \
    ASSERT ((m) < GMP_LIMB_HIGHBIT);                                          \
    __submod1__r = (a) - (b);                                                 \
    (r) = __submod1__r + ((mp_limb_signed_t) __submod1__r < 0 ? (m) : 0);     \
    ASSERT ((r) < (m));                                                       \
  } while (0)
#endif
#ifndef submod_dmsm
#define submod_dmsm(r, a, b, m)                                         \
  do {                                                                  \
    mp_limb_t  __submod1__a = (a);                                      \
    mp_limb_t  __submod1__r;                                            \
    ASSERT ((a) < (m));                                                 \
    ASSERT ((b) < (m));                                                 \
    __submod1__r = __submod1__a - (b);                                  \
    (r) = __submod1__r + (__submod1__r > __submod1__a ? (m) : 0);       \
    ASSERT ((r) < (m));                                                 \
  } while (0)
#endif

/* r = -a mod m */
#define negmod_nam(r, a, m)                                             \
  do {                                                                  \
    mp_limb_t  __negmod1__a = (a);                                      \
    ASSERT ((a) < (m));                                                 \
    (r) = (LIKELY (__negmod1__a != 0) ? (m) : 0) - __negmod1__a;        \
  } while (0)

/* udiv_qrnnd, but always allowing unnormalized divisors */
#if 1// UDIV_NEEDS_NORMALIZATION
#define udiv_qrnnd_unnorm(q, r, nh, nl, d)                      \
  do {                                                          \
    mp_limb_t  __udu__nh, __udu__nl, __udu__d, __udu__r;        \
    unsigned   __shift;                                         \
    ASSERT ((d) != 0);                                          \
    ASSERT ((nh) < (d));                                        \
    __udu__d = (d);                                             \
    count_leading_zeros (__shift, __udu__d);                    \
    __udu__d <<= __shift;                                       \
    __udu__nh = (nh);                                           \
    __udu__nl = (nl);                                           \
    __udu__nh <<= __shift;                                      \
    if (__shift != 0)                                           \
      __udu__nh |=  (__udu__nl >> (GMP_LIMB_BITS - __shift));   \
    __udu__nl <<= __shift;                                      \
    udiv_qrnnd (q, __udu__r, __udu__nh, __udu__nl, __udu__d);   \
    (r) = __udu__r >> __shift;                                  \
  } while (0)
#else
#define udiv_qrnnd_unnorm(q, r, nh, nl, d)   udiv_qrnnd (q, r, nh, nl, d)
#endif

/* r = x*y mod m */
#define umulmod_pmmd(r, x, y, m)                                        \
  do {                                                                  \
    mp_limb_t  __umm__hi, __umm__lo, __umm__dummy_q;                    \
    umul_ppmm (__umm__hi, __umm__lo, x, y);                             \
    udiv_qrnnd_unnorm (__umm__dummy_q, r, __umm__hi, __umm__lo, m);     \
    ASSERT ((r) == ((unsigned long long) (x) * (y)) % (m));             \
  } while (0)

/* Calculate Ninv satisfying N*Ninv == 1 mod R */
#define REDC1_INIT_Ninv(Ninv, N)                \
  do {                                          \
    modlimb_invert (Ninv, N);                   \
    TRACE (printf ("   Ninv=%#lX\n", Ninv));    \
  } while (0)



/* R mod N, calculated as (R-N) % N */
#define REDC1_INIT_RmodN(RmodN, N)              \
  do {                                          \
    (RmodN) = (- (mp_limb_t) N) % N;            \
    TRACE (printf ("   RmodN=%#lX\n", RmodN));  \
  } while (0)

/* R^2 mod N */
#define REDC1_INIT_R2modN(R2modN, N, RmodN)             \
  do {                                                  \
    umulmod_pmmd (R2modN, RmodN, RmodN, N);             \
    TRACE (printf ("   R^2modN=%#lX\n", R2modN));       \
  } while (0)


/* t = x converted to redc form, using R2modN. */
#define REDC1_IN(t, x, N, Ninv, R2modN) \
  do {                                  \
    REDC1_MUL (t, x, R2modN, N, Ninv);  \
  } while (0)

/* t = x converted to redc form, using only N.
   This is simply t = xR mod N. */
#define REDC1_IN_SINGLE(xR, x, N)                                       \
  do {                                                                  \
    mp_limb_t  __dummy__q;                                              \
    mp_limb_t  __redc1in__x = (x);                                      \
    mp_limb_t  __redc1in__N = (N);                                      \
    if (UNLIKELY (__redc1in__x >= __redc1in__N))                        \
      __redc1in__x %= __redc1in__N;                                     \
    udiv_qrnnd_unnorm (__dummy__q, xR, __redc1in__x, CNST_LIMB(0),      \
                       __redc1in__N);                                   \
  } while (0)

/* t = 1 in redc form, which is simply RmodN */
#define REDC1_IN_1(t, RmodN)    \
  do {                          \
    (t) = (RmodN);              \
  } while (0)

/* t = N-1 in redc form, which is simply N-RmodN */
#define REDC1_IN_NEG1(t, N, RmodN)      \
  do {                                  \
    ASSERT (RmodN != 0);                \
    (t) = (N) - (RmodN);                \
  } while (0)


/* REDC1_REDUCE does:
   m = (Th,Tl)*Ninv mod R.
   t = (Th,Tl - m*N) / R.
   if (t < 0) return t+N else return t.  */
#define REDC1_REDUCE(t, Th, Tl, N, Ninv)                        \
  do {                                                          \
    mp_limb_t  __redc1__hi, __redc1__lo;                        \
    ASSERT ((N) * (Ninv) == 1);                                 \
    ASSERT ((Th) < (N));                                        \
    umul_ppmm (__redc1__hi,__redc1__lo, (Tl) * (Ninv), N);      \
    ASSERT (__redc1__lo == (Tl));                               \
    submod_dmsm (t, Th, __redc1__hi, N);                        \
    ASSERT ((t) < (N));                                         \
  } while (0)

/* z = x*y, in redc form */
#define REDC1_MUL(z, x, y, N, Ninv)                             \
  do {                                                          \
    mp_limb_t  __redc1mul__Th, __redc1mul__Tl;                  \
    umul_ppmm (__redc1mul__Th, __redc1mul__Tl, x, y);           \
    REDC1_REDUCE (z, __redc1mul__Th, __redc1mul__Tl, N, Ninv);  \
  } while (0)

/* z = x+y, all in redc form */
#define REDC1_ADD(z, x, y, N)   \
  do {                          \
    addmod_saam (z, x, y, N);   \
  } while (0)

/* z = x-y, all in redc form */
#define REDC1_SUB(z, x, y, N)   \
  do {                          \
    submod_dmsm (z, x, y, N);   \
  } while (0)

/* z = x-y, all in redc form */
#define REDC1_NEG(z, x, y, N)   \
  do {                          \
    negmod_nam (z, x, y, N);   \
  } while (0)


/* Take a number out of redc form.  Same as REDC1_REDUCE(x, 0,xR, N,Ninv),
   but can use negmod_nam knowing Th==0. */
#define REDC1_OUT(x, xR, N, Ninv)                                       \
  do {                                                                  \
    mp_limb_t  __redc1out__hi, __redc1out__lo;                          \
    ASSERT ((xR) < (N));                                                \
    ASSERT ((N) * (Ninv) == 1);                                         \
    umul_ppmm (__redc1out__hi, __redc1out__lo, (xR) * (Ninv), N);       \
    ASSERT (__redc1out__lo == (xR));                                    \
    negmod_nam (x, __redc1out__hi, N);                                  \
  } while (0)

--=-=-=--