Skip to content

Commit

Permalink
[MiT and SegFormer] Refactor Backbone Arg Names (#1958)
Browse files Browse the repository at this point in the history
* refactor mit

* update presets and tools

* update presets

* fix merging master into feature branch

* fix tests after merging master into feature branch

---------

Co-authored-by: Divyashree Sreepathihalli <[email protected]>
  • Loading branch information
DavidLandup0 and divyashreepathihalli authored Nov 2, 2024
1 parent 550d04f commit 4a86c89
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 149 deletions.
55 changes: 29 additions & 26 deletions keras_hub/src/models/mit/mit_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
class MiTBackbone(FeaturePyramidBackbone):
def __init__(
self,
depths,
layerwise_depths,
num_layers,
blockwise_num_heads,
blockwise_sr_ratios,
layerwise_num_heads,
layerwise_sr_ratios,
max_drop_path_rate,
patch_sizes,
strides,
layerwise_patch_sizes,
layerwise_strides,
image_shape=(None, None, 3),
hidden_dims=None,
**kwargs,
Expand All @@ -43,12 +43,12 @@ def __init__(
https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/classification/mix_transformer)
Args:
depths: The number of transformer encoders to be used per layer in the
layerwise_depths: The number of transformer encoders to be used per layer in the
network.
num_layers: int. The number of Transformer layers.
blockwise_num_heads: list of integers, the number of heads to use
layerwise_num_heads: list of integers, the number of heads to use
in the attention computation for each layer.
blockwise_sr_ratios: list of integers, the sequence reduction
layerwise_sr_ratios: list of integers, the sequence reduction
ratio to perform for each layer on the sequence before key and
value projections. If set to > 1, a `Conv2D` layer is used to
reduce the length of the sequence.
Expand Down Expand Up @@ -82,7 +82,10 @@ def __init__(
model.fit(images, labels, epochs=3)
```
"""
dpr = [x for x in np.linspace(0.0, max_drop_path_rate, sum(depths))]
dpr = [
x
for x in np.linspace(0.0, max_drop_path_rate, sum(layerwise_depths))
]

# === Layers ===
cur = 0
Expand All @@ -93,24 +96,24 @@ def __init__(
for i in range(num_layers):
patch_embed_layer = OverlappingPatchingAndEmbedding(
project_dim=hidden_dims[i],
patch_size=patch_sizes[i],
stride=strides[i],
patch_size=layerwise_patch_sizes[i],
stride=layerwise_strides[i],
name=f"patch_and_embed_{i}",
)
patch_embedding_layers.append(patch_embed_layer)

transformer_block = [
HierarchicalTransformerEncoder(
project_dim=hidden_dims[i],
num_heads=blockwise_num_heads[i],
sr_ratio=blockwise_sr_ratios[i],
num_heads=layerwise_num_heads[i],
sr_ratio=layerwise_sr_ratios[i],
drop_prob=dpr[cur + k],
name=f"hierarchical_encoder_{i}_{k}",
)
for k in range(depths[i])
for k in range(layerwise_depths[i])
]
transformer_blocks.append(transformer_block)
cur += depths[i]
cur += layerwise_depths[i]
layer_norms.append(keras.layers.LayerNormalization(epsilon=1e-5))

# === Functional Model ===
Expand All @@ -120,7 +123,7 @@ def __init__(
for i in range(num_layers):
# Compute new height/width after the `proj`
# call in `OverlappingPatchingAndEmbedding`
stride = strides[i]
stride = layerwise_strides[i]
new_height, new_width = (
int(ops.shape(x)[1] / stride),
int(ops.shape(x)[2] / stride),
Expand All @@ -138,30 +141,30 @@ def __init__(
super().__init__(inputs=image_input, outputs=x, **kwargs)

# === Config ===
self.depths = depths
self.layerwise_depths = layerwise_depths
self.image_shape = image_shape
self.hidden_dims = hidden_dims
self.pyramid_outputs = pyramid_outputs
self.num_layers = num_layers
self.blockwise_num_heads = blockwise_num_heads
self.blockwise_sr_ratios = blockwise_sr_ratios
self.layerwise_num_heads = layerwise_num_heads
self.layerwise_sr_ratios = layerwise_sr_ratios
self.max_drop_path_rate = max_drop_path_rate
self.patch_sizes = patch_sizes
self.strides = strides
self.layerwise_patch_sizes = layerwise_patch_sizes
self.layerwise_strides = layerwise_strides

def get_config(self):
config = super().get_config()
config.update(
{
"depths": self.depths,
"layerwise_depths": self.layerwise_depths,
"hidden_dims": self.hidden_dims,
"image_shape": self.image_shape,
"num_layers": self.num_layers,
"blockwise_num_heads": self.blockwise_num_heads,
"blockwise_sr_ratios": self.blockwise_sr_ratios,
"layerwise_num_heads": self.layerwise_num_heads,
"layerwise_sr_ratios": self.layerwise_sr_ratios,
"max_drop_path_rate": self.max_drop_path_rate,
"patch_sizes": self.patch_sizes,
"strides": self.strides,
"layerwise_patch_sizes": self.layerwise_patch_sizes,
"layerwise_strides": self.layerwise_strides,
}
)
return config
10 changes: 5 additions & 5 deletions keras_hub/src/models/mit/mit_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
class MiTBackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
"depths": [2, 2],
"layerwise_depths": [2, 2],
"image_shape": (32, 32, 3),
"hidden_dims": [4, 8],
"num_layers": 2,
"blockwise_num_heads": [1, 2],
"blockwise_sr_ratios": [8, 4],
"layerwise_num_heads": [1, 2],
"layerwise_sr_ratios": [8, 4],
"max_drop_path_rate": 0.1,
"patch_sizes": [7, 3],
"strides": [4, 2],
"layerwise_patch_sizes": [7, 3],
"layerwise_strides": [4, 2],
}
self.input_size = 32
self.input_data = np.ones(
Expand Down
10 changes: 5 additions & 5 deletions keras_hub/src/models/mit/mit_image_classifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ def setUp(self):
self.images = np.ones((2, 32, 32, 3), dtype="float32")
self.labels = [0, 3]
self.backbone = MiTBackbone(
depths=[2, 2, 2, 2],
layerwise_depths=[2, 2, 2, 2],
image_shape=(32, 32, 3),
hidden_dims=[4, 8],
num_layers=2,
blockwise_num_heads=[1, 2],
blockwise_sr_ratios=[8, 4],
layerwise_num_heads=[1, 2],
layerwise_sr_ratios=[8, 4],
max_drop_path_rate=0.1,
patch_sizes=[7, 3],
strides=[4, 2],
layerwise_patch_sizes=[7, 3],
layerwise_strides=[4, 2],
)
self.init_kwargs = {
"backbone": self.backbone,
Expand Down
24 changes: 12 additions & 12 deletions keras_hub/src/models/mit/mit_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"official_name": "MiT",
"path": "mit",
},
"kaggle_handle": "kaggle://keras/mit/keras/mit_b0_ade20k_512/1",
"kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b0_ade20k_512/2",
},
"mit_b1_ade20k_512": {
"metadata": {
Expand All @@ -32,7 +32,7 @@
"official_name": "MiT",
"path": "mit",
},
"kaggle_handle": "kaggle://keras/mit/keras/mit_b1_ade20k_512/1",
"kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b1_ade20k_512/2",
},
"mit_b2_ade20k_512": {
"metadata": {
Expand All @@ -43,7 +43,7 @@
"official_name": "MiT",
"path": "mit",
},
"kaggle_handle": "kaggle://keras/mit/keras/mit_b2_ade20k_512/1",
"kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b2_ade20k_512/2",
},
"mit_b3_ade20k_512": {
"metadata": {
Expand All @@ -54,7 +54,7 @@
"official_name": "MiT",
"path": "mit",
},
"kaggle_handle": "kaggle://keras/mit/keras/mit_b3_ade20k_512/1",
"kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b3_ade20k_512/2",
},
"mit_b4_ade20k_512": {
"metadata": {
Expand All @@ -65,7 +65,7 @@
"official_name": "MiT",
"path": "mit",
},
"kaggle_handle": "kaggle://keras/mit/keras/mit_b4_ade20k_512/1",
"kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b4_ade20k_512/2",
},
"mit_b5_ade20k_640": {
"metadata": {
Expand All @@ -76,7 +76,7 @@
"official_name": "MiT",
"path": "mit",
},
"kaggle_handle": "kaggle://keras/mit/keras/mit_b5_ade20k_640/1",
"kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b5_ade20k_640/2",
},
"mit_b0_cityscapes_1024": {
"metadata": {
Expand All @@ -87,7 +87,7 @@
"official_name": "MiT",
"path": "mit",
},
"kaggle_handle": "kaggle://keras/mit/keras/mit_b0_cityscapes_1024/1",
"kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b0_cityscapes_1024/2",
},
"mit_b1_cityscapes_1024": {
"metadata": {
Expand All @@ -98,7 +98,7 @@
"official_name": "MiT",
"path": "mit",
},
"kaggle_handle": "kaggle://keras/mit/keras/mit_b1_cityscapes_1024/1",
"kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b1_cityscapes_1024/2",
},
"mit_b2_cityscapes_1024": {
"metadata": {
Expand All @@ -109,7 +109,7 @@
"official_name": "MiT",
"path": "mit",
},
"kaggle_handle": "kaggle://keras/mit/keras/mit_b2_cityscapes_1024/1",
"kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b2_cityscapes_1024/2",
},
"mit_b3_cityscapes_1024": {
"metadata": {
Expand All @@ -120,7 +120,7 @@
"official_name": "MiT",
"path": "mit",
},
"kaggle_handle": "kaggle://keras/mit/keras/mit_b3_cityscapes_1024/1",
"kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b3_cityscapes_1024/2",
},
"mit_b4_cityscapes_1024": {
"metadata": {
Expand All @@ -131,7 +131,7 @@
"official_name": "MiT",
"path": "mit",
},
"kaggle_handle": "kaggle://keras/mit/keras/mit_b4_cityscapes_1024/1",
"kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b4_cityscapes_1024/2",
},
"mit_b5_cityscapes_1024": {
"metadata": {
Expand All @@ -142,7 +142,7 @@
"official_name": "MiT",
"path": "mit",
},
"kaggle_handle": "kaggle://keras/mit/keras/mit_b5_cityscapes_1024/1",
"kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b5_cityscapes_1024/2",
},
}

Expand Down
45 changes: 0 additions & 45 deletions keras_hub/src/models/mit/mix_transformer_backbone_test.py

This file was deleted.

Loading

0 comments on commit 4a86c89

Please sign in to comment.