Skip to content

Commit

Permalink
UPDATE logger
Browse files Browse the repository at this point in the history
  • Loading branch information
matbun committed Jun 20, 2024
1 parent 6c66371 commit f45b13f
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 21 deletions.
36 changes: 28 additions & 8 deletions src/itwinai/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ class MLFlowLogger(Logger):
#: Supported kinds in the ``log`` method
supported_kinds: Tuple[str] = (
'metric', 'figure', 'image', 'artifact', 'torch', 'dict', 'param',
'text')
'text', 'model', 'dataset')

#: Current MLFLow experiment's run.
active_run: mlflow.ActiveRun
Expand All @@ -376,7 +376,7 @@ def __init__(
# TODO: for pytorch lightning:
# mlflow.pytorch.autolog()

def create_logger_context(self):
def create_logger_context(self) -> mlflow.ActiveRun:
"""Initialize logger. Start MLFLow run."""
active_run = mlflow.active_run()
if active_run:
Expand All @@ -388,6 +388,7 @@ def create_logger_context(self):
self.active_run: mlflow.ActiveRun = mlflow.start_run(
description=self.run_description
)
return self.active_run

def destroy_logger_context(self):
"""Destroy logger. End current MLFlow run."""
Expand Down Expand Up @@ -446,6 +447,23 @@ def log(
local_path=item,
artifact_path=identifier
)
if kind == 'model':
import torch
if isinstance(item, torch.nn.Module):
mlflow.pytorch.log_model(item, identifier)
else:
print("WARNING: unrecognized model type")
if kind == 'dataset':
# Log mlflow dataset
# https://mlflow.org/docs/latest/python_api/mlflow.html#mlflow.log_input
# It may be needed to convert item into a mlflow dataset, e.g.:
# https://mlflow.org/docs/latest/python_api/mlflow.data.html#mlflow.data.from_pandas
# ATM delegated to the user
if isinstance(item, mlflow.data.Dataset):
mlflow.log_input(item)
else:
print("WARNING: unrecognized dataset type. "
"Must be an MLFlow dataset")
if kind == 'torch':
import torch
# Save the object locally and then log it
Expand Down Expand Up @@ -788,8 +806,8 @@ class Prov4MLLogger(Logger):
#: Supported kinds in the ``log`` method
supported_kinds: Tuple[str] = (
'metric', 'flops_pb', 'flops_pe', 'system', 'carbon',
'execution_time', 'model_version', 'model_version_final',
'param')
'execution_time', 'model', 'best_model',
'torch')

def __init__(
self,
Expand Down Expand Up @@ -888,10 +906,12 @@ def log(
prov4ml.log_model(item, identifier, log_model_info=True,
log_as_artifact=True)
elif kind == 'torch': # LoggingItemKind.PARAMETER.value:
prov4ml.log_param(identifier, item)
elif kind == 'dataset':
# Log torch DataLoader
prov4ml.log_dataset(item, identifier)
from torch.utils.data import DataLoader
if isinstance(item, DataLoader):
prov4ml.log_dataset(item, identifier)
else:
# log_param name is misleading and should be renamed...
prov4ml.log_param(identifier, item)


class EpochTimeTracker:
Expand Down
15 changes: 10 additions & 5 deletions use-cases/3dgan/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import gdown

from itwinai.components import DataGetter, monitor_exec
from itwinai.loggers import Logger as BaseItwinaiLogger


class Lightning3DGANDownloader(DataGetter):
Expand Down Expand Up @@ -158,18 +159,22 @@ def GetDataAngleParallel(

class ParticlesDataModule(pl.LightningDataModule):
def __init__(
self,
datapath: str,
batch_size: int,
num_workers: int = 4,
max_samples: Optional[int] = None
self,
datapath: str,
batch_size: int,
num_workers: int = 4,
max_samples: Optional[int] = None
) -> None:
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.datapath = datapath
self.max_samples = max_samples

@property
def itwinai_logger(self) -> BaseItwinaiLogger:
return self.trainer.itwinai_logger

def setup(self, stage: str = None):
# make assignments here (val/train/test split)
# called on every process in DDP
Expand Down
10 changes: 9 additions & 1 deletion use-cases/3dgan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,14 @@ def __init__(
def itwinai_logger(self) -> BaseItwinaiLogger:
return self.trainer.itwinai_logger

def setup(self, stage: str):
if self.itwinai_logger:
self.itwinai_logger.create_logger_context()

def teardown(self, stage: str):
if self.itwinai_logger:
self.itwinai_logger.destroy_logger_context()

def BitFlip(self, x, prob=0.05):
"""
Flips a single bit according to a certain probability.
Expand Down Expand Up @@ -604,7 +612,7 @@ def training_step(self, batch, batch_idx):

# return avg_disc_loss + avg_generator_loss

def on_train_epoch_end(self): # outputs
def on_train_epoch_end(self):

self._log_provenance(context='training')

Expand Down
17 changes: 10 additions & 7 deletions use-cases/3dgan/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,6 @@ class GANTrainer(LightningTrainer):
def __init__(self, itwinai_logger: Optional[Logger] = None, **kwargs):
super().__init__(**kwargs)
self.itwinai_logger = itwinai_logger
if self.itwinai_logger:
self.itwinai_logger.create_logger_context()

def fit(self, *args, **kwargs):
super().fit(*args, **kwargs)
if self.itwinai_logger:
self.itwinai_logger.destroy_logger_context()


class Lightning3DGANTrainer(Trainer):
Expand Down Expand Up @@ -73,6 +66,16 @@ def execute(self) -> Any:
)
sys.argv = old_argv
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
cli.trainer.itwinai_logger.log(
cli.trainer.train_dataloader,
"train_dataloader",
kind='torch'
)
cli.trainer.itwinai_logger.log(
cli.trainer.val_dataloaders,
"val_dataloader",
kind='torch'
)
teardown_lightning_mlflow()


Expand Down

0 comments on commit f45b13f

Please sign in to comment.