/*
 * KerLin.class
 * Part of package project
 *
 * Author: Philip Bradley
 *
 */

package project;

import java.util.Random;

public class KerLin {
  private GeometricGraph graph;
  private int[][] graphAdjacancyMatrix;
  private int[] partitionVector;
  private int[] differenceVector;
  private int[][] exchangeVector;
  private int[] gainList;
  private boolean[] exchanged;
  private CostStructure[] vertexCostVector;
  private int[] degree;
  private int graphOrder;
  private int max_i;
  private int max_j;
  private int k;

  public KerLin(GeometricGraph g, int k, int[] initialSolution) {
    graph = g;
    partitionVector = initialSolution;
    graphOrder = g.getOrder();
    graphAdjacancyMatrix = g.getAdjacancyMatrix();
    differenceVector = new int[graphOrder];
    exchangeVector = new int[graphOrder][2];
    gainList = new int[graphOrder];
    exchanged = new boolean[graphOrder];
    max_i = 0;
    max_j = 0;
    this.k = k;
    int i;
    int j;
    
    // Initialise degree vector
    degree = new int[graphOrder];
    
    for(i=0; i<graphOrder; i++) {
      degree[i] = 0;
    }
    
    for(i=0; i<graphOrder; i++) {
      for(j=i+1; j<graphOrder; j++) {
	if (graphAdjacancyMatrix[i][j] != 0) {
	  degree[i]++;
	  degree[j]++;
	}
      }
    }
    
    // Initialise vertexCostVector
    vertexCostVector = new CostStructure[graphOrder];
    for (i=0; i<graphOrder; i++) {
      vertexCostVector[i] = new CostStructure();
    }

  }

  public KerLin(GeometricGraph g, int k) {
    graph = g;
    graphOrder = g.getOrder();
    graphAdjacancyMatrix = g.getAdjacancyMatrix();
    differenceVector = new int[graphOrder];
    exchangeVector = new int[graphOrder][2];
    gainList = new int[graphOrder];
    exchanged = new boolean[graphOrder];
    partitionVector = new int[graphOrder];
    max_i = 0;
    max_j = 0;
    this.k = k;
    int i;
    int j;
 
    // Initialise degree vector
    degree = new int[graphOrder];
    
    for(i=0; i<graphOrder; i++) {
      degree[i] = 0;
    }
    
    for(i=0; i<graphOrder; i++) {
      for(j=i+1; j<graphOrder; j++) {
	if (graphAdjacancyMatrix[i][j] != 0) {
	  degree[i]++;
	  degree[j]++;
	}
      }
    }
    
    // Initialise vertexCostVector
    vertexCostVector = new CostStructure[graphOrder];
    for (i=0; i<graphOrder; i++) {
      vertexCostVector[i] = new CostStructure();
    }
    
  }
  
  public void setSolution(int[] newSolution) {
    System.arraycopy(newSolution, 0, partitionVector, 0, partitionVector.length);
  }
  
  public int[] kerLinQuickOptimise(int maxExchangeSize) {
    // Performs only one iteration of Kernighan Lin 
    
    int i;
    int j;	
    int kerLinGain = 0;	
    int[] optimisedSolution = new int[graphOrder];
    
    // Perform pairwise optimisation on partitions
    try {
      for(i=0; i<k; i++) {
	for(j=i+1; j<k; j++) {
	  kerLinGain = phaseOneOptimisation(maxExchangeSize, i, j);
	}
      }
      optimisedSolution = partitionVector;
    } catch (Exception e) {
      System.err.println("Exception: " + e.getMessage());
      System.exit(-1);
    }
    return(optimisedSolution);
  }

