[Paul Zimmermann <Paul.Zimmermann@loria.fr>] mpz_cbrtrem

Torbjorn Granlund tege@swox.com
04 Nov 2002 10:39:36 +0100


This is of interest to gmp-devel I think.

From: Paul Zimmermann <Paul.Zimmermann@loria.fr>
Subject: mpz_cbrtrem

here is a first implementation of mpz_cbrtrem which is asymptotically
optimal, and even faster than mpz_sqrtrem:

[zimmerma@ecrouves ~/gmp]$ ./cbrtrem 50000 1
mpz_root took 29670ms
mpz_sqrtrem took 2270ms
mpz_cbrtrem took 1950ms

while being reasonably fast for one-limb operands:

[zimmerma@ecrouves ~/gmp]$ ./cbrtrem 1 100000
mpz_root took 1830ms
mpz_sqrtrem took 110ms
mpz_cbrtrem took 2400ms

The key function is mpz_add_bits: set the low bits of the destination
to bits n0 to n1-1 from the operand.

Paul

/* mpz_cbrtrem -- cube root and remainder

Copyright 2002 Free Software Foundation, Inc.
Contributed by Paul Zimmermann.

This file is part of the GNU MP Library.

The GNU MP Library is free software; you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation; either version 2.1 of the License, or (at your
option) any later version.

The GNU MP Library is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
License for more details.

You should have received a copy of the GNU Lesser General Public License
along with the GNU MP Library; see the file COPYING.LIB.  If not, write to
the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
MA 02111-1307, USA.
*/

#include <stdio.h>
#include "cputime.h"
#include "gmp.h"
#include "gmp-impl.h"

/* set low bits of r to bits [n0,n1[ of n */
void
mpz_add_bits (mpz_ptr r, mpz_srcptr n, mp_size_t n0, mp_size_t n1)
{
  unsigned int a0, a1, ln, lr, sh;
  mp_limb_t rh, rh2, *rp, mask;

  a0 = n0 / BITS_PER_MP_LIMB; /* limb where bit n0 is */
  a1 = (n1 - 1) / BITS_PER_MP_LIMB; /* limb where bit n1-1 is */
  ln = a1 + 1 - a0;
  lr = 1 + (n1 - n0 - 1) / BITS_PER_MP_LIMB;
  sh = n0 % BITS_PER_MP_LIMB;
  rp = PTR(r);
  if (ALLOC(r) < ln)
    _mpz_realloc (r, ln);
  if (SIZ(r) < lr)
    {
      SIZ(r) = lr;
      rh = 0;
    }
  else
    rh = rp[lr-1];
  if (ln > lr)
    rh2 = rp[lr];
  if (sh)
    mpn_rshift (rp, PTR(n) + a0, ln, sh);
  else
    MPN_COPY (rp, PTR(n) + a0, ln);
  sh = (n1 - n0) % BITS_PER_MP_LIMB;
  if (sh)
    {
      mask = ((mp_limb_t) 1 << sh) - 1;
      rp[lr-1] &= mask;
    }
  if (ln > lr)
    rp[lr] = rh2;
  rp[lr-1] |= rh;
  lr = SIZ(r);
  MPN_NORMALIZE (rp, lr);
  SIZ(r) = lr;
}

/* n -> s^2 + r with s having k bits,
   and using bits [n0,n1[ of n */
