forked from meta-llama/llama-cookbook
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_train_utils.py
51 lines (44 loc) · 1.42 KB
/
test_train_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import torch
from llama_recipes.utils.train_utils import train
def test_gradient_accumulation(mocker):
# import sys
# sys.path.append('/home/ubuntu/llama-recipes/')
model = mocker.MagicMock(name="model")
model().loss.__truediv__().detach.return_value = torch.tensor(1)
batch = {"input": torch.zeros(1)}
train_dataloader = [batch, batch, batch, batch, batch]
eval_dataloader = None
tokenizer = mocker.MagicMock()
optimizer = mocker.MagicMock()
lr_scheduler = mocker.MagicMock()
gradient_accumulation_steps = 1
train_config = mocker.MagicMock()
train_config.enable_fsdp = False
train_config.use_fp16 = False
train_config.run_validation = False
train(
model,
train_dataloader,
eval_dataloader,
tokenizer,
optimizer,
lr_scheduler,
gradient_accumulation_steps,
train_config,
)
assert optimizer.zero_grad.call_count == 5
optimizer.zero_grad.reset_mock()
gradient_accumulation_steps = 2
train(
model,
train_dataloader,
eval_dataloader,
tokenizer,
optimizer,
lr_scheduler,
gradient_accumulation_steps,
train_config,
)
assert optimizer.zero_grad.call_count == 3