diff --git a/README.md b/README.md index c4a8dc4a7f..a351ab667f 100644 --- a/README.md +++ b/README.md @@ -241,7 +241,7 @@ If you have an addition to this list, please submit a pull request. - Parler-TTS, text-to-speech model. - Computer Vision Models. - DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT, - ConvNeXTv2, MobileOne, EfficientVit (MSRA), MobileNetv4, Hiera. + ConvNeXTv2, MobileOne, EfficientVit (MSRA), MobileNetv4, Hiera, FastViT. - yolo-v3, yolo-v8. - Segment-Anything Model (SAM). - SegFormer. diff --git a/candle-examples/examples/fastvit/README.md b/candle-examples/examples/fastvit/README.md index 499685bd3c..467e1032b1 100644 --- a/candle-examples/examples/fastvit/README.md +++ b/candle-examples/examples/fastvit/README.md @@ -12,9 +12,9 @@ $ cargo run --example fastvit --release -- --image candle-examples/examples/yolo loaded image Tensor[dims 3, 256, 256; f32] model built -mountain bike, all-terrain bike, off-roader: 43.45% -bicycle-built-for-two, tandem bicycle, tandem: 14.16% -unicycle, monocycle : 4.12% -crash helmet : 2.26% -alp : 1.40% +mountain bike, all-terrain bike, off-roader: 52.67% +bicycle-built-for-two, tandem bicycle, tandem: 7.93% +unicycle, monocycle : 3.46% +maillot : 1.32% +crash helmet : 1.28% ``` diff --git a/candle-transformers/src/models/fastvit.rs b/candle-transformers/src/models/fastvit.rs index a0b3cc3e57..8199874276 100644 --- a/candle-transformers/src/models/fastvit.rs +++ b/candle-transformers/src/models/fastvit.rs @@ -339,8 +339,8 @@ fn positional_encoding(dim: usize, vb: VarBuilder) -> Result> { fn attention(dim: usize, vb: VarBuilder) -> Result> { let qkv = linear_no_bias(dim, dim * 3, vb.pp("qkv"))?; let proj = linear(dim, dim, vb.pp("proj"))?; - let num_heads = 32; - let head_dim = dim / num_heads; + let head_dim = 32; + let num_heads = dim / head_dim; let scale = (head_dim as f64).powf(-0.5); Ok(Func::new(move |xs| { @@ -434,7 +434,7 @@ fn fastvit_patch_embed( ) -> Result> { let lk = conv_norm(in_channels, out_channels, 7, 2, vb.pp("proj.0.large_conv"))?; let sk = conv_norm(in_channels, out_channels, 3, 2, vb.pp("proj.0.small_conv"))?; - let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("proj.0.se")); + let se = squeeze_and_excitation(out_channels, out_channels / 4, vb.pp("proj.0.se")); let mb = mobileone_block(out_channels, out_channels, 1, 1, 0, true, vb.pp("proj.1"))?; Ok(Func::new(move |xs| {