/* Copyright 2006, 2007 Free Software Foundation, Inc.

This program is free software; you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation; either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with
this program.  If not, see http://www.gnu.org/licenses/.  */


#include <stdlib.h>		/* for exit, strtoul */
#include <string.h>		/* for strlen */
#include <stdio.h>		/* for printf */
#include <math.h>		/* for log10, fmod, pow */
#include <unistd.h>		/* for isatty */

#include "gmp.h"
#include "gmp-impl.h"
#include "longlong.h"
#include "new.h"

static mp_size_t
sizearg (char *arg)
{
  if (arg[strlen (arg) - 1] == 'b')
    return strtoul (arg, 0, 0);
  else
    return strtoul (arg, 0, 0) * GMP_LIMB_BITS;
}


#ifdef CHECK
void
dumpy (mp_srcptr p, mp_size_t n)
{
  mp_size_t i;
  if (n > 20)
    {
      for (i = n - 1; i >= n - 4; i--)
	{
	  printf ("%0*lx", (int) (2 * sizeof (mp_limb_t)), p[i]);
	  printf (" ");
	}
      printf ("... ");
      for (i = 3; i >= 0; i--)
	{
	  printf ("%0*lx", (int) (2 * sizeof (mp_limb_t)), p[i]);
	  printf (" " + (i == 0));
	}
    }
  else
    {
      for (i = n - 1; i >= 0; i--)
	{
	  printf ("%0*lx", (int) (2 * sizeof (mp_limb_t)), p[i]);
	  printf (" " + (i == 0));
	}
    }
  puts ("");
}

unsigned long test, testh;
int err;

void
check_one (mp_ptr qp, mp_srcptr rp, mp_limb_t rh,
	   mp_srcptr np, mp_size_t nn, mp_srcptr dp, mp_size_t dn, char *fname)
{
  mp_size_t qn;
  int cmp;
  mp_ptr tp;
  mp_limb_t cy = 4711;		/* silence warnings */
  TMP_DECL;

  qn = nn - dn;

  if (qn == 0)
    return;

  TMP_MARK;

  tp = TMP_ALLOC_LIMBS (nn + 1);

  if (dn >= qn)
    mpn_mul (tp, dp, dn, qp, qn);
  else
    mpn_mul (tp, qp, qn, dp, dn);

  if (rp != NULL)
    {
      cy = mpn_add_n (tp + qn, tp + qn, rp, dn);
      cmp = cy != rh || mpn_cmp (tp, np, nn) != 0;
    }
  else
      cmp = mpn_cmp (tp, np, nn - dn) != 0;

  if (cmp != 0)
    {
      printf ("\r*******************************************************************************\n");
      printf ("%s inconsistent in test %lu,,%lu\n",
	      fname, testh, test);
      printf ("N=   "); dumpy (np, nn);
      printf ("D=   "); dumpy (dp, dn);
      printf ("Q=   "); dumpy (qp, qn);
      if (rp != NULL)
	{
	  printf ("R=   "); dumpy (rp, dn);
	  printf ("Rb=  %d, Cy=%d\n", (int) cy, (int) rh);
	}
      printf ("T=   "); dumpy (tp, nn);
      printf ("nn = %ld, dn = %ld, qn = %ld", nn, dn, qn);
      printf ("\n*******************************************************************************\n");
      if (++err >= 5)
	abort ();
    }

  TMP_FREE;
}

#include <getopt.h>

struct nameflag {char *name; int *flagp;};

int flag_mpn_sb_bdiv_qr, flag_mpn_sb_bdiv_q;
int flag_mpn_dc_bdiv_qr, flag_mpn_dc_bdiv_q;
int flag_mpn_dc_bdiv_qr_n, flag_mpn_dc_bdiv_q_n;
int flag_mpn_mu_bdiv_qr, flag_mpn_mu_bdiv_q;

static struct nameflag foo[] = {
  {"mpn_sb_bdiv_qr",     &flag_mpn_sb_bdiv_qr},
  {"mpn_sb_bdiv_q",      &flag_mpn_sb_bdiv_q},
  {"mpn_dc_bdiv_qr",     &flag_mpn_dc_bdiv_qr},
  {"mpn_dc_bdiv_q",      &flag_mpn_dc_bdiv_q},
  {"mpn_dc_bdiv_qr_n",   &flag_mpn_dc_bdiv_qr_n},
  {"mpn_dc_bdiv_q_n",    &flag_mpn_dc_bdiv_q_n},
//  {"mpn_mu_bdiv_qr",     &flag_mpn_mu_bdiv_qr},
  {"mpn_mu_bdiv_q",      &flag_mpn_mu_bdiv_q},
  {NULL,  NULL}
};

