Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various improvements #17

Merged
merged 7 commits into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Prerequisites
*.d

# Compiled object files
*.slo
*.lo
*.o
*.obj

# Precompiled headers
*.gch
*.pch

# Compiled dynamic libraries
*.so
*.so.[0-9]*
*.dylib
*.dll

# Fortran module files
*.mod
*.smod

# Compiled static libraries
*.lai
*.la
*.a
*.lib

# Executables
*.exe
*.out
*.app

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
.installed.cfg
MANIFEST
*.egg-info/
*.egg
*.manifest
*.spec
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.pytest_cache/

# Documentation
doc/html/
doc/latex/
doc/man/
doc/xml/
doc/_build/
doc/source
doc/modules

# Environments
.env
.venv
env/
venv/
ENV/

# Editor junk
tags
[._]*.s[a-v][a-z]
[._]*.sw[a-p]
[._]s[a-v][a-z]
[._]sw[a-p]
*~
\#*\#
.\#*
.ropeproject
.idea/
.spyderproject
.spyproject
.vscode/
# Mac .DS_Store
.DS_Store

# jupyter notebook checkpoints
.ipynb_checkpoints

# version file
graphchem/_version.py
287 changes: 59 additions & 228 deletions examples/comparison/train_cn.ipynb

Large diffs are not rendered by default.

195 changes: 97 additions & 98 deletions examples/comparison/train_mon.ipynb

Large diffs are not rendered by default.

195 changes: 97 additions & 98 deletions examples/comparison/train_ron.ipynb

Large diffs are not rendered by default.

241 changes: 113 additions & 128 deletions examples/predict_cn.ipynb

Large diffs are not rendered by default.

232 changes: 111 additions & 121 deletions examples/predict_lhv.ipynb

Large diffs are not rendered by default.

230 changes: 110 additions & 120 deletions examples/predict_mon.ipynb

Large diffs are not rendered by default.

232 changes: 111 additions & 121 deletions examples/predict_ron.ipynb

Large diffs are not rendered by default.

232 changes: 111 additions & 121 deletions examples/predict_ysi.ipynb

Large diffs are not rendered by default.

136 changes: 91 additions & 45 deletions graphchem/data/structs.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,117 @@
r"""Molecule graph structure, graph dataset"""
from typing import Iterable, Optional

from typing import List
import torch
import torch_geometric


class MoleculeGraph(torch_geometric.data.Data):

def __init__(self, atom_attr: 'torch.tensor', bond_attr: 'torch.tensor',
connectivity: 'torch.tensor', target: 'torch.tensor' = None):
""" MoleculeGraph object, extends torch_geometric.data.Data object; a
singular molecule graph/data point

Args:
atom_attr (torch.tensor): atom features, shape (n_atoms,
n_atom_features); dtype assumed torch.float32
bond_attr (torch.tensor): bond features, shape (n_bonds,
n_bond_features); dtype assumed torch.float32
connectivity (torch.tensor): COO graph connectivity index, size
(2, n_bonds); dtype assumed torch.long
target (torch.tensor, default=None): target value(s), shape
(1, n_targets); if not supplied (None), set to [0.0]; dtype
assumed torch.float32
from torch_geometric.data import Data, Dataset


class MoleculeGraph(Data):
"""
A custom graph class representing a molecular structure.

This class extends the `Data` class from PyTorch Geometric to represent
molecules with node attributes (atoms), edge attributes (bonds), and
connectivity information. It also includes an optional target value.

Attributes
----------
x : torch.Tensor
The node features (atom attributes).
edge_index : torch.Tensor
A 2D tensor describing the connectivity between atoms.
edge_attr : torch.Tensor
Edge features (bond attributes).
y : torch.Tensor
Target value(s) of the molecule.
"""

def __init__(self, atom_attr: torch.Tensor,
bond_attr: torch.Tensor,
connectivity: torch.Tensor,
target: Optional[torch.Tensor] = None):
"""
Initialize the MoleculeGraph object.

Parameters
----------
atom_attr : torch.Tensor
A 2D tensor of shape (num_atoms, num_atom_features) representing
the attributes of each atom in the molecule.
bond_attr : torch.Tensor
A 2D tensor of shape (num_bonds, num_bond_features) representing
the attributes of each bond in the molecule.
connectivity : torch.Tensor
A 2D tensor of shape (2, num_bonds) where each column represents an
edge (bond) between two atoms. The first row contains the source
atom indices and the second row contains the target atom indices.
target : Optional[torch.Tensor]
An optional 1D or 2D tensor representing the target value(s) of the
molecule. If not provided, it defaults to a tensor with a single
element set to 0.0.
"""

if target is None:
# Set default target to a tensor with shape (1, 1) and value 0.0
target = torch.tensor([0.0]).type(torch.float32).reshape(1, 1)
elif len(target.shape) == 1:
# Reshape target if it's a 1D tensor to (1, target.shape[0])
target = target.reshape(1, -1)
if target.shape[0] != 1:
raise ValueError("Target tensor must have shape (1, num_targets)")

super(MoleculeGraph, self).__init__(
super().__init__(
x=atom_attr,
edge_index=connectivity,
edge_attr=bond_attr,
y=torch.tensor(target).type(torch.float).reshape(1, len(target))
y=target
)


class MoleculeDataset(torch_geometric.data.Dataset):
class MoleculeDataset(Dataset):
"""
A custom dataset class for molecular graphs.

This class extends the `Dataset` class from PyTorch Geometric to create a
dataset of molecular graphs. Each graph in the dataset is an instance of
`MoleculeGraph`.

def __init__(self, graphs: List['MoleculeGraph']):
""" MoleculeDataset object, extends torch_geometric.data.Dataset
object; a torch_geometric-iterable dataset comprised of MoleculeGraph
objects
Attributes
----------
_graphs : List[MoleculeGraph]
A list containing all the molecule graphs in the dataset.
"""

Args:
graphs (List[MoleculeGraph]): list of molecule graphs
def __init__(self, graphs: Iterable[MoleculeGraph]):
"""
Initialize the MoleculeDataset object.

super(MoleculeDataset, self).__init__()
self._graphs = graphs
Parameters
----------
graphs : Iterable[MoleculeGraph]
An iterable of `MoleculeGraph` instances representing the
molecules in the dataset.
"""

def len(self) -> int:
""" torch_geometric.data.Dataset.len definition (required)
super().__init__()
self._graphs = list(graphs)

Returns:
int: number of molecule graphs
def len(self) -> int:
"""
Returns the number of molecules in the dataset.

Returns
-------
int
The number of molecule graphs in the dataset.
"""
return len(self._graphs)

def get(self, idx: int) -> 'MoleculeGraph':
""" torch_geometric.data.Dataset.get definition (required)

Args:
idx (int): index of item

Returns:
MoleculeGraph: indexed item
def get(self, idx: int) -> MoleculeGraph:
"""
Retrieves a molecule graph from the dataset by index.

Returns
-------
MoleculeGraph
The molecule graph at the specified index.
"""
return self._graphs[idx]
Loading
Loading