-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #23 from ECRL/dev
Bug fixes, GA improvements, data sorting options, optimizations
- Loading branch information
Showing
18 changed files
with
649 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
File renamed without changes
Binary file added
BIN
+697 KB
docs/tutorials/Getting Started/Getting Started with Predicting Fuel Properties.pdf
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
25 changes: 25 additions & 0 deletions
25
docs/tutorials/Getting Started/scripts/limit_input_descriptors.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from ecnet import Server | ||
|
||
|
||
def main(): | ||
|
||
sv = Server(log_level='debug', num_processes=4) | ||
sv.import_data( | ||
'../kv_model_v1.0_full.csv', | ||
sort_type='random', | ||
data_split=[0.6, 0.2, 0.2] | ||
) | ||
sv.limit_input_parameters( | ||
limit_num=15, | ||
output_filename='../kv_model_v1.0.csv', | ||
use_genetic=True, | ||
population_size=30, | ||
num_generations=10, | ||
mut_rate=0.2, | ||
max_mut_amt=1 | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
main() |
29 changes: 29 additions & 0 deletions
29
docs/tutorials/Getting Started/scripts/train_select_candidates.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from ecnet import Server | ||
|
||
|
||
def main(): | ||
|
||
sv = Server(log_level='debug', num_processes=4) | ||
sv.import_data( | ||
'../kv_model_v1.0.csv', | ||
sort_type='random', | ||
data_split=[0.6, 0.2, 0.2] | ||
) | ||
sv.create_project( | ||
'kinetic_viscosity', | ||
num_builds=1, | ||
num_nodes=5, | ||
num_candidates=25 | ||
) | ||
sv.train_model( | ||
validate=True, | ||
shuffle='train', | ||
data_split=[0.6, 0.2, 0.2] | ||
) | ||
sv.select_best(dset='test') | ||
sv.save_project(clean_up=True) | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
main() |
20 changes: 20 additions & 0 deletions
20
docs/tutorials/Getting Started/scripts/tune_hyperparameters.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from ecnet import Server | ||
|
||
|
||
def main(): | ||
|
||
sv = Server(log_level='debug', num_processes=4) | ||
sv.import_data( | ||
'../kv_model_v1.0.csv', | ||
sort_type='random', | ||
data_split=[0.6, 0.2, 0.2] | ||
) | ||
sv.tune_hyperparameters( | ||
num_employers=30, | ||
num_iterations=10 | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from ecnet import Server | ||
|
||
|
||
def main(): | ||
|
||
sv = Server( | ||
project_file='kinetic_viscosity.prj', | ||
log_level='debug' | ||
) | ||
sv.use_model( | ||
dset='test', | ||
output_filename='../kv_test_results.csv' | ||
) | ||
sv.calc_error( | ||
'rmse', | ||
'mean_abs_error', | ||
'med_abs_error', | ||
'r2', | ||
dset='test' | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/data_utils.py | ||
# v.2.1.0 | ||
# v.2.1.1 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains the "DataFrame" class, and functions for processing/importing/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,47 +2,30 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/error_utils.py | ||
# v.2.1.0 | ||
# v.2.1.1 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions for error calculations | ||
# | ||
|
||
# 3rd party imports | ||
from numpy import absolute, asarray, median, sqrt as nsqrt, sum as nsum | ||
from numpy import absolute, asarray, float64, isinf, isnan, median,\ | ||
nan_to_num, square, sqrt as nsqrt, sum as nsum | ||
|
||
|
||
def calc_rmse(y_hat, y): | ||
|
||
try: | ||
return(nsqrt(((y_hat-y)**2).mean())) | ||
except: | ||
try: | ||
return(nsqrt(((asarray(y_hat)-asarray(y))**2).mean())) | ||
except: | ||
raise ValueError('Check input data format.') | ||
return nsqrt((square(_get_diff(y_hat, y)).mean())) | ||
|
||
|
||
def calc_mean_abs_error(y_hat, y): | ||
|
||
try: | ||
return(abs(y_hat-y).mean()) | ||
except: | ||
try: | ||
return(abs(asarray(y_hat)-asarray(y)).mean()) | ||
except: | ||
raise ValueError('Check input data format.') | ||
return absolute(_get_diff(y_hat, y)).mean() | ||
|
||
|
||
def calc_med_abs_error(y_hat, y): | ||
|
||
try: | ||
return(median(absolute(y_hat-y))) | ||
except: | ||
try: | ||
return(median(absolute(asarray(y_hat)-asarray(y)))) | ||
except: | ||
raise ValueError('Check input data format.') | ||
return median(absolute(_get_diff(y_hat, y))) | ||
|
||
|
||
def calc_r2(y_hat, y): | ||
|
@@ -57,14 +40,12 @@ def calc_r2(y_hat, y): | |
y_mean = sum(y_form)/len(y_form) | ||
except: | ||
raise ValueError('Check input data format.') | ||
try: | ||
s_res = nsum((y_hat-y)**2) | ||
s_tot = nsum((y-y_mean)**2) | ||
return(1 - (s_res/s_tot)) | ||
except: | ||
try: | ||
s_res = nsum((asarray(y_hat)-asarray(y))**2) | ||
s_tot = nsum((asarray(y)-y_mean)**2) | ||
return(1 - (s_res/s_tot)) | ||
except: | ||
raise ValueError('Check input data format.') | ||
|
||
s_res = nsum(square(_get_diff(y_hat, y))) | ||
s_tot = nsum(square(_get_diff(y, y_mean))) | ||
return(1 - (s_res / s_tot)) | ||
|
||
|
||
def _get_diff(y_hat, y): | ||
|
||
return asarray(y_hat, dtype=float64) - asarray(y, dtype=float64) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/fitness_functions.py | ||
# v.2.1.0 | ||
# v.2.1.1 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains fitness functions used for input dimensionality reduction, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/limit_parameters.py | ||
# v.2.1.0 | ||
# v.2.1.1 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains the functions necessary for reducing the input dimensionality of a | ||
|
@@ -158,7 +158,7 @@ def limit_iterative_include(DataFrame, limit_num, vars, logger=None): | |
|
||
def limit_genetic(DataFrame, limit_num, vars, population_size, num_generations, | ||
num_processes, shuffle=False, data_split=[0.65, 0.25, 0.1], | ||
logger=None): | ||
mut_rate=0, max_mut_amt=0, logger=None): | ||
'''Limits the dimensionality of input data using a genetic algorithm | ||
Args: | ||
|
@@ -171,6 +171,10 @@ def limit_genetic(DataFrame, limit_num, vars, population_size, num_generations, | |
shuffle (bool): whether to shuffle the data sets for each population | ||
member | ||
data_split (list): [learn%, valid%, test%] if shuffle == True | ||
mut_rate (float): probability that a population member is subject | ||
to mutation | ||
max_mut_amt (float): if mutating, how much a parameter can mutate | ||
(proportionally) | ||
logger (ColorLogger): ColorLogger object; if not supplied, does not log | ||
Returns: | ||
|
@@ -206,38 +210,47 @@ def limit_genetic(DataFrame, limit_num, vars, population_size, num_generations, | |
|
||
population.generate_population() | ||
if logger is not None: | ||
logger.log('debug', 'Generation: 0 - Population fitness: {}'.format( | ||
sum(p.fitness_score for p in population.members) / len(population), | ||
logger.log( | ||
'debug', | ||
'Generation: 0 - Population fitness: {:.5f}'.format( | ||
float(population.med_cost_fn_val) | ||
), call_loc='LIMIT' | ||
) | ||
logger.log('debug', '\tBest fitness: {}'.format( | ||
population.best_cost_fn_val | ||
), call_loc='LIMIT') | ||
logger.log('debug', '\tBest parameters: {}'.format( | ||
[DataFrame.input_names[val] for val in | ||
population.best_parameters.values()] | ||
), call_loc='LIMIT') | ||
|
||
for gen in range(num_generations): | ||
population.next_generation() | ||
population.next_generation(mut_rate, max_mut_amt) | ||
if logger is not None: | ||
logger.log( | ||
'debug', | ||
'Generation: {} - Population fitness: {}'.format( | ||
'Generation: {} - Population fitness: {:.5f}'.format( | ||
gen + 1, | ||
sum( | ||
p.fitness_score for p in population.members | ||
) / len(population) | ||
float(population.med_cost_fn_val) | ||
), | ||
call_loc='LIMIT' | ||
) | ||
logger.log('debug', '\tBest fitness: {}'.format( | ||
population.best_cost_fn_val | ||
), call_loc='LIMIT') | ||
logger.log('debug', '\tBest parameters: {}'.format( | ||
[DataFrame.input_names[val] for val in | ||
population.best_parameters.values()] | ||
), call_loc='LIMIT') | ||
|
||
min_idx = 0 | ||
for new_idx, member in enumerate(population.members): | ||
if member.fitness_score < population.members[min_idx].fitness_score: | ||
min_idx = new_idx | ||
|
||
input_list = [] | ||
for val in population.members[min_idx].parameters.values(): | ||
input_list.append(DataFrame.input_names[val]) | ||
input_list = [DataFrame.input_names[val] for val in | ||
population.best_parameters.values()] | ||
|
||
if logger is not None: | ||
logger.log( | ||
'debug', | ||
'Best member fitness score: {}'.format( | ||
population.members[min_idx].fitness_score | ||
population.best_cost_fn_val | ||
), | ||
call_loc='LIMIT' | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# | ||
# ecnet/error_utils.py | ||
# v.2.1.0 | ||
# v.2.1.1 | ||
# Developed in 2019 by Travis Kessler <[email protected]> | ||
# | ||
# Contains functions necessary creating, training, saving, and reusing neural | ||
|
@@ -17,14 +17,19 @@ | |
from random import uniform | ||
from multiprocessing import current_process | ||
from copy import deepcopy | ||
from warnings import filterwarnings | ||
|
||
# 3rd party imports | ||
from tensorflow import add, global_variables_initializer, matmul, nn | ||
from tensorflow import placeholder, random_normal, reset_default_graph | ||
from tensorflow import Session, square, train, Variable | ||
from numpy import asarray, sqrt as nsqrt | ||
from numpy import asarray, isinf, isnan, nan_to_num, seterr, sqrt as nsqrt | ||
|
||
# ECNet imports | ||
from ecnet.error_utils import calc_rmse | ||
|
||
environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | ||
# filterwarnings('ignore', category=RuntimeWarning) | ||
|
||
|
||
def __linear_fn(n): | ||
|
@@ -214,9 +219,10 @@ def fit_validation(self, x_l, y_l, x_v, y_v, learning_rate=0.1, | |
sess.run(optimizer, feed_dict={x: x_l, y: y_l}) | ||
current_epoch += 1 | ||
if current_epoch % 250 == 0: | ||
valid_rmse = self.__calc_rmse( | ||
sess.run(pred, feed_dict={x: x_v}), y_v | ||
) | ||
valid_preds = sess.run(pred, feed_dict={x: x_v}) | ||
if isnan(valid_preds).any() or isinf(valid_preds).any(): | ||
valid_preds = nan_to_num(valid_preds) | ||
valid_rmse = calc_rmse(valid_preds, y_v) | ||
if valid_rmse < valid_rmse_lowest: | ||
valid_rmse_lowest = valid_rmse | ||
elif valid_rmse > valid_rmse_lowest + ( | ||
|
@@ -243,6 +249,8 @@ def use(self, x): | |
saver.restore(sess, self._filename) | ||
results = self.__feed_forward(x).eval() | ||
sess.close() | ||
if isnan(results).any() or isinf(results).any(): | ||
results = nan_to_num(results) | ||
return results | ||
|
||
def save(self, filepath=None): | ||
|
@@ -328,9 +336,13 @@ def __calc_rmse(self, y_hat, y): | |
''' | ||
|
||
try: | ||
return(nsqrt(((y_hat - y)**2).mean())) | ||
diff = (y_hat - y) | ||
except: | ||
return(nsqrt(((asarray(y_hat) - asarray(y))**2).mean())) | ||
diff = (asarray(y_hat) - asarray(y)) | ||
for i, d in enumerate(diff): | ||
if isnan(d): | ||
diff[i] = nan_to_num(d) | ||
return(nsqrt((diff**2).mean())) | ||
|
||
|
||
def train_model(validate, sets, vars, save_path=None, id=None): | ||
|
Oops, something went wrong.