# include <stdio.h>
# include <stdlib.h>
# include <limits.h>
# include <openssl/bn.h>
# include <openssl/crypto.h>
# include "mpi.h"
# include <omp.h>
# include "bn_prime.h"


char *prog;

int debug_level = 0;
int numnodes = 1;
int myid = 0;
int numthreads = 1;

BIGNUM *bn_maxint;

#pragma omp threadprivate (prog, debug_level, bn_maxint, numthreads)


void myerror(char *errstring)
{
  fprintf(stderr, "%s, processor %d%c (of %d*%d): %s\n", prog, 
	  myid, omp_get_thread_num()+'A', 
	  numnodes, omp_get_num_threads(),
	  errstring);
#pragma omp master
  MPI_Finalize();

  exit(1);
}

void debug(char *message)
{
	if ( debug_level >= 1 )  
	    fprintf(stderr, "processor %d%c (of %d*%d): %s.\n", 
		    myid, omp_get_thread_num()+'A',  
		    numnodes, omp_get_num_threads(),
		    message); 

	if ( debug_level >= 3 )
	{
     		fprintf(stderr, "Press return to continue\n");
     		getchar();
	}
	  
}

void trace_bignum(char *string, BIGNUM *bn)
{
    if ( debug_level >= 2 )
	{
   	 char *bnstring;
  	  bnstring=BN_bn2dec(bn);
  	  fprintf(stderr, "processor %d%c (of %d*%d): %s = %s\n", 
		  myid, omp_get_thread_num()+'A', 
		  numnodes, omp_get_num_threads(), string, bnstring);
  	  OPENSSL_free(bnstring);
	}
	
}

void trace_int(char *string, int i)
{
    if ( debug_level >= 2 )
	{
  	  fprintf(stderr, "processor %d%c (of %d*%d): %s = %d\n", 
		  myid, omp_get_thread_num()+'A', 
		  numnodes, omp_get_num_threads(), string, i);

	}
	
}

int isprime(BIGNUM *p, BN_CTX *ctx)
{

  /* code adapted from BN_is_prime_fasttest */
  /* same return values as BN_is_prime_fasttest */

  int i;
  
  /* first look for small factors */

  if (!BN_is_odd(p))
    {
      if ( BN_is_word(p, 2) )
	return 1;
      else
	return 0;
    }
  
  for (i = 1; i < NUMPRIMES; i++)
    {
      /*      trace_int("primes[i]", primes[i]); */

	if ( BN_mod_word(p, primes[i]) == 0 ) 
	  {
	          if ( BN_is_word(p, primes[i]) )
		    return 1;
		  else 
		    return 0;
	  }
    }

  /* then call BN_is_prime_fasttest without trial division -- this is the same as BN_is_prime */

  /*  debug("running BN_is_prime_fasttest"); */
  return BN_is_prime_fasttest(p, BN_prime_checks, NULL, ctx, NULL, 0);

}


int nextprime(BIGNUM *p, BN_CTX *ctx)
     /* steps p to the next larger prime.  returns 1 on success, 0 on error. */
{
  int test;

  if ( BN_mod_word(p, 2) == 0) 
    if ( !BN_sub_word(p, 1) ) 
      {
	debug("error decrementing p in nextprime");
	return(0);
      }
   do
    {
      /*       debug("incrementing p"); */
      if ( !BN_add_word(p, 2) ) 
	{
	  debug("error incrementing p in nextprime");
	  return(0);
	}
      /* trace_bignum("p", p); */
     }  
   while ( (test=isprime(p, ctx) ) == 0 );

  if ( test == -1 )
    {
      debug("error testing for primality in nextprime");
      return(0);
    }
  else 
    return(1);
}

int isprimroot(BIGNUM *g, BIGNUM *p, BIGNUM *pminusone, BN_CTX *ctx)
     /* check is g is a primitive root modulo the odd prime p */
     /* returns 0 if not a prim root, 1 if is, -1 on error */
     /* doesn't clean up on error, but should */
     /* we do not check that p is an odd prime */
     /* pminusone should be p-1, we don't check */

