From 868a71b54733a39f8b0c7a73e75512ba516efac8 Mon Sep 17 00:00:00 2001 From: v-chen_data Date: Sat, 30 Nov 2024 18:45:05 -0500 Subject: [PATCH] revery --- tests/trainer/test_fsdp_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index b0890d26a9..332ed5b7b7 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -111,11 +111,11 @@ def get_trainer( val_metrics=val_metrics, ) model.module.to(model_init_device) - dataset = RandomClassificationDataset(shape=(num_features,), num_classes=num_classes, size=8) + dataset = RandomClassificationDataset(shape=(num_features,), num_classes=num_classes, size=128) dataloader = DataLoader( dataset, sampler=dist.get_sampler(dataset), - batch_size=2, + batch_size=8, ) if optimizer == 'adam': optim = torch.optim.Adam(params=model.parameters())