Skip to content

Commit

Permalink
Update data.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Augus1999 authored Jan 20, 2025
1 parent 44b9c0c commit 6048478
Showing 1 changed file with 43 additions and 57 deletions.
100 changes: 43 additions & 57 deletions bayesianflow_for_chem/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,78 +174,64 @@ def collate(batch: List) -> Dict[str, Tensor]:
return out_dict


class BaseCSVDataClass(Dataset):
def __init__(
self,
file: str,
limit: Optional[int] = None,
label_idx: Optional[List[int]] = None,
) -> None:
class CSVData(Dataset):
def __init__(self, file: Union[str, Path]):
"""
Define dataset stored in CSV file.\n
This is the base class that should not be accessed directly.
Define dataset stored in CSV file.
:param file: dataset file name <file>
:param limit: item limit
:param label_idx: a list of indices indicating which value to be input;
use `'None'` for inputting all values
:type file: str | pathlib.Path
"""
super().__init__()
self.data = []
with open(file, "r") as db:
self.data = db.readlines()
self.smiles_idx, self.selfies_idx, self.aa_idx = [], [], []
self.geo2seq_idx: Optional[int] = None
self.value_idx = []
self.header_idx_dict: Dict[str, List[int]] = {}
for key, i in enumerate(self.data[0].replace("\n", "").split(",")):
i = i.lower()
if i == "smiles":
self.smiles_idx.append(key)
if i == "safe":
self.smiles_idx = [key]
if i == "selfies":
self.selfies_idx.append(key)
if i == "seq":
self.aa_idx.append(key)
if i == "geo2seq":
self.geo2seq_idx = key
if i == "value":
self.value_idx.append(key)
if self.value_idx and label_idx:
self.value_idx = [self.value_idx[j] for j in label_idx]
if limit:
self.data = self.data[: limit + 1]
if i in self.header_idx_dict:
self.header_idx_dict[i].append(key)
else:
self.header_idx_dict[i] = [key]
self.mapping = lambda x: x

def __len__(self) -> int:
return len(self.data) - 1

def __getitem__(self, idx: Union[int, Tensor]) -> None:
"""
You need to overwrite this method in the inherited class.
See `~bayesianflow_for_chem.data.CSVData` as an example.
"""
return super().__getitem__(idx)


class CSVData(BaseCSVDataClass):
def __getitem__(self, idx: Union[int, Tensor]) -> Dict[str, Dict[str, Tensor]]:
def __getitem__(self, idx: Union[int, Tensor]) -> Dict[str, Tensor]:
if torch.is_tensor(idx):
idx = idx.tolist()
# valid `idx` should start from 1 instead of 0
d: List[str] = self.data[idx + 1].replace("\n", "").split(",")
values = [
float(d[i]) if d[i].strip() != "" else torch.inf for i in self.value_idx
]
if self.smiles_idx:
smiles = ".".join([d[i] for i in self.smiles_idx if d[i] != ""])
token = smiles2token(smiles)
if self.geo2seq_idx is not None:
seq = d[self.geo2seq_idx]
token = geo2token(seq)
out_dict = {"token": token}
if len(values) != 0:
out_dict["value"] = torch.tensor(values, dtype=torch.float32)
return out_dict
data: List[str] = self.data[idx + 1].replace("\n", "").split(",")
data_dict: Dict[str, List[str]] = {}
for key in self.header_idx_dict:
data_dict[key] = [data[i] for i in self.header_idx_dict[key]]
return self.mapping(data_dict)

def map(self, mapping: Callable[[Dict[str, List[str]]], Any]) -> None:
"""
Pass a customised mapping function to transform the data entities to tensors.
e.g.
```python
import torch
from bayesianflow_for_chem.data import smiles2token, CSVData
def encode(x):
return {
"token": smiles2token(".".join(x["smiles"])),
"value": torch.tensor([float(i) if i != "" else torch.inf for i in x["value"]]),
}
dataset = CSVData(...)
dataset.map(encode)
```
:param mapping: customised mapping function
:type mapping: callable
:return:
:rtype: None
"""
self.mapping = mapping


if __name__ == "__main__":
Expand Down

0 comments on commit 6048478

Please sign in to comment.