{
  BIGNUM *temp_q, *temp_r, *temp_d;

  int returnval=1; /*assume we have a generator */

  temp_r=BN_new();
  if (temp_r==NULL) 
    {
      debug("error allocating temporary remainder in isprimroot"); 
      return(-1);
    }
  temp_q=BN_new();
  if (temp_q==NULL) 
    {
      debug("error allocating temporary prime in isprimroot"); 
      return(-1);
    }
  temp_d=BN_new();
  if (temp_d==NULL) 
    {
      debug("error allocating temporary quotient in isprimroot"); 
      return(-1);
    }
  /*  temporary variables */

  if ( !BN_mod(temp_r, g, p, ctx) ) 
    {
      debug("error reducing g mod p in isprimroot");
      return(-1);
    }

  /*  trace_bignum("g mod p", temp_r); */

  if ( BN_is_zero(temp_r) )
    return 0;
  /* do check if g is zero mod p */

  if ( !BN_set_word(temp_q, 2) ) 
    {
      debug("error initializing temporary prime in isprimroot"); 
      return(-1);
    }

  do
    {
      /*      trace_bignum("q", temp_q); */

      if ( !BN_div(temp_d, temp_r, pminusone, temp_q, ctx) )
	{
	  debug("error dividing p-1 by q in isprimroot");
	  return(-1);
	}

      /*      trace_bignum("p-1 div q", temp_d); */
      /*      trace_bignum("p-1 mod q", temp_r); */

      if ( BN_is_zero(temp_r) )
	{
	  if ( !BN_mod_exp(temp_r, g, temp_d, p, ctx) )
	    {	
	      debug("error assigning r = g^(p-1)/q mod p in isprimroot");
	      return(-1);
	    }
	  /* yes, we are reusing temp_r.  sue me. */
	  
	  /*	  trace_bignum("g^(p-1)/q mod p", temp_r); */

	  if ( BN_is_one(temp_r) ) returnval=0;
	  /* if we got one as a small power of g, then g is not a primitive root */

	  /*	  trace_int("prim root so far", returnval); */
	}

      if ( !nextprime(temp_q, ctx) ) 
	{
	  debug("error finding next q in isprimroot");
	  return(-1);
	}
    }
  while( (returnval==1) && (BN_cmp(temp_q, p)==-1) );
  /* repeat until definitely not a generator or q is not less than p */ 

  BN_free(temp_q);
  BN_free(temp_r);
  BN_free(temp_d);
  /* kill temporary storage before returning */

  return(returnval);
}

int primroot(BIGNUM *g, BIGNUM *p, BIGNUM *pminusone, BN_CTX *ctx)
     /* steps g to the next larger primitive root modulo the odd prime p */
     /* will keep going if g>p */
     /* returns 1 on success, 0 on error. */
     /* doesn't clean up on error, but should */
     /* we do not check that p is an odd prime */
     /* pminusone should be p-1, we don't check */
     /* probably not the most efficient way to find a primitive root */
{
  int test;

   do
     {
       /* debug("incrementing g in primroot");  */
       if ( !BN_add_word(g, 1) )
	 {
	   debug("error incrementing g in primroot");
	   return(0);
	 }
       /*       trace_bignum("g", g);  */
     }
   while ( (test=isprimroot(g, p, pminusone, ctx) ) == 0 );

   if ( test == -1 )
     {
       debug("error testing for primitivity in primroot");
       return(0);
     }
   else
     return(1);
}

