Skip to content

Commit

Permalink
Minor change to reduce GC overhead
Browse files Browse the repository at this point in the history
  • Loading branch information
EdwardRaff committed May 29, 2017
1 parent 0d0e748 commit 823f28a
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions JSAT/src/jsat/classifiers/svm/extended/CPM.java
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ else if(i == old_owner)
* @param sign_mul Either positive or negative 1. Controls whether or not
* the positive or negative class is to be enveloped by the polytype
*/
private void sgdTrain(ClassificationDataSet D, MatrixOfVecs W, Vec b, int sign_mul)
private void sgdTrain(ClassificationDataSet D, MatrixOfVecs W, Vec b, int sign_mul, ExecutorService ex)
{
IntList order = new IntList(D.getSampleSize());
ListUtils.addRange(order, 0, D.getSampleSize(), 1);
Expand All @@ -417,6 +417,8 @@ private void sgdTrain(ClassificationDataSet D, MatrixOfVecs W, Vec b, int sign_m
int[] assignments = new int[D.getSampleSize()];//who owns each data point
Arrays.fill(assignments, -1);//Starts out that no one is assigned!

Vec dots = new DenseVector(W.rows());

long t = 0;
for(int epoch = 0; epoch < epochs; epoch++)
{
Expand All @@ -428,8 +430,10 @@ private void sgdTrain(ClassificationDataSet D, MatrixOfVecs W, Vec b, int sign_m
Vec x_i = D.getDataPoint(i).getNumericalValues();
int y_i = (D.getDataPointCategory(i)*2-1)*sign_mul;

Vec dots = W.multiply(x_i);
dots.mutableAdd(b);
//this sets dots = bias, which we then add to with matrix-vector product
//result is the same as dots = W x_i + b
b.copyTo(dots);
W.multiply(x_i, 1.0, dots);

if(y_i == -1)
{
Expand Down Expand Up @@ -490,8 +494,8 @@ public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool)
MatrixOfVecs W_p = new MatrixOfVecs(Wv_p);
MatrixOfVecs W_n = new MatrixOfVecs(Wv_n);

sgdTrain(dataSet, W_p, bp, +1);
sgdTrain(dataSet, W_n, bn, -1);
sgdTrain(dataSet, W_p, bp, +1, threadPool);
sgdTrain(dataSet, W_n, bn, -1, threadPool);

this.Wp = new DenseMatrix(W_p);
this.Wn = new DenseMatrix(W_n);
Expand Down

0 comments on commit 823f28a

Please sign in to comment.