#include "newH2.h"
int mynode, totalnodes;

/* Algorithm for modular expoentiation.*/
long long ml_exp(long long b, int e, long long m) 
{
    long long result = 1;
    while (e > 0)
     {
         /* Multiply in this bits’ contribution while using modulus to keep result small*/
         /* Note '&' is the bitwise 'and' operator*/
         if ((e & 1) == 1) result = (result * b) % m;
         e >>= 1;
         b = (b * b) % m;
    }
    return result;
}

/* Another method for modular expoentiation called on integers.*/
int m_expn(int b, int r, int num) 
{
    return (int) ml_exp((long long) b, r, (long long) num);
}

/* Yet another simplified method for expoentiation that may be used when the modulus (n) has already been defined. */
/* The inputs are the base b and the exponent r. */
int m_exp(int b, int r)
{
    return m_expn(b, r, n);
}

/* Compact method to find total 'Distance to Cycle' and 'Cycle Size' based on previously computed arrays containing results from the individual nodes. */
void computeResults(const int* distToCycle, const int* cycleSize,
int* allToCycleSum, long long* allCycleLengthSum) 
{
    int sumDistToCycle = 0;
    long long sumCycleSize = 0;
    int i = 0;
    for(i = 0; i < n; i++) 
    {
        sumDistToCycle += distToCycle[i];
        sumCycleSize += cycleSize[i];
    }
    *allToCycleSum = sumDistToCycle;
    *allCycleLengthSum = sumCycleSize;
}

/* Implementation of the Rabin-Miller test for primality.*/
bool MillerRabin(int num, int k, int q, int a) 
{
    int n1 = num-1;
    if(m_expn(a,q, num) == 1) return true;
    int i = 0;
    for(i = 0; i < k; i++)
        if(m_expn(a,(int)pow(2,i)*q, num) == n1) return true;
    return false;
}

/*Test for primiality using an attached array of primes in the header file bn_prime.h.*/
bool isPrime(int num) 
{
    int i = 0;
    for(i = 0; i < 50;i++) 
    {
        if(primes[i] > (unsigned)num) return true;
        if((num % primes[i] == 0) && (num!=primes[i])) return false;
    }
    int k = 0;
    int q = num-1;
    while(q % 2 == 0) 
    {
        k++;
        q >>= 1;
    }
    srand(time(0));
    int a;
    for(i = 0; i < 10; i++) 
    {
        a = (rand() % (num-2)) + 1;
        if(!MillerRabin(num, k, q, a)) return false;
    }
    return true;
}

/* Method to test whether or not an element is a Primitive Root of an inputted modulus. */
bool isPrimRoot(int base) 
{ 
    if(!isPrime(n)) return false;
    int n_1 = n-1;
    if ((unsigned)n_1 > (primes[NUMPRIMES-1]*primes[NUMPRIMES-1]))
        printf("Error in Primitive Root Testing, n could have prime factor too large for testing\n");
    int n1 = n_1;
    int index = 0;
    int p;
    while(n1 > 1 && index < NUMPRIMES)
    {
        /*find the primes that divide phi(n)*/
        if((n1 % primes[index]) == 0) 
        {
            p = primes[index];
            /* divide out that prime all the way so it isn’t tested again*/
            while((n1 % primes[index] == 0)) n1/=primes[index];
            /*if base^phi(n)/p is 1, not a prim root*/
            if(m_exp(base,n_1/p) == 1) return false;
            /*if(isPrime(n1)) return m_exp(base,n_1/n1) == 1;*/
            if(n1 == 50021) return (m_exp(base,n_1/50021) != 1);
        }
        index++;
    }
    return true;
}

/*Method to find the greatest common divisors of two elements.*/
int gcd(int a, int b) 
{
    if(a== 0) return b;
    if(b==0) return a;
    int r = a % b;
    int d = b;
    int c;
    while (r > 0) 
    {
        c = d;
        d = r;
        r = c % d;
    }
    return d;
}

/*Method to determine whether an element is relatively prime to the modulus.*/
bool isRelPrime(int base) 
{
    return gcd(base, n-1) == 1;
}

