Skip to content

Commit

Permalink
major improvement and fixes for ce_calib:
Browse files Browse the repository at this point in the history
- create file directory before righting
- put proper termination of processes in finally
- cleanup unnecessary exception handling
- created separate functions for subtasks
- check progress of digest, filtering, writing
- handle failed prediction batches and continue writing remaining
- store failed batches and append existing file only the missing batches
- exit with code 1 if file exists without missing batch info to prevent corruption / undesired overwrite
  • Loading branch information
picciama committed Jan 5, 2024
1 parent 7f7bfd6 commit b8d19ba
Showing 1 changed file with 166 additions and 98 deletions.
264 changes: 166 additions & 98 deletions oktoberfest/runner.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import datetime
import json
import logging
import pickle
import sys
import time
import traceback
from multiprocessing import Manager, Pool, Process
from functools import partial
from multiprocessing import Manager, Process, pool
from pathlib import Path
from typing import List, Type, Union
from typing import Dict, List, Tuple, Union # Type, Union

import numpy as np
import pandas as pd
Expand All @@ -26,12 +27,22 @@
logger = logging.getLogger(__name__)


def _make_predictions(batch_df, int_model, irt_model, server_kwargs, queue_out, progress, lock):
def _make_predictions_error_callback(failure_progress_tracker, failure_lock, error):
logger.error(
f"Prediction failed due to: {error} Batch will be missing from output file. "
"DO NOT STOP THIS RUN: The index of the batch is stored and your output file will be appended "
"by the missing batches if you rerun without changing your config file after this run is completed."
)
with failure_lock:
failure_progress_tracker.value += 1


def _make_predictions(int_model, irt_model, server_kwargs, queue_out, progress, lock, batch_df):
predictions = pr.predict(batch_df, model_name=int_model, disable_progress_bar=True, **server_kwargs)
predictions |= pr.predict(batch_df, model_name=irt_model, disable_progress_bar=True, **server_kwargs)
queue_out.put((predictions, batch_df))
with lock:
progress.value += 1
queue_out.put((predictions, batch_df))


def _preprocess(spectra_files: List[Path], config: Config) -> List[Path]:
Expand Down Expand Up @@ -99,7 +110,7 @@ def _annotate_and_get_library(spectra_file: Path, config: Config) -> Spectra:
search = pp.load_search(config.output / "msms" / spectra_file.with_suffix(".rescore").name)
library = pp.merge_spectra_and_peptides(spectra, search)
pp.annotate_spectral_library(library, mass_tol=config.mass_tolerance, unit_mass_tol=config.unit_mass_tolerance)
library.write_as_hdf5(hdf5_path) # write_metadata_annotation
library.write_as_hdf5(hdf5_path).join() # write_metadata_annotation

return library

Expand Down Expand Up @@ -224,19 +235,23 @@ def _speclib_from_digestion(config: Config) -> Spectra:

pp_and_filter_step = ProcessStep(config.output, "speclib_filtered")

data_dir = config.output / "data"
if not pp_and_filter_step.is_done():
data_dir.mkdir(exist_ok=True)
spec_library = pp.process_and_filter_spectra_data(
library=spec_library, model=config.models["intensity"], tmt_label=config.tag
)
spec_library.write_as_hdf5(config.output / "data" / f"{library_file.stem}_filtered.hdf5")
spec_library.write_as_hdf5(data_dir / f"{library_file.stem}_filtered.hdf5").join()
pp_and_filter_step.mark_done()
else:
spec_library = Spectra.from_hdf5(config.output / "data" / f"{library_file.stem}_filtered.hdf5")
spec_library = Spectra.from_hdf5(data_dir / f"{library_file.stem}_filtered.hdf5")

return spec_library


def _get_writer_and_output(results_path: Path, output_format: str):
def _get_writer_and_output(results_path: Path, output_format: str) -> Tuple[SpectralLibrary, Path]:

# spectral_library: Type[SpectralLibrary]

