Skip to content

Commit

Permalink
Added engine.add_component() function. (#2621)
Browse files Browse the repository at this point in the history
Co-authored-by: Zhihong Zhang <[email protected]>
  • Loading branch information
yhwen and nvidianz authored Jun 7, 2024
1 parent 96995c3 commit 265c21c
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 0 deletions.
22 changes: 22 additions & 0 deletions nvflare/apis/client_engine_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,28 @@ def new_context(self) -> FLContext:
# the engine must use FLContextManager to create a new context!
pass

@abstractmethod
def add_component(self, component_id: str, component):
"""Add a component into the system.
Args:
component_id: component ID
component: component object
Returns:
"""
pass

@abstractmethod
def get_component(self, component_id: str) -> object:
"""Retrieve the system component from the engine.
Args:
component_id: component ID
Returns:
component object
"""
pass
22 changes: 22 additions & 0 deletions nvflare/apis/server_engine_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,30 @@ def new_context(self) -> FLContext:
def get_workspace(self) -> Workspace:
pass

@abstractmethod
def add_component(self, component_id: str, component):
"""Add a component into the system.
Args:
component_id: component ID
component: component object
Returns:
"""
pass

@abstractmethod
def get_component(self, component_id: str) -> object:
"""Retrieve the system component from the engine.
Args:
component_id: component ID
Returns:
component object
"""
pass

@abstractmethod
Expand Down
11 changes: 11 additions & 0 deletions nvflare/private/fed/client/client_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@ def set_agent(self, admin_agent):
def new_context(self) -> FLContext:
return self.fl_ctx_mgr.new_context()

def add_component(self, component_id: str, component):
if not isinstance(component_id, str):
raise TypeError(f"component id must be str but got {type(component_id)}")

if component_id in self.client.components:
raise ValueError(f"duplicate component id {component_id}")

self.client.components[component_id] = component
if isinstance(component, FLComponent):
self.fl_components.append(component)

def get_component(self, component_id: str) -> object:
return self.client.components.get(component_id)

Expand Down
3 changes: 3 additions & 0 deletions nvflare/private/fed/client/client_run_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ def reset_errors(self) -> ClientRunInfo:
def dispatch(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable:
return self.aux_runner.dispatch(topic=topic, request=request, fl_ctx=fl_ctx)

def add_component(self, component_id: str, component):
self.client.runner_config.add_component(component_id, component)

def get_component(self, component_id: str) -> object:
return self.components.get(component_id)

Expand Down
3 changes: 3 additions & 0 deletions nvflare/private/fed/server/server_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,9 @@ def new_context(self) -> FLContext:
engine=self, identity_name=self.server.project_name, job_id="", public_stickers={}, private_stickers={}
).new_context()

def add_component(self, component_id: str, component):
self.server.runner_config.add_component(component_id, component)

def get_component(self, component_id: str) -> object:
return self.run_manager.get_component(component_id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def new_context(self):
def get_workspace(self):
pass

def add_component(self, component_id: str, component):
pass

def get_component(self, component_id: str) -> object:
pass

Expand Down

0 comments on commit 265c21c

Please sign in to comment.