/* Copyright 2006, 2007 Free Software Foundation, Inc.

This program is free software; you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation; either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with
this program.  If not, see http://www.gnu.org/licenses/.  */


#include <stdlib.h>		/* for exit, strtoul */
#include <string.h>		/* for strlen */
#include <stdio.h>		/* for printf */
#include <math.h>		/* for log10, fmod, pow */
#include <unistd.h>		/* for isatty */

#include "gmp.h"
#include "gmp-impl.h"
#include "longlong.h"

#include "new.h"


static mp_size_t
sizearg (char *arg)
{
  if (arg[strlen (arg) - 1] == 'b')
    return strtoul (arg, 0, 0);
  else
    return strtoul (arg, 0, 0) * GMP_LIMB_BITS;
}


#ifdef CHECK
void
dumpy (mp_srcptr p, mp_size_t n)
{
  mp_size_t i;
  for (i = n - 1; i >= 0; i--)
    {
      printf ("%0*lx", (int) (2 * sizeof (mp_limb_t)), p[i]);
      printf (" " + (i == 0));
    }
  puts ("");
}

int
main (int argc, char **argv)
{
  gmp_randstate_t rs;
  unsigned long maxnbits, maxdbits, nbits, dbits;
  mpz_t n, d, q;
  mp_size_t maxnn, maxdn, nn, dn, qn;
  mp_ptr qp;
  mp_ptr scratch;
  unsigned long test, testh, err = 0, work = 0;
  mp_size_t itch;
  mp_limb_t ran;
  TMP_DECL;

  TMP_MARK;

  if (argc == 2)
    {
      maxdbits = sizearg (argv[1]);
      maxnbits = 2 * maxdbits;
    }
  else if (argc == 3)
    {
      maxnbits = sizearg (argv[1]);
      maxdbits = sizearg (argv[2]);
    }
  else
    {
      printf ("usage: %s nbits dbits\n", argv[0]);
      printf ("   or: %s bits\n", argv[0]);
      exit (1);
    }

  if (maxnbits <= maxdbits)
    {
      printf ("the dividend needs to be larger than the divisor\n");
      exit (1);
    }

  gmp_randinit_default (rs);

  mpz_init (n);
  mpz_init (d);
  mpz_init (q);

  maxnn = maxnbits / GMP_NUMB_BITS + 1;
  maxdn = maxdbits / GMP_NUMB_BITS + 1;

  qp = TMP_ALLOC_LIMBS (maxnn);

  testh = ~0;
  for (test = 0;; test++)
    {
      testh += test == 0;

#ifndef FIXED_SIZE
      nbits = random () % maxnbits + 1;
      if (maxdbits > nbits)
	dbits = random () % nbits;
      else
	dbits = random () % maxdbits;
#else
      nbits = maxnbits;
      dbits = maxdbits;
#endif

#if MAKE_Q_MAXIMAL
      mpz_set_ui (n, 1);
      mpz_mul_2exp (n, n, nbits);
      mpz_sub_ui (n, n, 1);
      mpz_set_ui (d, 1);
#else
      do
	{
#if RAND_UNIFORM
	  mpz_urandomb (q, rs, nbits - dbits);
#else
	  mpz_rrandomb (q, rs, nbits - dbits);
#endif
	}
      while (mpz_sgn (q) == 0);
#if RAND_UNIFORM
      mpz_urandomb (d, rs, dbits);
#else
      mpz_rrandomb (d, rs, dbits);
#endif
      mpz_setbit (d, 0);	/* mpn_divexact only handles odd divisors */
#ifdef ODDQ
      mpz_setbit (q, 0);
#endif
      mpz_mul (n, q, d);
#endif /* MAKE_Q_MAXIMAL */

      nn = SIZ (n);
      dn = SIZ (d);

      work += nbits - dbits;
      if (work >= 7123451)
	{
	  if (isatty (fileno (stdout)))
	    printf ("\r%lu,,%lu", testh, test);
#ifdef DEBUG
	  printf ("\n");
#endif
	  fflush (stdout);
	  work = 0;
	}

#ifdef DEBUG
      printf ("n="); mpn_dump (PTR(n), nn);
      printf ("d="); mpn_dump (PTR(d), dn);
#endif

      itch = mpn_divexact_itch (nn, dn);
      scratch = __GMP_ALLOCATE_FUNC_LIMBS (itch + 1);

      mpn_random (&ran, 1);
      scratch[itch] = ran;

      mpn_divexact (qp, PTR(n), nn, PTR(d), dn, scratch);

      if (scratch[itch] != ran)
	{
	  printf ("clobbered end of scratch for nn=%ld dn=%ld itch=%ld\n", nn, dn, itch);
	}

      __GMP_FREE_FUNC_LIMBS (scratch, itch + 1);

      qn = SIZ(q);
      if (mpn_cmp (qp, PTR(q), qn) != 0)
	{
	  mp_size_t lo, hi;
	  printf ("**********************************************************\n");
	  printf ("mpn_divexact and mpn_tdiv_qr disagree in test %lu,,%lu\n", testh, test);
	  printf ("n=    "); dumpy (PTR(n), nn);
	  printf ("d=    "); dumpy (PTR(d), dn);
	  printf ("qp=   "); dumpy (qp, qn);
	  printf ("refqp="); dumpy (PTR(q), qn);
	  for (lo = 0; qp[lo] == PTR(q)[lo]; lo++);
	  for (hi = qn - 1; qp[hi] == PTR(q)[hi]; hi--);
	  printf ("nn = %ld, dn = %ld, qn = %ld, bad %ld--%ld\n", nn, dn, qn, hi, lo);
	  printf ("**********************************************************\n");
	  if (++err >= 5)
	    abort ();
	}
    }

  TMP_FREE;
}
#endif

