Skip to content

Commit

Permalink
[FEAT][Mixture of Mambas]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 20, 2023
1 parent 5e1132a commit ec26627
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 3 deletions.
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "swarms-torch"
version = "0.1.1"
version = "0.1.2"
description = "swarms-torch - Pytorch"
license = "MIT"
authors = ["Kye Gomez <[email protected]>"]
Expand All @@ -31,6 +31,10 @@ packages = [
python = "^3.6"
torch = "*"
einops = "*"
zetascale = "*"
mamba-ssm = "*"
causal-conv1d = "1.1.0"




Expand Down
6 changes: 4 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
torch
einops
pandas

zetascale
mamba-ssm
causal-conv1d



Expand All @@ -11,4 +13,4 @@ pandas

mkdocs
mkdocs-material
mkdocs-glightbox
mkdocs-glightbox
82 changes: 82 additions & 0 deletions swarms_torch/mixture_of_mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
from torch import nn

try:
from mamba_ssm import Mamba
except ImportError:
print("Mamba not installed")


class MixtureOfMambas(nn.Module):
def __init__(
self,
num_mambas: int,
dim: int,
d_state: int,
d_conv,
expand: int,
aggregation_method: str = "average",
):
super(MixtureOfMambas, self).__init__()
self.num_mambas = num_mambas
self.dim = dim
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.aggregation_method = aggregation_method

self.models = nn.ModuleList()
for _ in range(num_mambas):
mamba_model = Mamba(dim, d_state, d_conv, d_conv, expand)
self.models.append(mamba_model)

def forward(self, x: torch.Tensor, weights=None):
"""Forward pass of the swarm
Args:
x (torch.Tensor): _description_
weights (_type_, optional): _description_. Defaults to None.
Raises:
ValueError: _description_
Returns:
_type_: _description_
"""
outputs = [model(x) for model in self.models]

if self.aggregation_method == "average":
return torch.mean(torch.stack(outputs), dim=0)
elif self.aggregation_method == "weighted":
return self.weighted_aggregate(outputs, weights)
else:
raise ValueError(f"Unknown aggregation method: {self.aggregation_method}")

def average_aggregate(self, outputs):
"""Average the outputs of the models in the swarm
Args:
outputs (_type_): _description_
Returns:
_type_: _description_
"""
return torch.mean(torch.stack(outputs), dim=0)

def weighted_aggegrate(self, outputs, weights):
"""Weighted average the outputs of the models in the swarm
Args:
outputs (_type_): _description_
weights (_type_): _description_
Raises:
ValueError: _description_
Returns:
_type_: _description_
"""
if weights is None or len(weights) != len(outputs):
raise ValueError("Weights must be the same length as outputs")
weighted_outputs = [weight * output for weight, output in zip(weights, outputs)]
return sum(weighted_outputs)
84 changes: 84 additions & 0 deletions tests/test_mixture_of_mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import pytest
import torch
from swarms_torch.mixture_of_mamba import MixtureOfMambas


@pytest.fixture
def mixture():
num_mambas = 5
dim = 10
d_state = 20
d_conv = 30
expand = 40
return MixtureOfMambas(num_mambas, dim, d_state, d_conv, expand)


def test_init(mixture):
assert mixture.num_mambas == 5
assert mixture.dim == 10
assert mixture.d_state == 20
assert mixture.d_conv == 30
assert mixture.expand == 40
assert len(mixture.models) == 5


def test_forward_average(mixture):
x = torch.rand((1, 10))
output = mixture.forward(x)
assert output.shape == (1, 10)


def test_forward_weighted(mixture):
x = torch.rand((1, 10))
weights = torch.ones(5)
mixture.aggregation_method = "weighted"
output = mixture.forward(x, weights)
assert output.shape == (1, 10)


def test_forward_invalid_aggregation(mixture):
x = torch.rand((1, 10))
mixture.aggregation_method = "invalid"
with pytest.raises(ValueError):
mixture.forward(x)


def test_average_aggregate(mixture):
outputs = [torch.rand((1, 10)) for _ in range(5)]
output = mixture.average_aggregate(outputs)
assert output.shape == (1, 10)


def test_weighted_aggregate(mixture):
outputs = [torch.rand((1, 10)) for _ in range(5)]
weights = torch.ones(5)
output = mixture.weighted_aggregate(outputs, weights)
assert output.shape == (1, 10)


def test_weighted_aggregate_invalid_weights(mixture):
outputs = [torch.rand((1, 10)) for _ in range(5)]
weights = torch.ones(4)
with pytest.raises(ValueError):
mixture.weighted_aggregate(outputs, weights)


def test_forward_different_dimensions(mixture):
x = torch.rand((2, 10))
with pytest.raises(ValueError):
mixture.forward(x)


def test_forward_no_weights(mixture):
x = torch.rand((1, 10))
mixture.aggregation_method = "weighted"
with pytest.raises(ValueError):
mixture.forward(x)


def test_forward_extra_weights(mixture):
x = torch.rand((1, 10))
weights = torch.ones(6)
mixture.aggregation_method = "weighted"
with pytest.raises(ValueError):
mixture.forward(x, weights)

0 comments on commit ec26627

Please sign in to comment.