From 8b1a9b2c8fc228ae01d52e953bbb842eefa97913 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 27 Nov 2023 23:05:59 -0500 Subject: [PATCH 1/5] modify configs/125M.yml to run without axonn --- configs/125M.yml | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/configs/125M.yml b/configs/125M.yml index 15a4b3b01..9f499248e 100644 --- a/configs/125M.yml +++ b/configs/125M.yml @@ -39,7 +39,7 @@ # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training "zero_optimization": { - "stage": 1, + "stage": 0, "allgather_partitions": True, "allgather_bucket_size": 500000000, "overlap_comm": True, @@ -84,11 +84,15 @@ "eval_iters": 10, # logging - "log_interval": 100, + "log_interval": 1, "steps_per_print": 10, "keep_last_n_checkpoints": 4, "wall_clock_breakdown": true, - # networking - "hostfile": "/mock_path" + + "data-path": "./data/enwik8/enwik8_text_document", + "vocab-file": "./data/gpt2-vocab.json", + "merge-file": "./data/gpt2-merges.txt" + # networking + #"hostfile": "/mock_path" } From 20d42289a44b1851af5b5f16c11620bcc4442aa6 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 28 Nov 2023 00:17:16 -0500 Subject: [PATCH 2/5] test each tp dim individually set to 2 --- configs/125M.yml | 8 +- megatron/initialize.py | 22 +++- megatron/model/transformer.py | 165 ++++++++++++++++++++------- megatron/neox_arguments/neox_args.py | 5 + 4 files changed, 156 insertions(+), 44 deletions(-) diff --git a/configs/125M.yml b/configs/125M.yml index 9f499248e..944bb88ae 100644 --- a/configs/125M.yml +++ b/configs/125M.yml @@ -3,7 +3,13 @@ # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages # across the node boundaries ) "pipe_parallel_size": 1, - "model_parallel_size": 1, + "model_parallel_size": 2, + + "use_axonn_model_parallelism": true, + ## these are the 3 dimensions of AxoNN's TP + "depth_model_parallel_size": 1, + "row_model_parallel_size": 1, + "column_model_parallel_size": 2, # model settings "num_layers": 12, diff --git a/megatron/initialize.py b/megatron/initialize.py index bc4032649..6ab7278e7 100644 --- a/megatron/initialize.py +++ b/megatron/initialize.py @@ -29,7 +29,7 @@ import deepspeed import inspect - +from axonn import axonn as ax def initialize_megatron(neox_args, allow_no_cuda=False): """Set initialize distributed and set autoresume and random seeds. @@ -188,6 +188,26 @@ def _initialize_distributed(neox_args): fp32_allreduce=neox_args.fp32_allreduce, ) + + + if neox_args.use_axonn_model_parallelism: + row_mp = neox_args.row_model_parallel_size + column_mp = neox_args.column_model_parallel_size + depth_mp = neox_args.depth_model_parallel_size + assert row_mp * column_mp * depth_mp == neox_args.model_parallel_size, "product of row-model-parallel-size, column-model-parallel-sizem and depth-model-parallel-size should equal model-parallel-size" + ax.init( + G_inter= pp, + G_data = dp, + G_intra_r = neox_args.row_model_parallel_size, + G_intra_c = neox_args.column_model_parallel_size, + G_intra_d = neox_args.depth_model_parallel_size, + ) + print( + f"> initialized AxoNN with G_intra_r={neox_args.row_model_parallel_size}," + f"G_intra_c={neox_args.column_model_parallel_size}", + f"G_intra_d={neox_args.depth_model_parallel_size}", + ) + # Init DeepSpeed Activation Checkpointing Features setup_deepspeed_random_and_activation_checkpointing(neox_args=neox_args) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 63f4122e2..7cde00dd9 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -40,6 +40,8 @@ ) from megatron.model.utils import configure_sparse_attention +from axonn.intra_layer import Linear, drop, gather + # flags required to enable jit fusion kernels torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) @@ -93,30 +95,57 @@ def __init__( if self.activation_type == "geglu" else ff_mult * neox_args.hidden_size ) - self.dense_h_to_4h = mpu.ColumnParallelLinear( - neox_args=neox_args, - input_size=neox_args.hidden_size, - output_size=ff_dim, - gather_output=False, - init_method=init_method, - skip_bias_add=True, - ) + if neox_args.use_axonn_model_parallelism: + self.dense_h_to_4h = Linear( + in_features = neox_args.hidden_size, + out_features = ff_dim, + init_method = init_method, + skip_bias_add = True + ) + else: + self.dense_h_to_4h = mpu.ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=ff_dim, + gather_output=False, + init_method=init_method, + skip_bias_add=True, + ) ff_dim_in = ff_dim // 2 if self.activation_type == "geglu" else ff_dim # Project back to h. - self.dense_4h_to_h = mpu.RowParallelLinear( - neox_args=neox_args, - input_size=ff_dim_in, - output_size=neox_args.hidden_size, - input_is_parallel=True, - init_method=output_layer_init_method, - skip_bias_add=True, - parallel_output=parallel_output, - ) + + if neox_args.use_axonn_model_parallelism: + self.dense_4h_to_h = Linear( + in_features = ff_dim_in, + out_features = neox_args.hidden_size, + init_method = output_layer_init_method, + skip_bias_add = True, + transpose=True + ) + assert not parallel_output, "ToDO: Implement axonn support for parallel_output=True (gpt j residual)" + + else: + self.dense_4h_to_h = mpu.RowParallelLinear( + neox_args=neox_args, + input_size=ff_dim_in, + output_size=neox_args.hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + skip_bias_add=True, + parallel_output=parallel_output, + ) + + + self.use_axonn_model_parallelism = neox_args.use_axonn_model_parallelism def forward(self, hidden_states): # [s, b, 4hp] - intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) + if self.use_axonn_model_parallelism: + intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states, + scatter_input=False, gather_output=False) + else: + intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) if ( self.activation_type == "gelu" and self.bias_gelu_fusion @@ -130,7 +159,11 @@ def forward(self, hidden_states): ) # [s, b, h] - output, output_bias = self.dense_4h_to_h(intermediate_parallel) + if self.use_axonn_model_parallelism: + output, output_bias = self.dense_4h_to_h(intermediate_parallel, + scatter_input=False, gather_output=False) + else: + output, output_bias = self.dense_4h_to_h(intermediate_parallel) return output, output_bias @@ -162,6 +195,9 @@ def __init__( ff_dim = int(2 * neox_args.hidden_size * 4 / 3) ff_dim = self.multiple_of * ((ff_dim + multiple_of - 1) // multiple_of) + + assert not neox_args.use_axonn_model_parallelism, "ToDo: Implement AxoNN TP for LLaMAParallelMLP" + self.w1 = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, @@ -275,7 +311,10 @@ def __init__( self.attention_softmax_in_fp32 = True self.layer_number = layer_number # Per attention head and per partition values. - world_size = mpu.get_model_parallel_world_size() + if neox_args.use_axonn_model_parallelism: + world_size = neox_args.row_model_parallel_size + else: + world_size = mpu.get_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size) self.hidden_size_per_attention_head = mpu.divide( neox_args.hidden_size, neox_args.num_attention_heads @@ -286,14 +325,24 @@ def __init__( self.pos_emb = neox_args.pos_emb # Strided linear layer. - self.query_key_value = mpu.ColumnParallelLinear( - neox_args=neox_args, - input_size=neox_args.hidden_size, - output_size=3 * neox_args.hidden_size, - gather_output=False, - init_method=init_method, - bias=neox_args.use_bias_in_attn_linear, - ) + self.use_axonn_model_parallelism = neox_args.use_axonn_model_parallelism + if neox_args.use_axonn_model_parallelism: + self.query_key_value = Linear( + in_features=neox_args.hidden_size, + out_features=3 * neox_args.hidden_size, + init_method=init_method, + bias=neox_args.use_bias_in_attn_linear, + skip_bias_add=True + ) + else: + self.query_key_value = mpu.ColumnParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=3 * neox_args.hidden_size, + gather_output=False, + init_method=init_method, + bias=neox_args.use_bias_in_attn_linear, + ) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) @@ -377,16 +426,27 @@ def __init__( self.attention_dropout = nn.Dropout(self.dropout_p) # Output. - self.dense = mpu.RowParallelLinear( - neox_args=neox_args, - input_size=neox_args.hidden_size, - output_size=neox_args.hidden_size, - input_is_parallel=True, - init_method=output_layer_init_method, - skip_bias_add=True, - parallel_output=parallel_output, - bias=neox_args.use_bias_in_attn_linear, - ) + if neox_args.use_axonn_model_parallelism: + self.dense = Linear( + in_features=neox_args.hidden_size, + out_features=neox_args.hidden_size, + init_method=output_layer_init_method, + skip_bias_add=True, + bias=neox_args.use_bias_in_attn_linear, + transpose=True + ) + assert not parallel_output, "ToDO: Implement axonn support for parallel_output=True (gpt j residual)" + else: + self.dense = mpu.RowParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=neox_args.hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + skip_bias_add=True, + parallel_output=parallel_output, + bias=neox_args.use_bias_in_attn_linear, + ) def attention( self, query_layer, key_layer, value_layer, layer_past, attention_mask @@ -625,7 +685,10 @@ def forward(self, hidden_states, attention_mask, layer_past=None): # ===================== # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] - mixed_x_layer, _ = self.query_key_value(hidden_states) + if self.use_axonn_model_parallelism: + mixed_x_layer, _ = self.query_key_value(hidden_states, scatter_input=False, gather_output=False) + else: + mixed_x_layer, _ = self.query_key_value(hidden_states) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + ( @@ -710,7 +773,10 @@ def forward(self, hidden_states, attention_mask, layer_past=None): # Output. [sq, b, h] # ================= - output, bias = self.dense(context_layer) + if self.use_axonn_model_parallelism: + output, bias = self.dense(context_layer, scatter_input=False, gather_output=False) + else: + output, bias = self.dense(context_layer) if self.use_cache: output = [output, present] @@ -739,11 +805,17 @@ def __init__( super().__init__() self.layer_number = layer_number + self.is_first_layer = ( layer_number == 0 ) + self.is_last_layer = ( layer_number == neox_args.num_layers - 1 ) norm, eps = get_norm(neox_args) # Layernorm on the input data. - self.input_layernorm = norm(neox_args.hidden_size, eps=eps) + if neox_args.use_axonn_model_parallelism: + self.input_layernorm = norm(mpu.divide(neox_args.hidden_size, + neox_args.column_model_parallel_size), eps=eps) + else: + self.input_layernorm = norm(neox_args.hidden_size, eps=eps) self.use_cache = use_cache self.hidden_dropout = neox_args.hidden_dropout @@ -771,7 +843,11 @@ def __init__( # Layernorm on the output of the attention layer. # If GPT-J residuals are used, this is surpurfulous but leaving it in # leads to cleaner code - self.post_attention_layernorm = norm(neox_args.hidden_size, eps=eps) + if neox_args.use_axonn_model_parallelism: + self.post_attention_layernorm = norm(mpu.divide(neox_args.hidden_size, + neox_args.column_model_parallel_size), eps=eps) + else: + self.post_attention_layernorm = norm(neox_args.hidden_size, eps=eps) # MLP if neox_args.mlp_type == "regular": @@ -807,6 +883,9 @@ def _get_bias_dropout(self): def forward(self, x, attention_mask, layer_past=None): layer_past = layer_past if layer_past is not None else self.layer_past bias_dropout_fn = self._get_bias_dropout() + + if self.is_first_layer: + x = drop(x, batch_dim=1) # x: [b, s, h] if self.gpt_j_residual: # pseudocode: @@ -904,6 +983,8 @@ def forward(self, x, attention_mask, layer_past=None): prob=self.hidden_dropout, ) + if self.is_last_layer: + output = gather(output, batch_dim=1) return output diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 957960832..7935c9e6c 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -61,6 +61,11 @@ class NeoXArgsParallelism(NeoXArgsTemplate): """ model_parallel_size: int = 1 + use_axonn_model_parallelism: bool = False + row_model_parallel_size: int = 1 + column_model_parallel_size: int = 1 + depth_model_parallel_size: int = 1 + """ Size of the model parallelism. """ From f1c40e27288c88667544ad65ad0cc2ed1d31c93c Mon Sep 17 00:00:00 2001 From: github-actions Date: Tue, 28 Nov 2023 05:19:23 +0000 Subject: [PATCH 3/5] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index bc2e8fc57..8d0c64cb9 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 2da1083 + Default = 20d4228 current git hash of repository @@ -858,6 +858,38 @@ Parallelism Arguments Default = 1 + + + + +- **use_axonn_model_parallelism**: bool + + Default = False + + + + + +- **row_model_parallel_size**: int + + Default = 1 + + + + + +- **column_model_parallel_size**: int + + Default = 1 + + + + + +- **depth_model_parallel_size**: int + + Default = 1 + Size of the model parallelism. From 23709b643a60389c3f3cbc3f050792b213ef5010 Mon Sep 17 00:00:00 2001 From: github-actions Date: Tue, 19 Dec 2023 19:02:28 +0000 Subject: [PATCH 4/5] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 89e73b218..12e442c6e 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = bb1b145 + Default = 7438b33 current git hash of repository From 28443c6dd15b7e5e34f3177ba66ff4f291ae1868 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Thu, 18 Jan 2024 05:53:26 -0500 Subject: [PATCH 5/5] add communication optimizations part 1 --- configs/125M.yml | 8 +++-- megatron/initialize.py | 1 + megatron/neox_arguments/neox_args.py | 1 + megatron/training.py | 52 ++++++++++++++++++++-------- 4 files changed, 45 insertions(+), 17 deletions(-) diff --git a/configs/125M.yml b/configs/125M.yml index 944bb88ae..26ca08155 100644 --- a/configs/125M.yml +++ b/configs/125M.yml @@ -2,14 +2,16 @@ { # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages # across the node boundaries ) - "pipe_parallel_size": 1, + "pipe_parallel_size": 0, "model_parallel_size": 2, + ## axonn's arguments "use_axonn_model_parallelism": true, ## these are the 3 dimensions of AxoNN's TP - "depth_model_parallel_size": 1, + "depth_model_parallel_size": 2, "row_model_parallel_size": 1, - "column_model_parallel_size": 2, + "column_model_parallel_size": 1, + "optimize_axonn_communication": true, # model settings "num_layers": 12, diff --git a/megatron/initialize.py b/megatron/initialize.py index 6ab7278e7..a4196350e 100644 --- a/megatron/initialize.py +++ b/megatron/initialize.py @@ -195,6 +195,7 @@ def _initialize_distributed(neox_args): column_mp = neox_args.column_model_parallel_size depth_mp = neox_args.depth_model_parallel_size assert row_mp * column_mp * depth_mp == neox_args.model_parallel_size, "product of row-model-parallel-size, column-model-parallel-sizem and depth-model-parallel-size should equal model-parallel-size" + assert neox_args.pipe_parallel_size == 0, "AxoNN's tensor parallelism has not been tested with pipeline parallelism" ax.init( G_inter= pp, G_data = dp, diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 001d2ac2c..a507662b2 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -65,6 +65,7 @@ class NeoXArgsParallelism(NeoXArgsTemplate): row_model_parallel_size: int = 1 column_model_parallel_size: int = 1 depth_model_parallel_size: int = 1 + optimize_axonn_communication: bool = False """ Size of the model parallelism. diff --git a/megatron/training.py b/megatron/training.py index 20c168ed3..ff1936a93 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -55,8 +55,13 @@ CharCounter, ) from megatron.model.gpt2_model import cross_entropy +from megatron.model.transformer import ParallelTransformerLayer from eval_tasks import run_eval_harness +from axonn.intra_layer import optimize_communication, sync_gradients +from contextlib import nullcontext +from functools import partial + def mup_weights_reinit(neox_args, model): def has_method(o, name): @@ -730,27 +735,46 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler) ) else: losses = [] + if neox_args.use_axonn_model_parallelism and neox_args.optimize_axonn_communication: + ctx = partial( + optimize_communication, + overlap_all_reduce=True, + overlap_reduce_scatter=True, + overlap_all_gather=True, + model_object_for_overlapping_allgathers=model + ) + else: + ctx = nullcontext + for _ in range(neox_args.gradient_accumulation_steps): # Forward model for one step. timers("forward").start() - loss = forward_step( - neox_args=neox_args, - timers=timers, - data_iterator=data_iterator, - model=model, - is_train=True, - ) + with ctx(): + loss = forward_step( + neox_args=neox_args, + timers=timers, + data_iterator=data_iterator, + model=model, + is_train=True, + ) timers("forward").stop() losses.append(loss) # Calculate gradients, reduce across processes, and clip. timers("backward").start() - backward_step( - neox_args=neox_args, - timers=timers, - optimizer=optimizer, - model=model, - loss=loss, - ) + with ctx(): + backward_step( + neox_args=neox_args, + timers=timers, + optimizer=optimizer, + model=model, + loss=loss, + ) + if neox_args.use_axonn_model_parallelism: + modules_to_sync = [] + for module in model.modules(): + if isinstance(module, ParallelTransformerLayer): + sync_gradients(module) + #sync_gradients(model) timers("backward").stop() # Update parameters. timers("optimizer").start()