/* Vector superoptimizer.  Finds the shortest instruction sequences for an
   m-ary -> n-ary projection.  The algorithm is based on exhaustive search with
   backtracking and iterative deepening.

Copyright 2007 Free Software Foundation, Inc.

This program is free software; you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation; either version 2.1 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with
this program; if not, write to the Free Software Foundation, Inc., 51 Franklin
Street, Fifth Floor, Boston, MA 02110-1301, USA. */

/*
   values:
     inarity random values from start  synthesized values
     |===============================|===========================....====|
				     <------------- max_ops ------------->

   goal_values:
      output_arity values always, with valid bit
     |==========================================|
   These values where computed by the goal function with values[0..inarity-1]
   as arguments

TODO:

 * Finish the testing code in possible_sequence.

 * In main, match function names, not their arity.

*/

#include <string.h>
#include <stdlib.h>
#include <stdio.h>

#include "gmp.h"
#include "gmp-impl.h"	/* for LIKELY/UNLIKELY and the HAVE_NATIVE_* macros */

#define MAX_MAX_OPS 20

struct
{
  mp_limb_t value[MAX_MAX_OPS];
  int valid[MAX_MAX_OPS];
} goal_values;
long n_goal_values;

/* FIXME: Even if these are few, put them in an insn.def file, to keep enum
   tags and names sync'ed.  */
typedef enum opcode {NOP, ADD, SUB, LSH, RSH, ADDLSH1, SUBLSH1, RSBLSH1} opcode_t;
char *opnames[] = {"NOP", "ADD", "SUB", "LSH", "RSH", "ADDLSH1", "SUBLSH1", "RSBLSH1", 0};

typedef struct
{
  opcode_t opcode:8;
  unsigned int d:6;
  unsigned int s1:6;
  unsigned int s2:6;
  unsigned int s3:6;		/* shift cnt for addlsh, sublsh, rsblsh */
} insn_t;

mp_limb_t values[MAX_MAX_OPS];
insn_t program[MAX_MAX_OPS];

void synth_nonleaf (long, long, long);
void synth_leaf (long);
void possible_sequence (long);

long locate_goal_value (mp_limb_t);
mp_limb_t find_last_goal_value (void);
void hide_goal_value (int);
void unhide_goal_value (int);

int success = 0;

typedef void gf (mp_limb_t *, mp_limb_t *);	/* type of goal functions */
gf *goal_function;


#define ADD(x,y)	(x + y)
#define SUB(x,y)	(x - y)
#define ADDLSH1(x,y)	(x + (y << 1))
#define SUBLSH1(x,y)	(x - (y << 1))
#define RSBLSH1(x,y)	((y << 1) - x)
#define LSH(x,y)	(x << y)
#define RSH(x,y)	((mp_limb_signed_t) x >> y)
#define ADDLSH(x,y,z)	(x + (y << z))
#define SUBLSH(x,y,z)	(x - (y << z))
#define RSBLSH(x,y,z)	((y << z) - x)
#define ADDADD(x,y,z)	(x + y + z)
#define ADDSUB(x,y,z)	(x + y - z)
#define SUBSUB(x,y,z)	(x - y - z)

#define CNST(c)		((c) + 0x20)

#define SETINSN2(insn,op,da,s1a,s2a) \
  do {									\
    insn.opcode = op;							\
    insn.d = da;							\
    insn.s1 = s1a;							\
    insn.s2 = s2a;							\
    insn.s3 = 0x3f;							\
  } while (0)

#define REG_REC(op,n_values,i,j)					\
  reg_rec (op, n_values, i, j, reg, missing_goal_values, allowed_cost);

#ifndef PRUNE
#define PRUNE 2
#endif

