Skip to content

Commit

Permalink
GlobalMaxAvgPooling2d
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Mar 30, 2024
1 parent 594b299 commit 068da09
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions pytorch_toolbelt/modules/pooling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Implementation of different pooling modules
"""

from typing import Union, Dict

import torch
Expand All @@ -18,6 +19,7 @@
"GlobalWeightedAvgPool2d",
"MILCustomPoolingModule",
"RMSPool",
"GlobalMaxAvgPooling2d",
]


Expand Down Expand Up @@ -203,3 +205,19 @@ def __repr__(self):
+ str(self.eps)
+ ")"
)


class GlobalMaxAvgPooling2d(nn.Module):
def __init__(self, flatten: bool = False):
super().__init__()
self.max_pooling = nn.AdaptiveMaxPool2d((1, 1))
self.avg_pooling = nn.AdaptiveAvgPool2d((1, 1))
self.flatten = flatten

def forward(self, x):
x_max = self.max_pooling(x).flatten(start_dim=1)
x_avg = self.avg_pooling(x).flatten(start_dim=1)
y = torch.cat([x_max, x_avg], dim=1)
if self.flatten:
y = torch.flatten(y, 1)
return y

0 comments on commit 068da09

Please sign in to comment.