Small gcdext_1

Niels Möller nisse at lysator.liu.se
Tue Oct 8 21:02:33 UTC 2019


For small gcdext (one or two limbs), it seems likely that a branch-free
binary algorithm is a good choice. The binary algorithm in
mpn/generic/gcdext_1.c seems to work now, after a recent bugfix, but
it's not that fast. One reason is that the loop updates both cofactors,
which costs both instructions and registers.

If we do a gcdext computing only one cofactor, I'd like to think about
it as a modular inversion function. For u, v, with v odd, attempt to compute
u^{-1} (mod v). More precisely, always compute gcd(u,v). In addition,
compute a cofactor s according to these rules:

If u = 0 (mod v), equivalently, gcd (u,v) = v, set s = 0

If gcd (u,v) = 1, set s = u^{-1} (mod v)

If gcd (u,v) > 1, set s = (u/g)^{-1} (mod v/g)

Below are three implementations of such a function, one based on
euclid's algorithm, one plain binary, and one binary using same masking
tricks as in gcdext_1.c.

Regards,
/Niels

#include <assert.h>
#include <stdio.h>
#include <stdlib.h>

#include "gmp.h"
#include "gmp-impl.h"
#include "longlong.h"

/* v must be odd. Returns d0 = g = gcd(u,v).
   If u = 0 (mod v), returns d1 = 0.
   Otherwise, returns d1 = (u/g)^-1 mod (v/g)
*/
static mp_double_limb_t
modinvert_euclid(mp_limb_t u, mp_limb_t input_v)
{
  mp_limb_t s0, s1, v;

  s0 = 1; s1 = 0;
  v = input_v;

  assert (v & 1);

  u %= v;

  /* Maintain

     U = t1 u + t0 v
     V = s1 u + s0 v = s1 (u - q*v) + s1 q v + s0 v = s1(u - qv) + (s0 + q s1) v

     where U, V are the inputs and the matrix has determinant 1.
     Inverting gives

     u =  s0 U - t0 V
     v = -s1 U + t1 V
  */
  for (;;)
    {
      mp_limb_t q;
      if (u == 0)
	{
	  mp_double_limb_t r;
	  r.d0 = v;
	  if (s1 > 0)
	    s1 = input_v / v - s1;
	  r.d1 = s1;

	  return r;
	}

      q = v / u;
      v -= q*u;
      s1 += q*s0;

      if (v == 0)
	{
	  mp_double_limb_t r;
	  r.d0 = u;
	  r.d1 = s0;
	  return r;
	}
      q = u / v;
      u -= q*v;
      s0 += q*s1;
    }
}

static mp_double_limb_t
modinvert_binary (mp_limb_t u, mp_limb_t input_v)
{
  mp_limb_t s0, s1, v, v_div_g;
  mp_double_limb_t r;
  int shift;

  v = input_v;

  assert (v & 1);
  if (u == 0)
    {
      mp_double_limb_t r;
      r.d0 = v;
      r.d1 = 0;
      return r;
    }
  count_trailing_zeros (shift, u);
  u >>= shift;
  s1 = 0;
  s0 = 1;

  /* Maintain

     U = t1 u + t0 v
     V = s1 u + s0 v = [2^k s1 (u-v)/ 2^k + (s0 + s1) v]

     where U, V are the inputs and the matrix has determinant 2^{shift}.
  */

  while (u != v)
    {
      int count;
      if (u > v)
	{
	  u -= v;
	  count_trailing_zeros (count, u);
	  u >>= count;
	  s0 += s1;
	  s1 <<= count;
	  shift += count;
	}
      else
	{
	  v -= u;
	  count_trailing_zeros (count, v);
	  v >>= count;
	  s1 += s0;
	  s0 <<= count;
	  shift += count;
	}
    }
  /* 2^{shift} g = s0 U */
  assert (input_v % v == 0);
  v_div_g = input_v / v;

  if (v_div_g == 1)
    /* u == 0 (mod v) */
    s0 = 0;
  else
    {
      for (; shift > 0; shift--)
	{
	  if (s0 & 1)
	    s0 = s0 / 2 + (v_div_g / 2) + 1;
	  else
	    s0 /= 2;
	}
    }
  r.d0 = v;
  r.d1 = s0;

  return r;
}

