/* don't use this until barriers fixed! */

# include <stdio.h>
# include <stdlib.h>
# include <limits.h>
# include <openssl/bn.h>
# include <openssl/crypto.h>
# include "mpi.h"
# include <omp.h>
# include "bn_prime.h"


char *prog;

int debug_level = 0;
int numnodes = 1;
int myid = 0;
int numthreads = 1;

BIGNUM *bn_maxint;

#pragma omp threadprivate (prog, debug_level, bn_maxint, numthreads)


void myerror(char *errstring)
{
  fprintf(stderr, "%s, processor %d%c (of %d*%d): %s\n", prog, 
	  myid, omp_get_thread_num()+'A', 
	  numnodes, omp_get_num_threads(),
	  errstring);
#pragma omp master
  MPI_Finalize();

  exit(1);
}

void debug(char *message)
{
	if ( debug_level >= 1 )  
	    fprintf(stderr, "processor %d%c (of %d*%d): %s.\n", 
		    myid, omp_get_thread_num()+'A',  
		    numnodes, omp_get_num_threads(),
		    message); 

	if ( debug_level >= 3 )
	{
     		fprintf(stderr, "Press return to continue\n");
     		getchar();
	}
	  
}

void trace_bignum(char *string, BIGNUM *bn)
{
    if ( debug_level >= 2 )
	{
   	 char *bnstring;
  	  bnstring=BN_bn2dec(bn);
  	  fprintf(stderr, "processor %d%c (of %d*%d): %s = %s\n", 
		  myid, omp_get_thread_num()+'A', 
		  numnodes, omp_get_num_threads(), string, bnstring);
  	  OPENSSL_free(bnstring);
	}
	
}

void trace_int(char *string, int i)
{
    if ( debug_level >= 2 )
	{
  	  fprintf(stderr, "processor %d%c (of %d*%d): %s = %d\n", 
		  myid, omp_get_thread_num()+'A', 
		  numnodes, omp_get_num_threads(), string, i);

	}
	
}

int isprime(BIGNUM *p, BN_CTX *ctx)
{

  /* code adapted from BN_is_prime_fasttest */
  /* same return values as BN_is_prime_fasttest */

  int i;
  
  /* first look for small factors */

  if (!BN_is_odd(p))
    {
      if ( BN_is_word(p, 2) )
	return 1;
      else
	return 0;
    }
  
  for (i = 1; i < NUMPRIMES; i++)
    {
      /*      trace_int("primes[i]", primes[i]); */

	if ( BN_mod_word(p, primes[i]) == 0 ) 
	  {
	          if ( BN_is_word(p, primes[i]) )
		    return 1;
		  else 
		    return 0;
	  }
    }

  /* then call BN_is_prime_fasttest without trial division -- this is the same as BN_is_prime */

  /*  debug("running BN_is_prime_fasttest"); */
  return BN_is_prime_fasttest(p, BN_prime_checks, NULL, ctx, NULL, 0);

}


int nextprime(BIGNUM *p, BN_CTX *ctx)
     /* steps p to the next larger prime.  returns 1 on success, 0 on error. */
{
  int test;

  if ( BN_mod_word(p, 2) == 0) 
    if ( !BN_sub_word(p, 1) ) 
      {
	debug("error decrementing p in nextprime");
	return(0);
      }
   do
    {
      /*       debug("incrementing p"); */
      if ( !BN_add_word(p, 2) ) 
	{
	  debug("error incrementing p in nextprime");
	  return(0);
	}
      /* trace_bignum("p", p); */
     }  
   while ( (test=isprime(p, ctx) ) == 0 );

  if ( test == -1 )
    {
      debug("error testing for primality in nextprime");
      return(0);
    }
  else 
    return(1);
}

int isprimroot(BIGNUM *g, BIGNUM *p, BIGNUM *pminusone, BN_CTX *ctx)
     /* check is g is a primitive root modulo the odd prime p */
     /* returns 0 if not a prim root, 1 if is, -1 on error */
     /* doesn't clean up on error, but should */
     /* we do not check that p is an odd prime */
     /* pminusone should be p-1, we don't check */