int testname;

static struct option longopts[] = {
  {"test",   required_argument,    &testname, 1},
  {NULL,     0,                    NULL,      0}
};

#if RAND_UNIFORM
#define MPZ_XRANDOMB mpz_urandomb
#else
#define MPZ_XRANDOMB mpz_rrandomb
#endif

int
main (int argc, char **argv)
{
  gmp_randstate_t rs;
  unsigned long maxnbits, maxdbits, dbits, qbits, rbits;
  mpz_t n, d, q, r;
  mp_size_t maxnn, maxdn, nn, dn;
  mp_ptr np, dp, qp, rp;
  int work = 0;
  mp_limb_t di;
  mp_limb_t rh;
  int specific_test_opt_seen = 0;
  int i;
  char *progname = argv[0];
  mp_size_t itch;
  mp_ptr scratch;
  mp_limb_t ran;
  TMP_DECL;

  /* Default to testing all functions.  */
  for (i = 0; foo[i].name != 0; i++)
    *(foo[i].flagp) = 1;

  for (;;)
    {
      int ch;
      testname = 0;
      ch = getopt_long (argc, argv, "", longopts, NULL);
      if (ch == -1)
	break;
      if (testname)
	{
	  if (! specific_test_opt_seen)
	    for (i = 0; foo[i].name != 0; i++)
	      *(foo[i].flagp) = 0;

	  specific_test_opt_seen = 1;

	  for (i = 0; foo[i].name != 0; i++)
	    {
	      if (strcmp (foo[i].name, optarg) == 0)
		{
		  *(foo[i].flagp) = 1;
		  break;
		}
	    }
	  if (foo[i].name == 0)
	    {
	      printf ("%s: no function matching argument %s to -test option\n",
		      progname, optarg);
	      exit (1);
	    }
	}
    }
  argc -= optind;
  argv += optind;

  if (argc == 1)
    {
      maxdbits = sizearg (argv[0]);
      maxnbits = 2 * maxdbits;
    }
  else if (argc == 2)
    {
      maxnbits = sizearg (argv[0]);
      maxdbits = sizearg (argv[1]);
    }
  else
    {
      printf ("usage: %s nsize dsize\n", progname);
      printf ("   or: %s size\n", progname);
      exit (1);
    }

  if (maxnbits <= maxdbits)
    {
      printf ("the dividend needs to be larger than the divisor\n");
      exit (1);
    }

  if (maxdbits <= GMP_NUMB_BITS)
    {
      printf ("the divisor needs to be at least 2 limbs (%d bits)\n",
	      1 + GMP_NUMB_BITS);
      exit (1);
    }

  gmp_randinit_default (rs);

  mpz_init (n);
  mpz_init (d);
  mpz_init (q);
  mpz_init (r);

  maxnn = maxnbits / GMP_NUMB_BITS + 1;
  maxdn = maxdbits / GMP_NUMB_BITS + 1;

  TMP_MARK;

  qp = TMP_ALLOC_LIMBS (maxnn);
  rp = TMP_ALLOC_LIMBS (maxnn);

  testh = 0;
  for (test = 0;;)
    {
#ifndef FIXED_SIZE
      dbits = random () % maxdbits + 1;
      qbits = random () % (maxnbits - maxdbits) + 1;
      rbits = random () % dbits;
#else
      dbits = maxdbits;
      qbits = maxnbits - maxdbits - 1;
      rbits = maxdbits - 1;
#endif

      MPZ_XRANDOMB (q, rs, qbits);
      do
	{
	  MPZ_XRANDOMB (d, rs, dbits);
	}
      while (mpz_sgn (d) == 0);
      MPZ_XRANDOMB (r, rs, rbits);
      mpz_mul_2exp (r, r, qbits);

      mpz_setbit (d, 0);

      mpz_mul (n, d, q);
      mpz_add (n, n, r);

      np = PTR (n);
      dp = PTR (d);
      nn = SIZ (n);
      dn = SIZ (d);

      mp_size_t clearn;
      clearn = random () % (nn + 1);

      for (i = clearn; i < nn; i++)
	np[i] = 0;

#if LEAD_D_MAX
      dp[dn - 1] = GMP_NUMB_MAX;
#endif

      if (dn < 2)
	continue;

      if (nn == dn)
	continue;

      test++;
      testh += test == 0;

      work += qbits + dbits;
      if (work >= 3123451)
	{
	  if (isatty (fileno (stdout)))
	    printf ("\r%lu,,%lu", testh, test);
#ifdef DEBUG
	  printf ("\n");
#endif
	  fflush (stdout);
	  work = 0;
	}

#ifdef DEBUG
      printf ("n="); mpn_dump (np, nn);
      printf ("d="); mpn_dump (dp, dn);
#endif

      binvert_limb (di, dp[0]);

      if ((double) (nn - dn) * dn < 1e8)
	{
	  if (flag_mpn_sb_bdiv_qr)
	    {
	      MPN_ZERO (qp, nn - dn);
	      MPN_ZERO (rp, dn);
	      MPN_COPY (rp, np, nn);
	      rh = mpn_sb_bdiv_qr (qp, rp, nn, dp, dn, -di);
	      check_one (qp, rp + nn - dn, rh, np, nn, dp, dn, "mpn_sb_bdiv_qr");
	    }

	  if (flag_mpn_sb_bdiv_q)
	    {
	      MPN_COPY (rp, np, nn);
	      MPN_ZERO (qp, nn - dn);
	      mpn_sb_bdiv_q (qp, rp, nn - dn, dp, MIN(dn,nn-dn), -di);
	      check_one (qp, NULL, 0, np, nn, dp, MIN(dn,nn-dn), "mpn_sb_bdiv_q");
	    }
	}

      if (flag_mpn_dc_bdiv_qr)
	{
	  itch = nn;					/* ??? FIXME ??? */
	  scratch = __GMP_ALLOCATE_FUNC_LIMBS (itch + 1);
	  mpn_random (scratch + itch, 1);  ran = scratch[itch];
	  MPN_COPY (rp, np, nn);
	  MPN_ZERO (qp, nn - dn);
	  rh = mpn_dc_bdiv_qr (qp, rp, nn, dp, dn, -di /* , scratch */);
	  ASSERT_ALWAYS (ran == scratch[itch]);
	  check_one (qp, rp + nn - dn, rh, np, nn, dp, dn, "mpn_dc_bdiv_qr");
	  __GMP_FREE_FUNC_LIMBS (scratch, itch);
	}

      if (flag_mpn_dc_bdiv_qr_n && nn >= 2*dn)
	{
	  itch = mpn_dc_bdiv_qr_n_itch(dn);
	  scratch = __GMP_ALLOCATE_FUNC_LIMBS (itch + 1);
	  mpn_random (scratch + itch, 1);  ran = scratch[itch];
	  MPN_COPY (rp, np, 2*dn);
	  MPN_ZERO (qp, dn);
	  rh = mpn_dc_bdiv_qr_n (qp, rp, dp, dn, -di, scratch);
	  ASSERT_ALWAYS (ran == scratch[itch]);
	  check_one (qp, rp + dn, rh, np, 2*dn, dp, dn, "mpn_dc_bdiv_qr_n");
	  __GMP_FREE_FUNC_LIMBS (scratch, itch);
	}
      
      if (flag_mpn_dc_bdiv_q)
	{
	  itch = nn;					/* ??? FIXME ??? */
	  scratch = __GMP_ALLOCATE_FUNC_LIMBS (itch + 1);
	  mpn_random (scratch + itch, 1);  ran = scratch[itch];
	  MPN_COPY (rp, np, nn);
	  MPN_ZERO (qp, nn - dn);
	  mpn_dc_bdiv_q (qp, rp, nn - dn, dp, MIN(dn,nn-dn), -di);
	  ASSERT_ALWAYS (ran == scratch[itch]);
	  check_one (qp, NULL, 0, np, nn, dp, dn, "mpn_dc_bdiv_q");
	  __GMP_FREE_FUNC_LIMBS (scratch, itch);
	}

      if (flag_mpn_dc_bdiv_q_n && nn >= 2*dn)
	{
	  itch = mpn_dc_bdiv_q_n_itch (dn);
	  scratch = __GMP_ALLOCATE_FUNC_LIMBS (itch + 1);
	  mpn_random (scratch + itch, 1);  ran = scratch[itch];
	  MPN_COPY (rp, np, dn);
	  MPN_ZERO (qp, nn - dn);
	  mpn_dc_bdiv_q_n (qp, rp, dp, dn, -di, scratch);
	  ASSERT_ALWAYS (ran == scratch[itch]);
	  check_one (qp, NULL, 0, np, 2*dn, dp, dn, "mpn_dc_bdiv_q_n");
	  __GMP_FREE_FUNC_LIMBS (scratch, itch);
	}
#if 0 // mpn_mu_bdiv_qr doesn't yet exist
      if (flag_mpn_mu_bdiv_qr)
	{
	  itch = mpn_mu_bdiv_qr_itch (nn, dn);
	  scratch = __GMP_ALLOCATE_FUNC_LIMBS (itch + 1);
	  mpn_random (scratch + itch, 1);  ran = scratch[itch];
	  MPN_ZERO (qp, nn - dn);
	  MPN_ZERO (rp, dn);
	  rh = mpn_mu_bdiv_qr (qp, rp, np, nn, dp, MIN(dn,nn-dn), scratch);
	  ASSERT_ALWAYS (ran == scratch[itch]);
	  check_one (qp, rp, rh, np, nn, dp, dn, "mpn_mu_bdiv_qr");
	  __GMP_FREE_FUNC_LIMBS (scratch, itch);
	}
#endif

      if (flag_mpn_mu_bdiv_q && dn >= 2 && nn - dn >= 2)
	{
	  itch = mpn_mu_bdiv_q_itch (nn, dn);
	  scratch = __GMP_ALLOCATE_FUNC_LIMBS (itch + 1);
	  mpn_random (scratch + itch, 1);  ran = scratch[itch];
	  MPN_ZERO (qp, nn - dn + 1);
	  mpn_mu_bdiv_q (qp, np, nn - dn, dp, dn, scratch);
	  ASSERT_ALWAYS (ran == scratch[itch]);
	  check_one (qp, NULL, 0, np, nn, dp, dn, "mpn_mu_bdiv_q");
	  __GMP_FREE_FUNC_LIMBS (scratch, itch);
	}
    }

  TMP_FREE;
}
#endif

