Skip to content

Commit

Permalink
Parallelised the greedy polarisation algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
leifeld committed Jan 11, 2025
1 parent f0bda2f commit 0f41fd2
Showing 1 changed file with 108 additions and 94 deletions.
202 changes: 108 additions & 94 deletions dna/src/main/java/dna/export/Polarisation.java
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,10 @@ public PolarisationResultTimeSeries getResults() {
* @param congruenceNetwork A 2D array representing the congruence network.
* @param conflictNetwork A 2D array representing the conflict network.
* @param normaliseScores Should the result be divided by its theoretical maximum (the sum of the two matrix norms)?
* @param numClusters The number of clusters.
* @return The quality of polarization as a double value.
*/
private double qualityAbsdiff(int[] memberships, double[][] congruenceNetwork, double[][] conflictNetwork, boolean normalise) {
private double qualityAbsdiff(int[] memberships, double[][] congruenceNetwork, double[][] conflictNetwork, boolean normalise, int numClusters) {
double congruenceNorm = calculateMatrixNorm(congruenceNetwork);
double conflictNorm = calculateMatrixNorm(conflictNetwork);

Expand Down Expand Up @@ -586,10 +587,11 @@ private class GeneticIteration {
* @param congruenceNetwork The congruence matrix.
* @param conflictNetwork The conflict matrix.
* @param normalise Should the quality/fitness scores be normalised?
* @param numClusters The number of clusters.
* @param rng The random number generator to use.
* @return A list of children cluster solutions.
*/
GeneticIteration(ArrayList<ClusterSolution> clusterSolutions, double[][] congruenceNetwork, double[][] conflictNetwork, boolean normalise, Random rng) {
GeneticIteration(ArrayList<ClusterSolution> clusterSolutions, double[][] congruenceNetwork, double[][] conflictNetwork, boolean normalise, int numClusters, Random rng) {
this.clusterSolutions = new ArrayList<>(clusterSolutions);
this.normalise = normalise;
this.congruenceNetwork = congruenceNetwork.clone();
Expand All @@ -608,7 +610,7 @@ private class GeneticIteration {
"Number of mutations based on the mutation percentage.");
Dna.logger.log(log);

this.q = evaluateQuality(this.congruenceNetwork, this.conflictNetwork, this.normalise);
this.q = evaluateQuality(this.congruenceNetwork, this.conflictNetwork, normalise, numClusters);
this.children = eliteRetentionStep(this.clusterSolutions, this.q, this.numElites);
this.children = crossoverStep(this.clusterSolutions, this.q, this.children, rng);
this.children = mutationStep(this.children, this.numMutations, this.n, rng);
Expand All @@ -622,13 +624,14 @@ private class GeneticIteration {
* @param congruenceNetwork The congruence network matrix.
* @param conflictNetwork The conflict network matrix.
* @param normalise Normalise the results?
* @param numClusters The number of clusters.
* @return An array of quality scores for each cluster solution.
*/
private double[] evaluateQuality(double[][] congruenceNetwork, double[][] conflictNetwork, boolean normalise) {
private double[] evaluateQuality(double[][] congruenceNetwork, double[][] conflictNetwork, boolean normalise, int numClusters) {
double[] q = new double[clusterSolutions.size()];
for (int i = 0; i < clusterSolutions.size(); i++) {
int[] mem = clusterSolutions.get(i).getMemberships();
q[i] = qualityAbsdiff(mem, congruenceNetwork, conflictNetwork, normalise);
q[i] = qualityAbsdiff(mem, congruenceNetwork, conflictNetwork, normalise, numClusters);
}
return q;
}
Expand Down Expand Up @@ -846,7 +849,7 @@ public PolarisationResultTimeSeries geneticAlgorithm () {
// Run through iterations and do the breeding, then collect results and stats
lastIndex = numIterations - 1; // choose last possible value here as a default if early convergence does not happen
for (int i = 0; i < numIterations; i++) {
GeneticIteration geneticIteration = new GeneticIteration(cs, this.congruence.get(t).getMatrix(), this.conflict.get(t).getMatrix(), this.normaliseScores, rng);
GeneticIteration geneticIteration = new GeneticIteration(cs, this.congruence.get(t).getMatrix(), this.conflict.get(t).getMatrix(), this.normaliseScores, this.numClusters, rng);
cs = geneticIteration.getChildren();

// compute summary statistics based on iteration step and retain them
Expand Down Expand Up @@ -1279,105 +1282,116 @@ private ArrayList<ExportStatement>[][][] create3dArray(String[] var1Values, Stri
/**
* Prepare the greedy membership swapping algorithm and run all the iterations.
* Take out the maximum quality measure at the last step and create an object
* that stores the polarisation results.
* that stores the polarisation results. Run the algorithm in parallel for all
* time windows.
*/
private PolarisationResultTimeSeries greedyAlgorithm () {
Random rng = (this.randomSeed == 0) ? new Random() : new Random(this.randomSeed); // Initialize random number generator
ArrayList<PolarisationResult> polarisationResults = new ArrayList<PolarisationResult>();

ArrayList<PolarisationResult> polarisationResults = ProgressBar
.wrap(IntStream.range(0, Polarisation.this.congruence.size()).parallel(), "Greedy algorithm")
.map(t -> greedyTimeStep(Polarisation.this.congruence.get(t),
Polarisation.this.conflict.get(t),
Polarisation.this.normaliseScores,
Polarisation.this.numClusters,
rng.nextLong()))
.collect(Collectors.toCollection(ArrayList::new));

PolarisationResultTimeSeries polarisationResultTimeSeries = new PolarisationResultTimeSeries(polarisationResults);
return polarisationResultTimeSeries;
}
/**
* A single run of the greedy algorithm, for one pair of congruence and conflict
* network, i.e., for one time slice.
*
* @param congruence A Matrix object containing the 2D congruence array.
* @param conflict A Matrix object containing the 2D conflict array.
* @param normaliseScores Normalise the absdiff quality/fitness scores to 1.0?
* @param numClusters The number of clusters.
* @param seed A random seed, which is used to create a new random number generator for this algorithm run. The seed should have been itself generated by a random number generator to ensure variability across time steps and reproducibility.
* @return a PolarisationResult object
*/
private PolarisationResult greedyTimeStep(Matrix congruence, Matrix conflict, boolean normaliseScores, int numClusters, long seed) {

// for each time step, run the algorithm over the cluster solutions; retain quality and memberships
double[][] congruenceMatrix, conflictMatrix;
int t, oldI, oldJ;
double[][] congruenceMatrix = congruence.getMatrix();
double[][] conflictMatrix = conflict.getMatrix();
ArrayList<Double> maxQArray = new ArrayList<Double>();
int[] bestMemberships, mem, mem2;
double maxQ, q1, q2;
boolean noChanges;

try (ProgressBar pb = new ProgressBar("Greedy algorithm", this.congruence.size())) {
for (t = 0; t < congruence.size(); t++) { // go through all time steps of the time window networks
maxQArray.clear();
congruenceMatrix = congruence.get(t).getMatrix();
conflictMatrix = conflict.get(t).getMatrix();
double combinedNorm = calculateMatrixNorm(congruenceMatrix) + calculateMatrixNorm(congruenceMatrix);

if (congruenceMatrix.length > 0 || combinedNorm == 0.0) { // if the network has no nodes or edges, skip this step and return 0 directly

// Create initially random cluster solution to update
ClusterSolution cs = new ClusterSolution(congruence.get(t).getMatrix().length, numClusters, rng);
mem = cs.getMemberships();

// evaluate quality of initial solution
maxQArray.add(qualityAbsdiff(mem, congruenceMatrix, conflictMatrix, this.normaliseScores));
bestMemberships = mem.clone();
maxQ = maxQArray.get(0);

boolean convergence = false;
while (!convergence) { // run the two nested for-loops repeatedly until there are no more swaps
noChanges = true;
for (int i = 0; i < mem.length; i++) {
for (int j = 1; j < mem.length; j++) { // swap positions i and j in the membership vector and see if leads to higher fitness
if (i < j && mem[i] != mem[j]) {
mem2 = mem.clone();
oldI = mem2[i];
oldJ = mem2[j];
mem2[i] = oldJ;
mem2[j] = oldI;
q1 = qualityAbsdiff(mem, congruenceMatrix, conflictMatrix, this.normaliseScores);
q2 = qualityAbsdiff(mem2, congruenceMatrix, conflictMatrix, this.normaliseScores);
if (q2 > q1) { // candidate solution has higher fitness -> keep it
mem = mem2.clone(); // accept the new solution if it was better than the previous
maxQArray.add(q2);
maxQ = q2;
bestMemberships = mem.clone();
noChanges = false;
}
}
double combinedNorm = calculateMatrixNorm(congruenceMatrix) + calculateMatrixNorm(congruenceMatrix);

if (congruenceMatrix.length > 0 || combinedNorm == 0.0) { // if the network has no nodes or edges, skip this step and return 0 directly

// Create initially random cluster solution to update
Random random = new Random(seed);
ClusterSolution cs = new ClusterSolution(congruenceMatrix.length, numClusters, random);
int[] mem = cs.getMemberships();

// evaluate quality of initial solution
maxQArray.add(qualityAbsdiff(mem, congruenceMatrix, conflictMatrix, normaliseScores, numClusters));
int[] bestMemberships = mem.clone();
double maxQ = maxQArray.get(0);

boolean convergence = false;
while (!convergence) { // run the two nested for-loops repeatedly until there are no more swaps
boolean noChanges = true;
for (int i = 0; i < mem.length; i++) {
for (int j = 1; j < mem.length; j++) { // swap positions i and j in the membership vector and see if leads to higher fitness
if (i < j && mem[i] != mem[j]) {
int[] mem2 = mem.clone();
int oldI = mem2[i];
int oldJ = mem2[j];
mem2[i] = oldJ;
mem2[j] = oldI;
double q1 = qualityAbsdiff(mem, congruenceMatrix, conflictMatrix, normaliseScores, numClusters);
double q2 = qualityAbsdiff(mem2, congruenceMatrix, conflictMatrix, normaliseScores, numClusters);
if (q2 > q1) { // candidate solution has higher fitness -> keep it
mem = mem2.clone(); // accept the new solution if it was better than the previous
maxQArray.add(q2);
maxQ = q2;
bestMemberships = mem.clone();
noChanges = false;
}
}
if (noChanges) {
convergence = true;
}
}

double[] maxQArray2 = new double[maxQArray.size()];
for (int i = 0; i < maxQArray.size(); i++) {
maxQArray2[i] = maxQArray.get(i);
}

// save results in array as a complex object
double[] avgQArray = maxQArray2;
double[] sdQArray = new double[maxQArray.size()];
PolarisationResult pr = new PolarisationResult(
maxQArray2,
avgQArray,
sdQArray,
maxQ,
bestMemberships,
congruence.get(t).getRowNames(),
true,
congruence.get(t).getStart(),
congruence.get(t).getStop(),
congruence.get(t).getDateTime());
polarisationResults.add(pr);
} else { // zero result because network is empty
PolarisationResult pr = new PolarisationResult(
new double[] { 0 },
new double[] { 0 },
new double[] { 0 },
0.0,
new int[0],
new String[0],
true,
congruence.get(t).getStart(),
congruence.get(t).getStop(),
congruence.get(t).getDateTime());
polarisationResults.add(pr);
}
pb.step();
if (noChanges) {
convergence = true;
}
}
}

PolarisationResultTimeSeries polarisationResultTimeSeries = new PolarisationResultTimeSeries(polarisationResults);
return polarisationResultTimeSeries;
double[] maxQArray2 = new double[maxQArray.size()];
for (int i = 0; i < maxQArray.size(); i++) {
maxQArray2[i] = maxQArray.get(i);
}

// save results in array as a complex object
double[] avgQArray = maxQArray2;
double[] sdQArray = new double[maxQArray.size()];
PolarisationResult pr = new PolarisationResult(
maxQArray2,
avgQArray,
sdQArray,
maxQ,
bestMemberships,
congruence.getRowNames(),
true,
congruence.getStart(),
congruence.getStop(),
congruence.getDateTime());
return pr;
} else { // zero result because network is empty
PolarisationResult pr = new PolarisationResult(
new double[] { 0 },
new double[] { 0 },
new double[] { 0 },
0.0,
new int[0],
new String[0],
true,
congruence.getStart(),
congruence.getStop(),
congruence.getDateTime());
return pr;
}
}
}

0 comments on commit 0f41fd2

Please sign in to comment.