Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unflatten traced module #954

Merged
merged 28 commits into from
Apr 1, 2024
Merged

Unflatten traced module #954

merged 28 commits into from
Apr 1, 2024

Conversation

kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Feb 28, 2024

Description

  • Move tracer from _export_to_torch_ir to the official torch.export.export
  • Add unflatten utils (from torch/export/unflatten.py) to unflatten each stage module

Purpose of this PR is to:

  • be composable with FSDP and TP, which requires structured FQNs like a.b.c to submodules to specify their policies.
  • be nice to DCP which would not like to see change of FQNs compared to original model.
  • retire use of _export_to_torch_ir per Export Team's plan.

Test

Added test_transformer.py.

class TransformerLike(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.layers = torch.nn.Sequential(
            *[
                MLPModule(d_hid)
                for _ in range(n_layers)
            ]
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)

We split the model into two stages. Each stages would preserve the layers.<i> structure as in the original model.

Stage 0: 
 GraphModule(
  (layers): InterpreterModule(
    (0): InterpreterModule(
      (net1): InterpreterModule()
      (relu): InterpreterModule()
      (net2): InterpreterModule()
    )
    (1): InterpreterModule(
      (net1): InterpreterModule()
      (relu): InterpreterModule()
      (net2): InterpreterModule()
    )
    (2): InterpreterModule(
      (net1): InterpreterModule()
      (relu): InterpreterModule()
      (net2): InterpreterModule()
    )
    (3): InterpreterModule(
      (net1): InterpreterModule()
      (relu): InterpreterModule()
      (net2): InterpreterModule()
    )
  )
)
Stage 1: 
 GraphModule(
  (layers): InterpreterModule(
    (4): InterpreterModule(
      (net1): InterpreterModule()
      (relu): InterpreterModule()
      (net2): InterpreterModule()
    )
    (5): InterpreterModule(
      (net1): InterpreterModule()
      (relu): InterpreterModule()
      (net2): InterpreterModule()
    )
    (6): InterpreterModule(
      (net1): InterpreterModule()
      (relu): InterpreterModule()
      (net2): InterpreterModule()
    )
    (7): InterpreterModule(
      (net1): InterpreterModule()
      (relu): InterpreterModule()
      (net2): InterpreterModule()
    )
  )
)

Caveat:
I temporarily disabled multi-use parameter support (aka. shared paramters or tied parameters). So some real examples may break. Will add the support back in next PR.

@@ -23,19 +23,21 @@
class ExampleCode(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add another test case inspired by torchtrain, where module has a ModuleList like this

Transformer
self.layers = torch.nn.ModuleList()
for layer_id in range(model_args.n_layers):
self.layers.append(TransformerBlock(layer_id, model_args))

What I am wondering is, how will you deal with keeping the FQNs the same after splitting?

if you put layers[0] on one stage and layers[1] on next stage,

will second stage still have layers[1] as FQN or will it drop to layers[0]?

Copy link
Contributor Author

@kwen2501 kwen2501 Mar 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Good idea.
The second stage will have "layers.1" as the FQN.
"layers" is from the self.layers = torch.nn.ModuleList() tier.
"1" corresponds to the attribute within the ModuleList.
Those two are preserved from the original model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wconstab I added test_transformer.py.
You can see the updated PR description for the split structure.



# Add an alias for convenience
aten_pipe_split_alias = torch.ops.pippy._pipe_split.default
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious, whats the purpose of the additional .default

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tugsbayasgalan wondering if you know the answer?

@kwen2501
Copy link
Contributor Author

kwen2501 commented Apr 1, 2024

Status of this branch:

Good:

  • All unit tests passed.
  • GPT2 and LLaMA model worked.

Bad:

I consider that the "Bad" items are not really blocking errors, and since the branch is needed by torchtrain pytorch/torchtitan#161, I am merging this branch into main as is.

@kwen2501 kwen2501 merged commit 77be55d into main Apr 1, 2024
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants