Skip to content

Commit

Permalink
Merge pull request #36 from nasa/24-allow-for-dynamic-plugins
Browse files Browse the repository at this point in the history
24 allow for dynamic plugins
  • Loading branch information
Evana13G authored Sep 14, 2023
2 parents a5f81a4 + a794f99 commit 0500300
Show file tree
Hide file tree
Showing 40 changed files with 352 additions and 98 deletions.
2 changes: 1 addition & 1 deletion driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pytest
import coverage
# coverage started early to see all lines in all files (def and imports were being missed with programmatic runs)
cov = coverage.Coverage(source=['onair'], branch=True)
cov = coverage.Coverage(source=['onair','plugins'], branch=True)
cov.start()

import os
Expand Down
1 change: 1 addition & 0 deletions onair/config/default_config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ TelemetryFiles = ['700_crash_to_earth_1.csv']
ParserFileName = csv_parser
ParserName = CSV
SimName = CSV
PluginList = {'generic_plugin':'plugins/generic/generic_plugin.py'}

[RUN_FLAGS]
IO_Flag = true
Expand Down
18 changes: 11 additions & 7 deletions onair/src/data_driven_components/data_driven_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,22 @@
"""
Data driven learning class for managing all data driven AI components
"""
import importlib
import importlib.util
import importlib.util

from ..util.data_conversion import *

class DataDrivenLearning:
def __init__(self, headers, _ai_plugins:list=[]):
assert(len(headers)>0)
def __init__(self, headers, _ai_plugins={}):
assert(len(headers)>0), 'Headers are required'
self.headers = headers
self.ai_constructs = [
importlib.import_module('onair.src.data_driven_components.' + plugin_name + '.' + f'{plugin_name}_plugin').Plugin(plugin_name, headers) for plugin_name in _ai_plugins
]

self.ai_constructs = []
for module_name in list(_ai_plugins.keys()):
spec = importlib.util.spec_from_file_location(module_name, _ai_plugins[module_name])
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
self.ai_constructs.append(module.Plugin(module_name,headers))

def update(self, curr_data, status):
input_data = curr_data
output_data = status_to_oneHot(status)
Expand Down
6 changes: 3 additions & 3 deletions onair/src/reasoning/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from ..reasoning.diagnosis import Diagnosis

class Agent:
def __init__(self, vehicle):
def __init__(self, vehicle, plugin_list):
self.vehicle_rep = vehicle
self.learning_systems = DataDrivenLearning(self.vehicle_rep.get_headers())
self.learning_systems = DataDrivenLearning(self.vehicle_rep.get_headers(),plugin_list)
self.mission_status = self.vehicle_rep.get_status()
self.bayesian_status = self.vehicle_rep.get_bayesian_status()

Expand All @@ -29,7 +29,7 @@ def reason(self, frame):

