Skip to content

Commit

Permalink
[frontend] Sync DynamoCompiler to torch2.5 (#431)
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Tars authored Dec 13, 2024
1 parent 73b5ab5 commit 5acfb02
Show file tree
Hide file tree
Showing 16 changed files with 789 additions and 51 deletions.
2 changes: 1 addition & 1 deletion examples/BuddyBert/bert-main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ int main() {

/// Execute forward inference of the model.
_mlir_ciface_forward(&result, &arg0, &arg1, &pureStrContainer,
&attention_mask, &token_type_ids);
&token_type_ids, &attention_mask);

const auto inferenceEnd = std::chrono::high_resolution_clock::now();
const std::chrono::duration<double, std::milli> inferenceTime =
Expand Down
4 changes: 1 addition & 3 deletions examples/BuddyLeNet/buddy-lenet-import.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import numpy as np
import torch
from torch._inductor.decomposition import decompositions as inductor_decomp

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.graph import GraphDriver
Expand All @@ -39,13 +38,12 @@
)

model = LeNet()
model = torch.load(model_path + "/lenet-model.pth")
model = torch.load(model_path + "/lenet-model.pth", weights_only=False)
model = model.eval()

# Initialize Dynamo Compiler with specific configurations as an importer.
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
)

data = torch.randn([1, 1, 28, 28])
Expand Down
1 change: 1 addition & 0 deletions examples/BuddyLlama/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ add_custom_command(
COMMAND ${LLVM_TOOLS_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyLlama/subgraph0.mlir
-pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" |
${BUDDY_BINARY_DIR}/buddy-opt
-convert-elementwise-to-linalg
-arith-expand
-eliminate-empty-tensors
-empty-tensor-to-alloc-tensor
Expand Down
2 changes: 1 addition & 1 deletion examples/BuddyLlama/import-llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)

# Initialize the tokenizer and model from the specified model path.
tokenizer = LlamaTokenizer.from_pretrained(model_path)
tokenizer = LlamaTokenizer.from_pretrained(model_path, legacy=True)
model = LlamaForCausalLM.from_pretrained(model_path, torchscript=True)
model.config.use_cache = False

Expand Down
2 changes: 1 addition & 1 deletion examples/BuddyLlama/llama-main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

using namespace buddy;

constexpr size_t ParamsSize = 6755192832;
constexpr size_t ParamsSize = 6738415680;
constexpr size_t MaxVocabSize = 32000;
constexpr size_t MaxTokenLength = 40;
constexpr size_t HiddenSize = 4096;
Expand Down
4 changes: 2 additions & 2 deletions examples/BuddyWhisper/whisper-main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ using namespace std;
using namespace buddy;
using namespace dap;

constexpr size_t ParamsSize = 99148800;
constexpr size_t ParamsSize = 72593920;
constexpr size_t MaxVocabSize = 51865;
constexpr size_t MaxTokenLength = 448;

Expand Down Expand Up @@ -180,4 +180,4 @@ int main() {
<< std::endl;

return 0;
}
}
51 changes: 44 additions & 7 deletions frontend/Python/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(
"mean.dim": MeanOp,
"rsqrt.default": RsqrtOp,
"mul.Tensor": MulOp,
"mul.Scalar": MulOp,
"t.default": TOp,
"mm.default": MatmulOp,
"transpose.int": TransposeOp,
Expand Down Expand Up @@ -167,6 +168,10 @@ def __init__(
"split.Tensor":SplitOp,
"max.default":MaxOp,
"gt.Scalar":GtOp,
"_scaled_dot_product_flash_attention_for_cpu.default": ScaledDotProductFlashAttentionForCpuOp,
"ge.Scalar": GeOp,
"gt.Tensor": GreaterThanOp,
"_unsafe_index.Tensor": UnsafeIndexOp,
}

@property
Expand Down Expand Up @@ -257,11 +262,26 @@ def _compile_fx(
return for torchdynamo's call.
"""

params = {
**dict(gm.named_parameters(remove_duplicate=False)),
**dict(gm.named_buffers(remove_duplicate=False)),
}
params_flat, _ = pytree.tree_flatten(params)
# params = {
# # **dict(gm.named_parameters(remove_duplicate=False)),
# **dict(gm.named_buffers(remove_duplicate=False)),
# }
# print(len(params))
# params_flat, _ = pytree.tree_flatten(params)
inputs_pos = []
params_pos = []
buffers_pos = []
for i, node in enumerate(gm.graph.nodes):
if i >= len(inputs):
break
if not str(node).startswith("l_self"):
inputs_pos.append(i)
elif "buffer" in str(node):
buffers_pos.append(i)
else:
params_pos.append(i)

params_flat = [inputs[i] for i in params_pos + buffers_pos]

if self._verbose:
print("Graph in tabular form:")
Expand All @@ -271,7 +291,9 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]):
"""Compile a FX graph in Aten/Prims IR to MLIR."""
nonlocal params_flat
func_inputs = []
for inp in _inputs[len(params_flat) :]:
for i in inputs_pos:
# for inp in _inputs[len(params_flat) :]:
inp = _inputs[i]
inp_shape = inp.shape
inp_dtype = self._torch_dtype_translate(str(inp.dtype))
func_inputs.append(TensorMeta(inp_shape, inp_dtype))
Expand All @@ -286,7 +308,22 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]):
self._func_name,
self._verbose
)
for gm_node in _gm.graph.nodes:
param_nodes = []
buffers_nodes = []
input_nodes = []
other_nodes = []
for i, node in enumerate(_gm.graph.nodes):
if i in params_pos:
param_nodes.append(node)
elif i in buffers_pos:
buffers_nodes.append(node)
elif i in inputs_pos:
input_nodes.append(node)
else:
other_nodes.append(node)
gm_nodes = param_nodes + buffers_nodes + input_nodes + other_nodes

for gm_node in gm_nodes:
node_users = []
for user in gm_node.users.keys():
node_users.append(str(user))
Expand Down
24 changes: 24 additions & 0 deletions frontend/Python/graph/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,3 +534,27 @@ class GtOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ElementwiseType


class ScaledDotProductFlashAttentionForCpuOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ElementwiseType


class GeOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ElementwiseType


class GreaterThanOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.BroadcastType


class UnsafeIndexOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ReshapeType
Loading

0 comments on commit 5acfb02

Please sign in to comment.