#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include "bn_prime.h"
//#define STATUS "status.txt"
//#define n 100103
#define M_ARY 2
#define bool char
#define false 0
#define true 1

int n;

/* Algorithm for modular expoentiation.*/
long long ml_exp(long long b, int e, long long m) 
{
    long long result = 1;
    while (e > 0)
     {
         /* Multiply in this bits’ contribution while using modulus to keep result small*/
         /* Note '&' is the bitwise 'and' operator*/
         if ((e & 1) == 1) result = (result * b) % m;
         e >>= 1;
         b = (b * b) % m;
    }
    return result;
}

/* Another method for modular expoentiation called on integers.*/
int m_expn(int b, int r, int num) 
{
    return (int) ml_exp((long long) b, r, (long long) num);
}

/* Yet another simplified method for expoentiation that may be used when the modulus (n) has already been defined. */
/* The inputs are the base b and the exponent r. */
int m_exp(int b, int r)
{
    return m_expn(b, r, n);
}

/* Compact method to find total 'Distance to Cycle' and 'Cycle Size' based on previously computed arrays containing results from the individual nodes. */
void computeResults(const int* distToCycle, const int* cycleSize, long* allToCycleSum, long* allCycleLengthSum) 
{
    int i;
    long sumDistToCycle = 0;
    long sumCycleSize = 0;
    for(i = 0; i < n; i++) 
    {
        sumDistToCycle += (long) distToCycle[i];
        sumCycleSize += (long) cycleSize[i];
    }
    *allToCycleSum = sumDistToCycle;
    *allCycleLengthSum = sumCycleSize;
}

/* Implementation of the Rabin-Miller test for primality.*/
bool MillerRabin(int num, int k, int q, int a) 
{
    int n1 = num-1;
    if(m_expn(a,q, num) == 1) return true;
    int i = 0;
    for(i = 0; i < k; i++)
        if(m_expn(a,(int) (pow(2,i)) * q, num) == n1) return true;
    return false;
}

/*Test for primiality using an attached array of primes in the header file bn_prime.h.*/
bool isPrime(int num) 
{
    int i = 0;
    for(i = 0; i < 50;i++) 
    {
        if(primes[i] > (unsigned)num) return true;
        if((num % primes[i] == 0) && (num!=primes[i])) return false;
    }
    int k = 0;
    int q = num-1;
    while(q % 2 == 0) 
    {
        k++;
        q >>= 1;
    }
    srand(time(0));
    int a;
    for(i = 0; i < 10; i++) 
    {
        a = (rand() % (num-2)) + 1;
        if(!MillerRabin(num, k, q, a)) return false;
    }
    return true;
}

/* Method to test whether or not an element is a Primitive Root of an inputted modulus. */
bool isPrimRoot(int base) 
{ 
    if(!isPrime(n)) return false;
    int n_1 = n-1;
    if ((unsigned)n_1 > (primes[NUMPRIMES-1]*primes[NUMPRIMES-1]))
        printf("Error in Primitive Root Testing, n could have prime factor too large for testing\n");
    int n1 = n_1;
    int index = 0;
    int p;
    while(n1 > 1 && index < NUMPRIMES)
    {
        /*find the primes that divide phi(n)*/
        if((n1 % primes[index]) == 0) 
        {
            p = primes[index];
            /* divide out that prime all the way so it isn’t tested again*/
            while((n1 % primes[index] == 0)) n1/=primes[index];
            /*if base^phi(n)/p is 1, not a prim root*/
            if(m_exp(base,n_1/p) == 1) return false;
            /*if(isPrime(n1)) return m_exp(base,n_1/n1) == 1;*/
            if(n1 == 50021) return (m_exp(base,n_1/50021) != 1);
        }
        index++;
    }
    return true;
}

/*Method to find the greatest common divisors of two elements.*/
int gcd(int a, int b) 
{
    if(a== 0) return b;
    if(b==0) return a;
    int r = a % b;
    int d = b;
    int c;
    while (r > 0) 
    {
        c = d;
        d = r;
        r = c % d;
    }
    return d;
}

