Skip to content

Commit

Permalink
one epoch only
Browse files Browse the repository at this point in the history
  • Loading branch information
v-chen_data committed Nov 30, 2024
1 parent 69f13ba commit 7b69fc3
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ def test_event_values(event: Event):

class TestEventCalls:

eval_subset_num_batches = 2
train_subset_num_batches = 2
eval_subset_num_batches = 1
train_subset_num_batches = 1

def get_trainer(self, precision='fp32', **kwargs):
def get_trainer(self, precision='fp32', max_duration='1ep', **kwargs):
model = SimpleModel()
optimizer = torch.optim.Adam(model.parameters())

train_dataset = RandomClassificationDataset()
eval_dataset = RandomClassificationDataset()
train_batch_size = 2
train_batch_size = 2

evaluator1 = DataLoader(
dataset=eval_dataset,
Expand All @@ -57,7 +57,7 @@ def get_trainer(self, precision='fp32', **kwargs):
precision=precision,
train_subset_num_batches=self.train_subset_num_batches,
eval_subset_num_batches=self.eval_subset_num_batches,
max_duration='1ep',
max_duration=max_duration,
optimizers=optimizer,
callbacks=[EventCounterCallback()],
**kwargs,
Expand Down Expand Up @@ -99,7 +99,7 @@ def get_trainer(self, precision='fp32', **kwargs):
),
],
)
@pytest.mark.parametrize('save_interval', ['1ep', '1ba'])
@pytest.mark.parametrize('save_interval', ['1ep'])
def test_event_calls(self, world_size, device, deepspeed_zero_stage, use_fsdp, precision, save_interval):
save_interval = Time.from_timestring(save_interval)

Expand All @@ -124,6 +124,7 @@ def test_event_calls(self, world_size, device, deepspeed_zero_stage, use_fsdp, p
parallelism_config=parallelism_config,
save_interval=save_interval,
eval_interval=save_interval,
max_duration='1ep',
)
trainer.fit()

Expand Down

0 comments on commit 7b69fc3

Please sign in to comment.