void
reg_rec (enum opcode op, long n_values, long i, long j, mp_limb_t reg, long missing_goal_values, long allowed_cost)
{
  long iv, k;

#if PRUNE >= 1
  /* Simple prune.  Check that the new instruction and the immediately
     preceding one are in canonical order, if they are independent.  */
  insn_t insn;
  k = n_values - 1;
  insn = program[k];
  if (insn.opcode == NOP)
    goto noprune;		/* before generated sequence */
  if (insn.d == i || insn.d == j)
    goto noprune;		/* read-after-write dependency */
  if (insn.opcode != NOP && insn.d != i && insn.d != j)
    {
      if (insn.opcode <= op)
	{
	  if (insn.opcode < op)
	    return;
	  if (insn.s1 <= i)
	    {
	      if (insn.s1 < i)
		return;
	      if (insn.s2 < j)
		return;
	    }
	}
    }
#endif
#if PRUNE == 2
  /* Check that the new instruction and the preceding ones are in canonical
     order, if they are independent.  */
  if (allowed_cost > 0)		/* don't do costly pruning at leaf node */
    for (k = n_values - 2; k > 0; k--)
      {
	insn = program[k];
	if (insn.opcode == NOP)
	  goto noprune;			/* before generated sequence */
	if (insn.d == i || insn.d == j)
	  goto noprune;			/* read-after-write dependency */
	if (insn.opcode <= op)
	  {
	    if (insn.opcode < op)
	      return;
	    if (insn.s1 <= i)
	      {
		if (insn.s1 < i)
		  return;
		if (insn.s2 < j)
		  return;
	      }
	  }
      }
#endif

 noprune:

  iv = locate_goal_value (reg);
  if (iv >= 0)
    {
      long mgv2 = missing_goal_values - (goal_values.valid[iv] == 0);
      if (mgv2 <= allowed_cost + 1)
	{
	  hide_goal_value (iv);
	  values[n_values] = reg;			/* save new value */
	  SETINSN2 (program[n_values],op,n_values,i,j);
	  if (allowed_cost > 0)
	    synth_nonleaf (allowed_cost - 1, mgv2, n_values + 1);
	  else
	    synth_leaf (n_values + 1);
	  unhide_goal_value (iv);
	}
    }
  else if (missing_goal_values <= allowed_cost + 1)
    {
      values[n_values] = reg;				/* save new value */
      SETINSN2 (program[n_values],op,n_values,i,j);
      if (allowed_cost > 0)
	synth_nonleaf (allowed_cost - 1, missing_goal_values, n_values + 1);
      else
	synth_leaf (n_values + 1);
    }
  else
    ;  /* cut */
}

void
synth_nonleaf (long allowed_cost, long missing_goal_values, long n_values)
{
  long i, j;
  mp_limb_t reg, x, y;

  /* binary commutative operations */
  for (i = n_values - 1; i > 0; i--)
    {
      x = values[i];
      for (j = i - 1; j >= 0; j--)
	{
	  y = values[j];
	  reg = ADD (x, y);  REG_REC(ADD,n_values,i,j);
	}
    }

  /* binary non-commutative operations */
  for (i = n_values - 1; i >= 0; i--)
    {
      x = values[i];
      for (j = n_values - 1; j >= 0; j--)
	{
	  y = values[j];
#if HAVE_NATIVE_addlsh1_n
	  reg = ADDLSH1 (x, y);  REG_REC(ADDLSH1,n_values,i,j);
#endif
	  if (i != j)
	    {
	      reg = SUB (x, y);  REG_REC(SUB,n_values,i,j);
#if HAVE_NATIVE_sublsh1_n
	      reg = SUBLSH1 (x, y);  REG_REC(SUBLSH1,n_values,i,j);
#endif
#if HAVE_NATIVE_rsblsh1_n
	      reg = RSBLSH1 (x, y);  REG_REC(RSBLSH1,n_values,i,j);
#endif
	    }
	}
    }

  /* unary operations */
  for (i = n_values - 1; i >= 0; i--)
    {
      x = values[i];
      reg = LSH (x, 1);  REG_REC(LSH,n_values,i,CNST(1));
      reg = LSH (x, 2);  REG_REC(LSH,n_values,i,CNST(2));
      reg = LSH (x, 3);  REG_REC(LSH,n_values,i,CNST(3));
      reg = RSH (x, 1);  REG_REC(RSH,n_values,i,CNST(1));
    }
}

#undef TEST
#define TEST(op,n_values,i,j)						\
  do {									\
    if (UNLIKELY (reg == last_goal_value))				\
      {									\
	SETINSN2 (program[n_values],op,n_values,i,j);			\
	possible_sequence (n_values);					\
      }									\
  } while (0)