  public Solution kerLinQuickOptimise(Solution s, int maxExchangeSize) {
    // Performs only one iteration of Kernighan Lin 
    
    int i;
    int j;	
    int kerLinGain = 0;	
    Solution newSolution = (Solution)s.clone();
    
    System.arraycopy( s.getPartition(), 0, partitionVector, 0, partitionVector.length);
    // Perform pairwise optimisation on partitions
    try {
      for(i=0; i<k; i++) {
	for(j=i+1; j<k; j++) {
	  kerLinGain = phaseOneOptimisation(maxExchangeSize, i, j);
	}
      }
      newSolution.setPartition(partitionVector);
    } catch (Exception e) {
      System.err.println("Exception: " + e.getMessage());
      System.exit(-1);
    }
    return(newSolution);
  }
  
  public int[] kerLinOptimise(int maxExchangeSize) {
    
    int i;
    int j;
    int kerLinGain = 0;
    int[] optimisedSolution = new int[graphOrder];
    
    try {
      do { 		    
	kerLinGain = 0;

	// Perform pairwise optimisation on partitions
	for(i=0; i<k; i++) {
	  for(j=i+1; j<k; j++) {
	    kerLinGain = phaseOneOptimisation(maxExchangeSize, i, j);
	  }
	}
      } while (kerLinGain != 0); 
      optimisedSolution = partitionVector;
    } catch (Exception e) {
      System.err.println("Caught Exception: " + e.getMessage());
      System.exit(-1);
    }
    return(optimisedSolution);
  }

  private int phaseOneOptimisation(int maxExchangeSize, int pPartition, int qPartition) throws BadPartitionException {
    
    int index = 0;
    int maxIndex = partitionVector.length;
    boolean foundp = false;
    boolean foundq = false;
    int maxGain;
    int randomVertexP = 0;
    int randomVertexQ = 0;
    
    // Check that both partitions exist
    index = 0;
    while ((index < maxIndex) && !(foundp && foundq)) {
      foundp |= (partitionVector[index] == pPartition);
      foundq |= (partitionVector[index] == qPartition);
      index++;
    }
    
    if ((foundp && foundq) && (pPartition != qPartition)) {
      // Initialise the values of internal and external for each CostStructure in the
      // vector. This depends on the partitions that are under consideration.
      calculateCostVector(pPartition, qPartition);
      
      // Find an arbitrary vertex p in pPartition and q in qPartition			
      index=0;
      while ((index < graphOrder) && ((randomVertexP == 0) || (randomVertexQ == 0))) {
	if (partitionVector[index] == pPartition) {
	  randomVertexP = index;
	} 
	if (partitionVector[index] == qPartition) {
	  randomVertexQ = index;
	}
	index++;
      }
      max_i = randomVertexP;
      max_j = randomVertexQ;
      
      // Initialise differenceVector
      for(index=0; index<graphOrder; index++) {
	differenceVector[index] = vertexCostVector[index].external -  vertexCostVector[index].internal;
      }
      
      // Initialise exchange vector and gainList
      for (index=0; index<graphOrder; index++) {
	exchanged[index] = false;
	gainList[index] = Integer.MIN_VALUE;
      }
      
      index=0; 
      while (index<maxExchangeSize) {	
	// Get two vertices yielding maximum gain
	try {
	  max_i = getMaxGainVertex(pPartition);
	  max_j = getMaxGainVertex(qPartition);
	} catch (Exception e) {
	  // No exchanges left, break out of loop 
	  break;
	}		      
	maxGain = gain(max_i, max_j);
	exchangeVector[index][0] = max_i;
	exchangeVector[index][1] = max_j;
	
	exchanged[max_i] = true;
	exchanged[max_j] = true;
	gainList[index] = maxGain;
	
	// Update the gain list 
	updateDifferenceVector();
	index++;
      }     
      
      // Find optimum k such that k exchanges yields greatest improvement
      int maximumGain = 0;
      int currentGain = 0;
      int k = -1;
      
      index = 0;
      while ((index < maxExchangeSize) && (gainList[index] != Integer.MIN_VALUE)) {
	currentGain += gainList[index];
	if (currentGain > maximumGain) {
	  maximumGain = currentGain;
	  k = index;
	}
	index++;
      }
      
      if (maximumGain > 0) {
	       		        // Make changes as specified by exchangeVector
	for (index=0; index<=k; index++) {
	  exchange(exchangeVector[index][0], exchangeVector[index][1]);
	}
      } 
      return(maximumGain);		    
    } else {
      throw(new BadPartitionException());
    }
  }
  
