Skip to content

Commit

Permalink
Moved set_start_method to multiprocessed tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
tjkessler committed Jun 22, 2019
1 parent 370abcd commit ea1e422
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
7 changes: 0 additions & 7 deletions ecnet/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@
resave_df, resave_model, save_config, save_df, save_project, train_model,\
use_model, use_project

# Stdlib imports
from multiprocessing import set_start_method
from os import name


class Server:

Expand All @@ -52,9 +48,6 @@ def __init__(self, model_config: str='config.yml', prj_file: str=None,

self._num_processes = num_processes

if name != 'nt':
set_start_method('spawn', force=True)

if prj_file is not None:
self._prj_name, self._num_pools, self._num_candidates, self._df,\
self._cf_file, self._vars = open_project(prj_file)
Expand Down
6 changes: 5 additions & 1 deletion ecnet/tasks/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
#

# stdlib. imports
from multiprocessing import Pool
from multiprocessing import Pool, set_start_method
from operator import itemgetter
from os import name

# ECNet imports
from ecnet.utils.logging import logger
Expand Down Expand Up @@ -48,6 +49,9 @@ def train_project(prj_name: str, num_pools: int, num_candidates: int,
num_processes (int): number of concurrent processes used to train
'''

if name != 'nt':
set_start_method('spawn', force=True)

logger.log('info', 'Training {}x{} models'.format(
num_pools, num_candidates
), call_loc='TRAIN')
Expand Down
7 changes: 7 additions & 0 deletions ecnet/tasks/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
# Contains functions/fitness functions for tuning hyperparameters
#

# stdlib. imports
from multiprocessing import set_start_method
from os import name

# 3rd party imports
from ecabc.abc import ABC

Expand Down Expand Up @@ -42,6 +46,9 @@ def tune_hyperparameters(df: DataFrame, vars: dict, num_employers: int,
dict: tuned hyperparameters
'''

if name != 'nt':
set_start_method('spawn', force=True)

logger.log('info', 'Tuning architecture/learning hyperparameters',
call_loc='TUNE')
logger.log('debug', 'Arguments:\n\t| num_employers:\t{}\n\t| '
Expand Down

0 comments on commit ea1e422

Please sign in to comment.