static mp_double_limb_t
modinvert_binary_mask (mp_limb_t u, mp_limb_t input_v)
{
  mp_limb_t s0, s1, v, v_div_g, sign;
  mp_double_limb_t r;
  int shift;

  v = input_v;

  assert (v & 1);
  if (u == 0)
    {
      mp_double_limb_t r;
      r.d0 = v;
      r.d1 = 0;
      return r;
    }
  count_trailing_zeros (shift, u);
  u >>= shift;
  s1 = 0;
  s0 = 1;

  /* Maintain

     U = t1 u + t0 v
     V = s1 u + s0 v = [2^k s1 (u-v)/ 2^k + (s0 + s1) v]

     where U, V are the inputs and the matrix has determinant 2^{shift}.
  */
  u >>= 1;
  v >>= 1;
  sign = 0;

  while (u != v)
    {
      int count;
      mp_limb_t d =  u - v;
      mp_limb_t vgtu = LIMB_HIGHBIT_TO_MASK (d);
      mp_limb_t sx;

      /* When v < u (vgtu == 0), the updates are:

	   (u; v)   <-- ( (u - v) >> count; v)    (det = +(1<<count) for corr. M factor)
	   (s1, s0) <-- (s1 << count, s0 + s1)

	 and when v > 0, the updates are

	   (u; v)   <-- ( (v - u) >> count; u)    (det = -(1<<count))
	   (s1, s0) <-- (s0 << count, s0 + s1)
      */

      /* v <-- min (u, v) */
      v += (vgtu & d);

      /* u <-- |u - v| */
      u = (d ^ vgtu) - vgtu;

      count_trailing_zeros (count, d);

      sign ^= vgtu;

      sx = vgtu & (s0 - s1);
      s0 += s1;
      s1 += sx;

      count++;
      u >>= count;
      s1 <<= count;
      shift += count;
    }
  v = (v<<1) | 1;
  /* 2^{shift} g = s0 U */
  assert (input_v % v == 0);
  v_div_g = input_v / v;

  if (v_div_g == 1)
    /* u == 0 (mod v) */
    s0 = 0;
  else
    {
      /* FIXME: Use binvert instead of looping. */
      for (; shift > 0; shift--)
	{
	  if (s0 & 1)
	    s0 = s0 / 2 + (v_div_g / 2) + 1;
	  else
	    s0 /= 2;
	}
      s0 = (sign & (v_div_g + 1)) + (sign ^ s0);
    }

  r.d0 = v;
  r.d1 = s0;

  return r;
}


static int
modinvert_validate (mp_limb_t u, mp_limb_t v, mp_limb_t g, mp_limb_t s)
{
  mp_limb_t p1, p0, q, r;
  u %= v;
  if (u == 0)
    return g == v && s == 0;
  if (u % g || v %g)
    return 0;

  u /= g;
  v /= g;
  if (s >= v)
    return 0;

  umul_ppmm (p1, p0, s, u);
  udiv_qrnnd (q, r, p1, p0, v);
  return r == 1;
}

#define COUNT 1000000
int
main (int argc, char **argv)
{
  gmp_randstate_t rands;
  int i;
  gmp_randinit_default (rands);
  for (i = 0; i < COUNT; i++)
    {
      unsigned u_bits = 1 + gmp_urandomm_ui (rands, GMP_NUMB_BITS);
      unsigned v_bits = 1 + gmp_urandomm_ui (rands, GMP_NUMB_BITS);
      mp_limb_t u = gmp_urandomb_ui(rands, u_bits);
      mp_limb_t v = gmp_urandomb_ui(rands, v_bits) | 1;
      mp_double_limb_t ref, r;
      ref = modinvert_euclid (u, v);
      if (!modinvert_validate (u, v, ref.d0, ref.d1))
	{
	  gmp_printf ("modinvert_euclid failed: u = %Mx v = %Mx\n"
		      "  g = %Mx, s = %Mx\n", u, v, ref.d0, ref.d1);
	  exit(EXIT_FAILURE);
	}
      r = modinvert_binary (u,v);
      if (r.d0 != ref.d0 || r.d1 != ref.d1)
	{
	  gmp_printf ("modinvert_binary failed: u = %Mx v = %Mx\n"
		      "  g = %Mx, s = %Mx\n", u, v, r.d0, r.d1);
	  exit(EXIT_FAILURE);
	}
      r = modinvert_binary_mask (u,v);
      if (r.d0 != ref.d0 || r.d1 != ref.d1)
	{
	  gmp_printf ("modinvert_binary_mask failed: u = %Mx v = %Mx\n"
		      "  g = %Mx, s = %Mx\n", u, v, r.d0, r.d1);
	  exit(EXIT_FAILURE);
	}
    }
  return EXIT_SUCCESS;
}

-- 
Niels Möller. PGP-encrypted email is preferred. Keyid 368C6677.
Internet email is subject to wholesale government surveillance.



More information about the gmp-devel mailing list