  private int getMaxGainVertex(int partition) throws NoVerticesAvailableException {
    
    int index=0;
    int maxVertex = 0;
    boolean foundVertex = false;		
    
    // Find any vertex in partition that is available for swapping
    while ((index<graphOrder) && (!foundVertex))  {
      if ((partitionVector[index] == partition) && (!exchanged[index])) {
	maxVertex = index;
	foundVertex = true;
      }
      index++;
    }
    
    if (foundVertex) {
      // Find the best vertex to swap
      while (index<graphOrder) {
	// If the vertex is in the parition and is available
	if ((partitionVector[index] == partition) && (!exchanged[index])) {
	  // If the gain of this vertex is better than the best encountered so far
	  if (differenceVector[index] > differenceVector[maxVertex]) {
	    maxVertex = index;
	  }
	}
	index++;
      }
    } else {
      throw(new NoVerticesAvailableException());
    }
    return(maxVertex);
  }
  
		        
    
  private void updateDifferenceVector() {
    // Difference vector contains differences between interenal & external(wrt)cost
    // of vertices available for swapping
    // A value of Integer.MIN_VALUE indicated that the vertex is not available for swapping
    
    int index;	
    int delta;
    
    // Need to generate arbitrary newMax_i & newMax_j
    
    for(index=0; index<graphOrder; index++) {
      if (exchanged[index]) {
	differenceVector[index] = Integer.MIN_VALUE;
      } else {
	if (partitionVector[index] == partitionVector[max_i]) {
	  delta = 2 * (graphAdjacancyMatrix[index][max_i] - graphAdjacancyMatrix[index][max_j]);
	  differenceVector[index] += delta;
	  
	} else if (partitionVector[index] == partitionVector[max_j]) {
	  delta = 2 * (graphAdjacancyMatrix[index][max_j] - graphAdjacancyMatrix[index][max_i]); 
	  differenceVector[index] += delta;
	  
	} else if ((partitionVector[index] != partitionVector[max_i]) && (partitionVector[index] != partitionVector[max_j])) {
	  // The vertex is not available for switching since it is in a different partition
	  differenceVector[index] = Integer.MIN_VALUE;
	}
      }
    }
  }
  
  private void calculateCostVector(int pPartition, int qPartition) {
    // For each vertex calculate its cost. Each entry contains a CostStructure.
    // The cost vector does not convey the absolute cost of each vertex, it conveys its
    // internal cost and its external cost with respect to a partition under consideration.
    
    int n;
    for(n=0; n<graphOrder; n++) {
      if (partitionVector[n] == pPartition) {
	vertexCostVector[n] = getVertexCostWrt(n, qPartition);
      }
      if (partitionVector[n] == qPartition) {
	vertexCostVector[n] = getVertexCostWrt(n, pPartition);
      } 
      
    }
  }
  
  private int getInternalCost(int vertex) {
    // Returns the internal cost of a vertex
    
    int internalCost = 0;
    for(int n=0; n<graphOrder; n++) {
      if ((graph.hasEdge(vertex, n)) && (partitionVector[n] == partitionVector[vertex])) {
	internalCost++;
      }
    }
    return(internalCost);
  }

  private int getExternalCost(int vertex) {
    // Returns the external cost of a vertex
    
    int externalCost = 0;
    for(int n=0; n<graphOrder; n++) {
      if ((graph.hasEdge(vertex, n)) && (partitionVector[n] != partitionVector[vertex])) {
	externalCost++;
      }
    }
    return(externalCost);
  }
  
