forked from databricks/megablocks
-
Notifications
You must be signed in to change notification settings - Fork 55
/
Copy pathdmoe.py
324 lines (280 loc) · 10.7 KB
/
dmoe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
from megablocks.layers import common
from megablocks.layers import moe
from megablocks.layers import dmlp_registry
from megablocks.layers import mpu
from megablocks.layers import router
from megablocks.layers.arguments import Arguments
import megablocks.ops as ops
import numpy as np
import stk
import torch
def promote_scalar(x):
return x.view(1) if not len(x.size()) else x
class ParallelDroplessMLP(moe.ParallelMLP):
def __init__(self, args : Arguments):
super(ParallelDroplessMLP, self).__init__(args)
self.hidden_size = args.hidden_size
self.ffn_hidden_size = mpu.features_per_rank(args)
self.blocking = 128
self.mlp = dmlp_registry.get(args)
# Calculate the number of bits needed to represent the column indices
# in the intermediate sparse matrix.
max_column_index = (
(self.ffn_hidden_size * self.num_experts) // self.blocking)
self.transpose_sort_end_bit = max(
int(np.ceil(np.log2(max_column_index))), 1)
def sparse_transpose(self, size, row_indices, column_indices, offsets):
block_columns = size[1] // self.blocking
# Sort row indices by column indices to get the transposed matrix's
# column indices.
#
# NOTE: Our sort operation uses the same width indices as the input values.
# To avoid overflow when we have large activation matrices we cast to
# 32-bit before sorting.
_, gather_indices = ops.sort(
column_indices.int(), self.transpose_sort_end_bit)
# There are a constant number of blocks in every row of the sparse matrix.
# A blocks offset is:
#
# row_index * blocks_per_row + column_index % blocks_per_row
#
# Once we have the block offsets ordered for transposition we can divide
# by blocks_per_row to get the transposed column indices.
column_indices_t = row_indices.gather(0, gather_indices.long())
block_offsets_t = gather_indices.int()
zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
nnz_per_column = ops.histogram(column_indices, block_columns)
nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
offsets_t = torch.cat([zero, nnz_per_column])
return column_indices_t, offsets_t, block_offsets_t
def topology(self, x, padded_bins):
padded_tokens, _ = x.size()
assert padded_tokens % self.blocking == 0
assert self.ffn_hidden_size % self.blocking == 0
# Offsets for the sparse matrix. All rows have the
# same number of nonzero blocks dictated by the
# dimensionality of a single expert.
block_rows = padded_tokens // self.blocking
blocks_per_row = self.ffn_hidden_size // self.blocking
offsets = torch.arange(
0,
block_rows * blocks_per_row + 1,
blocks_per_row,
dtype=torch.int32,
device=x.device)
# Indices for the sparse matrix. The indices for
# the intermediate matrix are dynamic depending
# on the mapping of tokens to experts.
column_indices = ops.topology(padded_bins,
self.blocking,
block_rows,
blocks_per_row)
# TODO(tgale): This is unused. Remove the need for this in stk.
# For now, use meta init to save the device memory.
data = torch.empty(
column_indices.numel(),
self.blocking,
self.blocking,
dtype=common.dtype(self.args),
device='meta')
shape = (
padded_tokens,
self.ffn_hidden_size * mpu.experts_per_rank(self.args)
)
row_indices = stk.ops.row_indices(
shape, data, offsets, column_indices)
column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
shape, row_indices, column_indices, offsets)
return stk.Matrix(shape, data, row_indices, column_indices, offsets,
column_indices_t, offsets_t, block_offsets_t)
def indices_and_padded_bins(self, top_experts):
# Sort the expert ids to produce the scatter/gather
# indices for the permutation.
top_experts = top_experts.int()
bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
# Histogram the expert ids to identify the number of
# tokens routed to each expert.
tokens_per_expert = ops.histogram(top_experts, self.num_experts)
# Round the token counts up to the block size used in
# the matrix muliplications. Caculate the starting
# position of each bin.
padded_tokens_per_expert = ops.round_up(
tokens_per_expert, self.blocking)
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
padded_bins = promote_scalar(padded_bins)
# Calculate the bin bounds for the sorted tokens.
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
bins = promote_scalar(bins)
return indices, bin_ids, bins, padded_bins, tokens_per_expert
def sparse_forward_once(self, x, expert_weights, top_experts):
# x: [sl, bs, hs]
# expert_weights: [sl * bs, top-k]
# top_experts: [sl * bs, top-k]
expert_weights = expert_weights.flatten()
top_experts = top_experts.flatten()
with torch.no_grad():
indices, bin_ids, bins, padded_bins, tokens_per_expert = (
self.indices_and_padded_bins(top_experts))
# Route the tokens for MoE computation.
x = x.view(-1, x.shape[-1])
x = ops.padded_gather(
x,
indices,
bin_ids,
bins,
padded_bins,
self.top_k)
# Create the sparse matrix topology.
with torch.no_grad():
topo = self.topology(x, padded_bins)
# Perform the expert computation.
x = self.mlp(x, topo)
# Un-route the data for the MoE output.
x = ops.padded_scatter(
x,
indices,
bin_ids,
expert_weights,
bins,
padded_bins,
self.top_k,
self.args.quantize_scatter_num_bits)
return x, tokens_per_expert
# For use in the base-class parallel_forward_once.
def sparse_permute_and_compute(
self,
x,
tokens_per_expert,
indices,
bin_ids,
expert_weights,
bins,
expert_capactiy, # unused
top_k):
# Round the token counts up to the block size used in the matrix
# multiplication. Calculate the starting position of each bin.
padded_tokens_per_expert = ops.round_up(
tokens_per_expert, self.blocking)
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
padded_bins = promote_scalar(padded_bins)
# Route the tokens for MoE computation.
x = x.view(-1, x.shape[-1])
x = ops.padded_gather(
x,
indices,
bin_ids,
bins,
padded_bins,
top_k)
# Create the sparse matrix topology.
with torch.no_grad():
topo = self.topology(x, padded_bins)
# Perform the expert computation.
x = self.mlp(x, topo)
# Un-route the data for the MoE output.
return ops.padded_scatter(
x,
indices,
bin_ids,
expert_weights,
bins,
padded_bins,
top_k)
def grouped_forward_once(self, x, expert_weights, top_experts):
# x: [sl, bs, hs]
# expert_weights: [sl * bs, top-k]
# top_experts: [sl * bs, top-k]
expert_weights = expert_weights.flatten()
top_experts = top_experts.flatten()
with torch.no_grad():
indices, bin_ids, bins, tokens_per_expert = (
self.indices_and_bins(top_experts))
out = self.grouped_permute_and_compute(
x,
tokens_per_expert,
indices,
bin_ids,
expert_weights,
bins,
-1, # unused
self.args.moe_top_k)
return out, tokens_per_expert
def grouped_permute_and_compute(
self,
x,
tokens_per_expert,
indices,
bin_ids,
expert_weights,
bins,
expert_capactiy, # unused
top_k):
# Route the tokens for MoE computation.
x = x.view(-1, x.shape[-1])
x = ops.gather(
x,
indices,
bin_ids,
bins,
top_k)
# Perform the expert computation.
x = self.mlp(x, tokens_per_expert)
# Un-route the data for the MoE output.
return ops.scatter(
x,
indices,
bin_ids,
expert_weights,
bins,
top_k,
self.args.quantize_scatter_num_bits)
def forward_once(self, x, expert_weights, top_experts):
if self.args.grouped_mlp:
return self.grouped_forward_once(
x, expert_weights, top_experts)
return self.sparse_forward_once(
x, expert_weights, top_experts)
def permute_and_compute(
self,
x,
tokens_per_expert,
indices,
bin_ids,
expert_weights,
bins,
expert_capactiy,
top_k):
if self.args.grouped_mlp:
return self.grouped_permute_and_compute(
x,
tokens_per_expert,
indices,
bin_ids,
expert_weights,
bins,
expert_capactiy,
top_k)
return self.sparse_permute_and_compute(
x,
tokens_per_expert,
indices,
bin_ids,
expert_weights,
bins,
expert_capactiy,
top_k)
class dMoE(torch.nn.Module):
def __init__(self, args : Arguments):
super(dMoE, self).__init__()
# Token router.
self.router = router.LearnedRouter(args)
# Expert computation helper.
self.experts = ParallelDroplessMLP(args)
def forward(self, x):
# NOTE: If we're going to cast the activations to lower precision
# do it before we permute the tokens to save bandwidth.
x = common.cast_if_autocast_enabled(x)
# Compute the expert scores and assignments.
scores, expert_weights, top_experts = self.router(x)
# Compute the experts.
return self.experts(x, scores, expert_weights, top_experts)