diff --git a/ecnet/server.py b/ecnet/server.py index f07c537..87741b7 100644 --- a/ecnet/server.py +++ b/ecnet/server.py @@ -24,10 +24,6 @@ resave_df, resave_model, save_config, save_df, save_project, train_model,\ use_model, use_project -# Stdlib imports -from multiprocessing import set_start_method -from os import name - class Server: @@ -52,9 +48,6 @@ def __init__(self, model_config: str='config.yml', prj_file: str=None, self._num_processes = num_processes - if name != 'nt': - set_start_method('spawn', force=True) - if prj_file is not None: self._prj_name, self._num_pools, self._num_candidates, self._df,\ self._cf_file, self._vars = open_project(prj_file) diff --git a/ecnet/tasks/training.py b/ecnet/tasks/training.py index fc3a891..89e6ca8 100644 --- a/ecnet/tasks/training.py +++ b/ecnet/tasks/training.py @@ -13,8 +13,9 @@ # # stdlib. imports -from multiprocessing import Pool +from multiprocessing import Pool, set_start_method from operator import itemgetter +from os import name # ECNet imports from ecnet.utils.logging import logger @@ -48,6 +49,9 @@ def train_project(prj_name: str, num_pools: int, num_candidates: int, num_processes (int): number of concurrent processes used to train ''' + if name != 'nt': + set_start_method('spawn', force=True) + logger.log('info', 'Training {}x{} models'.format( num_pools, num_candidates ), call_loc='TRAIN') diff --git a/ecnet/tasks/tuning.py b/ecnet/tasks/tuning.py index e3bc266..5f3af00 100644 --- a/ecnet/tasks/tuning.py +++ b/ecnet/tasks/tuning.py @@ -8,6 +8,10 @@ # Contains functions/fitness functions for tuning hyperparameters # +# stdlib. imports +from multiprocessing import set_start_method +from os import name + # 3rd party imports from ecabc.abc import ABC @@ -42,6 +46,9 @@ def tune_hyperparameters(df: DataFrame, vars: dict, num_employers: int, dict: tuned hyperparameters ''' + if name != 'nt': + set_start_method('spawn', force=True) + logger.log('info', 'Tuning architecture/learning hyperparameters', call_loc='TUNE') logger.log('debug', 'Arguments:\n\t| num_employers:\t{}\n\t| '