  private CostStructure getVertexCostWrt(int vertex, int targetPartition) {
    /* Returns a CostStructure with the vertex external cost with respect to targetPartition and its internal cost.
     * The external cost is with respect to a particular partition rather than with respect to all other partitions
     * because at any time, migrations between only two partitions is considered (k-way partitioning is performed 
     * in a pairwise manner.
     */
    
    int vertexPartition = partitionVector[vertex];
    CostStructure c = new CostStructure();
    
    for(int n=0; n<graphOrder; n++) {
      if (graphAdjacancyMatrix[vertex][n] != 0) {
	if ((partitionVector[n] == vertexPartition) && (n != vertex)) {
	  c.internal++;
	} 
	if (partitionVector[n] == targetPartition){
	  c.external++;
	}
      }
    }
    return(c);
  }
  
  private int getCostBetweenPartitions(int pPartition, int qPartition) throws BadPartitionException {
    
    int i=0;
    int j=0;
    int interPartitionCost = 0;
    
    if ((pPartition == qPartition) || (pPartition >= k) || (pPartition < 0) || (qPartition >= k) || (qPartition < 0)) {
      throw (new BadPartitionException());
    } else {
      for (i=0; i<graphOrder; i++) {
	for (j=i+1; j<graphOrder; j++) {				        
	  if (graphAdjacancyMatrix[i][j] != 0) {
	    // If vertices i and j are connected
	    if (((partitionVector[i] == pPartition) && (partitionVector[j] == qPartition)) ||
		((partitionVector[i] == qPartition) && (partitionVector[j] == pPartition))) {
	      // If one vertex is in pPartition and the other is in qPartition
	      interPartitionCost++;
	    }
	  }
	}
      }
    }
    return(interPartitionCost);
  }
  
  private int gain(int i, int j) {
    // Returns the gain resulting from switching vertices i and j
    
    int gainVal;
    
    if ((i == j) || (partitionVector[i] == partitionVector[j]) || exchanged[i] || exchanged[j]) {
      gainVal = Integer.MIN_VALUE;
    } else {
      // gain(i,j) = Di + Dj - 2Aij
      gainVal = differenceVector[i] + differenceVector[j] - (2 * graphAdjacancyMatrix[i][j]);
    } 
    return(gainVal);
  }
  
  private void exchange(int i, int j) {
    // Updates the state to reflect switch of vertices i and j.
    
    // Switch the vertices.	
    int tmpPartition = partitionVector[i];
    partitionVector[i] = partitionVector[j];
    partitionVector[j] = tmpPartition;
    
    // Update cost vector
    calculateCostVector(partitionVector[i], partitionVector[j]);
  }
  
  public int getCost() {
    /* Returns the total external cost of the partitioning. 
     * Each vertex has an internal cost and also an external cost with respect to the current partition
     * under consideration. 
     * It is not the case in general that degree[n] =  vertexCostVector[n].internal +  vertexCostVector[n].external
     * except where k = 2.
     */
    int totalExternalCost = 0;
    int n;
    
    for(n=0; n<graphOrder; n++) {
      // totalExternalCost += (degree[n] - getInternalCost(n));
      totalExternalCost += getExternalCost(n);
    }
    return((int)totalExternalCost/2);
  } 
  
  private int testGetCost() {
    // For test purposes
    int i=0;
    int j=0;
    int cost=0;
    for (i=0; i<graphOrder; i++) {
      for (j=i; j<graphOrder; j++) {
	if (graph.hasEdge(i,j)) {
	  if (partitionVector[i] != partitionVector[j]) {
	    // vertices are in different partitions
	    cost++;
	  }
	}
      }
    }
    return(cost);
  }

