-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathjax_benchmarks.py
73 lines (50 loc) · 1.36 KB
/
jax_benchmarks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import time
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
key = random.PRNGKey(0)
size = 4096
iterations = 60
mat = random.normal(key, (size, size))
mat_T = mat.T
def apply_matrix(v):
return jnp.dot(mat, v)
apply_matrix(mat_T).block_until_ready()
time.sleep(1)
st = time.time()
for i in range(iterations):
apply_matrix(mat_T).block_until_ready()
et = time.time()
duration = et-st
fps = iterations/duration
matmul_flops = 2 * (size**3)
TFLOPS = fps*matmul_flops/(1e12)
print("| MatMul |", TFLOPS, "|")
time.sleep(1)
apply_matrix_jit = jit(apply_matrix)
apply_matrix_jit(mat_T).block_until_ready()
time.sleep(1)
st = time.time()
for i in range(iterations):
apply_matrix_jit(mat_T).block_until_ready()
et = time.time()
duration = et-st
fps = iterations/duration
matmul_flops = 2 * (size**3)
TFLOPS = fps*matmul_flops/(1e12)
print("| JIT MatMul |", TFLOPS, "|")
time.sleep(1)
batched_x = random.normal(key, (iterations, size, size))
@jit
def vmap_batched_apply_matrix(v_batched):
return vmap(apply_matrix)(v_batched)
vmap_batched_apply_matrix(batched_x).block_until_ready()
time.sleep(1)
st = time.time()
vmap_batched_apply_matrix(batched_x).block_until_ready()
et = time.time()
duration = et-st
fps = iterations/duration
matmul_flops = 2 * (size**3)
TFLOPS = fps*matmul_flops/(1e12)
print("| JIT+VMAP MatMul |", TFLOPS, "|")