#ifdef TIMING

#include "cputime.h"

#define TIME(t,func)							\
  do { long int __t0, __times, __t, __tmp;				\
    __times = 1;							\
    for (;;)								\
      {									\
	__t0 = cputime ();						\
	for (__t = 0; __t < __times; __t++)				\
	  {func;}							\
	__tmp = cputime () - __t0;					\
	if (__tmp > 1000) break;					\
	__times <<= 1;							\
      }									\
    (t) = (double) __tmp / __times;					\
  } while (0)

void
printres (double t, int width)
{
  int lg;

  lg = floor (log10 (t));

  if (lg < 0)
    printf (" %*.*fns", width-3, -1 - lg, t * 1000);
  else if (lg < 3)		/* 1-999 */
    printf (" %*.*fµs", width-3, 2 - lg, t);
  else if (lg < 6)		/* 1000-999999 */
    printf (" %*.*fms", width-3, 5 - lg, t * 0.001);
  else if (lg < 9)			/* 1000000-oo */
    printf (" %*.*fs", width-2, 8 - lg, t * 0.000001);
  else
    {
      t = t - fmod (t, pow(10.0, lg - 9));
      printf (" %*.*fs", width-2, 0, t * 0.000001);
    }
}

int
main (int argc, char **argv)
{
  gmp_randstate_t rs;
  unsigned long nbits, dbits;
  mpz_t n, d, q;
  mp_size_t nn, dn;
  mp_ptr np, dp, qp, rp;
  double t;
  mp_limb_t di;
  mp_size_t itch;
  mp_ptr scratch;
  TMP_DECL;

  TMP_MARK;

#ifdef POWERD
  unsigned long i;
  for (i = 4000000000u; i != 0; i--)
    ;
#endif

  if (argc == 2)
    {
      dbits = sizearg (argv[1]);
      nbits = 2 * dbits;
    }
  else if (argc == 3)
    {
      nbits = sizearg (argv[1]);
      dbits = sizearg (argv[2]);
    }
  else
    {
      printf ("usage: %s nsize dsize\n", argv[0]);
      printf ("   or: %s size\n", argv[0]);
      exit (1);
    }

  if (nbits <= dbits)
    {
      printf ("the dividend needs to be larger than the divisor\n");
      exit (1);
    }

  if (dbits <= GMP_NUMB_BITS)
    {
      printf ("the divisor needs to be at least 2 limbs (%d bits)\n",
	      1 + GMP_NUMB_BITS);
      exit (1);
    }

  printf ("DC_BDIV_QR_THRESHOLD=%u\n", DC_BDIV_QR_THRESHOLD);
  printf ("DC_BDIV_Q_THRESHOLD=%u\n", DC_BDIV_Q_THRESHOLD);
  printf ("MU_BDIV_Q_THRESHOLD=%u\n", MU_BDIV_Q_THRESHOLD);
  printf ("BINV_NEWTON_THRESHOLD=%u\n", BINV_NEWTON_THRESHOLD);

  gmp_randinit_default (rs);

  mpz_init (n);
  mpz_init (d);
  mpz_init (q);

  nn = nbits / GMP_NUMB_BITS + 1;
  dn = dbits / GMP_NUMB_BITS + 1;

  qp = TMP_ALLOC_LIMBS (nn);
  rp = TMP_ALLOC_LIMBS (nn);

  mpz_urandomb (n, rs, nbits);

  do
    mpz_urandomb (d, rs, dbits);
  while (mpz_sgn (d) == 0);

  np = PTR (n);
  dp = PTR (d);
  nn = SIZ (n);
  dn = SIZ (d);

  dp[0] |= 1;

  printf ("                SB              DC                                MU\n");
  printf ("%8ld  bdiv_qr bdiv_q  bdiv_qr bdiv_q bdiv_qr_n bdiv_q_n bdiv_qr bdiv_q\n", (long) nn);
  printf ("%8ld", (long) dn);

  binvert_limb (di, dp[0]);

  /* SB functions */
  if ((double) (nn - dn) * dn < 1e10)
    {
      TIME (t, MPN_COPY (rp, np, nn); mpn_sb_bdiv_qr (qp, rp, nn, dp, dn, di));
      printres (t * 1000, 8);
      fflush (stdout);

      TIME (t, MPN_COPY (rp, np, nn);
	       mpn_sb_bdiv_q (qp, rp, nn - dn, dp, MIN(dn,nn-dn), di));
      printres (t * 1000, 8);
      fflush (stdout);
    }
  else printf ("%16s", "");

  /* DC functions */
  itch = nn;		/* ??? FIXME ??? */
  scratch = __GMP_ALLOCATE_FUNC_LIMBS (itch);
  TIME (t, MPN_COPY (rp, np, nn); mpn_dc_bdiv_qr (qp, rp, nn, dp, dn, di /*,scratch*/));
  printres (t * 1000, 8);
  fflush (stdout);
  __GMP_FREE_FUNC_LIMBS (scratch, itch);

  itch = nn;		/* ??? FIXME ??? */
  scratch = __GMP_ALLOCATE_FUNC_LIMBS (itch);
  TIME (t, MPN_COPY (rp, np, nn);
	   mpn_dc_bdiv_q (qp, rp, nn - dn, dp, MIN(dn,nn-dn), di));
  printres (t * 1000, 8);
  fflush (stdout);
  __GMP_FREE_FUNC_LIMBS (scratch, itch);

  if (nn == 2*dn)
    {
      itch = mpn_dc_bdiv_qr_n_itch (dn);
      scratch = __GMP_ALLOCATE_FUNC_LIMBS (itch);
      TIME (t, MPN_COPY (rp, np, nn); mpn_dc_bdiv_qr_n (qp, rp, dp, dn, di, scratch));
      printres (t * 1000, 8);
      fflush (stdout);
      __GMP_FREE_FUNC_LIMBS (scratch, itch);

      itch = mpn_dc_bdiv_q_n_itch (dn);
      scratch = __GMP_ALLOCATE_FUNC_LIMBS (itch);
      TIME (t, MPN_COPY (rp, np, dn);
	    mpn_dc_bdiv_q_n (qp, rp, dp, dn, di, scratch));
      printres (t * 1000, 8);
      fflush (stdout);
      __GMP_FREE_FUNC_LIMBS (scratch, itch);
    }
  else printf ("%16s", "");

  /* MU functions */
  printf ("%8s", "nyi");

  itch = mpn_mu_bdiv_q_itch (nn, dn);
  scratch = __GMP_ALLOCATE_FUNC_LIMBS (itch);
  TIME (t, mpn_mu_bdiv_q (qp, np, nn - dn, dp, MIN(dn,nn-dn), scratch));
  printres (t * 1000, 8);
  fflush (stdout);
  __GMP_FREE_FUNC_LIMBS (scratch, itch);

  puts ("");

  TMP_FREE;
  exit (0);
}
#endif
