-
Notifications
You must be signed in to change notification settings - Fork 0
/
bench.py
113 lines (87 loc) · 3.18 KB
/
bench.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import time
import jax
from jax import numpy as jnp
from flax.training.train_state import TrainState
import optax
from real_lru.real import LRU
from real_lru.naive import LRU as LRU_naive
### TEST if outputs match ###
key = jax.random.PRNGKey(907654)
batch = jax.random.normal(key, (1, 1024, 256), dtype=jnp.float32)
lru_opt = LRU(
d_model=256
)
lru = LRU_naive(
d_model=256
)
params = lru_opt.init(key, batch, False)
res_opt = lru_opt.apply(params, batch, False)
res = lru.apply(params, batch, False)
assert jnp.allclose(res, res_opt, atol=1e-6)
def create_train_state(cls_model, batch_size, hidden_dim, ssm_dim, seq_len, dtype, key):
model = cls_model(
d_model=hidden_dim,
ssm_size=ssm_dim,
dtype=dtype
)
batch = jax.random.normal(key, (batch_size, seq_len, hidden_dim), dtype=dtype)
params = model.init(key, batch, training=True)
optimizer = optax.adam(0.001)
opt_state = optimizer.init(params)
train_state = TrainState(
apply_fn=model.apply,
params=params,
step=0,
tx=optimizer,
opt_state=opt_state
)
return train_state
import numpy as np
from tqdm import trange
def get_benchmark(train_state, batch_size, seq_len, hidden_dim, dtype):
@jax.jit
def forward(batch, train_state):
out = train_state.apply_fn(train_state.params, batch, True)
return out
def test_speed(train_state, batch_size, seq_len, hidden_dim, key, dtype):
batch = jax.random.normal(key, (batch_size, seq_len, hidden_dim), dtype=dtype)
start = time.time()
out = forward(batch, train_state)
jax.block_until_ready(out)
assert out.dtype == dtype
return start
test_times = []
key = jax.random.PRNGKey(907654)
for i in trange(201):
key, _ = jax.random.split(key)
start = time.time()
test_speed(train_state, batch_size, seq_len, hidden_dim, key, dtype)
delta = time.time() - start
if i > 1:
test_times += [delta]
print(f"\n {np.mean(test_times):.4f} ± {np.std(test_times):.4f}")
return test_times
if __name__ == "__main__":
results = []
model_cls = [LRU, LRU_naive]
hidden_dim = 256
ssm_dim = 128
batch_size = 64
for model in model_cls:
for seq_len in [512, 1024, 2048, 4096]:
ts = create_train_state(model, batch_size, hidden_dim, ssm_dim, seq_len, jnp.float32, jax.random.PRNGKey(42))
results.append(get_benchmark(ts, batch_size, seq_len, hidden_dim, jnp.float32))
for seq_len in [512, 1024, 2048, 4096]:
ts = create_train_state(LRU, batch_size, hidden_dim, ssm_dim, seq_len, jnp.float16, jax.random.PRNGKey(42))
results.append(get_benchmark(ts, batch_size, seq_len, hidden_dim, jnp.float16))
import seaborn as sns
from matplotlib import pyplot as plt
sns.set(style="whitegrid", font_scale=1.2)
x = [512, 1024, 2048, 4096]
plt.plot(x, np.mean(results[4:8], -1), label="Naive LRU")
plt.plot(x, np.mean(results[:4], -1), label="LRU Real")
plt.plot(x, np.mean(results[-4:], -1), label="LRU Real + FP16")
plt.legend()
plt.xlabel("Sequence Length")
plt.ylabel("Time (seconds)")
plt.show()