void
synth_leaf (long n_values)
{
  long i, j;
  mp_limb_t reg, x, y;
  mp_limb_t last_goal_value;

  last_goal_value = find_last_goal_value ();

  /* binary commutative operations */
  for (i = n_values - 1; i > 0; i--)
    {
      x = values[i];
      for (j = i - 1; j >= 0; j--)
	{
	  y = values[j];
	  reg = ADD (x, y);     TEST(ADD,n_values,i,j);
	}
    }

  /* binary non-commutative operations */
  for (i = n_values - 1; i >= 0; i--)
    {
      x = values[i];
      for (j = n_values - 1; j >= 0; j--)
	{
	  y = values[j];
#if HAVE_NATIVE_addlsh1_n
	  reg = ADDLSH1 (x, y); TEST(ADDLSH1,n_values,i,j);
#endif
	  reg = SUB (x, y);     TEST(SUB,n_values,i,j);
#if HAVE_NATIVE_sublsh1_n
	  reg = SUBLSH1 (x, y); TEST(SUBLSH1,n_values,i,j);
#endif
#if HAVE_NATIVE_rsblsh1_n
	  reg = RSBLSH1 (x, y); TEST(RSBLSH1,n_values,i,j);
#endif
	}
    }

  /* unary operations */
  for (i = n_values - 1; i >= 0; i--)
    {
      x = values[i];
      reg = LSH (x, 1);         TEST(LSH,n_values,i,CNST(1));
      reg = LSH (x, 2);         TEST(LSH,n_values,i,CNST(2));
      reg = LSH (x, 3);         TEST(LSH,n_values,i,CNST(3));
      reg = RSH (x, 1);         TEST(RSH,n_values,i,CNST(1));
    }
}

/* Test a probably correct sequence, and report it if it passes the tests.  */
void
possible_sequence (long n_values)
{
  insn_t insn;
  long i;

#if 0
  for (test = 0; test < n_test; test++)
    {
      /* FIXME: If cheap, we should have a permutation array of which
	 values[i] corresponds to which goal_values[j].  Note however
	 that we'll probably almost always reject a sequence after 1
	 test here!  */

      run_sequence (...);
      for (i = outarity - 1; i >= 0; i--)
	{
	}
    }
#endif

  printf ("untested sequence (%d)\n", success);
  for (i = 0; i <= n_values; i++)
    {
      insn = program[i];
      if (insn.opcode != 0)
	{
	  if (insn.s2 >= 0x20)	/* immediate operand */
	    printf ("r%-2d = %s(r%d,%d)\n", insn.d, opnames[insn.opcode], insn.s1, insn.s2 - 0x20);
	  else
	    printf ("r%-2d = %s(r%d,r%d)\n", insn.d, opnames[insn.opcode], insn.s1, insn.s2);
	}
    }
  fflush (stdout);

  success++;
}

long
locate_goal_value (mp_limb_t r)
{
  long i;
  for (i = n_goal_values - 1; i >= 0; i--)
    {
      if (r == goal_values.value[i])
	return i;
    }
  return -1;
}

mp_limb_t
find_last_goal_value (void)
{
  long i;
#ifdef DEBUG
  long idx = -1;
  for (i = n_goal_values - 1; i >= 0; i--)
    {
      if (goal_values.valid[i] == 0)
	{
	  if (idx != -1)
	    abort ();
	  idx = i;
	}
    }
  if (idx == -1)
    abort ();
  return goal_values.value[idx];
#else
  for (i = n_goal_values - 1; i >= 0; i--)
    {
      if (goal_values.valid[i] == 0)
	return goal_values.value[i];
    }
  abort ();
#endif
}

void
hide_goal_value (int i)
{
  goal_values.valid[i]++;
}

void
unhide_goal_value (int i)
{
  goal_values.valid[i]--;
}

#undef DEF
#define DEF(ia,oa,fname,fn) ia,
unsigned char input_arity[] = {
#include "goal-va.def"
0
};

#undef DEF
#define DEF(ia,oa,fname,fn) oa,
unsigned char output_arity[] = {
#include "goal-va.def"
0
};

#undef DEF
#define DEF(ia,oa,fname,fn) void fname (mp_limb_t *gv, mp_limb_t *v) fn
#include "goal-va.def"

#undef DEF
#define DEF(ia,oa,fname,fn) fname,
gf *goal_functions[] = {
#include "goal-va.def"
NULL
};

#undef DEF
#define DEF(ia,oa,fname,fn) #fname,
char *goal_function_names[] = {
#include "goal-va.def"
NULL
};