/*Method to determine whether an element is relatively prime to the modulus.*/
bool isRelPrime(int base) 
{
    return gcd(base, n-1) == 1;
}

/*Method to determine the number of integers relatively prime, that is, Euler Phi Function. This is used to find
the number of m-ary graphs that will be computers thereby giving the number of trials.*/
double euler(int numb) 
{ 
  double value;
  int i=0;
  int result = numb; 
  for(i=2;i*i <= numb;i++) 
  { 
      if (numb % i == 0) result -= result / i; 
      while (numb % i == 0) numb /= i; 
  } 
  if (numb > 1) result -= result / numb;
  value=result; 
  return value; 
} 



/*Method that initializes the arrays based on the size of the modulus.*/
void setArrays(int * cycleSize, bool* visit, int* distToCycle, bool* image)
{
    int i = 0;
    for(i = 0; i < n; i++)
    {
        visit[i] = false;
        cycleSize[i] = 0;
        distToCycle[i] = 0;
        image[i] = false;
    }
}

void zeroList(int * listArray) 
{
    int i = 0;
    for(i = 0; i < n; i++)
    listArray[i] = 0;
}

void writeTotalResults(long* sumCycleLengths, long* sumTailLengths, FILE* s, char* str) {
	double aveTL = 0, aveCL = 0, aveRL = 0, aveSqDifRL = 0, varTL = 0, varCL = 0, varRL = 0, varSqDifRL = 0, covTC = 0;
	//int sumRhoLengths[n];
	//float aveTailLengths[n], aveCycleLengths[n], aveRhoLengths[n];
	s = fopen(str, "a");
	fprintf(s, "Allocating memory...\n");
	double *aveTailLengths = malloc(n*sizeof(double));
	if (NULL == aveTailLengths) { fprintf(s, "malloc failed for aveTailLengths array\n"); }

	double *aveCycleLengths = malloc(n*sizeof(double));
	if (NULL == aveCycleLengths) { fprintf(s, "malloc failed for aveCycleLengths array\n"); }

	double *aveRhoLengths = malloc(n*sizeof(double));
	if (NULL == aveRhoLengths) { fprintf(s, "malloc failed for aveRhoLengths array\n"); }
	fprintf(s, "Finished allocating...\n");
	fclose(s);

	//float difTailLengths[n], difCycleLengths[n], sqDifRhoLengths[n];
	double trials = euler((n - 1) / M_ARY);
	int i;

	s = fopen(str, "a");
	fprintf(s, "Begin calculating averages...\n");
	fclose(s);
	for (i = 2; i < n; ++i) {
		//s = fopen(str, "a");
		//fprintf(s, "Iteration %d\n", i);
		//fclose(s);
		if (sumCycleLengths[i] != 0) {
			//sumRhoLenghts[i] = sumTailLengths[i] + sumCycleLengths[i];
			//s = fopen(str, "a");
			//fprintf(s, "Iteration %d\n", i);
			//fclose(s);
			aveTailLengths[i] = ((long double) sumTailLengths[i]) / (n - 1);
			//aveTailLengths[i] = ((double) sumTailLengths[i]) / n;
			//s = fopen(str, "a");
			//fprintf(s, "Iteration %d: aveTailLength updated\n", i);
			//fclose(s);
			aveCycleLengths[i] = ((long double) sumCycleLengths[i]) / (n - 1);
			//aveCycleLengths[i] = ((double) sumCycleLengths[i]) / n;
			//s = fopen(str, "a");
			//fprintf(s, "Iteration %d: aveCycleLength updated\n", i);
			//fclose(s);
			aveRhoLengths[i] = ((long double) sumTailLengths[i]) / (n - 1) + ((long double) sumCycleLengths[i]) / (n - 1);
			//aveRhoLengths[i] = ((double) sumTailLengths[i]) / n + ((double) sumCycleLengths[i]) / n;
			//s = fopen(str, "a");
			//fprintf(s, "Iteration %d: aveRhoLength updated\n\n", i);
			//fclose(s);
		}
	}

	s = fopen(str, "a");
	fprintf(s, "Finished calculating averages...\n");
	fclose(s);

	int count = 0;
	for (i = 1; i < n; ++i) {
		if (sumCycleLengths[i] != 0) {
			++count;
			//double deltaTL = ((double) aveTailLengths[i]) - aveTL;
			double deltaTL = aveTailLengths[i] - aveTL;
			aveTL += deltaTL / count;
			//varTL += deltaTL * (((double) aveTailLengths[i]) - aveTL);
			varTL += deltaTL * (aveTailLengths[i] - aveTL);
 
			//double deltaCL = ((double) aveCycleLengths[i]) - aveCL;
			double deltaCL = aveCycleLengths[i] - aveCL;
			aveCL += deltaCL / count;
			//varCL += deltaCL * (((double) aveCycleLengths[i]) - aveCL);
			varCL += deltaCL * (aveCycleLengths[i] - aveCL);

			//double deltaRL = ((double) aveRhoLengths[i]) - aveRL;
			double deltaRL = aveRhoLengths[i] - aveRL;
			aveRL += deltaRL / count;
			//varRL += deltaRL * (((double) aveRhoLengths[i]) - aveRL);
			varRL += deltaRL * (aveRhoLengths[i] - aveRL);
			//aveTL += ((double) sumTailLengths[i]) / n / trials;
		}
	}
	s = fopen(str, "a");
	fprintf(s, "Finished getting mean and variances...\n");
	fprintf(s, "Doing final variance calc...\n");
	fclose(s);
	varTL = varTL / (count - 1);
	varCL = varCL / (count - 1);
	varRL = varRL / (count - 1);

	//for (i = 1; i < n; ++i) {
		//if (sumRhoLengths[i] != 0) {
			////aveTL += (((double) sumTailLengths[i]) / n) / trials;
			//aveTL += aveTailLengths[i] / trials;
			////aveCL += (((double) sumCycleLengths[i]) / n) / trials;
			//aveCL += aveCycleLengths[i] / trials;
			////aveRL += (((double) sumRhoLengths[i]) / n) / trials;
			//aveRL += aveRhoLengths[i] / trials;
		//}
	//}
	//aveTL = aveTL / trials;
	//aveCL = aveCL / trials;
	//aveRL = aveRL / trials;
	
	s = fopen(str, "w");
	fprintf(s, "Average Tail Length\n");
	fprintf(s, "%f\n", aveTL);
	fprintf(s, "Average Cycle Length\n");
	fprintf(s, "%f\n", aveCL);
	fprintf(s, "Average Rho Length\n");
	fprintf(s, "%f\n", aveRL);

	fprintf(s, "Variance of Average Tail Length\n");
	fprintf(s, "%f\n", varTL);
	fprintf(s, "Variance of Average Cycle Length\n");
	fprintf(s, "%f\n", varCL);
	fprintf(s, "Variance of Average Rho Length\n");
	fprintf(s, "%f\n", varRL);
	fclose(s);

	count = 0;
	for (i = 1; i < n; ++i) {
		if (sumCycleLengths[i] != 0) {
			++count;
			//double deltaSqDifRL = (((double) aveRhoLengths[i]) - aveRL)*(((double) aveRhoLengths[i]) - aveRL) - aveSqDifRL;
			double deltaSqDifRL = (aveRhoLengths[i] - aveRL) * (aveRhoLengths[i] - aveRL);
			aveSqDifRL += deltaSqDifRL / count; 
			//varSqDifRL += deltaSqDifRL * ((((double) aveRhoLengths[i]) - aveRL)*(((double) aveRhoLengths[i]) - aveRL) - aveSqDifRL);
			varSqDifRL += deltaSqDifRL * ((aveRhoLengths[i] - aveRL) * (aveRhoLengths[i] - aveRL));
			covTC += (aveTailLengths[i] - aveTL) * (aveCycleLengths[i] - aveCL) / trials;
		}
	}
	varSqDifRL = varSqDifRL / (count - 1);


	s = fopen(str, "a");
	fprintf(s, "Tau of Average Rho Length\n");
	fprintf(s, "%f\n", varSqDifRL);
	
	
	fprintf(s, "Covariance between Average Tail Length and Average Cycle Length\n");
	fprintf(s, "%f\n", covTC);
	fclose(s);

	//for (i = 2; i < n; ++i) {
		//if (sumCycleLengths[i] != 0) {
			////double temp = pow(((double) sumRhoLengths[i]) / n - aveRL,2);
			////double temp = ((double) sumRhoLengths[i]) / n - aveRL;
			//difTailLengths[i] = aveTailLengths[i] - aveTL;
			//difCycleLengths[i] = aveCycleLengths[i] - aveCL;
			//sqDifRhoLengths[i] = pow(aveRhoLengths[i] - aveRL,2);
			////double tempTL = pow(aveTailLengths[i] - aveTL,2);
			////double tempCL = pow(aveCycleLenghts[i] - aveCL,2);
			////double tempRL = pow(aveRhoLengths[i] - aveRL,2);
			////tempRL = tempRL * tempRL;
			////aveSqDif += temp;
			////varRL += pow((((double) sumRhoLengths[i]) / n - aveRL),2);
			//varTL += pow(difTailLengths[i],2) / (trials - 1);
			//varCL += pow(difCycleLengths[i],2) / (trials - 1);
			//varRL += sqDifRhoLengths[i] / (trials - 1);
			//aveSqDifRL += sqDifRhoLengths[i] / trials;
			////totSqDif += pow((((double) sumRhoLengths[i]) / n - aveRL),2);
		//}
	//}
	////varRL = varRL / (trials - 1);

	//fprintf(s, "Variance of Average Tail Length\n");
	//fprintf(s, "%f\n", varTL);
	//fprintf(s, "Variance of Average Cycle Length\n");
	//fprintf(s, "%f\n", varCL);
	//fprintf(s, "Variance of Average Rho Length\n");
	//fprintf(s, "%f\n", varTL);
	//fprintf(s, "Variance of Average Cycle Length\n");
	//fprintf(s, "%f\n", varCL);
	//fprintf(s, "Variance of Average Rho Length\n");
	//fprintf(s, "%f\n", varRL);
	//fprintf(s, "%f ----> Standard Deviation of Average Rho Length.\n", sqrt(varRL));

	//for (i = 1; i < n; ++i) {
		//if (sumRhoLengths[i] != 0) {
			////double temp = pow(((double) sumRhoLengths[i]) / n - aveRL,2);
			////double temp1 = ((double) sumRhoLengths[i]) / n - aveRL;
			//////double temp1 = aveRhoLengths[i] - aveRL;
			//////temp1 = temp1*temp1;
			////varSqDif += pow(temp - aveSqDif,2) / (trials - 1);
			//////double temp2 = temp1 - aveSqDif;
			//////varSqDif += temp2*temp2 / (trials - 1);
			//varSqDifRL += pow(sqDifRhoLengths[i] - aveSqDifRL,2) / (trials - 1);
			//covTC += difTailLengths[i] * difCycleLengths[i] / trials;
		//}
	//}

	//fprintf(s, "Tau of Average Rho Length\n");
	//fprintf(s, "%f\n", varSqDifRL);

	////for (i = 1; i < n; ++i) {
		////if (sumCycleLengths[i] != 0) {
			////covTC += (((double) sumTailLengths[i]) / n - aveTL) * (((double) sumCycleLengths[i]) / n - aveCL) / trials;
			////covTC += (aveTailLengths[i] - aveTL) * (aveCycleLengths[i] - aveCL) / trials;
		////}
	////}

	//fprintf(s, "Covariance between Average Tail Length and Average Cycle Length\n");
	//fprintf(s, "%f\n", covTC);
	
	free(aveTailLengths);
	free(aveCycleLengths);
	free(aveRhoLengths);
}

