Skip to content

Commit

Permalink
Move writer/reader formats into base class (#1833)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattwthompson authored Mar 21, 2024
1 parent 5744778 commit b739698
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 75 deletions.
3 changes: 0 additions & 3 deletions openff/toolkit/utils/ambertools_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ class AmberToolsToolkitWrapper(base_wrapper.ToolkitWrapper):
def __init__(self):
super().__init__()

self._toolkit_file_read_formats = []
self._toolkit_file_write_formats = []

if not self.is_available():
raise ToolkitUnavailableException(
f"The required toolkit {self._toolkit_name} is not "
Expand Down
2 changes: 2 additions & 0 deletions openff/toolkit/utils/base_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class ToolkitWrapper:
_toolkit_installation_instructions: Optional[str] = (
None # Installation instructions for the toolkit
)
_toolkit_file_read_formats: list[str] = list()
_toolkit_file_write_formats: list[str] = list()

# @staticmethod
# TODO: Right now, to access the class definition, I have to make this a classmethod
Expand Down
3 changes: 0 additions & 3 deletions openff/toolkit/utils/builtin_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ class BuiltInToolkitWrapper(base_wrapper.ToolkitWrapper):
def __init__(self):
super().__init__()

self._toolkit_file_read_formats = []
self._toolkit_file_write_formats = []

def assign_partial_charges(
self,
molecule,
Expand Down
3 changes: 0 additions & 3 deletions openff/toolkit/utils/nagl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ class NAGLToolkitWrapper(ToolkitWrapper):
def __init__(self):
super().__init__()

self._toolkit_file_read_formats = []
self._toolkit_file_write_formats = []

if not self.is_available():
raise ToolkitUnavailableException(
f"The required toolkit {self._toolkit_name} is not "
Expand Down
97 changes: 48 additions & 49 deletions openff/toolkit/utils/openeye_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,57 +101,56 @@ class OpenEyeToolkitWrapper(base_wrapper.ToolkitWrapper):
"oeiupac": "OEIUPACIsLicensed",
"oeomega": "OEOmegaIsLicensed",
}
_toolkit_file_read_formats = [
"CAN",
"CDX",
"CSV",
"FASTA",
"INCHI",
"INCHIKEY",
"ISM",
"MDL",
"MF",
"MMOD",
"MOL2",
"MOL2H",
"MOPAC",
"OEB",
"PDB",
"RDF",
"SDF",
"SKC",
"SLN",
"SMI",
"USM",
"XYC",
]
_toolkit_file_write_formats = [
"CAN",
"CDX",
"CSV",
"FASTA",
"INCHI",
"INCHIKEY",
"ISM",
"MDL",
"MF",
"MMOD",
"MOL2",
"MOL2H",
"MOPAC",
"OEB",
"PDB",
"RDF",
"SDF",
"SKC",
"SLN",
"SMI",
"USM",
"XYC",
]

def __init__(self):
self._toolkit_file_read_formats = [
"CAN",
"CDX",
"CSV",
"FASTA",
"INCHI",
"INCHIKEY",
"ISM",
"MDL",
"MF",
"MMOD",
"MOL2",
"MOL2H",
"MOPAC",
"OEB",
"PDB",
"RDF",
"SDF",
"SKC",
"SLN",
"SMI",
"USM",
"XYC",
]
self._toolkit_file_write_formats = [
"CAN",
"CDX",
"CSV",
"FASTA",
"INCHI",
"INCHIKEY",
"ISM",
"MDL",
"MF",
"MMOD",
"MOL2",
"MOL2H",
"MOPAC",
"OEB",
"PDB",
"RDF",
"SDF",
"SKC",
"SLN",
"SMI",
"USM",
"XYC",
]

# check if the toolkit can be loaded
if not self.is_available():
if self._is_installed is False:
Expand Down
38 changes: 21 additions & 17 deletions openff/toolkit/utils/rdkit_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import tempfile
import warnings
from collections import defaultdict
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional

import numpy as np
from cachetools import LRUCache, cached
Expand Down Expand Up @@ -78,12 +78,19 @@ class RDKitToolkitWrapper(base_wrapper.ToolkitWrapper):
"A conda-installable version of the free and open source RDKit cheminformatics "
"toolkit can be found at: https://anaconda.org/conda-forge/rdkit"
)
# TODO: Add TDT support
_toolkit_file_read_formats = ["SDF", "MOL", "SMI"]
_toolkit_file_write_formats = [
"SDF",
"MOL",
"SMI",
"PDB",
"TDT",
]

def __init__(self):
super().__init__()

self._toolkit_file_read_formats = ["SDF", "MOL", "SMI"] # TODO: Add TDT support

if not self.is_available():
raise ToolkitUnavailableException(
f"The required toolkit {self._toolkit_name} is not "
Expand All @@ -94,24 +101,12 @@ def __init__(self):

self._toolkit_version = rdkit_version

from rdkit import Chem

# we have to make sure the toolkit can be loaded before formatting this dict
# Note any new file write formats should be added here only
self._toolkit_file_write_formats = {
"SDF": Chem.SDWriter,
"MOL": Chem.SDWriter,
"SMI": None, # Special support to use to_smiles() instead of RDKit's SmilesWriter
"PDB": Chem.PDBWriter,
"TDT": Chem.TDTWriter,
}

@property
def toolkit_file_write_formats(self) -> list[str]:
"""
List of file formats that this toolkit can write.
"""
return list(self._toolkit_file_write_formats.keys())
return self._toolkit_file_write_formats

@classmethod
def is_available(cls) -> bool:
Expand Down Expand Up @@ -1225,6 +1220,15 @@ def to_file_obj(self, molecule: "Molecule", file_obj, file_format: str):
-------
"""
from rdkit import Chem

_TOOLKIT_WRITERS: dict[str, Any] = {
"SDF": Chem.SDWriter,
"MOL": Chem.SDWriter,
"PDB": Chem.PDBWriter,
"TDT": Chem.TDTWriter,
}

file_format = normalize_file_format(file_format)
_require_text_file_obj(file_obj)

Expand All @@ -1239,7 +1243,7 @@ def to_file_obj(self, molecule: "Molecule", file_obj, file_format: str):
file_obj.write(output_line)
else:
try:
writer_func = self._toolkit_file_write_formats[file_format]
writer_func = _TOOLKIT_WRITERS[file_format]
except KeyError:
raise ValueError(f"Unsupported file format: {file_format})") from None
rdmol = self.to_rdkit(molecule)
Expand Down

0 comments on commit b739698

Please sign in to comment.