From be17f4bdba9728bb7b8003b582151b620502f51b Mon Sep 17 00:00:00 2001 From: Kye Gomez Date: Sun, 1 Sep 2024 17:07:39 -0400 Subject: [PATCH] [CLEANUP] --- README.md | 25 +++++---- audio_encoder.py | 2 +- docs/zeta/index.md | 4 +- docs/zeta/nn/attention/multiquery.md | 4 +- docs/zeta/nn/modules/averagemodelmerger.md | 16 +++--- docs/zeta/nn/modules/fused_gelu_dense.md | 4 +- docs/zeta/nn/modules/slerpmodelmerger.md | 4 +- docs/zeta/quant/niva.md | 8 +-- docs/zeta/structs/autoregressivewrapper.md | 4 +- docs/zeta/structs/encoder.md | 12 ++-- docs/zeta/utils/main.md | 6 +- docs/zeta/utils/save_load.md | 2 +- docs/zeta/utils/save_load_wrapper.md | 8 +-- multi_query_attention.py | 16 ++++++ pyproject.toml | 2 +- tests/nn/attentions/test_mhaa.py | 8 +-- tests/nn/attentions/test_mqa.py | 10 ++-- tests/nn/modules/test_alr_block.py | 6 +- tests/nn/modules/test_avg_model_merger.py | 12 ++-- tests/nn/modules/test_full_feedforward.py | 62 ++++++++++----------- tests/nn/modules/test_slerp_model_merger.py | 6 +- tests/rl/test_vision_reward_model.py | 12 ++-- zeta/nn/attention/dilated_attention.py | 16 +++--- zeta/nn/attention/multiquery_attention.py | 40 ++++++------- zeta/nn/embeddings/positional.py | 2 +- zeta/nn/modules/README.md | 2 +- zeta/nn/modules/avg_model_merger.py | 16 +++--- zeta/nn/modules/deepseek_moe.py | 8 +-- zeta/nn/modules/g_shard_moe.py | 8 +-- zeta/nn/modules/gill_mapper.py | 4 +- zeta/nn/modules/mixtape.py | 14 ++--- zeta/nn/modules/slerp_model_merger.py | 16 +++--- zeta/nn/modules/subln.py | 18 +++--- zeta/nn/modules/xmoe/moe_layer.py | 8 +-- zeta/ops/async_softmax.py | 8 +-- zeta/rl/__init__.py | 2 +- zeta/rl/vision_model_rl.py | 4 +- zeta/structs/auto_regressive_wrapper.py | 4 +- zeta/structs/clip_encoder.py | 4 +- zeta/training/train.py | 4 +- zeta/utils/main.py | 2 +- 41 files changed, 217 insertions(+), 196 deletions(-) create mode 100644 multi_query_attention.py diff --git a/README.md b/README.md index a3175a6f..a35874fc 100644 --- a/README.md +++ b/README.md @@ -34,21 +34,26 @@ $ pip3 install -U zetascale ## Starting Your Journey -Creating a model empowered with the aforementioned breakthrough research features is a breeze. Here's how to quickly materialize the renowned Flash Attention +Creating a model empowered with the aforementioned breakthrough research features is a breeze. Here's how to quickly materialize the renowned Multi Query Attention ```python import torch +from zeta import MultiQueryAttention -from zeta.nn import FlashAttention - -q = torch.randn(2, 4, 6, 8) -k = torch.randn(2, 4, 10, 8) -v = torch.randn(2, 4, 10, 8) +# Model +model = MultiQueryAttention( + dim=512, + heads=8, +) -attention = FlashAttention(causal=False, dropout=0.1, flash=True) -output = attention(q, k, v) +# Input +text = torch.randn(2, 4, 512) +# Output +output, _, _ = model(text) print(output.shape) +print(output) + ``` @@ -352,8 +357,8 @@ model = YourModelClass() # Quantize the model dynamically, specifying layers to quantize niva( model=model, - model_path="path_to_pretrained_model_weights.pt", - output_path="quantized_model.pt", + model_path="path_to_pretrainedim_weights.pt", + output_path="quantizedim.pt", quant_type="dynamic", quantize_layers=[nn.Linear, nn.Conv2d], dtype=torch.qint8, diff --git a/audio_encoder.py b/audio_encoder.py index 17b833a3..54351edf 100644 --- a/audio_encoder.py +++ b/audio_encoder.py @@ -66,7 +66,7 @@ def __init__( ) transformer_encoder_layer = TransformerEncoderLayer( - d_model=cnn_channels * 8, + dim=cnn_channels * 8, nhead=nhead, dim_feedforward=dim_feedforward, ) diff --git a/docs/zeta/index.md b/docs/zeta/index.md index fe01fa10..648704b3 100644 --- a/docs/zeta/index.md +++ b/docs/zeta/index.md @@ -327,8 +327,8 @@ model = YourModelClass() # Quantize the model dynamically, specifying layers to quantize niva( model=model, - model_path="path_to_pretrained_model_weights.pt", - output_path="quantized_model.pt", + model_path="path_to_pretrainedim_weights.pt", + output_path="quantizedim.pt", quant_type="dynamic", quantize_layers=[nn.Linear, nn.Conv2d], dtype=torch.qint8, diff --git a/docs/zeta/nn/attention/multiquery.md b/docs/zeta/nn/attention/multiquery.md index 88aabb46..b857b6fd 100644 --- a/docs/zeta/nn/attention/multiquery.md +++ b/docs/zeta/nn/attention/multiquery.md @@ -15,7 +15,7 @@ class MultiQueryAttention(nn.Module): ``` ### Parameters: -- `d_model` (int): Dimension of the model. +- `dim` (int): Dimension of the model. - `heads` (int): Number of parallel attention heads. - `attn_impl` (str, optional): Attention implementation type, can be either 'triton', 'flash', or 'torch'. Default is 'triton'. - `clip_qkv` (Optional[float]): Clipping value for query, key, and value. If specified, qkv is clamped within the range [-clip_qkv, clip_qkv]. @@ -68,7 +68,7 @@ import torch from zeta.nn import MultiQueryAttention # Initialize the attention module -attention_layer = MultiQueryAttention(d_model=512, heads=8, attn_impl="torch") +attention_layer = MultiQueryAttention(dim=512, heads=8, attn_impl="torch") # Random input tensor x = torch.rand(16, 10, 512) # Batch of 16, sequence length 10, embedding size 512 diff --git a/docs/zeta/nn/modules/averagemodelmerger.md b/docs/zeta/nn/modules/averagemodelmerger.md index c62454a6..139e6ccf 100644 --- a/docs/zeta/nn/modules/averagemodelmerger.md +++ b/docs/zeta/nn/modules/averagemodelmerger.md @@ -19,8 +19,8 @@ class AverageModelMerger: model2 = nn.Linear(in_features=10, out_features=10) model3 = nn.Linear(in_features=10, out_features=10) merge = AverageModelMerger([model1, model2, model3]) - merged_model = merge.merge_models() - print(merged_model) + mergedim = merge.merge_models() + print(mergedim) """ ``` @@ -80,10 +80,10 @@ model3 = nn.Linear(in_features=10, out_features=10) merger = AverageModelMerger([model1, model2, model3]) # Merge models -merged_model = merger.merge_models() +mergedim = merger.merge_models() # Print merged model -print(merged_model) +print(mergedim) ``` ### Example 2 @@ -101,10 +101,10 @@ model3 = nn.Conv2d(3, 6, 5) merger = AverageModelMerger([model1, model2, model3]) # Merge models -merged_model = merger.merge_models() +mergedim = merger.merge_models() # Print merged model -print(merged_model) +print(mergedim) ``` ### Example 3 @@ -122,10 +122,10 @@ model3 = nn.CrossEntropyLoss() merger = AverageModelMerger([model1, model2, model3]) # Merge models -merged_model = merger.merge_models() +mergedim = merger.merge_models() # Print merged model -print(merged_model) +print(mergedim) ``` All the examples above demonstrate the basic usage of this class. In cases where you have multiple trained models (e.g., resultant from a k-fold cross-validation or models trained on different datasets), you can use this class to merge or average their weights. The resultant model will carry averaged weights, giving a balanced representation of all the models. diff --git a/docs/zeta/nn/modules/fused_gelu_dense.md b/docs/zeta/nn/modules/fused_gelu_dense.md index a83c6457..fe048f3f 100644 --- a/docs/zeta/nn/modules/fused_gelu_dense.md +++ b/docs/zeta/nn/modules/fused_gelu_dense.md @@ -117,7 +117,7 @@ import torch from zeta.nn import FusedDenseGELUDense # Create an instance of FusedDenseGELUDense with quantization -quantized_model = FusedDenseGELUDense( +quantizedim = FusedDenseGELUDense( dim=512, dim_out=1024, has_fp16_weights=True, threshold=4.0 ) @@ -125,7 +125,7 @@ quantized_model = FusedDenseGELUDense( x = torch.randn(1, 512) # Forward pass with quantization -out = quantized_model(x) +out = quantizedim(x) ``` ## 7. Additional Information diff --git a/docs/zeta/nn/modules/slerpmodelmerger.md b/docs/zeta/nn/modules/slerpmodelmerger.md index e3041329..91fe7296 100644 --- a/docs/zeta/nn/modules/slerpmodelmerger.md +++ b/docs/zeta/nn/modules/slerpmodelmerger.md @@ -49,10 +49,10 @@ model1 = nn.Linear(10, 10) model2 = nn.Linear(10, 10) merger = SLERPModelMerger(model1, model2, 0.5) -merged_model = merger.merge() +mergedim = merger.merge() # This will output the merged state_dict -print(merged_model.state_dict()) +print(mergedim.state_dict()) ``` The prints statement will output the state_dict of the merged model. The state_dict is a Python dictionary that maps each layer to its corresponding parameters (tensors). diff --git a/docs/zeta/quant/niva.md b/docs/zeta/quant/niva.md index 58e967a3..5bf78414 100644 --- a/docs/zeta/quant/niva.md +++ b/docs/zeta/quant/niva.md @@ -70,8 +70,8 @@ model = YourModelClass() # Quantize the model dynamically, specifying layers to quantize niva( model=model, - model_path="path_to_pretrained_model_weights.pt", - output_path="quantized_model.pt", + model_path="path_to_pretrainedim_weights.pt", + output_path="quantizedim.pt", quant_type="dynamic", quantize_layers=[nn.Linear, nn.Conv2d], dtype=torch.qint8, @@ -93,8 +93,8 @@ model = YourModelClass() # Quantize the entire model statically niva( model=model, - model_path="path_to_pretrained_model_weights.pt", - output_path="quantized_model.pt", + model_path="path_to_pretrainedim_weights.pt", + output_path="quantizedim.pt", quant_type="static", dtype=torch.qint8, ) diff --git a/docs/zeta/structs/autoregressivewrapper.md b/docs/zeta/structs/autoregressivewrapper.md index a4d1cd9f..82cc2b34 100644 --- a/docs/zeta/structs/autoregressivewrapper.md +++ b/docs/zeta/structs/autoregressivewrapper.md @@ -76,7 +76,7 @@ This method is particularly useful for generating multiple forecasted sequence p The `evaluate_and_select_best_solution()` method evaluates the solutions based on a reward model and returns the best one. ```python -def evaluate_and_select_best_solution(self, solutions, reward_model) +def evaluate_and_select_best_solution(self, solutions, rewardim) ``` @@ -113,7 +113,7 @@ The third example shows generating multiple solutions and selecting the best one ```python solutions = net.generate_n_solutions(start_tokens, n=5, seqlen=10) best_solution = net.evaluate_and_select_best_solution( - solutions, reward_model=lambda x: -x.sum() + solutions, rewardim=lambda x: -x.sum() ) ``` In the example above, the reward model simply returns the negative sum of the sequence, and the solution with lowest sum is selected as the best solution. diff --git a/docs/zeta/structs/encoder.md b/docs/zeta/structs/encoder.md index dd30767b..736661e0 100644 --- a/docs/zeta/structs/encoder.md +++ b/docs/zeta/structs/encoder.md @@ -33,9 +33,9 @@ from zeta.structs import AttentionLayers class MyEncoder(AttentionLayers): - def __init__(self, d_model, nhead, num_layers): - super().__init__(d_model=d_model, nhead=nhead, num_layers=num_layers) - self.linear = nn.Linear(d_model, d_model) + def __init__(self, dim, nhead, num_layers): + super().__init__(dim=dim, nhead=nhead, num_layers=num_layers) + self.linear = nn.Linear(dim, dim) def forward(self, x): x = super().forward(x) @@ -47,16 +47,16 @@ We built a custom encoder by extending the AttentionLayers, added a linear layer Firstly, let's initialize the model: ```python -model = MyEncoder(d_model=512, nhead=8, num_layers=6) +model = MyEncoder(dim=512, nhead=8, num_layers=6) ``` -The model is initialized with the dimensions of model `d_model=512`, number of heads `nhead=8`, and the number of layers `num_layers=6`. +The model is initialized with the dimensions of model `dim=512`, number of heads `nhead=8`, and the number of layers `num_layers=6`. Now, let's define some dummy input data and pass it through the model: ```python import torch -x = torch.randn(10, 32, 512) # (sequence_length, batch_size, d_model) +x = torch.randn(10, 32, 512) # (sequence_length, batch_size, dim) output = model(x) # forward pass print(output.shape) # torch.Size([10, 32, 512]) ``` diff --git a/docs/zeta/utils/main.md b/docs/zeta/utils/main.md index 26502fc0..eed24618 100644 --- a/docs/zeta/utils/main.md +++ b/docs/zeta/utils/main.md @@ -599,7 +599,7 @@ output = resnet_block(x, time_emb=time_emb) print(output.shape) ``` -## Function: load_model(path) +## Function: loadim(path) Load a model from a file. ### Parameters: @@ -610,9 +610,9 @@ Load a model from a file. ### Example: ```python -from zeta.utils.main import load_model +from zeta.utils.main import loadim -model = load_model("model_checkpoint.pth") +model = loadim("model_checkpoint.pth") print(model) ``` diff --git a/docs/zeta/utils/save_load.md b/docs/zeta/utils/save_load.md index 4cabd585..07c303ac 100644 --- a/docs/zeta/utils/save_load.md +++ b/docs/zeta/utils/save_load.md @@ -82,7 +82,7 @@ model = MyModel(32, 10) model.save("model.pt") # Load your model -loaded_model = MyModel.load("model.pt") +loadedim = MyModel.load("model.pt") ``` ### Example 2: Using the `save_load` with non-default arguments diff --git a/docs/zeta/utils/save_load_wrapper.md b/docs/zeta/utils/save_load_wrapper.md index 0cc403c9..c5ef27c6 100644 --- a/docs/zeta/utils/save_load_wrapper.md +++ b/docs/zeta/utils/save_load_wrapper.md @@ -90,7 +90,7 @@ my_model = MyModel() my_model.save("my_model.pth") # Load the model checkpoint -loaded_model = MyModel.load("my_model.pth") +loadedim = MyModel.load("my_model.pth") ``` #### Custom Methods and Hooks @@ -171,13 +171,13 @@ class VersionedModel(Module): # Create an instance of VersionedModel -versioned_model = VersionedModel() +versionedim = VersionedModel() # Save the model checkpoint -versioned_model.save("versioned_model.pth") +versionedim.save("versionedim.pth") # Load the model checkpoint with version compatibility check -loaded_versioned_model = VersionedModel.load("versioned_model.pth") +loaded_versionedim = VersionedModel.load("versionedim.pth") ``` ## 5. Additional Information diff --git a/multi_query_attention.py b/multi_query_attention.py new file mode 100644 index 00000000..aad02933 --- /dev/null +++ b/multi_query_attention.py @@ -0,0 +1,16 @@ +import torch +from zeta import MultiQueryAttention + +# Model +model = MultiQueryAttention( + dim=512, + heads=8, +) + +# Input +text = torch.randn(2, 4, 512) + +# Output +output, _, _ = model(text) +print(output.shape) +print(output) diff --git a/pyproject.toml b/pyproject.toml index 3634f0ca..7b821f6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "zetascale" -version = "2.6.7" +version = "2.6.9" description = "Rapidly Build, Optimize, and Train SOTA AI Models" authors = ["Zeta Team "] license = "MIT" diff --git a/tests/nn/attentions/test_mhaa.py b/tests/nn/attentions/test_mhaa.py index 3cbad5f6..3457c8cb 100644 --- a/tests/nn/attentions/test_mhaa.py +++ b/tests/nn/attentions/test_mhaa.py @@ -117,7 +117,7 @@ def test_attention_distribution(self): ) def setUp(self): - self.d_model = 128 + self.dim = 128 self.num_heads = 4 self.dilation_rate = 2 self.segment_size = 32 @@ -129,10 +129,10 @@ def setUp(self): self.batch_size = 10 self.seq_len = 100 - self.x = torch.rand(self.batch_size, self.seq_len, self.d_model) + self.x = torch.rand(self.batch_size, self.seq_len, self.dim) self.sparse_dilated_attention = MultiheadAttention( - self.d_model, + self.dim, self.num_heads, self.dilation_rate, self.segment_size, @@ -145,7 +145,7 @@ def setUp(self): def test_forward_pass(self): output = self.sparse_dilated_attention(self.x) self.assertEqual( - output.size(), (self.batch_size, self.seq_len, self.d_model) + output.size(), (self.batch_size, self.seq_len, self.dim) ) def test_attention_outputs(self): diff --git a/tests/nn/attentions/test_mqa.py b/tests/nn/attentions/test_mqa.py index e652160d..5626573c 100644 --- a/tests/nn/attentions/test_mqa.py +++ b/tests/nn/attentions/test_mqa.py @@ -5,16 +5,16 @@ def test_multiqueryattention_initialization(): - model = MultiQueryAttention(d_model=512, heads=8) + model = MultiQueryAttention(dim=512, heads=8) assert isinstance(model, MultiQueryAttention) - assert model.d_model == 512 + assert model.dim == 512 assert model.heads == 8 assert model.head_dim == 64 assert model.softmax_scale == 1 / 8 def test_multiqueryattention_forward(): - model = MultiQueryAttention(d_model=512, heads=8) + model = MultiQueryAttention(dim=512, heads=8) x = torch.randn(1, 10, 512) output, attn_weights, past_key_value = model(x) assert output.shape == (1, 10, 512) @@ -24,14 +24,14 @@ def test_multiqueryattention_forward(): @pytest.mark.parametrize("x_len", [0]) def test_multiqueryattention_forward_edge_cases(x_len): - model = MultiQueryAttention(d_model=512, heads=8) + model = MultiQueryAttention(dim=512, heads=8) x = torch.randn(1, x_len, 512) with pytest.raises(Exception): model(x) def test_multiqueryattention_forward_invalid_dimensions(): - model = MultiQueryAttention(d_model=512, heads=8) + model = MultiQueryAttention(dim=512, heads=8) x = torch.randn(1, 10, 256) with pytest.raises(Exception): model(x) diff --git a/tests/nn/modules/test_alr_block.py b/tests/nn/modules/test_alr_block.py index a3b80922..bc25b373 100644 --- a/tests/nn/modules/test_alr_block.py +++ b/tests/nn/modules/test_alr_block.py @@ -17,7 +17,7 @@ def alrblock_model(): @pytest.fixture -def feedforward_model(): +def feedforwardim(): return FeedForward(512, 2048, 0.1) @@ -27,8 +27,8 @@ def test_feedforward_creation(): assert isinstance(model, nn.Module) -def test_feedforward_forward(sample_input, feedforward_model): - output = feedforward_model(sample_input) +def test_feedforward_forward(sample_input, feedforwardim): + output = feedforwardim(sample_input) assert output.shape == sample_input.shape diff --git a/tests/nn/modules/test_avg_model_merger.py b/tests/nn/modules/test_avg_model_merger.py index 1b511aa8..df0cd9f0 100644 --- a/tests/nn/modules/test_avg_model_merger.py +++ b/tests/nn/modules/test_avg_model_merger.py @@ -16,9 +16,9 @@ def test_average_model_merger_merge_models(): model1 = nn.Linear(10, 10) model2 = nn.Linear(10, 10) merger = AverageModelMerger([model1, model2]) - merged_model = merger.merge_models() - assert isinstance(merged_model, nn.Module) - assert merged_model.state_dict().keys() == model1.state_dict().keys() + mergedim = merger.merge_models() + assert isinstance(mergedim, nn.Module) + assert mergedim.state_dict().keys() == model1.state_dict().keys() def test_average_model_merger_copy_model_structure(): @@ -33,10 +33,10 @@ def test_average_model_merger_merge_models_weights(): model1 = nn.Linear(10, 10) model2 = nn.Linear(10, 10) merger = AverageModelMerger([model1, model2]) - merged_model = merger.merge_models() - for param_tensor in merged_model.state_dict(): + mergedim = merger.merge_models() + for param_tensor in mergedim.state_dict(): assert torch.allclose( - merged_model.state_dict()[param_tensor], + mergedim.state_dict()[param_tensor], ( model1.state_dict()[param_tensor] + model2.state_dict()[param_tensor] diff --git a/tests/nn/modules/test_full_feedforward.py b/tests/nn/modules/test_full_feedforward.py index 93fa076e..78d4bcec 100644 --- a/tests/nn/modules/test_full_feedforward.py +++ b/tests/nn/modules/test_full_feedforward.py @@ -5,78 +5,78 @@ @pytest.fixture -def feed_forward_model(): +def feed_forwardim(): return FeedForward(768, 2048, 0.1) -def test_feed_forward_forward(feed_forward_model): +def test_feed_forward_forward(feed_forwardim): x = torch.randn(1, 768) - output = feed_forward_model(x) + output = feed_forwardim(x) assert output.shape == (1, 2048) -def test_feed_forward_relu_squared(feed_forward_model): - feed_forward_model_relu_squared = FeedForward( +def test_feed_forward_relu_squared(feed_forwardim): + feed_forwardim_relu_squared = FeedForward( 768, 2048, 0.1, relu_squared=True ) x = torch.randn(1, 768) - output = feed_forward_model_relu_squared(x) + output = feed_forwardim_relu_squared(x) assert output.shape == (1, 2048) -def test_feed_forward_post_act_ln(feed_forward_model): - feed_forward_model_post_act_ln = FeedForward( +def test_feed_forward_post_act_ln(feed_forwardim): + feed_forwardim_post_act_ln = FeedForward( 768, 2048, 0.1, post_act_ln=True ) x = torch.randn(1, 768) - output = feed_forward_model_post_act_ln(x) + output = feed_forwardim_post_act_ln(x) assert output.shape == (1, 2048) -def test_feed_forward_dropout(feed_forward_model): - feed_forward_model_dropout = FeedForward(768, 2048, 0.5) +def test_feed_forward_dropout(feed_forwardim): + feed_forwardim_dropout = FeedForward(768, 2048, 0.5) x = torch.randn(1, 768) - output = feed_forward_model_dropout(x) + output = feed_forwardim_dropout(x) assert output.shape == (1, 2048) -def test_feed_forward_no_bias(feed_forward_model): - feed_forward_model_no_bias = FeedForward(768, 2048, 0.1, no_bias=True) +def test_feed_forward_no_bias(feed_forwardim): + feed_forwardim_no_bias = FeedForward(768, 2048, 0.1, no_bias=True) x = torch.randn(1, 768) - output = feed_forward_model_no_bias(x) + output = feed_forwardim_no_bias(x) assert output.shape == (1, 2048) -def test_feed_forward_zero_init_output(feed_forward_model): - feed_forward_model_zero_init_output = FeedForward( +def test_feed_forward_zero_init_output(feed_forwardim): + feed_forwardim_zero_init_output = FeedForward( 768, 2048, 0.1, zero_init_output=True ) x = torch.randn(1, 768) - output = feed_forward_model_zero_init_output(x) + output = feed_forwardim_zero_init_output(x) assert output.shape == (1, 2048) assert torch.allclose(output, torch.zeros_like(output)) -def test_feed_forward_glu(feed_forward_model): - feed_forward_model_glu = FeedForward(768, 2048, 0.1, glu=True) +def test_feed_forward_glu(feed_forwardim): + feed_forwardim_glu = FeedForward(768, 2048, 0.1, glu=True) x = torch.randn(1, 768) - output = feed_forward_model_glu(x) + output = feed_forwardim_glu(x) assert output.shape == (1, 2048) -def test_feed_forward_glu_mult_bias(feed_forward_model): - feed_forward_model_glu_mult_bias = FeedForward( +def test_feed_forward_glu_mult_bias(feed_forwardim): + feed_forwardim_glu_mult_bias = FeedForward( 768, 2048, 0.1, glu=True, glu_mult_bias=True ) x = torch.randn(1, 768) - output = feed_forward_model_glu_mult_bias(x) + output = feed_forwardim_glu_mult_bias(x) assert output.shape == (1, 2048) -def test_feed_forward_swish(feed_forward_model): - feed_forward_model_swish = FeedForward(768, 2048, 0.1, swish=True) +def test_feed_forward_swish(feed_forwardim): + feed_forwardim_swish = FeedForward(768, 2048, 0.1, swish=True) x = torch.randn(1, 768) - output = feed_forward_model_swish(x) + output = feed_forwardim_swish(x) assert output.shape == (1, 2048) @@ -146,16 +146,16 @@ def test_feed_forward_invalid_relu_squared_post_act_ln(): def test_feed_forward_dim_out_larger(): - feed_forward_model_dim_out_larger = FeedForward(768, 3072, 0.1) + feed_forwardim_dim_out_larger = FeedForward(768, 3072, 0.1) x = torch.randn(1, 768) - output = feed_forward_model_dim_out_larger(x) + output = feed_forwardim_dim_out_larger(x) assert output.shape == (1, 3072) def test_feed_forward_dim_out_smaller(): - feed_forward_model_dim_out_smaller = FeedForward(768, 512, 0.1) + feed_forwardim_dim_out_smaller = FeedForward(768, 512, 0.1) x = torch.randn(1, 768) - output = feed_forward_model_dim_out_smaller(x) + output = feed_forwardim_dim_out_smaller(x) assert output.shape == (1, 512) diff --git a/tests/nn/modules/test_slerp_model_merger.py b/tests/nn/modules/test_slerp_model_merger.py index 5a83dcab..196531a7 100644 --- a/tests/nn/modules/test_slerp_model_merger.py +++ b/tests/nn/modules/test_slerp_model_merger.py @@ -18,9 +18,9 @@ def test_slerp_model_merger_merge(): model1 = nn.Linear(10, 10) model2 = nn.Linear(10, 10) merger = SLERPModelMerger(model1, model2, 0.5) - merged_model = merger.merge() - assert isinstance(merged_model, nn.Module) - assert merged_model.state_dict().keys() == model1.state_dict().keys() + mergedim = merger.merge() + assert isinstance(mergedim, nn.Module) + assert mergedim.state_dict().keys() == model1.state_dict().keys() def test_slerp_model_merger_slerp(): diff --git a/tests/rl/test_vision_reward_model.py b/tests/rl/test_vision_reward_model.py index 59b45726..1a50cbef 100644 --- a/tests/rl/test_vision_reward_model.py +++ b/tests/rl/test_vision_reward_model.py @@ -24,7 +24,7 @@ def test_residual_block_strides(stride): # 3. VisionRewardModel shape tests @pytest.mark.parametrize("batch_size", [1, 8, 32]) -def test_vision_reward_model_shapes(batch_size): +def test_vision_rewardim_shapes(batch_size): model = VisionRewardModel() sample_image = torch.randn(batch_size, 3, 32, 32) predicted_rewards = model(sample_image) @@ -32,7 +32,7 @@ def test_vision_reward_model_shapes(batch_size): # 4. VisionRewardModel outputs type check -def test_vision_reward_model_output_type(): +def test_vision_rewardim_output_type(): model = VisionRewardModel() sample_image = torch.randn(8, 3, 32, 32) predicted_rewards = model(sample_image) @@ -48,7 +48,7 @@ def test_residual_block_no_nan(): # 6. Ensure no NaN values in VisionRewardModel outputs -def test_vision_reward_model_no_nan(): +def test_vision_rewardim_no_nan(): model = VisionRewardModel() sample_image = torch.randn(8, 3, 32, 32) predicted_rewards = model(sample_image) @@ -56,7 +56,7 @@ def test_vision_reward_model_no_nan(): # 7. Ensure non-zero outputs for VisionRewardModel -def test_vision_reward_model_non_zero(): +def test_vision_rewardim_non_zero(): model = VisionRewardModel() sample_image = torch.randn(8, 3, 32, 32) predicted_rewards = model(sample_image) @@ -85,7 +85,7 @@ def test_residual_block_zero_input(): # 10. Testing zero inputs result in non-zero outputs for VisionRewardModel -def test_vision_reward_model_zero_input(): +def test_vision_rewardim_zero_input(): model = VisionRewardModel() sample_image = torch.zeros(8, 3, 32, 32) predicted_rewards = model(sample_image) @@ -94,7 +94,7 @@ def test_vision_reward_model_zero_input(): # Additional Testing for various shapes (e.g., larger images) @pytest.mark.parametrize("image_size", [32, 64, 128]) -def test_vision_reward_model_with_different_image_sizes(image_size): +def test_vision_rewardim_with_different_image_sizes(image_size): model = VisionRewardModel() sample_image = torch.randn(8, 3, image_size, image_size) predicted_rewards = model(sample_image) diff --git a/zeta/nn/attention/dilated_attention.py b/zeta/nn/attention/dilated_attention.py index 6ee2a7c2..92620e92 100644 --- a/zeta/nn/attention/dilated_attention.py +++ b/zeta/nn/attention/dilated_attention.py @@ -53,7 +53,7 @@ class DilatedAttention(BaseAttention): Dilated Attention Module. Arguments: - d_model: The dimension of the attention layers. + dim: The dimension of the attention layers. num_heads: The number of attention heads. dilation_rate: The dilation rate for dilated attention. segment_size: The segment size for dilated attention. @@ -66,7 +66,7 @@ class DilatedAttention(BaseAttention): The `DilatedAttention` class can be used as a module for neural networks and is especially suited for transformer architectures. Example: - attention = DilatedAttention(d_model=512, num_heads=8, dilation_rate=2, segment_size=64, use_xpos=True, use_rel_pos_bias=True) + attention = DilatedAttention(dim=512, num_heads=8, dilation_rate=2, segment_size=64, use_xpos=True, use_rel_pos_bias=True) output = attention(input_tensor) This will return the output tensor after applying dilated attention. The `use_xpos` and `use_rel_pos_bias` parameters allow for switching on positional encoding and relative positional bias respectively. @@ -74,7 +74,7 @@ class DilatedAttention(BaseAttention): def __init__( self, - d_model: int = None, + dim: int = None, num_heads: int = None, dilation_rate: int = None, segment_size: int = None, @@ -84,7 +84,7 @@ def __init__( use_rel_pos_bias: bool = False, ): super().__init__() - self.d_model = d_model + self.dim = dim self.num_heads = num_heads self.dilation_rate = dilation_rate @@ -101,14 +101,14 @@ def __init__( ) if use_xpos: - self.xpos = XPOS(head_dim=d_model // num_heads) + self.xpos = XPOS(head_dim=dim // num_heads) if use_rel_pos_bias: self.relative_bias = RelativePositionBias( num_buckets=32, max_distance=128, n_heads=num_heads ) # head offsets - self.head_offsets = nn.Parameter(torch.randn(num_heads, d_model)) + self.head_offsets = nn.Parameter(torch.randn(num_heads, dim)) def get_mask(self, i, j): return torch.ones((i, j), device=device, dtype=torch.bool).triu( @@ -128,7 +128,7 @@ def forward(self, x): x = self.xpos(x) # Split and sparsify - x = x.view(batch_size, -1, self.segment_size, self.d_model) + x = x.view(batch_size, -1, self.segment_size, self.dim) print(f"z after view shape: {x.shape}") x = x[:, :, :: self.dilation_rate, :] @@ -169,7 +169,7 @@ def forward(self, x): ) # Scatter and concatenate - attn_output = attn_output.reshape(batch_size, -1, self.d_model) + attn_output = attn_output.reshape(batch_size, -1, self.dim) print( f"attn_output scatter and concatenate: {attn_output.shape} and" f" {attn_output.dtype}" diff --git a/zeta/nn/attention/multiquery_attention.py b/zeta/nn/attention/multiquery_attention.py index 6fae16fa..914aec94 100644 --- a/zeta/nn/attention/multiquery_attention.py +++ b/zeta/nn/attention/multiquery_attention.py @@ -558,7 +558,7 @@ class MultiHeadAttention(nn.Module): def __init__( self, - d_model: int, + dim: int, heads: int, attn_impl: str = "triton", clip_qkv: Optional[float] = None, @@ -576,29 +576,29 @@ def __init__( self.clip_qkv = clip_qkv self.qk_ln = qk_ln - self.d_model = d_model + self.dim = dim self.heads = heads self.softmax_scale = softmax_scale if self.softmax_scale is None: - self.softmax_scale = 1 / math.sqrt(self.d_model / self.heads) + self.softmax_scale = 1 / math.sqrt(self.dim / self.heads) self.attn_dropout = attn_pdrop fc_kwargs = {} if fc_type != "te": fc_kwargs["device"] = device self.Wqkv = FC_CLASS_REGISTRY[fc_type]( - self.d_model, - 3 * self.d_model, + self.dim, + 3 * self.dim, **fc_kwargs, ) # for param init fn; enables shape based init of fused layers - fuse_splits = (d_model, 2 * d_model) + fuse_splits = (dim, 2 * dim) self.Wqkv._fused = (0, fuse_splits) # type: ignore if self.qk_ln: norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] - self.q_ln = norm_class(self.d_model, device=device) - self.k_ln = norm_class(self.d_model, device=device) + self.q_ln = norm_class(self.dim, device=device) + self.k_ln = norm_class(self.dim, device=device) if self.attn_impl == "flash": self.attn_fn = flash_attn_fn @@ -629,8 +629,8 @@ def __init__( raise ValueError(f"{attn_impl=} is an invalid setting.") self.out_proj = FC_CLASS_REGISTRY[fc_type]( - self.d_model, - self.d_model, + self.dim, + self.dim, **fc_kwargs, ) self.out_proj._is_residual = True # type: ignore @@ -688,7 +688,7 @@ class MultiQueryAttention(BaseAttention): def __init__( self, - d_model: int, + dim: int, heads: int, attn_impl: str = "torch", clip_qkv: Optional[float] = None, @@ -706,9 +706,9 @@ def __init__( self.clip_qkv = clip_qkv self.qk_ln = qk_ln - self.d_model = d_model + self.dim = dim self.heads = heads - self.head_dim = d_model // heads + self.head_dim = dim // heads self.softmax_scale = softmax_scale if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.head_dim) @@ -719,17 +719,17 @@ def __init__( fc_kwargs["device"] = device # - vchiley self.Wqkv = FC_CLASS_REGISTRY[fc_type]( - d_model, - d_model + 2 * self.head_dim, + dim, + dim + 2 * self.head_dim, **fc_kwargs, ) # for param init fn; enables shape based init of fused layers - fuse_splits = (d_model, d_model + self.head_dim) + fuse_splits = (dim, dim + self.head_dim) self.Wqkv._fused = (0, fuse_splits) # type: ignore if self.qk_ln: norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] - self.q_ln = norm_class(d_model, device=device) + self.q_ln = norm_class(dim, device=device) self.k_ln = norm_class(self.head_dim, device=device) if self.attn_impl == "flash": @@ -761,8 +761,8 @@ def __init__( raise ValueError(f"{attn_impl=} is an invalid setting.") self.out_proj = FC_CLASS_REGISTRY[fc_type]( - self.d_model, - self.d_model, + self.dim, + self.dim, **fc_kwargs, ) self.out_proj._is_residual = True # type: ignore @@ -782,7 +782,7 @@ def forward( qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv) query, key, value = qkv.split( - [self.d_model, self.head_dim, self.head_dim], dim=2 + [self.dim, self.head_dim, self.head_dim], dim=2 ) key_padding_mask = mask diff --git a/zeta/nn/embeddings/positional.py b/zeta/nn/embeddings/positional.py index e94c2bb4..b4d96df2 100644 --- a/zeta/nn/embeddings/positional.py +++ b/zeta/nn/embeddings/positional.py @@ -9,7 +9,7 @@ class PositionalEmbedding(nn.Embedding): Args: - d_model (int): Dimension of the model. + dim (int): Dimension of the model. max_len (int): Maximum length of the input sequence. padding_idx (int, optional): Index of the padding token. Defaults to 0. scale_grad_by_freq (bool, optional): If True, scale gradients by frequency. Defaults to False. diff --git a/zeta/nn/modules/README.md b/zeta/nn/modules/README.md index c4fb82c3..628f1535 100644 --- a/zeta/nn/modules/README.md +++ b/zeta/nn/modules/README.md @@ -322,6 +322,6 @@ model_path = f"runs:/{run.info.run_id}/model" version = register_model(model_uri=model_path, name=model_name) # Load the registered model for inference -loaded_model = Model.load(model_uri=f"models:/{model_name}/{version}") +loadedim = Model.load(model_uri=f"models:/{model_name}/{version}") By incorporating these additional shapeless and fluid containers, AI engineering can be made more seamless, efficient, and modular, ultimately leading to improved development and deployment of AI models and algorithms. message.txt diff --git a/zeta/nn/modules/avg_model_merger.py b/zeta/nn/modules/avg_model_merger.py index d3ee7cfb..8c72a84b 100644 --- a/zeta/nn/modules/avg_model_merger.py +++ b/zeta/nn/modules/avg_model_merger.py @@ -20,8 +20,8 @@ class AverageModelMerger: model2 = nn.Linear(in_features=10, out_features=10) model3 = nn.Linear(in_features=10, out_features=10) merge = AverageModelMerger([model1, model2, model3]) - merged_model = merge.merge_models() - print(merged_model) + mergedim = merge.merge_models() + print(mergedim) """ def __init__(self, models: List[nn.Module]): @@ -46,10 +46,10 @@ def merge_models(self) -> nn.Module: """ assert len(self.models) > 0, "models list must not be empty" - merged_model = self._copy_model_structure(self.models[0]) + mergedim = self._copy_model_structure(self.models[0]) # Initialize a state_dict for the merged model - merged_state_dict = merged_model.state_dict() + merged_state_dict = mergedim.state_dict() # Iterate over each parameter in the model's state_dict for key in merged_state_dict.keys(): @@ -59,8 +59,8 @@ def merge_models(self) -> nn.Module: ) / len(self.models) # Load the averaged state_dict into the merged model - merged_model.load_state_dict(merged_state_dict) - return merged_model + mergedim.load_state_dict(merged_state_dict) + return mergedim @staticmethod def _copy_model_structure(model: nn.Module) -> nn.Module: @@ -86,5 +86,5 @@ def _copy_model_structure(model: nn.Module) -> nn.Module: # model2 = nn.Linear(in_features=10, out_features=10) # model3 = nn.Linear(in_features=10, out_features=10) # merge = AverageModelMerger([model1, model2, model3]) -# merged_model = merge.merge_models() -# print(merged_model) +# mergedim = merge.merge_models() +# print(mergedim) diff --git a/zeta/nn/modules/deepseek_moe.py b/zeta/nn/modules/deepseek_moe.py index 0c5f3fb8..6d5bedc8 100644 --- a/zeta/nn/modules/deepseek_moe.py +++ b/zeta/nn/modules/deepseek_moe.py @@ -41,8 +41,8 @@ def __init__( self.gate = nn.Linear(dim, num_experts) def forward(self, x: Tensor): - batch_size, seq_len, d_model = x.shape - x_flat = x.view(-1, d_model) # Flatten for gating + batch_size, seq_len, dim = x.shape + x_flat = x.view(-1, dim) # Flatten for gating # Apply gating mechanism and ensure indices are within the valid range gate_scores = F.softmax(self.gate(x_flat), dim=-1) @@ -71,13 +71,13 @@ def forward(self, x: Tensor): # Example usage -d_model = 512 +dim = 512 num_experts = 16 d_ff = 2048 top_k = 2 num_shared_experts = 2 -moe_model = DeepSeekMoE(d_model, num_experts, d_ff, top_k, num_shared_experts) +moe_model = DeepSeekMoE(dim, num_experts, d_ff, top_k, num_shared_experts) input_tensor = torch.randn( 10, 15, 512 ) # Batch size of 10, sequence length 15, feature size of 512 diff --git a/zeta/nn/modules/g_shard_moe.py b/zeta/nn/modules/g_shard_moe.py index d26aecfb..818117f6 100644 --- a/zeta/nn/modules/g_shard_moe.py +++ b/zeta/nn/modules/g_shard_moe.py @@ -719,7 +719,7 @@ def forward( # assert input.shape[0] % len(self.experts) == 0, "num tokens must be order of number of local experts" # Implement Algorithm 2 from GShard paper. - d_model = input.shape[2] + dim = input.shape[2] # Pad to expected batch size input_shape = list(input.shape) expected_bsz = ( @@ -771,7 +771,7 @@ def forward( input_padding_mask = padded_input_padding_mask # Reshape into S tokens by dropping sequence dimension. - reshaped_input = input.reshape(-1, d_model) + reshaped_input = input.reshape(-1, dim) reshaped_input_shape = reshaped_input.shape reshaped_input_padding_mask = ( input_padding_mask.reshape(-1) @@ -851,7 +851,7 @@ def forward( # Re-shape after all-to-all: ecm -> gecm dispatched_input = dispatched_input.reshape( - self.all2all_size, self.num_local_experts, -1, d_model + self.all2all_size, self.num_local_experts, -1, dim ) chunks = dispatched_input.chunk(self.num_local_experts, dim=1) expert_outputs = [] @@ -864,7 +864,7 @@ def forward( # Re-shape back: gecm -> ecm expert_output = expert_output.reshape( - self.all2all_size * self.num_local_experts, -1, d_model + self.all2all_size * self.num_local_experts, -1, dim ) if has_tutel: diff --git a/zeta/nn/modules/gill_mapper.py b/zeta/nn/modules/gill_mapper.py index 01e8bc09..99eeee34 100644 --- a/zeta/nn/modules/gill_mapper.py +++ b/zeta/nn/modules/gill_mapper.py @@ -55,7 +55,7 @@ class GILLMapper(nn.Module): def __post_init__(self): super().__init__() self.transformer = nn.Transformer( - d_model=self.text_emb_size, + dim=self.text_emb_size, num_encoder_layers=self.num_encoder_layers, num_decoder_layers=self.num_decoder_layers, dim_feedforward=self.dim_ffn, @@ -74,7 +74,7 @@ def __post_init__(self): self.transformer_encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer( - d_model=self.text_emb_size, + dim=self.text_emb_size, nhead=self.heads, dim_feedforward=self.dim_ffn, ), diff --git a/zeta/nn/modules/mixtape.py b/zeta/nn/modules/mixtape.py index 06362235..e4cb6ccc 100644 --- a/zeta/nn/modules/mixtape.py +++ b/zeta/nn/modules/mixtape.py @@ -4,10 +4,10 @@ class Mixtape(nn.Module): - def __init__(self, vocab_size, d_model, d1, d2, num_gates=4): + def __init__(self, vocab_size, dim, d1, d2, num_gates=4): super(Mixtape, self).__init__() self.vocab_size = vocab_size - self.d_model = d_model + self.dim = dim self.d1 = d1 self.d2 = d2 self.num_gates = num_gates @@ -20,12 +20,12 @@ def __init__(self, vocab_size, d_model, d1, d2, num_gates=4): # Parameters for context embeddings self.H = nn.Parameter( - torch.randn(self.num_gates, self.d_model, self.d1) + torch.randn(self.num_gates, self.dim, self.d1) ) # Token embeddings (not specified in the abstract, assuming needed) self.token_embeddings = nn.Parameter( - torch.randn(self.vocab_size, self.d_model) + torch.randn(self.vocab_size, self.dim) ) def forward(self, gc): @@ -35,7 +35,7 @@ def forward(self, gc): # Expanded gc to [batch_size, seq_length, 1, d1] for broadcasting hc = torch.tanh( torch.einsum("kij,btj->btki", self.H, gc) - ) # (batch_size, seq_length, num_gates, d_model) + ) # (batch_size, seq_length, num_gates, dim) # Compute pre-activation gate priors for each token and gate # Expanded gc for broadcasting with different parameters @@ -80,13 +80,13 @@ def forward(self, gc): # Example usage -d_model = 512 +dim = 512 d1 = 256 d2 = 128 vocab_size = 10000 seq_length = 20 -model = Mixtape(vocab_size=vocab_size, d_model=d_model, d1=d1, d2=d2) +model = Mixtape(vocab_size=vocab_size, dim=dim, d1=d1, d2=d2) gc = torch.randn( 10, seq_length, d1 ) # Simulated last-layer hidden states for a batch of 10 with sequence length 20 diff --git a/zeta/nn/modules/slerp_model_merger.py b/zeta/nn/modules/slerp_model_merger.py index 34b64089..1ece00c2 100644 --- a/zeta/nn/modules/slerp_model_merger.py +++ b/zeta/nn/modules/slerp_model_merger.py @@ -25,8 +25,8 @@ class SLERPModelMerger(nn.Module): model4 = nn.Linear(10, 10) merge = SLERPModelMerger(model1, model2, 0.5) - merged_model = merge.merge() - print(merged_model.state_dict()) + mergedim = merge.merge() + print(mergedim.state_dict()) """ @enforce_types @@ -48,14 +48,14 @@ def merge(self) -> nn.Module: Returns: nn.Module: A new model with merged weights. """ - merged_model = self._copy_model_structure(self.model1) + mergedim = self._copy_model_structure(self.model1) # Get the state dicts of both models state_dict1 = self.model1.state_dict() state_dict2 = self.model2.state_dict() # Init a state dict for the merged model - merged_state_dict = merged_model.state_dict() + merged_state_dict = mergedim.state_dict() for key in merged_state_dict.keys(): # Perform WELP for each parameter @@ -64,8 +64,8 @@ def merge(self) -> nn.Module: merged_state_dict[key] = self._slerp(w1, w2, self.t) # Load the mergd state dict into the new model - merged_model.load_state_dict(merged_state_dict) - return merged_model + mergedim.load_state_dict(merged_state_dict) + return mergedim @staticmethod @enforce_types @@ -119,5 +119,5 @@ def _copy_model_structure(model: nn.Module) -> nn.Module: # model4 = nn.Linear(10, 10) # merge = SLERPModelMerger(model1, model2, 0.5) -# merged_model = merge.merge() -# print(merged_model.state_dict()) +# mergedim = merge.merge() +# print(mergedim.state_dict()) diff --git a/zeta/nn/modules/subln.py b/zeta/nn/modules/subln.py index 95004db0..a8ab7d9b 100644 --- a/zeta/nn/modules/subln.py +++ b/zeta/nn/modules/subln.py @@ -10,7 +10,7 @@ class SubLN(nn.Module): Parameters: ----------- - d_model: int + dim: int The number of expected features in the input x γ: float, optional Gain value for weight initialization. Default is 1.0. @@ -22,21 +22,21 @@ class SubLN(nn.Module): import torch from zeta.nn.modules import SubLN - model = SubLN(d_model=512) + model = SubLN(dim=512) x = torch.randn(10, 512) out = model(x) print(out) """ - def __init__(self, d_model, γ=1.0): + def __init__(self, dim, γ=1.0): super().__init__() # Define necessary layers and operations - self.LN1 = nn.LayerNorm(d_model) - self.fin = nn.Linear(d_model, d_model) # Example layer for fin - self.fout = nn.Linear(d_model, d_model) # Example layer for fout - self.LN2 = nn.LayerNorm(d_model) + self.LN1 = nn.LayerNorm(dim) + self.fin = nn.Linear(dim, dim) # Example layer for fin + self.fout = nn.Linear(dim, dim) # Example layer for fout + self.LN2 = nn.LayerNorm(dim) # Weight initialization self._initialize_weights(γ) @@ -48,12 +48,12 @@ def forward(self, x): Parameters: ----------- x : torch.Tensor - Input tensor of shape [batch_size, d_model] + Input tensor of shape [batch_size, dim] Returns: -------- torch.Tensor - Output tensor of shape [batch_size, d_model] + Output tensor of shape [batch_size, dim] """ return x + self.fout(self.LN2(self.fin(self.LN1(x)))) diff --git a/zeta/nn/modules/xmoe/moe_layer.py b/zeta/nn/modules/xmoe/moe_layer.py index 67f70cfb..fa7e4255 100644 --- a/zeta/nn/modules/xmoe/moe_layer.py +++ b/zeta/nn/modules/xmoe/moe_layer.py @@ -126,7 +126,7 @@ def forward( # assert input.shape[0] % len(self.experts) == 0, "num tokens must be order of number of local experts" # Implement Algorithm 2 from GShard paper. - d_model = input.shape[2] + dim = input.shape[2] # Pad to expected batch size input_shape = list(input.shape) expected_bsz = ( @@ -178,7 +178,7 @@ def forward( input_padding_mask = padded_input_padding_mask # Reshape into S tokens by dropping sequence dimension. - reshaped_input = input.reshape(-1, d_model) + reshaped_input = input.reshape(-1, dim) reshaped_input_shape = reshaped_input.shape reshaped_input_padding_mask = ( input_padding_mask.reshape(-1) @@ -259,7 +259,7 @@ def forward( # Re-shape after all-to-all: ecm -> gecm dispatched_input = dispatched_input.reshape( - self.all2all_size, self.num_local_experts, -1, d_model + self.all2all_size, self.num_local_experts, -1, dim ) chunks = dispatched_input.chunk(self.num_local_experts, dim=1) expert_outputs = [] @@ -272,7 +272,7 @@ def forward( # Re-shape back: gecm -> ecm expert_output = expert_output.reshape( - self.all2all_size * self.num_local_experts, -1, d_model + self.all2all_size * self.num_local_experts, -1, dim ) if has_tutel: diff --git a/zeta/ops/async_softmax.py b/zeta/ops/async_softmax.py index a79f625e..be1ad4c1 100644 --- a/zeta/ops/async_softmax.py +++ b/zeta/ops/async_softmax.py @@ -50,15 +50,15 @@ def asynchronized_softmax(Q, K, V, unified_max_value): # Define the main class for the attention mechanism class AsynchronizedAttention(nn.Module): - def __init__(self, d_model, n_heads, unified_max_value): + def __init__(self, dim, n_heads, unified_max_value): super().__init__() - self.d_model = d_model + self.dim = dim self.n_heads = n_heads self.unified_max_value = unified_max_value - self.head_dim = d_model // n_heads + self.head_dim = dim // n_heads # Linear layers for Q, K, V projections - self.qkv_proj = nn.Linear(d_model, d_model * 3) + self.qkv_proj = nn.Linear(dim, dim * 3) def forward(self, x): batch_size, seq_length, _ = x.size() diff --git a/zeta/rl/__init__.py b/zeta/rl/__init__.py index 08d32d9e..a6877adc 100644 --- a/zeta/rl/__init__.py +++ b/zeta/rl/__init__.py @@ -7,7 +7,7 @@ ) from zeta.rl.hindsight_replay import HindsightExperienceReplay from zeta.rl.language_reward import LanguageReward -from zeta.rl.reward_model import RewardModel +from zeta.rl.rewardim import RewardModel __all__ = [ "RewardModel", diff --git a/zeta/rl/vision_model_rl.py b/zeta/rl/vision_model_rl.py index f2b64956..6224d89e 100644 --- a/zeta/rl/vision_model_rl.py +++ b/zeta/rl/vision_model_rl.py @@ -93,8 +93,8 @@ def forward(self, x): # output_tensor = res_block(sample_tensor) # # 2. Example for VisionRewardModel -# vision_reward_model = VisionRewardModel() +# vision_rewardim = VisionRewardModel() # sample_image = torch.randn(8, 3, 32, 32) -# predicted_rewards = vision_reward_model(sample_image) +# predicted_rewards = vision_rewardim(sample_image) # print(output_tensor.shape, predicted_rewards.shape) diff --git a/zeta/structs/auto_regressive_wrapper.py b/zeta/structs/auto_regressive_wrapper.py index 3c3da954..a257ac20 100644 --- a/zeta/structs/auto_regressive_wrapper.py +++ b/zeta/structs/auto_regressive_wrapper.py @@ -343,10 +343,10 @@ def generate_n_solutions(self, start_tokens, n, seqlen, **kwargs): def evaluate_and_select_best_solution( self, solutions, - reward_model, + rewardim, ): """Evaluate solutions and select the best one.""" - scores = [reward_model(solution) for solution in solutions] + scores = [rewardim(solution) for solution in solutions] best_solution_idx = scores.index(max(scores)) return solutions[best_solution_idx] diff --git a/zeta/structs/clip_encoder.py b/zeta/structs/clip_encoder.py index 41760a3a..fec7d5c2 100644 --- a/zeta/structs/clip_encoder.py +++ b/zeta/structs/clip_encoder.py @@ -16,13 +16,13 @@ def __init__(self, vision_tower, args, delay_load=False): self.select_feature = getattr(args, "mm_vision_select_feature", "patch") if not delay_load: - self.load_model() + self.loadim() else: self.cfg_only = CLIPVisionConfig.from_pretrained( self.vision_tower_name ) - def load_model(self): + def loadim(self): self.image_processor = CLIPImageProcessor.from_pretrained( self.vision_tower_name ) diff --git a/zeta/training/train.py b/zeta/training/train.py index ec8c86c7..5f2f9679 100644 --- a/zeta/training/train.py +++ b/zeta/training/train.py @@ -258,10 +258,10 @@ def Trainer( accelerator.print(f"Saving model to {output_dir}") if output_dir is not None: accelerator.wait_for_everyone() - unwrapped_model = accelerator.unwrap_model(model) + unwrappedim = accelerator.unwrap_model(model) with accelerator.main_process_first(): accelerator.save( - unwrapped_model.state_dict(), + unwrappedim.state_dict(), f"{output_dir}/final/final_model.pt", ) diff --git a/zeta/utils/main.py b/zeta/utils/main.py index 9b5bc791..0addd221 100644 --- a/zeta/utils/main.py +++ b/zeta/utils/main.py @@ -434,7 +434,7 @@ def forward(self, x, time_emb=None): return h + self.res_conv(x) -def load_model(path): +def loadim(path): with open(path, "rb") as f: return torch.load( f, map_location=torch.device("cpu"), weights_only=True