int main(int argc, char** argv) {

	n = atoi(argv[1]);

	char str[25];
	sprintf(str, "status_%d.txt", n);
	FILE* s;
	s = fopen(str, "w");

	fprintf(s, "Starting process...\n");
	fprintf(s, "Allocating memory...\n");
	fclose(s);

	//bool visit[n];
	bool *visit = malloc(n*sizeof(bool));
	if (NULL == visit) { fprintf(s, "malloc failed for visit array\n"); return -1; }
	//bool image[n];
	bool *image = malloc(n*sizeof(bool));
	if (NULL == image) { fprintf(s, "malloc failed for image array\n"); return -1; }

	// Maximum tail length for base[i]
	//int maxTail[n];
	int *maxTail = malloc(n*sizeof(int));
	if (NULL == maxTail) { fprintf(s, "malloc failed for maxTail array\n"); return -1; }

	// Maximum cycle length for base[i]
	//int maxCycle[n];
	int *maxCycle = malloc(n*sizeof(int));
	if (NULL == maxCycle) { fprintf(s, "malloc failed for maxCycle array\n"); return -1; }

	// Terminal nodes for base[i]
	//int terminalNodes[n];
	int *terminalNodes = malloc(n*sizeof(int));
	if (NULL == terminalNodes) { fprintf(s, "malloc failed for terminalNodes array\n"); return -1; }

	// Size of cycle for the component this node is a part of
	//int cycleLength[n];
	int *cycleLength = malloc(n*sizeof(int));
	if (NULL == cycleLength) { fprintf(s, "malloc failed for cycleLength array\n"); return -1; }

	// Distance to cycle from node n (0 if node is cyclic)
	//int tailLength[n];
	int *tailLength = malloc(n*sizeof(int));
	if (NULL == tailLength) { fprintf(s, "malloc failed for tailLength array\n"); return -1; }

	// Distance to and around cycle for given node
	//int rhoLength[n];
	
	// Number of components for base[i]
	//int numComponents[n];
	int *numComponents = malloc(n*sizeof(int));
	if (NULL == numComponents) { fprintf(s, "malloc failed for numComponents array\n"); return -1; }
	
	// Number of image nodes for base[i]
	//int imageNodes[n];
	int *imageNodes = malloc(n*sizeof(int));
	if (NULL == imageNodes) { fprintf(s, "malloc failed for imageNodes array\n"); return -1; }
	
	// Sum of all nodes' tail lengths for base[i]
	//int sumTailLengths[n];
	long *sumTailLengths = malloc(n*sizeof(long));
	if (NULL == sumTailLengths) { fprintf(s, "malloc failed for sumTailLengths array\n"); return -1; }

	// Sum of all nodes' cycle lengths for base[i]
	//int sumCycleLengths[n];
	long *sumCycleLengths = malloc(n*sizeof(long));
	if (NULL == sumCycleLengths) { fprintf(s, "malloc failed for sumCycleLengths array\n"); return -1; }

	// Sum of all nodes' rho lengths for base[i]
	//long long sumRhoLengths[n];

	//int marker[n];
	int *marker = malloc(n*sizeof(int));
	if (NULL == marker) { fprintf(s, "malloc failed for marker array\n"); return -1; }

	// Initialize variable that will store max cycle length (MCL) and max tail length (MTL)
	int MCL, MTL;
	int next, loc, baseTail, cycleSize, terminal;
	int root, exp, base;
	//int listArray[n];
	int *listArray = malloc(n*sizeof(int));
	if (NULL == listArray) { fprintf(s, "malloc failed for listArray array\n"); return -1; }
	int listSize = 0;

	s = fopen(str, "w");

	fprintf(s, "Finished allocating...\n");
	fprintf(s, "Zeroing arrays...\n");
	fclose(s);

	// Initialize arrays to 0
	int i;
	for (i = 0; i < n; ++i) {
		//visit[i] = false;
		//image[i] = false;
		maxTail[i] = 0;
		maxCycle[i] = 0;
		terminalNodes[i] = 0;
		//cycleLength[i] = 0;
		//tailLength[i] = 0;
		numComponents[i] = 0;
		imageNodes[i] = 0;
		sumTailLengths[i] = 0;
		sumCycleLengths[i] = 0;
		//sumRhoLengths[i] = 0;
		marker[i] = 0;
	}

	s = fopen(str, "w");
	fprintf(s, "Finding a primitive root...\n");

	for (root = 1; !isPrimRoot(root); ++root);
	fprintf(s, "Primitive root is %d...\n", root);

	fprintf(s, "Begin making graphs...\n");
	fclose(s);
	int count = -1;
	for (exp = 0; exp < n; ++exp) {
		// Discard all but the bases which will make the target M_ARY graphs
		if (gcd(exp, n-1) != M_ARY) continue;
		count++;
		base = m_exp(root, exp);
		// Mark each base that will produce target graph
		marker[base] = 1;
		// Recall that MCL is max cycle length and MTL is max tail length
		MCL = 0;
		MTL = 0;
		// 0 out everything
		setArrays(cycleLength, visit, tailLength, image);
		// Begin makeing graph, using gamma(i) = base^i mod n
		for (i = 1; i < n; ++i) {
			if (visit[i]) continue;
			next = i;
			listArray[0] = next;
			listSize = 1;
			while(!visit[next]) {
				visit[next] = true;
				next = m_exp(base,next);
				image[next] = true;
				listArray[listSize] = next;
				listSize++;
			}
			int j;
			if (cycleLength[next] != 0) {
				if (tailLength[next] == 0) {
					// All tail into cycle
					cycleSize = cycleLength[listArray[listSize - 1]];
					if (listSize - 1 > MTL) MTL = listSize - 1;
					for (j = 0; j < listSize - 1; ++j) {
						tailLength[listArray[j]] = listSize - 1 - j;
						cycleLength[listArray[j]] = cycleSize;
					}
				} else {
					// Extension of tail
					baseTail = tailLength[listArray[listSize - 1]];
					cycleSize = cycleLength[listArray[listSize - 1]];
					if (listSize - 1 + baseTail > MTL) MTL = listSize - 1 + baseTail;
					for (j = 0; j < listSize - 1; ++j) {
						tailLength[listArray[j]] = baseTail + listSize - 1 - j;
						cycleLength[listArray[j]] = cycleSize;
					}
				}
			} else {
				// New cycle found
				// Location will be the first node in the cycle we ran into
				int repeat = listArray[listSize - 1];
				for (j = 0; listArray[j] != repeat; ++j);
				int firstCycle = j;
				cycleSize = listSize - (j + 1);
				if (cycleSize > MCL) MCL = cycleSize;
				if (firstCycle > MTL) MTL = firstCycle;
				// Mark each tail node along the way with how far it is to the cycle (marked as a negative number)
				for (j = 0; j < firstCycle; ++j) {
					tailLength[listArray[j]] = firstCycle - j;
					cycleLength[listArray[j]] = cycleSize;
					//sumCycleLengths[base] += cycleSize;
				}
				for (j = firstCycle; j < listSize - 1; ++j) {
					cycleLength[listArray[j]] = cycleSize;
				}
			}
		}
		terminal = 0;
		for (i = 1; i < n; ++i) {
			if (!image[i]) terminal++;
		}
		maxTail[base] = MTL;
		maxCycle[base] = MCL;
		terminalNodes[base] = terminal;
		computeResults(tailLength, cycleLength, &sumTailLengths[base], &sumCycleLengths[base]);
		//computeResults(tailLength, cycleLength, sumTailLengths[base], sumCycleLengths[base]);
	}

	free(visit);
	free(image);
	free(tailLength);
	free(cycleLength);
	free(listArray);
	free(imageNodes);
	free(maxTail);
	free(maxCycle);
	free(terminalNodes);
	free(numComponents);
	//free(sumCycleLengths);
	//free(sumTailLengths);
	free(marker);

	s = fopen(str, "w");
	fprintf(s, "Writing results...\n");
	fclose(s);
	//writeTotalResults(maxTail, maxCycle, terminalNodes, numComponents, sumCycleLengths, sumTailLengths, marker, s);
	//printf("sumCycleLengths[%d] = %ld\n", 2, sumCycleLengths[2]);
	//printf("sumCycleLengths[%d] = %ld\n", 6, sumCycleLengths[6]);
	writeTotalResults(sumCycleLengths, sumTailLengths, s, str);

	//free(maxTail);
	//free(maxCycle);
	//free(terminalNodes);
	//free(numComponents);
	free(sumCycleLengths);
	free(sumTailLengths);
	//free(marker);
			
	//fclose(s);

	return 0;
}