Skip to content

Commit

Permalink
Add FlareClientContext
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh committed Jan 17, 2025
1 parent 84afbed commit 2fc6acb
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 57 deletions.
94 changes: 48 additions & 46 deletions examples/hello-world/hello-pt/src/hello-pt_cifar10_fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor

import nvflare.client as flare
from nvflare.client import FlareClientContext, FLModel
from nvflare.client.tracking import SummaryWriter

DATASET_PATH = "/tmp/nvflare/data"
Expand All @@ -43,53 +43,55 @@ def main():
]
)

flare.init()
sys_info = flare.system_info()
client_name = sys_info["site_name"]
with FlareClientContext() as flare:
sys_info = flare.system_info()
client_name = sys_info["site_name"]

train_dataset = CIFAR10(
root=os.path.join(DATASET_PATH, client_name), transform=transforms, download=True, train=True
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

summary_writer = SummaryWriter()
while flare.is_running():
input_model = flare.receive()
print(f"current_round={input_model.current_round}")

model.load_state_dict(input_model.params)
model.to(device)

steps = epochs * len(train_loader)
for epoch in range(epochs):
running_loss = 0.0
for i, batch in enumerate(train_loader):
images, labels = batch[0].to(device), batch[1].to(device)
optimizer.zero_grad()

predictions = model(images)
cost = loss(predictions, labels)
cost.backward()
optimizer.step()

running_loss += cost.cpu().detach().numpy() / images.size()[0]
if i % 3000 == 0:
print(f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss / 3000}")
global_step = input_model.current_round * steps + epoch * len(train_loader) + i
summary_writer.add_scalar(tag="loss_for_each_batch", scalar=running_loss, global_step=global_step)
running_loss = 0.0

print("Finished Training")

PATH = "./cifar_net.pth"
torch.save(model.state_dict(), PATH)

output_model = flare.FLModel(
params=model.cpu().state_dict(),
meta={"NUM_STEPS_CURRENT_ROUND": steps},
train_dataset = CIFAR10(
root=os.path.join(DATASET_PATH, client_name), transform=transforms, download=True, train=True
)

flare.send(output_model)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

summary_writer = SummaryWriter()
while flare.is_running():
input_model = flare.receive()
print(f"current_round={input_model.current_round}")

model.load_state_dict(input_model.params)
model.to(device)

steps = epochs * len(train_loader)
for epoch in range(epochs):
running_loss = 0.0
for i, batch in enumerate(train_loader):
images, labels = batch[0].to(device), batch[1].to(device)
optimizer.zero_grad()

predictions = model(images)
cost = loss(predictions, labels)
cost.backward()
optimizer.step()

running_loss += cost.cpu().detach().numpy() / images.size()[0]
if i % 3000 == 0:
print(f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss / 3000}")
global_step = input_model.current_round * steps + epoch * len(train_loader) + i
summary_writer.add_scalar(
tag="loss_for_each_batch", scalar=running_loss, global_step=global_step
)
running_loss = 0.0

print("Finished Training")

PATH = "./cifar_net.pth"
torch.save(model.state_dict(), PATH)

output_model = FLModel(
params=model.cpu().state_dict(),
meta={"NUM_STEPS_CURRENT_ROUND": steps},
)

flare.send(output_model)


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion nvflare/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from nvflare.app_common.abstract.fl_model import FLModel as FLModel
from nvflare.app_common.abstract.fl_model import ParamsType as ParamsType

from .api import FlareClientContext as FlareClientContext
from .api import get_config as get_config
from .api import get_job_id as get_job_id
from .api import get_site_name as get_site_name
Expand All @@ -33,4 +34,4 @@
from .api import system_info as system_info
from .decorator import evaluate as evaluate
from .decorator import train as train
from .ipc.ipc_agent import IPCAgent
from .ipc.ipc_agent import IPCAgent as IPCAgent
38 changes: 37 additions & 1 deletion nvflare/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from nvflare.apis.analytix import AnalyticsDataType
from nvflare.app_common.abstract.fl_model import FLModel
from nvflare.client.constants import CLIENT_API_CONFIG
from nvflare.fuel.data_event.data_bus import DataBus

from .api_spec import CLIENT_API_KEY, CLIENT_API_TYPE_KEY, APISpec
Expand All @@ -30,10 +31,45 @@ class ClientAPIType(Enum):
EX_PROCESS_API = "EX_PROCESS_API"


DEFAULT_CONFIG = f"config/{CLIENT_API_CONFIG}"

client_api: Optional[APISpec] = None
data_bus = DataBus()


class FlareClientContext:
def __init__(self, rank: Optional[str] = None, config_file: str = None):
self.rank = rank
self.config_file = config_file if config_file else DEFAULT_CONFIG
self._client_api = None

def __enter__(self):
"""Initialize the client API in the context."""
api_type_name = os.environ.get(CLIENT_API_TYPE_KEY, ClientAPIType.IN_PROCESS_API.value)
api_type = ClientAPIType(api_type_name)

if not self._client_api:
global client_api
client_api = self._create_client_api(api_type)
client_api.init(rank=self.rank)
self._client_api = client_api

return self._client_api

def __exit__(self, exc_type, exc_val, exc_tb):
"""Cleanup the client API when the context ends."""
if self._client_api:
self._client_api.clear()
self._client_api = None

def _create_client_api(self, api_type: ClientAPIType) -> APISpec:
"""Creates a new client_api based on the provided API type."""
if api_type == ClientAPIType.IN_PROCESS_API:
return data_bus.get_data(CLIENT_API_KEY)
else:
return ExProcessClientAPI(config_file=self.config_file)


def init(rank: Optional[str] = None):
"""Initializes NVFlare Client API environment.
Expand All @@ -51,7 +87,7 @@ def init(rank: Optional[str] = None):
if api_type == ClientAPIType.IN_PROCESS_API:
client_api = data_bus.get_data(CLIENT_API_KEY)
else:
client_api = ExProcessClientAPI()
client_api = ExProcessClientAPI(config_file=DEFAULT_CONFIG)
client_api.init(rank=rank)
else:
logging.warning("Warning: called init() more than once. The subsequence calls are ignored")
Expand Down
17 changes: 8 additions & 9 deletions nvflare/client/ex_process/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from nvflare.app_common.abstract.fl_model import FLModel
from nvflare.client.api_spec import APISpec
from nvflare.client.config import ClientConfig, ConfigKey, ExchangeFormat, from_file
from nvflare.client.constants import CLIENT_API_CONFIG
from nvflare.client.flare_agent import FlareAgentException
from nvflare.client.flare_agent_with_fl_model import FlareAgentWithFLModel
from nvflare.client.model_registry import ModelRegistry
Expand Down Expand Up @@ -65,16 +64,17 @@ def _create_pipe_using_config(client_config: ClientConfig, section: str) -> Tupl


class ExProcessClientAPI(APISpec):
def __init__(self):
self.process_model_registry = None
def __init__(self, config_file: str):
self.model_registry = None
self.logger = get_obj_logger(self)
self.receive_called = False
self.config_file = config_file

def get_model_registry(self) -> ModelRegistry:
"""Gets the ModelRegistry."""
if self.process_model_registry is None:
if self.model_registry is None:
raise RuntimeError("needs to call init method first")
return self.process_model_registry
return self.model_registry

def init(self, rank: Optional[str] = None):
"""Initializes NVFlare Client API environment.
Expand All @@ -87,12 +87,11 @@ def init(self, rank: Optional[str] = None):
if rank is None:
rank = os.environ.get("RANK", "0")

if self.process_model_registry:
if self.model_registry:
self.logger.warning("Warning: called init() more than once. The subsequence calls are ignored")
return

config_file = f"config/{CLIENT_API_CONFIG}"
client_config = _create_client_config(config=config_file)
client_config = _create_client_config(config=self.config_file)

flare_agent = None
try:
Expand Down Expand Up @@ -124,7 +123,7 @@ def init(self, rank: Optional[str] = None):
)
flare_agent.start()

self.process_model_registry = ModelRegistry(client_config, rank, flare_agent)
self.model_registry = ModelRegistry(client_config, rank, flare_agent)
except Exception as e:
self.logger.error(f"flare.init failed: {e}")
raise e
Expand Down

0 comments on commit 2fc6acb

Please sign in to comment.