void
mpz_cbrtrem_aux (mpz_ptr s, mp_size_t k, mpz_ptr r, mpz_srcptr n,
                 mp_size_t n0, mp_size_t n1, mpz_ptr q, mpz_ptr t)
{
  if (k == 1)
    {
      mpz_set_ui (r, 0);
      mpz_add_bits (r, n, n0, n1);
      mpz_set_ui (s, 1);
      mpz_sub_ui (r, r, 1);
      return;
    }
  else
    {
      int k2 = k / 2;
      mpz_cbrtrem_aux (s, k - k2, r, n, n0 + 3 * k2, n1, q, t);
      mpz_mul_2exp (r, r, k2);
      mpz_add_bits (r, n, n0 + 2 * k2, n0 + 3 * k2);
      mpz_mul (t, s, s);
      mpz_mul_ui (t, t, 3);
      mpz_tdiv_qr (q, r, r, t);
      mpz_mul_2exp (r, r, 2 * k2);
      mpz_add_bits (r, n, n0, n0 + 2 * k2);
      mpz_mul_2exp (s, s, k2);
      mpz_mul_ui (t, s, 3); /* 3*s */
      mpz_add (s, s, q);
      mpz_add (t, t, q); /* 3*s+q */
      mpz_mul (q, q, q);
      mpz_mul (t, t, q); /* q^2*(3*s + q) */
      mpz_sub (r, r, t);
      while (mpz_cmp_ui (r, 0) < 0)
        {
          mpz_sub_ui (t, s, 1);
          mpz_mul_ui (t, t, 3);
          mpz_mul (t, t, s);
          mpz_add (r, r, t);
          mpz_add_ui (r, r, 1);
          mpz_sub_ui (s, s, 1);
        }
    }
}

/* sets s and r such that s^3+r = n with r >= 0 and n < (s+1)^3 */
void
mpz_cbrtrem (mpz_ptr s, mpz_ptr r, mpz_srcptr n)
{
  mp_size_t nbits; /* number of bits in n */
  mp_size_t sbits; /* number of bits in s */
  mpz_t q, t;

  if (mpz_cmp_ui (n, 0) == 0)
    {
      mpz_set_ui (s, 0);
      mpz_set_ui (r, 0);
      return;
    }

  /* now n >= 1 */
  nbits = mpz_sizeinbase (n, 2); /* 2^(nbits-1) <= n < 2^nbits */
  sbits = (nbits + 2) / 3;

  mpz_init (q);
  mpz_init (t);
  
  mpz_cbrtrem_aux (s, sbits, r, n, 0, nbits, q, t);

  mpz_clear (q);
  mpz_clear (t);

  return;
}

int
main(int argc, char *argv[])
{
  int i, l, k, st;
  mpz_t n, s, r, t;

  l = atoi(argv[1]);
  k = atoi(argv[2]);

  mpz_init (n);
  mpz_init (s);
  mpz_init (r);
  mpz_init (t);

  mpz_random (n, l);

  st = cputime ();
  for (i=0; i<k; i++)
    mpz_root (s, n, 3);
  printf ("mpz_root took %dms\n", cputime () - st);

  st = cputime ();
  for (i=0; i<k; i++)
    mpz_sqrtrem (s, r, n);
  printf ("mpz_sqrtrem took %dms\n", cputime () - st);

  st = cputime ();
  for (i=0; i<k; i++)
    mpz_cbrtrem (s, r, n);
  printf ("mpz_cbrtrem took %dms\n", cputime () - st);

      mpz_mul (t, s, s);
      mpz_mul (t, t, s);
      mpz_add (t, t, r);
      if (mpz_cmp (t, n) != 0)
        {
          fprintf (stderr, "n <> s^2+r\n");
          fprintf (stderr, "n=");
          mpz_out_str (stderr, 10, n);
          fprintf (stderr, "\ns=");
          mpz_out_str (stderr, 10, s);
          fprintf (stderr, "\nr=");
          mpz_out_str (stderr, 10, r);
          fprintf (stderr, "\n");
          exit (1);
        }
      if (mpz_cmp_ui (r, 0) < 0)
        {
          fprintf (stderr, "r < 0 for n=%u\n", i);
          exit (1);
        }
      mpz_add_ui (t, s, 1);
      mpz_mul_ui (t, t, 3);
      mpz_mul (t, t, s); 
      if (mpz_cmp (r, t) > 0)
        {
          fprintf (stderr, "r > 3*s^2+3*s for n=%u\n", i);
          exit (1);
        }

  mpz_clear (n);
  mpz_clear (s);
  mpz_clear (r);
  mpz_clear (t);

  return 0;
}

-- 
Torbjörn