def diagnose(self, time_step):
""" Grab the mnemonics from the """
learning_system_results = self.learning_systems.render_reasoning()
learning_system_results = self.learning_systems.render_reasoning()
diagnosis = Diagnosis(time_step,
learning_system_results,
self.bayesian_status,
Expand Down
30 changes: 23 additions & 7 deletions onair/src/run_scripts/execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import shutil
from distutils.dir_util import copy_tree
from time import gmtime, strftime
import ast

from ...data_handling.time_synchronizer import TimeSynchronizer
from ..run_scripts.sim import Simulator
Expand Down Expand Up @@ -49,6 +50,9 @@ def __init__(self, config_file='', run_name='', save_flag=False):
self.sim_name = ''
self.processedSimData = None
self.sim = None

# Init plugins
self.plugin_list = ['']

self.save_flag = save_flag
self.save_name = run_name
Expand All @@ -60,40 +64,52 @@ def __init__(self, config_file='', run_name='', save_flag=False):
self.setup_sim()

def parse_configs(self, config_filepath):
# print("Using config file: {}".format(config_filepath))

config = configparser.ConfigParser()

if len(config.read(config_filepath)) == 0:
raise FileNotFoundError(f"Config file at '{config_filepath}' could not be read.")

try:
## Sort Required Data: Telementry Data & Configuration
## Parse Required Data: Telementry Data & Configuration
self.dataFilePath = config['DEFAULT']['TelemetryDataFilePath']
self.metadataFilePath = config['DEFAULT']['TelemetryMetadataFilePath']
self.metaFiles = config['DEFAULT']['MetaFiles'] # Config for vehicle telemetry
self.telemetryFiles = config['DEFAULT']['TelemetryFiles'] # Vehicle telemetry data

## Sort Required Data: Names
## Parse Required Data: Names
self.parser_file_name = config['DEFAULT']['ParserFileName']
self.parser_name = config['DEFAULT']['ParserName']
self.sim_name = config['DEFAULT']['SimName']

## Parse Required Data: Plugin name to path dict
config_plugin_list = config['DEFAULT']['PluginList']
ast_plugin_list = ast.parse(config_plugin_list, mode='eval')
if isinstance(ast_plugin_list.body, ast.Dict) and len(ast_plugin_list.body.keys) > 0:
temp_plugin_list = ast.literal_eval(config_plugin_list)
else:
raise ValueError(f"{config_plugin_list} is an invalid PluginList. It must be a dict of at least 1 key/value pair.")
for plugin_name in temp_plugin_list.values():
if not(os.path.exists(plugin_name)):
raise FileNotFoundError(f"In config file '{config_filepath}', path '{plugin_name}' does not exist or is formatted incorrectly.")
self.plugin_list = temp_plugin_list
except KeyError as e:
new_message = f"Config file: '{config_filepath}', missing key: {e.args[0]}"
raise KeyError(new_message) from e

## Sort Optional Data: Flags
## Parse Optional Data: Flags
self.IO_Flag = config['RUN_FLAGS'].getboolean('IO_Flag')
self.Dev_Flag = config['RUN_FLAGS'].getboolean('Dev_Flag')
self.SBN_Flag = config['RUN_FLAGS'].getboolean('SBN_Flag')
self.Viz_Flag = config['RUN_FLAGS'].getboolean('Viz_Flag')

## Sort Optional Data: Benchmarks
## Parse Optional Data: Benchmarks
try:
self.benchmarkFilePath = config['DEFAULT']['BenchmarkFilePath']
self.benchmarkFiles = config['DEFAULT']['BenchmarkFiles'] # Vehicle telemetry data
self.benchmarkIndices = config['DEFAULT']['BenchmarkIndices']
except:
pass


def parse_data(self, parser_name, parser_file_name, dataFilePath, metadataFilePath, subsystems_breakdown=False):
parser = importlib.import_module('onair.data_handling.parsers.' + parser_file_name)
Expand All @@ -104,7 +120,7 @@ def parse_data(self, parser_name, parser_file_name, dataFilePath, metadataFilePa
self.processedSimData = TimeSynchronizer(*parsed_data.get_sim_data())

def setup_sim(self):
self.sim = Simulator(self.sim_name, self.processedSimData, self.SBN_Flag)
self.sim = Simulator(self.sim_name, self.processedSimData, self.plugin_list, self.SBN_Flag)
try:
fls = ast.literal_eval(self.benchmarkFiles)
fp = os.path.dirname(os.path.realpath(__file__)) + '/../..' + self.benchmarkFilePath
Expand Down
4 changes: 2 additions & 2 deletions onair/src/run_scripts/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
DIAGNOSIS_INTERVAL = 100

class Simulator:
def __init__(self, simType, parsedData, SBN_Flag):
def __init__(self, simType, parsedData, plugin_list, SBN_Flag):
self.simulator = simType
vehicle = VehicleRepresentation(*parsedData.get_vehicle_metadata())

Expand All @@ -39,7 +39,7 @@ def __init__(self, simType, parsedData, SBN_Flag):

else:
self.simData = DataSource(parsedData.get_sim_data())
self.agent = Agent(vehicle)
self.agent = Agent(vehicle, plugin_list)

def run_sim(self, IO_Flag=False, dev_flag=False, viz_flag = True):
if IO_Flag == True: print_sim_header()
Expand Down
1 change: 1 addition & 0 deletions plugins/generic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .generic_plugin import Plugin
30 changes: 30 additions & 0 deletions plugins/generic/generic_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# GSC-19165-1, "The On-Board Artificial Intelligence Research (OnAIR) Platform"
#
# Copyright © 2023 United States Government as represented by the Administrator of
# the National Aeronautics and Space Administration. No copyright is claimed in the
# United States under Title 17, U.S. Code. All Other Rights Reserved.
#
# Licensed under the NASA Open Source Agreement version 1.3
# See "NOSA GSC-19165-1 OnAIR.pdf"

import numpy as np
from onair.src.data_driven_components.ai_plugin_abstract.core import AIPlugIn

class Plugin(AIPlugIn):
def apriori_training(self,batch_data=[]):
"""
Given data, system should learn any priors necessary for realtime diagnosis.
"""
pass

def update(self,frame=[]):
"""
Given streamed data point, system should update internally
"""
pass

def render_reasoning(self):
"""
System should return its diagnosis
"""
pass
1 change: 1 addition & 0 deletions plugins/kalman_plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .kalman_plugin import Plugin
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import simdkalman
import numpy as np
from ...data_driven_components.generic_component import AIPlugIn
from onair.src.data_driven_components.ai_plugin_abstract.core import AIPlugIn

class Plugin(AIPlugIn):
def __init__(self, name, headers, window_size=3):
Expand Down
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import pytest
from mock import MagicMock

import onair.src.data_driven_components.generic_component.core as core
from onair.src.data_driven_components.generic_component.core import AIPlugIn
import onair.src.data_driven_components.ai_plugin_abstract.core as core
from onair.src.data_driven_components.ai_plugin_abstract.core import AIPlugIn

class FakeAIPlugIn(AIPlugIn):
def __init__(self, _name, _headers):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from onair.src.data_driven_components.data_driven_learning import DataDrivenLearning

import importlib
from typing import Dict

# __init__ tests
def test_DataDrivenLearning__init__sets_instance_headers_to_given_headers_and_does_nothing_else_when_given__ai_plugins_is_empty(mocker):
# Arrange
arg_headers = []
arg__ai_plugins = []
arg__ai_plugins = {}

num_fake_headers = pytest.gen.randint(1, 10) # arbitrary, from 1 to 10 headers (0 has own test)
for i in range(num_fake_headers):
Expand All @@ -34,64 +35,67 @@ def test_DataDrivenLearning__init__sets_instance_headers_to_given_headers_and_do
# Assert
assert cut.headers == arg_headers

def test_DataDrivenLearning__init__sets_instance_ai_constructs_to_a_list_of_the_calls_AIPlugIn_with_plugin_and_given_headers_for_each_item_in_given__ai_plugins(mocker):
def test_DataDrivenLearning__init__throws_AttributeError_when_given_module_file_has_no_attribute_Plugin(mocker):
# Arrange
fake_module_name = MagicMock()
arg_headers = []
arg__ai_plugins = []

num_fake_headers = pytest.gen.randint(1, 10) # arbitrary, from 1 to 10 headers (0 has own test)
for i in range(num_fake_headers):
arg_headers.append(MagicMock())
fake_imported_module = MagicMock()
num_fake_ai_plugins = pytest.gen.randint(1, 10) # arbitrary, from 1 to 10 (0 has own test)
for i in range(num_fake_ai_plugins):
arg__ai_plugins.append(str(MagicMock()))

mocker.patch('importlib.import_module', return_value=fake_imported_module)
arg__ai_plugins = {MagicMock()}

cut = DataDrivenLearning.__new__(DataDrivenLearning)


# Act
cut.__init__(arg_headers, arg__ai_plugins)

# Assert
assert importlib.import_module.call_count == num_fake_ai_plugins
for i in range(num_fake_ai_plugins):
assert importlib.import_module.call_args_list[i].args == ('onair.src.data_driven_components.' + arg__ai_plugins[i] + '.' + arg__ai_plugins[i] + '_plugin',)

def test_DataDrivenLearning__init__sets_instance_ai_constructs_to_a_list_of_the_calls_AIPlugIn_with_plugin_and_given_headers_for_each_item_in_given__ai_plugins_when_given__ai_plugins_is_occupied(mocker):
# Arrange
arg_headers = []
arg__ai_plugins = []
arg__ai_plugins = {}
fake_spec_list = []
fake_module_list = []

num_fake_headers = pytest.gen.randint(1, 10) # arbitrary, from 1 to 10 headers (0 has own test)
for i in range(num_fake_headers):
arg_headers.append(MagicMock())
fake_imported_module = MagicMock()
num_fake_ai_plugins = pytest.gen.randint(1, 10) # arbitrary, from 1 to 10 (0 has own test)
for i in range(num_fake_ai_plugins):
arg__ai_plugins.append(str(MagicMock()))
arg__ai_plugins[str(MagicMock())] = str(MagicMock())
fake_spec_list.append(MagicMock())
fake_module_list.append(MagicMock())


expected_ai_constructs = []
for i in range(num_fake_ai_plugins):
expected_ai_constructs.append(MagicMock())

mocker.patch('importlib.import_module', return_value=fake_imported_module)
mocker.patch.object(fake_imported_module, 'Plugin', side_effect=expected_ai_constructs)
# mocker.patch('importlib.import_module', return_value=fake_imported_module)
mocker.patch('importlib.util.spec_from_file_location',side_effect=fake_spec_list)
mocker.patch('importlib.util.module_from_spec',side_effect=fake_module_list)
for spec in fake_spec_list:
mocker.patch.object(spec,'loader.exec_module')
for i, module in enumerate(fake_module_list):
mocker.patch.object(module,'Plugin',return_value=expected_ai_constructs[i])

cut = DataDrivenLearning.__new__(DataDrivenLearning)

# Act
cut.__init__(arg_headers, arg__ai_plugins)

# Assert
assert importlib.import_module.call_count == num_fake_ai_plugins
for i in range(num_fake_ai_plugins):
assert importlib.import_module.call_args_list[i].args == (f'onair.src.data_driven_components.{arg__ai_plugins[i]}.{arg__ai_plugins[i]}_plugin',)
assert fake_imported_module.Plugin.call_count == num_fake_ai_plugins
assert importlib.util.spec_from_file_location.call_count == len(arg__ai_plugins)
assert importlib.util.module_from_spec.call_count == len(fake_spec_list)

for i in range(num_fake_ai_plugins):
assert fake_imported_module.Plugin.call_args_list[i].args == (arg__ai_plugins[i], arg_headers)

fake_name = list(arg__ai_plugins.keys())[i]
fake_path = arg__ai_plugins[fake_name]
assert importlib.util.spec_from_file_location.call_args_list[i].args == (fake_name,fake_path)
assert importlib.util.module_from_spec.call_args_list[i].args == (fake_spec_list[i],)
assert fake_spec_list[i].loader.exec_module.call_count == 1
assert fake_spec_list[i].loader.exec_module.call_args_list[0].args == (fake_module_list[i],)
assert fake_module_list[i].Plugin.call_count == 1
assert fake_module_list[i].Plugin.call_args_list[0].args == (fake_name,arg_headers)

assert cut.ai_constructs == expected_ai_constructs

# update tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def test_Agent__init__sets_vehicle_rep_to_given_vehicle_and_learning_systems_and
fake_learning_systems = MagicMock()
fake_mission_status = MagicMock()
fake_bayesian_status = MagicMock()
fake_plugin_list = MagicMock()

mocker.patch.object(arg_vehicle, 'get_headers', return_value=fake_headers)
mocker.patch(agent.__name__ + '.DataDrivenLearning', return_value=fake_learning_systems)
Expand All @@ -32,14 +33,14 @@ def test_Agent__init__sets_vehicle_rep_to_given_vehicle_and_learning_systems_and
cut = Agent.__new__(Agent)

# Act
result = cut.__init__(arg_vehicle)
result = cut.__init__(arg_vehicle, fake_plugin_list)

# Assert
assert cut.vehicle_rep == arg_vehicle
assert arg_vehicle.get_headers.call_count == 1
assert arg_vehicle.get_headers.call_args_list[0].args == ()
assert agent.DataDrivenLearning.call_count == 1
assert agent.DataDrivenLearning.call_args_list[0].args == (fake_headers, )
assert agent.DataDrivenLearning.call_args_list[0].args == (fake_headers, fake_plugin_list)
assert cut.learning_systems == fake_learning_systems
assert arg_vehicle.get_status.call_count == 1
assert arg_vehicle.get_status.call_args_list[0].args == ()
Expand Down
File renamed without changes.
Loading

0 comments on commit 0500300

Please sign in to comment.