Skip to content

Commit

Permalink
fix paths
Browse files Browse the repository at this point in the history
  • Loading branch information
Lily Wang committed Aug 8, 2023
1 parent 152a3a8 commit bb22dcd
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 66 deletions.
3 changes: 1 addition & 2 deletions devtools/conda-envs/docs_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ dependencies:
- openeye-toolkits

# database
- sqlalchemy
- sqlite
- pyarrow

# gcn
- dglteam::dgl >=0.7
Expand Down
2 changes: 1 addition & 1 deletion openff/nagl/config/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pydantic import Field

from openff.nagl.training._loss import (
from openff.nagl.training.loss import (
TargetType,

)
Expand Down
2 changes: 1 addition & 1 deletion openff/nagl/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class BaseLayer(ImmutableModel):
description="The activation function to apply for each layer"
)
dropout: float = Field(
default=None,
default=0.0,
description="The dropout to apply after each layer"
)

Expand Down
2 changes: 1 addition & 1 deletion openff/nagl/tests/training/test_loss.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import torch
import numpy as np
from openff.nagl.training._loss import (
from openff.nagl.training.loss import (
MultipleDipoleTarget,
SingleDipoleTarget,
HeavyAtomReadoutTarget,
Expand Down
2 changes: 1 addition & 1 deletion openff/nagl/tests/training/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from openff.nagl.training._metrics import (
from openff.nagl.training.metrics import (
RMSEMetric,
MSEMetric,
MAEMetric,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from openff.nagl._base.metaregistry import create_registry_metaclass
from openff.nagl.molecule._dgl import DGLMoleculeOrBatch
from openff.nagl.training._metrics import MetricType #MetricMeta, BaseMetric
from openff.nagl.training.metrics import MetricType #MetricMeta, BaseMetric
from openff.nagl._base.base import ImmutableModel
from openff.nagl.nn._pooling import PoolingLayer
from openff.nagl.nn._containers import ReadoutModule
Expand Down
File renamed without changes.
6 changes: 3 additions & 3 deletions openff/nagl/training/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

if typing.TYPE_CHECKING:
from openff.nagl.molecule._dgl import DGLMolecule
from openff.nagl.training._metrics import MetricType
from openff.nagl.training.metrics import MetricType


def _encode_image(image):
Expand Down Expand Up @@ -128,7 +128,7 @@ def _generate_jinja_dicts_per_atom(
highlight_outliers: bool = False,
outlier_threshold: float = 1.0,
) -> typing.List[typing.Dict[str, str]]:
from openff.nagl.training._metrics import get_metric_type
from openff.nagl.training.metrics import get_metric_type

metrics = [get_metric_type(metric) for metric in metrics]
jinja_dicts = []
Expand Down Expand Up @@ -217,7 +217,7 @@ def create_atom_label_report(
highlight_outliers: bool = False,
outlier_threshold: float = 1.0,
):
from openff.nagl.training._metrics import get_metric_type
from openff.nagl.training.metrics import get_metric_type

ranker = get_metric_type(rank_by)
metrics = [get_metric_type(metric) for metric in metrics]
Expand Down
113 changes: 57 additions & 56 deletions openff/nagl/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,40 @@ def dataloader():

setattr(self, f"{stage}_dataloader", dataloader)

def _get_dgl_molecule_dataset(
self,
config,
cache_dir,
columns,
):
if config.lazy_loading:
loader = functools.partial(
_LazyDGLMoleculeDataset.from_arrow_dataset,
format="parquet",
atom_features=self.config.model.atom_features,
bond_features=self.config.model.bond_features,
columns=columns,
cache_directory=cache_dir,
use_cached_data=config.use_cached_data,
n_processes=self.n_processes,
)
else:
loader = functools.partial(
DGLMoleculeDataset.from_arrow_dataset,
format="parquet",
atom_features=self.config.model.atom_features,
bond_features=self.config.model.bond_features,
columns=columns,
n_processes=self.n_processes,
)

datasets = []
for path in config.sources:
ds = loader(path)
datasets.append(ds)
dataset = torch.utils.data.ConcatDataset(datasets)
return dataset

def prepare_data(self):
for stage, config in self._dataset_configs.items():
if config is None or not config.sources:
Expand Down Expand Up @@ -188,34 +222,13 @@ def prepare_data(self):
logger.info(f"Loading cached data from {pickle_hash}")
continue

if config.lazy_loading:
loader = functools.partial(
_LazyDGLMoleculeDataset.from_arrow_dataset,
format="parquet",
atom_features=self.config.model.atom_features,
bond_features=self.config.model.bond_features,
columns=columns,
cache_directory=cache_dir,
use_cached_data=config.use_cached_data,
n_processes=self.n_processes,
)
else:
loader = functools.partial(
DGLMoleculeDataset.from_arrow_dataset,
format="parquet",
atom_features=self.config.model.atom_features,
bond_features=self.config.model.bond_features,
columns=columns,
n_processes=self.n_processes,
)

datasets = []
for path in config.sources:
ds = loader(path)
datasets.append(ds)

if not config.lazy_loading:
dataset = torch.utils.data.ConcatDataset(datasets)
dataset = self._get_dgl_molecule_dataset(
config=config,
cache_dir=cache_dir,
columns=columns,
)

if not config.lazy_loading and config.use_cached_data:
with open(pickle_hash, "wb") as f:
pickle.dump(dataset, f)
logger.info(f"Saved data to {pickle_hash}")
Expand All @@ -226,39 +239,27 @@ def _setup_stage(self, config, stage: str):

cache_dir = config.cache_directory if config.cache_directory else "."
columns = config.get_required_target_columns()
pickle_hash = self._get_hash_file(
paths=config.sources,
columns=columns,
cache_directory=cache_dir,
extension=".pkl"
)
if pickle_hash.exists():
with open(pickle_hash, "rb") as f:
ds = pickle.load(f)
return ds
if not config.lazy_loading:
raise FileNotFoundError(
f"Data not found for stage {stage}: {pickle_hash}"
if config.use_cached_data or config.lazy_loading:
pickle_hash = self._get_hash_file(
paths=config.sources,
columns=columns,
cache_directory=cache_dir,
extension=".pkl"
)
loader = functools.partial(
_LazyDGLMoleculeDataset.from_arrow_dataset,
format="parquet",
atom_features=self.config.model.atom_features,
bond_features=self.config.model.bond_features,

if pickle_hash.exists():
with open(pickle_hash, "rb") as f:
ds = pickle.load(f)
return ds

dataset = self._get_dgl_molecule_dataset(
config=config,
cache_dir=cache_dir,
columns=columns,
cache_directory=cache_dir,
use_cached_data=config.use_cached_data,
n_processes=self.n_processes,
)
datasets = []
for path in config.sources:
ds = loader(path)
datasets.append(ds)
dataset = torch.utils.data.ConcatDataset(datasets)
return dataset




def setup(self, **kwargs):
for stage, config in self._dataset_configs.items():
dataset = self._setup_stage(config, stage)
Expand Down

0 comments on commit bb22dcd

Please sign in to comment.