Skip to content

Commit

Permalink
Fixing up resnetv1
Browse files Browse the repository at this point in the history
  • Loading branch information
dibyaghosh committed May 21, 2023
1 parent fc80b70 commit 6ef3f01
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions jaxrl_m/vision/resnet_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class ResNet(nn.Module):
num_filters: int = 64
dtype: Any = jnp.float32
act: Callable = nn.relu
output_fc_dim: int = None

@nn.compact
def __call__(self, x, train: bool = True):
Expand All @@ -94,6 +95,9 @@ def __call__(self, x, train: bool = True):
norm=norm,
act=self.act)(x)
x = jnp.mean(x, axis=(1, 2))
if self.output_fc_dim is not None:
x = nn.Dense(self.output_fc_dim, dtype=self.dtype)(x)

# x = nn.Dense(self.num_classes, dtype=self.dtype)(x)
# x = jnp.asarray(x, self.dtype)
# # x = nn.log_softmax(x) # to match the Torch implementation at https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
Expand Down Expand Up @@ -211,3 +215,5 @@ def convert_pytorch_to_jax(pytorch_statedict, jax_variables, resnet_type="50"):
'resnetv1-200': partial(ResNet, stage_sizes=[3, 24, 36, 3],
block_cls=BottleneckResNetBlock)
}

vanilla_resnetv1_configs['vip'] = partial(vanilla_resnetv1_configs['resnetv1-50'], output_fc_dim=1024)

0 comments on commit 6ef3f01

Please sign in to comment.