diff --git a/python/paddle/tests/test_vision_models.py b/python/paddle/tests/test_vision_models.py index a25a8f373c29c..1795dca69c3e0 100644 --- a/python/paddle/tests/test_vision_models.py +++ b/python/paddle/tests/test_vision_models.py @@ -41,6 +41,9 @@ def test_mobilenetv2_pretrained(self): def test_mobilenetv1(self): self.models_infer('mobilenet_v1') + def test_alexnet(self): + self.models_infer('AlexNet') + def test_vgg11(self): self.models_infer('vgg11') diff --git a/python/paddle/vision/models/__init__.py b/python/paddle/vision/models/__init__.py index d38f3b1722ee8..4adfde4a90cd2 100644 --- a/python/paddle/vision/models/__init__.py +++ b/python/paddle/vision/models/__init__.py @@ -28,6 +28,7 @@ from .vgg import vgg16 # noqa: F401 from .vgg import vgg19 # noqa: F401 from .lenet import LeNet # noqa: F401 +from .alexnet import AlexNet # noqa: F401 __all__ = [ #noqa 'ResNet', @@ -45,5 +46,6 @@ 'mobilenet_v1', 'MobileNetV2', 'mobilenet_v2', - 'LeNet' + 'LeNet', + 'AlexNet' ] diff --git a/python/paddle/vision/models/alexnet.py b/python/paddle/vision/models/alexnet.py new file mode 100644 index 0000000000000..938f4f24fb5c4 --- /dev/null +++ b/python/paddle/vision/models/alexnet.py @@ -0,0 +1,71 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.nn as nn + +__all__ = [] + + +class AlexNet(nn.Layer): + def __init__(self, num_classes=10): + """AlexNet model architecture from the + `"One weird trick..." `_ paper. + + Args: + num_classes (int): output dim of the classifier. If num_classes <=0, the classifier + will not be defined. Default: 10. + + Examples: + .. code-block:: python + + from paddle.vision.models import AlexNet + + model = AlexNet() + """ + super(AlexNet, self).__init__() + self.num_classes = num_classes + self.features = nn.Sequential( + nn.Conv2D(3, 64, kernel_size=3, stride=2, padding=1), + nn.ReLU(), + nn.MaxPool2D(kernel_size=3, stride=2), + nn.Conv2D(64, 192, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.MaxPool2D(kernel_size=3, stride=2), + nn.Conv2D(192, 384, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.Conv2D(384, 256, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.Conv2D(256, 256, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.MaxPool2D(kernel_size=3, stride=2), + ) + + if num_classes > 0: + self.classifier = nn.Sequential( + nn.Dropout(), + nn.Linear(256 * 1 * 1, 4096), + nn.ReLU(), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(), + nn.Linear(4096, num_classes), + ) + + def forward(self, x): + x = self.features(x) + if self.num_classes > 0: + x = paddle.flatten(x, 1) + x = self.classifier(x) + return x \ No newline at end of file