diff --git a/.github/workflows/build-and-publish.yml b/.github/workflows/build-and-publish.yml index 307ade0e..9592fcfb 100644 --- a/.github/workflows/build-and-publish.yml +++ b/.github/workflows/build-and-publish.yml @@ -15,6 +15,7 @@ jobs: - "accelerated-peft" - "fused-ops-and-kernels" - "attention-and-distributed-packing" + - "accelerated-moe" permissions: id-token: write # IMPORTANT: this permission is mandatory for trusted publishing diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index d2f9aea6..8f25a613 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -30,6 +30,7 @@ jobs: - "accelerated-peft" - "fused-ops-and-kernels" - "attention-and-distributed-packing" + - "accelerated-moe" steps: - uses: actions/checkout@v4 diff --git a/README.md b/README.md index 1158550c..3bdda440 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ Plugin | Description | Depends | License | Status [accelerated-peft](./plugins/accelerated-peft/README.md) | For PEFT-training, e.g., 4bit QLoRA. | Huggingface
AutoGPTQ | Apache 2.0
MIT | Alpha [fused-op-and-kernels](./plugins/fused-ops-and-kernels/README.md) | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 [(contains extracted code)](./plugins/fused-ops-and-kernels/README.md#code-extracted-from-unsloth)| Beta [attention-and-distributed-packing](./plugins/attention-and-distributed-packing/README.md) | Padding-Free Flash Attention Computation | flash-attn | Apache 2.0 | Beta - MOE-training-acceleration | [MegaBlocks](https://github.com/databricks/megablocks) inspired triton Kernels and acclerations for Mixture-of-Expert models | | Apache 2.0 | Coming Soon +[accelerated-moe](./plugins/accelerated-moe/README.md) | Triton Kernels for Mixture-of-Expert parallel, inspired by [ScatterMoe](https://github.com/shawntan/scattermoe) and [MegaBlocks](https://github.com/databricks/megablocks) | | Apache 2.0 | Beta ## Usage with FMS HF Tuning diff --git a/plugins/accelerated-moe/.isort.cfg b/plugins/accelerated-moe/.isort.cfg new file mode 100644 index 00000000..7d3762ec --- /dev/null +++ b/plugins/accelerated-moe/.isort.cfg @@ -0,0 +1,10 @@ +[settings] +profile=black +from_first=true +import_heading_future=Future +import_heading_stdlib=Standard +import_heading_thirdparty=Third Party +import_heading_firstparty=First Party +import_heading_localfolder=Local +known_firstparty= +known_localfolder=tuning \ No newline at end of file diff --git a/plugins/accelerated-moe/.pylintrc b/plugins/accelerated-moe/.pylintrc new file mode 100644 index 00000000..32b5dc66 --- /dev/null +++ b/plugins/accelerated-moe/.pylintrc @@ -0,0 +1,649 @@ +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint +# in a server-like mode. +clear-cache-post-run=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist= + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS,protobufs + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, +# it can't be used as an escape character. +ignore-paths=.*megablocks + +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks +ignore-patterns=^\.# + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Minimum Python version to use for version dependent checks. Will default to +# the version used to run pylint. +py-version=3.9 + +# Discover python modules and packages in the file system subtree. +recursive=no + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. If left empty, argument names will be checked with the set +# naming style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. If left empty, class names will be checked with the set naming style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. If left empty, function names will be checked with the set +# naming style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +#typevar-rgx= + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. If left empty, variable names will be checked with the set +# naming style. +#variable-rgx= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +exclude-too-few-public-methods= + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=100 + +# Maximum number of lines in a module. +max-module-lines=1100 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow explicit reexports by alias from a package __init__. +allow-reexport-from-package=no + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + # Added messages + use-symbolic-message-instead, + invalid-name, + missing-class-docstring, + missing-module-docstring, + missing-function-docstring, + consider-using-f-string, + inconsistent-return-statements, + no-member, + too-many-arguments, + too-many-locals, + too-many-branches, + too-many-statements, + cyclic-import, + too-few-public-methods, + protected-access, + fixme, + logging-format-interpolation, + logging-too-many-args, + attribute-defined-outside-init, + abstract-method, + pointless-statement, + wrong-import-order, + duplicate-code, + unbalanced-tuple-unpacking, + unused-argument + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages. +reports=yes + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it work, +# install the 'python-enchant' package. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io \ No newline at end of file diff --git a/plugins/accelerated-moe/README.md b/plugins/accelerated-moe/README.md new file mode 100644 index 00000000..ff1c31b2 --- /dev/null +++ b/plugins/accelerated-moe/README.md @@ -0,0 +1,95 @@ +# FMS Acceleration for Mixture-of-Experts + +This library contains plugins to accelerate finetuning with the following optimizations: +1. Expert-Parallel MoE with Triton Kernels from ScatterMoE, and some extracted from [megablocks](https://github.com/databricks/megablocks). + - Megablocks kernels for `gather` and `scatter` + +## Plugins + +Plugin | Description | Depends | Loading | Augmentation | Callbacks +--|--|--|--|--|-- +[scattermoe](./src/fms_acceleration_moe/framework_plugin_scattermoe.py) | MoE Expert Parallel with Triton Kernels from scattermoe (& megablocks) | ScatterMoE / extracted kernels from megablocks | ✅ | | ✅ + + +## Adding New Models + +Our `ScatterMoe` implementation is a module-swap; to add new models we need to update the specifications in [scattermoe_constants.py](./src/fms_acceleration_moe/utils/scattermoe_constants.py). +- See the code documentation within to understand how to add new models. + +### Using ScatterMoE Saved Checkpoints + +`ScatterMoE` checkpoints are saved using `torch.distributed.checkpoint` (DCP) and which is by default `StateDictType.SHARDED_STATE_DICT`: +- `DTensors` limited support for full state dicts. +- sharded state dicts are the extremely efficient, and require little comms overhead when saving. + +We provide a script to recover back the original checkpoint: +- currently the script is only tested in the case where DCP has saved the model in a single node. + +If the checkpoint is stored in `hf/checkpoint-10`, call the following to have the converted checkpoint written into `output_dir`: + +``` +python -m fms_acceleration_moe.utils.checkpoint_utils \ + hf/checkpoint-10 output_dir \ + mistralai/Mixtral-8x7B-Instruct-v0.1 +``` + +## Code Extracted from Megablocks + +Notes on code extraction: +- we have only extracted two `autograd` functions [GatherOp](https://github.com/databricks/megablocks/blob/main/megablocks/ops/gather.py) and [ScatterOp](https://github.com/databricks/megablocks/blob/main/megablocks/ops/scatter.py), +- and the associated triton kernels from [backend/kernels.py](https://github.com/databricks/megablocks/blob/main/megablocks/backend/kernels.py); mostly the `_padded_copy`. + +## Running Benchmarks + + +Run the below in the top-level directory of this repo: +- the `scattermoe` dep is not included by default, so the `-x` switch installs it. +- consider disabling the `torch` memory logging to see improved speeds. + +``` +tox -e run-benches \ + -x testenv:run-benches.deps+="-r plugins/accelerated-moe/requirements-khd.txt" \ + -x testenv:run-benches.setenv+="MEMORY_LOGGING=nvidia" \ + -- \ + "1 2 4" 128 benchmark_outputs scenarios-moe.yaml accelerated-moe-scatter +``` +or run the larger `Mixtral-8x7B` bench: +``` +tox ... \ + 8 128 benchmark_outputs scenarios-moe.yaml accelerated-moe-scatter-mixtral +``` + +NOTE: if `FileNotFoundError` is observed on the *triton cache*, similar to issues like these: +- https://github.com/triton-lang/triton/issues/2688 + +then somehow `tox` is causing problems with triton and multiprocessing (there is some race condition). +But the workaound is to first *activate the tox env* and +running in `bash`: +``` +# if FileNotFoundError in the triton cache is observed +# - then activate the env and run the script manually + +source .tox/run-benches/bin/activate +bash scripts/run_benchmarks.sh \ + .... +``` + + +### Triton Kernel Dependencies + +Currently we do not copy the `scattermoe` kernels into this respository, to this is an additional manual install: + +``` +# this will install the kernel-hyperdrive fork with the scattermoe triton kernels +pip install -r requirements-khd.txt +``` + +### Known Issues + +These are currently some known issues not yet resolved: +- should eventually remove the dependency on an external `kernel-hyperdrive` repository. +- now support only loading *sharded* `safetensor` non-GGUF MoE checkpoints. This is a reasonable assumption since MoE checkpoints are typically above the size limit that prevents it being saved into a single checkpoint filed. +- when used together with FSDP, the FSDP's `clip_grad_norm` will not properly compute for `ScatterMoE`, see [issue here](https://github.com/foundation-model-stack/fms-acceleration/issues/109). + + + diff --git a/plugins/accelerated-moe/configs/scattermoe.yaml b/plugins/accelerated-moe/configs/scattermoe.yaml new file mode 100644 index 00000000..63623694 --- /dev/null +++ b/plugins/accelerated-moe/configs/scattermoe.yaml @@ -0,0 +1,16 @@ +training: + + # mixture-of-experts configurations + moe: + + # expert-parallel for MoE + scattermoe: + + # The level of expert parallel sharding. + # - 1 means no sharding + # - if > 1, please ensure that this divides the world_size. This is because + # the devices will be replicated for every ep_degree devices, and + # the experts will be sharded within each group. + # - if > 1, also ensure that it divides the number of experts, as each device + # will then have num_of_experts / ep_degree experts. + ep_degree: 1 \ No newline at end of file diff --git a/plugins/accelerated-moe/pyproject.toml b/plugins/accelerated-moe/pyproject.toml new file mode 100644 index 00000000..5d522425 --- /dev/null +++ b/plugins/accelerated-moe/pyproject.toml @@ -0,0 +1,29 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "fms-acceleration-moe" +version = '0.0.1' +description = "FMS Acceleration Plugin for Mixture-of-Experts" +authors = [ + {name = "Fabian Lim", email = "flim@sg.ibm.com"}, +] +license = {text = "Apache-2.0"} +readme = "README.md" +requires-python = "~=3.9" +keywords = ['fms-hf-tuning', 'acceleration', 'mixture-of-experts', 'scattermoe', 'megablocks'] +classifiers=[ + "License :: OSI Approved :: Apache Software License", + "Development Status :: 4 - Beta", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", +] + +[tool.hatch.build.targets.wheel] +only-include = ["src/fms_acceleration_moe"] + +[tool.hatch.build.targets.wheel.sources] +"src" = "" diff --git a/plugins/accelerated-moe/requirements-khd.txt b/plugins/accelerated-moe/requirements-khd.txt new file mode 100644 index 00000000..497bf78e --- /dev/null +++ b/plugins/accelerated-moe/requirements-khd.txt @@ -0,0 +1,2 @@ +# fork of https://github.com/mayank31398/kernel-hyperdrive/ +kernel-hyperdrive @ git+https://github.com/fabianlim/kernel-hyperdrive.git \ No newline at end of file diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/__init__.py b/plugins/accelerated-moe/src/fms_acceleration_moe/__init__.py new file mode 100644 index 00000000..a1b41417 --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/__init__.py @@ -0,0 +1,17 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Local +from .framework_plugin_scattermoe import ScatterMoEAccelerationPlugin diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py new file mode 100644 index 00000000..148a5488 --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py @@ -0,0 +1,123 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from typing import Dict + +# Third Party +from fms_acceleration import AccelerationPlugin +from transformers import AutoModelForCausalLM +import torch + +# Local +from .utils import ( + patch_huggingface_save_and_load_for_dtensors, + patch_torch_optim_foreach_to_not_apply_to_dtensors, + prepare_scattermoe, +) + + +# pylint: disable=too-many-instance-attributes +class ScatterMoEAccelerationPlugin(AccelerationPlugin): + + # NOTE: we cannot do + # - require_packages = {"khd"} + # this is because the khd fork is not properly packaged as a PyPI project, and so + # - "importlib.util.find_spec('khd')" returns, but + # - "importlib.metadata.version('kernel-hyperdrive')" does not return + # if we decide to extract the kernels, then we do not need to anymore, + # https://github.com/foundation-model-stack/fms-acceleration/issues/105 + + restricted_model_archs = ["GraniteMoeForCausalLM", "MixtralForCausalLM"] + + def __init__(self, configurations: Dict[str, Dict]): + super().__init__(configurations) + + # ep_degree determines the expert parallel sharding + # - default of 1 means experts are not sharded and operate in pure replication. + self._ep_degree = self._check_config_and_maybe_check_values( + key="training.moe.scattermoe.ep_degree", + default=1, + ) + + @property + def requires_custom_loading(self): + return True + + def model_loader(self, model_name: str, **kwargs): + + # load the model + model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) + + rank, world_size = 0, 1 + if torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + # shard the MOE, and store the component names, eventually needed + # to configure the FSDP + self._moe_component_module_names = prepare_scattermoe( + model, + checkpoint_name_or_path=model_name, + rank=rank, + world_size=world_size, + ep_degree=self._ep_degree, + mixed_precision=False, # Currently this is hardcoded to OFF + ) + + # NOTE: there is currently no good way to get the mixed precision + # flag from train_args. It will be better to handle this if + # when we move the sharding to augmentation. + # https://github.com/foundation-model-stack/fms-acceleration/issues/103 + + return model + + def get_callbacks_and_ready_for_train( + self, model: torch.nn.Module = None, accelerator=None + ): + + callbacks = [] + if ( + accelerator is not None + and getattr(accelerator.state, "fsdp_plugin", None) is not None + ): + + # - use an internal function call to get the no split + # module names, which are typically layers + _layers = model._get_no_split_modules("") + accelerator.state.fsdp_plugin.ignored_modules = [ + getattr(layer, name) + for name in self._moe_component_module_names + for layer in model.modules() + if layer.__class__.__name__ in _layers + ] + + # call this to patch the HF save and load functions to be able + # to save DTensors propery + patch_huggingface_save_and_load_for_dtensors() + + # call this to patch torch optim to not use + # foreach for dtensors + patch_torch_optim_foreach_to_not_apply_to_dtensors() + + return callbacks + + +# register +AccelerationPlugin.register_plugin( + ScatterMoEAccelerationPlugin, + configuration_and_paths=[ + "training.moe.scattermoe", + ], +) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py new file mode 100644 index 00000000..660d2252 --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py @@ -0,0 +1,43 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Local +from .checkpoint_utils import patch_huggingface_save_and_load_for_dtensors +from .scattermoe_prepare import prepare_scattermoe + +# this is a special patch function to disable foreach for +# dtensors, which has been introduced since torch 2.4. +# The reason is because this will cause problems in the optimizer +# RuntimeError: aten._foreach_mul_.Scalar: got mixed torch.Tensor and DTensor, +# need to convert all torch.Tensor to DTensor before calling distributed operators! + + +# - this function patches torch +def patch_torch_optim_foreach_to_not_apply_to_dtensors(): + # guarded. + # this is an array of supported types, we will remove + # dtensor from it, so the optimizer will faillback to per + # parameter + # Third Party + # pylint: disable=import-outside-toplevel + from torch.optim.optimizer import _foreach_supported_types + + i = 0 # list index + while i < len(_foreach_supported_types): + x = _foreach_supported_types[i] + if x.__name__ == "DTensor": + # pop from list + _foreach_supported_types.pop(i) + else: + i += 1 diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py new file mode 100644 index 00000000..d8d33b18 --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -0,0 +1,468 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from collections import defaultdict +from typing import List +import json +import os +import re + +# Third Party +from accelerate.logging import get_logger +from accelerate.utils.constants import FSDP_MODEL_NAME, OPTIMIZER_NAME +from torch.distributed.checkpoint.default_planner import ( + DefaultLoadPlanner, + DefaultSavePlanner, +) +from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict +from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType +from transformers import PretrainedConfig +import torch +import torch.distributed.checkpoint as dcp + +# Local +from .scattermoe_constants import ( + FILE_SAFETENSOR_INDEX, + PARAM_NAME_ROUTER_SCATTERMOE, + PARAM_NAME_WEIGHT_SCATTERMOE, + get_scattermoe_conv_spec_from_archs, +) +from .scattermoe_state_dict import get_checkpoint_meta_from_sharded_safetensor + +logger = get_logger(__name__) + +# - variable to capture the model variable +# in the save/load model calls +MODEL_INDEX = None +KEY_MODEL = "model" +KEY_OPTIMIZER = "optimizer" + +# Below are rewrite of HF FSDP model saving functions to be able to handle +# that the parameters are now a mixture of regular and Dtensors. +# - these functions are found in accelerate.utils.fsdp_utils.py +# - save_fsdp_model, save_fsdp_optimizer, load_fsdp_model, load_fsdp_optimizer +# NOTE: we will observe warnings such as +# /torch/distributed/checkpoint/state_dict.py:520: +# FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor. + + +# rewrite of func from accelerate.utils.fsdp_utils.py +# - empty function, the main logic will be in save_fsdp_optimizer (see below). +def save_fsdp_model( + fsdp_plugin, accelerator, model, output_dir, model_index=0, adapter_only=False +): + # pylint: disable=global-statement + global MODEL_INDEX + MODEL_INDEX = model_index + + +# rewrite of func from accelerate.utils.fsdp_utils.py +# - saves both model and optimizer at the same time +def save_fsdp_optimizer( + fsdp_plugin, accelerator, optimizer, model, output_dir, optimizer_index=0 +): + + if fsdp_plugin.state_dict_type != StateDictType.SHARDED_STATE_DICT: + raise NotImplementedError( + "Checkpointing for megablocks only enabled for sharded state dict." + ) + + # get the state dicts for model and optimize + (model_state_dict, optimizer_state_dict) = get_state_dict(model, optimizer) + + # - save model + ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}") + os.makedirs(ckpt_model, exist_ok=True) + logger.info(f"Saving model to {ckpt_model}") + dcp.save( + state_dict={KEY_MODEL: model_state_dict}, + storage_writer=dcp.FileSystemWriter(ckpt_model), + planner=DefaultSavePlanner(), + ) + logger.info(f"Model saved to {ckpt_model}") + + # - save optimizer + ckpt_opt = os.path.join(output_dir, f"{OPTIMIZER_NAME}_{optimizer_index}") + os.makedirs(ckpt_opt, exist_ok=True) + logger.info(f"Saving Optimizer state to {ckpt_opt}") + dcp.save( + state_dict={KEY_OPTIMIZER: optimizer_state_dict}, + storage_writer=dcp.FileSystemWriter(ckpt_opt), + planner=DefaultSavePlanner(), + ) + logger.info(f"Optimizer state saved in {ckpt_opt}") + + +# rewrite of func from accelerate.utils.fsdp_utils.py +# - empty function, main logic in load_fsdp_optimizer (see below). +def load_fsdp_model( + fsdp_plugin, accelerator, model, input_dir, model_index=0, adapter_only=False +): + # pylint: disable=global-statement + global MODEL_INDEX + MODEL_INDEX = model_index + + +# rewrite of func from accelerate.utils.fsdp_utils.py +# - loads both model and optimizer +def load_fsdp_optimizer( + fsdp_plugin, + accelerator, + optimizer, + model, + input_dir, + optimizer_index=0, + adapter_only=False, +): + + accelerator.wait_for_everyone() + if fsdp_plugin.state_dict_type != StateDictType.SHARDED_STATE_DICT: + raise NotImplementedError( + "Checkpointing for megablocks only enabled for sharded state dict." + ) + + # - get the state dicts + model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) + + # - load the model state dict + ckpt_model = os.path.join(input_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}") + dcp.load( + state_dict={KEY_MODEL: model_state_dict}, + storage_reader=dcp.FileSystemReader(ckpt_model), + planner=DefaultLoadPlanner(), + ) + + # - load the optimizer state dict + ckpt_opt = os.path.join(input_dir, f"{OPTIMIZER_NAME}_{optimizer_index}") + dcp.load( + state_dict={KEY_OPTIMIZER: optimizer_state_dict}, + storage_reader=dcp.FileSystemReader(ckpt_opt), + planner=DefaultLoadPlanner(), + ) + + # - set the state dicts + set_state_dict( + model, + optimizer, + model_state_dict=model_state_dict, + optim_state_dict=optimizer_state_dict, + ) + + # FIXME: + # - We see errors that occur in optimizer.step() + # - torch/optim/optimizer.py", line 89, in _use_grad + # - torch/optim/adamw.py", line 214, in step beta1, + # beta2 = cast(Tuple[float, float], group["betas"]) + # - KeyError: 'betas' + # - Fortunately, this seems to be limited to the empty groups case, where + # it seems that it is just the params are not initialized. Since we suppose + # these groups are never used, we simply initialize the empty groups with + # random values so the errors do not throw. + for group in optimizer.param_groups: + if len(group["params"]) == 0: + group["betas"] = (0.9, 0.999) + group["lr"] = 0.0 + group["initial_lr"] = 0.0 + group["eps"] = 1e-8 + group["weight_decay"] = 0.0 + + +# function to replace various trainer functions in HF with the ones +# above +def patch_huggingface_save_and_load_for_dtensors(): + # Third Party + # NOTE: this is really a global replacement, which we use the patcher + # to do + # pylint: disable=import-outside-toplevel + from fms_acceleration.model_patcher import patch_target_module + + patch_target_module("transformers.trainer.save_fsdp_model", save_fsdp_model) + patch_target_module("transformers.trainer.save_fsdp_optimizer", save_fsdp_optimizer) + patch_target_module("transformers.trainer.load_fsdp_model", load_fsdp_model) + patch_target_module("transformers.trainer.load_fsdp_optimizer", load_fsdp_optimizer) + + +# this function implements a trick to get the resolved cache file to acccess the safetensor +# - NOTE: does not work if _dict_from_json_file is not called, such as in the case of GGUF files. +def get_resolved_checkpoint_location(model_name_or_path: str): + + result = None + _old_func = PretrainedConfig._dict_from_json_file + + def _dict_from_json_file(resolved_config_file): + nonlocal result + result = resolved_config_file + return _old_func(resolved_config_file) + + # make a hook and restrive + PretrainedConfig._dict_from_json_file = _dict_from_json_file + PretrainedConfig.from_pretrained(model_name_or_path) + PretrainedConfig._dict_from_json_file = _old_func + return os.path.dirname(result) + + +# function to get the ScatterMoE state dict from its DCP checkpoint +# - if the original pretrained_model_name_or_path is specified, will use the checkpoint as hints +# to map the ScatterMoE checkpoint to that of the original model. This is useful so that we +# can restore the checkpoint to be loaded by the original architecture. +def recover_original_state_dict_from_dcp_checkpoint( + dcp_checkpoint_dir: str, + pretrained_model_name_or_path: str = None, +): + """ + Parameters: + dcp_checkpoint_dir (str): the DCP to be converted. + pretrained_model_name_or_path (str): Optional, if provided we will + use the hints to remap the + """ + + # reference dcp_to_torch_save from torch.distributed.checkpoint.format_utils.py + # - strategy is to use _EmptyStateDictLoadPlanner to populate the state dict, then we remap + + # guarded, load some internal functions + # pylint: disable=import-outside-toplevel + # Third Party + from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner + from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE + from torch.distributed.checkpoint.state_dict_loader import _load_state_dict + + sd: STATE_DICT_TYPE = {} + _load_state_dict( + sd, + storage_reader=dcp.FileSystemReader(dcp_checkpoint_dir), + planner=_EmptyStateDictLoadPlanner(), + no_dist=True, + ) + sd = sd[KEY_MODEL] + + # if not provided + if pretrained_model_name_or_path is None: + return sd + + # now do the remap + loc = get_resolved_checkpoint_location(pretrained_model_name_or_path) + with open(os.path.join(loc, FILE_SAFETENSOR_INDEX), encoding="utf-8") as f: + index = json.load(f) + + # config + config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path) + + ( + _, + router_name, + expert_name, + __, + sharded_expert_ckpt, + ) = get_scattermoe_conv_spec_from_archs(config.architectures) + + # the sd from the module swap must have keys like + # 'model.layers.0.block_sparse_moe.w1.weight' + # 'model.layers.0.block_sparse_moe.w2.weight' + # 'model.layers.0.block_sparse_moe.router.weight' + # so we use this fact to infer that + # prefix = model.layers.0 and module_name = block_sparse_moe + + def _infer_prefixes_and_module_names( + sd_keys: List[str], + min_count: int = 3, + ): + _name = "|".join([PARAM_NAME_ROUTER_SCATTERMOE, *PARAM_NAME_WEIGHT_SCATTERMOE]) + # pylint: disable=anomalous-backslash-in-string + _reg = re.compile(f"(.*)\.({_name})\.weight") + found = {} + + for k in sd_keys: + m = _reg.match(k) + if m is None: + continue + + prefix, _ = m.groups() + found[prefix] = 1 + found.get(prefix, 0) + + results = [] + for prefix, cnt in found.items(): + # if at least router, w1 and w2 are found, take it + # otherwise we delete + if cnt >= min_count: + results.append(prefix) + + return results + + for prefix in _infer_prefixes_and_module_names(sd.keys()): + prefix = prefix.split(".") + prefix, module_name = ".".join(prefix[:-1]), prefix[-1] + + # checkpoint metadata is will be a map + # key -> list of tuples + # where each in the list is (param_name, stfile) + # - if the list is larger than one, it means that the + # actual model has a sharded checkpoint + + # defaultdict(list, + # {'w1.weight': [('model.layers.0.block_sparse_moe.input_linear.weight', + # 'model-00001-of-00002.safetensors')], + # 'w3.weight': [('model.layers.0.block_sparse_moe.input_linear.weight', + # 'model-00001-of-00002.safetensors')], + # 'w2.weight': [('model.layers.0.block_sparse_moe.output_linear.weight', + # 'model-00001-of-00002.safetensors')], + # 'router.weight': [('model.layers.0.block_sparse_moe.router.layer.weight', + # 'model-00001-of-00002.safetensors')]}) + + checkpoint_metadata = get_checkpoint_meta_from_sharded_safetensor( + index["weight_map"], + prefix, + module_name, + router_name, + expert_name, + ) + + model2scatter = defaultdict(dict) + # construct a map of model_key -> {scatter_key: [params, ...]} + # - if the param list > 1, that means many scatter keys map to 1 + # model param and they need to be cat + for scatter_key, list_of_params in checkpoint_metadata.items(): + scatter_key_fqdn = ".".join([prefix, module_name, scatter_key]) + scatter_param = sd[scatter_key_fqdn] + + # remove from state dict + del sd[scatter_key_fqdn] + + n = len(list_of_params) + if scatter_key.startswith(PARAM_NAME_ROUTER_SCATTERMOE): + assert n == 1, "Router parameters should not be sharded." + elif not sharded_expert_ckpt: + assert n == 1, "Expert weights expected to be non-sharded." + else: + # if sharded, we just assume that there should be 1 expert + # per shard + assert ( + n == scatter_param.shape[0] + ), "Sharded expert weights should be 1 expert per shard." + + if any(scatter_key.startswith(k) for k in PARAM_NAME_WEIGHT_SCATTERMOE): + scatter_param = scatter_param.permute(0, 2, 1) + + # go through all the model keys + + for i, (model_key, _) in enumerate(list_of_params): + if n == 1: + # handles routers and non-sharded experts case + _param = scatter_param + else: + # then it needs to be sharded + _param = scatter_param[i] + + model2scatter[model_key][scatter_key] = _param + + # replace them back in the sd + for model_key in list(model2scatter.keys()): + + scatter_params = model2scatter[model_key] + + # - there is an assumption that the ifthere is a cat, then + # it will go by order of scatter keys + scatter_keys = sorted(scatter_params.keys()) + + assert ( + len(scatter_keys) > 0 + ), f"Obtained zero scatter keys for model_key '{model_key}'" + + if len(scatter_keys) == 1: + sd[model_key] = scatter_params[scatter_keys[0]] + else: + # unfortunately, there this is a in + # scattermoe_state_dict._maybe_reshape_scattermoe_expert_weights + # that we split on the dim=1, so we cat back on that + sd[model_key] = torch.cat( + [scatter_params[k] for k in scatter_keys], dim=1 + ) + + # remove from this intemediate mapping + del model2scatter[model_key] + + rem_keys = ",".join(list(model2scatter)) + assert len(rem_keys) == 0, f"Did not handle model parameters '{rem_keys}'" + + return sd + + +# --------------------------- SCRIPT ------------------------- + + +# have it serve as a conversion script +if __name__ == "__main__": + # Standard + import argparse + + parser = argparse.ArgumentParser( + description=( + "Utility for converting ScatterMoE checkpoint back to the " + "orginal state dict format. " + "The ScatterMoE checkpoint was saved after the pretrained model " + "had been converted by a module swap, hence the state dict will " + "no longer resemble the original. This utility creaes" + ) + ) + + parser.add_argument( + "dcp_checkpoint_dir", + help="Path to the distributed checkpoint.", + ) + + parser.add_argument( + "output_dir", help="Path to the location to write the converted checkpoint." + ) + + parser.add_argument( + "pretrained_model_name_or_path", + help=( + "In order to reconstruct the state dict, we requre hints from " + "the original pretrained model checkpoint (from which this " + "checkpoint is obtained)." + ), + ) + + args = parser.parse_args() + + # search for the checkpint. By the code above, it must + # start with FSDP_MODEL_NAME + if args.dcp_checkpoint_dir.startswith(FSDP_MODEL_NAME): + checkpoint_dir = args.dcp_checkpoint_dir + else: + checkpoint_dir = [ + x + for x in os.listdir(args.dcp_checkpoint_dir) + if os.path.isdir(os.path.join(args.dcp_checkpoint_dir, x)) + and x.startswith(FSDP_MODEL_NAME) + ] + if len(checkpoint_dir) > 1: + raise ValueError( + f"Found > 1 dirs in dcp checkpoint dir {args.dcp_checkpoint_dir} " + f"that starts with {FSDP_MODEL_NAME}. Please spectify the exact dir." + ) + if len(checkpoint_dir) == 0: + raise ValueError( + f"Found no dirs in dcp checkpoint dir {args.dcp_checkpoint_dir} " + f"that starts with {FSDP_MODEL_NAME}. Nothing to convert" + ) + checkpoint_dir = os.path.join(args.dcp_checkpoint_dir, checkpoint_dir[0]) + + # get the converted statedict + state_dict = recover_original_state_dict_from_dcp_checkpoint( + checkpoint_dir, args.pretrained_model_name_or_path + ) + + # save it + torch.save(state_dict, args.output_dir) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py new file mode 100644 index 00000000..152bee1d --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py @@ -0,0 +1,507 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from typing import Tuple + +# Third Party +from peft import LoraConfig +from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND +from torch.distributed._tensor import DTensor + +# pylint: disable=import-error +from torch.distributed._tensor.device_mesh import DeviceMesh +from transformers.activations import ACT2FN +import torch +import torch.nn.functional as F + +try: + # Third Party + from khd.kernels.scattermoe.triton_implementation.ops import ( + padded_block_indices, + scattered_experts, + ) +except ImportError as e: + raise ImportError( + "kernel-hyperdrive PyPI package not found. Install it: " + "pip install -r plugins/accelerated-moe/requirements-scattermoe.txt" + ) from e + +# Local +from .scattermoe_constants import SCATTERMOE_SPEC_HAS_GATE +from .scattermoe_utils import all_to_all_gather_inputs, scatter_with_routing_weights + + +# helper function to fetch the local tensor if its a dtensor +def _maybe_get_local_tensor(weight: torch.Tensor): + if isinstance(weight, DTensor): + return weight.to_local() + return weight + + +class ScatteredExperts(torch.nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + num_experts: int, + fan_out: int, + grouped_in: bool = False, + grouped_out: bool = False, + dtype: torch.dtype = torch.bfloat16, + device: torch.device = torch.device("cpu"), + lora_config: LoraConfig = None, + ): + """ + ScatteredExperts is the module that implements a group of experts. The + forward function will call scattermoe triton kernels. + + NOTE: in the current implementation, we do not initialize the parameters. + We assume this will be done outside. + + Paramters: + in_features (int): num of input features per expert. + out_features (int): num of output features per expert. + num_experts (int): the number of experts. + fan_out (int): if the number of embedding inputs are expected to be + a factor less than the bind_ids and indices at the forward. + grouped_in (bool): if the embedding inputs are expected to be already + grouped in at the forward. + grouped_out (bool): if the outputs are expected to be grouped + when they are returned from the forward. + dtype (torch.dtype): the dtype of the parameter tensors. Note, for now the + adapter takes the same dtype as base layer if LoRA is enabled. + device (torch.device): the cuda device in which the model should be loaded. + Only cuda is supported since only triton kernels are supported. + lora_config (peft.LoraConfig): Optional, to be passed if lora is to be used. + """ + super().__init__() + + # parameters the experts (not initialized). + self.weight = torch.nn.Parameter( + torch.empty( + num_experts, + in_features, + out_features, + dtype=dtype, + device=device, + ), + requires_grad=True, + ) + + # handle lora embeddings + self.lora_A, self.lora_B = None, None + self.lora_r = 0 + if lora_config is not None: + # if LoRA, then disable gradient for the base layer. + self.weight.requires_grad = False + + # NOTE : - for now adapter takes same dtype as base + self.lora_A = torch.nn.Parameter( + torch.empty( + num_experts, + in_features, + lora_config.r, + dtype=dtype, + device=device, + ), + requires_grad=True, + ) + self.lora_B = torch.nn.Parameter( + torch.empty( + num_experts, + lora_config.r, + out_features, + dtype=dtype, + device=device, + ), + requires_grad=True, + ) + self.lora_r = lora_config.r + + # store these settings + self.fan_out = fan_out + self.grouped_in = grouped_in + self.grouped_out = grouped_out + + def forward( + self, + x: torch.Tensor, + bin_ids: torch.Tensor, + indices: torch.Tensor, + padded_block_idxs: torch.Tensor, + expert_offsets: torch.Tensor, + gates: torch.Tensor = None, + ): + """ + ScatteredExperts executes grouped forwards where each group is a single expert. + + Parameters: + x (tensor): the emebeddings fed as input. + bin_ids (tensor): the expert index where each embedding is to be sent. + Expect that these indices are sorted. + indices (tensor): the sorting index that brings the input embeddings to the + sorted order corresponding to bin_ids. + padded_block_idxs (tensor): the indices for passing triton block info to the + scattermoe kernels. + expert_offsets (tensor): the offsets for passing triton block info to the + scattermoe kernels. + gates (tensor): Optional. the weighting coefficients that should be applied + at the output of the scattermoe kernels. + """ + weight = _maybe_get_local_tensor(self.weight) + lora_A, lora_B = None, None + if self.lora_r > 0: + lora_A, lora_B = ( + _maybe_get_local_tensor(self.lora_A), + _maybe_get_local_tensor(self.lora_B), + ) + + # NOTE: x is of shape (seqlen, in_features) + # bin_ids is of shape (seqlen,) + # padded_block_idxs is a 1-dim tensor, whose length depends on + # triton kernel block size and input. + # expert_offsets is of shape (num_experts, ) + return scattered_experts( + x, + weight, + self.fan_out, + bin_ids, # sorted_expert_idxs, + indices, # sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates=gates, # we dont have router weights + grouped_in=self.grouped_in, + grouped_out=self.grouped_out, + expert_lora_A=lora_A, + expert_lora_B=lora_B, + lora_alp=self.lora_r, + ) + + +# NOTE: this name should match scattermoe_constants.CLASS_NAME_SCATTERMOE +# similar to of MoE_Triton from https://github.com/mayank31398/kernel-hyperdrive +# and ParameterizedScatteredExperts from +# https://github.com/IBM/dolomite-engine/blob/main/dolomite_engine/hf_models/models/moe_dolomite/moe/scatter.py +# - support expert parallel where the data is communicated via all_to_all +# pylint: disable=too-many-instance-attributes +class ScatterMoE(torch.nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_act: str, + intermediate_size: int, + num_experts: int, + has_bias: bool = False, + mlp_arch: str = None, + top_k: int = 2, + dtype: torch.dtype = torch.bfloat16, + device: str = torch.device("cpu"), + ep_device_mesh: DeviceMesh = None, + lora_config: LoraConfig = None, + ): + """ + ScatterMoE is the module swap that replaces a sparse mixture-of-experts module + in order to run the scatter moe kernels and the all_to_all expert parallel routines. + + The submodules are a i) router (nn.Linear) and ii) w1, w2, ... (ScatteredExperts); + the latter hold the expert weights and run the triton kernels. + + Parameters: + + hidden_size (int): the hidden dimension. + hidden_act (str): the activation fucntion. + intermediate_size (int): the intermediate dimension. + num_experts (int): the number of experts. + has_bias (bool): if the router and experts have bias. + mlp_arch (str): unique key that specifies the MLP architecture, + e.g., if there is a gate forward. + top_k (int): the number of experts each token gets routed to. + dtype (torch.dtype): the dtype of the parameter tensors. + device (torch.device): the cuda device in which the model should be loaded. + Only cuda is supported since only triton kernels are supported. + ep_device_mesh (torch.distributed.DeviceMesh): Optional, to be passed if there is + sharding. Only pass the mesh for the experts. + lora_config (peft.LoraConfig): Optional, to be passed if lora is to be used. + """ + assert ( + not has_bias + ), "ScatterMoE currently unable to handle bias in both gates and experts." + + if lora_config is not None: + # since this is self implemented, we really only support basic lora funcs + assert ( + lora_config.bias == "none" + ), "ScatterMoE currently unable to handle bias in the lora adapters" + assert ( + lora_config.target_modules == INCLUDE_LINEAR_LAYERS_SHORTHAND + or INCLUDE_LINEAR_LAYERS_SHORTHAND in lora_config.target_modules + ), "ScatterMoe currently only handles lora adapters on all linears." + + assert lora_config.init_lora_weights in { + True, + "gaussian", + }, "ScatterMoe currently only handles gaussian initialization." + + super().__init__() + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_experts = num_experts + self.hidden_act = hidden_act + self.activation = ACT2FN[hidden_act] + self.top_k = top_k + self.all_to_all = ( + ep_device_mesh.size() > 1 if ep_device_mesh is not None else False + ) + + # NOTE: we should then use this to distribute inside + # and not do the distribution outside + self.expert_parallel_group = ( + ep_device_mesh.get_group(0) if ep_device_mesh is not None else None + ) + + # build the router + self.router = torch.nn.Linear( + in_features=hidden_size, + out_features=num_experts, + bias=has_bias, + dtype=dtype, + device=device, + ) + + # the experts. The architecture may depend on the model + # - w1: the up_projection. + # - w2: the down_projection. + # - w3 (optional): the gate projection. + self.w1 = ScatteredExperts( + in_features=self.hidden_size, + out_features=self.intermediate_size, + num_experts=self.num_experts, + fan_out=self.top_k if not self.all_to_all else 1, + grouped_out=True, + dtype=dtype, + device=device, + lora_config=lora_config, + ) + self.w2 = ScatteredExperts( + in_features=self.intermediate_size, + out_features=self.hidden_size, + num_experts=self.num_experts, + fan_out=1, + grouped_in=True, + dtype=dtype, + device=device, + lora_config=lora_config, + ) + if mlp_arch == SCATTERMOE_SPEC_HAS_GATE: + self.w3 = ScatteredExperts( + in_features=self.hidden_size, + out_features=self.intermediate_size, + num_experts=self.num_experts, + fan_out=self.top_k if not self.all_to_all else 1, + grouped_out=True, + dtype=dtype, + device=device, + lora_config=lora_config, + ) + + # referenced from dolomite-engine + def _compute_routing_weights(self, hidden_states: torch.Tensor): + + # router_logits: (batch * sequence_length, n_experts) + weight = _maybe_get_local_tensor(self.router.weight) + bias = self.router.bias + if bias: + bias = _maybe_get_local_tensor(bias) + # pylint: disable=not-callable + router_logits = F.linear(hidden_states, weight, bias) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + return router_logits, routing_weights, selected_experts + + def _get_expert_idxs_and_maybe_gather( + self, + hidden_states: torch.Tensor, + selected_experts: torch.Tensor, + ): + """ + gets the expert indices, and also gather the hidden_states if + all-to-all processing is required. + + Parameters: + hidden_states (tensor): 2D batch-flattened hidden states. + selected_experts (tensor): indices of experts selected for each + hidden state. + """ + + # megablocks has a cuda kernel for computing a radix sort, but + # just use the torch version + sorted_expert_idxs, sorted_scattered_idxs = torch.sort( + selected_experts.flatten() + ) + if not self.all_to_all: + # in this case, no gathering required for hidden states + return hidden_states, sorted_expert_idxs, sorted_scattered_idxs + + # outputs will: + # - parallel_x: gathered version of hidden_states + # - parallel_bin_ids: gathered version of sorted_expert_idxs, + # - parallel_ind: gathered version of sorted_scattered_idxs. + # + # followed by some counting metrics: + # - send_counts, recv_counts, bins (local) + outputs = all_to_all_gather_inputs( + hidden_states, + selected_experts, + sorted_expert_idxs, + sorted_scattered_idxs, + self.expert_parallel_group, + self.top_k, + self.num_experts, + ) + + return outputs + (sorted_expert_idxs, sorted_scattered_idxs) + + def _maybe_scatter( + self, + hidden_states: torch.Tensor, + routing_weights: torch.Tensor = None, + gather_products: Tuple[torch.Tensor, ...] = None, + ): + """ + maybe undo the earlier scatter operation during all-to-all. + + Parameters: + hidden_states (tensor): batch-flattened hidden states. + routing_weights (tensor): Optional, routing weights for each expert. + gather_products (tensor): Optional, tuple of tensors that would have been + produced by the earlier gather call. + """ + + if not self.all_to_all: + # in this case scattering is already handled by + # scattermoe when computing w2 + # - then there is nothing to do + return hidden_states + + # expect these products to be produced by an earlier + # all-to-all gather call + (send_counts, recv_counts, bins, sorted_expert_idxs, sorted_scattered_idxs) = ( + gather_products + ) + + # perform the scattering with the gather products, + hidden_states = scatter_with_routing_weights( + hidden_states, + routing_weights.flatten(), + send_counts, + recv_counts, + bins, + sorted_expert_idxs, + sorted_scattered_idxs, + self.expert_parallel_group, + self.top_k, + ) + + return hidden_states + + def forward(self, hidden_states: torch.Tensor): + """ + ScatterMoe.forward replaces the forward of the sparse + mixture-of-expert module. + """ + + # flatten the batch dimension + original_shape = hidden_states.shape # take a record + hidden_states = hidden_states.view(-1, self.hidden_size) + + # compute the routing logits, weights, and expert assigments + # - router_logits: will be passed out of forward, used for computing + # routing loss. + (router_logits, routing_weights, selected_experts) = ( + self._compute_routing_weights(hidden_states) + ) + + # get the sorted expert idxs and scattered idxs. + # - if a gather is required, then the hidden-states will be + # communicated from other ranks, and will change. + # - in gather is required, then some _gather_products will be + # required for the scattering later, so return these out also. + ( + hidden_states, + sorted_expert_idxs, + sorted_scattered_idxs, + *_gather_products, + ) = self._get_expert_idxs_and_maybe_gather( + hidden_states, + selected_experts, + ) + + # scattemoe specific computation. + # - padded indicies need to be computed for the scattermoe + # triton kernels. + with torch.no_grad(): + padded_block_idxs, expert_offsets = padded_block_indices( + sorted_expert_idxs, self.num_experts + ) + + # compute the up projection + out = self.w1( + hidden_states, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + ) + out = self.activation(out) + + # - if the arch has a seperate gate projection + if self.w3: + out *= self.w3( + hidden_states, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + ) + + # compute the down projection + # - if no all-to-all processing, then depend on + # scattermoe kernel to perform the final scattering + hidden_states = self.w2( + out, + sorted_expert_idxs, + sorted_scattered_idxs, + padded_block_idxs, + expert_offsets, + gates=(None if self.all_to_all else routing_weights), + ) + + # maybe scatter + hidden_states = self._maybe_scatter( + hidden_states, + routing_weights, + _gather_products, + ) + + # return hidden states and router logits + return (hidden_states.view(original_shape), router_logits) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_constants.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_constants.py new file mode 100644 index 00000000..be3057b5 --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_constants.py @@ -0,0 +1,94 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from typing import List + +# to be updated so that the parsers can work properly +PARAM_NAME_ROUTER_SCATTERMOE = "router" +PARAM_NAME_WEIGHT_SCATTERMOE = ["w1", "w2", "w3"] + +FILE_SAFETENSOR_INDEX = "model.safetensors.index.json" +KEY_REPLICATE = "replicate" +KEY_EXPERT_PARALLEL = "expert_parallel" +DIM_EXPERT = 0 + +KEY_SCATTERMOE_ROUTER = PARAM_NAME_ROUTER_SCATTERMOE + ".weight" + +# Currently out ScatterMoE drop supports an up/down proj, and +# and optional gate_proj. +# - When new architectures are supported this list will update +SCATTERMOE_SPEC_HAS_GATE = "Gated" + +# - moe_cls +# - router_name +# - expert_name +# - weight_spec +# - sharded experts + +# NOTE: it is quite challenging to perform the module swap +# when the incoming MoE model can have quite varied impls. +# - hence the adopted strategy is to infer the weights from the +# state dict of the incoming model, and map them to the ScatterMoE +# - the SPEC is a description of some hints to help perfom this state_dict +# mapping. + +# NOTE: there is an expert_map logic which is currently not exposed +# in the SPEC. the expert_map allows us to map the the parameter names +# if they are different. But so far we do not need to use it. + +# NOTE: the keys can be a single arch string MixtralForCausalLM +# or a few arch strings seperated by comma (no space) + +# when adding new models, follow the following convention: +# - class name of moe module to be replaced with ScatterMoE. +# - module_name of the router. +# - module_name of the experts; this can be specified as a plain +# name or a regex. +# (str): name of the module +# (regex): w1_name|w2_name|w3_name specificy the names if they are different. +# - boolean flag indicating if the experts are sharded in the state dict. +# i.e., meaning the experts exist in seperate 2D Linear modules +# or all "combined" into a single 3D linear module. +SCATTERMOE_CONVERSION_SPEC = { + "MixtralForCausalLM": ( + "MixtralSparseMoeBlock", + "gate", + "experts", + SCATTERMOE_SPEC_HAS_GATE, + True, + ), + "GraniteMoeForCausalLM": ( + "GraniteMoeMoE", + "router", + "input_linear|output_linear|input_linear", + SCATTERMOE_SPEC_HAS_GATE, + False, + ), +} + + +# helper function to get the spec based on architectures +def get_scattermoe_conv_spec_from_archs(architectures: List[str]): + # infer the spec + for archs, spec in SCATTERMOE_CONVERSION_SPEC.items(): + archs = archs.split(",") + if any(x in archs for x in architectures): + return spec + + # if not found + raise ValueError( + f"In order to configure ScatterMoe for archs '{architectures}' " + "the conversion spect must be updated in scattermoe_constants.py" + ) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py new file mode 100644 index 00000000..7dfa607c --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py @@ -0,0 +1,341 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from collections import OrderedDict +from contextlib import nullcontext +import json +import os + +# Third Party +from accelerate import init_empty_weights +from peft import LoraConfig +from torch.distributed._tensor import DTensor, Replicate, Shard, distribute_tensor + +# pylint: disable=import-error +from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh +from tqdm import tqdm +from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0 +import torch + +# Local +from .checkpoint_utils import get_resolved_checkpoint_location +from .scattermoe_constants import ( + FILE_SAFETENSOR_INDEX, + KEY_EXPERT_PARALLEL, + KEY_REPLICATE, + KEY_SCATTERMOE_ROUTER, + get_scattermoe_conv_spec_from_archs, +) +from .scattermoe_state_dict import ( + convert_state_dict, + get_checkpoint_meta_from_sharded_safetensor, + get_state_dict_from_checkpoint_metadata, +) + + +# this function will load the sharded experts onto the device. +# - this assumes that the "dmoe" module is the megablocks.layers.dmoe.dMoE distributed +# implementation of the mixture of experts. +def load_experts_onto_device( + module: torch.nn.Module, + state_dict: OrderedDict, + device_mesh: DeviceMesh, + num_experts_per_device: int, +): + + # hook for scaling the gradient + scaling = device_mesh[KEY_EXPERT_PARALLEL].size() + + def _hook(grad): + if grad is not None: + grad.div_(scaling) + return grad + + # required replication placements + reps = [Replicate() for _ in range(device_mesh.ndim - 1)] + + for weight_name, param in state_dict.items(): + + if KEY_SCATTERMOE_ROUTER in weight_name: + # if its the router, replicate + param = distribute_tensor(param, device_mesh, reps + [Replicate()]) + elif param.shape[0] > num_experts_per_device: + # if its a weight param and the number of experts exceed that of + # the device, shard + param = distribute_tensor(param, device_mesh, reps + [Shard(0)]) + else: + # if its a weight and the already sharded by number of experts + param = DTensor.from_local( + param, device_mesh=device_mesh, placements=reps + [Shard(0)] + ) + + # get the module we want to shard + name = weight_name.split(".") + path, name = ".".join(name[:-1]), name[-1] + mod = module.get_submodule(path) + requires_grad = getattr(mod, name).requires_grad + + param = torch.nn.Parameter( + param, + requires_grad=requires_grad, + ) + + # install gradient scaling hook + if KEY_SCATTERMOE_ROUTER not in weight_name: + param.register_hook(_hook) + + # register the sharded parameter onto the megablocks.dmoe + mod.register_parameter(name, param) + + +def prepare_scattermoe( + model: torch.nn.Module, + checkpoint_name_or_path: str = None, + rank: int = None, + world_size: int = None, + ep_degree: int = 1, + key_rep: str = KEY_REPLICATE, + key_ep: str = KEY_EXPERT_PARALLEL, + device_type: str = "cuda", + mixed_precision: bool = False, + lora_config: LoraConfig = None, +): + + # guarded because may have third party package deps + # Local + # pylint: disable=import-outside-toplevel + from .scattermoe import ScatterMoE + + assert world_size % ep_degree == 0, ( + f"world size ({world_size}) " f"not divisible by ep_size ({ep_degree})." + ) + + moe_num_experts: int = model.config.num_local_experts + num_experts_per_device = moe_num_experts // ep_degree + assert ( + moe_num_experts % ep_degree == 0 + ), f"moe num experts ({moe_num_experts}) not divisible by ep_shard_factor ({ep_degree})." + + # current rank of the device + device = torch.device(f"{device_type}:{rank}") + + # get the scattermoe conversion spec + ( + moe_cls, + router_name, + expert_name, + expert_mlp_spec, + sharded_expert_ckpt, + ) = get_scattermoe_conv_spec_from_archs(model.config.architectures) + + # split the names first + expert_name = expert_name.split("|") + + rep_size = world_size // ep_degree + if ep_degree == 1 and rep_size == 1: + # in this case no need for sharding + device_mesh = None + elif rep_size == 1: + # in this case a 1D device mesh suffices + device_mesh = init_device_mesh( + device_type, + (ep_degree,), + mesh_dim_names=(key_ep,), + ) + else: + # in this case it will distribute experts on a different dim + # - this will achieve the effect that the expert sharding can be + # hierachical (e.g., can be over a slower network plane since + # the communication overhead is less + device_mesh = init_device_mesh( + device_type, + (rep_size, ep_degree), + mesh_dim_names=(key_rep, key_ep), + ) + + # - compute the shard indices for current expert, if sharding is + # indeed taking place + expert_shards = None + if device_mesh is not None: + _index = device_mesh[KEY_EXPERT_PARALLEL].get_local_rank() + expert_shards = list( + range( + _index * num_experts_per_device, (_index + 1) * num_experts_per_device + ) + ) + + # - if mixed precision is specified then we upcast + dtype = model.dtype if not mixed_precision else torch.float32 + + # for all the MoE related params, e.g., gate, experts + # get a dictionary + # parent_mod: (child_instance_name, [list of fqdn keys]) + found = {} + for name, mod in model.named_modules(): + name = name.split(".") + parent, child = ".".join(name[:-1]), name[-1] + + # check the module depending if moe_cls is a str or class + # pylint: disable=isinstance-second-argument-not-valid-type + if ( + mod.__class__.__name__ == moe_cls + if isinstance(moe_cls, str) + else isinstance(mod, moe_cls) + ): + fqdn_keys = [ # all params, including childs' + f"{parent}.{child}.{n}" for n, _ in mod.named_parameters() + ] + + # check if there are any biases in any of the experts + # if there are biases + # Assumption: assume that if one expert has bias,then the others + # will have it to + has_bias = any( + expert_name[0] in k and k.endswith("bias") for k in fqdn_keys + ) + + found[parent] = (child, fqdn_keys, has_bias) + + assert len(found) > 0, "cannot find scattermoe modules to replace" + + moe_module_names = set() + + # pylint: disable=too-many-nested-blocks + # NOTE: for now we only support sharded safetensors + # - most MOE models should be used using this checkpoint format + try: + loc = get_resolved_checkpoint_location(checkpoint_name_or_path) + with open(os.path.join(loc, FILE_SAFETENSOR_INDEX), encoding="utf-8") as f: + index = json.load(f) + + # e.g., prefix: 'model.layers.0', + # module_name: 'block_sparse_moe' + for prefix, (module_name, _, has_bias) in tqdm( + found.items(), disable=(rank > 0), desc="Converting ScatterMoE layers" + ): + checkpoint_metadata = get_checkpoint_meta_from_sharded_safetensor( + index["weight_map"], + prefix, + module_name, + router_name, + "|".join(expert_name), + ) + + # the parent module + parent = model.get_submodule(prefix) + + # - handle state dict loading + # - NOTE: convert_state_dict does not have logic to concat sharded + # experts so cannot handle the case where sharded_expert_ckpt=True + if ( + ep_degree == 1 + and (not is_fsdp_enabled() or is_local_dist_rank_0()) + and not sharded_expert_ckpt # cannot be a sharded checkpoint + ): + # - if there is no sharding, and model is not loaded on the + # meta device, we can simply convert the state dict + sd = convert_state_dict( + prefix + "." + module_name + ".", + checkpoint_metadata, + getattr(parent, module_name).state_dict(), + model.config.num_local_experts, + model.config.intermediate_size, + dtype, + ) + else: + # if there is sharding, then we want the model to be loaded + # on meta in general, since the actual model may be alot smaller + sd = get_state_dict_from_checkpoint_metadata( + loc, + checkpoint_metadata, + num_experts_per_device, + model.config.intermediate_size, + expert_shards, + dtype, + ) + + if device_mesh is None: + _init_scattermoe_context = nullcontext + else: + # in this case we need to distribute parameters, so just initialize + # the scattermoe module swap with empty weights, + # since they are going to replaced. + _init_scattermoe_context = init_empty_weights + + # - conver to a scatter moe + # - very hard to do patching, settle for module swap + with _init_scattermoe_context(): + moe = ScatterMoE( + hidden_size=model.config.hidden_size, + hidden_act=model.config.hidden_act, + intermediate_size=model.config.intermediate_size, + num_experts=num_experts_per_device, + has_bias=has_bias, + mlp_arch=expert_mlp_spec, + top_k=model.config.num_experts_per_tok, + dtype=model.dtype, + device=device, + ep_device_mesh=( + device_mesh[key_ep] if device_mesh is not None else None + ), + lora_config=lora_config, + ) # + + # the state dict logic below will not have lora adapters + # - so we need to initialize them + # - initialize them + if lora_config is not None: + + # update the state_dict + for name, param in moe.named_parameters(): + # NOTE: is his reliable? + if "lora_" in name: + if device_mesh is not None: + # this means it has been loaded with empty context above + # - so materialize the tensor + param = torch.empty( + *param.size(), dtype=dtype, requires_grad=True + ) + + sd[name] = param # set the param in state dict + + # initialize the loras here + if "lora_A" in name: + torch.nn.init.zeros_(sd[name]) + elif "lora_B" in name: + torch.nn.init.normal_(sd[name]) + + if device_mesh is None: + # - if not on meta, just load the state dict + # - and then put on the device + moe.load_state_dict(sd) + moe = moe.to(device) + else: + # - otherwise, we need to distribtue and will + # replace the parameters + load_experts_onto_device(moe, sd, device_mesh, num_experts_per_device) + # module swap + setattr(parent, module_name, moe) + + # - keep track of the name for returning + moe_module_names.add(module_name) + + except ValueError as e: + raise ValueError( + f"Unable to load checkpoint_path '{checkpoint_name_or_path}'. " + "Currently only support non-GGUF safetensor checkpoints. " + ) from e + + return moe_module_names diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py new file mode 100644 index 00000000..e13f6ba5 --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -0,0 +1,338 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from collections import OrderedDict, defaultdict +from contextlib import ExitStack +from typing import Dict, List, Tuple +import os +import re + +# Third Party +from safetensors import safe_open +import torch + +# Local +from .scattermoe_constants import ( + DIM_EXPERT, + KEY_SCATTERMOE_ROUTER, + PARAM_NAME_WEIGHT_SCATTERMOE, +) + +# This function creates a dictionary of keys and paths into the the sharded +# safetensors checkpoint file, that are relevant to the "prefix" and "instance_name" +# being pased in. +# - the keys point to modules found in megablocks.layers.dmoe.dMoE, the distributed +# expert module provided by megablocks. +# - the values are tuples pointing to the keys within the checkpoint file. +# +# Example: if prefix="module.layers.0" and instance_name="block_sparse_moe", then a dictionary +# of the following will be returned: +# { +# 'w1.weight': [ +# ( +# 'model.layers.0.block_sparse_moe.experts.0.w1.weight', +# 'model-00001-of-00019.safetensors' +# ), +# ( +# 'model.layers.0.block_sparse_moe.experts.1.w1.weight', +# 'model-00001-of-00019.safetensors' +# ), +# ... +# ] +# 'w2.weight': [...], +# 'w3.weight': [...], +# 'router.weight': [ +# ( +# 'model.layers.0.block_sparse_moe.gate.weight', +# 'model-00001-of-00019.safetensors' +# ) +# ] +# } +# +# or the non-sharded case (and possibly fused case) +# { +# 'w1.weight': [ +# ( +# 'model.layers.0.block_sparse_moe.input_linear.layer.weight', +# 'model-00001-of-00001.safetensors' +# ), +# ], +# ... +# 'w3.weight': [ +# ( +# 'model.layers.0.block_sparse_moe.input_linear.layer.weight', +# 'model-00001-of-00001.safetensors' +# ), +# ] +# } + + +def get_checkpoint_meta_from_sharded_safetensor( + weight_map: Dict, + prefix: str, # e.g., 'model.layers.0, + instance_name: str, # e.g., block_sparse_moe + router_name: str = "gate", # e.g., named "gate" within block_sparse_moe + expert_name: str = "experts", # e.g., named "experts" within block_sparse_moe + expert_map: Dict = None, # map -> [w1,w2,w3] +) -> Dict[str, List[Tuple]]: + """ + utilty function to infer the mapping of ScatterMoe parameters + from that of an incoming model model, based on a weight_map from a + sharded safetensor. + + Parameters: + weight_map (dict): The weight map read in from a safetensor checkpoint. + prefix (str): the prefix where the MoE module lives (with respect to orig model). + instance_name (str): the name of the MoE module in the orig model + router_name (str): name of the router module as it is called in the MoE module + in the original model. + expert_name (str): name of the experts as they are called in the MoE module in + the orignal model. There are two patterns to use this. + i) specifiy a single string, and map them based on the + e.g., experts.w1 -> w1 + ii) specify mutiple strings in order of w1, w2, ... + e.g., input_linear|output_linear|input_linear + expert_map (dict): This is used with pattern ii) described above in expert_name. + If not specified, will be the identity map, e.g., w1 -> w1 + """ + + # insert in order + def _insert(L: List, i: int, v): + n = len(L) + if i < n: + L[i] = v + return + + n = i - n + 1 + while n > 0: + L.append(None) + n -= 1 + L[i] = v + + # if expert_name = input_linear|output_linear|input_linear + # - in this case will map + # - input_linear: [w1, w3], output_linear: {w2} + # - will assume the latter has double the size and can + # be split. + if expert_map is None: + if "|" in expert_name: + expert_map = {} + _names = expert_name.split("|") + _n, _n2 = len(_names), len(PARAM_NAME_WEIGHT_SCATTERMOE) + assert ( + 2 <= _n <= _n2 + ), f"If expert_name has |, expect between 2 and {_n2} entries, but got {_n}." + + for i, n in enumerate(_names): + if n not in expert_map: + expert_map[n] = [] + expert_map[n].append(PARAM_NAME_WEIGHT_SCATTERMOE[i]) + else: + expert_map = {x: [x] for x in PARAM_NAME_WEIGHT_SCATTERMOE} + + # state dict -> weights + # 'router.weight': [(k, file),...] + # `w1.weight`: [...] + _map = defaultdict(list) + prefix = f"{prefix}.{instance_name}." + for k, stfile in weight_map.items(): + if not k.startswith(prefix): + continue + + # e.g. after replacement we get + # - gate.weight + # - experts.0.w1.weight + rel_k = k.replace(prefix, "") + # pylint: disable=anomalous-backslash-in-string + m = re.match(f"({router_name}|{expert_name})\.?(\d+)?\.?(\w+)?\.weight", rel_k) + if m is None: + raise ValueError( + f"Unable to handle key '{k}' with provided router_name " + f"'{router_name}' or expert_name '{expert_name}'" + ) + if m.group(1) == router_name: + _map[KEY_SCATTERMOE_ROUTER].append((k, stfile)) + elif m.group(1) in expert_name: + index = m.group(2) + index = 0 if index is None else int(index) + mod = None + for mod in expert_map.get(m.group(1), expert_map.get(m.group(3))): + _insert(_map[f"{mod}.weight"], index, (k, stfile)) + + assert mod is not None, f"cannot map '{rel_k}'" + + if len(_map) == 0: + raise ValueError( + f"Could not get safetensor map for '{prefix}' and '{instance_name}'" + ) + + return _map + + +# if the weight is a scattermoe expert weight, need some reshaping +def _maybe_reshape_scattermoe_expert_weights( + scatter_key: str, + param: torch.Tensor, + num_experts: int, + intermediate_size: int, +): + (_is_w1, _is_w2, _is_w3) = [ + f"{x}.weight" in scatter_key for x in PARAM_NAME_WEIGHT_SCATTERMOE + ] + + if _is_w1 or _is_w2 or _is_w3: + if len(param.shape) == 2: + param = param.view(num_experts, -1, param.shape[-1]) + + if _is_w1 or _is_w3: + if param.shape[-2] == (2 * intermediate_size): + # cut it + if _is_w1: + param = param[..., :intermediate_size, :] + else: + param = param[..., intermediate_size:, :] + + # asumme these are linears + # assert param.shape[-2] == intermediate_size, "wrong intermediate size" + # assert param.shape[-1] == hidden_size, "wrong hidden size" + + # have to transpose for weights since scattermoe accepts the differen + # order + param = param.permute(0, 2, 1) + + return param + + +def convert_state_dict( + prefix: str, + checkpoint_metadata: Dict[str, List[Tuple]], + state_dict: OrderedDict, + num_experts: int, + intermediate_size: int, + dtype: torch.dtype = None, +): + """ + utility to convert the state dict for ScatterMoE. To be used + if the model is already loaded with weights. + + Parameters: + prefix (str): where the MoE is located in the incoming model. + checkpoint_metadata (dict): a mapping of ScatterMoE state dict + with respect to that of incoming model. + state_dict (dict): of the incoming MoE. + num_experts (int): + intermediate_size (int): + dtype (torch.dtype): + """ + target = OrderedDict() + + for scatter_key, vs in checkpoint_metadata.items(): + for state_key, _ in vs: + state_key = state_key.replace(prefix, "") + param = state_dict[state_key] + param = _maybe_reshape_scattermoe_expert_weights( + scatter_key, param, num_experts, intermediate_size + ) + if dtype is not None: + param = param.to(dtype) + target[scatter_key] = param + + return target + + +def get_state_dict_from_checkpoint_metadata( + checkpoint_directory: str, + checkpoint_metadata: Dict[str, List[Tuple]], + num_experts: int, + intermediate_size: int, + expert_shards: List[int] = None, + dtype: torch.dtype = None, +): + """ + utility to convert a sharded checkpoint into a state dict for + ScatterMoe. To be used if the model was loaded on the meta + device and actual weights does not exist in it. + + Parameters: + checkpoint_directory (str): where the checkpoint is located. + checkpoint_metadata (dict): a mapping of ScatterMoE state dict + with respect to that of incoming model. + num_experts (int): + intermediate_size (int): + expert_shards (list): indexing which of the shards are required + if only a subset of parameters are required + dtype (torch.dtype): + """ + target = OrderedDict() + + # typically they all should be same file, but to play safe, load the checkpoint file onto + # cpu first since we may not need all weights in that file. + with ExitStack() as stack: + files = {} + for _, vs in checkpoint_metadata.items(): + for _, fi in vs: + if fi not in files: + files[fi] = stack.enter_context( + safe_open( + os.path.join(checkpoint_directory, fi), + framework="pt", + device="cpu", + ) + ) + + # go by one weight at a time. + for scatter_key, vs in checkpoint_metadata.items(): + + if KEY_SCATTERMOE_ROUTER in scatter_key: + k, fi = vs[0] # only one item + param = files[fi].get_tensor(k) + + elif len(vs) == 1: + k, fi = vs[0] # only one item + # if its a non-router weight and its non-sharded + param = files[fi].get_tensor(k) + assert len(param.shape) == 3, ( + "Expected 3D tensor for checkpoints with non-sharded experts, ", + f"but got shape {param.shape}.", + ) + + else: + # handle sharding if the checkpoint shards experts + # - + data = [] + if expert_shards is not None: + vs = [vs[i] for i in expert_shards] + + for k, fi in vs: + T = files[fi].get_tensor(k) + assert len(T.shape) == 2, ( + "Expected 2D tensor for checkpoints with sharded experts, " + f"but got shape {T.shape}." + ) + + T = T.unsqueeze(0) + data.append(T) + + param = torch.concat(data, dim=DIM_EXPERT) + + param = _maybe_reshape_scattermoe_expert_weights( + scatter_key, param, num_experts, intermediate_size + ) + if dtype is not None: + param = param.to(dtype) + + target[scatter_key] = param + + return target diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/__init__.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/__init__.py new file mode 100644 index 00000000..b72b46d2 --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/__init__.py @@ -0,0 +1,17 @@ +# Copyright The FMS HF Tuning Authors +# Copyright 2023 MegaBlocks authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Local +from .expert_parallel import all_to_all_gather_inputs, scatter_with_routing_weights diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/expert_parallel.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/expert_parallel.py new file mode 100644 index 00000000..1704b850 --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/expert_parallel.py @@ -0,0 +1,232 @@ +# Copyright The FMS HF Tuning Authors +# Copyright 2023 MegaBlocks authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Third Party +import numpy as np +import torch + +try: + # if megablocks is installed, import the kernels, distributed + # and kernel functions + + # - mixture of triton and cuda kernels + # Third Party + from megablocks import ops + + # - distributed autograd + from megablocks.layers.all_to_all import all_to_all + from megablocks.ops import gather, histogram, inclusive_cumsum, scatter + + # this is a radix sort for integral indices 0 .. num_bins-1 + def sort(indices: torch.Tensor, num_bins: int): + bits = max(int(np.ceil(np.log2(num_bins))), 1) + # TODO: figure out why we need this upcast + bins, inds = ops.sort(indices, bits) + return bins, inds.to(torch.int64) + + # replicate indices with bins + def replicate(indices: torch.Tensor, bins: torch.Tensor): + replicate_bins = inclusive_cumsum(bins.flatten(), 0) + # pylint: disable=use-implicit-booleaness-not-len + replicate_bins = ( + replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins + ) + + return ops.replicate( + indices.unsqueeze(dim=0), + replicate_bins, + replicate_bins[-1], + ).flatten() + +except ImportError: + + # - distributed autograd + # Local + from .megablocks import all_to_all, gather, scatter + + # take the histogram of integral indices from 0 .. num_bins-1 + def histogram(indices: torch.Tensor, num_bins: int): + # - this has an Aten for the GPU backend + return torch.histc(indices, bins=num_bins, min=0, max=num_bins - 1) + + def inclusive_cumsum(x: torch.Tensor, dim: int): + # - convert to int332 type as that is what is expected by the + # megablocks gather and scatter kernels + return x.cumsum(axis=dim, dtype=torch.int32) + + # this is a radix sort for integral indices 0 .. num_bins-1 + def sort(indices: torch.Tensor, num_bins: int): + return torch.sort(indices) + + # replicate, this replicates an integral indices according to bin times + def replicate(indices: torch.Tensor, bins: torch.Tensor): + return torch.repeat_interleave(indices, bins) + + +# from megablocks +def no_indices_just_bins(top_expert, num_experts): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + # + tokens_per_expert = histogram(top_expert, num_experts) + + # Calculate the bin bounds for the sorted tokens. + bins = inclusive_cumsum(tokens_per_expert, 0) + # pylint: disable=use-implicit-booleaness-not-len + bins = bins.view(1) if not len(bins.size()) else bins + return bins, tokens_per_expert + + +# modified from https://github.com/databricks/megablocks/blob/main/megablocks/layers/mlp.py +# - credit to trevor-gale +def all_to_all_gather_inputs( + x: torch.Tensor, + top_experts: torch.Tensor, + bin_ids: torch.Tensor, + indices: torch.Tensor, + expert_parallel_group: torch.distributed.ProcessGroup, + top_k: int, + experts_per_rank: int, +): + """ + Extracted from megablocks. This function performs all-to-all input + gathering for expert parallel. + """ + + # Compute the mapping of local tokens to experts. + # expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + world_size = expert_parallel_group.size() + with torch.no_grad(): + bins, tokens_per_expert = no_indices_just_bins( + top_experts, experts_per_rank * world_size + ) + + # Pass token count information to the device on which the + # target expert resides. + parallel_tokens_per_expert = torch.empty_like( + tokens_per_expert, + ) + tpe_handle = torch.distributed.all_to_all_single( + parallel_tokens_per_expert, + tokens_per_expert, + group=expert_parallel_group, + async_op=True, + ) + + # Permute locally and without any padding so that tokens for each + # parallel device are stored contiguously. + # + # This view updates the shape of the tensor from [sl, bs, hs] to + # [sl * bs, hs] prior to the permutation. + x = gather(x, indices, bin_ids, bins, top_k) + + # Compute the number of tokens that will be received from each + # device and permute the input data across the devices. + with torch.no_grad(): + tpe_handle.wait() + + # Reshape to [world_size, num_experts_per_rank]. + tokens_per_expert = tokens_per_expert.view(world_size, experts_per_rank) + parallel_tokens_per_expert = parallel_tokens_per_expert.view( + world_size, experts_per_rank + ) + + # TODO(tgale): It might be faster to do this on the GPU and + # then communicate the results back to the host. + send_counts = tokens_per_expert.cpu().sum(dim=-1) + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() + recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) + + # Convert the send/recv counts to lists. + send_counts = send_counts.tolist() + recv_counts = recv_counts.tolist() + + # Start the cross-device permutation asynchronously so we can + # overlap communication with computation. + parallel_x, parallel_x_handle = all_to_all( + x, + recv_counts, + send_counts, + expert_parallel_group, + async_op=True, + ) + + with torch.no_grad(): + # After we do the cross-device permutation we have the tokens on the + # correct device but not yet grouped by expert because we received + # tokens from each device as contiguous chunks. To group the tokens + # for expert computation we'll do one more local permutation. The + # rest of this torch.no_grad() scope sets up the indices and bins + # for this permutation. + + # Construct the expert indices for the permuted tokens. + parallel_top_expert = torch.remainder( + torch.arange( + experts_per_rank * world_size, + dtype=torch.int32, + device=indices.device, + ), + experts_per_rank, + ) + + parallel_top_expert = replicate( + parallel_top_expert, + parallel_tokens_per_expert.flatten(), + ) + + parallel_bin_ids, parallel_indices = sort(parallel_top_expert, experts_per_rank) + + parallel_x_handle.wait() + + return ( + parallel_x, + parallel_bin_ids, + parallel_indices, + send_counts, + recv_counts, # for all to all + bins, # local + ) + + +def scatter_with_routing_weights( + x: torch.Tensor, + expert_weights: torch.Tensor, + send_counts: torch.Tensor, + recv_counts: torch.Tensor, + bins: torch.Tensor, + bin_ids: torch.Tensor, + indices: torch.Tensor, + expert_parallel_group: torch.distributed.ProcessGroup, + top_k: int, +): + """ + Extracted from megablocks. This function undoes the all-to-all + gathering for expert parallel. + """ + + # Un-permute the tokens across the devices. + x, _ = all_to_all( + x, + send_counts, + recv_counts, + expert_parallel_group, + ) + + # Un-permute locally to setup for the next series of operations. + return scatter(x, indices, bin_ids, expert_weights, bins, top_k) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/__init__.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/__init__.py new file mode 100644 index 00000000..024c575e --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/__init__.py @@ -0,0 +1,17 @@ +# Copyright The FMS HF Tuning Authors +# Copyright 2024 Databricks +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Local +from .autograd import all_to_all, gather, scatter diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/autograd.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/autograd.py new file mode 100644 index 00000000..f9e100fd --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/autograd.py @@ -0,0 +1,183 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +# Standard +import functools + +# Third Party +import torch + +# Local +from .kernels import gather as _kernels_gather +from .kernels import scatter as _kernels_scatter +from .kernels import scatter_wgrad as _kernels_scatter_wgrad + + +# ------------------------ HELPERS ----------------------------- +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + + return decorate_bwd + + +# ------------------------ AUTOGRAD ----------------------------- + + +class AllToAllOp(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): + out = torch.empty( + (sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype + ) + + ctx.input_shape = x.shape + ctx.output_split_sizes = output_split_sizes + ctx.input_split_sizes = input_split_sizes + ctx.group = group + handle = torch.distributed.all_to_all_single( + out, + x, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + ) + return out, handle + + @staticmethod + def backward(ctx, grad, _): + if ctx.needs_input_grad[0]: + out = torch.empty( + ctx.input_shape, + device=grad.device, + dtype=grad.dtype, + ) + torch.distributed.all_to_all_single( + out, + grad, + output_split_sizes=ctx.input_split_sizes, + input_split_sizes=ctx.output_split_sizes, + group=ctx.group, + ) + return out, None, None, None, None + return None, None, None, None, None + + +def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): + return AllToAllOp.apply( + x, + output_split_sizes, + input_split_sizes, + group, + async_op, + ) + + +# Autograd wrapper for scatter kernel. +class ScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, x, indices, bin_ids, weights, bins, top_k): + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) + ctx.top_k = top_k + ctx.x_shape = x.shape + return _kernels_scatter(x, indices, bin_ids, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx, grad): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins = saved_tensors[:4] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = _kernels_gather( + grad, + indices, + bin_ids, + weights, + bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = _kernels_scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None + + +def scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, +): + return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) + + +# Autograd wrapper for gather kernel. +class GatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, x, indices, bin_ids, bins, top_k): + ctx.save_for_backward(indices, bin_ids, bins) + ctx.top_k = top_k + return _kernels_gather(x, indices, bin_ids, None, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx, grad): + grad = grad.contiguous() + + indices, bin_ids, bins = ctx.saved_tensors + out = _kernels_scatter(grad, indices, bin_ids, None, bins, ctx.top_k) + return out, None, None, None, None, None + + +gather = GatherOp.apply diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/kernels/__init__.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/kernels/__init__.py new file mode 100644 index 00000000..6dc4fe5b --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/kernels/__init__.py @@ -0,0 +1,17 @@ +# Copyright The FMS HF Tuning Authors +# Copyright 2024 Databricks +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Local +from .gather_scatter import gather, scatter, scatter_wgrad diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/kernels/gather_scatter.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/kernels/gather_scatter.py new file mode 100644 index 00000000..4809b4fb --- /dev/null +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_utils/megablocks/kernels/gather_scatter.py @@ -0,0 +1,309 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +# Third Party +import torch +import triton +import triton.language as tl + + +def assert_is_tensor(x, ndim): + if x.ndim != ndim: + raise ValueError(f"Expected {ndim}-tensor but got {x.ndim}-tensor") + + +def assert_is_matrix(x): + assert_is_tensor(x, 2) + + +def assert_is_vector(x): + if x.ndim != 1: + raise ValueError(f"Expected 1-tensor but got {x.ndim}-tensor") + + +def assert_equal(a, b): + if a != b: + raise ValueError( + f"Expected dimensions to be equal but got {a} and {b}.", + ) + + +# a: (tokens, hidden_size), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({"BLOCK_X": 64}, num_warps=2), + triton.Config({"BLOCK_X": 128}, num_warps=2), + triton.Config({"BLOCK_X": 256}, num_warps=2), + triton.Config({"BLOCK_X": 128}, num_warps=4), + triton.Config({"BLOCK_X": 256}, num_warps=4), + ], + key=["NUM_COLUMNS"], +) +@triton.jit +def _padded_copy( + a, + b, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Our index into array 'a'. + index_a = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'b'. + index_b = offset_in_bin + if bin_idx > 0: + index_b += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + for i in range(tl.cdiv(NUM_COLUMNS, BLOCK_X)): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: Because of the padding, the output size is dynamic. + # We load the final padded bin bound to get the output rows. + output_rows = padded_bins[-1].cpu().item() + out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def gather(x, indices, bin_ids, weights, bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: There is no padding so the output rows equals the + # input rows multiplied by top_k. + output_rows = x.shape[0] * top_k + out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + out, + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) + + +def scatter(x, indices, bin_ids, weights, bins, top_k): + return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) + + +# x: (tokens, top_k, hidden_size), real +# grad: (tokens, hidden_size), real. +# wgrad: (tokens, top_k), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({"BLOCK_X": 64}, num_warps=2), + triton.Config({"BLOCK_X": 128}, num_warps=2), + triton.Config({"BLOCK_X": 256}, num_warps=2), + triton.Config({"BLOCK_X": 128}, num_warps=4), + triton.Config({"BLOCK_X": 256}, num_warps=4), + ], + key=["NUM_COLUMNS"], +) +@triton.jit +def _padded_copy_wgrad( + x, + grad, + wgrad, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Our index into 'tokens * top_k'. + index_out = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'x'. + index_x = offset_in_bin + if bin_idx > 0: + index_x += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for i in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) + _padded_copy_wgrad[(indices.shape[0],)]( + x, + grad, + out, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + TOP_K=top_k, + ) + return out + + +def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): + return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) diff --git a/plugins/accelerated-moe/tests/__init__.py b/plugins/accelerated-moe/tests/__init__.py new file mode 100644 index 00000000..38a9531e --- /dev/null +++ b/plugins/accelerated-moe/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/plugins/accelerated-moe/tests/test_scattermoe_plugin.py b/plugins/accelerated-moe/tests/test_scattermoe_plugin.py new file mode 100644 index 00000000..ccdba6d3 --- /dev/null +++ b/plugins/accelerated-moe/tests/test_scattermoe_plugin.py @@ -0,0 +1,34 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import os + +# Third Party +from fms_acceleration.utils import instantiate_framework, read_configuration + +# First Party +from fms_acceleration_moe import ScatterMoEAccelerationPlugin + +# configuration +DIRNAME = os.path.dirname(__file__) +CONFIG_PATH_SCATTERMOE = os.path.join(DIRNAME, "../configs/scattermoe.yaml") + + +def test_framework_installs_scattermoe_plugin(): + with instantiate_framework( + read_configuration(CONFIG_PATH_SCATTERMOE), require_packages_check=False + ) as framework: + for plugin in framework.active_plugins: + assert isinstance(plugin[1], ScatterMoEAccelerationPlugin) diff --git a/plugins/accelerated-moe/tests/test_scattermoe_state_dict.py b/plugins/accelerated-moe/tests/test_scattermoe_state_dict.py new file mode 100644 index 00000000..ff8965ba --- /dev/null +++ b/plugins/accelerated-moe/tests/test_scattermoe_state_dict.py @@ -0,0 +1,208 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from typing import List + +# Third Party +# pylint: disable=import-error +import pytest + +# First Party +from fms_acceleration_moe.utils.scattermoe_constants import ( + PARAM_NAME_ROUTER_SCATTERMOE, + PARAM_NAME_WEIGHT_SCATTERMOE, +) +from fms_acceleration_moe.utils.scattermoe_state_dict import ( + get_checkpoint_meta_from_sharded_safetensor, +) + +# just a dummy sample value +ST_SHARD = "model-00001-of-00001.safetensors" + + +# --------------------------- HELPERS ------------------------------ +# - builds a weight dict for checkpoints where MoE is sharded (i.e., +# one linear by expert). +# - this is like Mixtral style +def build_dummy_weight_map_sharded_moe( + prefix: str, + module_name: str, + router_name: str, + expert_name: str, + num_layers: int, + num_experts: int, + expert_keys: List[str], +): + + # - ST_SHARD entries are not impt for the test + weight_map = {} + for i in range(num_layers): + layer_map = { + f"{prefix}.{i}.{module_name}.{router_name}.weight": ST_SHARD, + } + for j in range(num_experts): + expert_map = {} + + for n in expert_keys: + expert_map.update( + { + f"{prefix}.{i}.{module_name}.{expert_name}.{j}.{n}.weight": ST_SHARD + } + ) + + layer_map.update(expert_map) + + weight_map.update(layer_map) + + return weight_map + + +# - this is like granite style +def build_dummy_weight_map_non_sharded_moe( + prefix: str, + module_name: str, + router_name: str, + num_layers: int, + expert_keys: List[str], +): + # - ST_SHARD entries are not impt for the test + weight_map = {} + for i in range(num_layers): + layer_map = { + f"{prefix}.{i}.{module_name}.{router_name}.weight": ST_SHARD, + } + for n in expert_keys: + layer_map.update({f"{prefix}.{i}.{module_name}.{n}.weight": ST_SHARD}) + + weight_map.update(layer_map) + + return weight_map + + +# --------------------------- TEST --------------------------------- + +PARAMETERS = [ + ( + True, + "model.layers", + "block_sparse_moe", + "gate", + "experts", + 2, + 8, + ["w1", "w2", "w3"], + ), + ( + False, + "model.layers", + "block_sparse_moe", + "gate", + "input_linear|output_linear|input_linear", + 2, + None, + ["input_linear", "output_linear"], + ), +] + + +@pytest.mark.parametrize( + ( + "sharded_ckpt,prefix,module_name,router_name,expert_name," + "num_layers,num_experts,expert_keys" + ), + PARAMETERS, +) +def test_get_metadata_from_sharded_safetensor_correctly( + sharded_ckpt: bool, + prefix: str, + module_name: str, + router_name: str, + expert_name: str, + num_layers: int, + num_experts: int, + expert_keys: List[str], +): + + if sharded_ckpt: + weight_map = build_dummy_weight_map_sharded_moe( + prefix, + module_name, + router_name, + expert_name, + num_layers, + num_experts, + expert_keys, + ) + else: + weight_map = build_dummy_weight_map_non_sharded_moe( + prefix, module_name, router_name, num_layers, expert_keys + ) + + # get the metadata for the a layer + ckpt_metadata = get_checkpoint_meta_from_sharded_safetensor( + weight_map, + prefix + ".0", # include layer + module_name, + router_name, + expert_name, + ) + + _key = f"{PARAM_NAME_ROUTER_SCATTERMOE}.weight" + assert _key in ckpt_metadata, "unable to map scattermoe router metadata." + + _n = len(ckpt_metadata[_key]) + assert _n == 1, f"expected only 1 router weights but got {_n}" + + for n in PARAM_NAME_WEIGHT_SCATTERMOE: + _key = f"{n}.weight" + assert _key in ckpt_metadata, f"unable top map scattermoe expert weight {n}." + + _n = len(ckpt_metadata[_key]) + if sharded_ckpt: + assert ( + _n == num_experts + ), f"missing expert weights, only mapped {_n} weights out of {num_experts}." + else: + assert ( + _n == 1 + ), f"missing expert weights, mapped {_n} but expected only 1 for non-sharded." + + +def test_get_metadata_from_sharded_safetensor_incorrectly(): + + weight_map_wrong = {"prefix.moe_name.expert.weight": ST_SHARD} + + # - if passing a prefix, has to map the weight_map + with pytest.raises(ValueError, match="Could not get safetensor map for"): + get_checkpoint_meta_from_sharded_safetensor( + weight_map_wrong, "wrong_prefix", "moe_name", None, "expert_name" + ) + + # - if passing mutiple expert names, cannot violate the number of + # possible expert gates + with pytest.raises( + AssertionError, match="If expert_name has |, expect between 2 and" + ): + get_checkpoint_meta_from_sharded_safetensor( + weight_map_wrong, "prefix", "moe_name", None, "exp1|exp2|exp3|exp4" + ) + + # - if a weight_map key that matches the moe_name, cannot be handled + with pytest.raises( + ValueError, match="Unable to handle key 'prefix.moe_name.expert.weight'" + ): + get_checkpoint_meta_from_sharded_safetensor( + weight_map_wrong, "prefix", "moe_name", None, "wrong_expert_name" + ) diff --git a/plugins/accelerated-moe/tox.ini b/plugins/accelerated-moe/tox.ini new file mode 100644 index 00000000..811f1329 --- /dev/null +++ b/plugins/accelerated-moe/tox.ini @@ -0,0 +1,48 @@ +[tox] +envlist = py, lint + +[testenv] +deps = + pytest>=7 + -e {toxinidir} +skip_install = true +commands = + + # install the dependencies here to ensure + # the order + pip install -e {toxinidir}/../framework + pytest {posargs:tests} + +[testenv:lint] +description = run linters +skip_install = false +deps = + -e {toxinidir}/../framework + pylint>=2.16.2,<=3.1.0 +commands = + pylint src tests +allowlist_externals = pylint + +[testenv:fmt] +description = format +skip_install = true +deps = + black>=22.12 + isort>=5.11 +commands = + black {posargs:.} + isort {posargs:.} + +[testenv:build] +description = build wheel +deps = + build +commands = python -m build -w +skip_install = True + +[testenv:twinecheck] +description = check wheel +deps = + twine +commands = twine check dist/* +skip_install = True \ No newline at end of file diff --git a/plugins/framework/src/fms_acceleration/constants.py b/plugins/framework/src/fms_acceleration/constants.py index 6a81d977..d568ec13 100644 --- a/plugins/framework/src/fms_acceleration/constants.py +++ b/plugins/framework/src/fms_acceleration/constants.py @@ -21,4 +21,4 @@ # and activated. # - hence the plugins that have model loaders should be on top of this list -PLUGINS = ["peft", "foak", "aadp"] +PLUGINS = ["peft", "foak", "aadp", "moe"] diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py index 16ea64b7..906a4668 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py @@ -39,6 +39,7 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] = from .models import ( # pylint: disable=import-outside-toplevel gpt_bigcode, granite, + granitemoe, llama, mistral, mixtral, @@ -47,6 +48,7 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] = rules = [ *gpt_bigcode.get_mp_rules(base_type), *granite.get_mp_rules(base_type), + *granitemoe.get_mp_rules(base_type), *llama.get_mp_rules(base_type), *mistral.get_mp_rules(base_type), *mixtral.get_mp_rules(base_type), @@ -76,6 +78,7 @@ class FastKernelsAccelerationPlugin(AccelerationPlugin): # NOTE: may remove this when we have generic model rules restricted_model_archs = [ "GraniteForCausalLM", + "GraniteMoeForCausalLM", "GPTBigCodeForCausalLM", "MixtralForCausalLM", "LlamaForCausalLM", diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granitemoe.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granitemoe.py new file mode 100644 index 00000000..6da14682 --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granitemoe.py @@ -0,0 +1,116 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from functools import partial + +# Third Party +from fms_acceleration.model_patcher import ( + ModelPatcherRule, + ModelPatcherTrigger, + combine_functions, + combine_triggers, +) + +# Local +from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss +from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm +from ..kernels.unsloth.rope_embedding import fast_rope_embedding +from .utils import KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops + + +def get_mp_rules(base_type: str): + """ + Function to access all patch rules in this module. + If it is a forward_builder rule with `base_type` in + its forward builder argument, wrap the forward_builder + function as a partial function with the base_type argument + """ + try: + # Third Party + from transformers.models.granitemoe.modeling_granitemoe import ( # pylint: disable=import-outside-toplevel + GraniteMoeAttention, + GraniteMoeRMSNorm, + ) + except ImportError: + return [] + + return [ + # TODO: have a generic version of this rule + # - do regex on RMSNorm class name + # - check on the tensors required for fast_rms_layernorm + ModelPatcherRule( + rule_id="granitemoe-rms", + trigger=ModelPatcherTrigger(check=GraniteMoeRMSNorm), + forward=fast_rms_layernorm, + ), + # TODO: have a generic version of this rule + # - do regex on Attention class name + # - have a set of qkv / o module names and check on that + ModelPatcherRule( + rule_id="granitemoe-qkvo", + trigger=combine_triggers( + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=GraniteMoeAttention, + submodule_names=["q_proj", "k_proj", "v_proj"], + ) + ), + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=GraniteMoeAttention, + submodule_names=["o_proj"], + ) + ), + logic="OR", + ), + forward_builder=combine_functions( + partial( + build_lora_fused_ops, + submodule_names=["q_proj", "k_proj", "v_proj"], + fused_op=KEY_QKV, + base_type=base_type, + ), + partial( + build_lora_fused_ops, + submodule_names=["o_proj"], + fused_op=KEY_O, + base_type=base_type, + ), + logic="APPEND", + ), + ), + ModelPatcherRule( + rule_id="granitemoe-cross-ent", + import_and_maybe_reload=( + "torch.nn.CrossEntropyLoss", + FastCrossEntropyLoss, + "transformers.models.granitemoe.modeling_granitemoe", + ), + ), + # TODO: have a generic version of this rule + # - get the module name + # - check if "apply_rotary_pos_emb" exists + # - patch + ModelPatcherRule( + rule_id="granitemoe-rope", + import_and_maybe_reload=( + "transformers.models.granitemoe.modeling_granitemoe.apply_rotary_pos_emb", + fast_rope_embedding, + None, + ), + ), + ] diff --git a/sample-configurations/CONTENTS.yaml b/sample-configurations/CONTENTS.yaml index 6781b3bd..b3b1deec 100644 --- a/sample-configurations/CONTENTS.yaml +++ b/sample-configurations/CONTENTS.yaml @@ -73,3 +73,69 @@ framework_configs: plugins: - fused-ops-and-kernels filename: foak-fast-kernels-sample-configuration.yaml + + # ------- MOE CONFIGS ---------- + - shortname: moe-scattermoe-granite-ep1 + plugins: + - accelerated-moe + filename: moe-scattermoe-granite-ep1-sample-configuration.yaml + + - shortname: moe-scattermoe-granite-ep1-padding-free + plugins: + - accelerated-moe + - attention-and-distributed-packing + filename: moe-scattermoe-granite-ep1-padding-free-sample-configuration.yaml + + - shortname: moe-scattermoe-granite-ep1-padding-free-foak + plugins: + - accelerated-moe + - attention-and-distributed-packing + - fused-ops-and-kernels + filename: moe-scattermoe-granite-ep1-padding-free-foak-sample-configuration.yaml + + - shortname: moe-scattermoe-granite-ep2 + plugins: + - accelerated-moe + filename: moe-scattermoe-granite-ep2-sample-configuration.yaml + + - shortname: moe-scattermoe-granite-ep2-padding-free + plugins: + - accelerated-moe + - attention-and-distributed-packing + filename: moe-scattermoe-granite-ep2-padding-free-sample-configuration.yaml + + - shortname: moe-scattermoe-granite-ep2-padding-free-foak + plugins: + - accelerated-moe + - attention-and-distributed-packing + - fused-ops-and-kernels + filename: moe-scattermoe-granite-ep2-padding-free-foak-sample-configuration.yaml + + - shortname: moe-scattermoe-granite-ep4 + plugins: + - accelerated-moe + filename: moe-scattermoe-granite-ep4-sample-configuration.yaml + + - shortname: moe-scattermoe-granite-ep4-padding-free + plugins: + - accelerated-moe + - attention-and-distributed-packing + filename: moe-scattermoe-granite-ep4-padding-free-sample-configuration.yaml + + - shortname: moe-scattermoe-granite-ep4-padding-free-foak + plugins: + - accelerated-moe + - attention-and-distributed-packing + - fused-ops-and-kernels + filename: moe-scattermoe-granite-ep4-padding-free-foak-sample-configuration.yaml + + - shortname: moe-scattermoe-granite-ep8 + plugins: + - accelerated-moe + filename: moe-scattermoe-granite-ep8-sample-configuration.yaml + + - shortname: moe-scattermoe-granite-ep8-foak + plugins: + - accelerated-moe + - fused-ops-and-kernels + filename: moe-scattermoe-granite-ep8-foak-sample-configuration.yaml \ No newline at end of file diff --git a/sample-configurations/moe-scattermoe-granite-ep1-padding-free-foak-sample-configuration.yaml b/sample-configurations/moe-scattermoe-granite-ep1-padding-free-foak-sample-configuration.yaml new file mode 100644 index 00000000..881ef14b --- /dev/null +++ b/sample-configurations/moe-scattermoe-granite-ep1-padding-free-foak-sample-configuration.yaml @@ -0,0 +1,51 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + # Configurations to accelerate data packing/padding in training + training: + + # attention module configurations + # e.g. padding-free modifications to attention layer + attention: + + # this controls the confgurations for padding free computation of flash attention + padding_free: + method: huggingface + fused_ops_and_kernels: + + # if under training stanza, then putting + # base_layer and fused_lora will be a misnomer + # - this should be in peft.quantized + # However, if it is specified, it will still + # be read. This is useful in use cases where + # the yaml is system generated and not shown + # to a user. + + # activate various unsloth optimizations + # there are two versions of the plugin + # - the FastKernel version supports individual kernels + # - the FastQuantized version is all-or-nothing + + # fast loss triton kernels + fast_loss: true + + # fast rms norm triton kernels + fast_rms_layernorm: true + + # fast RoPE embedding triton kernels + fast_rope_embeddings: true + moe: + + # expert-parallel for MoE + scattermoe: + + # The level of expert parallel sharding. + # - 1 means no sharding + # - if > 1, please ensure that this divides the world_size. This is because + # the devices will be replicated for every ep_degree devices, and + # the experts will be sharded within each group. + # - if > 1, also ensure that it divides the number of experts, as each device + # will then have num_of_experts / ep_degree experts. + ep_degree: 1 diff --git a/sample-configurations/moe-scattermoe-granite-ep1-padding-free-sample-configuration.yaml b/sample-configurations/moe-scattermoe-granite-ep1-padding-free-sample-configuration.yaml new file mode 100644 index 00000000..b3af6f99 --- /dev/null +++ b/sample-configurations/moe-scattermoe-granite-ep1-padding-free-sample-configuration.yaml @@ -0,0 +1,28 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + # Configurations to accelerate data packing/padding in training + training: + + # attention module configurations + # e.g. padding-free modifications to attention layer + attention: + + # this controls the confgurations for padding free computation of flash attention + padding_free: + method: huggingface + moe: + + # expert-parallel for MoE + scattermoe: + + # The level of expert parallel sharding. + # - 1 means no sharding + # - if > 1, please ensure that this divides the world_size. This is because + # the devices will be replicated for every ep_degree devices, and + # the experts will be sharded within each group. + # - if > 1, also ensure that it divides the number of experts, as each device + # will then have num_of_experts / ep_degree experts. + ep_degree: 1 diff --git a/sample-configurations/moe-scattermoe-granite-ep1-sample-configuration.yaml b/sample-configurations/moe-scattermoe-granite-ep1-sample-configuration.yaml new file mode 100644 index 00000000..bb327a44 --- /dev/null +++ b/sample-configurations/moe-scattermoe-granite-ep1-sample-configuration.yaml @@ -0,0 +1,21 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + training: + + # mixture-of-experts configurations + moe: + + # expert-parallel for MoE + scattermoe: + + # The level of expert parallel sharding. + # - 1 means no sharding + # - if > 1, please ensure that this divides the world_size. This is because + # the devices will be replicated for every ep_degree devices, and + # the experts will be sharded within each group. + # - if > 1, also ensure that it divides the number of experts, as each device + # will then have num_of_experts / ep_degree experts. + ep_degree: 1 diff --git a/sample-configurations/moe-scattermoe-granite-ep2-padding-free-foak-sample-configuration.yaml b/sample-configurations/moe-scattermoe-granite-ep2-padding-free-foak-sample-configuration.yaml new file mode 100644 index 00000000..b3c7712d --- /dev/null +++ b/sample-configurations/moe-scattermoe-granite-ep2-padding-free-foak-sample-configuration.yaml @@ -0,0 +1,51 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + # Configurations to accelerate data packing/padding in training + training: + + # attention module configurations + # e.g. padding-free modifications to attention layer + attention: + + # this controls the confgurations for padding free computation of flash attention + padding_free: + method: huggingface + fused_ops_and_kernels: + + # if under training stanza, then putting + # base_layer and fused_lora will be a misnomer + # - this should be in peft.quantized + # However, if it is specified, it will still + # be read. This is useful in use cases where + # the yaml is system generated and not shown + # to a user. + + # activate various unsloth optimizations + # there are two versions of the plugin + # - the FastKernel version supports individual kernels + # - the FastQuantized version is all-or-nothing + + # fast loss triton kernels + fast_loss: true + + # fast rms norm triton kernels + fast_rms_layernorm: true + + # fast RoPE embedding triton kernels + fast_rope_embeddings: true + moe: + + # expert-parallel for MoE + scattermoe: + + # The level of expert parallel sharding. + # - 1 means no sharding + # - if > 1, please ensure that this divides the world_size. This is because + # the devices will be replicated for every ep_degree devices, and + # the experts will be sharded within each group. + # - if > 1, also ensure that it divides the number of experts, as each device + # will then have num_of_experts / ep_degree experts. + ep_degree: 2 diff --git a/sample-configurations/moe-scattermoe-granite-ep2-padding-free-sample-configuration.yaml b/sample-configurations/moe-scattermoe-granite-ep2-padding-free-sample-configuration.yaml new file mode 100644 index 00000000..474171b3 --- /dev/null +++ b/sample-configurations/moe-scattermoe-granite-ep2-padding-free-sample-configuration.yaml @@ -0,0 +1,28 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + # Configurations to accelerate data packing/padding in training + training: + + # attention module configurations + # e.g. padding-free modifications to attention layer + attention: + + # this controls the confgurations for padding free computation of flash attention + padding_free: + method: huggingface + moe: + + # expert-parallel for MoE + scattermoe: + + # The level of expert parallel sharding. + # - 1 means no sharding + # - if > 1, please ensure that this divides the world_size. This is because + # the devices will be replicated for every ep_degree devices, and + # the experts will be sharded within each group. + # - if > 1, also ensure that it divides the number of experts, as each device + # will then have num_of_experts / ep_degree experts. + ep_degree: 2 diff --git a/sample-configurations/moe-scattermoe-granite-ep2-sample-configuration.yaml b/sample-configurations/moe-scattermoe-granite-ep2-sample-configuration.yaml new file mode 100644 index 00000000..b24d00cb --- /dev/null +++ b/sample-configurations/moe-scattermoe-granite-ep2-sample-configuration.yaml @@ -0,0 +1,21 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + training: + + # mixture-of-experts configurations + moe: + + # expert-parallel for MoE + scattermoe: + + # The level of expert parallel sharding. + # - 1 means no sharding + # - if > 1, please ensure that this divides the world_size. This is because + # the devices will be replicated for every ep_degree devices, and + # the experts will be sharded within each group. + # - if > 1, also ensure that it divides the number of experts, as each device + # will then have num_of_experts / ep_degree experts. + ep_degree: 2 diff --git a/sample-configurations/moe-scattermoe-granite-ep4-padding-free-foak-sample-configuration.yaml b/sample-configurations/moe-scattermoe-granite-ep4-padding-free-foak-sample-configuration.yaml new file mode 100644 index 00000000..c73917ce --- /dev/null +++ b/sample-configurations/moe-scattermoe-granite-ep4-padding-free-foak-sample-configuration.yaml @@ -0,0 +1,51 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + # Configurations to accelerate data packing/padding in training + training: + + # attention module configurations + # e.g. padding-free modifications to attention layer + attention: + + # this controls the confgurations for padding free computation of flash attention + padding_free: + method: huggingface + fused_ops_and_kernels: + + # if under training stanza, then putting + # base_layer and fused_lora will be a misnomer + # - this should be in peft.quantized + # However, if it is specified, it will still + # be read. This is useful in use cases where + # the yaml is system generated and not shown + # to a user. + + # activate various unsloth optimizations + # there are two versions of the plugin + # - the FastKernel version supports individual kernels + # - the FastQuantized version is all-or-nothing + + # fast loss triton kernels + fast_loss: true + + # fast rms norm triton kernels + fast_rms_layernorm: true + + # fast RoPE embedding triton kernels + fast_rope_embeddings: true + moe: + + # expert-parallel for MoE + scattermoe: + + # The level of expert parallel sharding. + # - 1 means no sharding + # - if > 1, please ensure that this divides the world_size. This is because + # the devices will be replicated for every ep_degree devices, and + # the experts will be sharded within each group. + # - if > 1, also ensure that it divides the number of experts, as each device + # will then have num_of_experts / ep_degree experts. + ep_degree: 4 diff --git a/sample-configurations/moe-scattermoe-granite-ep4-padding-free-sample-configuration.yaml b/sample-configurations/moe-scattermoe-granite-ep4-padding-free-sample-configuration.yaml new file mode 100644 index 00000000..8cd803cd --- /dev/null +++ b/sample-configurations/moe-scattermoe-granite-ep4-padding-free-sample-configuration.yaml @@ -0,0 +1,28 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + # Configurations to accelerate data packing/padding in training + training: + + # attention module configurations + # e.g. padding-free modifications to attention layer + attention: + + # this controls the confgurations for padding free computation of flash attention + padding_free: + method: huggingface + moe: + + # expert-parallel for MoE + scattermoe: + + # The level of expert parallel sharding. + # - 1 means no sharding + # - if > 1, please ensure that this divides the world_size. This is because + # the devices will be replicated for every ep_degree devices, and + # the experts will be sharded within each group. + # - if > 1, also ensure that it divides the number of experts, as each device + # will then have num_of_experts / ep_degree experts. + ep_degree: 4 diff --git a/sample-configurations/moe-scattermoe-granite-ep4-sample-configuration.yaml b/sample-configurations/moe-scattermoe-granite-ep4-sample-configuration.yaml new file mode 100644 index 00000000..b48081df --- /dev/null +++ b/sample-configurations/moe-scattermoe-granite-ep4-sample-configuration.yaml @@ -0,0 +1,21 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + training: + + # mixture-of-experts configurations + moe: + + # expert-parallel for MoE + scattermoe: + + # The level of expert parallel sharding. + # - 1 means no sharding + # - if > 1, please ensure that this divides the world_size. This is because + # the devices will be replicated for every ep_degree devices, and + # the experts will be sharded within each group. + # - if > 1, also ensure that it divides the number of experts, as each device + # will then have num_of_experts / ep_degree experts. + ep_degree: 4 diff --git a/sample-configurations/moe-scattermoe-granite-ep8-foak-sample-configuration.yaml b/sample-configurations/moe-scattermoe-granite-ep8-foak-sample-configuration.yaml new file mode 100644 index 00000000..938c9024 --- /dev/null +++ b/sample-configurations/moe-scattermoe-granite-ep8-foak-sample-configuration.yaml @@ -0,0 +1,43 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + training: + + fused_ops_and_kernels: + + # if under training stanza, then putting + # base_layer and fused_lora will be a misnomer + # - this should be in peft.quantized + # However, if it is specified, it will still + # be read. This is useful in use cases where + # the yaml is system generated and not shown + # to a user. + + # activate various unsloth optimizations + # there are two versions of the plugin + # - the FastKernel version supports individual kernels + # - the FastQuantized version is all-or-nothing + + # fast loss triton kernels + fast_loss: true + + # fast rms norm triton kernels + fast_rms_layernorm: true + + # fast RoPE embedding triton kernels + fast_rope_embeddings: true + moe: + + # expert-parallel for MoE + scattermoe: + + # The level of expert parallel sharding. + # - 1 means no sharding + # - if > 1, please ensure that this divides the world_size. This is because + # the devices will be replicated for every ep_degree devices, and + # the experts will be sharded within each group. + # - if > 1, also ensure that it divides the number of experts, as each device + # will then have num_of_experts / ep_degree experts. + ep_degree: 8 diff --git a/sample-configurations/moe-scattermoe-granite-ep8-sample-configuration.yaml b/sample-configurations/moe-scattermoe-granite-ep8-sample-configuration.yaml new file mode 100644 index 00000000..af5500e5 --- /dev/null +++ b/sample-configurations/moe-scattermoe-granite-ep8-sample-configuration.yaml @@ -0,0 +1,21 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + training: + + # mixture-of-experts configurations + moe: + + # expert-parallel for MoE + scattermoe: + + # The level of expert parallel sharding. + # - 1 means no sharding + # - if > 1, please ensure that this divides the world_size. This is because + # the devices will be replicated for every ep_degree devices, and + # the experts will be sharded within each group. + # - if > 1, also ensure that it divides the number of experts, as each device + # will then have num_of_experts / ep_degree experts. + ep_degree: 8 diff --git a/scripts/benchmarks/README.md b/scripts/benchmarks/README.md index 269d3ead..4795efee 100644 --- a/scripts/benchmarks/README.md +++ b/scripts/benchmarks/README.md @@ -76,13 +76,14 @@ bash run_benchmarks.sh NUM_GPUS_MATRIX RESULT_DIR SCENARIOS_CONFIG SCENARIOS_FIL ``` where: - `NUM_GPUS_MATRIX`: list of `num_gpu` settings to bench for, e.g. `"1 2"` will bench for 1 and 2 gpus. +- `EFFECTIVE_BS_MATRIX`: list of effective batch sizes, e.g., `"4 8"` will bench for effective batch sizes 4 and 8. - `RESULT_DIR`: where the benchmark results will be placed. - `SCENARIOS_CONFIG`: the `scenarios.yaml` file. - `SCENARIOS_CONFIG`: specify to run only a specific `scenario` by providing the specific `scenario` name. The recommended way to run `benchmarks.sh` is using `tox` which handles the dependencies: ``` -tox -e run-benches -- NUM_GPUS_MATRIX RESULT_DIR SCENARIOS_CONFIG SCENARIOS_FILTER +tox -e run-benches -- NUM_GPUS_MATRIX EFFECTIVE_BS_MATRIX RESULT_DIR SCENARIOS_CONFIG SCENARIOS_FILTER ``` Alternatively run [`benchmark.py`](./benchmark.py) directly. To see the help do: diff --git a/scripts/benchmarks/accelerate.yaml b/scripts/benchmarks/accelerate.yaml index 7923e624..a40505a7 100644 --- a/scripts/benchmarks/accelerate.yaml +++ b/scripts/benchmarks/accelerate.yaml @@ -31,7 +31,7 @@ fsdp_config: # 3 is NO_SHARD, effectively disabling FSDP # 4, 5 are HYBRID_ modes for multi-node training only. - fsdp_state_dict_type: FULL_STATE_DICT # set to FULL_STATE_DICT (1), SHARDED_STATE_DICT (3) + fsdp_state_dict_type: SHARDED_STATE_DICT # set to FULL_STATE_DICT (1), SHARDED_STATE_DICT (3) # 2 is LOCAL_STATE_DICT where parameters are still flattened # 3 is efficient, but requires know-how to use the shared checkpoint. diff --git a/scripts/benchmarks/accelerator-config.json b/scripts/benchmarks/accelerator-config.json new file mode 100644 index 00000000..7f736f97 --- /dev/null +++ b/scripts/benchmarks/accelerator-config.json @@ -0,0 +1,5 @@ +{ + "gradient_accumulation_kwargs": { + "sync_each_batch": true + } +} diff --git a/scripts/benchmarks/benchmark.py b/scripts/benchmarks/benchmark.py index 3bd7056b..38fe6679 100644 --- a/scripts/benchmarks/benchmark.py +++ b/scripts/benchmarks/benchmark.py @@ -305,12 +305,34 @@ def build_args_from_products(products: List[Dict], defaults: Dict): argument_list = ConfigUtils.convert_keyvalue_arguments_to_list( combined_args ) - argument_list.extend( - [ - "--per_device_train_batch_size", - str(effective_batch_size // num_gpus), - ] - ) + pdtbs = combined_args.get('per_device_train_batch_size') + grad_accum = combined_args.get('gradient_accumulation_steps') + if pdtbs is None and grad_accum is not None: + if grad_accum > 1: + warnings.warn( + f"Found gradient_accumulation_steps={grad_accum} and " + "no per_device_train_batch_size specified, but for backward " + "compatibility, ignoring gradient_accum in batch size " + "computation (this behavior may change in the future)." + ) + argument_list.extend( + [ + "--per_device_train_batch_size", + str(effective_batch_size // num_gpus), + ] + ) + elif grad_accum is None and pdtbs is not None: + argument_list.extend( + [ + "--gradient_accumulation_steps", + str(effective_batch_size // num_gpus // pdtbs), + ] + ) + else: + raise ValueError( + "Please specify only either per_device_train_batch_size or gradient_accumulation_steps " + "and not both." + ) args.append((num_gpus, framework_config, argument_list)) return args @@ -358,6 +380,12 @@ class ScenarioMatrix: def __init__(self, scenario: Dict, acceleration_config_map: Dict = None) -> None: assert "arguments" in scenario.keys(), "Missing `arguments` key in `scenario`" + + # "slow" is a special key that indicates this scenario + # takes resources to run + # - "slow" scenarios are not run if not specified by a filter + self.slow = False + for key, val in scenario.items(): if key == "framework_config": # if acceleration_config_map is None, then do not do mapping @@ -689,7 +717,18 @@ def prepare_arguments(args, benchmark_dataset: BenchmarkDataset): if args.run_only_scenarios and _scn_name not in args.run_only_scenarios: print(f"Skipping scenario '{_scn_name}'") continue + + # build scenario matrix scenario = ScenarioMatrix(scenario_config, acceleration_config_map) + + if ( + not args.run_only_scenarios + and scenarios.slow + ): + # unfiltered runs omit all "slow" marked scenarios + print(f"Skipping slow scenario '{_scn_name}' beacuse run_only_scenarios=None.") + continue + scenario_matrices, scenario_constants = ( scenario.get_scenario_matrices_and_defaults() ) diff --git a/scripts/benchmarks/refs/a100_80gb_moe.csv b/scripts/benchmarks/refs/a100_80gb_moe.csv new file mode 100644 index 00000000..4936cd6f --- /dev/null +++ b/scripts/benchmarks/refs/a100_80gb_moe.csv @@ -0,0 +1,25 @@ +epoch,framework_config,gradient_accumulation_steps,mem_nvidia_mem_reserved,model_name_or_path,num_gpus,per_device_train_batch_size,torch_dtype,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second +0.25,none,16.0,71199.0,ibm-granite/granite-3.0-3b-a800m-instruct,1,8,bfloat16,0.9438143467903136,2371.9316,5.396,0.042,1505.608 +0.25,none,8.0,46829.0,ibm-granite/granite-3.0-3b-a800m-instruct,2,8,bfloat16,0.9437569552659988,1355.7096,9.442,0.074,1317.096 +0.25,none,4.0,37996.0,ibm-granite/granite-3.0-3b-a800m-instruct,4,8,bfloat16,0.9437739425897598,708.3914,18.069,0.141,1260.32 +0.25,moe-scattermoe-granite-ep1,16.0,71187.0,ibm-granite/granite-3.0-3b-a800m-instruct,1,8,bfloat16,0.9439476370811464,742.739,17.234,0.135,4808.149 +0.25,moe-scattermoe-granite-ep1,8.0,52503.0,ibm-granite/granite-3.0-3b-a800m-instruct,2,8,bfloat16,0.9506204092502594,485.5103,26.364,0.206,3677.78 +0.25,moe-scattermoe-granite-ep1,4.0,51145.0,ibm-granite/granite-3.0-3b-a800m-instruct,4,8,bfloat16,0.9572784686088562,262.9566,48.677,0.38,3395.238 +0.25,moe-scattermoe-granite-ep2,8.0,40193.0,ibm-granite/granite-3.0-3b-a800m-instruct,2,8,bfloat16,0.9437192791700364,577.2164,22.175,0.173,3093.467 +0.25,moe-scattermoe-granite-ep2,4.0,40878.5,ibm-granite/granite-3.0-3b-a800m-instruct,4,8,bfloat16,0.9509018939733506,300.285,42.626,0.333,2973.176 +0.25,moe-scattermoe-granite-ep4,4.0,31777.5,ibm-granite/granite-3.0-3b-a800m-instruct,4,8,bfloat16,0.9434539985656738,307.1264,41.677,0.326,2906.946 +0.25,moe-scattermoe-granite-ep1-padding-free,16.0,48401.0,ibm-granite/granite-3.0-3b-a800m-instruct,1,8,bfloat16,0.9437484860420228,631.9756,20.254,0.158,3924.202 +0.25,moe-scattermoe-granite-ep1-padding-free,8.0,42452.0,ibm-granite/granite-3.0-3b-a800m-instruct,2,8,bfloat16,0.9506663566827774,454.3444,28.172,0.22,2729.207 +0.25,moe-scattermoe-granite-ep1-padding-free,4.0,38560.0,ibm-granite/granite-3.0-3b-a800m-instruct,4,8,bfloat16,0.957276314496994,241.2967,53.047,0.414,2569.451 +0.25,moe-scattermoe-granite-ep2-padding-free,8.0,31012.0,ibm-granite/granite-3.0-3b-a800m-instruct,2,8,bfloat16,0.943688799738884,546.507,23.421,0.183,2268.955 +0.25,moe-scattermoe-granite-ep2-padding-free,4.0,28133.0,ibm-granite/granite-3.0-3b-a800m-instruct,4,8,bfloat16,0.9505942213535308,283.5444,45.143,0.353,2186.607 +0.25,moe-scattermoe-granite-ep4-padding-free,4.0,21585.5,ibm-granite/granite-3.0-3b-a800m-instruct,4,8,bfloat16,0.9441865116357804,284.6079,44.974,0.351,2178.436 +0.25,moe-scattermoe-granite-ep1-padding-free-foak,16.0,42651.0,ibm-granite/granite-3.0-3b-a800m-instruct,1,8,bfloat16,0.9437448275089264,615.4528,20.798,0.162,4029.554 +0.25,moe-scattermoe-granite-ep1-padding-free-foak,8.0,37743.0,ibm-granite/granite-3.0-3b-a800m-instruct,2,8,bfloat16,0.950773031115532,433.4811,29.528,0.231,2860.563 +0.25,moe-scattermoe-granite-ep1-padding-free-foak,4.0,35153.0,ibm-granite/granite-3.0-3b-a800m-instruct,4,8,bfloat16,0.9572476959228516,232.0428,55.162,0.431,2671.921 +0.25,moe-scattermoe-granite-ep2-padding-free-foak,8.0,26075.0,ibm-granite/granite-3.0-3b-a800m-instruct,2,8,bfloat16,0.9437651455402374,524.7751,24.391,0.191,2362.917 +0.25,moe-scattermoe-granite-ep2-padding-free-foak,4.0,24665.5,ibm-granite/granite-3.0-3b-a800m-instruct,4,8,bfloat16,0.9507779973745346,274.126,46.694,0.365,2261.733 +0.25,moe-scattermoe-granite-ep4-padding-free-foak,4.0,18368.0,ibm-granite/granite-3.0-3b-a800m-instruct,4,8,bfloat16,0.943427557349205,278.1245,46.023,0.36,2229.217 +,none,,65607.25,mistralai/Mixtral-8x7B-Instruct-v0.1,8,1,bfloat16,0.8599078696966171,4180.9544,3.062,0.024,80.364 +,moe-scattermoe-granite-ep8,,52004.75,mistralai/Mixtral-8x7B-Instruct-v0.1,8,1,bfloat16,0.8588122856616974,1071.1967,11.949,0.093,313.668 +,moe-scattermoe-granite-ep8-foak,,51961.25,mistralai/Mixtral-8x7B-Instruct-v0.1,8,1,bfloat16,0.8599798053503036,1043.6675,12.264,0.096,321.942 diff --git a/scripts/benchmarks/refs/requirements_moe.txt b/scripts/benchmarks/refs/requirements_moe.txt new file mode 100644 index 00000000..63700ed0 --- /dev/null +++ b/scripts/benchmarks/refs/requirements_moe.txt @@ -0,0 +1,89 @@ +accelerate==1.0.1 +aiohappyeyeballs==2.4.3 +aiohttp==3.10.10 +aiosignal==1.3.1 +async-timeout==4.0.3 +attrs==24.2.0 +bitsandbytes==0.43.3 +certifi==2024.8.30 +charset-normalizer==3.4.0 +contourpy==1.3.0 +cycler==0.12.1 +datasets==2.21.0 +dill==0.3.8 +docstring_parser==0.16 +einops==0.8.0 +filelock==3.16.1 +flash-attn==2.6.3 +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@21af5fb9f2989b3dbf443c016e4c0470b536a593#egg=fms_acceleration&subdirectory=plugins/framework +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@21af5fb9f2989b3dbf443c016e4c0470b536a593#egg=fms_acceleration_aadp&subdirectory=plugins/attention-and-distributed-packing +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@21af5fb9f2989b3dbf443c016e4c0470b536a593#egg=fms_acceleration_foak&subdirectory=plugins/fused-ops-and-kernels +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@21af5fb9f2989b3dbf443c016e4c0470b536a593#egg=fms_acceleration_moe&subdirectory=plugins/accelerated-moe +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@21af5fb9f2989b3dbf443c016e4c0470b536a593#egg=fms_acceleration_peft&subdirectory=plugins/accelerated-peft +fms-hf-tuning @ git+https://github.com/foundation-model-stack/fms-hf-tuning.git@398c2a8fe26d734344240555585d95e05299faa8 +fonttools==4.54.1 +frozenlist==1.5.0 +fsspec==2024.6.1 +huggingface-hub==0.26.2 +idna==3.10 +Jinja2==3.1.4 +kernel-hyperdrive @ git+https://github.com/fabianlim/kernel-hyperdrive.git@45036497e12444ca98a6f0072204538aee4543ba +kiwisolver==1.4.7 +llvmlite==0.43.0 +markdown-it-py==3.0.0 +MarkupSafe==3.0.2 +matplotlib==3.9.2 +mdurl==0.1.2 +mpmath==1.3.0 +multidict==6.1.0 +multiprocess==0.70.16 +networkx==3.4.2 +numba==0.60.0 +numpy==1.26.4 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.20.5 +nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvtx-cu12==12.1.105 +packaging==24.2 +pandas==2.2.3 +peft==0.13.2 +pillow==11.0.0 +propcache==0.2.0 +protobuf==5.28.3 +psutil==6.1.0 +pyarrow==18.0.0 +Pygments==2.18.0 +pyparsing==3.2.0 +python-dateutil==2.9.0.post0 +pytz==2024.2 +PyYAML==6.0.2 +regex==2024.11.6 +requests==2.32.3 +rich==13.9.4 +safetensors==0.4.5 +sentencepiece==0.2.0 +shtab==1.7.1 +simpleeval==0.9.13 +six==1.16.0 +sympy==1.13.1 +threadpoolctl==3.5.0 +tokenizers==0.20.3 +torch==2.4.1 +tqdm==4.67.0 +transformers==4.45.2 +triton==3.0.0 +trl==0.11.4 +typing_extensions==4.12.2 +tyro==0.8.14 +tzdata==2024.2 +urllib3==2.2.3 +xxhash==3.5.0 +yarl==1.17.1 diff --git a/scripts/benchmarks/scenarios-granite.yaml b/scripts/benchmarks/scenarios-granite.yaml index 2e5d0cf9..2221797b 100644 --- a/scripts/benchmarks/scenarios-granite.yaml +++ b/scripts/benchmarks/scenarios-granite.yaml @@ -108,3 +108,19 @@ scenarios: # target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] # model_name_or_path: # - 'ibm/PowerLM-3b' + + - name: accelerated-moe + framework_config: + - # without acceleration + - moe-scattermoe-granite + # add pf + slow: True + arguments: + learning_rate: 5e-5 + torch_dtype: bfloat16 + gradient_accumulation_steps: 16 + logging_steps: 1 + packing: False + adam_epsilon: 1e-8 + model_name_or_path: + - 'ibm/PowerMoE-3b' diff --git a/scripts/benchmarks/scenarios-moe.yaml b/scripts/benchmarks/scenarios-moe.yaml new file mode 100644 index 00000000..efa2725e --- /dev/null +++ b/scripts/benchmarks/scenarios-moe.yaml @@ -0,0 +1,78 @@ +# This file holds a list of scenarios to may be run. +# - to limit to a number of scenarios, use the --run-only-scenarios flag. +# - Each scenario will be run against a particular acceleration framework +# config, if the framework_config: key is specified. +# * a particular framework configuration +# - the arguments tag will hold arguments to be passed to sft_trainer +# * the arguments are singular except for model_name_or_path which can handle +# multiple arguments. +# - So anything that is critical for the scenario MUST be specified here +# and not in the defaults, e.g. fp16 + +# This stanza will be used in future to replace the custom processing functions in data_processing.py +# data_processing: +# dataset_name: yahma/alpaca-cleaned +# chat_template: | +# {%- for message in messages %} +# {% if message['input'] != '' %} +# Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. + +# {% else %} +# Below is an instruction that describes a task. Write a response that appropriately completes the request. + +# {% endif %} +# ### Instruction: +# {{ message['instruction'] }} + +# {% if message['input'] != '' %} +# ### Input: +# {{ message['input'] }} + +# {% endif %} +# ### Response: +# {{ message['output'] + eos_token }} +# {% endfor %} +# tokenize: True + + +scenarios: + - name: accelerated-moe-full + framework_config: + - # without acceleration + - moe-scattermoe-granite-ep1 + - moe-scattermoe-granite-ep2 + - moe-scattermoe-granite-ep4 + - moe-scattermoe-granite-ep1-padding-free + - moe-scattermoe-granite-ep1-padding-free-foak + - moe-scattermoe-granite-ep2-padding-free + - moe-scattermoe-granite-ep2-padding-free-foak + - moe-scattermoe-granite-ep4-padding-free + - moe-scattermoe-granite-ep4-padding-free-foak + arguments: + learning_rate: 5e-5 + torch_dtype: bfloat16 + gradient_accumulation_steps: null + per_device_train_batch_size: 8 + logging_steps: 1 + packing: False + adam_epsilon: 1e-8 + model_name_or_path: + - 'ibm-granite/granite-3.0-3b-a800m-instruct' + + - name: accelerated-moe-full-mixtral + framework_config: + - # without acceleration + - moe-scattermoe-granite-ep8 + - moe-scattermoe-granite-ep8-foak + slow: True + arguments: + learning_rate: 5e-5 + torch_dtype: bfloat16 + accelerator_config: scripts/benchmarks/accelerator-config.json + gradient_accumulation_steps: null + per_device_train_batch_size: 1 + logging_steps: 1 + packing: False + adam_epsilon: 1e-8 + model_name_or_path: + - 'mistralai/Mixtral-8x7B-Instruct-v0.1' \ No newline at end of file diff --git a/scripts/generate_sample_configurations.py b/scripts/generate_sample_configurations.py index bd304959..ff775c8e 100644 --- a/scripts/generate_sample_configurations.py +++ b/scripts/generate_sample_configurations.py @@ -148,6 +148,10 @@ def read_configuration(path: str) -> Dict: KEY_AADP_PADDING_FREE = "aadp-padding-free" KEY_AADP_MULTIPACK = "aadp-multipack" KEY_FAST_KERNELS = "foak-fast-kernels" +KEY_SCATTERMOE_EP1 = "moe-scattermoe-ep1" +KEY_SCATTERMOE_EP2 = 'moe-scattermoe-ep2' +KEY_SCATTERMOE_EP4 = 'moe-scattermoe-ep4' +KEY_SCATTERMOE_EP8 = 'moe-scattermoe-ep8' CONFIGURATIONS = { KEY_AUTO_GPTQ: "plugins/accelerated-peft/configs/autogptq.yaml", @@ -173,6 +177,19 @@ def read_configuration(path: str) -> Dict: KEY_AADP_PADDING_FREE: "plugins/attention-and-distributed-packing/configs/padding_free.yaml", KEY_AADP_MULTIPACK: "plugins/attention-and-distributed-packing/configs/multipack.yaml", KEY_FAST_KERNELS: "plugins/fused-ops-and-kernels/configs/fast_kernels.yaml", + KEY_SCATTERMOE_EP1: "plugins/accelerated-moe/configs/scattermoe.yaml", + KEY_SCATTERMOE_EP2: ( + "plugins/accelerated-moe/configs/scattermoe.yaml", + [("training.moe.scattermoe.ep_degree", 2)], + ), + KEY_SCATTERMOE_EP4: ( + "plugins/accelerated-moe/configs/scattermoe.yaml", + [("training.moe.scattermoe.ep_degree", 4)], + ), + KEY_SCATTERMOE_EP8: ( + "plugins/accelerated-moe/configs/scattermoe.yaml", + [("training.moe.scattermoe.ep_degree", 8)], + ), } # list of (tag, combi) tuples @@ -192,7 +209,18 @@ def read_configuration(path: str) -> Dict: ("accelerated-peft-autogptq-foak-padding-free", (KEY_AADP_PADDING_FREE,KEY_AUTO_GPTQ, KEY_AUTO_GPTQ_FOAK)), ("accelerated-peft-bnb-nf4-foak-padding-free", (KEY_AADP_PADDING_FREE,KEY_BNB_NF4, KEY_BNB_NF4_FOAK)), ("aadp-padding-free-multipack", (KEY_AADP_PADDING_FREE, KEY_AADP_MULTIPACK)), - ("foak-fast-kernels", (KEY_FAST_KERNELS,)) + ("foak-fast-kernels", (KEY_FAST_KERNELS,)), + ("moe-scattermoe-granite-ep1", (KEY_SCATTERMOE_EP1,)), + ("moe-scattermoe-granite-ep1-padding-free", (KEY_AADP_PADDING_FREE, KEY_SCATTERMOE_EP1,)), + ("moe-scattermoe-granite-ep1-padding-free-foak", (KEY_AADP_PADDING_FREE, KEY_FAST_KERNELS, KEY_SCATTERMOE_EP1,)), + ("moe-scattermoe-granite-ep2", (KEY_SCATTERMOE_EP2,)), + ("moe-scattermoe-granite-ep2-padding-free", (KEY_AADP_PADDING_FREE, KEY_SCATTERMOE_EP2,)), + ("moe-scattermoe-granite-ep2-padding-free-foak", (KEY_AADP_PADDING_FREE, KEY_FAST_KERNELS, KEY_SCATTERMOE_EP2,)), + ("moe-scattermoe-granite-ep4", (KEY_SCATTERMOE_EP4,)), + ("moe-scattermoe-granite-ep4-padding-free", (KEY_AADP_PADDING_FREE, KEY_SCATTERMOE_EP4,)), + ("moe-scattermoe-granite-ep4-padding-free-foak", (KEY_AADP_PADDING_FREE, KEY_FAST_KERNELS, KEY_SCATTERMOE_EP4,)), + ("moe-scattermoe-granite-ep8", (KEY_SCATTERMOE_EP8,)), + ("moe-scattermoe-granite-ep8-foak", (KEY_FAST_KERNELS, KEY_SCATTERMOE_EP8,)), ] diff --git a/tox.ini b/tox.ini index d29f1f24..a62ae961 100644 --- a/tox.ini +++ b/tox.ini @@ -29,7 +29,7 @@ commands = # need a version of fms-hf-tuning that has integrated the framework # NOTE: have to install this first coz havnt merged # - this repo has a lot of pins, so we just install it first - pip install "fms-hf-tuning[flash-attn] @ git+https://github.com/foundation-model-stack/fms-hf-tuning.git@"{env:FHT_BRANCH:main} + pip install "fms-hf-tuning @ git+https://github.com/foundation-model-stack/fms-hf-tuning.git@"{env:FHT_BRANCH:main} # some models need this for tokenizers pip install protobuf @@ -39,6 +39,10 @@ commands = python -m fms_acceleration.cli install -e {toxinidir}/plugins/accelerated-peft python -m fms_acceleration.cli install -e {toxinidir}/plugins/fused-ops-and-kernels python -m fms_acceleration.cli install -e {toxinidir}/plugins/attention-and-distributed-packing + python -m fms_acceleration.cli install -e {toxinidir}/plugins/accelerated-moe + + # install the flash attn at the last + pip install flash-attn # run the benchmark script bash scripts/run_benchmarks.sh {posargs:"1 2" "4 8" benchmark_outputs}