# include <stdio.h>
# include <stdlib.h>
# include <limits.h>
# include <openssl/bn.h>
# include <openssl/crypto.h>
# include "mpi.h"
# include <omp.h>

char *prog;

int debug_level = 0;
int numnodes = 1;
int myid = 0;
int numthreads = 1;

BIGNUM *bn_maxint;

#pragma omp threadprivate (prog, debug_level, bn_maxint, numthreads)


void myerror(char *errstring)
{
  fprintf(stderr, "%s, processor %d%c (of %d*%d): %s\n", prog, 
	  myid, omp_get_thread_num()+'A', 
	  numnodes, omp_get_num_threads(),
	  errstring);
#pragma omp master
  MPI_Finalize();

  exit(1);
}

void debug(char *message)
{
	if ( debug_level >= 1 )  
	    fprintf(stderr, "processor %d%c (of %d*%d): %s.\n", 
		    myid, omp_get_thread_num()+'A',  
		    numnodes, omp_get_num_threads(),
		    message); 

	if ( debug_level >= 3 )
	{
     		fprintf(stderr, "Press return to continue\n");
     		getchar();
	}
	  
}

void trace_bignum(char *string, BIGNUM *bn)
{
    if ( debug_level >= 2 )
	{
   	 char *bnstring;
  	  bnstring=BN_bn2dec(bn);
  	  fprintf(stderr, "processor %d%c (of %d*%d): %s = %s\n", 
		  myid, omp_get_thread_num()+'A', 
		  numnodes, omp_get_num_threads(), string, bnstring);
  	  OPENSSL_free(bnstring);
	}
	
}


BIGNUM *nextprime(BIGNUM *p, BN_CTX *ctx)
{
  int test;

/*   while ( (test=BN_is_prime_fasttest(p, BN_prime_checks, NULL, ctx, NULL, 1)) */
/* 	  == 0 ) */

  if ( BN_mod_word(p, 2) == 0) 
    if ( !BN_sub_word(p, 1) ) myerror("error decrementing p");

   do
    {
      /*       debug("incrementing p"); */
      if ( !BN_add_word(p, 2) ) myerror("error incrementing p");
      trace_bignum("p", p);
     }  
   while ( (test=BN_is_prime(p, BN_prime_checks, NULL, ctx, NULL)) 
 	  == 0 );

  if ( test == -1 )
    myerror("error testing for the primality of p");
  else 
    return(p);
}