int
main (int argc, char **argv)
{
  int inarity, outarity, i, maxmax_cost, allowed_extra_cost, max_ops;
  char *progname, *goal_function_name;
  gf *goal_function;

  progname = strrchr (argv[0], '/');
  if (progname == NULL)
    progname = argv[0];
  else
    progname++;

  maxmax_cost = -1;
  allowed_extra_cost = 0;
  inarity = outarity = 0;
  goal_function = NULL;

  argv++;
  argc--;

  while (argc > 0)
    {
      char *arg = argv[0];
      int arglen = strlen (arg);

      if (arglen < 2)
	arglen = 2;

      if (!strncmp (arg, "-version", arglen))
	{
	  printf ("%s version %s\n", progname, "0.1");
	  if (argc == 1)
	    exit (0);
	}
      else if (!strncmp (arg, "-max-cost", arglen))
	{
	  argv++;
	  argc--;
	  if (argc == 0)
	    {
	      fprintf (stderr, "superoptimizer: argument to `-max-cost' expected\n");
	      exit(-1);
	    }
	  maxmax_cost = atoi (argv[0]);
	}
#if 0
      /* It seems hard to make this work, since find_last_goal_value will not
	 always find a suitable last value.  */
      else if (!strncmp(arg, "-extra-cost", arglen))
	{
	  argv++;
	  argc--;
	  if (argc == 0)
	    {
	      fprintf(stderr, "superoptimizer: argument `-extra-cost' expected\n");
	      exit(-1);
	    }
	  allowed_extra_cost = atoi (argv[0]);
	}
#endif
      else if (!strncmp (arg, "-f", 2))
	{
	  int i;
	  for (i = 0;; i++)
	    {
	      goal_function_name = goal_function_names[i];

	      if (goal_function_name == NULL)
		break;

	      if (!strcmp (arg + 2, goal_function_name))
		{
		  goal_function = goal_functions[i];
		  inarity = input_arity[i];
		  outarity = output_arity[i];
		  break;
		}
	    }
	  if (goal_function_name == NULL)
	    {
	      fprintf (stderr, "%s: unknown goal function\n", progname);
	      exit (-1);
	    }
	}
      else
	{
	  int i, len, maxlen, cols, maxcols;
	  char *prefix;
	  fprintf(stderr, "usage:  %s -f<goal-function> [-max-cost n]\\\n", progname);
	  fprintf(stderr, "Supported goal functions:\n\n");

	  maxlen = 0;
	  for (i = 0;; i++)
	    {
	      goal_function_name = goal_function_names[i];

	      if (goal_function_name == NULL)
		break;

	      len = strlen (goal_function_name);
	      if (len > maxlen)
		maxlen = len;
	    }

	  maxcols = 79 / (maxlen + 2);
	  if (maxcols < 1)
	    maxcols = 1;

	  cols = 1;
	  prefix = "";
	  for (i = 0;; i++)
	    {
	      goal_function_name = goal_function_names[i];

	      if (goal_function_name == NULL)
		break;

	      fprintf(stderr, "%s  %-*s", prefix, maxlen, goal_function_name);

	      cols++;
	      if (cols > maxcols)
		{
		  cols = 1;
		  prefix = "\n";
		}
	      else
		prefix = "";
	    }

	  fprintf(stderr, "\n");
	  exit(-1);
	}

      argv++;
      argc--;
    }

  if (sizeof (mp_limb_t) == 4)
    printf ("WARNING: The likelyhood of false sequences is high on 32-bit machines.\n"
	    "         Please run this program on a 64-bit machine\n");

  /* Generate a set of random arguments, then evaluate the goal function.
     FIXME: We might want to iterate until the goal function gives unique
     values for all 'outarity' goal values.  */
  for (i = 0; i < inarity; i++)
    {
      mp_limb_t x;
      mpn_random (&x, 1);
      values[i] = x >> 16;
    }
  (goal_function) (goal_values.value, values);
  for (i = 0; i < outarity; i++)
    goal_values.valid[i] = 0;

  n_goal_values = outarity;

  if (maxmax_cost == -1 || maxmax_cost > MAX_MAX_OPS - inarity)
    maxmax_cost = MAX_MAX_OPS - inarity;

  for (i = 0; i < inarity; i++)
    SETINSN2 (program[i], NOP, 0, 0, 0);

  for (max_ops = 1; max_ops < maxmax_cost; max_ops++)
    {
      printf ("cost %d+%d\r", inarity, max_ops);  fflush (stdout);
      synth_nonleaf (max_ops - 1, outarity, inarity);
      if (success)
	break;
    }
#if 0
  maxmax_cost = max_ops + allowed_extra_cost;
  for ( ; max_ops < maxmax_cost; max_ops++)
    {
      printf ("cost %d+%d\r", inarity, max_ops);  fflush (stdout);
      synth_nonleaf (max_ops - 1, outarity, inarity);
    }
#endif
  if (! success)
    puts ("");

  return 0;
}
