Skip to content

Commit

Permalink
fix code quality python version to 3.10, add full test set eval
Browse files Browse the repository at this point in the history
  • Loading branch information
atong01 committed Feb 8, 2024
1 parent 04bdb57 commit ab77405
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/code-quality-main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ jobs:
uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v3
with:
python-version: "3.10"

- name: Run pre-commits
uses: pre-commit/[email protected]
9 changes: 6 additions & 3 deletions dem/energies/base_energy_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,15 @@ def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
x = (x + 1) / 2
return x * (maxs - mins) + mins

def sample_test_set(self, num_points: int, normalize: bool = False) -> Optional[torch.Tensor]:
def sample_test_set(self, num_points: int, normalize: bool = False, full: bool=False) -> Optional[torch.Tensor]:
if self.test_set is None:
return None

idxs = torch.randperm(len(self.test_set))[:num_points]
outs = self.test_set[idxs]
if full:
outs = self.test_set
else:
idxs = torch.randperm(len(self.test_set))[:num_points]
outs = self.test_set[idxs]
if normalize:
outs = self.normalize(outs)

Expand Down
2 changes: 1 addition & 1 deletion dem/models/dem_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ def eval_step(self, prefix: str, batch: torch.Tensor, batch_idx: int) -> None:
:param batch_idx: The index of the current batch.
"""
if prefix == "test":
batch = self.energy_function.sample_test_set(self.eval_batch_size)
batch = self.energy_function.sample_test_set(self.eval_batch_size, full=True)
elif prefix == "val":
batch = self.energy_function.sample_val_set(self.eval_batch_size)

Expand Down

0 comments on commit ab77405

Please sign in to comment.