main(int argc, char *argv[])
{
  int firstp, nump;

  int gint, hint, pint;
  
  int j;

  int  rpfxcount, nrpfxcount, rptccount, nrptccount;
  int  rpfxglob, nrpfxglob, rptcglob, nrptcglob;
  
  BIGNUM *p, *g, *h, *a, *b, *d, *pminusone;

  BN_CTX *ctx;

  char *pstring;

  MPI_Init(&argc, &argv);
  MPI_Comm_size(MPI_COMM_WORLD, &numnodes);
  MPI_Comm_rank(MPI_COMM_WORLD, &myid);

  if ( argc >= 4 ) numthreads=atoi(argv[3]);
  if ( argc >= 5 ) debug_level=atoi(argv[4]);

  omp_set_num_threads(numthreads);
  
  #pragma omp parallel default(private) shared(argc, argv, numnodes, myid, rpfxcount, nrpfxcount, rptccount, nrptccount,  rpfxglob, nrpfxglob, rptcglob, nrptcglob) copyin(prog, debug_level, bn_maxint, numthreads)
  {


  prog = argv[0];

 #pragma omp master
  if ( myid == 0 ) 
    printf("Program %s running on %d nodes with %d processors each.\n\n", 
	   prog, numnodes, numthreads);



  debug("Checking usage");
  
  if ( argc < 3 ) myerror("Need two arguments: firstp and nump");

  
  firstp=atoi(argv[1]);
  nump=atoi(argv[2]);
  
  debug("initializing");

  rpfxglob=0;
  nrpfxglob=0;
  rptcglob=0;
  nrptcglob=0;


  ctx=BN_CTX_new();
  if (ctx==NULL) myerror("error allocating ctx");
  /* temporary structure */

  bn_maxint=BN_new();
  if (bn_maxint==NULL) myerror("error allocating bn_maxint");
  /* the value of maxint as a BIGNUM */
  if ( !BN_set_word(bn_maxint, INT_MAX) ) myerror("error initializing bn_maxint to maxint");
         
  p=BN_new();
  if (p==NULL) myerror("error allocating p");
  /* the prime we are considering */

  g=BN_new();
  if (g==NULL) myerror("error allocating g");
  /* the base of g^h=h or g^g^h=h */

  h=BN_new();
  if (h==NULL) myerror("error allocating h");
  /* the exponent */

  a=BN_new();
  if (a==NULL) myerror("error allocating a");
  /* a = g^h */

  b=BN_new();
  if (b==NULL) myerror("error allocating b");
  /* b = g^g^h */
  
  d=BN_new();
  if (d==NULL) myerror("error allocating d");
  /* the gcd of h and p-1 */

  pminusone=BN_new();
  if (pminusone==NULL) myerror("error allocating pminusone");
  /* p-1 */

  if ( !BN_set_word(p, firstp) ) myerror("error initializing p to firstp");

  for( j=1; j<=nump; j++ )
   {
     
     debug("finding a prime");

     p=nextprime(p, ctx);
     
     if ( BN_cmp(p, bn_maxint) == 1 ) myerror("p larger than maxint");
       /* p greater than maxint */

     pint = (int)BN_get_word(p);
     if ( pint <= 0 ) myerror("error converting p to int (too large?)");
     
     if ( !BN_sub(pminusone, p, BN_value_one()) ) 
       myerror("error initializing pminusone to p-1");

     /*   	trace_bignum("p-1", pminusone); */
   
     /* shared resources */
#pragma omp barrier
#pragma omp single
     {
     rpfxcount=0;
     /* number of fixed points with h r.p. to p-1 */
     nrpfxcount=0;
     /* number of fixed points with h not r.p. to p-1 */
     rptccount=0;
     /* number of two-cycles (including f.p.s) with h r.p. to p-1 */
     nrptccount=0;
     /* number of two-cycles (including f.p.s) with h not r.p. to p-1 */
     } /* end single */
#pragma omp barrier
#pragma omp flush


     debug("starting loop");

     /* distribute this among all cpus */
     	for (gint = myid + 1; gint < pint; gint = gint + numnodes)	
       {
	   if ( !BN_set_word(g, gint) ) myerror("error converting g to bignum");
	   #pragma omp for
	   for (hint=1; hint < pint; hint++)	
	   {
	        if ( !BN_set_word(h, hint) ) myerror("error converting h to bignum");
  
	   trace_bignum("g", g);
	   trace_bignum("h", h);
	   if ( !BN_gcd(d, h, pminusone, ctx) ) 
	       myerror("error assigning d =  gcd(h, p-1)"); 
	   /* trace_bignum("gcd(h, p-1)", d); */
	   
	   if ( !BN_mod_exp(a, g, h, p, ctx) ) 
	       myerror("error assigning a = g^h mod p");
	   /* trace_bignum("g^h mod p", a); */
	   
	     if ( !BN_mod_exp(b, g, a, p, ctx) ) 
	       myerror("error assigning b = g^a mod p");
	     /*  trace_bignum("g^g^h mod p mod p", b); */
	
	     if ( BN_cmp(a, h) == 0 )
	       /* if (a == h) */ 
	       {
		 if ( BN_cmp(d, BN_value_one()) == 0 )
		   /* if (d == 1) */
	       	   {	
		     /*	           debug("g^h=h, r.p."); */
		   rpfxcount = rpfxcount+1;
                   #pragma omp flush 
	       	   }
	       	  else
	      	   {
		     /*    debug("g^h=h, n.r.p."); */
		     nrpfxcount = nrpfxcount+1;
                     #pragma omp flush
}
	       }	     

	     if ( BN_cmp(b, h) == 0 )
	       /* if (b == h) */ 
	       {
		 if ( BN_cmp(d, BN_value_one()) == 0 )
                   /* if (d == 1) */
	       		{	
			  /* debug("g^g^h=h, r.p."); */
			  rptccount = rptccount+1;
                          #pragma omp flush	       		
			}
	       	 else
		     {
		       /* debug("g^g^h=h, n.r.p."); */
		       nrptccount = nrptccount+1;
                       #pragma omp flush
		     }
		}
	   } /* end hint loop */
       } /* end gint loop */
     
     /* gather the data, only on master */

#pragma omp barrier
#pragma omp master
	{
	  MPI_Reduce(&rpfxcount, &rpfxglob, 1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
	  MPI_Reduce(&nrpfxcount, &nrpfxglob, 1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
	  MPI_Reduce(&rptccount, &rptcglob, 1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
	  MPI_Reduce(&nrptccount, &nrptcglob, 1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
	}

#pragma omp flush



     /* report data, only on master */
     #pragma omp master
     if (myid == 0)
       {
	 pstring=BN_bn2dec(p);
	 printf("prime = %s\n", pstring);
	 OPENSSL_free(pstring);

	 printf("number of fixed points with h an r.p. = %d\n", rpfxglob);
	 printf("number of fixed points total = %d\n", rpfxglob+nrpfxglob);
	 printf("\n");

	 printf("number of two-cycles with h an r.p. = %d\n", rptcglob);
	 printf("number of two-cycles total = %d\n", rptcglob+nrptcglob);
	 printf("\n");
       } /* end master processor only */

#pragma omp barrier

   } /* end j loop */
  }  /* end parallel structure */

     MPI_Finalize();

} /* end main */