-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #17 from ecrl/20250111_improvements
Various improvements
- Loading branch information
Showing
20 changed files
with
1,793 additions
and
1,343 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# Prerequisites | ||
*.d | ||
|
||
# Compiled object files | ||
*.slo | ||
*.lo | ||
*.o | ||
*.obj | ||
|
||
# Precompiled headers | ||
*.gch | ||
*.pch | ||
|
||
# Compiled dynamic libraries | ||
*.so | ||
*.so.[0-9]* | ||
*.dylib | ||
*.dll | ||
|
||
# Fortran module files | ||
*.mod | ||
*.smod | ||
|
||
# Compiled static libraries | ||
*.lai | ||
*.la | ||
*.a | ||
*.lib | ||
|
||
# Executables | ||
*.exe | ||
*.out | ||
*.app | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
.installed.cfg | ||
MANIFEST | ||
*.egg-info/ | ||
*.egg | ||
*.manifest | ||
*.spec | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*,cover | ||
.pytest_cache/ | ||
|
||
# Documentation | ||
doc/html/ | ||
doc/latex/ | ||
doc/man/ | ||
doc/xml/ | ||
doc/_build/ | ||
doc/source | ||
doc/modules | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
|
||
# Editor junk | ||
tags | ||
[._]*.s[a-v][a-z] | ||
[._]*.sw[a-p] | ||
[._]s[a-v][a-z] | ||
[._]sw[a-p] | ||
*~ | ||
\#*\# | ||
.\#* | ||
.ropeproject | ||
.idea/ | ||
.spyderproject | ||
.spyproject | ||
.vscode/ | ||
# Mac .DS_Store | ||
.DS_Store | ||
|
||
# jupyter notebook checkpoints | ||
.ipynb_checkpoints | ||
|
||
# version file | ||
graphchem/_version.py |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,71 +1,117 @@ | ||
r"""Molecule graph structure, graph dataset""" | ||
from typing import Iterable, Optional | ||
|
||
from typing import List | ||
import torch | ||
import torch_geometric | ||
|
||
|
||
class MoleculeGraph(torch_geometric.data.Data): | ||
|
||
def __init__(self, atom_attr: 'torch.tensor', bond_attr: 'torch.tensor', | ||
connectivity: 'torch.tensor', target: 'torch.tensor' = None): | ||
""" MoleculeGraph object, extends torch_geometric.data.Data object; a | ||
singular molecule graph/data point | ||
Args: | ||
atom_attr (torch.tensor): atom features, shape (n_atoms, | ||
n_atom_features); dtype assumed torch.float32 | ||
bond_attr (torch.tensor): bond features, shape (n_bonds, | ||
n_bond_features); dtype assumed torch.float32 | ||
connectivity (torch.tensor): COO graph connectivity index, size | ||
(2, n_bonds); dtype assumed torch.long | ||
target (torch.tensor, default=None): target value(s), shape | ||
(1, n_targets); if not supplied (None), set to [0.0]; dtype | ||
assumed torch.float32 | ||
from torch_geometric.data import Data, Dataset | ||
|
||
|
||
class MoleculeGraph(Data): | ||
""" | ||
A custom graph class representing a molecular structure. | ||
This class extends the `Data` class from PyTorch Geometric to represent | ||
molecules with node attributes (atoms), edge attributes (bonds), and | ||
connectivity information. It also includes an optional target value. | ||
Attributes | ||
---------- | ||
x : torch.Tensor | ||
The node features (atom attributes). | ||
edge_index : torch.Tensor | ||
A 2D tensor describing the connectivity between atoms. | ||
edge_attr : torch.Tensor | ||
Edge features (bond attributes). | ||
y : torch.Tensor | ||
Target value(s) of the molecule. | ||
""" | ||
|
||
def __init__(self, atom_attr: torch.Tensor, | ||
bond_attr: torch.Tensor, | ||
connectivity: torch.Tensor, | ||
target: Optional[torch.Tensor] = None): | ||
""" | ||
Initialize the MoleculeGraph object. | ||
Parameters | ||
---------- | ||
atom_attr : torch.Tensor | ||
A 2D tensor of shape (num_atoms, num_atom_features) representing | ||
the attributes of each atom in the molecule. | ||
bond_attr : torch.Tensor | ||
A 2D tensor of shape (num_bonds, num_bond_features) representing | ||
the attributes of each bond in the molecule. | ||
connectivity : torch.Tensor | ||
A 2D tensor of shape (2, num_bonds) where each column represents an | ||
edge (bond) between two atoms. The first row contains the source | ||
atom indices and the second row contains the target atom indices. | ||
target : Optional[torch.Tensor] | ||
An optional 1D or 2D tensor representing the target value(s) of the | ||
molecule. If not provided, it defaults to a tensor with a single | ||
element set to 0.0. | ||
""" | ||
|
||
if target is None: | ||
# Set default target to a tensor with shape (1, 1) and value 0.0 | ||
target = torch.tensor([0.0]).type(torch.float32).reshape(1, 1) | ||
elif len(target.shape) == 1: | ||
# Reshape target if it's a 1D tensor to (1, target.shape[0]) | ||
target = target.reshape(1, -1) | ||
if target.shape[0] != 1: | ||
raise ValueError("Target tensor must have shape (1, num_targets)") | ||
|
||
super(MoleculeGraph, self).__init__( | ||
super().__init__( | ||
x=atom_attr, | ||
edge_index=connectivity, | ||
edge_attr=bond_attr, | ||
y=torch.tensor(target).type(torch.float).reshape(1, len(target)) | ||
y=target | ||
) | ||
|
||
|
||
class MoleculeDataset(torch_geometric.data.Dataset): | ||
class MoleculeDataset(Dataset): | ||
""" | ||
A custom dataset class for molecular graphs. | ||
This class extends the `Dataset` class from PyTorch Geometric to create a | ||
dataset of molecular graphs. Each graph in the dataset is an instance of | ||
`MoleculeGraph`. | ||
def __init__(self, graphs: List['MoleculeGraph']): | ||
""" MoleculeDataset object, extends torch_geometric.data.Dataset | ||
object; a torch_geometric-iterable dataset comprised of MoleculeGraph | ||
objects | ||
Attributes | ||
---------- | ||
_graphs : List[MoleculeGraph] | ||
A list containing all the molecule graphs in the dataset. | ||
""" | ||
|
||
Args: | ||
graphs (List[MoleculeGraph]): list of molecule graphs | ||
def __init__(self, graphs: Iterable[MoleculeGraph]): | ||
""" | ||
Initialize the MoleculeDataset object. | ||
super(MoleculeDataset, self).__init__() | ||
self._graphs = graphs | ||
Parameters | ||
---------- | ||
graphs : Iterable[MoleculeGraph] | ||
An iterable of `MoleculeGraph` instances representing the | ||
molecules in the dataset. | ||
""" | ||
|
||
def len(self) -> int: | ||
""" torch_geometric.data.Dataset.len definition (required) | ||
super().__init__() | ||
self._graphs = list(graphs) | ||
|
||
Returns: | ||
int: number of molecule graphs | ||
def len(self) -> int: | ||
""" | ||
Returns the number of molecules in the dataset. | ||
Returns | ||
------- | ||
int | ||
The number of molecule graphs in the dataset. | ||
""" | ||
return len(self._graphs) | ||
|
||
def get(self, idx: int) -> 'MoleculeGraph': | ||
""" torch_geometric.data.Dataset.get definition (required) | ||
Args: | ||
idx (int): index of item | ||
Returns: | ||
MoleculeGraph: indexed item | ||
def get(self, idx: int) -> MoleculeGraph: | ||
""" | ||
Retrieves a molecule graph from the dataset by index. | ||
Returns | ||
------- | ||
MoleculeGraph | ||
The molecule graph at the specified index. | ||
""" | ||
return self._graphs[idx] |
Oops, something went wrong.