Skip to content

Commit

Permalink
Fixed resource leaks in file loading and 'continued' sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
lejon committed Feb 17, 2020
1 parent 579aa49 commit ad5370f
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 91 deletions.
3 changes: 2 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@
<!-- v8.3.0: Support saving sampler from main Java application -->
<!-- v8.4.0: Added support for continued sampling in main, controlled by 'continue' CMD line option -->
<!-- v8.5.0: Support for compressed input files (zip, gz). Improved iteration callback support -->
<version>8.5.0</version>
<!-- v8.5.1: Fixed resource leaks in file loading and 'continued' sampling -->
<version>8.5.1</version>


<name>Partially Collapsed Parallel LDA</name>
Expand Down
13 changes: 9 additions & 4 deletions src/main/java/cc/mallet/topics/UncollapsedParallelLDA.java
Original file line number Diff line number Diff line change
Expand Up @@ -2046,11 +2046,16 @@ private void readObject (ObjectInputStream in) throws IOException, ClassNotFound
lu.checkAndCreateCurrentLogDir(logSuitePath);
config.setLoggingUtil(lu);
if(activeSubconfig==null) {
activeSubconfig = config.getSubConfigs()[0];
System.out.println("Active subconfig not set, activating first available (" + activeSubconfig + ") ...");
String [] subconfs = config.getSubConfigs();
if(subconfs!= null && subconfs.length > 0) {
System.out.println("Active subconfig not set, activating first available (" + activeSubconfig + ") ...");
activeSubconfig = subconfs[0];
config.activateSubconfig(activeSubconfig);
System.out.println("Activating subconfig: " + activeSubconfig);
}
} else {
config.activateSubconfig(activeSubconfig);
}
System.out.println("Activating subconfig: " + activeSubconfig);
config.activateSubconfig(activeSubconfig);

System.out.println("Done Reading config!");
} catch (ConfigurationException e) {
Expand Down
15 changes: 9 additions & 6 deletions src/main/java/cc/mallet/topics/tui/ParallelLDA.java
Original file line number Diff line number Diff line change
Expand Up @@ -208,15 +208,18 @@ void doIteration(LDACommandLineParser cp, LDAConfiguration config, LoggingUtils
}
System.out.println("Scheme: " + whichModel);

InstanceList instances = LDAUtils.loadDataset(config, dataset_fn);
instances.getAlphabet().stopGrowth();

boolean continueSampling = isContinuation(cp);
LDAGibbsSampler model = createModel(config, whichModel);
InstanceList instances = null;
if(continueSampling) {
System.out.println("Continuing sampling from previously stored model...");
initSamplerFromSaved(config, instances, model);
}
initSamplerFromSaved(config, model);
instances = model.getDataset();
} else {
instances = LDAUtils.loadDataset(config, dataset_fn);
instances.getAlphabet().stopGrowth();
}