#ifdef TIMING

#include "cputime.h"

#define TIME(t,func)							\
  do { long int __t0, __times, __t, __tmp;				\
    __times = 1;							\
    for (;;)								\
      {									\
	__t0 = cputime ();						\
	for (__t = 0; __t < __times; __t++)				\
	  {func;}							\
	__tmp = cputime () - __t0;					\
	if (__tmp > 1000) break;					\
	__times <<= 1;							\
      }									\
    (t) = (double) __tmp / __times;					\
  } while (0)

void
printres (char *fmt, double t)
{
  int prc;
  t = 1000.0 * t;
  prc =  2 - floor (log10 (t));
  if (prc < 0)
    {
      t = t - fmod (t, pow(10.0, -prc));
      prc = 0;
    }
  printf (fmt, prc, t);
}

int
main (int argc, char **argv)
{
  gmp_randstate_t rs;
  unsigned long nbits, dbits;
  mpz_t n, d, q;
  mp_size_t nn, dn;
  mp_ptr qp, rp;
  mp_ptr scratch;
  double t;
  mp_size_t itch;
  TMP_DECL;

  TMP_MARK;

#ifdef POWERD
  unsigned long i;
  for (i = 4000000000u; i != 0; i--)
    ;
#endif

  if (argc == 2)
    {
      dbits = sizearg (argv[1]);
      nbits = 2 * dbits;
    }
  else if (argc == 3)
    {
      nbits = sizearg (argv[1]);
      dbits = sizearg (argv[2]);
    }
  else
    {
      printf ("usage: %s nbits dbits\n", argv[0]);
      printf ("   or: %s bits\n", argv[0]);
      exit (1);
    }

  if (nbits <= dbits)
    {
      printf ("the dividend needs to be larger than the divisor\n");
      exit (1);
    }

  gmp_randinit_default (rs);

  mpz_init (n);
  mpz_init (d);
  mpz_init (q);

  nn = nbits / GMP_NUMB_BITS + 1;
  dn = dbits / GMP_NUMB_BITS + 1;

  qp = TMP_ALLOC_LIMBS (nn);
  rp = TMP_ALLOC_LIMBS (dn);

  do
    mpz_urandomb (q, rs, nbits - dbits);
  while (mpz_sgn (q) == 0);
  mpz_urandomb (d, rs, dbits);
  mpz_setbit (d, 0);		/* FIXME: mpn_divexact only handles odd divisors */
  mpz_mul (n, q, d);

  itch = mpn_divexact_itch (SIZ(n), SIZ(d));
  scratch = TMP_ALLOC_LIMBS (itch);

  printf ("%d / %d ------------------------------------------------\n", SIZ(n), SIZ(d));

  TIME (t, mpn_divexact (qp, PTR(n), SIZ(n), PTR(d), SIZ(d), scratch));
  printres ("mpn_divexact      took %.*fµs\n", t);

  if (SIZ(q) < 100000)
    {
      mp_ptr foo = __GMP_ALLOCATE_FUNC_LIMBS (SIZ(n));
      mp_limb_t di;
      TIME (t, binvert_limb (di, PTR(d)[0]); di = -di;
	       MPN_COPY (foo, PTR(n), SIZ(d));
	       mpn_sb_bdiv_q (qp, foo, SIZ(d), PTR(d), SIZ(d), di));
      __GMP_FREE_FUNC_LIMBS (foo, SIZ(n));
      printres ("mpn_sb_bdiv_q     took %.*fµs\n", t);
    }
  else puts ("mpn_sb_bdiv_q     took n/a");

  TIME (t, mpz_divexact (q, n, d));
  printres ("mpz_divexact      took %.*fµs\n", t);

  if (SIZ(q) < 1500)
    {
      TIME (t, mpn_tdiv_qr (qp, rp, 0L, PTR(n), SIZ(n), PTR(d), SIZ(d)));
      printres ("mpn_tdiv_qr       took %.*fµs\n", t);
    }
  else
    printres ("mpn_tdiv_qr       took %.*fµs\n", t);

  TMP_FREE;
  exit (0);
}
#endif
