/* Copyright 2006, 2007 Free Software Foundation, Inc.

This program is free software; you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation; either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with
this program.  If not, see http://www.gnu.org/licenses/.  */


void
mpn_print (mp_ptr A, mp_size_t n)
{
  int j;

  for (j=0; j<n; j++)
    {
      printf ("+%lu*B^%u", A[j], j);
      if (j % 4 == 3 && j != n-1)
	printf ("\n");
    }
  printf (":\n");
}


int
test_invert (mp_ptr xp, mp_srcptr ap, mp_size_t n)
{
  int res = 1;
  mp_size_t i;
  mp_ptr tp, up;
  mp_limb_t cy;
  TMP_DECL;

  TMP_MARK;
  tp = TMP_ALLOC_LIMBS (2 * n);
  up = TMP_ALLOC_LIMBS (2 * n);

  /* first check X*A < B^(2*n) */
  mpn_mul_n (tp, xp, ap, n);
  cy = mpn_add_n (tp + n, tp + n, ap, n); /* A * msb(X) */
  if (cy != 0)
    res = 0;

  /* now check B^(2n) - X*A <= A */
  mpn_com_n (tp, tp, 2 * n);
  mpn_add_1 (tp, tp, 2 * n, 1); /* B^(2n) - X*A */
  MPN_ZERO (up, 2 * n);
  MPN_COPY (up, ap, n);
  res = mpn_cmp (tp, up, 2 * n) <= 0;
  TMP_FREE;
  return res;
}

#ifdef MAIN
#include <sys/types.h>
#include <sys/resource.h>

int
cputime ()
{
  struct rusage rus;

  getrusage (0, &rus);
  return rus.ru_utime.tv_sec * 1000 + rus.ru_utime.tv_usec / 1000;
}

int
main (int argc, char *argv[])
{
  mp_size_t n = atoi (argv[1]), i, j, k;
  mp_ptr qp, rp, dp, tp, qp2, rp2;
  mp_limb_t cy;
  pid_t pid;
  int st;

  k = (argc <= 2) ? 1 : atoi(argv[2]);

  qp = malloc (n * sizeof (mp_limb_t));
  qp2 = malloc (n * sizeof (mp_limb_t));
  rp = malloc (n * sizeof (mp_limb_t));
  rp2 = malloc (2 * n * sizeof (mp_limb_t));
  dp = malloc (n * sizeof (mp_limb_t));
  tp = malloc (2 * n * sizeof (mp_limb_t));

  pid = getpid ();
  printf ("Seed=%lu\n", pid);
  srand48 (pid);
  for (i = 0; i < n; i++)
    dp[i] = lrand48 ();
  dp[n - 1] |= GMP_NUMB_HIGHBIT;

  mpn_random (rp, n);
  st = cputime ();
  for (i = 0; i < k; i++)
    mpn_mul_n (tp, dp, rp, n);
  printf ("mpn_mul_n took %dms\n", cputime () - st);

  st = cputime ();
  for (i = 0; i < k; i++)
    {
#ifdef CHECK
      //      printf ("Test %lu\n", i);
      for (j = 0; j < n; j++)
	dp[j] = lrand48 ();
      dp[n - 1] |= GMP_NUMB_HIGHBIT;
#endif
      mpn_invert (qp, dp, n);
#ifdef CHECK
  if (test_invert (qp, dp, n) == 0)
    {
      fprintf (stderr, "test_invert failed at i=%lu\n", i);
      printf ("A:="); mpn_print (dp, n);
      printf ("X:=B^%lu", n); mpn_print (qp, n);
      exit (1);
    }
#endif
    }
  printf ("mpn_invert%d took %dms", INVERT_VERSION, cputime () - st);
#ifdef WRAP_AROUND
  printf (" (with wrap-around trick, WRAP_AROUND_BOUND=%lu)",
	  WRAP_AROUND_BOUND);
#endif
  printf ("\n");

  // printf ("xp="); mpn_print (qp, n);

  MPN_ZERO (rp2, 2 * n);
  rp2[2 * n - 1] = GMP_LIMB_HIGHBIT;
  st = cputime ();
  for (i = 0; i < k; i++)
    {
      MPN_ZERO (rp2, 2 * n);
      rp2[2 * n - 1] = GMP_LIMB_HIGHBIT;
      mpn_divrem (qp2, 0, rp2, 2 * n, dp, n);
    }
  printf ("mpn_divrem took %dms\n", cputime () - st);

  free (qp);
  free (rp);
  free (dp);
  free (tp);

  return 0;
}
#endif
