diff --git a/.github/workflows/build-and-publish.yml b/.github/workflows/build-and-publish.yml
index 307ade0e..9592fcfb 100644
--- a/.github/workflows/build-and-publish.yml
+++ b/.github/workflows/build-and-publish.yml
@@ -15,6 +15,7 @@ jobs:
- "accelerated-peft"
- "fused-ops-and-kernels"
- "attention-and-distributed-packing"
+ - "accelerated-moe"
permissions:
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml
index d2f9aea6..8f25a613 100644
--- a/.github/workflows/format.yml
+++ b/.github/workflows/format.yml
@@ -30,6 +30,7 @@ jobs:
- "accelerated-peft"
- "fused-ops-and-kernels"
- "attention-and-distributed-packing"
+ - "accelerated-moe"
steps:
- uses: actions/checkout@v4
diff --git a/README.md b/README.md
index 1158550c..3bdda440 100644
--- a/README.md
+++ b/README.md
@@ -34,7 +34,7 @@ Plugin | Description | Depends | License | Status
[accelerated-peft](./plugins/accelerated-peft/README.md) | For PEFT-training, e.g., 4bit QLoRA. | Huggingface
AutoGPTQ | Apache 2.0
MIT | Alpha
[fused-op-and-kernels](./plugins/fused-ops-and-kernels/README.md) | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 [(contains extracted code)](./plugins/fused-ops-and-kernels/README.md#code-extracted-from-unsloth)| Beta
[attention-and-distributed-packing](./plugins/attention-and-distributed-packing/README.md) | Padding-Free Flash Attention Computation | flash-attn | Apache 2.0 | Beta
- MOE-training-acceleration | [MegaBlocks](https://github.com/databricks/megablocks) inspired triton Kernels and acclerations for Mixture-of-Expert models | | Apache 2.0 | Coming Soon
+[accelerated-moe](./plugins/accelerated-moe/README.md) | Triton Kernels for Mixture-of-Expert parallel, inspired by [ScatterMoe](https://github.com/shawntan/scattermoe) and [MegaBlocks](https://github.com/databricks/megablocks) | | Apache 2.0 | Beta
## Usage with FMS HF Tuning
diff --git a/plugins/accelerated-moe/.isort.cfg b/plugins/accelerated-moe/.isort.cfg
new file mode 100644
index 00000000..7d3762ec
--- /dev/null
+++ b/plugins/accelerated-moe/.isort.cfg
@@ -0,0 +1,10 @@
+[settings]
+profile=black
+from_first=true
+import_heading_future=Future
+import_heading_stdlib=Standard
+import_heading_thirdparty=Third Party
+import_heading_firstparty=First Party
+import_heading_localfolder=Local
+known_firstparty=
+known_localfolder=tuning
\ No newline at end of file
diff --git a/plugins/accelerated-moe/.pylintrc b/plugins/accelerated-moe/.pylintrc
new file mode 100644
index 00000000..32b5dc66
--- /dev/null
+++ b/plugins/accelerated-moe/.pylintrc
@@ -0,0 +1,649 @@
+[MAIN]
+
+# Analyse import fallback blocks. This can be used to support both Python 2 and
+# 3 compatible code, which means that the block might have code that exists
+# only in one or another interpreter, leading to false positives when analysed.
+analyse-fallback-blocks=no
+
+# Clear in-memory caches upon conclusion of linting. Useful if running pylint
+# in a server-like mode.
+clear-cache-post-run=no
+
+# Load and enable all available extensions. Use --list-extensions to see a list
+# all available extensions.
+#enable-all-extensions=
+
+# In error mode, messages with a category besides ERROR or FATAL are
+# suppressed, and no reports are done by default. Error mode is compatible with
+# disabling specific errors.
+#errors-only=
+
+# Always return a 0 (non-error) status code, even if lint errors are found.
+# This is primarily useful in continuous integration scripts.
+#exit-zero=
+
+# A comma-separated list of package or module names from where C extensions may
+# be loaded. Extensions are loading into the active Python interpreter and may
+# run arbitrary code.
+extension-pkg-allow-list=
+
+# A comma-separated list of package or module names from where C extensions may
+# be loaded. Extensions are loading into the active Python interpreter and may
+# run arbitrary code. (This is an alternative name to extension-pkg-allow-list
+# for backward compatibility.)
+extension-pkg-whitelist=
+
+# Return non-zero exit code if any of these messages/categories are detected,
+# even if score is above --fail-under value. Syntax same as enable. Messages
+# specified are enabled, while categories only check already-enabled messages.
+fail-on=
+
+# Specify a score threshold under which the program will exit with error.
+fail-under=10
+
+# Interpret the stdin as a python script, whose filename needs to be passed as
+# the module_or_package argument.
+#from-stdin=
+
+# Files or directories to be skipped. They should be base names, not paths.
+ignore=CVS,protobufs
+
+# Add files or directories matching the regular expressions patterns to the
+# ignore-list. The regex matches against paths and can be in Posix or Windows
+# format. Because '\\' represents the directory delimiter on Windows systems,
+# it can't be used as an escape character.
+ignore-paths=.*megablocks
+
+# Files or directories matching the regular expression patterns are skipped.
+# The regex matches against base names, not paths. The default value ignores
+# Emacs file locks
+ignore-patterns=^\.#
+
+# List of module names for which member attributes should not be checked
+# (useful for modules/projects where namespaces are manipulated during runtime
+# and thus existing member attributes cannot be deduced by static analysis). It
+# supports qualified module names, as well as Unix pattern matching.
+ignored-modules=
+
+# Python code to execute, usually for sys.path manipulation such as
+# pygtk.require().
+#init-hook=
+
+# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
+# number of processors available to use, and will cap the count on Windows to
+# avoid hangs.
+jobs=1
+
+# Control the amount of potential inferred values when inferring a single
+# object. This can help the performance when dealing with large functions or
+# complex, nested conditions.
+limit-inference-results=100
+
+# List of plugins (as comma separated values of python module names) to load,
+# usually to register additional checkers.
+load-plugins=
+
+# Pickle collected data for later comparisons.
+persistent=yes
+
+# Minimum Python version to use for version dependent checks. Will default to
+# the version used to run pylint.
+py-version=3.9
+
+# Discover python modules and packages in the file system subtree.
+recursive=no
+
+# When enabled, pylint would attempt to guess common misconfiguration and emit
+# user-friendly hints instead of false-positive error messages.
+suggestion-mode=yes
+
+# Allow loading of arbitrary C extensions. Extensions are imported into the
+# active Python interpreter and may run arbitrary code.
+unsafe-load-any-extension=no
+
+# In verbose mode, extra non-checker-related info will be displayed.
+#verbose=
+
+
+[BASIC]
+
+# Naming style matching correct argument names.
+argument-naming-style=snake_case
+
+# Regular expression matching correct argument names. Overrides argument-
+# naming-style. If left empty, argument names will be checked with the set
+# naming style.
+#argument-rgx=
+
+# Naming style matching correct attribute names.
+attr-naming-style=snake_case
+
+# Regular expression matching correct attribute names. Overrides attr-naming-
+# style. If left empty, attribute names will be checked with the set naming
+# style.
+#attr-rgx=
+
+# Bad variable names which should always be refused, separated by a comma.
+bad-names=foo,
+ bar,
+ baz,
+ toto,
+ tutu,
+ tata
+
+# Bad variable names regexes, separated by a comma. If names match any regex,
+# they will always be refused
+bad-names-rgxs=
+
+# Naming style matching correct class attribute names.
+class-attribute-naming-style=any
+
+# Regular expression matching correct class attribute names. Overrides class-
+# attribute-naming-style. If left empty, class attribute names will be checked
+# with the set naming style.
+#class-attribute-rgx=
+
+# Naming style matching correct class constant names.
+class-const-naming-style=UPPER_CASE
+
+# Regular expression matching correct class constant names. Overrides class-
+# const-naming-style. If left empty, class constant names will be checked with
+# the set naming style.
+#class-const-rgx=
+
+# Naming style matching correct class names.
+class-naming-style=PascalCase
+
+# Regular expression matching correct class names. Overrides class-naming-
+# style. If left empty, class names will be checked with the set naming style.
+#class-rgx=
+
+# Naming style matching correct constant names.
+const-naming-style=UPPER_CASE
+
+# Regular expression matching correct constant names. Overrides const-naming-
+# style. If left empty, constant names will be checked with the set naming
+# style.
+#const-rgx=
+
+# Minimum line length for functions/classes that require docstrings, shorter
+# ones are exempt.
+docstring-min-length=-1
+
+# Naming style matching correct function names.
+function-naming-style=snake_case
+
+# Regular expression matching correct function names. Overrides function-
+# naming-style. If left empty, function names will be checked with the set
+# naming style.
+#function-rgx=
+
+# Good variable names which should always be accepted, separated by a comma.
+good-names=i,
+ j,
+ k,
+ ex,
+ Run,
+ _
+
+# Good variable names regexes, separated by a comma. If names match any regex,
+# they will always be accepted
+good-names-rgxs=
+
+# Include a hint for the correct naming format with invalid-name.
+include-naming-hint=no
+
+# Naming style matching correct inline iteration names.
+inlinevar-naming-style=any
+
+# Regular expression matching correct inline iteration names. Overrides
+# inlinevar-naming-style. If left empty, inline iteration names will be checked
+# with the set naming style.
+#inlinevar-rgx=
+
+# Naming style matching correct method names.
+method-naming-style=snake_case
+
+# Regular expression matching correct method names. Overrides method-naming-
+# style. If left empty, method names will be checked with the set naming style.
+#method-rgx=
+
+# Naming style matching correct module names.
+module-naming-style=snake_case
+
+# Regular expression matching correct module names. Overrides module-naming-
+# style. If left empty, module names will be checked with the set naming style.
+#module-rgx=
+
+# Colon-delimited sets of names that determine each other's naming style when
+# the name regexes allow several styles.
+name-group=
+
+# Regular expression which should only match function or class names that do
+# not require a docstring.
+no-docstring-rgx=^_
+
+# List of decorators that produce properties, such as abc.abstractproperty. Add
+# to this list to register other decorators that produce valid properties.
+# These decorators are taken in consideration only for invalid-name.
+property-classes=abc.abstractproperty
+
+# Regular expression matching correct type variable names. If left empty, type
+# variable names will be checked with the set naming style.
+#typevar-rgx=
+
+# Naming style matching correct variable names.
+variable-naming-style=snake_case
+
+# Regular expression matching correct variable names. Overrides variable-
+# naming-style. If left empty, variable names will be checked with the set
+# naming style.
+#variable-rgx=
+
+
+[CLASSES]
+
+# Warn about protected attribute access inside special methods
+check-protected-access-in-special-methods=no
+
+# List of method names used to declare (i.e. assign) instance attributes.
+defining-attr-methods=__init__,
+ __new__,
+ setUp,
+ __post_init__
+
+# List of member names, which should be excluded from the protected access
+# warning.
+exclude-protected=_asdict,
+ _fields,
+ _replace,
+ _source,
+ _make
+
+# List of valid names for the first argument in a class method.
+valid-classmethod-first-arg=cls
+
+# List of valid names for the first argument in a metaclass class method.
+valid-metaclass-classmethod-first-arg=mcs
+
+
+[DESIGN]
+
+# List of regular expressions of class ancestor names to ignore when counting
+# public methods (see R0903)
+exclude-too-few-public-methods=
+
+# List of qualified class names to ignore when counting class parents (see
+# R0901)
+ignored-parents=
+
+# Maximum number of arguments for function / method.
+max-args=5
+
+# Maximum number of attributes for a class (see R0902).
+max-attributes=7
+
+# Maximum number of boolean expressions in an if statement (see R0916).
+max-bool-expr=5
+
+# Maximum number of branch for function / method body.
+max-branches=12
+
+# Maximum number of locals for function / method body.
+max-locals=15
+
+# Maximum number of parents for a class (see R0901).
+max-parents=7
+
+# Maximum number of public methods for a class (see R0904).
+max-public-methods=20
+
+# Maximum number of return / yield for function / method body.
+max-returns=6
+
+# Maximum number of statements in function / method body.
+max-statements=50
+
+# Minimum number of public methods for a class (see R0903).
+min-public-methods=2
+
+
+[EXCEPTIONS]
+
+# Exceptions that will emit a warning when caught.
+overgeneral-exceptions=builtins.BaseException,builtins.Exception
+
+
+[FORMAT]
+
+# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
+expected-line-ending-format=
+
+# Regexp for a line that is allowed to be longer than the limit.
+ignore-long-lines=^\s*(# )??$
+
+# Number of spaces of indent required inside a hanging or continued line.
+indent-after-paren=4
+
+# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
+# tab).
+indent-string=' '
+
+# Maximum number of characters on a single line.
+max-line-length=100
+
+# Maximum number of lines in a module.
+max-module-lines=1100
+
+# Allow the body of a class to be on the same line as the declaration if body
+# contains single statement.
+single-line-class-stmt=no
+
+# Allow the body of an if to be on the same line as the test if there is no
+# else.
+single-line-if-stmt=no
+
+
+[IMPORTS]
+
+# List of modules that can be imported at any level, not just the top level
+# one.
+allow-any-import-level=
+
+# Allow explicit reexports by alias from a package __init__.
+allow-reexport-from-package=no
+
+# Allow wildcard imports from modules that define __all__.
+allow-wildcard-with-all=no
+
+# Deprecated modules which should not be used, separated by a comma.
+deprecated-modules=
+
+# Output a graph (.gv or any supported image format) of external dependencies
+# to the given file (report RP0402 must not be disabled).
+ext-import-graph=
+
+# Output a graph (.gv or any supported image format) of all (i.e. internal and
+# external) dependencies to the given file (report RP0402 must not be
+# disabled).
+import-graph=
+
+# Output a graph (.gv or any supported image format) of internal dependencies
+# to the given file (report RP0402 must not be disabled).
+int-import-graph=
+
+# Force import order to recognize a module as part of the standard
+# compatibility libraries.
+known-standard-library=
+
+# Force import order to recognize a module as part of a third party library.
+known-third-party=enchant
+
+# Couples of modules and preferred modules, separated by a comma.
+preferred-modules=
+
+
+[LOGGING]
+
+# The type of string formatting that logging methods do. `old` means using %
+# formatting, `new` is for `{}` formatting.
+logging-format-style=old
+
+# Logging modules to check that the string format arguments are in logging
+# function parameter format.
+logging-modules=logging
+
+
+[MESSAGES CONTROL]
+
+# Only show warnings with the listed confidence levels. Leave empty to show
+# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE,
+# UNDEFINED.
+confidence=HIGH,
+ CONTROL_FLOW,
+ INFERENCE,
+ INFERENCE_FAILURE,
+ UNDEFINED
+
+# Disable the message, report, category or checker with the given id(s). You
+# can either give multiple identifiers separated by comma (,) or put this
+# option multiple times (only on the command line, not in the configuration
+# file where it should appear only once). You can also use "--disable=all" to
+# disable everything first and then re-enable specific checks. For example, if
+# you want to run only the similarities checker, you can use "--disable=all
+# --enable=similarities". If you want to run only the classes checker, but have
+# no Warning level messages displayed, use "--disable=all --enable=classes
+# --disable=W".
+disable=raw-checker-failed,
+ bad-inline-option,
+ locally-disabled,
+ file-ignored,
+ suppressed-message,
+ useless-suppression,
+ deprecated-pragma,
+ # Added messages
+ use-symbolic-message-instead,
+ invalid-name,
+ missing-class-docstring,
+ missing-module-docstring,
+ missing-function-docstring,
+ consider-using-f-string,
+ inconsistent-return-statements,
+ no-member,
+ too-many-arguments,
+ too-many-locals,
+ too-many-branches,
+ too-many-statements,
+ cyclic-import,
+ too-few-public-methods,
+ protected-access,
+ fixme,
+ logging-format-interpolation,
+ logging-too-many-args,
+ attribute-defined-outside-init,
+ abstract-method,
+ pointless-statement,
+ wrong-import-order,
+ duplicate-code,
+ unbalanced-tuple-unpacking,
+ unused-argument
+
+# Enable the message, report, category or checker with the given id(s). You can
+# either give multiple identifier separated by comma (,) or put this option
+# multiple time (only on the command line, not in the configuration file where
+# it should appear only once). See also the "--disable" option for examples.
+enable=c-extension-no-member
+
+
+[METHOD_ARGS]
+
+# List of qualified names (i.e., library.method) which require a timeout
+# parameter e.g. 'requests.api.get,requests.api.post'
+timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request
+
+
+[MISCELLANEOUS]
+
+# List of note tags to take in consideration, separated by a comma.
+notes=FIXME,
+ XXX,
+ TODO
+
+# Regular expression of note tags to take in consideration.
+notes-rgx=
+
+
+[REFACTORING]
+
+# Maximum number of nested blocks for function / method body
+max-nested-blocks=5
+
+# Complete name of functions that never returns. When checking for
+# inconsistent-return-statements if a never returning function is called then
+# it will be considered as an explicit return statement and no message will be
+# printed.
+never-returning-functions=sys.exit,argparse.parse_error
+
+
+[REPORTS]
+
+# Python expression which should return a score less than or equal to 10. You
+# have access to the variables 'fatal', 'error', 'warning', 'refactor',
+# 'convention', and 'info' which contain the number of messages in each
+# category, as well as 'statement' which is the total number of statements
+# analyzed. This score is used by the global evaluation report (RP0004).
+evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))
+
+# Template used to display messages. This is a python new-style format string
+# used to format the message information. See doc for all details.
+msg-template=
+
+# Set the output format. Available formats are text, parseable, colorized, json
+# and msvs (visual studio). You can also give a reporter class, e.g.
+# mypackage.mymodule.MyReporterClass.
+output-format=text
+
+# Tells whether to display a full report or only the messages.
+reports=yes
+
+# Activate the evaluation score.
+score=yes
+
+
+[SIMILARITIES]
+
+# Comments are removed from the similarity computation
+ignore-comments=yes
+
+# Docstrings are removed from the similarity computation
+ignore-docstrings=yes
+
+# Imports are removed from the similarity computation
+ignore-imports=yes
+
+# Signatures are removed from the similarity computation
+ignore-signatures=yes
+
+# Minimum lines number of a similarity.
+min-similarity-lines=4
+
+
+[SPELLING]
+
+# Limits count of emitted suggestions for spelling mistakes.
+max-spelling-suggestions=4
+
+# Spelling dictionary name. Available dictionaries: none. To make it work,
+# install the 'python-enchant' package.
+spelling-dict=
+
+# List of comma separated words that should be considered directives if they
+# appear at the beginning of a comment and should not be checked.
+spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:
+
+# List of comma separated words that should not be checked.
+spelling-ignore-words=
+
+# A path to a file that contains the private dictionary; one word per line.
+spelling-private-dict-file=
+
+# Tells whether to store unknown words to the private dictionary (see the
+# --spelling-private-dict-file option) instead of raising a message.
+spelling-store-unknown-words=no
+
+
+[STRING]
+
+# This flag controls whether inconsistent-quotes generates a warning when the
+# character used as a quote delimiter is used inconsistently within a module.
+check-quote-consistency=no
+
+# This flag controls whether the implicit-str-concat should generate a warning
+# on implicit string concatenation in sequences defined over several lines.
+check-str-concat-over-line-jumps=no
+
+
+[TYPECHECK]
+
+# List of decorators that produce context managers, such as
+# contextlib.contextmanager. Add to this list to register other decorators that
+# produce valid context managers.
+contextmanager-decorators=contextlib.contextmanager
+
+# List of members which are set dynamically and missed by pylint inference
+# system, and so shouldn't trigger E1101 when accessed. Python regular
+# expressions are accepted.
+generated-members=
+
+# Tells whether to warn about missing members when the owner of the attribute
+# is inferred to be None.
+ignore-none=yes
+
+# This flag controls whether pylint should warn about no-member and similar
+# checks whenever an opaque object is returned when inferring. The inference
+# can return multiple potential results while evaluating a Python object, but
+# some branches might not be evaluated, which results in partial inference. In
+# that case, it might be useful to still emit no-member and other checks for
+# the rest of the inferred objects.
+ignore-on-opaque-inference=yes
+
+# List of symbolic message names to ignore for Mixin members.
+ignored-checks-for-mixins=no-member,
+ not-async-context-manager,
+ not-context-manager,
+ attribute-defined-outside-init
+
+# List of class names for which member attributes should not be checked (useful
+# for classes with dynamically set attributes). This supports the use of
+# qualified names.
+ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace
+
+# Show a hint with possible names when a member name was not found. The aspect
+# of finding the hint is based on edit distance.
+missing-member-hint=yes
+
+# The minimum edit distance a name should have in order to be considered a
+# similar match for a missing member name.
+missing-member-hint-distance=1
+
+# The total number of similar names that should be taken in consideration when
+# showing a hint for a missing member.
+missing-member-max-choices=1
+
+# Regex pattern to define which classes are considered mixins.
+mixin-class-rgx=.*[Mm]ixin
+
+# List of decorators that change the signature of a decorated function.
+signature-mutators=
+
+
+[VARIABLES]
+
+# List of additional names supposed to be defined in builtins. Remember that
+# you should avoid defining new builtins when possible.
+additional-builtins=
+
+# Tells whether unused global variables should be treated as a violation.
+allow-global-unused-variables=yes
+
+# List of names allowed to shadow builtins
+allowed-redefined-builtins=
+
+# List of strings which can identify a callback function by name. A callback
+# name must start or end with one of those strings.
+callbacks=cb_,
+ _cb
+
+# A regular expression matching the name of dummy variables (i.e. expected to
+# not be used).
+dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
+
+# Argument names that match this expression will be ignored.
+ignored-argument-names=_.*|^ignored_|^unused_
+
+# Tells whether we should check for unused import in __init__ files.
+init-import=no
+
+# List of qualified module names which can have objects that can redefine
+# builtins.
+redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
\ No newline at end of file
diff --git a/plugins/accelerated-moe/README.md b/plugins/accelerated-moe/README.md
new file mode 100644
index 00000000..ff1c31b2
--- /dev/null
+++ b/plugins/accelerated-moe/README.md
@@ -0,0 +1,95 @@
+# FMS Acceleration for Mixture-of-Experts
+
+This library contains plugins to accelerate finetuning with the following optimizations:
+1. Expert-Parallel MoE with Triton Kernels from ScatterMoE, and some extracted from [megablocks](https://github.com/databricks/megablocks).
+ - Megablocks kernels for `gather` and `scatter`
+
+## Plugins
+
+Plugin | Description | Depends | Loading | Augmentation | Callbacks
+--|--|--|--|--|--
+[scattermoe](./src/fms_acceleration_moe/framework_plugin_scattermoe.py) | MoE Expert Parallel with Triton Kernels from scattermoe (& megablocks) | ScatterMoE / extracted kernels from megablocks | ✅ | | ✅
+
+
+## Adding New Models
+
+Our `ScatterMoe` implementation is a module-swap; to add new models we need to update the specifications in [scattermoe_constants.py](./src/fms_acceleration_moe/utils/scattermoe_constants.py).
+- See the code documentation within to understand how to add new models.
+
+### Using ScatterMoE Saved Checkpoints
+
+`ScatterMoE` checkpoints are saved using `torch.distributed.checkpoint` (DCP) and which is by default `StateDictType.SHARDED_STATE_DICT`:
+- `DTensors` limited support for full state dicts.
+- sharded state dicts are the extremely efficient, and require little comms overhead when saving.
+
+We provide a script to recover back the original checkpoint:
+- currently the script is only tested in the case where DCP has saved the model in a single node.
+
+If the checkpoint is stored in `hf/checkpoint-10`, call the following to have the converted checkpoint written into `output_dir`:
+
+```
+python -m fms_acceleration_moe.utils.checkpoint_utils \
+ hf/checkpoint-10 output_dir \
+ mistralai/Mixtral-8x7B-Instruct-v0.1
+```
+
+## Code Extracted from Megablocks
+
+Notes on code extraction:
+- we have only extracted two `autograd` functions [GatherOp](https://github.com/databricks/megablocks/blob/main/megablocks/ops/gather.py) and [ScatterOp](https://github.com/databricks/megablocks/blob/main/megablocks/ops/scatter.py),
+- and the associated triton kernels from [backend/kernels.py](https://github.com/databricks/megablocks/blob/main/megablocks/backend/kernels.py); mostly the `_padded_copy`.
+
+## Running Benchmarks
+
+
+Run the below in the top-level directory of this repo:
+- the `scattermoe` dep is not included by default, so the `-x` switch installs it.
+- consider disabling the `torch` memory logging to see improved speeds.
+
+```
+tox -e run-benches \
+ -x testenv:run-benches.deps+="-r plugins/accelerated-moe/requirements-khd.txt" \
+ -x testenv:run-benches.setenv+="MEMORY_LOGGING=nvidia" \
+ -- \
+ "1 2 4" 128 benchmark_outputs scenarios-moe.yaml accelerated-moe-scatter
+```
+or run the larger `Mixtral-8x7B` bench:
+```
+tox ... \
+ 8 128 benchmark_outputs scenarios-moe.yaml accelerated-moe-scatter-mixtral
+```
+
+NOTE: if `FileNotFoundError` is observed on the *triton cache*, similar to issues like these:
+- https://github.com/triton-lang/triton/issues/2688
+
+then somehow `tox` is causing problems with triton and multiprocessing (there is some race condition).
+But the workaound is to first *activate the tox env* and
+running in `bash`:
+```
+# if FileNotFoundError in the triton cache is observed
+# - then activate the env and run the script manually
+
+source .tox/run-benches/bin/activate
+bash scripts/run_benchmarks.sh \
+ ....
+```
+
+
+### Triton Kernel Dependencies
+
+Currently we do not copy the `scattermoe` kernels into this respository, to this is an additional manual install:
+
+```
+# this will install the kernel-hyperdrive fork with the scattermoe triton kernels
+pip install -r requirements-khd.txt
+```
+
+### Known Issues
+
+These are currently some known issues not yet resolved:
+- should eventually remove the dependency on an external `kernel-hyperdrive` repository.
+- now support only loading *sharded* `safetensor` non-GGUF MoE checkpoints. This is a reasonable assumption since MoE checkpoints are typically above the size limit that prevents it being saved into a single checkpoint filed.
+- when used together with FSDP, the FSDP's `clip_grad_norm` will not properly compute for `ScatterMoE`, see [issue here](https://github.com/foundation-model-stack/fms-acceleration/issues/109).
+
+
+
diff --git a/plugins/accelerated-moe/configs/scattermoe.yaml b/plugins/accelerated-moe/configs/scattermoe.yaml
new file mode 100644
index 00000000..63623694
--- /dev/null
+++ b/plugins/accelerated-moe/configs/scattermoe.yaml
@@ -0,0 +1,16 @@
+training:
+
+ # mixture-of-experts configurations
+ moe:
+
+ # expert-parallel for MoE
+ scattermoe:
+
+ # The level of expert parallel sharding.
+ # - 1 means no sharding
+ # - if > 1, please ensure that this divides the world_size. This is because
+ # the devices will be replicated for every ep_degree devices, and
+ # the experts will be sharded within each group.
+ # - if > 1, also ensure that it divides the number of experts, as each device
+ # will then have num_of_experts / ep_degree experts.
+ ep_degree: 1
\ No newline at end of file
diff --git a/plugins/accelerated-moe/pyproject.toml b/plugins/accelerated-moe/pyproject.toml
new file mode 100644
index 00000000..5d522425
--- /dev/null
+++ b/plugins/accelerated-moe/pyproject.toml
@@ -0,0 +1,29 @@
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[project]
+name = "fms-acceleration-moe"
+version = '0.0.1'
+description = "FMS Acceleration Plugin for Mixture-of-Experts"
+authors = [
+ {name = "Fabian Lim", email = "flim@sg.ibm.com"},
+]
+license = {text = "Apache-2.0"}
+readme = "README.md"
+requires-python = "~=3.9"
+keywords = ['fms-hf-tuning', 'acceleration', 'mixture-of-experts', 'scattermoe', 'megablocks']
+classifiers=[
+ "License :: OSI Approved :: Apache Software License",
+ "Development Status :: 4 - Beta",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+]
+
+[tool.hatch.build.targets.wheel]
+only-include = ["src/fms_acceleration_moe"]
+
+[tool.hatch.build.targets.wheel.sources]
+"src" = ""
diff --git a/plugins/accelerated-moe/requirements-khd.txt b/plugins/accelerated-moe/requirements-khd.txt
new file mode 100644
index 00000000..497bf78e
--- /dev/null
+++ b/plugins/accelerated-moe/requirements-khd.txt
@@ -0,0 +1,2 @@
+# fork of https://github.com/mayank31398/kernel-hyperdrive/
+kernel-hyperdrive @ git+https://github.com/fabianlim/kernel-hyperdrive.git
\ No newline at end of file
diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/__init__.py b/plugins/accelerated-moe/src/fms_acceleration_moe/__init__.py
new file mode 100644
index 00000000..a1b41417
--- /dev/null
+++ b/plugins/accelerated-moe/src/fms_acceleration_moe/__init__.py
@@ -0,0 +1,17 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Local
+from .framework_plugin_scattermoe import ScatterMoEAccelerationPlugin
diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py
new file mode 100644
index 00000000..148a5488
--- /dev/null
+++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py
@@ -0,0 +1,123 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Standard
+from typing import Dict
+
+# Third Party
+from fms_acceleration import AccelerationPlugin
+from transformers import AutoModelForCausalLM
+import torch
+
+# Local
+from .utils import (
+ patch_huggingface_save_and_load_for_dtensors,
+ patch_torch_optim_foreach_to_not_apply_to_dtensors,
+ prepare_scattermoe,
+)
+
+
+# pylint: disable=too-many-instance-attributes
+class ScatterMoEAccelerationPlugin(AccelerationPlugin):
+
+ # NOTE: we cannot do
+ # - require_packages = {"khd"}
+ # this is because the khd fork is not properly packaged as a PyPI project, and so
+ # - "importlib.util.find_spec('khd')" returns, but
+ # - "importlib.metadata.version('kernel-hyperdrive')" does not return
+ # if we decide to extract the kernels, then we do not need to anymore,
+ # https://github.com/foundation-model-stack/fms-acceleration/issues/105
+
+ restricted_model_archs = ["GraniteMoeForCausalLM", "MixtralForCausalLM"]
+
+ def __init__(self, configurations: Dict[str, Dict]):
+ super().__init__(configurations)
+
+ # ep_degree determines the expert parallel sharding
+ # - default of 1 means experts are not sharded and operate in pure replication.
+ self._ep_degree = self._check_config_and_maybe_check_values(
+ key="training.moe.scattermoe.ep_degree",
+ default=1,
+ )
+
+ @property
+ def requires_custom_loading(self):
+ return True
+
+ def model_loader(self, model_name: str, **kwargs):
+
+ # load the model
+ model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
+
+ rank, world_size = 0, 1
+ if torch.distributed.is_initialized():
+ world_size = torch.distributed.get_world_size()
+ rank = torch.distributed.get_rank()
+
+ # shard the MOE, and store the component names, eventually needed
+ # to configure the FSDP
+ self._moe_component_module_names = prepare_scattermoe(
+ model,
+ checkpoint_name_or_path=model_name,
+ rank=rank,
+ world_size=world_size,
+ ep_degree=self._ep_degree,
+ mixed_precision=False, # Currently this is hardcoded to OFF
+ )
+
+ # NOTE: there is currently no good way to get the mixed precision
+ # flag from train_args. It will be better to handle this if
+ # when we move the sharding to augmentation.
+ # https://github.com/foundation-model-stack/fms-acceleration/issues/103
+
+ return model
+
+ def get_callbacks_and_ready_for_train(
+ self, model: torch.nn.Module = None, accelerator=None
+ ):
+
+ callbacks = []
+ if (
+ accelerator is not None
+ and getattr(accelerator.state, "fsdp_plugin", None) is not None
+ ):
+
+ # - use an internal function call to get the no split
+ # module names, which are typically layers
+ _layers = model._get_no_split_modules("")
+ accelerator.state.fsdp_plugin.ignored_modules = [
+ getattr(layer, name)
+ for name in self._moe_component_module_names
+ for layer in model.modules()
+ if layer.__class__.__name__ in _layers
+ ]
+
+ # call this to patch the HF save and load functions to be able
+ # to save DTensors propery
+ patch_huggingface_save_and_load_for_dtensors()
+
+ # call this to patch torch optim to not use
+ # foreach for dtensors
+ patch_torch_optim_foreach_to_not_apply_to_dtensors()
+
+ return callbacks
+
+
+# register
+AccelerationPlugin.register_plugin(
+ ScatterMoEAccelerationPlugin,
+ configuration_and_paths=[
+ "training.moe.scattermoe",
+ ],
+)
diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py
new file mode 100644
index 00000000..660d2252
--- /dev/null
+++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py
@@ -0,0 +1,43 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Local
+from .checkpoint_utils import patch_huggingface_save_and_load_for_dtensors
+from .scattermoe_prepare import prepare_scattermoe
+
+# this is a special patch function to disable foreach for
+# dtensors, which has been introduced since torch 2.4.
+# The reason is because this will cause problems in the optimizer
+# RuntimeError: aten._foreach_mul_.Scalar: got mixed torch.Tensor and DTensor,
+# need to convert all torch.Tensor to DTensor before calling distributed operators!
+
+
+# - this function patches torch
+def patch_torch_optim_foreach_to_not_apply_to_dtensors():
+ # guarded.
+ # this is an array of supported types, we will remove
+ # dtensor from it, so the optimizer will faillback to per
+ # parameter
+ # Third Party
+ # pylint: disable=import-outside-toplevel
+ from torch.optim.optimizer import _foreach_supported_types
+
+ i = 0 # list index
+ while i < len(_foreach_supported_types):
+ x = _foreach_supported_types[i]
+ if x.__name__ == "DTensor":
+ # pop from list
+ _foreach_supported_types.pop(i)
+ else:
+ i += 1
diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py
new file mode 100644
index 00000000..d8d33b18
--- /dev/null
+++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py
@@ -0,0 +1,468 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Standard
+from collections import defaultdict
+from typing import List
+import json
+import os
+import re
+
+# Third Party
+from accelerate.logging import get_logger
+from accelerate.utils.constants import FSDP_MODEL_NAME, OPTIMIZER_NAME
+from torch.distributed.checkpoint.default_planner import (
+ DefaultLoadPlanner,
+ DefaultSavePlanner,
+)
+from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
+from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
+from transformers import PretrainedConfig
+import torch
+import torch.distributed.checkpoint as dcp
+
+# Local
+from .scattermoe_constants import (
+ FILE_SAFETENSOR_INDEX,
+ PARAM_NAME_ROUTER_SCATTERMOE,
+ PARAM_NAME_WEIGHT_SCATTERMOE,
+ get_scattermoe_conv_spec_from_archs,
+)
+from .scattermoe_state_dict import get_checkpoint_meta_from_sharded_safetensor
+
+logger = get_logger(__name__)
+
+# - variable to capture the model variable
+# in the save/load model calls
+MODEL_INDEX = None
+KEY_MODEL = "model"
+KEY_OPTIMIZER = "optimizer"
+
+# Below are rewrite of HF FSDP model saving functions to be able to handle
+# that the parameters are now a mixture of regular and Dtensors.
+# - these functions are found in accelerate.utils.fsdp_utils.py
+# - save_fsdp_model, save_fsdp_optimizer, load_fsdp_model, load_fsdp_optimizer
+# NOTE: we will observe warnings such as
+# /torch/distributed/checkpoint/state_dict.py:520:
+# FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
+
+
+# rewrite of func from accelerate.utils.fsdp_utils.py
+# - empty function, the main logic will be in save_fsdp_optimizer (see below).
+def save_fsdp_model(
+ fsdp_plugin, accelerator, model, output_dir, model_index=0, adapter_only=False
+):
+ # pylint: disable=global-statement
+ global MODEL_INDEX
+ MODEL_INDEX = model_index
+
+
+# rewrite of func from accelerate.utils.fsdp_utils.py
+# - saves both model and optimizer at the same time
+def save_fsdp_optimizer(
+ fsdp_plugin, accelerator, optimizer, model, output_dir, optimizer_index=0
+):
+
+ if fsdp_plugin.state_dict_type != StateDictType.SHARDED_STATE_DICT:
+ raise NotImplementedError(
+ "Checkpointing for megablocks only enabled for sharded state dict."
+ )
+
+ # get the state dicts for model and optimize
+ (model_state_dict, optimizer_state_dict) = get_state_dict(model, optimizer)
+
+ # - save model
+ ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}")
+ os.makedirs(ckpt_model, exist_ok=True)
+ logger.info(f"Saving model to {ckpt_model}")
+ dcp.save(
+ state_dict={KEY_MODEL: model_state_dict},
+ storage_writer=dcp.FileSystemWriter(ckpt_model),
+ planner=DefaultSavePlanner(),
+ )
+ logger.info(f"Model saved to {ckpt_model}")
+
+ # - save optimizer
+ ckpt_opt = os.path.join(output_dir, f"{OPTIMIZER_NAME}_{optimizer_index}")
+ os.makedirs(ckpt_opt, exist_ok=True)
+ logger.info(f"Saving Optimizer state to {ckpt_opt}")
+ dcp.save(
+ state_dict={KEY_OPTIMIZER: optimizer_state_dict},
+ storage_writer=dcp.FileSystemWriter(ckpt_opt),
+ planner=DefaultSavePlanner(),
+ )
+ logger.info(f"Optimizer state saved in {ckpt_opt}")
+
+
+# rewrite of func from accelerate.utils.fsdp_utils.py
+# - empty function, main logic in load_fsdp_optimizer (see below).
+def load_fsdp_model(
+ fsdp_plugin, accelerator, model, input_dir, model_index=0, adapter_only=False
+):
+ # pylint: disable=global-statement
+ global MODEL_INDEX
+ MODEL_INDEX = model_index
+
+
+# rewrite of func from accelerate.utils.fsdp_utils.py
+# - loads both model and optimizer
+def load_fsdp_optimizer(
+ fsdp_plugin,
+ accelerator,
+ optimizer,
+ model,
+ input_dir,
+ optimizer_index=0,
+ adapter_only=False,
+):
+
+ accelerator.wait_for_everyone()
+ if fsdp_plugin.state_dict_type != StateDictType.SHARDED_STATE_DICT:
+ raise NotImplementedError(
+ "Checkpointing for megablocks only enabled for sharded state dict."
+ )
+
+ # - get the state dicts
+ model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
+
+ # - load the model state dict
+ ckpt_model = os.path.join(input_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}")
+ dcp.load(
+ state_dict={KEY_MODEL: model_state_dict},
+ storage_reader=dcp.FileSystemReader(ckpt_model),
+ planner=DefaultLoadPlanner(),
+ )
+
+ # - load the optimizer state dict
+ ckpt_opt = os.path.join(input_dir, f"{OPTIMIZER_NAME}_{optimizer_index}")
+ dcp.load(
+ state_dict={KEY_OPTIMIZER: optimizer_state_dict},
+ storage_reader=dcp.FileSystemReader(ckpt_opt),
+ planner=DefaultLoadPlanner(),
+ )
+
+ # - set the state dicts
+ set_state_dict(
+ model,
+ optimizer,
+ model_state_dict=model_state_dict,
+ optim_state_dict=optimizer_state_dict,
+ )
+
+ # FIXME:
+ # - We see errors that occur in optimizer.step()
+ # - torch/optim/optimizer.py", line 89, in _use_grad
+ # - torch/optim/adamw.py", line 214, in step beta1,
+ # beta2 = cast(Tuple[float, float], group["betas"])
+ # - KeyError: 'betas'
+ # - Fortunately, this seems to be limited to the empty groups case, where
+ # it seems that it is just the params are not initialized. Since we suppose
+ # these groups are never used, we simply initialize the empty groups with
+ # random values so the errors do not throw.
+ for group in optimizer.param_groups:
+ if len(group["params"]) == 0:
+ group["betas"] = (0.9, 0.999)
+ group["lr"] = 0.0
+ group["initial_lr"] = 0.0
+ group["eps"] = 1e-8
+ group["weight_decay"] = 0.0
+
+
+# function to replace various trainer functions in HF with the ones
+# above
+def patch_huggingface_save_and_load_for_dtensors():
+ # Third Party
+ # NOTE: this is really a global replacement, which we use the patcher
+ # to do
+ # pylint: disable=import-outside-toplevel
+ from fms_acceleration.model_patcher import patch_target_module
+
+ patch_target_module("transformers.trainer.save_fsdp_model", save_fsdp_model)
+ patch_target_module("transformers.trainer.save_fsdp_optimizer", save_fsdp_optimizer)
+ patch_target_module("transformers.trainer.load_fsdp_model", load_fsdp_model)
+ patch_target_module("transformers.trainer.load_fsdp_optimizer", load_fsdp_optimizer)
+
+
+# this function implements a trick to get the resolved cache file to acccess the safetensor
+# - NOTE: does not work if _dict_from_json_file is not called, such as in the case of GGUF files.
+def get_resolved_checkpoint_location(model_name_or_path: str):
+
+ result = None
+ _old_func = PretrainedConfig._dict_from_json_file
+
+ def _dict_from_json_file(resolved_config_file):
+ nonlocal result
+ result = resolved_config_file
+ return _old_func(resolved_config_file)
+
+ # make a hook and restrive
+ PretrainedConfig._dict_from_json_file = _dict_from_json_file
+ PretrainedConfig.from_pretrained(model_name_or_path)
+ PretrainedConfig._dict_from_json_file = _old_func
+ return os.path.dirname(result)
+
+
+# function to get the ScatterMoE state dict from its DCP checkpoint
+# - if the original pretrained_model_name_or_path is specified, will use the checkpoint as hints
+# to map the ScatterMoE checkpoint to that of the original model. This is useful so that we
+# can restore the checkpoint to be loaded by the original architecture.
+def recover_original_state_dict_from_dcp_checkpoint(
+ dcp_checkpoint_dir: str,
+ pretrained_model_name_or_path: str = None,
+):
+ """
+ Parameters:
+ dcp_checkpoint_dir (str): the DCP to be converted.
+ pretrained_model_name_or_path (str): Optional, if provided we will
+ use the hints to remap the
+ """
+
+ # reference dcp_to_torch_save from torch.distributed.checkpoint.format_utils.py
+ # - strategy is to use _EmptyStateDictLoadPlanner to populate the state dict, then we remap
+
+ # guarded, load some internal functions
+ # pylint: disable=import-outside-toplevel
+ # Third Party
+ from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
+ from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
+ from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
+
+ sd: STATE_DICT_TYPE = {}
+ _load_state_dict(
+ sd,
+ storage_reader=dcp.FileSystemReader(dcp_checkpoint_dir),
+ planner=_EmptyStateDictLoadPlanner(),
+ no_dist=True,
+ )
+ sd = sd[KEY_MODEL]
+
+ # if not provided
+ if pretrained_model_name_or_path is None:
+ return sd
+
+ # now do the remap
+ loc = get_resolved_checkpoint_location(pretrained_model_name_or_path)
+ with open(os.path.join(loc, FILE_SAFETENSOR_INDEX), encoding="utf-8") as f:
+ index = json.load(f)
+
+ # config
+ config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path)
+
+ (
+ _,
+ router_name,
+ expert_name,
+ __,
+ sharded_expert_ckpt,
+ ) = get_scattermoe_conv_spec_from_archs(config.architectures)
+
+ # the sd from the module swap must have keys like
+ # 'model.layers.0.block_sparse_moe.w1.weight'
+ # 'model.layers.0.block_sparse_moe.w2.weight'
+ # 'model.layers.0.block_sparse_moe.router.weight'
+ # so we use this fact to infer that
+ # prefix = model.layers.0 and module_name = block_sparse_moe
+
+ def _infer_prefixes_and_module_names(
+ sd_keys: List[str],
+ min_count: int = 3,
+ ):
+ _name = "|".join([PARAM_NAME_ROUTER_SCATTERMOE, *PARAM_NAME_WEIGHT_SCATTERMOE])
+ # pylint: disable=anomalous-backslash-in-string
+ _reg = re.compile(f"(.*)\.({_name})\.weight")
+ found = {}
+
+ for k in sd_keys:
+ m = _reg.match(k)
+ if m is None:
+ continue
+
+ prefix, _ = m.groups()
+ found[prefix] = 1 + found.get(prefix, 0)
+
+ results = []
+ for prefix, cnt in found.items():
+ # if at least router, w1 and w2 are found, take it
+ # otherwise we delete
+ if cnt >= min_count:
+ results.append(prefix)
+
+ return results
+
+ for prefix in _infer_prefixes_and_module_names(sd.keys()):
+ prefix = prefix.split(".")
+ prefix, module_name = ".".join(prefix[:-1]), prefix[-1]
+
+ # checkpoint metadata is will be a map
+ # key -> list of tuples
+ # where each in the list is (param_name, stfile)
+ # - if the list is larger than one, it means that the
+ # actual model has a sharded checkpoint
+
+ # defaultdict(list,
+ # {'w1.weight': [('model.layers.0.block_sparse_moe.input_linear.weight',
+ # 'model-00001-of-00002.safetensors')],
+ # 'w3.weight': [('model.layers.0.block_sparse_moe.input_linear.weight',
+ # 'model-00001-of-00002.safetensors')],
+ # 'w2.weight': [('model.layers.0.block_sparse_moe.output_linear.weight',
+ # 'model-00001-of-00002.safetensors')],
+ # 'router.weight': [('model.layers.0.block_sparse_moe.router.layer.weight',
+ # 'model-00001-of-00002.safetensors')]})
+
+ checkpoint_metadata = get_checkpoint_meta_from_sharded_safetensor(
+ index["weight_map"],
+ prefix,
+ module_name,
+ router_name,
+ expert_name,
+ )
+
+ model2scatter = defaultdict(dict)
+ # construct a map of model_key -> {scatter_key: [params, ...]}
+ # - if the param list > 1, that means many scatter keys map to 1
+ # model param and they need to be cat
+ for scatter_key, list_of_params in checkpoint_metadata.items():
+ scatter_key_fqdn = ".".join([prefix, module_name, scatter_key])
+ scatter_param = sd[scatter_key_fqdn]
+
+ # remove from state dict
+ del sd[scatter_key_fqdn]
+
+ n = len(list_of_params)
+ if scatter_key.startswith(PARAM_NAME_ROUTER_SCATTERMOE):
+ assert n == 1, "Router parameters should not be sharded."
+ elif not sharded_expert_ckpt:
+ assert n == 1, "Expert weights expected to be non-sharded."
+ else:
+ # if sharded, we just assume that there should be 1 expert
+ # per shard
+ assert (
+ n == scatter_param.shape[0]
+ ), "Sharded expert weights should be 1 expert per shard."
+
+ if any(scatter_key.startswith(k) for k in PARAM_NAME_WEIGHT_SCATTERMOE):
+ scatter_param = scatter_param.permute(0, 2, 1)
+
+ # go through all the model keys
+
+ for i, (model_key, _) in enumerate(list_of_params):
+ if n == 1:
+ # handles routers and non-sharded experts case
+ _param = scatter_param
+ else:
+ # then it needs to be sharded
+ _param = scatter_param[i]
+
+ model2scatter[model_key][scatter_key] = _param
+
+ # replace them back in the sd
+ for model_key in list(model2scatter.keys()):
+
+ scatter_params = model2scatter[model_key]
+
+ # - there is an assumption that the ifthere is a cat, then
+ # it will go by order of scatter keys
+ scatter_keys = sorted(scatter_params.keys())
+
+ assert (
+ len(scatter_keys) > 0
+ ), f"Obtained zero scatter keys for model_key '{model_key}'"
+
+ if len(scatter_keys) == 1:
+ sd[model_key] = scatter_params[scatter_keys[0]]
+ else:
+ # unfortunately, there this is a in
+ # scattermoe_state_dict._maybe_reshape_scattermoe_expert_weights
+ # that we split on the dim=1, so we cat back on that
+ sd[model_key] = torch.cat(
+ [scatter_params[k] for k in scatter_keys], dim=1
+ )
+
+ # remove from this intemediate mapping
+ del model2scatter[model_key]
+
+ rem_keys = ",".join(list(model2scatter))
+ assert len(rem_keys) == 0, f"Did not handle model parameters '{rem_keys}'"
+
+ return sd
+
+
+# --------------------------- SCRIPT -------------------------
+
+
+# have it serve as a conversion script
+if __name__ == "__main__":
+ # Standard
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ description=(
+ "Utility for converting ScatterMoE checkpoint back to the "
+ "orginal state dict format. "
+ "The ScatterMoE checkpoint was saved after the pretrained model "
+ "had been converted by a module swap, hence the state dict will "
+ "no longer resemble the original. This utility creaes"
+ )
+ )
+
+ parser.add_argument(
+ "dcp_checkpoint_dir",
+ help="Path to the distributed checkpoint.",
+ )
+
+ parser.add_argument(
+ "output_dir", help="Path to the location to write the converted checkpoint."
+ )
+
+ parser.add_argument(
+ "pretrained_model_name_or_path",
+ help=(
+ "In order to reconstruct the state dict, we requre hints from "
+ "the original pretrained model checkpoint (from which this "
+ "checkpoint is obtained)."
+ ),
+ )
+
+ args = parser.parse_args()
+
+ # search for the checkpint. By the code above, it must
+ # start with FSDP_MODEL_NAME
+ if args.dcp_checkpoint_dir.startswith(FSDP_MODEL_NAME):
+ checkpoint_dir = args.dcp_checkpoint_dir
+ else:
+ checkpoint_dir = [
+ x
+ for x in os.listdir(args.dcp_checkpoint_dir)
+ if os.path.isdir(os.path.join(args.dcp_checkpoint_dir, x))
+ and x.startswith(FSDP_MODEL_NAME)
+ ]
+ if len(checkpoint_dir) > 1:
+ raise ValueError(
+ f"Found > 1 dirs in dcp checkpoint dir {args.dcp_checkpoint_dir} "
+ f"that starts with {FSDP_MODEL_NAME}. Please spectify the exact dir."
+ )
+ if len(checkpoint_dir) == 0:
+ raise ValueError(
+ f"Found no dirs in dcp checkpoint dir {args.dcp_checkpoint_dir} "
+ f"that starts with {FSDP_MODEL_NAME}. Nothing to convert"
+ )
+ checkpoint_dir = os.path.join(args.dcp_checkpoint_dir, checkpoint_dir[0])
+
+ # get the converted statedict
+ state_dict = recover_original_state_dict_from_dcp_checkpoint(
+ checkpoint_dir, args.pretrained_model_name_or_path
+ )
+
+ # save it
+ torch.save(state_dict, args.output_dir)
diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py
new file mode 100644
index 00000000..152bee1d
--- /dev/null
+++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py
@@ -0,0 +1,507 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Standard
+from typing import Tuple
+
+# Third Party
+from peft import LoraConfig
+from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND
+from torch.distributed._tensor import DTensor
+
+# pylint: disable=import-error
+from torch.distributed._tensor.device_mesh import DeviceMesh
+from transformers.activations import ACT2FN
+import torch
+import torch.nn.functional as F
+
+try:
+ # Third Party
+ from khd.kernels.scattermoe.triton_implementation.ops import (
+ padded_block_indices,
+ scattered_experts,
+ )
+except ImportError as e:
+ raise ImportError(
+ "kernel-hyperdrive PyPI package not found. Install it: "
+ "pip install -r plugins/accelerated-moe/requirements-scattermoe.txt"
+ ) from e
+
+# Local
+from .scattermoe_constants import SCATTERMOE_SPEC_HAS_GATE
+from .scattermoe_utils import all_to_all_gather_inputs, scatter_with_routing_weights
+
+
+# helper function to fetch the local tensor if its a dtensor
+def _maybe_get_local_tensor(weight: torch.Tensor):
+ if isinstance(weight, DTensor):
+ return weight.to_local()
+ return weight
+
+
+class ScatteredExperts(torch.nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ num_experts: int,
+ fan_out: int,
+ grouped_in: bool = False,
+ grouped_out: bool = False,
+ dtype: torch.dtype = torch.bfloat16,
+ device: torch.device = torch.device("cpu"),
+ lora_config: LoraConfig = None,
+ ):
+ """
+ ScatteredExperts is the module that implements a group of experts. The
+ forward function will call scattermoe triton kernels.
+
+ NOTE: in the current implementation, we do not initialize the parameters.
+ We assume this will be done outside.
+
+ Paramters:
+ in_features (int): num of input features per expert.
+ out_features (int): num of output features per expert.
+ num_experts (int): the number of experts.
+ fan_out (int): if the number of embedding inputs are expected to be
+ a factor less than the bind_ids and indices at the forward.
+ grouped_in (bool): if the embedding inputs are expected to be already
+ grouped in at the forward.
+ grouped_out (bool): if the outputs are expected to be grouped
+ when they are returned from the forward.
+ dtype (torch.dtype): the dtype of the parameter tensors. Note, for now the
+ adapter takes the same dtype as base layer if LoRA is enabled.
+ device (torch.device): the cuda device in which the model should be loaded.
+ Only cuda is supported since only triton kernels are supported.
+ lora_config (peft.LoraConfig): Optional, to be passed if lora is to be used.
+ """
+ super().__init__()
+
+ # parameters the experts (not initialized).
+ self.weight = torch.nn.Parameter(
+ torch.empty(
+ num_experts,
+ in_features,
+ out_features,
+ dtype=dtype,
+ device=device,
+ ),
+ requires_grad=True,
+ )
+
+ # handle lora embeddings
+ self.lora_A, self.lora_B = None, None
+ self.lora_r = 0
+ if lora_config is not None:
+ # if LoRA, then disable gradient for the base layer.
+ self.weight.requires_grad = False
+
+ # NOTE : - for now adapter takes same dtype as base
+ self.lora_A = torch.nn.Parameter(
+ torch.empty(
+ num_experts,
+ in_features,
+ lora_config.r,
+ dtype=dtype,
+ device=device,
+ ),
+ requires_grad=True,
+ )
+ self.lora_B = torch.nn.Parameter(
+ torch.empty(
+ num_experts,
+ lora_config.r,
+ out_features,
+ dtype=dtype,
+ device=device,
+ ),
+ requires_grad=True,
+ )
+ self.lora_r = lora_config.r
+
+ # store these settings
+ self.fan_out = fan_out
+ self.grouped_in = grouped_in
+ self.grouped_out = grouped_out
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ bin_ids: torch.Tensor,
+ indices: torch.Tensor,
+ padded_block_idxs: torch.Tensor,
+ expert_offsets: torch.Tensor,
+ gates: torch.Tensor = None,
+ ):
+ """
+ ScatteredExperts executes grouped forwards where each group is a single expert.
+
+ Parameters:
+ x (tensor): the emebeddings fed as input.
+ bin_ids (tensor): the expert index where each embedding is to be sent.
+ Expect that these indices are sorted.
+ indices (tensor): the sorting index that brings the input embeddings to the
+ sorted order corresponding to bin_ids.
+ padded_block_idxs (tensor): the indices for passing triton block info to the
+ scattermoe kernels.
+ expert_offsets (tensor): the offsets for passing triton block info to the
+ scattermoe kernels.
+ gates (tensor): Optional. the weighting coefficients that should be applied
+ at the output of the scattermoe kernels.
+ """
+ weight = _maybe_get_local_tensor(self.weight)
+ lora_A, lora_B = None, None
+ if self.lora_r > 0:
+ lora_A, lora_B = (
+ _maybe_get_local_tensor(self.lora_A),
+ _maybe_get_local_tensor(self.lora_B),
+ )
+
+ # NOTE: x is of shape (seqlen, in_features)
+ # bin_ids is of shape (seqlen,)
+ # padded_block_idxs is a 1-dim tensor, whose length depends on
+ # triton kernel block size and input.
+ # expert_offsets is of shape (num_experts, )
+ return scattered_experts(
+ x,
+ weight,
+ self.fan_out,
+ bin_ids, # sorted_expert_idxs,
+ indices, # sorted_scattered_idxs,
+ padded_block_idxs,
+ expert_offsets,
+ gates=gates, # we dont have router weights
+ grouped_in=self.grouped_in,
+ grouped_out=self.grouped_out,
+ expert_lora_A=lora_A,
+ expert_lora_B=lora_B,
+ lora_alp=self.lora_r,
+ )
+
+
+# NOTE: this name should match scattermoe_constants.CLASS_NAME_SCATTERMOE
+# similar to of MoE_Triton from https://github.com/mayank31398/kernel-hyperdrive
+# and ParameterizedScatteredExperts from
+# https://github.com/IBM/dolomite-engine/blob/main/dolomite_engine/hf_models/models/moe_dolomite/moe/scatter.py
+# - support expert parallel where the data is communicated via all_to_all
+# pylint: disable=too-many-instance-attributes
+class ScatterMoE(torch.nn.Module):
+
+ def __init__(
+ self,
+ hidden_size: int,
+ hidden_act: str,
+ intermediate_size: int,
+ num_experts: int,
+ has_bias: bool = False,
+ mlp_arch: str = None,
+ top_k: int = 2,
+ dtype: torch.dtype = torch.bfloat16,
+ device: str = torch.device("cpu"),
+ ep_device_mesh: DeviceMesh = None,
+ lora_config: LoraConfig = None,
+ ):
+ """
+ ScatterMoE is the module swap that replaces a sparse mixture-of-experts module
+ in order to run the scatter moe kernels and the all_to_all expert parallel routines.
+
+ The submodules are a i) router (nn.Linear) and ii) w1, w2, ... (ScatteredExperts);
+ the latter hold the expert weights and run the triton kernels.
+
+ Parameters:
+
+ hidden_size (int): the hidden dimension.
+ hidden_act (str): the activation fucntion.
+ intermediate_size (int): the intermediate dimension.
+ num_experts (int): the number of experts.
+ has_bias (bool): if the router and experts have bias.
+ mlp_arch (str): unique key that specifies the MLP architecture,
+ e.g., if there is a gate forward.
+ top_k (int): the number of experts each token gets routed to.
+ dtype (torch.dtype): the dtype of the parameter tensors.
+ device (torch.device): the cuda device in which the model should be loaded.
+ Only cuda is supported since only triton kernels are supported.
+ ep_device_mesh (torch.distributed.DeviceMesh): Optional, to be passed if there is
+ sharding. Only pass the mesh for the experts.
+ lora_config (peft.LoraConfig): Optional, to be passed if lora is to be used.
+ """
+ assert (
+ not has_bias
+ ), "ScatterMoE currently unable to handle bias in both gates and experts."
+
+ if lora_config is not None:
+ # since this is self implemented, we really only support basic lora funcs
+ assert (
+ lora_config.bias == "none"
+ ), "ScatterMoE currently unable to handle bias in the lora adapters"
+ assert (
+ lora_config.target_modules == INCLUDE_LINEAR_LAYERS_SHORTHAND
+ or INCLUDE_LINEAR_LAYERS_SHORTHAND in lora_config.target_modules
+ ), "ScatterMoe currently only handles lora adapters on all linears."
+
+ assert lora_config.init_lora_weights in {
+ True,
+ "gaussian",
+ }, "ScatterMoe currently only handles gaussian initialization."
+
+ super().__init__()
+
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_experts = num_experts
+ self.hidden_act = hidden_act
+ self.activation = ACT2FN[hidden_act]
+ self.top_k = top_k
+ self.all_to_all = (
+ ep_device_mesh.size() > 1 if ep_device_mesh is not None else False
+ )
+
+ # NOTE: we should then use this to distribute inside
+ # and not do the distribution outside
+ self.expert_parallel_group = (
+ ep_device_mesh.get_group(0) if ep_device_mesh is not None else None
+ )
+
+ # build the router
+ self.router = torch.nn.Linear(
+ in_features=hidden_size,
+ out_features=num_experts,
+ bias=has_bias,
+ dtype=dtype,
+ device=device,
+ )
+
+ # the experts. The architecture may depend on the model
+ # - w1: the up_projection.
+ # - w2: the down_projection.
+ # - w3 (optional): the gate projection.
+ self.w1 = ScatteredExperts(
+ in_features=self.hidden_size,
+ out_features=self.intermediate_size,
+ num_experts=self.num_experts,
+ fan_out=self.top_k if not self.all_to_all else 1,
+ grouped_out=True,
+ dtype=dtype,
+ device=device,
+ lora_config=lora_config,
+ )
+ self.w2 = ScatteredExperts(
+ in_features=self.intermediate_size,
+ out_features=self.hidden_size,
+ num_experts=self.num_experts,
+ fan_out=1,
+ grouped_in=True,
+ dtype=dtype,
+ device=device,
+ lora_config=lora_config,
+ )
+ if mlp_arch == SCATTERMOE_SPEC_HAS_GATE:
+ self.w3 = ScatteredExperts(
+ in_features=self.hidden_size,
+ out_features=self.intermediate_size,
+ num_experts=self.num_experts,
+ fan_out=self.top_k if not self.all_to_all else 1,
+ grouped_out=True,
+ dtype=dtype,
+ device=device,
+ lora_config=lora_config,
+ )
+
+ # referenced from dolomite-engine
+ def _compute_routing_weights(self, hidden_states: torch.Tensor):
+
+ # router_logits: (batch * sequence_length, n_experts)
+ weight = _maybe_get_local_tensor(self.router.weight)
+ bias = self.router.bias
+ if bias:
+ bias = _maybe_get_local_tensor(bias)
+ # pylint: disable=not-callable
+ router_logits = F.linear(hidden_states, weight, bias)
+
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
+ routing_weights, selected_experts = torch.topk(
+ routing_weights, self.top_k, dim=-1
+ )
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+ # we cast back to the input dtype
+ routing_weights = routing_weights.to(hidden_states.dtype)
+ return router_logits, routing_weights, selected_experts
+
+ def _get_expert_idxs_and_maybe_gather(
+ self,
+ hidden_states: torch.Tensor,
+ selected_experts: torch.Tensor,
+ ):
+ """
+ gets the expert indices, and also gather the hidden_states if
+ all-to-all processing is required.
+
+ Parameters:
+ hidden_states (tensor): 2D batch-flattened hidden states.
+ selected_experts (tensor): indices of experts selected for each
+ hidden state.
+ """
+
+ # megablocks has a cuda kernel for computing a radix sort, but
+ # just use the torch version
+ sorted_expert_idxs, sorted_scattered_idxs = torch.sort(
+ selected_experts.flatten()
+ )
+ if not self.all_to_all:
+ # in this case, no gathering required for hidden states
+ return hidden_states, sorted_expert_idxs, sorted_scattered_idxs
+
+ # outputs will:
+ # - parallel_x: gathered version of hidden_states
+ # - parallel_bin_ids: gathered version of sorted_expert_idxs,
+ # - parallel_ind: gathered version of sorted_scattered_idxs.
+ #
+ # followed by some counting metrics:
+ # - send_counts, recv_counts, bins (local)
+ outputs = all_to_all_gather_inputs(
+ hidden_states,
+ selected_experts,
+ sorted_expert_idxs,
+ sorted_scattered_idxs,
+ self.expert_parallel_group,
+ self.top_k,
+ self.num_experts,
+ )
+
+ return outputs + (sorted_expert_idxs, sorted_scattered_idxs)
+
+ def _maybe_scatter(
+ self,
+ hidden_states: torch.Tensor,
+ routing_weights: torch.Tensor = None,
+ gather_products: Tuple[torch.Tensor, ...] = None,
+ ):
+ """
+ maybe undo the earlier scatter operation during all-to-all.
+
+ Parameters:
+ hidden_states (tensor): batch-flattened hidden states.
+ routing_weights (tensor): Optional, routing weights for each expert.
+ gather_products (tensor): Optional, tuple of tensors that would have been
+ produced by the earlier gather call.
+ """
+
+ if not self.all_to_all:
+ # in this case scattering is already handled by
+ # scattermoe when computing w2
+ # - then there is nothing to do
+ return hidden_states
+
+ # expect these products to be produced by an earlier
+ # all-to-all gather call
+ (send_counts, recv_counts, bins, sorted_expert_idxs, sorted_scattered_idxs) = (
+ gather_products
+ )
+
+ # perform the scattering with the gather products,
+ hidden_states = scatter_with_routing_weights(
+ hidden_states,
+ routing_weights.flatten(),
+ send_counts,
+ recv_counts,
+ bins,
+ sorted_expert_idxs,
+ sorted_scattered_idxs,
+ self.expert_parallel_group,
+ self.top_k,
+ )
+
+ return hidden_states
+
+ def forward(self, hidden_states: torch.Tensor):
+ """
+ ScatterMoe.forward replaces the forward of the sparse
+ mixture-of-expert module.
+ """
+
+ # flatten the batch dimension
+ original_shape = hidden_states.shape # take a record
+ hidden_states = hidden_states.view(-1, self.hidden_size)
+
+ # compute the routing logits, weights, and expert assigments
+ # - router_logits: will be passed out of forward, used for computing
+ # routing loss.
+ (router_logits, routing_weights, selected_experts) = (
+ self._compute_routing_weights(hidden_states)
+ )
+
+ # get the sorted expert idxs and scattered idxs.
+ # - if a gather is required, then the hidden-states will be
+ # communicated from other ranks, and will change.
+ # - in gather is required, then some _gather_products will be
+ # required for the scattering later, so return these out also.
+ (
+ hidden_states,
+ sorted_expert_idxs,
+ sorted_scattered_idxs,
+ *_gather_products,
+ ) = self._get_expert_idxs_and_maybe_gather(
+ hidden_states,
+ selected_experts,
+ )
+
+ # scattemoe specific computation.
+ # - padded indicies need to be computed for the scattermoe
+ # triton kernels.
+ with torch.no_grad():
+ padded_block_idxs, expert_offsets = padded_block_indices(
+ sorted_expert_idxs, self.num_experts
+ )
+
+ # compute the up projection
+ out = self.w1(
+ hidden_states,
+ sorted_expert_idxs,
+ sorted_scattered_idxs,
+ padded_block_idxs,
+ expert_offsets,
+ )
+ out = self.activation(out)
+
+ # - if the arch has a seperate gate projection
+ if self.w3:
+ out *= self.w3(
+ hidden_states,
+ sorted_expert_idxs,
+ sorted_scattered_idxs,
+ padded_block_idxs,
+ expert_offsets,
+ )
+
+ # compute the down projection
+ # - if no all-to-all processing, then depend on
+ # scattermoe kernel to perform the final scattering
+ hidden_states = self.w2(
+ out,
+ sorted_expert_idxs,
+ sorted_scattered_idxs,
+ padded_block_idxs,
+ expert_offsets,
+ gates=(None if self.all_to_all else routing_weights),
+ )
+
+ # maybe scatter
+ hidden_states = self._maybe_scatter(
+ hidden_states,
+ routing_weights,
+ _gather_products,
+ )
+
+ # return hidden states and router logits
+ return (hidden_states.view(original_shape), router_logits)
diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_constants.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_constants.py
new file mode 100644
index 00000000..be3057b5
--- /dev/null
+++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_constants.py
@@ -0,0 +1,94 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Standard
+from typing import List
+
+# to be updated so that the parsers can work properly
+PARAM_NAME_ROUTER_SCATTERMOE = "router"
+PARAM_NAME_WEIGHT_SCATTERMOE = ["w1", "w2", "w3"]
+
+FILE_SAFETENSOR_INDEX = "model.safetensors.index.json"
+KEY_REPLICATE = "replicate"
+KEY_EXPERT_PARALLEL = "expert_parallel"
+DIM_EXPERT = 0
+
+KEY_SCATTERMOE_ROUTER = PARAM_NAME_ROUTER_SCATTERMOE + ".weight"
+
+# Currently out ScatterMoE drop supports an up/down proj, and
+# and optional gate_proj.
+# - When new architectures are supported this list will update
+SCATTERMOE_SPEC_HAS_GATE = "Gated"
+
+# - moe_cls
+# - router_name
+# - expert_name
+# - weight_spec
+# - sharded experts
+
+# NOTE: it is quite challenging to perform the module swap
+# when the incoming MoE model can have quite varied impls.
+# - hence the adopted strategy is to infer the weights from the
+# state dict of the incoming model, and map them to the ScatterMoE
+# - the SPEC is a description of some hints to help perfom this state_dict
+# mapping.
+
+# NOTE: there is an expert_map logic which is currently not exposed
+# in the SPEC. the expert_map allows us to map the the parameter names
+# if they are different. But so far we do not need to use it.
+
+# NOTE: the keys can be a single arch string MixtralForCausalLM
+# or a few arch strings seperated by comma (no space)
+
+# when adding new models, follow the following convention:
+# - class name of moe module to be replaced with ScatterMoE.
+# - module_name of the router.
+# - module_name of the experts; this can be specified as a plain
+# name or a regex.
+# (str): name of the module
+# (regex): w1_name|w2_name|w3_name specificy the names if they are different.
+# - boolean flag indicating if the experts are sharded in the state dict.
+# i.e., meaning the experts exist in seperate 2D Linear modules
+# or all "combined" into a single 3D linear module.
+SCATTERMOE_CONVERSION_SPEC = {
+ "MixtralForCausalLM": (
+ "MixtralSparseMoeBlock",
+ "gate",
+ "experts",
+ SCATTERMOE_SPEC_HAS_GATE,
+ True,
+ ),
+ "GraniteMoeForCausalLM": (
+ "GraniteMoeMoE",
+ "router",
+ "input_linear|output_linear|input_linear",
+ SCATTERMOE_SPEC_HAS_GATE,
+ False,
+ ),
+}
+
+
+# helper function to get the spec based on architectures
+def get_scattermoe_conv_spec_from_archs(architectures: List[str]):
+ # infer the spec
+ for archs, spec in SCATTERMOE_CONVERSION_SPEC.items():
+ archs = archs.split(",")
+ if any(x in archs for x in architectures):
+ return spec
+
+ # if not found
+ raise ValueError(
+ f"In order to configure ScatterMoe for archs '{architectures}' "
+ "the conversion spect must be updated in scattermoe_constants.py"
+ )
diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py
new file mode 100644
index 00000000..7dfa607c
--- /dev/null
+++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py
@@ -0,0 +1,341 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Standard
+from collections import OrderedDict
+from contextlib import nullcontext
+import json
+import os
+
+# Third Party
+from accelerate import init_empty_weights
+from peft import LoraConfig
+from torch.distributed._tensor import DTensor, Replicate, Shard, distribute_tensor
+
+# pylint: disable=import-error
+from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh
+from tqdm import tqdm
+from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0
+import torch
+
+# Local
+from .checkpoint_utils import get_resolved_checkpoint_location
+from .scattermoe_constants import (
+ FILE_SAFETENSOR_INDEX,
+ KEY_EXPERT_PARALLEL,
+ KEY_REPLICATE,
+ KEY_SCATTERMOE_ROUTER,
+ get_scattermoe_conv_spec_from_archs,
+)
+from .scattermoe_state_dict import (
+ convert_state_dict,
+ get_checkpoint_meta_from_sharded_safetensor,
+ get_state_dict_from_checkpoint_metadata,
+)
+
+
+# this function will load the sharded experts onto the device.
+# - this assumes that the "dmoe" module is the megablocks.layers.dmoe.dMoE distributed
+# implementation of the mixture of experts.
+def load_experts_onto_device(
+ module: torch.nn.Module,
+ state_dict: OrderedDict,
+ device_mesh: DeviceMesh,
+ num_experts_per_device: int,
+):
+
+ # hook for scaling the gradient
+ scaling = device_mesh[KEY_EXPERT_PARALLEL].size()
+
+ def _hook(grad):
+ if grad is not None:
+ grad.div_(scaling)
+ return grad
+
+ # required replication placements
+ reps = [Replicate() for _ in range(device_mesh.ndim - 1)]
+
+ for weight_name, param in state_dict.items():
+
+ if KEY_SCATTERMOE_ROUTER in weight_name:
+ # if its the router, replicate
+ param = distribute_tensor(param, device_mesh, reps + [Replicate()])
+ elif param.shape[0] > num_experts_per_device:
+ # if its a weight param and the number of experts exceed that of
+ # the device, shard
+ param = distribute_tensor(param, device_mesh, reps + [Shard(0)])
+ else:
+ # if its a weight and the already sharded by number of experts
+ param = DTensor.from_local(
+ param, device_mesh=device_mesh, placements=reps + [Shard(0)]
+ )
+
+ # get the module we want to shard
+ name = weight_name.split(".")
+ path, name = ".".join(name[:-1]), name[-1]
+ mod = module.get_submodule(path)
+ requires_grad = getattr(mod, name).requires_grad
+
+ param = torch.nn.Parameter(
+ param,
+ requires_grad=requires_grad,
+ )
+
+ # install gradient scaling hook
+ if KEY_SCATTERMOE_ROUTER not in weight_name:
+ param.register_hook(_hook)
+
+ # register the sharded parameter onto the megablocks.dmoe
+ mod.register_parameter(name, param)
+
+
+def prepare_scattermoe(
+ model: torch.nn.Module,
+ checkpoint_name_or_path: str = None,
+ rank: int = None,
+ world_size: int = None,
+ ep_degree: int = 1,
+ key_rep: str = KEY_REPLICATE,
+ key_ep: str = KEY_EXPERT_PARALLEL,
+ device_type: str = "cuda",
+ mixed_precision: bool = False,
+ lora_config: LoraConfig = None,
+):
+
+ # guarded because may have third party package deps
+ # Local
+ # pylint: disable=import-outside-toplevel
+ from .scattermoe import ScatterMoE
+
+ assert world_size % ep_degree == 0, (
+ f"world size ({world_size}) " f"not divisible by ep_size ({ep_degree})."
+ )
+
+ moe_num_experts: int = model.config.num_local_experts
+ num_experts_per_device = moe_num_experts // ep_degree
+ assert (
+ moe_num_experts % ep_degree == 0
+ ), f"moe num experts ({moe_num_experts}) not divisible by ep_shard_factor ({ep_degree})."
+
+ # current rank of the device
+ device = torch.device(f"{device_type}:{rank}")
+
+ # get the scattermoe conversion spec
+ (
+ moe_cls,
+ router_name,
+ expert_name,
+ expert_mlp_spec,
+ sharded_expert_ckpt,
+ ) = get_scattermoe_conv_spec_from_archs(model.config.architectures)
+
+ # split the names first
+ expert_name = expert_name.split("|")
+
+ rep_size = world_size // ep_degree
+ if ep_degree == 1 and rep_size == 1:
+ # in this case no need for sharding
+ device_mesh = None
+ elif rep_size == 1:
+ # in this case a 1D device mesh suffices
+ device_mesh = init_device_mesh(
+ device_type,
+ (ep_degree,),
+ mesh_dim_names=(key_ep,),
+ )
+ else:
+ # in this case it will distribute experts on a different dim
+ # - this will achieve the effect that the expert sharding can be
+ # hierachical (e.g., can be over a slower network plane since
+ # the communication overhead is less
+ device_mesh = init_device_mesh(
+ device_type,
+ (rep_size, ep_degree),
+ mesh_dim_names=(key_rep, key_ep),
+ )
+
+ # - compute the shard indices for current expert, if sharding is
+ # indeed taking place
+ expert_shards = None
+ if device_mesh is not None:
+ _index = device_mesh[KEY_EXPERT_PARALLEL].get_local_rank()
+ expert_shards = list(
+ range(
+ _index * num_experts_per_device, (_index + 1) * num_experts_per_device
+ )
+ )
+
+ # - if mixed precision is specified then we upcast
+ dtype = model.dtype if not mixed_precision else torch.float32
+
+ # for all the MoE related params, e.g., gate, experts
+ # get a dictionary
+ # parent_mod: (child_instance_name, [list of fqdn keys])
+ found = {}
+ for name, mod in model.named_modules():
+ name = name.split(".")
+ parent, child = ".".join(name[:-1]), name[-1]
+
+ # check the module depending if moe_cls is a str or class
+ # pylint: disable=isinstance-second-argument-not-valid-type
+ if (
+ mod.__class__.__name__ == moe_cls
+ if isinstance(moe_cls, str)
+ else isinstance(mod, moe_cls)
+ ):
+ fqdn_keys = [ # all params, including childs'
+ f"{parent}.{child}.{n}" for n, _ in mod.named_parameters()
+ ]
+
+ # check if there are any biases in any of the experts
+ # if there are biases
+ # Assumption: assume that if one expert has bias,then the others
+ # will have it to
+ has_bias = any(
+ expert_name[0] in k and k.endswith("bias") for k in fqdn_keys
+ )
+
+ found[parent] = (child, fqdn_keys, has_bias)
+
+ assert len(found) > 0, "cannot find scattermoe modules to replace"
+
+ moe_module_names = set()
+
+ # pylint: disable=too-many-nested-blocks
+ # NOTE: for now we only support sharded safetensors
+ # - most MOE models should be used using this checkpoint format
+ try:
+ loc = get_resolved_checkpoint_location(checkpoint_name_or_path)
+ with open(os.path.join(loc, FILE_SAFETENSOR_INDEX), encoding="utf-8") as f:
+ index = json.load(f)
+
+ # e.g., prefix: 'model.layers.0',
+ # module_name: 'block_sparse_moe'
+ for prefix, (module_name, _, has_bias) in tqdm(
+ found.items(), disable=(rank > 0), desc="Converting ScatterMoE layers"
+ ):
+ checkpoint_metadata = get_checkpoint_meta_from_sharded_safetensor(
+ index["weight_map"],
+ prefix,
+ module_name,
+ router_name,
+ "|".join(expert_name),
+ )
+
+ # the parent module
+ parent = model.get_submodule(prefix)
+
+ # - handle state dict loading
+ # - NOTE: convert_state_dict does not have logic to concat sharded
+ # experts so cannot handle the case where sharded_expert_ckpt=True
+ if (
+ ep_degree == 1
+ and (not is_fsdp_enabled() or is_local_dist_rank_0())
+ and not sharded_expert_ckpt # cannot be a sharded checkpoint
+ ):
+ # - if there is no sharding, and model is not loaded on the
+ # meta device, we can simply convert the state dict
+ sd = convert_state_dict(
+ prefix + "." + module_name + ".",
+ checkpoint_metadata,
+ getattr(parent, module_name).state_dict(),
+ model.config.num_local_experts,
+ model.config.intermediate_size,
+ dtype,
+ )
+ else:
+ # if there is sharding, then we want the model to be loaded
+ # on meta in general, since the actual model may be alot smaller
+ sd = get_state_dict_from_checkpoint_metadata(
+ loc,
+ checkpoint_metadata,
+ num_experts_per_device,
+ model.config.intermediate_size,
+ expert_shards,
+ dtype,
+ )
+
+ if device_mesh is None:
+ _init_scattermoe_context = nullcontext
+ else:
+ # in this case we need to distribute parameters, so just initialize
+ # the scattermoe module swap with empty weights,
+ # since they are going to replaced.
+ _init_scattermoe_context = init_empty_weights
+
+ # - conver to a scatter moe
+ # - very hard to do patching, settle for module swap
+ with _init_scattermoe_context():
+ moe = ScatterMoE(
+ hidden_size=model.config.hidden_size,
+ hidden_act=model.config.hidden_act,
+ intermediate_size=model.config.intermediate_size,
+ num_experts=num_experts_per_device,
+ has_bias=has_bias,
+ mlp_arch=expert_mlp_spec,
+ top_k=model.config.num_experts_per_tok,
+ dtype=model.dtype,
+ device=device,
+ ep_device_mesh=(
+ device_mesh[key_ep] if device_mesh is not None else None
+ ),
+ lora_config=lora_config,
+ ) #
+
+ # the state dict logic below will not have lora adapters
+ # - so we need to initialize them
+ # - initialize them
+ if lora_config is not None:
+
+ # update the state_dict
+ for name, param in moe.named_parameters():
+ # NOTE: is his reliable?
+ if "lora_" in name:
+ if device_mesh is not None:
+ # this means it has been loaded with empty context above
+ # - so materialize the tensor
+ param = torch.empty(
+ *param.size(), dtype=dtype, requires_grad=True
+ )
+
+ sd[name] = param # set the param in state dict
+
+ # initialize the loras here
+ if "lora_A" in name:
+ torch.nn.init.zeros_(sd[name])
+ elif "lora_B" in name:
+ torch.nn.init.normal_(sd[name])
+
+ if device_mesh is None:
+ # - if not on meta, just load the state dict
+ # - and then put on the device
+ moe.load_state_dict(sd)
+ moe = moe.to(device)
+ else:
+ # - otherwise, we need to distribtue and will
+ # replace the parameters
+ load_experts_onto_device(moe, sd, device_mesh, num_experts_per_device)
+ # module swap
+ setattr(parent, module_name, moe)
+
+ # - keep track of the name for returning
+ moe_module_names.add(module_name)
+
+ except ValueError as e:
+ raise ValueError(
+ f"Unable to load checkpoint_path '{checkpoint_name_or_path}'. "
+ "Currently only support non-GGUF safetensor checkpoints. "
+ ) from e
+
+ return moe_module_names
diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py
new file mode 100644
index 00000000..e13f6ba5
--- /dev/null
+++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py
@@ -0,0 +1,338 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Standard
+from collections import OrderedDict, defaultdict
+from contextlib import ExitStack
+from typing import Dict, List, Tuple
+import os
+import re
+
+# Third Party
+from safetensors import safe_open
+import torch
+
+# Local
+from .scattermoe_constants import (
+ DIM_EXPERT,
+ KEY_SCATTERMOE_ROUTER,
+ PARAM_NAME_WEIGHT_SCATTERMOE,
+)
+
+# This function creates a dictionary of keys and paths into the the sharded
+# safetensors checkpoint file, that are relevant to the "prefix" and "instance_name"
+# being pased in.
+# - the keys point to modules found in megablocks.layers.dmoe.dMoE, the distributed
+# expert module provided by megablocks.
+# - the values are tuples pointing to the keys within the checkpoint file.
+#
+# Example: if prefix="module.layers.0" and instance_name="block_sparse_moe", then a dictionary
+# of the following will be returned:
+# {
+# 'w1.weight': [
+# (
+# 'model.layers.0.block_sparse_moe.experts.0.w1.weight',
+# 'model-00001-of-00019.safetensors'
+# ),
+# (
+# 'model.layers.0.block_sparse_moe.experts.1.w1.weight',
+# 'model-00001-of-00019.safetensors'
+# ),
+# ...
+# ]
+# 'w2.weight': [...],
+# 'w3.weight': [...],
+# 'router.weight': [
+# (
+# 'model.layers.0.block_sparse_moe.gate.weight',
+# 'model-00001-of-00019.safetensors'
+# )
+# ]
+# }
+#
+# or the non-sharded case (and possibly fused case)
+# {
+# 'w1.weight': [
+# (
+# 'model.layers.0.block_sparse_moe.input_linear.layer.weight',
+# 'model-00001-of-00001.safetensors'
+# ),
+# ],
+# ...
+# 'w3.weight': [
+# (
+# 'model.layers.0.block_sparse_moe.input_linear.layer.weight',
+# 'model-00001-of-00001.safetensors'
+# ),
+# ]
+# }
+
+
+def get_checkpoint_meta_from_sharded_safetensor(
+ weight_map: Dict,
+ prefix: str, # e.g., 'model.layers.0,
+ instance_name: str, # e.g., block_sparse_moe
+ router_name: str = "gate", # e.g., named "gate" within block_sparse_moe
+ expert_name: str = "experts", # e.g., named "experts" within block_sparse_moe
+ expert_map: Dict = None, # map -> [w1,w2,w3]
+) -> Dict[str, List[Tuple]]:
+ """
+ utilty function to infer the mapping of ScatterMoe parameters
+ from that of an incoming model model, based on a weight_map from a
+ sharded safetensor.
+
+ Parameters:
+ weight_map (dict): The weight map read in from a safetensor checkpoint.
+ prefix (str): the prefix where the MoE module lives (with respect to orig model).
+ instance_name (str): the name of the MoE module in the orig model
+ router_name (str): name of the router module as it is called in the MoE module
+ in the original model.
+ expert_name (str): name of the experts as they are called in the MoE module in
+ the orignal model. There are two patterns to use this.
+ i) specifiy a single string, and map them based on the
+ e.g., experts.w1 -> w1
+ ii) specify mutiple strings in order of w1, w2, ...
+ e.g., input_linear|output_linear|input_linear
+ expert_map (dict): This is used with pattern ii) described above in expert_name.
+ If not specified, will be the identity map, e.g., w1 -> w1
+ """
+
+ # insert in order
+ def _insert(L: List, i: int, v):
+ n = len(L)
+ if i < n:
+ L[i] = v
+ return
+
+ n = i - n + 1
+ while n > 0:
+ L.append(None)
+ n -= 1
+ L[i] = v
+
+ # if expert_name = input_linear|output_linear|input_linear
+ # - in this case will map
+ # - input_linear: [w1, w3], output_linear: {w2}
+ # - will assume the latter has double the size and can
+ # be split.
+ if expert_map is None:
+ if "|" in expert_name:
+ expert_map = {}
+ _names = expert_name.split("|")
+ _n, _n2 = len(_names), len(PARAM_NAME_WEIGHT_SCATTERMOE)
+ assert (
+ 2 <= _n <= _n2
+ ), f"If expert_name has |, expect between 2 and {_n2} entries, but got {_n}."
+
+ for i, n in enumerate(_names):
+ if n not in expert_map:
+ expert_map[n] = []
+ expert_map[n].append(PARAM_NAME_WEIGHT_SCATTERMOE[i])
+ else:
+ expert_map = {x: [x] for x in PARAM_NAME_WEIGHT_SCATTERMOE}
+
+ # state dict -> weights
+ # 'router.weight': [(k, file),...]
+ # `w1.weight`: [...]
+ _map = defaultdict(list)
+ prefix = f"{prefix}.{instance_name}."
+ for k, stfile in weight_map.items():
+ if not k.startswith(prefix):
+ continue
+
+ # e.g. after replacement we get
+ # - gate.weight
+ # - experts.0.w1.weight
+ rel_k = k.replace(prefix, "")
+ # pylint: disable=anomalous-backslash-in-string
+ m = re.match(f"({router_name}|{expert_name})\.?(\d+)?\.?(\w+)?\.weight", rel_k)
+ if m is None:
+ raise ValueError(
+ f"Unable to handle key '{k}' with provided router_name "
+ f"'{router_name}' or expert_name '{expert_name}'"
+ )
+ if m.group(1) == router_name:
+ _map[KEY_SCATTERMOE_ROUTER].append((k, stfile))
+ elif m.group(1) in expert_name:
+ index = m.group(2)
+ index = 0 if index is None else int(index)
+ mod = None
+ for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))):
+ _insert(_map[f"{mod}.weight"], index, (k, stfile))
+
+ assert mod is not None, f"cannot map '{rel_k}'"
+
+ if len(_map) == 0:
+ raise ValueError(
+ f"Could not get safetensor map for '{prefix}' and '{instance_name}'"
+ )
+
+ return _map
+
+
+# if the weight is a scattermoe expert weight, need some reshaping
+def _maybe_reshape_scattermoe_expert_weights(
+ scatter_key: str,
+ param: torch.Tensor,
+ num_experts: int,
+ intermediate_size: int,
+):
+ (_is_w1, _is_w2, _is_w3) = [
+ f"{x}.weight" in scatter_key for x in PARAM_NAME_WEIGHT_SCATTERMOE
+ ]
+
+ if _is_w1 or _is_w2 or _is_w3:
+ if len(param.shape) == 2:
+ param = param.view(num_experts, -1, param.shape[-1])
+
+ if _is_w1 or _is_w3:
+ if param.shape[-2] == (2 * intermediate_size):
+ # cut it
+ if _is_w1:
+ param = param[..., :intermediate_size, :]
+ else:
+ param = param[..., intermediate_size:, :]
+
+ # asumme these are linears
+ # assert param.shape[-2] == intermediate_size, "wrong intermediate size"
+ # assert param.shape[-1] == hidden_size, "wrong hidden size"
+
+ # have to transpose for weights since scattermoe accepts the differen
+ # order
+ param = param.permute(0, 2, 1)
+
+ return param
+
+
+def convert_state_dict(
+ prefix: str,
+ checkpoint_metadata: Dict[str, List[Tuple]],
+ state_dict: OrderedDict,
+ num_experts: int,
+ intermediate_size: int,
+ dtype: torch.dtype = None,
+):
+ """
+ utility to convert the state dict for ScatterMoE. To be used
+ if the model is already loaded with weights.
+
+ Parameters:
+ prefix (str): where the MoE is located in the incoming model.
+ checkpoint_metadata (dict): a mapping of ScatterMoE state dict
+ with respect to that of incoming model.
+ state_dict (dict): of the incoming MoE.
+ num_experts (int):
+ intermediate_size (int):
+ dtype (torch.dtype):
+ """
+ target = OrderedDict()
+
+ for scatter_key, vs in checkpoint_metadata.items():
+ for state_key, _ in vs:
+ state_key = state_key.replace(prefix, "")
+ param = state_dict[state_key]
+ param = _maybe_reshape_scattermoe_expert_weights(
+ scatter_key, param, num_experts, intermediate_size
+ )
+ if dtype is not None:
+ param = param.to(dtype)
+ target[scatter_key] = param
+
+ return target
+
+
+def get_state_dict_from_checkpoint_metadata(
+ checkpoint_directory: str,
+ checkpoint_metadata: Dict[str, List[Tuple]],
+ num_experts: int,
+ intermediate_size: int,
+ expert_shards: List[int] = None,
+ dtype: torch.dtype = None,
+):
+ """
+ utility to convert a sharded checkpoint into a state dict for
+ ScatterMoe. To be used if the model was loaded on the meta
+ device and actual weights does not exist in it.
+
+ Parameters:
+ checkpoint_directory (str): where the checkpoint is located.
+ checkpoint_metadata (dict): a mapping of ScatterMoE state dict
+ with respect to that of incoming model.
+ num_experts (int):
+ intermediate_size (int):
+ expert_shards (list): indexing which of the shards are required
+ if only a subset of parameters are required
+ dtype (torch.dtype):
+ """
+ target = OrderedDict()
+
+ # typically they all should be same file, but to play safe, load the checkpoint file onto
+ # cpu first since we may not need all weights in that file.
+ with ExitStack() as stack:
+ files = {}
+ for _, vs in checkpoint_metadata.items():
+ for _, fi in vs:
+ if fi not in files:
+ files[fi] = stack.enter_context(
+ safe_open(
+ os.path.join(checkpoint_directory, fi),
+ framework="pt",
+ device="cpu",
+ )
+ )
+
+ # go by one weight at a time.
+ for scatter_key, vs in checkpoint_metadata.items():
+
+ if KEY_SCATTERMOE_ROUTER in scatter_key:
+ k, fi = vs[0] # only one item
+ param = files[fi].get_tensor(k)
+
+ elif len(vs) == 1:
+ k, fi = vs[0] # only one item
+ # if its a non-router weight and its non-sharded
+ param = files[fi].get_tensor(k)
+ assert len(param.shape) == 3, (
+ "Expected 3D tensor for checkpoints with non-sharded experts, ",
+ f"but got shape {param.shape}.",
+ )
+
+ else:
+ # handle sharding if the checkpoint shards experts
+ # -
+ data = []
+ if expert_shards is not None:
+ vs = [vs[i] for i in expert_shards]
+
+ for k, fi in vs:
+ T = files[fi].get_tensor(k)
+ assert len(T.shape) == 2, (
+ "Expected 2D tensor for checkpoints with sharded experts, "
+ f"but got shape {T.shape}."
+ )
+
+ T = T.unsqueeze(0)
+ data.append(T)
+
+ param = torch.concat(data, dim=DIM_EXPERT)
+
+ param = _maybe_reshape_scattermoe_expert_weights(
+ scatter_key, param, num_experts, intermediate_size
+ )
+ if dtype is not None:
+ param = param.to(dtype)
+
+ target[scatter_key] = param
+
+ return target
diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/__init__.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/__init__.py
new file mode 100644
index 00000000..b72b46d2
--- /dev/null
+++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/__init__.py
@@ -0,0 +1,17 @@
+# Copyright The FMS HF Tuning Authors
+# Copyright 2023 MegaBlocks authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Local
+from .expert_parallel import all_to_all_gather_inputs, scatter_with_routing_weights
diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/expert_parallel.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/expert_parallel.py
new file mode 100644
index 00000000..1704b850
--- /dev/null
+++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/expert_parallel.py
@@ -0,0 +1,232 @@
+# Copyright The FMS HF Tuning Authors
+# Copyright 2023 MegaBlocks authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Third Party
+import numpy as np
+import torch
+
+try:
+ # if megablocks is installed, import the kernels, distributed
+ # and kernel functions
+
+ # - mixture of triton and cuda kernels
+ # Third Party
+ from megablocks import ops
+
+ # - distributed autograd
+ from megablocks.layers.all_to_all import all_to_all
+ from megablocks.ops import gather, histogram, inclusive_cumsum, scatter
+
+ # this is a radix sort for integral indices 0 .. num_bins-1
+ def sort(indices: torch.Tensor, num_bins: int):
+ bits = max(int(np.ceil(np.log2(num_bins))), 1)
+ # TODO: figure out why we need this upcast
+ bins, inds = ops.sort(indices, bits)
+ return bins, inds.to(torch.int64)
+
+ # replicate indices with bins
+ def replicate(indices: torch.Tensor, bins: torch.Tensor):
+ replicate_bins = inclusive_cumsum(bins.flatten(), 0)
+ # pylint: disable=use-implicit-booleaness-not-len
+ replicate_bins = (
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
+ )
+
+ return ops.replicate(
+ indices.unsqueeze(dim=0),
+ replicate_bins,
+ replicate_bins[-1],
+ ).flatten()
+
+except ImportError:
+
+ # - distributed autograd
+ # Local
+ from .megablocks import all_to_all, gather, scatter
+
+ # take the histogram of integral indices from 0 .. num_bins-1
+ def histogram(indices: torch.Tensor, num_bins: int):
+ # - this has an Aten for the GPU backend
+ return torch.histc(indices, bins=num_bins, min=0, max=num_bins - 1)
+
+ def inclusive_cumsum(x: torch.Tensor, dim: int):
+ # - convert to int332 type as that is what is expected by the
+ # megablocks gather and scatter kernels
+ return x.cumsum(axis=dim, dtype=torch.int32)
+
+ # this is a radix sort for integral indices 0 .. num_bins-1
+ def sort(indices: torch.Tensor, num_bins: int):
+ return torch.sort(indices)
+
+ # replicate, this replicates an integral indices according to bin times
+ def replicate(indices: torch.Tensor, bins: torch.Tensor):
+ return torch.repeat_interleave(indices, bins)
+
+
+# from megablocks
+def no_indices_just_bins(top_expert, num_experts):
+ # Sort the expert ids to produce the scatter/gather
+ # indices for the permutation.
+
+ # Histogram the expert ids to identify the number of
+ # tokens routed to each expert.
+ #
+ tokens_per_expert = histogram(top_expert, num_experts)
+
+ # Calculate the bin bounds for the sorted tokens.
+ bins = inclusive_cumsum(tokens_per_expert, 0)
+ # pylint: disable=use-implicit-booleaness-not-len
+ bins = bins.view(1) if not len(bins.size()) else bins
+ return bins, tokens_per_expert
+
+
+# modified from https://github.com/databricks/megablocks/blob/main/megablocks/layers/mlp.py
+# - credit to trevor-gale
+def all_to_all_gather_inputs(
+ x: torch.Tensor,
+ top_experts: torch.Tensor,
+ bin_ids: torch.Tensor,
+ indices: torch.Tensor,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ top_k: int,
+ experts_per_rank: int,
+):
+ """
+ Extracted from megablocks. This function performs all-to-all input
+ gathering for expert parallel.
+ """
+
+ # Compute the mapping of local tokens to experts.
+ # expert_weights = expert_weights.flatten()
+ top_experts = top_experts.flatten()
+ world_size = expert_parallel_group.size()
+ with torch.no_grad():
+ bins, tokens_per_expert = no_indices_just_bins(
+ top_experts, experts_per_rank * world_size
+ )
+
+ # Pass token count information to the device on which the
+ # target expert resides.
+ parallel_tokens_per_expert = torch.empty_like(
+ tokens_per_expert,
+ )
+ tpe_handle = torch.distributed.all_to_all_single(
+ parallel_tokens_per_expert,
+ tokens_per_expert,
+ group=expert_parallel_group,
+ async_op=True,
+ )
+
+ # Permute locally and without any padding so that tokens for each
+ # parallel device are stored contiguously.
+ #
+ # This view updates the shape of the tensor from [sl, bs, hs] to
+ # [sl * bs, hs] prior to the permutation.
+ x = gather(x, indices, bin_ids, bins, top_k)
+
+ # Compute the number of tokens that will be received from each
+ # device and permute the input data across the devices.
+ with torch.no_grad():
+ tpe_handle.wait()
+
+ # Reshape to [world_size, num_experts_per_rank].
+ tokens_per_expert = tokens_per_expert.view(world_size, experts_per_rank)
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
+ world_size, experts_per_rank
+ )
+
+ # TODO(tgale): It might be faster to do this on the GPU and
+ # then communicate the results back to the host.
+ send_counts = tokens_per_expert.cpu().sum(dim=-1)
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
+
+ # Convert the send/recv counts to lists.
+ send_counts = send_counts.tolist()
+ recv_counts = recv_counts.tolist()
+
+ # Start the cross-device permutation asynchronously so we can
+ # overlap communication with computation.
+ parallel_x, parallel_x_handle = all_to_all(
+ x,
+ recv_counts,
+ send_counts,
+ expert_parallel_group,
+ async_op=True,
+ )
+
+ with torch.no_grad():
+ # After we do the cross-device permutation we have the tokens on the
+ # correct device but not yet grouped by expert because we received
+ # tokens from each device as contiguous chunks. To group the tokens
+ # for expert computation we'll do one more local permutation. The
+ # rest of this torch.no_grad() scope sets up the indices and bins
+ # for this permutation.
+
+ # Construct the expert indices for the permuted tokens.
+ parallel_top_expert = torch.remainder(
+ torch.arange(
+ experts_per_rank * world_size,
+ dtype=torch.int32,
+ device=indices.device,
+ ),
+ experts_per_rank,
+ )
+
+ parallel_top_expert = replicate(
+ parallel_top_expert,
+ parallel_tokens_per_expert.flatten(),
+ )
+
+ parallel_bin_ids, parallel_indices = sort(parallel_top_expert, experts_per_rank)
+
+ parallel_x_handle.wait()
+
+ return (
+ parallel_x,
+ parallel_bin_ids,
+ parallel_indices,
+ send_counts,
+ recv_counts, # for all to all
+ bins, # local
+ )
+
+
+def scatter_with_routing_weights(
+ x: torch.Tensor,
+ expert_weights: torch.Tensor,
+ send_counts: torch.Tensor,
+ recv_counts: torch.Tensor,
+ bins: torch.Tensor,
+ bin_ids: torch.Tensor,
+ indices: torch.Tensor,
+ expert_parallel_group: torch.distributed.ProcessGroup,
+ top_k: int,
+):
+ """
+ Extracted from megablocks. This function undoes the all-to-all
+ gathering for expert parallel.
+ """
+
+ # Un-permute the tokens across the devices.
+ x, _ = all_to_all(
+ x,
+ send_counts,
+ recv_counts,
+ expert_parallel_group,
+ )
+
+ # Un-permute locally to setup for the next series of operations.
+ return scatter(x, indices, bin_ids, expert_weights, bins, top_k)
diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/__init__.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/__init__.py
new file mode 100644
index 00000000..024c575e
--- /dev/null
+++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/__init__.py
@@ -0,0 +1,17 @@
+# Copyright The FMS HF Tuning Authors
+# Copyright 2024 Databricks
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Local
+from .autograd import all_to_all, gather, scatter
diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/autograd.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/autograd.py
new file mode 100644
index 00000000..f9e100fd
--- /dev/null
+++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/autograd.py
@@ -0,0 +1,183 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# Standard
+import functools
+
+# Third Party
+import torch
+
+# Local
+from .kernels import gather as _kernels_gather
+from .kernels import scatter as _kernels_scatter
+from .kernels import scatter_wgrad as _kernels_scatter_wgrad
+
+
+# ------------------------ HELPERS -----------------------------
+def _is_eligible(x):
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
+
+
+def _cast(x, dtype):
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
+ return x.to(dtype)
+ elif isinstance(x, map):
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
+ elif isinstance(x, list) or isinstance(x, tuple):
+ return type(x)(map(lambda y: _cast(y, dtype), x))
+ return x
+
+
+def custom_fwd(fwd):
+ """Wrap a custom autograd function that always uses autocast dtype."""
+
+ @functools.wraps(fwd)
+ def decorate_fwd(*args, **kwargs):
+ if torch.is_autocast_enabled():
+ with torch.autocast(device_type="cuda", enabled=False):
+ dtype = torch.get_autocast_gpu_dtype()
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
+ return fwd(*args, **kwargs)
+
+ return decorate_fwd
+
+
+def custom_bwd(bwd):
+ @functools.wraps(bwd)
+ def decorate_bwd(*args, **kwargs):
+ with torch.autocast(device_type="cuda", enabled=False):
+ return bwd(*args, **kwargs)
+
+ return decorate_bwd
+
+
+# ------------------------ AUTOGRAD -----------------------------
+
+
+class AllToAllOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
+ out = torch.empty(
+ (sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype
+ )
+
+ ctx.input_shape = x.shape
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+ ctx.group = group
+ handle = torch.distributed.all_to_all_single(
+ out,
+ x,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op,
+ )
+ return out, handle
+
+ @staticmethod
+ def backward(ctx, grad, _):
+ if ctx.needs_input_grad[0]:
+ out = torch.empty(
+ ctx.input_shape,
+ device=grad.device,
+ dtype=grad.dtype,
+ )
+ torch.distributed.all_to_all_single(
+ out,
+ grad,
+ output_split_sizes=ctx.input_split_sizes,
+ input_split_sizes=ctx.output_split_sizes,
+ group=ctx.group,
+ )
+ return out, None, None, None, None
+ return None, None, None, None, None
+
+
+def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
+ return AllToAllOp.apply(
+ x,
+ output_split_sizes,
+ input_split_sizes,
+ group,
+ async_op,
+ )
+
+
+# Autograd wrapper for scatter kernel.
+class ScatterOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, x, indices, bin_ids, weights, bins, top_k):
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
+ ctx.top_k = top_k
+ ctx.x_shape = x.shape
+ return _kernels_scatter(x, indices, bin_ids, weights, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+ saved_tensors = ctx.saved_tensors
+
+ indices, bin_ids, weights, bins = saved_tensors[:4]
+ dgrad = None
+ if ctx.needs_input_grad[0]:
+ dgrad = _kernels_gather(
+ grad,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ ctx.top_k,
+ )
+
+ wgrad = None
+ if ctx.needs_input_grad[3]: # need wgrad
+ x = saved_tensors[-1]
+ wgrad = _kernels_scatter_wgrad(
+ x,
+ grad,
+ indices,
+ bin_ids,
+ bins,
+ ctx.top_k,
+ )
+ return dgrad, None, None, wgrad, None, None, None
+
+
+def scatter(
+ x: torch.Tensor,
+ indices: torch.Tensor,
+ bin_ids: torch.Tensor,
+ weights: torch.Tensor,
+ bins: torch.Tensor,
+ top_k: int,
+):
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
+
+
+# Autograd wrapper for gather kernel.
+class GatherOp(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, x, indices, bin_ids, bins, top_k):
+ ctx.save_for_backward(indices, bin_ids, bins)
+ ctx.top_k = top_k
+ return _kernels_gather(x, indices, bin_ids, None, bins, top_k)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad):
+ grad = grad.contiguous()
+
+ indices, bin_ids, bins = ctx.saved_tensors
+ out = _kernels_scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
+ return out, None, None, None, None, None
+
+
+gather = GatherOp.apply
diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/kernels/__init__.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/kernels/__init__.py
new file mode 100644
index 00000000..6dc4fe5b
--- /dev/null
+++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/kernels/__init__.py
@@ -0,0 +1,17 @@
+# Copyright The FMS HF Tuning Authors
+# Copyright 2024 Databricks
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Local
+from .gather_scatter import gather, scatter, scatter_wgrad
diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/kernels/gather_scatter.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/kernels/gather_scatter.py
new file mode 100644
index 00000000..4809b4fb
--- /dev/null
+++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/kernels/gather_scatter.py
@@ -0,0 +1,309 @@
+# Copyright 2024 Databricks
+# SPDX-License-Identifier: Apache-2.0
+
+# Third Party
+import torch
+import triton
+import triton.language as tl
+
+
+def assert_is_tensor(x, ndim):
+ if x.ndim != ndim:
+ raise ValueError(f"Expected {ndim}-tensor but got {x.ndim}-tensor")
+
+
+def assert_is_matrix(x):
+ assert_is_tensor(x, 2)
+
+
+def assert_is_vector(x):
+ if x.ndim != 1:
+ raise ValueError(f"Expected 1-tensor but got {x.ndim}-tensor")
+
+
+def assert_equal(a, b):
+ if a != b:
+ raise ValueError(
+ f"Expected dimensions to be equal but got {a} and {b}.",
+ )
+
+
+# a: (tokens, hidden_size), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# weights: (tokens * top_k), real.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_X": 64}, num_warps=2),
+ triton.Config({"BLOCK_X": 128}, num_warps=2),
+ triton.Config({"BLOCK_X": 256}, num_warps=2),
+ triton.Config({"BLOCK_X": 128}, num_warps=4),
+ triton.Config({"BLOCK_X": 256}, num_warps=4),
+ ],
+ key=["NUM_COLUMNS"],
+)
+@triton.jit
+def _padded_copy(
+ a,
+ b,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+ A_TO_B: tl.constexpr,
+ SCALE: tl.constexpr,
+):
+ # Our index into array 'a'.
+ index_a = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'b'.
+ index_b = offset_in_bin
+ if bin_idx > 0:
+ index_b += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ #
+ # If we're going from A to B, divide the input index to copy
+ # the same input repeatedly. If we're going from B to A we
+ # need to reduce the result. Using atomics is slow, so we
+ # do the reduce step in a second kernel.
+ offset = index_a // TOP_K if A_TO_B else index_a
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ # Load the scale, if requested.
+ scale = tl.load(weights + index_a) if SCALE else 1
+
+ # Swap the pointers depending on the direction.
+ iptr = a if A_TO_B else b
+ optr = b if A_TO_B else a
+
+ for i in range(tl.cdiv(NUM_COLUMNS, BLOCK_X)):
+ mask = offsets < NUM_COLUMNS
+ x = tl.load(iptr + offsets, mask=mask)
+ x = x.to(tl.float32) * scale.to(tl.float32)
+
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
+
+ offsets += BLOCK_X
+
+
+def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: Because of the padding, the output size is dynamic.
+ # We load the final padded bin bound to get the output rows.
+ output_rows = padded_bins[-1].cpu().item()
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def gather(x, indices, bin_ids, weights, bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
+
+ if weights is not None:
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
+
+ # NOTE: There is no padding so the output rows equals the
+ # input rows multiplied by top_k.
+ output_rows = x.shape[0] * top_k
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ x,
+ out,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=True,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+ return out
+
+
+def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ if weights is not None:
+ assert_equal(indices.shape[0], weights.shape[0])
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
+ _padded_copy[(indices.shape[0],)](
+ out,
+ x,
+ indices,
+ bin_ids,
+ weights,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ A_TO_B=False,
+ TOP_K=top_k,
+ SCALE=weights is not None,
+ )
+
+ # Reduce along the top-k dimension, if needed.
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
+
+
+def scatter(x, indices, bin_ids, weights, bins, top_k):
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
+
+
+# x: (tokens, top_k, hidden_size), real
+# grad: (tokens, hidden_size), real.
+# wgrad: (tokens, top_k), real.
+# indices: (tokens * top_k), integer.
+# bin_ids: (tokens * top_k), integer.
+# bins: (num_experts), integer.
+# padded_bins: (num_experts), integer.
+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_X": 64}, num_warps=2),
+ triton.Config({"BLOCK_X": 128}, num_warps=2),
+ triton.Config({"BLOCK_X": 256}, num_warps=2),
+ triton.Config({"BLOCK_X": 128}, num_warps=4),
+ triton.Config({"BLOCK_X": 256}, num_warps=4),
+ ],
+ key=["NUM_COLUMNS"],
+)
+@triton.jit
+def _padded_copy_wgrad(
+ x,
+ grad,
+ wgrad,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS: tl.constexpr,
+ TOP_K: tl.constexpr,
+ BLOCK_X: tl.constexpr,
+):
+ # Our index into 'tokens * top_k'.
+ index_out = tl.load(indices + tl.program_id(0))
+
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
+ # number of rows since they could be padded.
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
+
+ # Now we know what bin we're assigned to, but we need to know how
+ # many threadblocks were assigned to earlier bins so we can offset
+ # in our bin properly.
+ offset_in_bin = tl.program_id(0)
+ if bin_idx > 0:
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
+
+ # Load the starting index of our bin in array 'x'.
+ index_x = offset_in_bin
+ if bin_idx > 0:
+ index_x += tl.load(padded_bins + bin_idx - 1)
+
+ # Offset the input and output pointers.
+ wgrad += index_out
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
+
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
+ for i in range(iterations):
+ mask = offsets < NUM_COLUMNS
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
+ acc += data * scale
+ offsets += BLOCK_X
+
+ # Reduce to get the final result and store.
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
+ tl.store(wgrad, out)
+
+
+def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
+ # Validate the input shapes.
+ assert_is_matrix(x)
+ assert_is_matrix(grad)
+ assert_is_vector(indices)
+ assert_is_vector(bin_ids)
+ assert_is_vector(bins)
+ assert_is_vector(padded_bins)
+ assert_equal(indices.shape[0], bin_ids.shape[0])
+ assert_equal(bins.size(), padded_bins.size())
+
+ tokens = indices.shape[0] // top_k
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
+ _padded_copy_wgrad[(indices.shape[0],)](
+ x,
+ grad,
+ out,
+ indices,
+ bin_ids,
+ bins,
+ padded_bins,
+ NUM_COLUMNS=x.shape[1],
+ TOP_K=top_k,
+ )
+ return out
+
+
+def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
diff --git a/plugins/accelerated-moe/tests/__init__.py b/plugins/accelerated-moe/tests/__init__.py
new file mode 100644
index 00000000..38a9531e
--- /dev/null
+++ b/plugins/accelerated-moe/tests/__init__.py
@@ -0,0 +1,13 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/plugins/accelerated-moe/tests/test_scattermoe_plugin.py b/plugins/accelerated-moe/tests/test_scattermoe_plugin.py
new file mode 100644
index 00000000..ccdba6d3
--- /dev/null
+++ b/plugins/accelerated-moe/tests/test_scattermoe_plugin.py
@@ -0,0 +1,34 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Standard
+import os
+
+# Third Party
+from fms_acceleration.utils import instantiate_framework, read_configuration
+
+# First Party
+from fms_acceleration_moe import ScatterMoEAccelerationPlugin
+
+# configuration
+DIRNAME = os.path.dirname(__file__)
+CONFIG_PATH_SCATTERMOE = os.path.join(DIRNAME, "../configs/scattermoe.yaml")
+
+
+def test_framework_installs_scattermoe_plugin():
+ with instantiate_framework(
+ read_configuration(CONFIG_PATH_SCATTERMOE), require_packages_check=False
+ ) as framework:
+ for plugin in framework.active_plugins:
+ assert isinstance(plugin[1], ScatterMoEAccelerationPlugin)
diff --git a/plugins/accelerated-moe/tests/test_scattermoe_state_dict.py b/plugins/accelerated-moe/tests/test_scattermoe_state_dict.py
new file mode 100644
index 00000000..ff8965ba
--- /dev/null
+++ b/plugins/accelerated-moe/tests/test_scattermoe_state_dict.py
@@ -0,0 +1,208 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Standard
+from typing import List
+
+# Third Party
+# pylint: disable=import-error
+import pytest
+
+# First Party
+from fms_acceleration_moe.utils.scattermoe_constants import (
+ PARAM_NAME_ROUTER_SCATTERMOE,
+ PARAM_NAME_WEIGHT_SCATTERMOE,
+)
+from fms_acceleration_moe.utils.scattermoe_state_dict import (
+ get_checkpoint_meta_from_sharded_safetensor,
+)
+
+# just a dummy sample value
+ST_SHARD = "model-00001-of-00001.safetensors"
+
+
+# --------------------------- HELPERS ------------------------------
+# - builds a weight dict for checkpoints where MoE is sharded (i.e.,
+# one linear by expert).
+# - this is like Mixtral style
+def build_dummy_weight_map_sharded_moe(
+ prefix: str,
+ module_name: str,
+ router_name: str,
+ expert_name: str,
+ num_layers: int,
+ num_experts: int,
+ expert_keys: List[str],
+):
+
+ # - ST_SHARD entries are not impt for the test
+ weight_map = {}
+ for i in range(num_layers):
+ layer_map = {
+ f"{prefix}.{i}.{module_name}.{router_name}.weight": ST_SHARD,
+ }
+ for j in range(num_experts):
+ expert_map = {}
+
+ for n in expert_keys:
+ expert_map.update(
+ {
+ f"{prefix}.{i}.{module_name}.{expert_name}.{j}.{n}.weight": ST_SHARD
+ }
+ )
+
+ layer_map.update(expert_map)
+
+ weight_map.update(layer_map)
+
+ return weight_map
+
+
+# - this is like granite style
+def build_dummy_weight_map_non_sharded_moe(
+ prefix: str,
+ module_name: str,
+ router_name: str,
+ num_layers: int,
+ expert_keys: List[str],
+):
+ # - ST_SHARD entries are not impt for the test
+ weight_map = {}
+ for i in range(num_layers):
+ layer_map = {
+ f"{prefix}.{i}.{module_name}.{router_name}.weight": ST_SHARD,
+ }
+ for n in expert_keys:
+ layer_map.update({f"{prefix}.{i}.{module_name}.{n}.weight": ST_SHARD})
+
+ weight_map.update(layer_map)
+
+ return weight_map
+
+
+# --------------------------- TEST ---------------------------------
+
+PARAMETERS = [
+ (
+ True,
+ "model.layers",
+ "block_sparse_moe",
+ "gate",
+ "experts",
+ 2,
+ 8,
+ ["w1", "w2", "w3"],
+ ),
+ (
+ False,
+ "model.layers",
+ "block_sparse_moe",
+ "gate",
+ "input_linear|output_linear|input_linear",
+ 2,
+ None,
+ ["input_linear", "output_linear"],
+ ),
+]
+
+
+@pytest.mark.parametrize(
+ (
+ "sharded_ckpt,prefix,module_name,router_name,expert_name,"
+ "num_layers,num_experts,expert_keys"
+ ),
+ PARAMETERS,
+)
+def test_get_metadata_from_sharded_safetensor_correctly(
+ sharded_ckpt: bool,
+ prefix: str,
+ module_name: str,
+ router_name: str,
+ expert_name: str,
+ num_layers: int,
+ num_experts: int,
+ expert_keys: List[str],
+):
+
+ if sharded_ckpt:
+ weight_map = build_dummy_weight_map_sharded_moe(
+ prefix,
+ module_name,
+ router_name,
+ expert_name,
+ num_layers,
+ num_experts,
+ expert_keys,
+ )
+ else:
+ weight_map = build_dummy_weight_map_non_sharded_moe(
+ prefix, module_name, router_name, num_layers, expert_keys
+ )
+
+ # get the metadata for the a layer
+ ckpt_metadata = get_checkpoint_meta_from_sharded_safetensor(
+ weight_map,
+ prefix + ".0", # include layer
+ module_name,
+ router_name,
+ expert_name,
+ )
+
+ _key = f"{PARAM_NAME_ROUTER_SCATTERMOE}.weight"
+ assert _key in ckpt_metadata, "unable to map scattermoe router metadata."
+
+ _n = len(ckpt_metadata[_key])
+ assert _n == 1, f"expected only 1 router weights but got {_n}"
+
+ for n in PARAM_NAME_WEIGHT_SCATTERMOE:
+ _key = f"{n}.weight"
+ assert _key in ckpt_metadata, f"unable top map scattermoe expert weight {n}."
+
+ _n = len(ckpt_metadata[_key])
+ if sharded_ckpt:
+ assert (
+ _n == num_experts
+ ), f"missing expert weights, only mapped {_n} weights out of {num_experts}."
+ else:
+ assert (
+ _n == 1
+ ), f"missing expert weights, mapped {_n} but expected only 1 for non-sharded."
+
+
+def test_get_metadata_from_sharded_safetensor_incorrectly():
+
+ weight_map_wrong = {"prefix.moe_name.expert.weight": ST_SHARD}
+
+ # - if passing a prefix, has to map the weight_map
+ with pytest.raises(ValueError, match="Could not get safetensor map for"):
+ get_checkpoint_meta_from_sharded_safetensor(
+ weight_map_wrong, "wrong_prefix", "moe_name", None, "expert_name"
+ )
+
+ # - if passing mutiple expert names, cannot violate the number of
+ # possible expert gates
+ with pytest.raises(
+ AssertionError, match="If expert_name has |, expect between 2 and"
+ ):
+ get_checkpoint_meta_from_sharded_safetensor(
+ weight_map_wrong, "prefix", "moe_name", None, "exp1|exp2|exp3|exp4"
+ )
+
+ # - if a weight_map key that matches the moe_name, cannot be handled
+ with pytest.raises(
+ ValueError, match="Unable to handle key 'prefix.moe_name.expert.weight'"
+ ):
+ get_checkpoint_meta_from_sharded_safetensor(
+ weight_map_wrong, "prefix", "moe_name", None, "wrong_expert_name"
+ )
diff --git a/plugins/accelerated-moe/tox.ini b/plugins/accelerated-moe/tox.ini
new file mode 100644
index 00000000..811f1329
--- /dev/null
+++ b/plugins/accelerated-moe/tox.ini
@@ -0,0 +1,48 @@
+[tox]
+envlist = py, lint
+
+[testenv]
+deps =
+ pytest>=7
+ -e {toxinidir}
+skip_install = true
+commands =
+
+ # install the dependencies here to ensure
+ # the order
+ pip install -e {toxinidir}/../framework
+ pytest {posargs:tests}
+
+[testenv:lint]
+description = run linters
+skip_install = false
+deps =
+ -e {toxinidir}/../framework
+ pylint>=2.16.2,<=3.1.0
+commands =
+ pylint src tests
+allowlist_externals = pylint
+
+[testenv:fmt]
+description = format
+skip_install = true
+deps =
+ black>=22.12
+ isort>=5.11
+commands =
+ black {posargs:.}
+ isort {posargs:.}
+
+[testenv:build]
+description = build wheel
+deps =
+ build
+commands = python -m build -w
+skip_install = True
+
+[testenv:twinecheck]
+description = check wheel
+deps =
+ twine
+commands = twine check dist/*
+skip_install = True
\ No newline at end of file
diff --git a/plugins/framework/src/fms_acceleration/constants.py b/plugins/framework/src/fms_acceleration/constants.py
index 6a81d977..d568ec13 100644
--- a/plugins/framework/src/fms_acceleration/constants.py
+++ b/plugins/framework/src/fms_acceleration/constants.py
@@ -21,4 +21,4 @@
# and activated.
# - hence the plugins that have model loaders should be on top of this list
-PLUGINS = ["peft", "foak", "aadp"]
+PLUGINS = ["peft", "foak", "aadp", "moe"]
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py
index 16ea64b7..906a4668 100644
--- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py
@@ -39,6 +39,7 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] =
from .models import ( # pylint: disable=import-outside-toplevel
gpt_bigcode,
granite,
+ granitemoe,
llama,
mistral,
mixtral,
@@ -47,6 +48,7 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] =
rules = [
*gpt_bigcode.get_mp_rules(base_type),
*granite.get_mp_rules(base_type),
+ *granitemoe.get_mp_rules(base_type),
*llama.get_mp_rules(base_type),
*mistral.get_mp_rules(base_type),
*mixtral.get_mp_rules(base_type),
@@ -76,6 +78,7 @@ class FastKernelsAccelerationPlugin(AccelerationPlugin):
# NOTE: may remove this when we have generic model rules
restricted_model_archs = [
"GraniteForCausalLM",
+ "GraniteMoeForCausalLM",
"GPTBigCodeForCausalLM",
"MixtralForCausalLM",
"LlamaForCausalLM",
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granitemoe.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granitemoe.py
new file mode 100644
index 00000000..6da14682
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granitemoe.py
@@ -0,0 +1,116 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Standard
+from functools import partial
+
+# Third Party
+from fms_acceleration.model_patcher import (
+ ModelPatcherRule,
+ ModelPatcherTrigger,
+ combine_functions,
+ combine_triggers,
+)
+
+# Local
+from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
+from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
+from ..kernels.unsloth.rope_embedding import fast_rope_embedding
+from .utils import KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops
+
+
+def get_mp_rules(base_type: str):
+ """
+ Function to access all patch rules in this module.
+ If it is a forward_builder rule with `base_type` in
+ its forward builder argument, wrap the forward_builder
+ function as a partial function with the base_type argument
+ """
+ try:
+ # Third Party
+ from transformers.models.granitemoe.modeling_granitemoe import ( # pylint: disable=import-outside-toplevel
+ GraniteMoeAttention,
+ GraniteMoeRMSNorm,
+ )
+ except ImportError:
+ return []
+
+ return [
+ # TODO: have a generic version of this rule
+ # - do regex on RMSNorm class name
+ # - check on the tensors required for fast_rms_layernorm
+ ModelPatcherRule(
+ rule_id="granitemoe-rms",
+ trigger=ModelPatcherTrigger(check=GraniteMoeRMSNorm),
+ forward=fast_rms_layernorm,
+ ),
+ # TODO: have a generic version of this rule
+ # - do regex on Attention class name
+ # - have a set of qkv / o module names and check on that
+ ModelPatcherRule(
+ rule_id="granitemoe-qkvo",
+ trigger=combine_triggers(
+ ModelPatcherTrigger(
+ check=partial(
+ trigger_fused_ops,
+ attn_cls=GraniteMoeAttention,
+ submodule_names=["q_proj", "k_proj", "v_proj"],
+ )
+ ),
+ ModelPatcherTrigger(
+ check=partial(
+ trigger_fused_ops,
+ attn_cls=GraniteMoeAttention,
+ submodule_names=["o_proj"],
+ )
+ ),
+ logic="OR",
+ ),
+ forward_builder=combine_functions(
+ partial(
+ build_lora_fused_ops,
+ submodule_names=["q_proj", "k_proj", "v_proj"],
+ fused_op=KEY_QKV,
+ base_type=base_type,
+ ),
+ partial(
+ build_lora_fused_ops,
+ submodule_names=["o_proj"],
+ fused_op=KEY_O,
+ base_type=base_type,
+ ),
+ logic="APPEND",
+ ),
+ ),
+ ModelPatcherRule(
+ rule_id="granitemoe-cross-ent",
+ import_and_maybe_reload=(
+ "torch.nn.CrossEntropyLoss",
+ FastCrossEntropyLoss,
+ "transformers.models.granitemoe.modeling_granitemoe",
+ ),
+ ),
+ # TODO: have a generic version of this rule
+ # - get the module name
+ # - check if "apply_rotary_pos_emb" exists
+ # - patch
+ ModelPatcherRule(
+ rule_id="granitemoe-rope",
+ import_and_maybe_reload=(
+ "transformers.models.granitemoe.modeling_granitemoe.apply_rotary_pos_emb",
+ fast_rope_embedding,
+ None,
+ ),
+ ),
+ ]
diff --git a/sample-configurations/CONTENTS.yaml b/sample-configurations/CONTENTS.yaml
index 6781b3bd..b3b1deec 100644
--- a/sample-configurations/CONTENTS.yaml
+++ b/sample-configurations/CONTENTS.yaml
@@ -73,3 +73,69 @@ framework_configs:
plugins:
- fused-ops-and-kernels
filename: foak-fast-kernels-sample-configuration.yaml
+
+ # ------- MOE CONFIGS ----------
+ - shortname: moe-scattermoe-granite-ep1
+ plugins:
+ - accelerated-moe
+ filename: moe-scattermoe-granite-ep1-sample-configuration.yaml
+
+ - shortname: moe-scattermoe-granite-ep1-padding-free
+ plugins:
+ - accelerated-moe
+ - attention-and-distributed-packing
+ filename: moe-scattermoe-granite-ep1-padding-free-sample-configuration.yaml
+
+ - shortname: moe-scattermoe-granite-ep1-padding-free-foak
+ plugins:
+ - accelerated-moe
+ - attention-and-distributed-packing
+ - fused-ops-and-kernels
+ filename: moe-scattermoe-granite-ep1-padding-free-foak-sample-configuration.yaml
+
+ - shortname: moe-scattermoe-granite-ep2
+ plugins:
+ - accelerated-moe
+ filename: moe-scattermoe-granite-ep2-sample-configuration.yaml
+
+ - shortname: moe-scattermoe-granite-ep2-padding-free
+ plugins:
+ - accelerated-moe
+ - attention-and-distributed-packing
+ filename: moe-scattermoe-granite-ep2-padding-free-sample-configuration.yaml
+
+ - shortname: moe-scattermoe-granite-ep2-padding-free-foak
+ plugins:
+ - accelerated-moe
+ - attention-and-distributed-packing
+ - fused-ops-and-kernels
+ filename: moe-scattermoe-granite-ep2-padding-free-foak-sample-configuration.yaml
+
+ - shortname: moe-scattermoe-granite-ep4
+ plugins:
+ - accelerated-moe
+ filename: moe-scattermoe-granite-ep4-sample-configuration.yaml
+
+ - shortname: moe-scattermoe-granite-ep4-padding-free
+ plugins:
+ - accelerated-moe
+ - attention-and-distributed-packing
+ filename: moe-scattermoe-granite-ep4-padding-free-sample-configuration.yaml
+
+ - shortname: moe-scattermoe-granite-ep4-padding-free-foak
+ plugins:
+ - accelerated-moe
+ - attention-and-distributed-packing
+ - fused-ops-and-kernels
+ filename: moe-scattermoe-granite-ep4-padding-free-foak-sample-configuration.yaml
+
+ - shortname: moe-scattermoe-granite-ep8
+ plugins:
+ - accelerated-moe
+ filename: moe-scattermoe-granite-ep8-sample-configuration.yaml
+
+ - shortname: moe-scattermoe-granite-ep8-foak
+ plugins:
+ - accelerated-moe
+ - fused-ops-and-kernels
+ filename: moe-scattermoe-granite-ep8-foak-sample-configuration.yaml
\ No newline at end of file
diff --git a/sample-configurations/moe-scattermoe-granite-ep1-padding-free-foak-sample-configuration.yaml b/sample-configurations/moe-scattermoe-granite-ep1-padding-free-foak-sample-configuration.yaml
new file mode 100644
index 00000000..881ef14b
--- /dev/null
+++ b/sample-configurations/moe-scattermoe-granite-ep1-padding-free-foak-sample-configuration.yaml
@@ -0,0 +1,51 @@
+# FMS Acceleration Plugin Configuration.
+#
+# Each stanza incorporates various configurations for
+# different fine-tuning / training tasks.
+plugins:
+ # Configurations to accelerate data packing/padding in training
+ training:
+
+ # attention module configurations
+ # e.g. padding-free modifications to attention layer
+ attention:
+
+ # this controls the confgurations for padding free computation of flash attention
+ padding_free:
+ method: huggingface
+ fused_ops_and_kernels:
+
+ # if under training stanza, then putting
+ # base_layer and fused_lora will be a misnomer
+ # - this should be in peft.quantized
+ # However, if it is specified, it will still
+ # be read. This is useful in use cases where
+ # the yaml is system generated and not shown
+ # to a user.
+
+ # activate various unsloth optimizations
+ # there are two versions of the plugin
+ # - the FastKernel version supports individual kernels
+ # - the FastQuantized version is all-or-nothing
+
+ # fast loss triton kernels
+ fast_loss: true
+
+ # fast rms norm triton kernels
+ fast_rms_layernorm: true
+
+ # fast RoPE embedding triton kernels
+ fast_rope_embeddings: true
+ moe:
+
+ # expert-parallel for MoE
+ scattermoe:
+
+ # The level of expert parallel sharding.
+ # - 1 means no sharding
+ # - if > 1, please ensure that this divides the world_size. This is because
+ # the devices will be replicated for every ep_degree devices, and
+ # the experts will be sharded within each group.
+ # - if > 1, also ensure that it divides the number of experts, as each device
+ # will then have num_of_experts / ep_degree experts.
+ ep_degree: 1
diff --git a/sample-configurations/moe-scattermoe-granite-ep1-padding-free-sample-configuration.yaml b/sample-configurations/moe-scattermoe-granite-ep1-padding-free-sample-configuration.yaml
new file mode 100644
index 00000000..b3af6f99
--- /dev/null
+++ b/sample-configurations/moe-scattermoe-granite-ep1-padding-free-sample-configuration.yaml
@@ -0,0 +1,28 @@
+# FMS Acceleration Plugin Configuration.
+#
+# Each stanza incorporates various configurations for
+# different fine-tuning / training tasks.
+plugins:
+ # Configurations to accelerate data packing/padding in training
+ training:
+
+ # attention module configurations
+ # e.g. padding-free modifications to attention layer
+ attention:
+
+ # this controls the confgurations for padding free computation of flash attention
+ padding_free:
+ method: huggingface
+ moe:
+
+ # expert-parallel for MoE
+ scattermoe:
+
+ # The level of expert parallel sharding.
+ # - 1 means no sharding
+ # - if > 1, please ensure that this divides the world_size. This is because
+ # the devices will be replicated for every ep_degree devices, and
+ # the experts will be sharded within each group.
+ # - if > 1, also ensure that it divides the number of experts, as each device
+ # will then have num_of_experts / ep_degree experts.
+ ep_degree: 1
diff --git a/sample-configurations/moe-scattermoe-granite-ep1-sample-configuration.yaml b/sample-configurations/moe-scattermoe-granite-ep1-sample-configuration.yaml
new file mode 100644
index 00000000..bb327a44
--- /dev/null
+++ b/sample-configurations/moe-scattermoe-granite-ep1-sample-configuration.yaml
@@ -0,0 +1,21 @@
+# FMS Acceleration Plugin Configuration.
+#
+# Each stanza incorporates various configurations for
+# different fine-tuning / training tasks.
+plugins:
+ training:
+
+ # mixture-of-experts configurations
+ moe:
+
+ # expert-parallel for MoE
+ scattermoe:
+
+ # The level of expert parallel sharding.
+ # - 1 means no sharding
+ # - if > 1, please ensure that this divides the world_size. This is because
+ # the devices will be replicated for every ep_degree devices, and
+ # the experts will be sharded within each group.
+ # - if > 1, also ensure that it divides the number of experts, as each device
+ # will then have num_of_experts / ep_degree experts.
+ ep_degree: 1
diff --git a/sample-configurations/moe-scattermoe-granite-ep2-padding-free-foak-sample-configuration.yaml b/sample-configurations/moe-scattermoe-granite-ep2-padding-free-foak-sample-configuration.yaml
new file mode 100644
index 00000000..b3c7712d
--- /dev/null
+++ b/sample-configurations/moe-scattermoe-granite-ep2-padding-free-foak-sample-configuration.yaml
@@ -0,0 +1,51 @@
+# FMS Acceleration Plugin Configuration.
+#
+# Each stanza incorporates various configurations for
+# different fine-tuning / training tasks.
+plugins:
+ # Configurations to accelerate data packing/padding in training
+ training:
+
+ # attention module configurations
+ # e.g. padding-free modifications to attention layer
+ attention:
+
+ # this controls the confgurations for padding free computation of flash attention
+ padding_free:
+ method: huggingface
+ fused_ops_and_kernels:
+
+ # if under training stanza, then putting
+ # base_layer and fused_lora will be a misnomer
+ # - this should be in peft.quantized
+ # However, if it is specified, it will still
+ # be read. This is useful in use cases where
+ # the yaml is system generated and not shown
+ # to a user.
+
+ # activate various unsloth optimizations
+ # there are two versions of the plugin
+ # - the FastKernel version supports individual kernels
+ # - the FastQuantized version is all-or-nothing
+
+ # fast loss triton kernels
+ fast_loss: true
+
+ # fast rms norm triton kernels
+ fast_rms_layernorm: true
+
+ # fast RoPE embedding triton kernels
+ fast_rope_embeddings: true
+ moe:
+
+ # expert-parallel for MoE
+ scattermoe:
+
+ # The level of expert parallel sharding.
+ # - 1 means no sharding
+ # - if > 1, please ensure that this divides the world_size. This is because
+ # the devices will be replicated for every ep_degree devices, and
+ # the experts will be sharded within each group.
+ # - if > 1, also ensure that it divides the number of experts, as each device
+ # will then have num_of_experts / ep_degree experts.
+ ep_degree: 2
diff --git a/sample-configurations/moe-scattermoe-granite-ep2-padding-free-sample-configuration.yaml b/sample-configurations/moe-scattermoe-granite-ep2-padding-free-sample-configuration.yaml
new file mode 100644
index 00000000..474171b3
--- /dev/null
+++ b/sample-configurations/moe-scattermoe-granite-ep2-padding-free-sample-configuration.yaml
@@ -0,0 +1,28 @@
+# FMS Acceleration Plugin Configuration.
+#
+# Each stanza incorporates various configurations for
+# different fine-tuning / training tasks.
+plugins:
+ # Configurations to accelerate data packing/padding in training
+ training:
+
+ # attention module configurations
+ # e.g. padding-free modifications to attention layer
+ attention:
+
+ # this controls the confgurations for padding free computation of flash attention
+ padding_free:
+ method: huggingface
+ moe:
+
+ # expert-parallel for MoE
+ scattermoe:
+
+ # The level of expert parallel sharding.
+ # - 1 means no sharding
+ # - if > 1, please ensure that this divides the world_size. This is because
+ # the devices will be replicated for every ep_degree devices, and
+ # the experts will be sharded within each group.
+ # - if > 1, also ensure that it divides the number of experts, as each device
+ # will then have num_of_experts / ep_degree experts.
+ ep_degree: 2
diff --git a/sample-configurations/moe-scattermoe-granite-ep2-sample-configuration.yaml b/sample-configurations/moe-scattermoe-granite-ep2-sample-configuration.yaml
new file mode 100644
index 00000000..b24d00cb
--- /dev/null
+++ b/sample-configurations/moe-scattermoe-granite-ep2-sample-configuration.yaml
@@ -0,0 +1,21 @@
+# FMS Acceleration Plugin Configuration.
+#
+# Each stanza incorporates various configurations for
+# different fine-tuning / training tasks.
+plugins:
+ training:
+
+ # mixture-of-experts configurations
+ moe:
+
+ # expert-parallel for MoE
+ scattermoe:
+
+ # The level of expert parallel sharding.
+ # - 1 means no sharding
+ # - if > 1, please ensure that this divides the world_size. This is because
+ # the devices will be replicated for every ep_degree devices, and
+ # the experts will be sharded within each group.
+ # - if > 1, also ensure that it divides the number of experts, as each device
+ # will then have num_of_experts / ep_degree experts.
+ ep_degree: 2
diff --git a/sample-configurations/moe-scattermoe-granite-ep4-padding-free-foak-sample-configuration.yaml b/sample-configurations/moe-scattermoe-granite-ep4-padding-free-foak-sample-configuration.yaml
new file mode 100644
index 00000000..c73917ce
--- /dev/null
+++ b/sample-configurations/moe-scattermoe-granite-ep4-padding-free-foak-sample-configuration.yaml
@@ -0,0 +1,51 @@
+# FMS Acceleration Plugin Configuration.
+#
+# Each stanza incorporates various configurations for
+# different fine-tuning / training tasks.
+plugins:
+ # Configurations to accelerate data packing/padding in training
+ training:
+
+ # attention module configurations
+ # e.g. padding-free modifications to attention layer
+ attention:
+
+ # this controls the confgurations for padding free computation of flash attention
+ padding_free:
+ method: huggingface
+ fused_ops_and_kernels:
+
+ # if under training stanza, then putting
+ # base_layer and fused_lora will be a misnomer
+ # - this should be in peft.quantized
+ # However, if it is specified, it will still
+ # be read. This is useful in use cases where
+ # the yaml is system generated and not shown
+ # to a user.
+
+ # activate various unsloth optimizations
+ # there are two versions of the plugin
+ # - the FastKernel version supports individual kernels
+ # - the FastQuantized version is all-or-nothing
+
+ # fast loss triton kernels
+ fast_loss: true
+
+ # fast rms norm triton kernels
+ fast_rms_layernorm: true
+
+ # fast RoPE embedding triton kernels
+ fast_rope_embeddings: true
+ moe:
+
+ # expert-parallel for MoE
+ scattermoe:
+
+ # The level of expert parallel sharding.
+ # - 1 means no sharding
+ # - if > 1, please ensure that this divides the world_size. This is because
+ # the devices will be replicated for every ep_degree devices, and
+ # the experts will be sharded within each group.
+ # - if > 1, also ensure that it divides the number of experts, as each device
+ # will then have num_of_experts / ep_degree experts.
+ ep_degree: 4
diff --git a/sample-configurations/moe-scattermoe-granite-ep4-padding-free-sample-configuration.yaml b/sample-configurations/moe-scattermoe-granite-ep4-padding-free-sample-configuration.yaml
new file mode 100644
index 00000000..8cd803cd
--- /dev/null
+++ b/sample-configurations/moe-scattermoe-granite-ep4-padding-free-sample-configuration.yaml
@@ -0,0 +1,28 @@
+# FMS Acceleration Plugin Configuration.
+#
+# Each stanza incorporates various configurations for
+# different fine-tuning / training tasks.
+plugins:
+ # Configurations to accelerate data packing/padding in training
+ training:
+
+ # attention module configurations
+ # e.g. padding-free modifications to attention layer
+ attention:
+
+ # this controls the confgurations for padding free computation of flash attention
+ padding_free:
+ method: huggingface
+ moe:
+
+ # expert-parallel for MoE
+ scattermoe:
+
+ # The level of expert parallel sharding.
+ # - 1 means no sharding
+ # - if > 1, please ensure that this divides the world_size. This is because
+ # the devices will be replicated for every ep_degree devices, and
+ # the experts will be sharded within each group.
+ # - if > 1, also ensure that it divides the number of experts, as each device
+ # will then have num_of_experts / ep_degree experts.
+ ep_degree: 4
diff --git a/sample-configurations/moe-scattermoe-granite-ep4-sample-configuration.yaml b/sample-configurations/moe-scattermoe-granite-ep4-sample-configuration.yaml
new file mode 100644
index 00000000..b48081df
--- /dev/null
+++ b/sample-configurations/moe-scattermoe-granite-ep4-sample-configuration.yaml
@@ -0,0 +1,21 @@
+# FMS Acceleration Plugin Configuration.
+#
+# Each stanza incorporates various configurations for
+# different fine-tuning / training tasks.
+plugins:
+ training:
+
+ # mixture-of-experts configurations
+ moe:
+
+ # expert-parallel for MoE
+ scattermoe:
+
+ # The level of expert parallel sharding.
+ # - 1 means no sharding
+ # - if > 1, please ensure that this divides the world_size. This is because
+ # the devices will be replicated for every ep_degree devices, and
+ # the experts will be sharded within each group.
+ # - if > 1, also ensure that it divides the number of experts, as each device
+ # will then have num_of_experts / ep_degree experts.
+ ep_degree: 4
diff --git a/sample-configurations/moe-scattermoe-granite-ep8-foak-sample-configuration.yaml b/sample-configurations/moe-scattermoe-granite-ep8-foak-sample-configuration.yaml
new file mode 100644
index 00000000..938c9024
--- /dev/null
+++ b/sample-configurations/moe-scattermoe-granite-ep8-foak-sample-configuration.yaml
@@ -0,0 +1,43 @@
+# FMS Acceleration Plugin Configuration.
+#
+# Each stanza incorporates various configurations for
+# different fine-tuning / training tasks.
+plugins:
+ training:
+
+ fused_ops_and_kernels:
+
+ # if under training stanza, then putting
+ # base_layer and fused_lora will be a misnomer
+ # - this should be in peft.quantized
+ # However, if it is specified, it will still
+ # be read. This is useful in use cases where
+ # the yaml is system generated and not shown
+ # to a user.
+
+ # activate various unsloth optimizations
+ # there are two versions of the plugin
+ # - the FastKernel version supports individual kernels
+ # - the FastQuantized version is all-or-nothing
+
+ # fast loss triton kernels
+ fast_loss: true
+
+ # fast rms norm triton kernels
+ fast_rms_layernorm: true
+
+ # fast RoPE embedding triton kernels
+ fast_rope_embeddings: true
+ moe:
+
+ # expert-parallel for MoE
+ scattermoe:
+
+ # The level of expert parallel sharding.
+ # - 1 means no sharding
+ # - if > 1, please ensure that this divides the world_size. This is because
+ # the devices will be replicated for every ep_degree devices, and
+ # the experts will be sharded within each group.
+ # - if > 1, also ensure that it divides the number of experts, as each device
+ # will then have num_of_experts / ep_degree experts.
+ ep_degree: 8
diff --git a/sample-configurations/moe-scattermoe-granite-ep8-sample-configuration.yaml b/sample-configurations/moe-scattermoe-granite-ep8-sample-configuration.yaml
new file mode 100644
index 00000000..af5500e5
--- /dev/null
+++ b/sample-configurations/moe-scattermoe-granite-ep8-sample-configuration.yaml
@@ -0,0 +1,21 @@
+# FMS Acceleration Plugin Configuration.
+#
+# Each stanza incorporates various configurations for
+# different fine-tuning / training tasks.
+plugins:
+ training:
+
+ # mixture-of-experts configurations
+ moe:
+
+ # expert-parallel for MoE
+ scattermoe:
+
+ # The level of expert parallel sharding.
+ # - 1 means no sharding
+ # - if > 1, please ensure that this divides the world_size. This is because
+ # the devices will be replicated for every ep_degree devices, and
+ # the experts will be sharded within each group.
+ # - if > 1, also ensure that it divides the number of experts, as each device
+ # will then have num_of_experts / ep_degree experts.
+ ep_degree: 8
diff --git a/scripts/benchmarks/README.md b/scripts/benchmarks/README.md
index 269d3ead..4795efee 100644
--- a/scripts/benchmarks/README.md
+++ b/scripts/benchmarks/README.md
@@ -76,13 +76,14 @@ bash run_benchmarks.sh NUM_GPUS_MATRIX RESULT_DIR SCENARIOS_CONFIG SCENARIOS_FIL
```
where:
- `NUM_GPUS_MATRIX`: list of `num_gpu` settings to bench for, e.g. `"1 2"` will bench for 1 and 2 gpus.
+- `EFFECTIVE_BS_MATRIX`: list of effective batch sizes, e.g., `"4 8"` will bench for effective batch sizes 4 and 8.
- `RESULT_DIR`: where the benchmark results will be placed.
- `SCENARIOS_CONFIG`: the `scenarios.yaml` file.
- `SCENARIOS_CONFIG`: specify to run only a specific `scenario` by providing the specific `scenario` name.
The recommended way to run `benchmarks.sh` is using `tox` which handles the dependencies:
```
-tox -e run-benches -- NUM_GPUS_MATRIX RESULT_DIR SCENARIOS_CONFIG SCENARIOS_FILTER
+tox -e run-benches -- NUM_GPUS_MATRIX EFFECTIVE_BS_MATRIX RESULT_DIR SCENARIOS_CONFIG SCENARIOS_FILTER
```
Alternatively run [`benchmark.py`](./benchmark.py) directly. To see the help do:
diff --git a/scripts/benchmarks/accelerate.yaml b/scripts/benchmarks/accelerate.yaml
index 7923e624..a40505a7 100644
--- a/scripts/benchmarks/accelerate.yaml
+++ b/scripts/benchmarks/accelerate.yaml
@@ -31,7 +31,7 @@ fsdp_config:
# 3 is NO_SHARD, effectively disabling FSDP
# 4, 5 are HYBRID_ modes for multi-node training only.
- fsdp_state_dict_type: FULL_STATE_DICT # set to FULL_STATE_DICT (1), SHARDED_STATE_DICT (3)
+ fsdp_state_dict_type: SHARDED_STATE_DICT # set to FULL_STATE_DICT (1), SHARDED_STATE_DICT (3)
# 2 is LOCAL_STATE_DICT where parameters are still flattened
# 3 is efficient, but requires know-how to use the shared checkpoint.
diff --git a/scripts/benchmarks/accelerator-config.json b/scripts/benchmarks/accelerator-config.json
new file mode 100644
index 00000000..7f736f97
--- /dev/null
+++ b/scripts/benchmarks/accelerator-config.json
@@ -0,0 +1,5 @@
+{
+ "gradient_accumulation_kwargs": {
+ "sync_each_batch": true
+ }
+}
diff --git a/scripts/benchmarks/benchmark.py b/scripts/benchmarks/benchmark.py
index 3bd7056b..38fe6679 100644
--- a/scripts/benchmarks/benchmark.py
+++ b/scripts/benchmarks/benchmark.py
@@ -305,12 +305,34 @@ def build_args_from_products(products: List[Dict], defaults: Dict):
argument_list = ConfigUtils.convert_keyvalue_arguments_to_list(
combined_args
)
- argument_list.extend(
- [
- "--per_device_train_batch_size",
- str(effective_batch_size // num_gpus),
- ]
- )
+ pdtbs = combined_args.get('per_device_train_batch_size')
+ grad_accum = combined_args.get('gradient_accumulation_steps')
+ if pdtbs is None and grad_accum is not None:
+ if grad_accum > 1:
+ warnings.warn(
+ f"Found gradient_accumulation_steps={grad_accum} and "
+ "no per_device_train_batch_size specified, but for backward "
+ "compatibility, ignoring gradient_accum in batch size "
+ "computation (this behavior may change in the future)."
+ )
+ argument_list.extend(
+ [
+ "--per_device_train_batch_size",
+ str(effective_batch_size // num_gpus),
+ ]
+ )
+ elif grad_accum is None and pdtbs is not None:
+ argument_list.extend(
+ [
+ "--gradient_accumulation_steps",
+ str(effective_batch_size // num_gpus // pdtbs),
+ ]
+ )
+ else:
+ raise ValueError(
+ "Please specify only either per_device_train_batch_size or gradient_accumulation_steps "
+ "and not both."
+ )
args.append((num_gpus, framework_config, argument_list))
return args
@@ -358,6 +380,12 @@ class ScenarioMatrix:
def __init__(self, scenario: Dict, acceleration_config_map: Dict = None) -> None:
assert "arguments" in scenario.keys(), "Missing `arguments` key in `scenario`"
+
+ # "slow" is a special key that indicates this scenario
+ # takes resources to run
+ # - "slow" scenarios are not run if not specified by a filter
+ self.slow = False
+
for key, val in scenario.items():
if key == "framework_config":
# if acceleration_config_map is None, then do not do mapping
@@ -689,7 +717,18 @@ def prepare_arguments(args, benchmark_dataset: BenchmarkDataset):
if args.run_only_scenarios and _scn_name not in args.run_only_scenarios:
print(f"Skipping scenario '{_scn_name}'")
continue
+
+ # build scenario matrix
scenario = ScenarioMatrix(scenario_config, acceleration_config_map)
+
+ if (
+ not args.run_only_scenarios
+ and scenarios.slow
+ ):
+ # unfiltered runs omit all "slow" marked scenarios
+ print(f"Skipping slow scenario '{_scn_name}' beacuse run_only_scenarios=None.")
+ continue
+
scenario_matrices, scenario_constants = (
scenario.get_scenario_matrices_and_defaults()
)
diff --git a/scripts/benchmarks/refs/a100_80gb_moe.csv b/scripts/benchmarks/refs/a100_80gb_moe.csv
new file mode 100644
index 00000000..4936cd6f
--- /dev/null
+++ b/scripts/benchmarks/refs/a100_80gb_moe.csv
@@ -0,0 +1,25 @@
+epoch,framework_config,gradient_accumulation_steps,mem_nvidia_mem_reserved,model_name_or_path,num_gpus,per_device_train_batch_size,torch_dtype,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second
+0.25,none,16.0,71199.0,ibm-granite/granite-3.0-3b-a800m-instruct,1,8,bfloat16,0.9438143467903136,2371.9316,5.396,0.042,1505.608
+0.25,none,8.0,46829.0,ibm-granite/granite-3.0-3b-a800m-instruct,2,8,bfloat16,0.9437569552659988,1355.7096,9.442,0.074,1317.096
+0.25,none,4.0,37996.0,ibm-granite/granite-3.0-3b-a800m-instruct,4,8,bfloat16,0.9437739425897598,708.3914,18.069,0.141,1260.32
+0.25,moe-scattermoe-granite-ep1,16.0,71187.0,ibm-granite/granite-3.0-3b-a800m-instruct,1,8,bfloat16,0.9439476370811464,742.739,17.234,0.135,4808.149
+0.25,moe-scattermoe-granite-ep1,8.0,52503.0,ibm-granite/granite-3.0-3b-a800m-instruct,2,8,bfloat16,0.9506204092502594,485.5103,26.364,0.206,3677.78
+0.25,moe-scattermoe-granite-ep1,4.0,51145.0,ibm-granite/granite-3.0-3b-a800m-instruct,4,8,bfloat16,0.9572784686088562,262.9566,48.677,0.38,3395.238
+0.25,moe-scattermoe-granite-ep2,8.0,40193.0,ibm-granite/granite-3.0-3b-a800m-instruct,2,8,bfloat16,0.9437192791700364,577.2164,22.175,0.173,3093.467
+0.25,moe-scattermoe-granite-ep2,4.0,40878.5,ibm-granite/granite-3.0-3b-a800m-instruct,4,8,bfloat16,0.9509018939733506,300.285,42.626,0.333,2973.176
+0.25,moe-scattermoe-granite-ep4,4.0,31777.5,ibm-granite/granite-3.0-3b-a800m-instruct,4,8,bfloat16,0.9434539985656738,307.1264,41.677,0.326,2906.946
+0.25,moe-scattermoe-granite-ep1-padding-free,16.0,48401.0,ibm-granite/granite-3.0-3b-a800m-instruct,1,8,bfloat16,0.9437484860420228,631.9756,20.254,0.158,3924.202
+0.25,moe-scattermoe-granite-ep1-padding-free,8.0,42452.0,ibm-granite/granite-3.0-3b-a800m-instruct,2,8,bfloat16,0.9506663566827774,454.3444,28.172,0.22,2729.207
+0.25,moe-scattermoe-granite-ep1-padding-free,4.0,38560.0,ibm-granite/granite-3.0-3b-a800m-instruct,4,8,bfloat16,0.957276314496994,241.2967,53.047,0.414,2569.451
+0.25,moe-scattermoe-granite-ep2-padding-free,8.0,31012.0,ibm-granite/granite-3.0-3b-a800m-instruct,2,8,bfloat16,0.943688799738884,546.507,23.421,0.183,2268.955
+0.25,moe-scattermoe-granite-ep2-padding-free,4.0,28133.0,ibm-granite/granite-3.0-3b-a800m-instruct,4,8,bfloat16,0.9505942213535308,283.5444,45.143,0.353,2186.607
+0.25,moe-scattermoe-granite-ep4-padding-free,4.0,21585.5,ibm-granite/granite-3.0-3b-a800m-instruct,4,8,bfloat16,0.9441865116357804,284.6079,44.974,0.351,2178.436
+0.25,moe-scattermoe-granite-ep1-padding-free-foak,16.0,42651.0,ibm-granite/granite-3.0-3b-a800m-instruct,1,8,bfloat16,0.9437448275089264,615.4528,20.798,0.162,4029.554
+0.25,moe-scattermoe-granite-ep1-padding-free-foak,8.0,37743.0,ibm-granite/granite-3.0-3b-a800m-instruct,2,8,bfloat16,0.950773031115532,433.4811,29.528,0.231,2860.563
+0.25,moe-scattermoe-granite-ep1-padding-free-foak,4.0,35153.0,ibm-granite/granite-3.0-3b-a800m-instruct,4,8,bfloat16,0.9572476959228516,232.0428,55.162,0.431,2671.921
+0.25,moe-scattermoe-granite-ep2-padding-free-foak,8.0,26075.0,ibm-granite/granite-3.0-3b-a800m-instruct,2,8,bfloat16,0.9437651455402374,524.7751,24.391,0.191,2362.917
+0.25,moe-scattermoe-granite-ep2-padding-free-foak,4.0,24665.5,ibm-granite/granite-3.0-3b-a800m-instruct,4,8,bfloat16,0.9507779973745346,274.126,46.694,0.365,2261.733
+0.25,moe-scattermoe-granite-ep4-padding-free-foak,4.0,18368.0,ibm-granite/granite-3.0-3b-a800m-instruct,4,8,bfloat16,0.943427557349205,278.1245,46.023,0.36,2229.217
+,none,,65607.25,mistralai/Mixtral-8x7B-Instruct-v0.1,8,1,bfloat16,0.8599078696966171,4180.9544,3.062,0.024,80.364
+,moe-scattermoe-granite-ep8,,52004.75,mistralai/Mixtral-8x7B-Instruct-v0.1,8,1,bfloat16,0.8588122856616974,1071.1967,11.949,0.093,313.668
+,moe-scattermoe-granite-ep8-foak,,51961.25,mistralai/Mixtral-8x7B-Instruct-v0.1,8,1,bfloat16,0.8599798053503036,1043.6675,12.264,0.096,321.942
diff --git a/scripts/benchmarks/refs/requirements_moe.txt b/scripts/benchmarks/refs/requirements_moe.txt
new file mode 100644
index 00000000..63700ed0
--- /dev/null
+++ b/scripts/benchmarks/refs/requirements_moe.txt
@@ -0,0 +1,89 @@
+accelerate==1.0.1
+aiohappyeyeballs==2.4.3
+aiohttp==3.10.10
+aiosignal==1.3.1
+async-timeout==4.0.3
+attrs==24.2.0
+bitsandbytes==0.43.3
+certifi==2024.8.30
+charset-normalizer==3.4.0
+contourpy==1.3.0
+cycler==0.12.1
+datasets==2.21.0
+dill==0.3.8
+docstring_parser==0.16
+einops==0.8.0
+filelock==3.16.1
+flash-attn==2.6.3
+-e git+https://github.com/foundation-model-stack/fms-acceleration.git@21af5fb9f2989b3dbf443c016e4c0470b536a593#egg=fms_acceleration&subdirectory=plugins/framework
+-e git+https://github.com/foundation-model-stack/fms-acceleration.git@21af5fb9f2989b3dbf443c016e4c0470b536a593#egg=fms_acceleration_aadp&subdirectory=plugins/attention-and-distributed-packing
+-e git+https://github.com/foundation-model-stack/fms-acceleration.git@21af5fb9f2989b3dbf443c016e4c0470b536a593#egg=fms_acceleration_foak&subdirectory=plugins/fused-ops-and-kernels
+-e git+https://github.com/foundation-model-stack/fms-acceleration.git@21af5fb9f2989b3dbf443c016e4c0470b536a593#egg=fms_acceleration_moe&subdirectory=plugins/accelerated-moe
+-e git+https://github.com/foundation-model-stack/fms-acceleration.git@21af5fb9f2989b3dbf443c016e4c0470b536a593#egg=fms_acceleration_peft&subdirectory=plugins/accelerated-peft
+fms-hf-tuning @ git+https://github.com/foundation-model-stack/fms-hf-tuning.git@398c2a8fe26d734344240555585d95e05299faa8
+fonttools==4.54.1
+frozenlist==1.5.0
+fsspec==2024.6.1
+huggingface-hub==0.26.2
+idna==3.10
+Jinja2==3.1.4
+kernel-hyperdrive @ git+https://github.com/fabianlim/kernel-hyperdrive.git@45036497e12444ca98a6f0072204538aee4543ba
+kiwisolver==1.4.7
+llvmlite==0.43.0
+markdown-it-py==3.0.0
+MarkupSafe==3.0.2
+matplotlib==3.9.2
+mdurl==0.1.2
+mpmath==1.3.0
+multidict==6.1.0
+multiprocess==0.70.16
+networkx==3.4.2
+numba==0.60.0
+numpy==1.26.4
+nvidia-cublas-cu12==12.1.3.1
+nvidia-cuda-cupti-cu12==12.1.105
+nvidia-cuda-nvrtc-cu12==12.1.105
+nvidia-cuda-runtime-cu12==12.1.105
+nvidia-cudnn-cu12==9.1.0.70
+nvidia-cufft-cu12==11.0.2.54
+nvidia-curand-cu12==10.3.2.106
+nvidia-cusolver-cu12==11.4.5.107
+nvidia-cusparse-cu12==12.1.0.106
+nvidia-nccl-cu12==2.20.5
+nvidia-nvjitlink-cu12==12.4.127
+nvidia-nvtx-cu12==12.1.105
+packaging==24.2
+pandas==2.2.3
+peft==0.13.2
+pillow==11.0.0
+propcache==0.2.0
+protobuf==5.28.3
+psutil==6.1.0
+pyarrow==18.0.0
+Pygments==2.18.0
+pyparsing==3.2.0
+python-dateutil==2.9.0.post0
+pytz==2024.2
+PyYAML==6.0.2
+regex==2024.11.6
+requests==2.32.3
+rich==13.9.4
+safetensors==0.4.5
+sentencepiece==0.2.0
+shtab==1.7.1
+simpleeval==0.9.13
+six==1.16.0
+sympy==1.13.1
+threadpoolctl==3.5.0
+tokenizers==0.20.3
+torch==2.4.1
+tqdm==4.67.0
+transformers==4.45.2
+triton==3.0.0
+trl==0.11.4
+typing_extensions==4.12.2
+tyro==0.8.14
+tzdata==2024.2
+urllib3==2.2.3
+xxhash==3.5.0
+yarl==1.17.1
diff --git a/scripts/benchmarks/scenarios-granite.yaml b/scripts/benchmarks/scenarios-granite.yaml
index 2e5d0cf9..2221797b 100644
--- a/scripts/benchmarks/scenarios-granite.yaml
+++ b/scripts/benchmarks/scenarios-granite.yaml
@@ -108,3 +108,19 @@ scenarios:
# target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
# model_name_or_path:
# - 'ibm/PowerLM-3b'
+
+ - name: accelerated-moe
+ framework_config:
+ - # without acceleration
+ - moe-scattermoe-granite
+ # add pf
+ slow: True
+ arguments:
+ learning_rate: 5e-5
+ torch_dtype: bfloat16
+ gradient_accumulation_steps: 16
+ logging_steps: 1
+ packing: False
+ adam_epsilon: 1e-8
+ model_name_or_path:
+ - 'ibm/PowerMoE-3b'
diff --git a/scripts/benchmarks/scenarios-moe.yaml b/scripts/benchmarks/scenarios-moe.yaml
new file mode 100644
index 00000000..efa2725e
--- /dev/null
+++ b/scripts/benchmarks/scenarios-moe.yaml
@@ -0,0 +1,78 @@
+# This file holds a list of scenarios to may be run.
+# - to limit to a number of scenarios, use the --run-only-scenarios flag.
+# - Each scenario will be run against a particular acceleration framework
+# config, if the framework_config: key is specified.
+# * a particular framework configuration
+# - the arguments tag will hold arguments to be passed to sft_trainer
+# * the arguments are singular except for model_name_or_path which can handle
+# multiple arguments.
+# - So anything that is critical for the scenario MUST be specified here
+# and not in the defaults, e.g. fp16
+
+# This stanza will be used in future to replace the custom processing functions in data_processing.py
+# data_processing:
+# dataset_name: yahma/alpaca-cleaned
+# chat_template: |
+# {%- for message in messages %}
+# {% if message['input'] != '' %}
+# Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
+
+# {% else %}
+# Below is an instruction that describes a task. Write a response that appropriately completes the request.
+
+# {% endif %}
+# ### Instruction:
+# {{ message['instruction'] }}
+
+# {% if message['input'] != '' %}
+# ### Input:
+# {{ message['input'] }}
+
+# {% endif %}
+# ### Response:
+# {{ message['output'] + eos_token }}
+# {% endfor %}
+# tokenize: True
+
+
+scenarios:
+ - name: accelerated-moe-full
+ framework_config:
+ - # without acceleration
+ - moe-scattermoe-granite-ep1
+ - moe-scattermoe-granite-ep2
+ - moe-scattermoe-granite-ep4
+ - moe-scattermoe-granite-ep1-padding-free
+ - moe-scattermoe-granite-ep1-padding-free-foak
+ - moe-scattermoe-granite-ep2-padding-free
+ - moe-scattermoe-granite-ep2-padding-free-foak
+ - moe-scattermoe-granite-ep4-padding-free
+ - moe-scattermoe-granite-ep4-padding-free-foak
+ arguments:
+ learning_rate: 5e-5
+ torch_dtype: bfloat16
+ gradient_accumulation_steps: null
+ per_device_train_batch_size: 8
+ logging_steps: 1
+ packing: False
+ adam_epsilon: 1e-8
+ model_name_or_path:
+ - 'ibm-granite/granite-3.0-3b-a800m-instruct'
+
+ - name: accelerated-moe-full-mixtral
+ framework_config:
+ - # without acceleration
+ - moe-scattermoe-granite-ep8
+ - moe-scattermoe-granite-ep8-foak
+ slow: True
+ arguments:
+ learning_rate: 5e-5
+ torch_dtype: bfloat16
+ accelerator_config: scripts/benchmarks/accelerator-config.json
+ gradient_accumulation_steps: null
+ per_device_train_batch_size: 1
+ logging_steps: 1
+ packing: False
+ adam_epsilon: 1e-8
+ model_name_or_path:
+ - 'mistralai/Mixtral-8x7B-Instruct-v0.1'
\ No newline at end of file
diff --git a/scripts/generate_sample_configurations.py b/scripts/generate_sample_configurations.py
index bd304959..ff775c8e 100644
--- a/scripts/generate_sample_configurations.py
+++ b/scripts/generate_sample_configurations.py
@@ -148,6 +148,10 @@ def read_configuration(path: str) -> Dict:
KEY_AADP_PADDING_FREE = "aadp-padding-free"
KEY_AADP_MULTIPACK = "aadp-multipack"
KEY_FAST_KERNELS = "foak-fast-kernels"
+KEY_SCATTERMOE_EP1 = "moe-scattermoe-ep1"
+KEY_SCATTERMOE_EP2 = 'moe-scattermoe-ep2'
+KEY_SCATTERMOE_EP4 = 'moe-scattermoe-ep4'
+KEY_SCATTERMOE_EP8 = 'moe-scattermoe-ep8'
CONFIGURATIONS = {
KEY_AUTO_GPTQ: "plugins/accelerated-peft/configs/autogptq.yaml",
@@ -173,6 +177,19 @@ def read_configuration(path: str) -> Dict:
KEY_AADP_PADDING_FREE: "plugins/attention-and-distributed-packing/configs/padding_free.yaml",
KEY_AADP_MULTIPACK: "plugins/attention-and-distributed-packing/configs/multipack.yaml",
KEY_FAST_KERNELS: "plugins/fused-ops-and-kernels/configs/fast_kernels.yaml",
+ KEY_SCATTERMOE_EP1: "plugins/accelerated-moe/configs/scattermoe.yaml",
+ KEY_SCATTERMOE_EP2: (
+ "plugins/accelerated-moe/configs/scattermoe.yaml",
+ [("training.moe.scattermoe.ep_degree", 2)],
+ ),
+ KEY_SCATTERMOE_EP4: (
+ "plugins/accelerated-moe/configs/scattermoe.yaml",
+ [("training.moe.scattermoe.ep_degree", 4)],
+ ),
+ KEY_SCATTERMOE_EP8: (
+ "plugins/accelerated-moe/configs/scattermoe.yaml",
+ [("training.moe.scattermoe.ep_degree", 8)],
+ ),
}
# list of (tag, combi) tuples
@@ -192,7 +209,18 @@ def read_configuration(path: str) -> Dict:
("accelerated-peft-autogptq-foak-padding-free", (KEY_AADP_PADDING_FREE,KEY_AUTO_GPTQ, KEY_AUTO_GPTQ_FOAK)),
("accelerated-peft-bnb-nf4-foak-padding-free", (KEY_AADP_PADDING_FREE,KEY_BNB_NF4, KEY_BNB_NF4_FOAK)),
("aadp-padding-free-multipack", (KEY_AADP_PADDING_FREE, KEY_AADP_MULTIPACK)),
- ("foak-fast-kernels", (KEY_FAST_KERNELS,))
+ ("foak-fast-kernels", (KEY_FAST_KERNELS,)),
+ ("moe-scattermoe-granite-ep1", (KEY_SCATTERMOE_EP1,)),
+ ("moe-scattermoe-granite-ep1-padding-free", (KEY_AADP_PADDING_FREE, KEY_SCATTERMOE_EP1,)),
+ ("moe-scattermoe-granite-ep1-padding-free-foak", (KEY_AADP_PADDING_FREE, KEY_FAST_KERNELS, KEY_SCATTERMOE_EP1,)),
+ ("moe-scattermoe-granite-ep2", (KEY_SCATTERMOE_EP2,)),
+ ("moe-scattermoe-granite-ep2-padding-free", (KEY_AADP_PADDING_FREE, KEY_SCATTERMOE_EP2,)),
+ ("moe-scattermoe-granite-ep2-padding-free-foak", (KEY_AADP_PADDING_FREE, KEY_FAST_KERNELS, KEY_SCATTERMOE_EP2,)),
+ ("moe-scattermoe-granite-ep4", (KEY_SCATTERMOE_EP4,)),
+ ("moe-scattermoe-granite-ep4-padding-free", (KEY_AADP_PADDING_FREE, KEY_SCATTERMOE_EP4,)),
+ ("moe-scattermoe-granite-ep4-padding-free-foak", (KEY_AADP_PADDING_FREE, KEY_FAST_KERNELS, KEY_SCATTERMOE_EP4,)),
+ ("moe-scattermoe-granite-ep8", (KEY_SCATTERMOE_EP8,)),
+ ("moe-scattermoe-granite-ep8-foak", (KEY_FAST_KERNELS, KEY_SCATTERMOE_EP8,)),
]
diff --git a/tox.ini b/tox.ini
index d29f1f24..a62ae961 100644
--- a/tox.ini
+++ b/tox.ini
@@ -29,7 +29,7 @@ commands =
# need a version of fms-hf-tuning that has integrated the framework
# NOTE: have to install this first coz havnt merged
# - this repo has a lot of pins, so we just install it first
- pip install "fms-hf-tuning[flash-attn] @ git+https://github.com/foundation-model-stack/fms-hf-tuning.git@"{env:FHT_BRANCH:main}
+ pip install "fms-hf-tuning @ git+https://github.com/foundation-model-stack/fms-hf-tuning.git@"{env:FHT_BRANCH:main}
# some models need this for tokenizers
pip install protobuf
@@ -39,6 +39,10 @@ commands =
python -m fms_acceleration.cli install -e {toxinidir}/plugins/accelerated-peft
python -m fms_acceleration.cli install -e {toxinidir}/plugins/fused-ops-and-kernels
python -m fms_acceleration.cli install -e {toxinidir}/plugins/attention-and-distributed-packing
+ python -m fms_acceleration.cli install -e {toxinidir}/plugins/accelerated-moe
+
+ # install the flash attn at the last
+ pip install flash-attn
# run the benchmark script
bash scripts/run_benchmarks.sh {posargs:"1 2" "4 8" benchmark_outputs}