From ab774052cba631c4cb9b629d7365b0de139e454e Mon Sep 17 00:00:00 2001 From: Alex Tong Date: Thu, 8 Feb 2024 16:13:25 -0500 Subject: [PATCH] fix code quality python version to 3.10, add full test set eval --- .github/workflows/code-quality-main.yaml | 4 +++- dem/energies/base_energy_function.py | 9 ++++++--- dem/models/dem_module.py | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/.github/workflows/code-quality-main.yaml b/.github/workflows/code-quality-main.yaml index 88b7220..beaaa10 100644 --- a/.github/workflows/code-quality-main.yaml +++ b/.github/workflows/code-quality-main.yaml @@ -16,7 +16,9 @@ jobs: uses: actions/checkout@v2 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v3 + with: + python-version: "3.10" - name: Run pre-commits uses: pre-commit/action@v2.0.3 diff --git a/dem/energies/base_energy_function.py b/dem/energies/base_energy_function.py index d31af0d..79c76bc 100644 --- a/dem/energies/base_energy_function.py +++ b/dem/energies/base_energy_function.py @@ -62,12 +62,15 @@ def unnormalize(self, x: torch.Tensor) -> torch.Tensor: x = (x + 1) / 2 return x * (maxs - mins) + mins - def sample_test_set(self, num_points: int, normalize: bool = False) -> Optional[torch.Tensor]: + def sample_test_set(self, num_points: int, normalize: bool = False, full: bool=False) -> Optional[torch.Tensor]: if self.test_set is None: return None - idxs = torch.randperm(len(self.test_set))[:num_points] - outs = self.test_set[idxs] + if full: + outs = self.test_set + else: + idxs = torch.randperm(len(self.test_set))[:num_points] + outs = self.test_set[idxs] if normalize: outs = self.normalize(outs) diff --git a/dem/models/dem_module.py b/dem/models/dem_module.py index cfb257c..f307f9a 100644 --- a/dem/models/dem_module.py +++ b/dem/models/dem_module.py @@ -673,7 +673,7 @@ def eval_step(self, prefix: str, batch: torch.Tensor, batch_idx: int) -> None: :param batch_idx: The index of the current batch. """ if prefix == "test": - batch = self.energy_function.sample_test_set(self.eval_batch_size) + batch = self.energy_function.sample_test_set(self.eval_batch_size, full=True) elif prefix == "val": batch = self.energy_function.sample_val_set(self.eval_batch_size)