int globtree(int glob[3][3][3][3], int i1, int i2, int i3, int i4)
{
  if ( glob[i1][i2][i3][i4] < 0 )
    /* value not calculated yet! */
    {
      if ( i1 == 2 )
	glob[i1][i2][i3][i4] = 
	  globtree(glob, 0, i2, i3, i4) + globtree(glob, 1, i2, i3, i4);
      else if ( i2 == 2 )
	glob[i1][i2][i3][i4] =	
	  globtree(glob, i1, 0, i3, i4) + globtree(glob, i1, 1, i3, i4);
      else if ( i3 == 2 )
	glob[i1][i2][i3][i4] =
	  globtree(glob, i1, i2, 0, i4) + globtree(glob, i1, i2, 1, i4);
      else if ( i4 == 2 )
	glob[i1][i2][i3][i4] =
	  globtree(glob, i1, i2, i3, 0) + globtree(glob, i1, i2, i3, 1);
  else /* something bad happened */
    myerror("error in globtree --- array wasn't set up correctly!");
    }
  return(glob[i1][i2][i3][i4]);

}

#define globeval(glob, i1, i2, i3, i4) glob[i1][i2][i3][i4]

main(int argc, char *argv[])
{
  int firstp, nump;

  int j;
  /* loop counter for number of primes */

  int fxcount[2][2][2][2], tccount[2][2][2][2];
  int fxglob[3][3][3][3], tcglob[3][3][3][3];
  /* fx = number of fixed points, tc = number of two-cycles */
  /* g r.p. to p-1, g p.r. mod p, h r.p. to p-1, h p.r. mod p */
  /* index is 0 if false, 1 if true, 2 if doesn't matter (only used at the end) */

  int gisrp, gispr, hisrp, hispr;
  /* 0 if false, 1 if true */

  BIGNUM *p, *x, *y, *g, *h, *a, *b, *d, *pminusone, *gen;
  /* p is the prime, x is the log of g, y is the log of h */
  /* a=g^h, b=g^g^h */
  /* d holds various gcds */
  /* g is a generator = p.r. mod p */

  int xint, yint, pint;
  /* integer versions of the bignums */

  BN_CTX *ctx;

  char *pstring;

  debug("Initializing MPI");

  MPI_Init(&argc, &argv);
  MPI_Comm_size(MPI_COMM_WORLD, &numnodes);
  MPI_Comm_rank(MPI_COMM_WORLD, &myid);

  if ( argc >= 4 ) numthreads=atoi(argv[3]);
  if ( argc >= 5 ) debug_level=atoi(argv[4]);

  omp_set_num_threads(numthreads);
  
  #pragma omp parallel default(private) shared(argc, argv, numnodes, myid, fxcount, tccount, fxglob, tcglob) copyin(prog, debug_level, bn_maxint, numthreads)
  {


  prog = argv[0];

 #pragma omp master
  if ( myid == 0 ) 
    {
    printf("Program %s running on %d nodes with %d processors each.\n\n", 
	   prog, numnodes, numthreads);
    fflush(stdout);
    }

  debug("Checking usage");
  
  if ( argc < 3 ) myerror("Need two arguments: firstp and nump");

  
  firstp=atoi(argv[1]);
  nump=atoi(argv[2]);
  
  debug("initializing");

  ctx=BN_CTX_new();
  if (ctx==NULL) myerror("error allocating ctx");
  /* temporary structure */

  bn_maxint=BN_new();
  if (bn_maxint==NULL) myerror("error allocating bn_maxint");
  /* the value of maxint as a BIGNUM */
  if ( !BN_set_word(bn_maxint, INT_MAX) ) myerror("error initializing bn_maxint to maxint");
         
  p=BN_new();
  if (p==NULL) myerror("error allocating p");
  /* the prime we are considering */

  g=BN_new();
  if (g==NULL) myerror("error allocating g");
  /* the base of g^h=h or g^g^h=h */

  x=BN_new();
  if (x==NULL) myerror("error allocating x");
  /* the log of g base gen */ 

  h=BN_new();
  if (h==NULL) myerror("error allocating h");
  /* the exponent */

  y=BN_new();
  if (y==NULL) myerror("error allocating y");
  /* log of h base gen */

  a=BN_new();
  if (a==NULL) myerror("error allocating a");
  /* a = g^h */

  b=BN_new();
  if (b==NULL) myerror("error allocating b");
  /* b = g^g^h */
  
  d=BN_new();
  if (d==NULL) myerror("error allocating d");
  /* holds various gcds */

  gen=BN_new();
  if (gen==NULL) myerror("error allocating generator");
  /* a generator = p.r. mod p */

  pminusone=BN_new();
  if (pminusone==NULL) myerror("error allocating pminusone");
  /* p-1 */

  if ( !BN_set_word(p, firstp) ) myerror("error initializing p to firstp");

  for( j=1; j<=nump; j++ )
   {
     
     debug("finding a prime");

     if ( !nextprime(p, ctx) ) myerror("error finding next prime p");
     
     if ( BN_cmp(p, bn_maxint) == 1 ) myerror("p larger than maxint");
       /* p greater than maxint */

     pint = (int)BN_get_word(p);
     if ( pint <= 0 ) myerror("error converting p to int (too large?)");
     
     if ( !BN_sub(pminusone, p, BN_value_one()) ) 
       myerror("error initializing pminusone to p-1");

     /*   	trace_bignum("p-1", pminusone); */
   
     if ( !BN_one(gen) ) myerror("error initializing generator to 1");
     debug("finding a generator using primroot");
     if ( !primroot(gen, p, pminusone, ctx) ) myerror("error in primroot");
     debug("primroot successful");
     
     /* shared resources */
#pragma omp barrier
#pragma omp single
     {
     debug("setting counts to zero");
     for(gisrp=0; gisrp<2; gisrp++)
       for(gispr=0; gispr<2; gispr++)
	 for(hisrp=0; hisrp<2; hisrp++)
	   for(hispr=0; hispr<2; hispr++)
	     {
	       fxcount[gisrp][gispr][hisrp][hispr]=0;
	       tccount[gisrp][gispr][hisrp][hispr]=0;
	     }


     debug("setting global counts to -1 (no value)");
     for(gisrp=0; gisrp<3; gisrp++)
       for(gispr=0; gispr<3; gispr++)
	 for(hisrp=0; hisrp<3; hisrp++)
	   for(hispr=0; hispr<3; hispr++)
	     {
	       fxglob[gisrp][gispr][hisrp][hispr]=-1;
	       tcglob[gisrp][gispr][hisrp][hispr]=-1;
	     }
     } /* end single */
#pragma omp barrier
#pragma omp flush

     debug("starting g and h loops");

     /* distribute this among all nodes */
     	for (xint = myid + 1; xint < pint; xint = xint + numnodes)	
       {
	   if ( !BN_set_word(x, xint) ) myerror("error converting x to bignum");

	   if ( !BN_mod_exp(g, gen, x, p, ctx) )
	     myerror("error assigning g = gen^x mod p");
	   
	   /* distribute this among all threads */
	   #pragma omp for
	   for (yint=1; yint < pint; yint++)	
	   {
	        if ( !BN_set_word(y, yint) ) myerror("error converting y to bignum");
  
		if ( !BN_mod_exp(h, gen, y, p, ctx) )
		  myerror("error assigning h = gen^y mod p");

	   trace_bignum("x", x);
	   trace_bignum("g", g);
	   trace_bignum("y", y);
	   trace_bignum("h", h);

	   if ( !BN_gcd(d, g, pminusone, ctx) ) 
	       myerror("error assigning d =  gcd(g, p-1)"); 
	   /* trace_bignum("gcd(g, p-1)", d); */
	   gisrp = BN_is_one(d);

	   if ( !BN_gcd(d, h, pminusone, ctx) ) 
	       myerror("error assigning d =  gcd(h, p-1)"); 
	   /* trace_bignum("gcd(h, p-1)", d); */
	   hisrp = BN_is_one(d);

	   if ( !BN_gcd(d, x, pminusone, ctx) ) 
	       myerror("error assigning d =  gcd(x, p-1)"); 
	   /* trace_bignum("gcd(x, p-1)", d); */
	   gispr = BN_is_one(d);

	   if ( !BN_gcd(d, y, pminusone, ctx) ) 
	       myerror("error assigning d =  gcd(y, p-1)"); 
	   /* trace_bignum("gcd(y, p-1)", d); */
	   hispr = BN_is_one(d);


          
	   if ( !BN_mod_exp(a, g, h, p, ctx) ) 
	       myerror("error assigning a = g^h mod p");
	   /* trace_bignum("g^h mod p", a); */
	   
	     if ( !BN_mod_exp(b, g, a, p, ctx) ) 
	       myerror("error assigning b = g^a mod p");
	     /*  trace_bignum("g^g^h mod p mod p", b); */
	
	     if ( BN_cmp(a, h) == 0 )
	       /* if (a == h) */ 
	       {
		 fxcount[gisrp][gispr][hisrp][hispr]++;
		 /* debug("g^h=h"); */
                 #pragma omp flush 
	       }
	      
	     if ( BN_cmp(b, h) == 0 )
	       /* if (b == h) */ 
	       {
		 tccount[gisrp][gispr][hisrp][hispr]++;
		 /* debug("g^g^h=h"); */
                 #pragma omp flush	       		
	       }

	   } /* end yint loop */
       } /* end xint loop */
     
/* gather the data, only on master for each node, collect at root node */

#pragma omp barrier
#pragma omp master
	{
	  for(gisrp=0; gisrp<2; gisrp++)
	    for(gispr=0; gispr<2; gispr++)
	      for(hisrp=0; hisrp<2; hisrp++)
		for(hispr=0; hispr<2; hispr++)
		  {  
		    MPI_Reduce(&(fxcount[gisrp][gispr][hisrp][hispr]),
			       &(fxglob[gisrp][gispr][hisrp][hispr]),
			       1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
		    MPI_Reduce(&(tccount[gisrp][gispr][hisrp][hispr]),
			       &(tcglob[gisrp][gispr][hisrp][hispr]),
			       1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
		  }
	}

#pragma omp flush

     /* report data, only on master */
     #pragma omp master
     if (myid == 0)
       {
	 pstring=BN_bn2dec(p);
	 printf("prime = %s\n", pstring);
	 OPENSSL_free(pstring);

	 /* the better way to report these? */

	 /* rearrange the data */

	 for(gisrp=0; gisrp<2; gisrp++)
	   for(gispr=0; gispr<2; gispr++)
	     for(hisrp=0; hisrp<2; hisrp++)
	       {
		 fxglob[gisrp][gispr][hisrp][2] = 
		   fxglob[gisrp][gispr][hisrp][0] + 
		   fxglob[gisrp][gispr][hisrp][1];
	       }
         for(gisrp=0; gisrp<2; gisrp++)
           for(gispr=0; gispr<2; gispr++)
             for(hispr=1; hispr<3; hispr++)
               {
                 fxglob[gisrp][gispr][2][hispr] = 
                   fxglob[gisrp][gispr][0][hispr] +
                   fxglob[gisrp][gispr][1][hispr];
               }
         for(gisrp=0; gisrp<2; gisrp++)
             for(hisrp=1; hisrp<3; hisrp++)
	       for(hispr=1; hispr<3; hispr++)
               {
                 fxglob[gisrp][2][hisrp][hispr] = 
                   fxglob[gisrp][0][hisrp][hispr] +
                   fxglob[gisrp][1][hisrp][hispr];
               }
         for(gispr=1; gispr<3; gispr++)
	   for(hisrp=1; hisrp<3; hisrp++)
	     for(hispr=1; hispr<3; hispr++)
               {
                 fxglob[2][gispr][hisrp][hispr] =
                   fxglob[0][gispr][hisrp][hispr] +
                   fxglob[1][gispr][hisrp][hispr];
               }

	 for(gisrp=0; gisrp<2; gisrp++)
	   for(gispr=0; gispr<2; gispr++)
	     for(hisrp=0; hisrp<2; hisrp++)
	       {
		 tcglob[gisrp][gispr][hisrp][2] = 
		   tcglob[gisrp][gispr][hisrp][0] + 
		   tcglob[gisrp][gispr][hisrp][1];
	       }
         for(gisrp=0; gisrp<2; gisrp++)
           for(gispr=0; gispr<2; gispr++)
             for(hispr=1; hispr<3; hispr++)
               {
                 tcglob[gisrp][gispr][2][hispr] = 
                   tcglob[gisrp][gispr][0][hispr] +
                   tcglob[gisrp][gispr][1][hispr];
               }
         for(gisrp=0; gisrp<2; gisrp++)
             for(hisrp=1; hisrp<3; hisrp++)
	       for(hispr=1; hispr<3; hispr++)
               {
                 tcglob[gisrp][2][hisrp][hispr] = 
                   tcglob[gisrp][0][hisrp][hispr] +
                   tcglob[gisrp][1][hisrp][hispr];
               }
         for(gispr=1; gispr<3; gispr++)
	   for(hisrp=1; hisrp<3; hisrp++)
	     for(hispr=1; hispr<3; hispr++)
               {
                 tcglob[2][gispr][hisrp][hispr] =
                   tcglob[0][gispr][hisrp][hispr] +
                   tcglob[1][gispr][hisrp][hispr];
               }


	 printf("Fixed points\n");
	 printf("g\\h\t h ANY\t h PR\t h RP\t h RPPR\n");
	 printf("g ANY\t %d\t %d\t %d\t %d\n",
		globeval(fxglob,2,2,2,2),
		globeval(fxglob,2,2,2,1),
		globeval(fxglob,2,2,1,2),
		globeval(fxglob,2,2,1,1));
	 printf("g PR\t %d\t %d\t %d\t %d\n",
                globeval(fxglob,2,1,2,2),
                globeval(fxglob,2,1,2,1),
                globeval(fxglob,2,1,1,2),
                globeval(fxglob,2,1,1,1));
	 printf("g RP\t %d\t %d\t %d\t %d\n",
                globeval(fxglob,1,2,2,2),
                globeval(fxglob,1,2,2,1),
                globeval(fxglob,1,2,1,2),
                globeval(fxglob,1,2,1,1));
	 printf("g RPPR\t %d\t %d\t %d\t %d\n",
                globeval(fxglob,1,1,2,2),
                globeval(fxglob,1,1,2,1),
                globeval(fxglob,1,1,1,2),
                globeval(fxglob,1,1,1,1));
	 printf("\n");


	 printf("Two-cycles\n");
	 printf("g\\h\t h ANY\t h PR\t h RP\t h RPPR\n");
	 printf("g ANY\t %d\t %d\t %d\t %d\n",
		globeval(tcglob,2,2,2,2),
		globeval(tcglob,2,2,2,1),
		globeval(tcglob,2,2,1,2),
		globeval(tcglob,2,2,1,1));
	 printf("g PR\t %d\t %d\t %d\t %d\n",
                globeval(tcglob,2,1,2,2),
                globeval(tcglob,2,1,2,1),
                globeval(tcglob,2,1,1,2),
                globeval(tcglob,2,1,1,1));
	 printf("g RP\t %d\t %d\t %d\t %d\n",
                globeval(tcglob,1,2,2,2),
                globeval(tcglob,1,2,2,1),
                globeval(tcglob,1,2,1,2),
                globeval(tcglob,1,2,1,1));
	 printf("g RPPR\t %d\t %d\t %d\t %d\n",
                globeval(tcglob,1,1,2,2),
                globeval(tcglob,1,1,2,1),
                globeval(tcglob,1,1,1,2),
                globeval(tcglob,1,1,1,1));
	 printf("\n");
	 
	 fflush(stdout);

       } /* end master processor only */

#pragma omp barrier

   } /* end j loop */
  }  /* end parallel structure */

     MPI_Finalize();

} /* end main */