{
  BIGNUM *temp_q, *temp_r, *temp_d;

  int returnval=1; /*assume we have a generator */

  temp_r=BN_new();
  if (temp_r==NULL) 
    {
      debug("error allocating temporary remainder in isprimroot"); 
      return(-1);
    }
  temp_q=BN_new();
  if (temp_q==NULL) 
    {
      debug("error allocating temporary prime in isprimroot"); 
      return(-1);
    }
  temp_d=BN_new();
  if (temp_d==NULL) 
    {
      debug("error allocating temporary quotient in isprimroot"); 
      return(-1);
    }
  /*  temporary variables */

  if ( !BN_mod(temp_r, g, p, ctx) ) 
    {
      debug("error reducing g mod p in isprimroot");
      return(-1);
    }

  /*  trace_bignum("g mod p", temp_r); */

  if ( BN_is_zero(temp_r) )
    return 0;
  /* do check if g is zero mod p */

  if ( !BN_set_word(temp_q, 2) ) 
    {
      debug("error initializing temporary prime in isprimroot"); 
      return(-1);
    }

  do
    {
      /*      trace_bignum("q", temp_q); */

      if ( !BN_div(temp_d, temp_r, pminusone, temp_q, ctx) )
	{
	  debug("error dividing p-1 by q in isprimroot");
	  return(-1);
	}

      /*      trace_bignum("p-1 div q", temp_d); */
      /*      trace_bignum("p-1 mod q", temp_r); */

      if ( BN_is_zero(temp_r) )
	{
	  if ( !BN_mod_exp(temp_r, g, temp_d, p, ctx) )
	    {	
	      debug("error assigning r = g^(p-1)/q mod p in isprimroot");
	      return(-1);
	    }
	  /* yes, we are reusing temp_r.  sue me. */
	  
	  /*	  trace_bignum("g^(p-1)/q mod p", temp_r); */

	  if ( BN_is_one(temp_r) ) returnval=0;
	  /* if we got one as a small power of g, then g is not a primitive root */

	  /*	  trace_int("prim root so far", returnval); */
	}

      if ( !nextprime(temp_q, ctx) ) 
	{
	  debug("error finding next q in isprimroot");
	  return(-1);
	}
    }
  while( (returnval==1) && (BN_cmp(temp_q, p)==-1) );
  /* repeat until definitely not a generator or q is not less than p */ 

  BN_free(temp_q);
  BN_free(temp_r);
  BN_free(temp_d);
  /* kill temporary storage before returning */

  return(returnval);
}

int primroot(BIGNUM *g, BIGNUM *p, BIGNUM *pminusone, BN_CTX *ctx)
     /* steps g to the next larger primitive root modulo the odd prime p */
     /* will keep going if g>p */
     /* returns 1 on success, 0 on error. */
     /* doesn't clean up on error, but should */
     /* we do not check that p is an odd prime */
     /* pminusone should be p-1, we don't check */
     /* probably not the most efficient way to find a primitive root */
{
  int test;

   do
     {
       /* debug("incrementing g in primroot");  */
       if ( !BN_add_word(g, 1) )
	 {
	   debug("error incrementing g in primroot");
	   return(0);
	 }
       /*       trace_bignum("g", g);  */
     }
   while ( (test=isprimroot(g, p, pminusone, ctx) ) == 0 );

   if ( test == -1 )
     {
       debug("error testing for primitivity in primroot");
       return(0);
     }
   else
     return(1);
}