  protected void report() {
    // Write some status information to System.out, used only for testing.
    
    int n;
    int vertexExternalCost;
    int vertexInternalCost;
    CostStructure[] partitionCosts = new CostStructure[k];
    
    for(n=0; n<k; n++) {
      partitionCosts[n] = new CostStructure();
    }
    
    for(n=0; n<graphOrder; n++) {
      vertexExternalCost = getExternalCost(n);
      vertexInternalCost = (degree[n] - vertexExternalCost);
      // vertexInternalCost = getInternalCost(n);
      partitionCosts[partitionVector[n]].internal += vertexInternalCost;
      partitionCosts[partitionVector[n]].external += vertexExternalCost;
    }
    
    for(n=0; n<k; n++) {
      partitionCosts[n].internal /= 2;
    }
    
    // Write status to stdout
    for(n=0; n<k; n++) {
      System.out.print(n + ": " + partitionCosts[n].internal + "," + partitionCosts[n].external + "   ");
    }
    System.out.println("Cost: " + getCost() + "," + testGetCost());
  }

  private void test() {
    System.out.println("Graph has size: " + graph.getSize());
    int totalDegree = 0;
    for (int i=0; i<graphOrder; i++) {
      totalDegree += degree[i];
    }
    totalDegree = (int) totalDegree/2;
    System.out.println("Calculate edges: " + totalDegree);
    System.out.println("Cost: " + getCost());
    int xcost=0;
    for (int i=0; i<graphOrder; i++) {
      for (int j=i; j<graphOrder; j++) {
	if (graph.hasEdge(i, j) && (partitionVector[i] != partitionVector[j])) {
	  xcost++;
	}
      }
    }
    System.out.println("Calculate Cost: " + xcost);

    for (int i=0; i<graphOrder; i++) {
      System.out.println(getInternalCost(i) + "    " + getExternalCost(i) + "    " + degree[i]);
    }
  }
  
  public static void main(String[] args) {
    
    int k=0;
    int graphOrder; 
    int numberOfArgs = args.length;
    long startTime = 0;
    long finishTime = 0;
    String graphName = new String("");
    String optSwitch = new String("");
    String solutionFileName = new String("");
    GeometricGraph g = (GeometricGraph)null;
    Solution initialSolution = (Solution)null;

    String tmpString = new String("");
    int i=0;
    if ((numberOfArgs != 5) && (numberOfArgs != 7)) {
      System.err.println("Error: KerLin -g GraphFile -k Partition Size [-sol Solution] -f|-q");
      System.exit(1);
    } else {
      while (i < numberOfArgs) {
	tmpString = args[i];
	if (tmpString.equals("-g")) {
	  g = GeometricGraph.readFromFile(args[++i]);	
	} else if (tmpString.equals("-k")) {
	  k = new Integer(args[++i]).intValue();       	
	} else if (tmpString.equals("-q") || tmpString.equals("-f")) {
	  optSwitch = new String(args[i]);
	} else if (tmpString.equals("-sol")) {
	  solutionFileName = new String(args[++i]);
	} else {
	  System.err.println("Unrecognised Argument: " + args[i]);
	  System.exit(1);
	}
	i++;
      }
    }

    System.out.println("# Using graph of order: " + g.getOrder() + " and size: " + g.getSize() + ", k = " + k);;
    if (solutionFileName.equals("")) {
      // Generate random solution
      initialSolution = new Solution(g,k);
    } else {
      // Read solution from file
      initialSolution = Solution.readFromFile(solutionFileName);
      initialSolution.setGraph(g);
    }

    Solution s2 = (Solution)initialSolution.clone();
    KerLin kl = new KerLin(g, k, initialSolution.getPartition());
    graphOrder = g.getOrder();
    // Start Kernighan-Lin algorithm
    System.out.println("# Initial Cost: " + kl.getCost());
    startTime = System.currentTimeMillis();
    if (optSwitch.equals("-f")) {
      kl.kerLinOptimise(graphOrder); 
    } else if (optSwitch.equals("-q")) {
      kl.kerLinQuickOptimise(graphOrder);
    } else {
      System.err.println("Error: Illegal option: " + optSwitch);
    }
    finishTime = System.currentTimeMillis();
    System.out.println("# Total running time in milliseconds: " + (finishTime - startTime));
    System.out.println("# Solution: " + kl.getCost());
  }
}