if output_format == "msp":
spectral_library = MSP
Expand All @@ -253,6 +268,54 @@ def _get_writer_and_output(results_path: Path, output_format: str):
return spectral_library, out_file


def _get_batches_and_mode(out_file: Path, failed_batch_file: Path, no_of_spectra: int, batchsize: int):
if out_file.is_file():
if failed_batch_file.is_file():
with open(failed_batch_file, "rb") as fh:
batches = pickle.load(fh)
mode = "a"
logger.warning(
f"Found existing spectral library {out_file}. "
"Attempting to append missing batches from previous run..."
)
else:
logger.error(
f"A file {out_file} already exists but no information about missing batches "
"from a previous run could be found. Stopping to prevent corruption / data loss. "
"If this is intended, delete the file and rerun."
)
sys.exit(1)
else:
batches = range(np.math.ceil(no_of_spectra / batchsize))
mode = "w"

return batches, mode


def _update(pbar: str, postfix_values: Dict[str, int], delay: float = 0.5):
time.sleep(delay)
pbar.set_postfix(**postfix_values)
pbar.n = sum(postfix_values.values())
pbar.refresh()


def _check_write_failed_batch_file(failed_batch_file: Path, n_failed: int, results: List[pool.AsyncResult]) -> bool:
if n_failed > 0:
failed_batches = []
for i, result in enumerate(results):
try:
result.get()
except Exception:
failed_batches.append(i)
logger.error(
f"Prediction for {n_failed} / {i+1} batches failed. Check the log to find out why. "
"Then rerun without changing the config file to append only the missing batches to your output file."
)
with open(failed_batch_file, "wb") as fh:
pickle.dump(failed_batches, fh)
sys.exit(1)


