Skip to content

Commit

Permalink
Merge pull request #17 from ecrl/20250111_improvements
Browse files Browse the repository at this point in the history
Various improvements
  • Loading branch information
tjkessler authored Jan 12, 2025
2 parents ae30ce4 + 10b7c89 commit 34120c3
Show file tree
Hide file tree
Showing 20 changed files with 1,793 additions and 1,343 deletions.
111 changes: 111 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Prerequisites
*.d

# Compiled object files
*.slo
*.lo
*.o
*.obj

# Precompiled headers
*.gch
*.pch

# Compiled dynamic libraries
*.so
*.so.[0-9]*
*.dylib
*.dll

# Fortran module files
*.mod
*.smod

# Compiled static libraries
*.lai
*.la
*.a
*.lib

# Executables
*.exe
*.out
*.app

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
.installed.cfg
MANIFEST
*.egg-info/
*.egg
*.manifest
*.spec
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.pytest_cache/

# Documentation
doc/html/
doc/latex/
doc/man/
doc/xml/
doc/_build/
doc/source
doc/modules

# Environments
.env
.venv
env/
venv/
ENV/

# Editor junk
tags
[._]*.s[a-v][a-z]
[._]*.sw[a-p]
[._]s[a-v][a-z]
[._]sw[a-p]
*~
\#*\#
.\#*
.ropeproject
.idea/
.spyderproject
.spyproject
.vscode/
# Mac .DS_Store
.DS_Store

# jupyter notebook checkpoints
.ipynb_checkpoints

# version file
graphchem/_version.py
287 changes: 59 additions & 228 deletions examples/comparison/train_cn.ipynb

Large diffs are not rendered by default.

195 changes: 97 additions & 98 deletions examples/comparison/train_mon.ipynb

Large diffs are not rendered by default.

195 changes: 97 additions & 98 deletions examples/comparison/train_ron.ipynb

Large diffs are not rendered by default.

241 changes: 113 additions & 128 deletions examples/predict_cn.ipynb

Large diffs are not rendered by default.

232 changes: 111 additions & 121 deletions examples/predict_lhv.ipynb

Large diffs are not rendered by default.

230 changes: 110 additions & 120 deletions examples/predict_mon.ipynb

Large diffs are not rendered by default.

232 changes: 111 additions & 121 deletions examples/predict_ron.ipynb

Large diffs are not rendered by default.

232 changes: 111 additions & 121 deletions examples/predict_ysi.ipynb

Large diffs are not rendered by default.

136 changes: 91 additions & 45 deletions graphchem/data/structs.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,117 @@
r"""Molecule graph structure, graph dataset"""
from typing import Iterable, Optional

from typing import List
import torch
import torch_geometric


class MoleculeGraph(torch_geometric.data.Data):

def __init__(self, atom_attr: 'torch.tensor', bond_attr: 'torch.tensor',
connectivity: 'torch.tensor', target: 'torch.tensor' = None):
""" MoleculeGraph object, extends torch_geometric.data.Data object; a
singular molecule graph/data point
Args:
atom_attr (torch.tensor): atom features, shape (n_atoms,
n_atom_features); dtype assumed torch.float32
bond_attr (torch.tensor): bond features, shape (n_bonds,
n_bond_features); dtype assumed torch.float32
connectivity (torch.tensor): COO graph connectivity index, size
(2, n_bonds); dtype assumed torch.long
target (torch.tensor, default=None): target value(s), shape
(1, n_targets); if not supplied (None), set to [0.0]; dtype
assumed torch.float32
from torch_geometric.data import Data, Dataset


class MoleculeGraph(Data):
"""
A custom graph class representing a molecular structure.
This class extends the `Data` class from PyTorch Geometric to represent
molecules with node attributes (atoms), edge attributes (bonds), and
connectivity information. It also includes an optional target value.
Attributes
----------
x : torch.Tensor
The node features (atom attributes).
edge_index : torch.Tensor
A 2D tensor describing the connectivity between atoms.
edge_attr : torch.Tensor
Edge features (bond attributes).
y : torch.Tensor
Target value(s) of the molecule.
"""

def __init__(self, atom_attr: torch.Tensor,
bond_attr: torch.Tensor,
connectivity: torch.Tensor,
target: Optional[torch.Tensor] = None):
"""
Initialize the MoleculeGraph object.
Parameters
----------
atom_attr : torch.Tensor
A 2D tensor of shape (num_atoms, num_atom_features) representing
the attributes of each atom in the molecule.
bond_attr : torch.Tensor
A 2D tensor of shape (num_bonds, num_bond_features) representing
the attributes of each bond in the molecule.
connectivity : torch.Tensor
A 2D tensor of shape (2, num_bonds) where each column represents an
edge (bond) between two atoms. The first row contains the source
atom indices and the second row contains the target atom indices.
target : Optional[torch.Tensor]
An optional 1D or 2D tensor representing the target value(s) of the
molecule. If not provided, it defaults to a tensor with a single
element set to 0.0.
"""

if target is None:
# Set default target to a tensor with shape (1, 1) and value 0.0
target = torch.tensor([0.0]).type(torch.float32).reshape(1, 1)
elif len(target.shape) == 1:
# Reshape target if it's a 1D tensor to (1, target.shape[0])
target = target.reshape(1, -1)
if target.shape[0] != 1:
raise ValueError("Target tensor must have shape (1, num_targets)")

super(MoleculeGraph, self).__init__(
super().__init__(
x=atom_attr,
edge_index=connectivity,
edge_attr=bond_attr,
y=torch.tensor(target).type(torch.float).reshape(1, len(target))
y=target
)


class MoleculeDataset(torch_geometric.data.Dataset):
class MoleculeDataset(Dataset):
"""
A custom dataset class for molecular graphs.
This class extends the `Dataset` class from PyTorch Geometric to create a
dataset of molecular graphs. Each graph in the dataset is an instance of
`MoleculeGraph`.
def __init__(self, graphs: List['MoleculeGraph']):
""" MoleculeDataset object, extends torch_geometric.data.Dataset
object; a torch_geometric-iterable dataset comprised of MoleculeGraph
objects
Attributes
----------
_graphs : List[MoleculeGraph]
A list containing all the molecule graphs in the dataset.
"""

Args:
graphs (List[MoleculeGraph]): list of molecule graphs
def __init__(self, graphs: Iterable[MoleculeGraph]):
"""
Initialize the MoleculeDataset object.
super(MoleculeDataset, self).__init__()
self._graphs = graphs
Parameters
----------
graphs : Iterable[MoleculeGraph]
An iterable of `MoleculeGraph` instances representing the
molecules in the dataset.
"""

def len(self) -> int:
""" torch_geometric.data.Dataset.len definition (required)
super().__init__()
self._graphs = list(graphs)

Returns:
int: number of molecule graphs
def len(self) -> int:
"""
Returns the number of molecules in the dataset.
Returns
-------
int
The number of molecule graphs in the dataset.
"""
return len(self._graphs)

def get(self, idx: int) -> 'MoleculeGraph':
""" torch_geometric.data.Dataset.get definition (required)
Args:
idx (int): index of item
Returns:
MoleculeGraph: indexed item
def get(self, idx: int) -> MoleculeGraph:
"""
Retrieves a molecule graph from the dataset by index.
Returns
-------
MoleculeGraph
The molecule graph at the specified index.
"""
return self._graphs[idx]
Loading

0 comments on commit 34120c3

Please sign in to comment.