diff --git a/bayesianflow_for_chem/data.py b/bayesianflow_for_chem/data.py index 71826e8..1e1e651 100644 --- a/bayesianflow_for_chem/data.py +++ b/bayesianflow_for_chem/data.py @@ -174,78 +174,64 @@ def collate(batch: List) -> Dict[str, Tensor]: return out_dict -class BaseCSVDataClass(Dataset): - def __init__( - self, - file: str, - limit: Optional[int] = None, - label_idx: Optional[List[int]] = None, - ) -> None: +class CSVData(Dataset): + def __init__(self, file: Union[str, Path]): """ - Define dataset stored in CSV file.\n - This is the base class that should not be accessed directly. + Define dataset stored in CSV file. :param file: dataset file name - :param limit: item limit - :param label_idx: a list of indices indicating which value to be input; - use `'None'` for inputting all values + :type file: str | pathlib.Path """ super().__init__() - self.data = [] with open(file, "r") as db: self.data = db.readlines() - self.smiles_idx, self.selfies_idx, self.aa_idx = [], [], [] - self.geo2seq_idx: Optional[int] = None - self.value_idx = [] + self.header_idx_dict: Dict[str, List[int]] = {} for key, i in enumerate(self.data[0].replace("\n", "").split(",")): - i = i.lower() - if i == "smiles": - self.smiles_idx.append(key) - if i == "safe": - self.smiles_idx = [key] - if i == "selfies": - self.selfies_idx.append(key) - if i == "seq": - self.aa_idx.append(key) - if i == "geo2seq": - self.geo2seq_idx = key - if i == "value": - self.value_idx.append(key) - if self.value_idx and label_idx: - self.value_idx = [self.value_idx[j] for j in label_idx] - if limit: - self.data = self.data[: limit + 1] + if i in self.header_idx_dict: + self.header_idx_dict[i].append(key) + else: + self.header_idx_dict[i] = [key] + self.mapping = lambda x: x def __len__(self) -> int: return len(self.data) - 1 - def __getitem__(self, idx: Union[int, Tensor]) -> None: - """ - You need to overwrite this method in the inherited class. - See `~bayesianflow_for_chem.data.CSVData` as an example. - """ - return super().__getitem__(idx) - - -class CSVData(BaseCSVDataClass): - def __getitem__(self, idx: Union[int, Tensor]) -> Dict[str, Dict[str, Tensor]]: + def __getitem__(self, idx: Union[int, Tensor]) -> Dict[str, Tensor]: if torch.is_tensor(idx): idx = idx.tolist() # valid `idx` should start from 1 instead of 0 - d: List[str] = self.data[idx + 1].replace("\n", "").split(",") - values = [ - float(d[i]) if d[i].strip() != "" else torch.inf for i in self.value_idx - ] - if self.smiles_idx: - smiles = ".".join([d[i] for i in self.smiles_idx if d[i] != ""]) - token = smiles2token(smiles) - if self.geo2seq_idx is not None: - seq = d[self.geo2seq_idx] - token = geo2token(seq) - out_dict = {"token": token} - if len(values) != 0: - out_dict["value"] = torch.tensor(values, dtype=torch.float32) - return out_dict + data: List[str] = self.data[idx + 1].replace("\n", "").split(",") + data_dict: Dict[str, List[str]] = {} + for key in self.header_idx_dict: + data_dict[key] = [data[i] for i in self.header_idx_dict[key]] + return self.mapping(data_dict) + + def map(self, mapping: Callable[[Dict[str, List[str]]], Any]) -> None: + """ + Pass a customised mapping function to transform the data entities to tensors. + + e.g. + ```python + import torch + from bayesianflow_for_chem.data import smiles2token, CSVData + + + def encode(x): + return { + "token": smiles2token(".".join(x["smiles"])), + "value": torch.tensor([float(i) if i != "" else torch.inf for i in x["value"]]), + } + + dataset = CSVData(...) + dataset.map(encode) + ``` + + :param mapping: customised mapping function + :type mapping: callable + :return: + :rtype: None + """ + self.mapping = mapping if __name__ == "__main__":