#include "ca3.h"

int mynode, totalnodes;

int m_expn(int b, int r, int num) {
	return (int) ml_exp((long long) b, r, (long long) num);
}

int m_exp(int b, int r) {
	return m_expn(b, r, n);
}

/*
long long ml_exp(long long b, int r, long long num) {
	if (r == 0) return 1;
	if(r % 2 == 0) {
		long long result = ml_exp(b,r/2,num);
		return result * result % num;
	}
	long long result = ml_exp(b,r/2,num);
	return (b * result % num) * result % num;
}
*/

/*taken from wikipedia page*/
long long ml_exp(long long b, int e, long long m) {

   long long result = 1;

   while (e > 0) {
      /*// multiply in this bits' contribution while using modulus to keep result small*/
      if ((e & 1) == 1) result = (result * b) % m; 
      e >>= 1;
      b = (b * b) % m;
   }
   return result;
}

void computeResults(const int* distToCycle, const int* cycleSize, 
					int* allToCycleSum, int* allCycleLengthSum) {

	int sumDistToCycle = 0;
	int sumCycleSize = 0;
	int i = 0;

	for(i = 0; i < n; i++) {
		sumDistToCycle += distToCycle[i];
		sumCycleSize += cycleSize[i];
	}
	
	*allToCycleSum = sumDistToCycle;
	*allCycleLengthSum = sumCycleSize;
}


bool isPrimRoot(int base) {
	if(!isPrime(n)) return false;
	
	int n_1 = n-1;

	if ((unsigned)n_1 > (primes[NUMPRIMES-1]*primes[NUMPRIMES-1]))
		printf("Error in Primitive Root Testing, n could have prime factor too large for testing\n");

	int n1 = n_1;
	int index = 0;
	int p;

	while(n1 > 1 && index < NUMPRIMES) {
		/*find the primes that divide phi(n)*/
		if((n1 % primes[index]) == 0) {
			p = primes[index];
			/* divide out that prime all the way so it isn't tested again */
			while((n1 % primes[index] == 0)) n1/=primes[index];
			/*if base^phi(n)/p is 1, not a prim root */
			if(m_exp(base,n_1/p) == 1) return false;
			/*if(isPrime(n1))	return m_exp(base,n_1/n1) == 1; */
			if(n1 == 50021) return (m_exp(base,n_1/50021) != 1);
		}
		index++;
	}
	return true;
}

bool isPrime(int num) {
	int i = 0;
	for(i = 0; i < 50;i++) {
		if(primes[i] > (unsigned)num) return true;
		if((num % primes[i] == 0) && (num!=primes[i])) return false;
	}
	
	int k = 0;
	int q = num-1;
	while(q % 2 == 0) {
		k++;
		q >>= 1;
	}
	srand(time(0));
	int a;
	for(i = 0; i < 10; i++) {
		a = (rand() % (num-2)) + 1;
		if(!MillerRabin(num, k, q, a)) return false;
	}

	return true;
}


bool MillerRabin(int num, int k, int q, int a) {
	int n1 = num-1;
	if(m_expn(a,q, num) == 1) return true;
	int i = 0;
	for(i = 0; i < k; i++)
		if(m_expn(a,(int)pow(2,i)*q, num) == n1) return true;
	return false;
}


bool isRelPrime(int base) {
	return gcd(base, n-1) == 1;
}

int gcd(int a, int b) {
  if(a== 0) return b;
  if(b==0) return a;
  int r = a % b;
	int d = b;
	int c;
	while (r > 0) {
		c = d;
		d = r;
		r = c % d;
	}
	return d;
}

void setArrays(int * cycleSize, bool* visit, int* distToCycle, bool* image){
	int i = 0;
	for(i = 0; i < n; i++){
		visit[i] = false;
		cycleSize[i] = 0;
		distToCycle[i] = 0;
		image[i] = false;
	}
}

void zeroList(int * listArray) {
	int i = 0;
	for(i = 0; i < n; i++)
		listArray[i] = 0;
}