/*Method to determine the number of integers relatively prime, that is, Euler Phi Function. This is used to find
the number of m-ary graphs that will be computers thereby giving the number of trials.*/
double euler(int numb) 
{ 
  double value;
  int i=0;
  int result = numb; 
  for(i=2;i*i <= numb;i++) 
  { 
      if (numb % i == 0) result -= result / i; 
      while (numb % i == 0) numb /= i; 
  } 
  if (numb > 1) result -= result / numb;
  value=result; 
  return value; 
} 



/*Method that initializes the arrays based on the size of the modulus.*/
void setArrays(int * cycleSize, bool* visit, int* distToCycle, bool* image)
{
    int i = 0;
    for(i = 0; i < n; i++)
    {
        visit[i] = false;
        cycleSize[i] = 0;
        distToCycle[i] = 0;
        image[i] = false;
    }
}

void zeroList(int * listArray) 
{
    int i = 0;
    for(i = 0; i < n; i++)
    listArray[i] = 0;
}

void writeTotalResults(
    int* maxTAll,
    int* maxCAll,
    int* terminalAll,
    int* allComponents,
    int* allCyclicNodes,
    int* allToCycleSum,
    long long* allCycleLengthSum,
    int* oneCycles,
    int* twoCycles,
    int* threeCycles,
    int* fourCycles,
    int* fiveCycles,
    int* sevenCycles,
    int* tenCycles,
    int* twentyfiveCycles,
    int* marker)
{
    double trials=euler((n-1)/M_ARY);
    char fileStr[20];
    sprintf(fileStr, "%d_%d_%d.dat", n, M_ARY, mynode);
    FILE * out = fopen(fileStr, "w");
    /* cycles base i
    sum of cycle size seen from nodes in base i
    sum of distance to cycle from nodes in base i
    terminal nodes for base i
    max cycle for base i
    max tail for base i
    cyclic nodes for base i */
    int i = 2;
    for(i = 2; i < n; i++) 
    {
        /*0 and 1 not considered*/
        fprintf(out, "%d %ld %d %d %d %d %d\n", allComponents[i], allCycleLengthSum[i],
        allToCycleSum[i], terminalAll[i], maxCAll[i], maxTAll[i], allCyclicNodes[i]);
    }
    fclose(out);

    char fileStr2[20];
    sprintf(fileStr2, "distrib_%d_%d.dat", n, M_ARY);
    FILE * dis = fopen(fileStr2, "w");
    for(i=2; i < n; i++)
    {
       fprintf(dis, "%d %d %d %d %d %d %d %d %d\n", marker[i], oneCycles[i], twoCycles[i], threeCycles[i], fourCycles[i], fiveCycles[i], sevenCycles[i], tenCycles[i], twentyfiveCycles[i]); 
    }
    fclose(dis);
	
    double cComponents = 0;
    double cComponentsSquared = 0;
    double cCyclicNodes = 0;
    double cCyclicNodesSquared = 0;
    double cImageNodes = 0;
    /*variance = 0*/
    double cMaxCycle = 0;
    double cMaxCycleSquared = 0;
    double cMaxTail = 0;
    double cMaxTailSquared = 0;
    double cWeightedCycle = 0;
    double cWeightedCycleSquared = 0;
    double cWeightedTail = 0;
    double cWeightedTailSquared = 0;
    for (i = 2; i < n; i++)
    {
        cComponents += ((double)allComponents[i]) / trials;
        cComponentsSquared += (double)allComponents[i] * (double)allComponents[i] / trials;
        cCyclicNodes += (double)allCyclicNodes[i] / trials;
        cCyclicNodesSquared += (double)allCyclicNodes[i] * (double)allCyclicNodes[i] / trials;
        long double cycle = (long double)allCycleLengthSum[i] / (double)(n-1);
        cWeightedCycle += (long double)(cycle) / trials;
        cWeightedCycleSquared += (long double)(cycle*cycle) / trials;
        double tail = (double)allToCycleSum[i] / (double)(n-1);
        cWeightedTail += tail / trials;
        cWeightedTailSquared += (double)(tail*tail) / trials;
        if (i != n)
        {
            if (terminalAll[i] > 0)
                cImageNodes += ((double)(n-1) - terminalAll[i]) / trials;
            cMaxCycle += (double)maxCAll[i] / trials;
            cMaxCycleSquared += (double)maxCAll[i]*(double)maxCAll[i] / trials;
            cMaxTail += (double)maxTAll[i] / trials;
            cMaxTailSquared += (double)maxTAll[i]*(double)maxTAll[i] / trials;
        }
    }
    
    double cComponentsTot = 0;
    double cComponentsSquaredTot = 0;
    double cCyclicNodesTot = 0;
    double cCyclicNodesSquaredTot = 0;
    double cImageNodesTot = 0;
    double cMaxCycleTot = 0;
    double cMaxCycleSquaredTot = 0;
    double cMaxTailTot = 0;
    double cMaxTailSquaredTot = 0;
    double cWeightedCycleTot = 0;
    double cWeightedCycleSquaredTot = 0;
    double cWeightedTailTot = 0;
    double cWeightedTailSquaredTot = 0;
    
    MPI_Reduce( &cComponents, &cComponentsTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
    MPI_Reduce( &cComponentsSquared, &cComponentsSquaredTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
    MPI_Reduce( &cCyclicNodes, &cCyclicNodesTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
    MPI_Reduce( &cCyclicNodesSquared, &cCyclicNodesSquaredTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
    MPI_Reduce( &cImageNodes, &cImageNodesTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
    MPI_Reduce( &cMaxCycle, &cMaxCycleTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
    MPI_Reduce( &cMaxCycleSquared, &cMaxCycleSquaredTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
    MPI_Reduce( &cMaxTail, &cMaxTailTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
    MPI_Reduce( &cMaxTailSquared, &cMaxTailSquaredTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
    MPI_Reduce( &cWeightedCycle, &cWeightedCycleTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
    MPI_Reduce( &cWeightedCycleSquared, &cWeightedCycleSquaredTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
    MPI_Reduce( &cWeightedTail, &cWeightedTailTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);
    MPI_Reduce( &cWeightedTailSquared, &cWeightedTailSquaredTot, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD);

    if (mynode == 0)
    {
        char res[20];
        sprintf(res, "results_%d.dat", n);
        FILE * r = fopen(res, "w");
        double ComponentsVariance = cComponentsSquaredTot - cComponentsTot*cComponentsTot;
        double CyclicNodesVariance = cCyclicNodesSquaredTot - cCyclicNodesTot*cCyclicNodesTot;
        double WeightedCycleVariance = cWeightedCycleSquaredTot - cWeightedCycleTot*cWeightedCycleTot;
        double WeightedTailVariance = cWeightedTailSquaredTot - cWeightedTailTot*cWeightedTailTot;
        double MaxCycleVariance = cMaxCycleSquaredTot - cMaxCycleTot*cMaxCycleTot;
        double MaxTailVariance = cMaxTailSquaredTot - cMaxTailTot*cMaxTailTot;
        fprintf(r, "components: %lf \n", cComponentsTot);
        fprintf(r, "components variance: %lf \n", ComponentsVariance);
        fprintf(r, "cyclic nodes: %lf \n", cCyclicNodesTot);
        fprintf(r, "cyclic nodes variance: %lf\n", CyclicNodesVariance);
        fprintf(r, "avg cycle: %lf\n", cWeightedCycleTot);
        fprintf(r, "avg cycle variance: %lf\n", WeightedCycleVariance);
        fprintf(r, "avg tail: %lf\n", cWeightedTailTot);
        fprintf(r, "avg tail variance: %lf\n", WeightedTailVariance);
        fprintf(r, "image nodes: %lf\n", cImageNodesTot);
        fprintf(r, "max cycle: %lf\n", cMaxCycleTot);
        fprintf(r, "max cycle variance: %lf\n", MaxCycleVariance);
        fprintf(r, "max tail: %lf\n", cMaxTailTot);
        fprintf(r, "max tail variance: %lf\n", MaxTailVariance);
        fclose(r);
    }
}
void run() 
{
    MPI_Comm_size(MPI_COMM_WORLD, &totalnodes);
    MPI_Comm_rank(MPI_COMM_WORLD, &mynode);
    FILE * s;
    if (mynode == 0)
    {
        s = fopen(STATUS, "w");
        fprintf(s, "Allocating...\n");
        fclose(s);                                      
    }	
    bool visit[n];
    bool image[n];
    /*Maximum tail length for base [i]*/
    int maxTAll[n];
    /*Maximum cycle lenghth for base [i]*/
    int maxCAll[n];
    /*Terminal nodes for base [i]*/
    int terminalAll[n];
    /*Size of cycle for the component this node is a part of*/
    int cycleSize[n];
    /*Distance to cycle from node n (0 if node n is cyclic)*/
    int distToCycle[n];
    /*Number of components for base i*/
    int allComponents[n+1];
    /*Number of image nodes for base i*/
    int allCyclicNodes[n+1];
    /*Sum of all nodes’ distance to cycle for base i*/
    int allToCycleSum[n+1];
    /*Sum of each node’s cycle length for base i*/
    long long allCycleLengthSum[n+1];
    /*Record the number of cycles of arbitrary lengths.*/
    int oneCycles[n];
    int twoCycles[n];
    int threeCycles[n];
    int fourCycles[n];
    int fiveCycles[n];
    int sevenCycles[n];
    int tenCycles[n];
    int twentyfiveCycles[n];
    int marker[n];
    /*Initialize variables that will store max cycle length (mC) and max tail length (mT).*/
    int mC, mT;
    int next, loc, baseTail, cycleLength, terminal;
    int root, exp, base;
    int listArray[n];
    int listSize = 0;
    if (mynode == 0)
    {
        s = fopen(STATUS, "w");
        fprintf(s, "zeroing...\n");
        fclose(s);
    }
    /*initialize arrays to 0 */
    int i = 0;
    for(i = 0; i < n; i++)
    {
        if(i < n) 
        {
            maxTAll[i] = 0;
            maxCAll[i] = 0;
            terminalAll[i] = 0;
        }
    }
    allComponents[n] = 0;
    allCyclicNodes[n] = 0;
    allToCycleSum[n] = 0;
    allCycleLengthSum[n] = 0;
    oneCycles[n] = 0;
    twoCycles[n] = 0;
    threeCycles[n];
    fourCycles[n];
    fiveCycles[n] = 0;
    sevenCycles[n]=0;
    tenCycles[n] = 0;
    twentyfiveCycles[n] = 0;
    marker[n]=0;
    double t;
    if (mynode == 0)
        t = MPI_Wtime();
    double tt;
    double expTime = 0;
    double tailTime = 0;
    double intoCycleTime = 0;
    double cycleTime = 0;
    double resultsTime = 0;
    if (mynode == 0)
    {
        s = fopen(STATUS, "w");
        fprintf(s, "Finding a PR...\n");
        fclose(s);
    }
    /*find the smallest primitive root*/
    for(root = 1; !isPrimRoot(root); root++);
        if (mynode == 0)
        {
            s = fopen(STATUS, "a");
            fprintf(s, "Prim root is %d...\n", root);
            fclose(s);
        }
    int count = -1;
    for(exp = 0; exp < n; exp ++) 
    {
        if(exp % 100 == 0 && mynode == 0) 
        {
            s = fopen(STATUS, "w");
            fprintf(s, "Exp is %d\n", exp);
            fclose(s);
        }
        /*discard all but the bases which will make the target M-ARY graphs*/
        if(gcd(exp,n-1) != M_ARY) continue;
        count++;
        if(count % totalnodes != mynode) continue;
        base = m_exp(root,exp);
        /*Mark each base that will produce target graph.*/
        marker[base]=1;
        /*Recall mC is max cycle, mT is max tail.*/
        mC = 0;
        mT = 0;
        /*0 out everything*/
        setArrays(cycleSize, visit, distToCycle, image);
        /*begin making graph, using gamma(i) = base^i mod n*/
        for(i = 1; i < n; i++) 
        {
            if(visit[i])
                continue;
            next = i;
            listArray[0] = next;
            listSize = 1;
            tt = MPI_Wtime();
            while(!visit[next])
            {
                visit[next] = true;
                next = m_exp(base,next);
                image[next] = true;
                listArray[listSize] = next;
                listSize++;
            }
            expTime += MPI_Wtime() - tt;
            int j = 0;
            if(cycleSize[next] != 0) 
            {
                if(distToCycle[next] == 0) 
                {
                    /*all tail into cycle*/
                    tt = MPI_Wtime();
                    cycleLength = cycleSize[listArray[listSize-1]];
                    if(listSize - 1 > mT) mT = listSize - 1;
                        for(j = 0; j < listSize-1; j++)
                        {
                            distToCycle[listArray[j]] = listSize - 1 - j;
                            cycleSize[listArray[j]] = cycleLength;
                        }
                    intoCycleTime += MPI_Wtime() - tt;
                } 
                else 
                {
                    /*extension of tail*/
                    tt = MPI_Wtime();
                    baseTail = distToCycle[listArray[listSize-1]];
                    cycleLength = cycleSize[listArray[listSize-1]];
                    if(listSize-1 + baseTail > mT) mT = listSize-1 + baseTail;
                    for(j = 0; j < listSize-1; j++) 
                    {
                        distToCycle[listArray[j]] = baseTail + listSize - 1 - j;
                        cycleSize[listArray[j]] = cycleLength;
                    }
                    tailTime += MPI_Wtime() - tt;
                }
            } 
            else 
            {
                /*new cycle found*/
                tt = MPI_Wtime();
                /*loc will be the first node in the cycle we ran in to*/
                int repeat = listArray[listSize-1];
                for(j = 0; listArray[j] != repeat; j++);
                    int firstCycle = j;
                cycleLength = listSize - (j+1);
                if(cycleLength==1) oneCycles[base]++;
                if(cycleLength==2) twoCycles[base]++;
                if(cycleLength==3) threeCycles[base]++;
                if(cycleLength==3) fourCycles[base]++;
                if(cycleLength==5) fiveCycles[base]++;
                if(cycleLength==7) sevenCycles[base]++;
                if(cycleLength==10) tenCycles[base]++;
                if(cycleLength==20) twentyfiveCycles[base]++;
                if(cycleLength > mC) mC = cycleLength;
                if(firstCycle > mT) mT = firstCycle;
                /*mark each tail node along the way with how far it is to
                the cycle (marked as a negative number)*/
                for(j = 0; j < firstCycle; j++) 
                {
                    distToCycle[listArray[j]] = firstCycle - j;
                    cycleSize[listArray[j]] = cycleLength;
                }
                    /*mark each cycle node with how big the cycle is*/
                for(j = firstCycle; j < listSize - 1; j++)
                    cycleSize[listArray[j]] = cycleLength;
                allComponents[base]++;
                allCyclicNodes[base] += cycleLength;
                cycleTime += MPI_Wtime() - tt;
            }
        }
        tt = MPI_Wtime();
        terminal=0;
        for(i = 1; i < n; i++)
            if(!image[i]) terminal++;
        maxTAll[base] = mT;
        maxCAll[base] = mC;
        terminalAll[base] = terminal;
        computeResults(distToCycle, cycleSize, &allToCycleSum[base], &allCycleLengthSum[base]);
        resultsTime += MPI_Wtime() - tt;
    }
    if(mynode == 0)
    {
        s = fopen(STATUS, "w");
        fprintf(s, "Writing Results...\n");
        fclose(s);
    }
    writeTotalResults(
    maxTAll,
    maxCAll,
    terminalAll,
    allComponents,
    allCyclicNodes,
    allToCycleSum,
    allCycleLengthSum,
    oneCycles,
    twoCycles,
    threeCycles,
    fourCycles, 
    fiveCycles,
    sevenCycles,
    tenCycles,
    twentyfiveCycles,
    marker);
    if(mynode == 0)
    {
        s = fopen(STATUS, "w");
        fprintf(s, "%lf minutes...\n Exiting...\n%lf %lf %lf %lf %lf\n", (MPI_Wtime() - t)/60, expTime, tailTime);
        fclose(s);
    }
    printf("%lf\n", expTime);
}