# include <stdio.h>
# include <stdlib.h>
# include <limits.h>
# include <openssl/bn.h>
# include <openssl/crypto.h>
# include "mpi.h"

char *prog;

int debug_level = 0;
int numcpus = 1;
int myid = 0;

BIGNUM *bn_maxint;

void myerror(char *errstring)
{
  fprintf(stderr, "%s, processor %d: %s\n", prog, myid, errstring);
  MPI_Finalize();
  exit(1);
}

void debug(char *message)
{
	if ( debug_level >= 1 )  
	    fprintf(stderr, "processor %d: %s.\n", myid, message);

	if ( debug_level >= 3 )
	{
     		fprintf(stderr, "Press return to continue\n");
     		getchar();
	}
	  
}

void trace_bignum(char *string, BIGNUM *bn)
{
    if ( debug_level >= 2 )
	{
   	 char *bnstring;
  	  bnstring=BN_bn2dec(bn);
  	  fprintf(stderr, "processor %d: %s = %s\n", myid, string, bnstring);
  	  OPENSSL_free(bnstring);
	}
	
}


BIGNUM *nextprime(BIGNUM *p, BN_CTX *ctx)
{
  int test;

/*   while ( (test=BN_is_prime_fasttest(p, BN_prime_checks, NULL, ctx, NULL, 1)) */
/* 	  == 0 ) */

  
   do
    {
      /*       debug("incrementing p"); */
      if ( !BN_add_word(p, 1) ) myerror("error incrementing p");
      trace_bignum("p", p);
     }  
   while ( (test=BN_is_prime(p, BN_prime_checks, NULL, ctx, NULL)) 
 	  == 0 );

  if ( test == -1 )
    myerror("error testing for the primality of p");
  else 
    return(p);
}