if(model==null) {
System.out.println("No valid model selected ('" + whichModel + "' is not a recognized model), please select a valid model...");
Expand Down Expand Up @@ -414,9 +417,9 @@ void doIteration(LDACommandLineParser cp, LDAConfiguration config, LoggingUtils
System.out.println(new Date() + ": I am done!");
}

private void initSamplerFromSaved(LDAConfiguration config, InstanceList instances, LDAGibbsSampler model) {
private void initSamplerFromSaved(LDAConfiguration config, LDAGibbsSampler model) {
String storedDir = config.getSavedSamplerDirectory(LDAConfiguration.STORED_SAMPLER_DIR_DEFAULT);
LDASamplerWithPhi newModel = LDAUtils.loadStoredSampler(instances, config, storedDir);
LDASamplerWithPhi newModel = LDAUtils.loadStoredSampler(config, storedDir);
// Since the user asked us to continue using this sampler, we assume it is "initiable"
LDASamplerInitiable toInit = (LDASamplerInitiable) model;
toInit.initFrom(newModel);
Expand Down
147 changes: 67 additions & 80 deletions src/main/java/cc/mallet/util/LDAUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -292,19 +292,17 @@ public static InstanceList loadInstancesPrune(String inputFile, String stoplistF
*/
public static InstanceList loadInstancesPrune(String inputFile, String stoplistFile, int pruneCount, boolean keepNumbers,
int maxBufSize, boolean keepConnectors, Alphabet dataAlphabet, LabelAlphabet targetAlphabet) throws FileNotFoundException {
BufferedInputStream in;
try {
in = new BufferedInputStream(streamFromFile(inputFile));

try (BufferedInputStream in = new BufferedInputStream(streamFromFile(inputFile))) {
in.mark(Integer.MAX_VALUE);
return loadInstancesPrune(in, stoplistFile, pruneCount, keepNumbers,
maxBufSize, keepConnectors, dataAlphabet, targetAlphabet);
} catch (IOException e) {
throw new IllegalArgumentException(e);
}

return loadInstancesPrune(in, stoplistFile, pruneCount, keepNumbers,
maxBufSize, keepConnectors, dataAlphabet, targetAlphabet);
}


/**
* Loads instances and prunes away low occurring words
*
Expand All @@ -324,7 +322,7 @@ public static InstanceList loadInstancesPrune(BufferedInputStream in, String sto
int dataGroup = 3;
int labelGroup = 2;
int nameGroup = 1; // data, label, name fields

in.mark(Integer.MAX_VALUE);

tokenizer = initTokenizer(stoplistFile, keepNumbers, maxBufSize, keepConnectors);
Expand Down Expand Up @@ -475,22 +473,18 @@ public static InputStream streamFromFile(String inputFile) throws FileNotFoundEx
return sout;
}
}

public static InstanceList loadInstancesRaw(String inputFile, String stoplistFile, int keepCount, int maxBufSize,
Alphabet dataAlphabet, LabelAlphabet targetAlphabet) throws FileNotFoundException {

BufferedInputStream in;
try {
in = new BufferedInputStream(streamFromFile(inputFile));
Alphabet dataAlphabet, LabelAlphabet targetAlphabet) throws FileNotFoundException {
try (BufferedInputStream in = new BufferedInputStream(streamFromFile(inputFile))){
in.mark(Integer.MAX_VALUE);
return loadInstancesRaw(in, stoplistFile, keepCount, maxBufSize, dataAlphabet, targetAlphabet);
} catch (IOException e) {
throw new IllegalArgumentException(e);
}

return loadInstancesRaw(in, stoplistFile, keepCount, maxBufSize, dataAlphabet, targetAlphabet);
}


/**
* Loads instances and keeps the <code>keepCount</code> number of words with
* the highest TF-IDF. Does no preprocessing of the input other than splitting
Expand All @@ -512,11 +506,11 @@ public static InstanceList loadInstancesRaw(BufferedInputStream in, String stopl
int dataGroup = 3;
int labelGroup = 2;
int nameGroup = 1; // data, label, name fields

in.mark(Integer.MAX_VALUE);

tokenizer = initRawTokenizer(stoplistFile, maxBufSize);

if (keepCount > 0) {
CsvIterator reader = new CsvIterator(
new InputStreamReader(in),
Expand Down Expand Up @@ -609,22 +603,18 @@ public static InstanceList loadInstancesRaw(BufferedInputStream in, String stopl

return instances;
}


public static InstanceList loadInstancesKeep(String inputFile, String stoplistFile, int keepCount, boolean keepNumbers,
int maxBufSize, boolean keepConnectors, Alphabet dataAlphabet, LabelAlphabet targetAlphabet) throws FileNotFoundException {

BufferedInputStream in;
try {
in = new BufferedInputStream(streamFromFile(inputFile));
int maxBufSize, boolean keepConnectors, Alphabet dataAlphabet, LabelAlphabet targetAlphabet) throws FileNotFoundException {
try (BufferedInputStream in = new BufferedInputStream(streamFromFile(inputFile))){
in.mark(Integer.MAX_VALUE);
return loadInstancesKeep(in, stoplistFile, keepCount, keepNumbers,
maxBufSize, keepConnectors, dataAlphabet, targetAlphabet);

} catch (IOException e) {
throw new IllegalArgumentException(e);
}

return loadInstancesKeep(in, stoplistFile, keepCount, keepNumbers,
maxBufSize, keepConnectors, dataAlphabet, targetAlphabet);

}

/**
Expand Down Expand Up @@ -750,7 +740,7 @@ public static InstanceList loadInstancesKeep(BufferedInputStream in, String stop

/**
* Re-creates the pipe that is used if loading with TF-IDF
* This is ugly as hell, but I wanted ti to be as similar as
* This is ugly as hell, but I wanted it to be as similar as
* possible as when using loadDataset
*
* @param inputFile Input file to load
Expand All @@ -770,63 +760,60 @@ public static TfIdfPipe getTfIdfPipe(String inputFile, String stoplistFile, int
int labelGroup = 2;
int nameGroup = 1; // data, label, name fields

tokenizer = initTokenizer(stoplistFile, keepNumbers, maxBufSize, keepConnectors);

BufferedInputStream in;
try {
in = new BufferedInputStream(streamFromFile(inputFile));
} catch (IOException e) {
throw new IllegalArgumentException(e);
}

if (keepCount > 0) {
CsvIterator reader = new CsvIterator(
new InputStreamReader(in),
lineRegex,
dataGroup,
labelGroup,
nameGroup);
tokenizer = initTokenizer(stoplistFile, keepNumbers, maxBufSize, keepConnectors);
try (BufferedInputStream in = new BufferedInputStream(streamFromFile(inputFile))) {
CsvIterator reader = new CsvIterator(
new InputStreamReader(in),
lineRegex,
dataGroup,
labelGroup,
nameGroup);

ArrayList<Pipe> pipes = new ArrayList<Pipe>();
Alphabet alphabet = null;
if(dataAlphabet==null) {
alphabet = new Alphabet();
} else {
alphabet = dataAlphabet;
}

ArrayList<Pipe> pipes = new ArrayList<Pipe>();
Alphabet alphabet = null;
if(dataAlphabet==null) {
alphabet = new Alphabet();
} else {
alphabet = dataAlphabet;
}
CharSequenceLowercase csl = new CharSequenceLowercase();
SimpleTokenizer st = tokenizer.deepClone();
StringList2FeatureSequence sl2fs = new StringList2FeatureSequence(alphabet);
TfIdfPipe tfIdfPipe = new TfIdfPipe(alphabet, null);

CharSequenceLowercase csl = new CharSequenceLowercase();
SimpleTokenizer st = tokenizer.deepClone();
StringList2FeatureSequence sl2fs = new StringList2FeatureSequence(alphabet);
TfIdfPipe tfIdfPipe = new TfIdfPipe(alphabet, null);
pipes.add(csl);
pipes.add(st);
pipes.add(sl2fs);
if (keepCount > 0) {
pipes.add(tfIdfPipe);
}

pipes.add(csl);
pipes.add(st);
pipes.add(sl2fs);
if (keepCount > 0) {
pipes.add(tfIdfPipe);
}
Pipe serialPipe = new SerialPipes(pipes);

Pipe serialPipe = new SerialPipes(pipes);
Iterator<Instance> iterator = serialPipe.newIteratorFrom(reader);

Iterator<Instance> iterator = serialPipe.newIteratorFrom(reader);
int count = 0;

int count = 0;
// We aren't really interested in the instance itself,
// just the total feature counts.
while (iterator.hasNext()) {
count++;
if (count % 100000 == 0) {
System.out.println(count);
}
iterator.next();
}

// We aren't really interested in the instance itself,
// just the total feature counts.
while (iterator.hasNext()) {
count++;
if (count % 100000 == 0) {
System.out.println(count);
if (keepCount > 0) {
tfIdfPipe.addPrunedWordsToStoplist(tokenizer, keepCount);
return tfIdfPipe;
}
iterator.next();
} catch (IOException e) {
throw new IllegalArgumentException(e);
}

if (keepCount > 0) {
tfIdfPipe.addPrunedWordsToStoplist(tokenizer, keepCount);
return tfIdfPipe;
}
} else {
return null;
}
Expand Down Expand Up @@ -2411,7 +2398,7 @@ public static InstanceList loadInstancesStrings(String [] doclines, String class
return instances;
}

public static LDASamplerWithPhi loadStoredSampler(InstanceList trainingset, LDAConfiguration config, String saveDir) {
public static LDASamplerWithPhi loadStoredSampler(LDAConfiguration config, String saveDir) {
String configHash = getConfigSetHash(config);
if(!saveDir.endsWith(File.separator)) saveDir = saveDir + File.separator;
String samplerFn = saveDir + buildSamplerSaveFilename(configHash);
Expand Down

0 comments on commit ad5370f

Please sign in to comment.