def generate_spectral_lib(config_path: Union[str, Path]):
"""
Create a SpectralLibrary object and generate the spectral library.
Expand All @@ -265,105 +328,110 @@ def generate_spectral_lib(config_path: Union[str, Path]):

spec_library = _speclib_from_digestion(config)

no_of_spectra = len(spec_library.spectra_data)
batchsize = config.batch_size
no_of_sections = np.math.ceil(no_of_spectra / batchsize)

server_kwargs = {
"server_url": config.prediction_server,
"ssl": config.ssl,
}

spectral_library: Type[SpectralLibrary]
results_path = config.output / "results"
results_path.mkdir(exist_ok=True)

writer, out_file = _get_writer_and_output(results_path, config.output_format)
speclib_written_step = ProcessStep(config.output, "speclib_written")
if not speclib_written_step.is_done():

if out_file.is_file():
out_file.unlink()
results_path = config.output / "results"
results_path.mkdir(exist_ok=True)

speclib = writer(out_file, mode="w")
batchsize = config.batch_size
failed_batch_file = config.output / "data" / "speclib_failed_batches.pkl"
writer, out_file = _get_writer_and_output(results_path, config.output_format)
batches, mode = _get_batches_and_mode(out_file, failed_batch_file, len(spec_library.spectra_data), batchsize)
speclib = writer(out_file, mode=mode)

spec_library.spectra_data.rename(
columns={
"MODIFIED_SEQUENCE": "peptide_sequences",
"PRECURSOR_CHARGE": "precursor_charges",
"COLLISION_ENERGY": "collision_energies",
"FRAGMENTATION": "fragmentation_types",
},
inplace=True,
)

with Manager() as manager:

# setup
shared_queue = manager.Queue(maxsize=config.num_threads)
prediction_progress = manager.Value("i", 0)
writing_progress = manager.Value("i", 0)

lock = manager.Lock()

# Create a pool for producer processes
pool = Pool(config.num_threads)

try:
for i in range(no_of_sections):
pool.apply_async(
_make_predictions,
(
spec_library.spectra_data.iloc[i * batchsize : (i + 1) * batchsize],
config.models["intensity"],
config.models["irt"],
server_kwargs,
shared_queue,
prediction_progress,
lock,
),
)
spec_library.spectra_data.rename(
columns={
"MODIFIED_SEQUENCE": "peptide_sequences",
"PRECURSOR_CHARGE": "precursor_charges",
"COLLISION_ENERGY": "collision_energies",
"FRAGMENTATION": "fragmentation_types",
},
inplace=True,
)

# Start the consumer process
with tqdm(total=no_of_sections, desc="Getting predictions", disable=False) as predictor_pbar:
with tqdm(total=no_of_sections, desc="Writing library", disable=False) as writer_pbar:
args = shared_queue, writing_progress
consumer_process = Process(target=speclib.async_write, args=args)
n_batches = len(batches)

with Manager() as manager:

# setup
shared_queue = manager.Queue(maxsize=config.num_threads)
prediction_progress = manager.Value("i", 0)
prediction_failure_progress = manager.Value("i", 0)
writing_progress = manager.Value("i", 0)

lock = manager.Lock()
lock_failure = manager.Lock()

# Create a pool for producer processes
predictor_pool = pool.Pool(config.num_threads)

try:
results = []
for i in batches:
result = predictor_pool.apply_async(
_make_predictions,
(
config.models["intensity"],
config.models["irt"],
server_kwargs,
shared_queue,
prediction_progress,
lock,
spec_library.spectra_data.iloc[i * batchsize : (i + 1) * batchsize],
),
error_callback=partial(
_make_predictions_error_callback, prediction_failure_progress, lock_failure
),
)
results.append(result)
predictor_pool.close()

with tqdm(
total=n_batches, desc="Writing library", postfix={"successful": 0, "missing": 0}
) as writer_pbar:
# Start the consumer process
consumer_process = Process(
target=speclib.async_write,
args=(
shared_queue,
writing_progress,
),
)
consumer_process.start()
while prediction_progress.value < no_of_sections:
time.sleep(0.5)
predictor_pbar.n = prediction_progress.value
predictor_pbar.refresh()
writer_pbar.n = writing_progress.value
writer_pbar.refresh()
while writing_progress.value < no_of_sections:
time.sleep(0.5)
writer_pbar.n = writing_progress.value
writer_pbar.refresh()

# Wait for all producer processes to finish
pool.close()
pool.join()

# Signal the consumer process that producers have finished
shared_queue.put(None)

# Wait for the consumer process to finish
consumer_process.join()

except (KeyboardInterrupt, SystemExit):
logger.error("Caught KeyboardInterrupt, terminating workers")
pool.terminate()
pool.join()
sys.exit(1)

except Exception as e:
logger.error("Caught Unknown exception, terminating workers")
logger.error(traceback.format_exc())
logger.error(e)
pool.terminate()
pool.join()
sys.exit(1)

logger.info("Finished writing the library to disk")
with tqdm(
total=n_batches, desc="Getting predictions", postfix={"successful": 0, "failed": 0}
) as predictor_pbar:
while predictor_pbar.n < n_batches:
print(shared_queue.qsize())
pr_fail_val = prediction_failure_progress.value
_update(predictor_pbar, {"failed": pr_fail_val, "successful": prediction_progress.value})
_update(writer_pbar, {"successful": writing_progress.value, "missing": pr_fail_val})
shared_queue.put(None) # signal the writer process, that it is done
predictor_pool.join() # properly await and terminate the pool

while writer_pbar.n < n_batches: # have to keep updating the writer pbar
_update(
writer_pbar,
{"successful": writing_progress.value, "missing": prediction_failure_progress.value},
)
consumer_process.join() # properly await the termination of the writer process

_check_write_failed_batch_file(failed_batch_file, prediction_failure_progress.value, results)

finally:
predictor_pool.terminate()
predictor_pool.join()
consumer_process.terminate()
consumer_process.join()

logger.info("Finished writing the library to disk")
speclib_written_step.mark_done()


def _ce_calib(spectra_file: Path, config: Config) -> Spectra:
Expand Down

0 comments on commit b8d19ba

Please sign in to comment.