main(int argc, char *argv[])
{
  int firstp, nump;

  int gint, hint, pint;
  
  int j;

  int  rpfxcount, nrpfxcount, rptccount, nrptccount;
  int  rpfxglob, nrpfxglob, rptcglob, nrptcglob;
  
  BIGNUM *p, *g, *h, *a, *b, *d, *pminusone;

  BN_CTX *ctx;

  char *pstring;

  MPI_Init(&argc, &argv);
  MPI_Comm_size(MPI_COMM_WORLD, &numcpus);
  MPI_Comm_rank(MPI_COMM_WORLD, &myid);

  prog = argv[0];

  debug("Checking usage");
  
  if ( argc < 3 ) myerror("Need two arguments: firstp and nump");

  
  firstp=atoi(argv[1]);
  nump=atoi(argv[2]);
  
  if ( argc >= 4 ) debug_level=atoi(argv[3]);

  debug("initializing");

  rpfxglob=0;
  nrpfxglob=0;
  rptcglob=0;
  nrptcglob=0;


  ctx=BN_CTX_new();
  if (ctx==NULL) myerror("error allocating ctx");
  /* temporary structure */

  bn_maxint=BN_new();
  if (bn_maxint==NULL) myerror("error allocating bn_maxint");
  /* the value of maxint as a BIGNUM */
  if ( !BN_set_word(bn_maxint, INT_MAX) ) myerror("error initializing bn_maxint to maxint");
         
  p=BN_new();
  if (p==NULL) myerror("error allocating p");
  /* the prime we are considering */

  g=BN_new();
  if (g==NULL) myerror("error allocating g");
  /* the base of g^h=h or g^g^h=h */

  h=BN_new();
  if (h==NULL) myerror("error allocating h");
  /* the exponent */

  a=BN_new();
  if (a==NULL) myerror("error allocating a");
  /* a = g^h */

  b=BN_new();
  if (b==NULL) myerror("error allocating b");
  /* b = g^g^h */
  
  d=BN_new();
  if (d==NULL) myerror("error allocating d");
  /* the gcd of h and p-1 */

  pminusone=BN_new();
  if (pminusone==NULL) myerror("error allocating pminusone");
  /* p-1 */

  if ( !BN_set_word(p, firstp) ) myerror("error initializing p to firstp");

  for( j=1; j<=nump; j++ )
   {
     
     debug("finding a prime");

     p=nextprime(p, ctx);
     
     if ( BN_cmp(p, bn_maxint) == 1 ) myerror("p larger than maxint");
       /* p greater than maxint */

     pint = (int)BN_get_word(p);
     if ( pint <= 0 ) myerror("error converting p to int (too large?)");
     
     if ( !BN_sub(pminusone, p, BN_value_one()) ) 
       myerror("error initializing pminusone to p-1");

     /*   	trace_bignum("p-1", pminusone); */
   
     rpfxcount=0;
     /* number of fixed points with h r.p. to p-1 */
     nrpfxcount=0;
     /* number of fixed points with h not r.p. to p-1 */
     rptccount=0;
     /* number of two-cycles (including f.p.s) with h r.p. to p-1 */
     nrptccount=0;
     /* number of two-cycles (including f.p.s) with h not r.p. to p-1 */
     
     debug("starting loop");

     /* distribute this among all cpus */
     	for (gint = myid + 1; gint < pint; gint = gint + numcpus)	
       {
	   if ( !BN_set_word(g, gint) ) myerror("error converting g to bignum");
	   for (hint=1; hint < pint; hint++)	
	   {
	        if ( !BN_set_word(h, hint) ) myerror("error converting h to bignum");
  
	   trace_bignum("g", g);
	   trace_bignum("h", h);
	   if ( !BN_gcd(d, h, pminusone, ctx) ) 
	       myerror("error assigning d =  gcd(h, p-1)"); 
	   /* trace_bignum("gcd(h, p-1)", d); */
	   
	   if ( !BN_mod_exp(a, g, h, p, ctx) ) 
	       myerror("error assigning a = g^h mod p");
	   /* trace_bignum("g^h mod p", a); */
	   
	     if ( !BN_mod_exp(b, g, a, p, ctx) ) 
	       myerror("error assigning b = g^a mod p");
	     /*  trace_bignum("g^g^h mod p mod p", b); */
	
	     if ( BN_cmp(a, h) == 0 )
	       /* if (a == h) */ 
	       {
		 if ( BN_cmp(d, BN_value_one()) == 0 )
		   /* if (d == 1) */
	       	   {	
		     /*	           debug("g^h=h, r.p."); */
		   rpfxcount = rpfxcount+1;
	       	   }
	       	  else
	      	   {
		     /*    debug("g^h=h, n.r.p."); */
		  nrpfxcount = nrpfxcount+1;
	      	   }
	       }	     

	     if ( BN_cmp(b, h) == 0 )
	       /* if (b == h) */ 
	       {
		 if ( BN_cmp(d, BN_value_one()) == 0 )
                   /* if (d == 1) */
	       		{	
			  /* debug("g^g^h=h, r.p."); */
			  rptccount = rptccount+1;
	       		}
	       	 else
		     {
		       /* debug("g^g^h=h, n.r.p."); */
		       nrptccount = nrptccount+1;
		     }
		}
	   } /* end hint loop */
       } /* end gint loop */
     
     /* gather the data */

     MPI_Reduce(&rpfxcount, &rpfxglob, 1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
     MPI_Reduce(&nrpfxcount, &nrpfxglob, 1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
     MPI_Reduce(&rptccount, &rptcglob, 1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
     MPI_Reduce(&nrptccount, &nrptcglob, 1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);

     /* report data, only on master */
     if (myid == 0)
       {
	 pstring=BN_bn2dec(p);
	 printf("prime = %s\n", pstring);
	 OPENSSL_free(pstring);

	 printf("number of fixed points with h an r.p. = %d\n", rpfxglob);
	 printf("number of fixed points total = %d\n", rpfxglob+nrpfxglob);
	 printf("\n");

	 printf("number of two-cycles with h an r.p. = %d\n", rptcglob);
	 printf("number of two-cycles total = %d\n", rptcglob+nrptcglob);
	 printf("\n");
       }
   }

     MPI_Finalize();
}