void writeTotalResults(
		        int* maxTAll,
		        int* maxCAll,
		        int* terminalAll,
				int* allComponents,
				int* allCyclicNodes,
				int* allToCycleSum,
				int* allCycleLengthSum) {


	char fileStr[20];
	sprintf(fileStr, "%d_%d_%d.dat", n, M_ARY, mynode);
	FILE * out = fopen(fileStr, "w");
	/*# cycles base i
	//sum of cycle size seen from nodes in base i
	//sum of distance to cycle from nodes in base i
	//terminal nodes for base i
	//max cycle for base i
	//max tail for base i
	//cyclic nodes for base i */
	int i = 2;
	for(i = 2; i < n; i++) { /*0 and 1 not considered*/
		fprintf(out, "%d %d %d %d %d %d %d\n", allComponents[i], allCycleLengthSum[i],
			allToCycleSum[i], terminalAll[i], maxCAll[i], maxTAll[i], allCyclicNodes[i]);
	}
	fclose(out);

	double cComponents = 0;
	double cComponentsSquared = 0;

	double cCyclicNodes = 0;
	double cCyclicNodesSquared = 0;

	double cImageNodes = 0;
	/*variance = 0*/

	double cMaxCycle = 0;
	double cMaxCycleSquared = 0;

	double cMaxTail = 0;
	double cMaxTailSquared = 0;

	double cWeightedCycle = 0;
	double cWeightedCycleSquared = 0;

	double cWeightedTail = 0;
	double cWeightedTailSquared = 0;

	for (i = 2; i < n; i++)
	{
		cComponents += ((double)allComponents[i]) / trials;
		cComponentsSquared += (double)allComponents[i] * (double)allComponents[i] / trials;

		cCyclicNodes += (double)allCyclicNodes[i] / trials;
		cCyclicNodesSquared += (double)allCyclicNodes[i] * (double)allCyclicNodes[i] / trials;

		double cycle = (double)allCycleLengthSum[i] / (double)(n-1);
		cWeightedCycle += (double)(cycle) / trials;
		cWeightedCycleSquared += (double)(cycle*cycle) / (trials);

		double tail = (double)allToCycleSum[i] / (double)(n-1);
		cWeightedTail += tail / trials;
		cWeightedTailSquared += (double)(tail*tail) / (trials);
		
		if (i != n)
		{
			if (terminalAll[i] > 0)
				cImageNodes += ((double)(n-1) - terminalAll[i]) / trials;
			cMaxCycle += (double)maxCAll[i] / trials;
			cMaxCycleSquared += (double)maxCAll[i]*(double)maxCAll[i] / trials;

			cMaxTail += (double)maxTAll[i] / trials;
			cMaxTailSquared += (double)maxTAll[i]*(double)maxTAll[i] / trials;
		}
	}

	double cComponentsTot = 0;
	double cComponentsSquaredTot = 0;

	double cCyclicNodesTot = 0;
	double cCyclicNodesSquaredTot = 0;

	double cImageNodesTot = 0;
	double cMaxCycleTot = 0;
	double cMaxCycleSquaredTot = 0;

	double cMaxTailTot = 0;
	double cMaxTailSquaredTot = 0;

	double cWeightedCycleTot = 0;
	double cWeightedCycleSquaredTot = 0;

	double cWeightedTailTot = 0;
	double cWeightedTailSquaredTot = 0;



	MPI_Reduce( &cComponents, &cComponentsTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
	MPI_Reduce( &cComponentsSquared, &cComponentsSquaredTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);

	MPI_Reduce( &cCyclicNodes, &cCyclicNodesTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
	MPI_Reduce( &cCyclicNodesSquared, &cCyclicNodesSquaredTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);

	MPI_Reduce( &cImageNodes, &cImageNodesTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
	
	MPI_Reduce( &cMaxCycle, &cMaxCycleTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
	MPI_Reduce( &cMaxCycleSquared, &cMaxCycleSquaredTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);

	MPI_Reduce( &cMaxTail, &cMaxTailTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
	MPI_Reduce( &cMaxTailSquared, &cMaxTailSquaredTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);

	MPI_Reduce( &cWeightedCycle, &cWeightedCycleTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
	MPI_Reduce( &cWeightedCycleSquared, &cWeightedCycleSquaredTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);

	MPI_Reduce( &cWeightedTail, &cWeightedTailTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
	MPI_Reduce( &cWeightedTailSquared, &cWeightedTailSquaredTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);


	if (mynode == 0)
	{
		char res[20];
		sprintf(res, "results_%d.dat", n);
		FILE * r = fopen(res, "w");
		
		double ComponentsVariance = cComponentsSquaredTot - cComponentsTot*cComponentsTot;
		double CyclicNodesVariance = cCyclicNodesSquaredTot - cCyclicNodesTot*cCyclicNodesTot;
		double WeightedCycleVariance = cWeightedCycleSquaredTot - cWeightedCycleTot*cWeightedCycleTot;
		double WeightedTailVariance = cWeightedTailSquaredTot - cWeightedTailTot*cWeightedTailTot;
		double MaxCycleVariance = cMaxCycleSquaredTot - cMaxCycleTot*cMaxCycleTot;
		double MaxTailVariance = cMaxTailSquaredTot - cMaxTailTot*cMaxTailTot;

		
		fprintf(r, "components: %lf \n", cComponentsTot);
		fprintf(r, "components variance: %lf \n", ComponentsVariance);

		fprintf(r, "cyclic nodes: %lf \n", cCyclicNodesTot);
		fprintf(r, "cyclic nodes variance: %lf\n", CyclicNodesVariance);

		fprintf(r, "avg cycle: %lf\n", cWeightedCycleTot);
		fprintf(r, "avg cycle variance: %lf\n", WeightedCycleVariance);

		fprintf(r, "avg tail: %lf\n", cWeightedTailTot);
		fprintf(r, "avg tail variance: %lf\n", WeightedTailVariance);

		fprintf(r, "image nodes: %lf\n", cImageNodesTot);
		fprintf(r, "max cycle: %lf\n", cMaxCycleTot);
		fprintf(r, "max cycle variance: %lf\n", MaxCycleVariance);

		fprintf(r, "max tail: %lf\n", cMaxTailTot);
		fprintf(r, "max tail variance: %lf\n", MaxTailVariance);

		fclose(r);
	}
}


void run() {


	MPI_Comm_size(MPI_COMM_WORLD, &totalnodes);
	MPI_Comm_rank(MPI_COMM_WORLD, &mynode);

	FILE * s;
	if (mynode == 0)
	{
		s = fopen(STATUS, "w");
		fprintf(s, "Allocating...\n");
		fclose(s);
	}

	/*status << "Allocating...\n";
	status.close();*/

	bool visit[n];
	bool image[n];
	/*maximum tail length for base [i]*/
	int maxTAll[n];
	/*maximum cycle lenghth for base [i]*/
	int maxCAll[n];
	/*terminal nodes for base [i]*/
	int terminalAll[n];

	/*size of cycle for the component this node is a part of*/
	int cycleSize[n];
	/*distance to cycle from node n (0 if node n is cyclic)*/
	int distToCycle[n];

	/*The number of nodes -in- cycles of size [i] for current n
	//int *allCResults = new int[n+1];
	//The number of nodes which are [i] away from their cycle
	//int *allTResults = new int[n+1];
	//The number of tail nodes that lead to a cycle of size [i]
	//int *allToCycleResults = new int[n+1];*/

	/*Number of components for base i*/
	int allComponents[n+1];
	/*Number of image nodes for base i*/
	int allCyclicNodes[n+1];
	/*Sum of all nodes' distance to cycle for base i */
	int allToCycleSum[n+1];
	/*Sum of each node's cycle length for base i */
	int allCycleLengthSum[n+1];

	/*max cycle, max tail */
	int mC, mT;
	int next, loc, baseTail, cycleLength, terminal;
	int root, exp, base;

	int listArray[n];
	int listSize = 0;
  
	
	if (mynode == 0)
	{
		s = fopen(STATUS, "w");
		fprintf(s, "zeroing...\n");
		fclose(s);
	}


	/*initialize arrays to 0 */
	int i = 0;
	for(i = 0; i < n; i++){
		if(i < n) {
		  maxTAll[i] = 0;
		  maxCAll[i] = 0;
		  terminalAll[i] = 0;
		}
	}
	allComponents[n] = 0;
	allCyclicNodes[n] = 0;
	allToCycleSum[n] = 0;
	allCycleLengthSum[n] = 0;

	double t;
	if (mynode == 0)
		t = MPI_Wtime();

	double tt;
	double expTime = 0;
	double tailTime = 0;
	double intoCycleTime = 0;
	double cycleTime = 0;
	double resultsTime = 0;

	if (mynode == 0)
	{
		s = fopen(STATUS, "w");
		fprintf(s, "Finding a PR...\n");
		fclose(s);
	}
	
	/*find the smallest primitive root*/
	for(root = 1; !isPrimRoot(root); root++);

	if (mynode == 0)
	{
		s = fopen(STATUS, "a");
		fprintf(s, "Prim root is %d...\n", root);
		fclose(s);
	}

	int count = -1;

	for(exp = 0; exp < n; exp ++) {
	  if(exp % 100 == 0 && mynode == 0) {
	    s = fopen(STATUS, "w");
	    fprintf(s, "Exp is %d\n", exp);
	    fclose(s);
	  }

	  /*discard all but the bases which will make the target M-ARY graphs*/
	  if(gcd(exp,n-1) != M_ARY) continue;
	  count++;
	  if(count % totalnodes != mynode) continue;

	  base = m_exp(root,exp);      
	  mC = 0;
	  mT = 0;
	  /*0 out everything*/
	  setArrays(cycleSize, visit, distToCycle, image);

	  /*begin making graph, using gamma(i) = base^i mod n*/
		for(i = 1; i < n; i++) {
			if (visit[i]) 
				continue;

			next = i;
			listArray[0] = next;
			listSize = 1;

			tt = MPI_Wtime();
			while(!visit[next]){
				visit[next] = true;
				next = m_exp(base,next);
				image[next] = true;
				listArray[listSize] = next;
				listSize++;
			}
			expTime += MPI_Wtime() - tt;

			int j = 0;
			if(cycleSize[next] != 0) {
				if(distToCycle[next] == 0) {/*all tail into cycle*/
					
					tt = MPI_Wtime();
					cycleLength = cycleSize[listArray[listSize-1]];
					if(listSize - 1 > mT) mT = listSize - 1;
					for(j = 0; j < listSize-1; j++){
						distToCycle[listArray[j]] = listSize - 1 - j;
						cycleSize[listArray[j]] = cycleLength;
					}
					intoCycleTime += MPI_Wtime() - tt;

				} else {/*extension of tail*/

					tt = MPI_Wtime();
					baseTail = distToCycle[listArray[listSize-1]];
					cycleLength = cycleSize[listArray[listSize-1]];
					if(listSize-1 + baseTail > mT) mT = listSize-1 + baseTail;
					for(j = 0; j < listSize-1; j++) {
						distToCycle[listArray[j]] = baseTail + listSize - 1 - j;
						cycleSize[listArray[j]] = cycleLength;
					}
					tailTime += MPI_Wtime() - tt;
				}
			} else {/*new cycle found*/

				tt = MPI_Wtime();
				/*loc will be the first node in the cycle we ran in to*/
				int repeat = listArray[listSize-1];
				for(j = 0; listArray[j] != repeat; j++);

				int firstCycle = j;
				cycleLength = listSize - (j+1);
				if(cycleLength > mC) mC = cycleLength;

				if(firstCycle > mT) mT = firstCycle;
				/*mark each tail node along the way with how far it is to*/
				/*the cycle (marked as a negative number)*/
				for(j = 0; j < firstCycle; j++) {
					distToCycle[listArray[j]] = firstCycle - j;
					cycleSize[listArray[j]] = cycleLength;
				}
				/*mark each cycle nodes with how big the cycle is*/
				for(j = firstCycle; j < listSize - 1; j++)
					cycleSize[listArray[j]] = cycleLength;

				allComponents[base]++;
				allCyclicNodes[base] += cycleLength;
				
				cycleTime += MPI_Wtime() - tt;
			}
		}

		tt = MPI_Wtime();

		terminal=0;
		for(i = 1; i < n; i++) 
		  if(!image[i]) terminal++;
		maxTAll[base] = mT;
		maxCAll[base] = mC;
		terminalAll[base] = terminal;
		computeResults(distToCycle, cycleSize, &allToCycleSum[base], &allCycleLengthSum[base]);

		resultsTime += MPI_Wtime() - tt;
	}
	if (mynode == 0)
	{
		s = fopen(STATUS, "w");
		fprintf(s, "Writing Results...\n");
		fclose(s);
	}
	writeTotalResults(
			  maxTAll,
			  maxCAll,
			  terminalAll,
			  allComponents,
			  allCyclicNodes,
			  allToCycleSum,
			  allCycleLengthSum);


	if (mynode == 0)
	{
		s = fopen(STATUS, "w");
		fprintf(s, "%lf minutes...\n Exiting...\n%lf %lf %lf %lf %lf\n", (MPI_Wtime() - t)/60, expTime, tailTime, intoCycleTime, cycleTime, resultsTime);
		fclose(s);
	}

	printf("%lf\n", expTime);

}