diff --git a/_static/img/compiled_autograd/call_hook_node.png b/_static/img/compiled_autograd/call_hook_node.png new file mode 100644 index 0000000000..3e094cf6f7 Binary files /dev/null and b/_static/img/compiled_autograd/call_hook_node.png differ diff --git a/_static/img/compiled_autograd/entire_verbose_log.png b/_static/img/compiled_autograd/entire_verbose_log.png new file mode 100644 index 0000000000..4ce2b8538e Binary files /dev/null and b/_static/img/compiled_autograd/entire_verbose_log.png differ diff --git a/_static/img/compiled_autograd/recompile_due_to_dynamic.png b/_static/img/compiled_autograd/recompile_due_to_dynamic.png new file mode 100644 index 0000000000..41ae56acf2 Binary files /dev/null and b/_static/img/compiled_autograd/recompile_due_to_dynamic.png differ diff --git a/_static/img/compiled_autograd/recompile_due_to_node.png b/_static/img/compiled_autograd/recompile_due_to_node.png new file mode 100644 index 0000000000..800a178458 Binary files /dev/null and b/_static/img/compiled_autograd/recompile_due_to_node.png differ diff --git a/intermediate_source/compiled_autograd_tutorial.rst b/intermediate_source/compiled_autograd_tutorial.rst index bcae7e63da..1091b19a49 100644 --- a/intermediate_source/compiled_autograd_tutorial.rst +++ b/intermediate_source/compiled_autograd_tutorial.rst @@ -97,62 +97,10 @@ Run the script with the ``TORCH_LOGS`` environment variables: Rerun the snippet above, the compiled autograd graph should now be logged to ``stderr``. Certain graph nodes will have names that are prefixed by ``aot0_``, these correspond to the nodes previously compiled ahead of time in AOTAutograd backward graph 0, for example, ``aot0_view_2`` corresponds to ``view_2`` of the AOT backward graph with id=0. +In the image below, the red box encapsulates the AOT backward graph that is captured by ``torch.compile`` without Compiled Autograd. -.. code:: python - stderr_output = """ - DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[] - DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:TRACED GRAPH - ===== Compiled autograd graph ===== - .4 class CompiledAutograd(torch.nn.Module): - def forward(self, inputs, sizes, scalars, hooks): - # No stacktrace found for following nodes - aot0_tangents_1: "f32[][]cpu" = inputs[0] - aot0_primals_3: "f32[10][1]cpu" = inputs[1] - getitem_2: "f32[10][1]cpu" = inputs[2] - getitem_3: "f32[10, 10][10, 1]cpu" = inputs[3]; inputs = None - - # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: CompiledFunctionBackward0 (NodeCall 1) - aot0_expand: "f32[10][0]cpu" = torch.ops.aten.expand.default(aot0_tangents_1, [10]); aot0_tangents_1 = None - aot0_view_2: "f32[1, 10][0, 0]cpu" = torch.ops.aten.view.default(aot0_expand, [1, 10]); aot0_expand = None - aot0_permute_2: "f32[10, 1][0, 0]cpu" = torch.ops.aten.permute.default(aot0_view_2, [1, 0]) - aot0_select: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 0) - aot0_view: "f32[1, 10][10, 1]cpu" = torch.ops.aten.view.default(aot0_primals_3, [1, 10]); aot0_primals_3 = None - aot0_mul_3: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select, aot0_view); aot0_select = None - aot0_select_1: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 1) - aot0_mul_4: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_1, aot0_view); aot0_select_1 = None - aot0_select_2: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 2) - aot0_mul_5: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_2, aot0_view); aot0_select_2 = None - aot0_select_3: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 3) - aot0_mul_6: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_3, aot0_view); aot0_select_3 = None - aot0_select_4: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 4) - aot0_mul_7: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_4, aot0_view); aot0_select_4 = None - aot0_select_5: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 5) - aot0_mul_8: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_5, aot0_view); aot0_select_5 = None - aot0_select_6: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 6) - aot0_mul_9: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_6, aot0_view); aot0_select_6 = None - aot0_select_7: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 7) - aot0_mul_10: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_7, aot0_view); aot0_select_7 = None - aot0_select_8: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 8) - aot0_mul_11: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_8, aot0_view); aot0_select_8 = None - aot0_select_9: "f32[1][0]cpu" = torch.ops.aten.select.int(aot0_permute_2, 0, 9); aot0_permute_2 = None - aot0_mul_12: "f32[1, 10][10, 1]cpu" = torch.ops.aten.mul.Tensor(aot0_select_9, aot0_view); aot0_select_9 = aot0_view = None - aot0_cat: "f32[10, 10][10, 1]cpu" = torch.ops.aten.cat.default([aot0_mul_3, aot0_mul_4, aot0_mul_5, aot0_mul_6, aot0_mul_7, aot0_mul_8, aot0_mul_9, aot0_mul_10, aot0_mul_11, aot0_mul_12]); aot0_mul_3 = aot0_mul_4 = aot0_mul_5 = aot0_mul_6 = aot0_mul_7 = aot0_mul_8 = aot0_mul_9 = aot0_mul_10 = aot0_mul_11 = aot0_mul_12 = None - aot0_permute_3: "f32[10, 10][1, 10]cpu" = torch.ops.aten.permute.default(aot0_cat, [1, 0]); aot0_cat = None - aot0_sum_3: "f32[1, 10][10, 1]cpu" = torch.ops.aten.sum.dim_IntList(aot0_view_2, [0], True); aot0_view_2 = None - aot0_view_3: "f32[10][1]cpu" = torch.ops.aten.view.default(aot0_sum_3, [10]); aot0_sum_3 = None - - # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: torch::autograd::AccumulateGrad (NodeCall 2) - accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_2, aot0_view_3); getitem_2 = aot0_view_3 = accumulate_grad_ = None - - # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: CompiledFunctionBackward0 (NodeCall 1) - aot0_permute_4: "f32[10, 10][10, 1]cpu" = torch.ops.aten.permute.default(aot0_permute_3, [1, 0]); aot0_permute_3 = None - - # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:483 in set_node_origin, code: torch::autograd::AccumulateGrad (NodeCall 3) - accumulate_grad__1 = torch.ops.inductor.accumulate_grad_.default(getitem_3, aot0_permute_4); getitem_3 = aot0_permute_4 = accumulate_grad__1 = None - _exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None - return [] - """ +.. image:: ../_static/img/compiled_autograd/entire_verbose_log.png .. note:: This is the graph on which we will call ``torch.compile``, **NOT** the optimized graph. Compiled Autograd essentially generates some unoptimized Python code to represent the entire C++ autograd execution. @@ -181,7 +129,7 @@ Or you can use the context manager, which will apply to all autograd calls withi Compiled Autograd addresses certain limitations of AOTAutograd -------------------------------------------------------------- -1. Graph breaks in the forward pass lead to graph breaks in the backward pass: +1. Graph breaks in the forward pass no longer necessarily lead to graph breaks in the backward pass: .. code:: python @@ -216,7 +164,10 @@ Compiled Autograd addresses certain limitations of AOTAutograd In the first ``torch.compile`` case, we see that 3 backward graphs were produced due to the 2 graph breaks in the compiled function ``fn``. Whereas in the second ``torch.compile`` with compiled autograd case, we see that a full backward graph was traced despite the graph breaks. -2. Backward hooks are not captured +.. note:: It is still possible for the Dynamo to graph break when tracing backward hooks captured by Compiled Autograd. + + +2. Backward hooks can now be captured .. code:: python @@ -233,19 +184,7 @@ Whereas in the second ``torch.compile`` with compiled autograd case, we see that There should be a ``call_hook`` node in the graph, which dynamo will later inline into the following: -.. code:: python - - stderr_output = """ - DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[] - DEBUG:torch._dynamo.compiled_autograd.__compiled_autograd_verbose:TRACED GRAPH - ===== Compiled autograd graph ===== - .2 class CompiledAutograd(torch.nn.Module): - def forward(self, inputs, sizes, scalars, hooks): - ... - getitem_2 = hooks[0]; hooks = None - call_hook: "f32[10, 10][0, 0]cpu" = torch__dynamo_external_utils_call_hook(getitem_2, aot0_expand, hook_type = 'tensor_pre_hook'); getitem_2 = aot0_expand = None - ... - """ +.. image:: ../_static/img/compiled_autograd/call_hook_node.png Common recompilation reasons for Compiled Autograd -------------------------------------------------- @@ -261,18 +200,7 @@ Common recompilation reasons for Compiled Autograd In the example above, we call a different operator on each iteration, leading to ``loss`` tracking a different autograd history each time. You should see some recompile messages: **Cache miss due to new autograd node**. -.. code:: python - - stderr_output = """ - Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0) with key size 39, previous key sizes=[] - ... - Cache miss due to new autograd node: SubBackward0 (NodeCall 2) with key size 56, previous key sizes=[] - ... - Cache miss due to new autograd node: MulBackward0 (NodeCall 2) with key size 71, previous key sizes=[] - ... - Cache miss due to new autograd node: DivBackward0 (NodeCall 2) with key size 70, previous key sizes=[] - ... - """ +.. image:: ../_static/img/compiled_autograd/recompile_due_to_node.png 2. Due to tensors changing shapes: @@ -286,16 +214,7 @@ In the example above, we call a different operator on each iteration, leading to In the example above, ``x`` changes shapes, and compiled autograd will mark ``x`` as a dynamic shape tensor after the first change. You should see recompiles messages: **Cache miss due to changed shapes**. -.. code:: python - - stderr_output = """ - ... - Cache miss due to changed shapes: marking size idx 0 of torch::autograd::GraphRoot (NodeCall 0) as dynamic - Cache miss due to changed shapes: marking size idx 1 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic - Cache miss due to changed shapes: marking size idx 2 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic - Cache miss due to changed shapes: marking size idx 3 of torch::autograd::AccumulateGrad (NodeCall 2) as dynamic - ... - """ +.. image:: ../_static/img/compiled_autograd/recompile_due_to_dynamic.png Conclusion ----------