main(int argc, char *argv[])
{
  int firstp, nump;

  int gint, hint, pint;
  
  int j;

  int  rpfxcount, nrpfxcount, rptccount, nrptccount;
  int  rpfxglob, nrpfxglob, rptcglob, nrptcglob;
  
  BIGNUM *p, *g, *h, *a, *b, *d, *pminusone;

  BN_CTX *ctx;

  char *pstring;

  MPI_Init(&argc, &argv);
  MPI_Comm_size(MPI_COMM_WORLD, &numnodes);
  MPI_Comm_rank(MPI_COMM_WORLD, &myid);

  if ( argc >= 4 ) numthreads=atoi(argv[3]);
  if ( argc >= 5 ) debug_level=atoi(argv[4]);

  omp_set_num_threads(numthreads);
  
  #pragma omp parallel default(private) shared(argc, argv, numnodes, myid, rpfxcount, nrpfxcount, rptccount, nrptccount,  rpfxglob, nrpfxglob, rptcglob, nrptcglob) copyin(prog, debug_level, bn_maxint, numthreads)
  {


  prog = argv[0];

 #pragma omp master
  if ( myid == 0 ) 
    printf("Program %s running on %d nodes with %d processors each.\n\n", 
	   prog, numnodes, numthreads);



  debug("Checking usage");
  
  if ( argc < 3 ) myerror("Need two arguments: firstp and nump");

  
  firstp=atoi(argv[1]);
  nump=atoi(argv[2]);
  
  debug("initializing");

  rpfxglob=0;
  nrpfxglob=0;
  rptcglob=0;
  nrptcglob=0;


  ctx=BN_CTX_new();
  if (ctx==NULL) myerror("error allocating ctx");
  /* temporary structure */

  bn_maxint=BN_new();
  if (bn_maxint==NULL) myerror("error allocating bn_maxint");
  /* the value of maxint as a BIGNUM */
  if ( !BN_set_word(bn_maxint, INT_MAX) ) myerror("error initializing bn_maxint to maxint");
         
  p=BN_new();
  if (p==NULL) myerror("error allocating p");
  /* the prime we are considering */

  g=BN_new();
  if (g==NULL) myerror("error allocating g");
  /* the base of g^h=h or g^g^h=h */

  h=BN_new();
  if (h==NULL) myerror("error allocating h");
  /* the exponent */

  a=BN_new();
  if (a==NULL) myerror("error allocating a");
  /* a = g^h */

  b=BN_new();
  if (b==NULL) myerror("error allocating b");
  /* b = g^g^h */
  
  d=BN_new();
  if (d==NULL) myerror("error allocating d");
  /* the gcd of h and p-1 */

  pminusone=BN_new();
  if (pminusone==NULL) myerror("error allocating pminusone");
  /* p-1 */

  if ( !BN_set_word(p, firstp) ) myerror("error initializing p to firstp");

  for( j=1; j<=nump; j++ )
   {
     
     debug("finding a prime");

     if ( !nextprime(p, ctx) ) myerror("error finding next prime p");
     
     if ( BN_cmp(p, bn_maxint) == 1 ) myerror("p larger than maxint");
       /* p greater than maxint */

     pint = (int)BN_get_word(p);
     if ( pint <= 0 ) myerror("error converting p to int (too large?)");
     
     if ( !BN_sub(pminusone, p, BN_value_one()) ) 
       myerror("error initializing pminusone to p-1");

     /*   	trace_bignum("p-1", pminusone); */
   
     rpfxcount=0;
     /* number of fixed points with h r.p. to p-1 */
     nrpfxcount=0;
     /* number of fixed points with h not r.p. to p-1 */
     rptccount=0;
     /* number of two-cycles (including f.p.s) with h r.p. to p-1 */
     nrptccount=0;
     /* number of two-cycles (including f.p.s) with h not r.p. to p-1 */
     
     debug("starting loop");

     /* distribute this among all cpus */
     	for (gint = myid + 1; gint < pint; gint = gint + numnodes)	
       {
	   if ( !BN_set_word(g, gint) ) myerror("error converting g to bignum");
	   #pragma omp for
	   for (hint=1; hint < pint; hint++)	
	   {
	        if ( !BN_set_word(h, hint) ) myerror("error converting h to bignum");
  
	   trace_bignum("g", g);
	   trace_bignum("h", h);
	   if ( !BN_gcd(d, h, pminusone, ctx) ) 
	       myerror("error assigning d =  gcd(h, p-1)"); 
	   /* trace_bignum("gcd(h, p-1)", d); */
	   
	   if ( !BN_mod_exp(a, g, h, p, ctx) ) 
	       myerror("error assigning a = g^h mod p");
	   /* trace_bignum("g^h mod p", a); */
	   
	     if ( !BN_mod_exp(b, g, a, p, ctx) ) 
	       myerror("error assigning b = g^a mod p");
	     /*  trace_bignum("g^g^h mod p mod p", b); */
	
	     if ( BN_cmp(a, h) == 0 )
	       /* if (a == h) */ 
	       {
		 if ( BN_is_one(d) )
		   /* if (d == 1) */
	       	   {	
		     /*	           debug("g^h=h, r.p."); */
		   rpfxcount = rpfxcount+1;
                   #pragma omp flush 
	       	   }
	       	  else
	      	   {
		     /*    debug("g^h=h, n.r.p."); */
		     nrpfxcount = nrpfxcount+1;
                     #pragma omp flush
}
	       }	     

	     if ( BN_cmp(b, h) == 0 )
	       /* if (b == h) */ 
	       {
		 if ( BN_is_one(d) )
                   /* if (d == 1) */
	       		{	
			  /* debug("g^g^h=h, r.p."); */
			  rptccount = rptccount+1;
                          #pragma omp flush	       		
			}
	       	 else
		     {
		       /* debug("g^g^h=h, n.r.p."); */
		       nrptccount = nrptccount+1;
                       #pragma omp flush
		     }
		}
	   } /* end hint loop */
       } /* end gint loop */
     
     /* gather the data, only on master */

#pragma omp barrier
#pragma omp master
	{
	  MPI_Reduce(&rpfxcount, &rpfxglob, 1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
	  MPI_Reduce(&nrpfxcount, &nrpfxglob, 1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
	  MPI_Reduce(&rptccount, &rptcglob, 1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
	  MPI_Reduce(&nrptccount, &nrptcglob, 1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
	}

     /* report data, only on master */
     #pragma omp master
     if (myid == 0)
       {
	 pstring=BN_bn2dec(p);
	 printf("prime = %s\n", pstring);
	 OPENSSL_free(pstring);

	 printf("number of fixed points with h an r.p. = %d\n", rpfxglob);
	 printf("number of fixed points total = %d\n", rpfxglob+nrpfxglob);
	 printf("\n");

	 printf("number of two-cycles with h an r.p. = %d\n", rptcglob);
	 printf("number of two-cycles total = %d\n", rptcglob+nrptcglob);
	 printf("\n");
       }

   } /* end j loop */
  }  /* end parallel structure */

     MPI_Finalize();

} /* end main */