Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

APIs to import weights from external frameworks. #148

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions examples/vgg_pytorch/pytorch_vgg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import jax.numpy as jn
import torch
import torchvision

from objax.zoo import vgg


def delta(x, y): # pytoch, jax
return jn.abs(x.detach().numpy() - y).max()


mo = vgg.vgg16(use_bn=False)
vgg.load_pretrained_weights_from_pytorch(mo)
print(mo.vars())

mt = torchvision.models.vgg16(pretrained=True)
mt.eval() # Wow that's error prone
x = torch.randn((4, 3, 224, 224))
yt = mt(x) # (4, 1000)

for name, param in mt.state_dict().items():
print(f'{name:40s} {tuple(param.shape)}')

yo = mo(x.numpy(), training=False)
print('Max difference:', jn.abs(yt.detach().numpy() - yo).max())
2 changes: 1 addition & 1 deletion objax/functional/core/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from objax.constants import ConvPadding
from objax.typing import JaxArray, ConvPaddingInt
from objax.util import to_padding, to_tuple
from objax.util.util import to_padding, to_tuple


def average_pool_2d(x: JaxArray,
Expand Down
2 changes: 1 addition & 1 deletion objax/io/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import jax.numpy as jn
import numpy as np

from objax.util import Renamer
from objax.util.util import Renamer
from objax.variable import TrainRef, VarCollection


Expand Down
2 changes: 1 addition & 1 deletion objax/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from jax.interpreters.pxla import ShardedDeviceArray

from objax.typing import JaxArray
from objax.util import override_args_kwargs, positional_args_names
from objax.util.util import override_args_kwargs, positional_args_names
from objax.variable import BaseVar, RandomState, VarCollection


Expand Down
2 changes: 1 addition & 1 deletion objax/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""This module contains type declarations for Objax."""

__all__ = ['FileOrStr', 'JaxArray', 'JaxDType']
__all__ = ['ConvPaddingInt', 'FileOrStr', 'JaxArray', 'JaxDType']

from typing import Union, IO, BinaryIO, Sequence, Tuple

Expand Down
3 changes: 2 additions & 1 deletion objax/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from . import image
from . import check
from . import convert
from . import image
from .util import *
2 changes: 2 additions & 0 deletions objax/util/convert/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from . import pytorch
from .convert import *
28 changes: 28 additions & 0 deletions objax/util/convert/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
__all__ = ['assign', 'import_weights']

import re
from typing import Dict, Callable

import jax.numpy as jn
import numpy as np

from objax.variable import BaseVar, VarCollection


def assign(x: BaseVar, v: np.ndarray):
x.assign(jn.array(v.reshape(x.value.shape)))


def import_weights(target_vc: VarCollection,
source_numpy: Dict[str, np.ndarray],
source_names: Dict[str, str],
numpy_convert: Dict[str, Callable[[BaseVar, np.ndarray], None]]):
module_var = re.compile(r'.*(\([^)]*\)\.[^(]*)$')
for k, v in target_vc.items():
s = source_names[k]
t = module_var.match(k).group(1)
if s not in source_numpy:
print(f'Skipping {k} ({s})')
continue
assert t in numpy_convert, f'Unhandled name {k}'
numpy_convert[t](v, source_numpy[s])
26 changes: 26 additions & 0 deletions objax/util/convert/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
__all__ = ['ARRAY_CONVERT', 'rename']

import re

from .convert import assign

ARRAY_CONVERT = {
'(BatchNorm2D).beta': assign,
'(BatchNorm2D).gamma': assign,
'(BatchNorm2D).running_mean': assign,
'(BatchNorm2D).running_var': assign,
'(Conv2D).b': assign,
'(Conv2D).w': lambda x, y: assign(x, y.transpose((2, 3, 1, 0))),
'(Linear).b': assign,
'(Linear).w': lambda x, y: assign(x, y.T),
}


def rename(x):
x = x.replace('(BatchNorm2D).gamma', '(BatchNorm2D).weight').replace('(BatchNorm2D).beta', '(BatchNorm2D).bias')
x = re.sub(r'\([^)]*\)', '', x)
x = re.sub(r'^\.', '', x)
x = re.sub('.w$', '.weight', x)
x = re.sub('.b$', '.bias', x)
x = x.replace('[', '.').replace(']', '')
return x
2 changes: 1 addition & 1 deletion objax/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
import numpy as np

from objax.typing import JaxArray
from objax.util import map_to_device, Renamer
from objax.util.check import assert_assigned_type_and_shape_match
from objax.util.util import map_to_device, Renamer


def reduce_mean(x: JaxArray) -> JaxArray:
Expand Down
206 changes: 65 additions & 141 deletions objax/zoo/vgg.py
Original file line number Diff line number Diff line change
@@ -1,144 +1,68 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module with VGG-19 implementation.

See https://arxiv.org/abs/1409.1556 for detail.
"""

import functools
import os
from urllib import request

import jax.numpy as jn
import numpy as np
__all__ = ['VGG', 'load_pretrained_weights_from_pytorch', 'vgg11', 'vgg13', 'vgg16', 'vgg19']

from typing import Union, Sequence

import objax
from objax.util.convert import import_weights, pytorch


class VGG(objax.Module):
def __init__(self, nin: int, nout: int, ops: Sequence[Union[str, int]], use_bn: bool, name: str):
self.name = name + ('_bn' if use_bn else '')
self.ops = tuple(ops)
n = nin
self.features = objax.nn.Sequential()
for v in ops:
if v == 'M':
self.features.append(objax.functional.max_pool_2d)
continue
self.features.append(objax.nn.Conv2D(n, v, 3, padding=1))
if use_bn:
self.features.append(objax.nn.BatchNorm2D(v, momentum=0.1, eps=1e-5))
self.features.append(objax.functional.relu)
n = v

self.classifier = objax.nn.Sequential([objax.nn.Linear(512 * 7 * 7, 4096), objax.functional.relu,
objax.nn.Dropout(0.5),
objax.nn.Linear(4096, 4096), objax.functional.relu,
objax.nn.Dropout(0.5),
objax.nn.Linear(4096, nout)])

def __call__(self, *args, **kwargs):
features = objax.functional.flatten(self.features(*args, **kwargs))
return self.classifier(features, **kwargs)

def __repr__(self):
use_bn = self.name.endswith('_bn')
name = self.name[:-3] if use_bn else self.name
return f'{self.__class__.__name__}(nin={self.features[0].w.value.shape[2]}, ' \
f'nout={self.features[0].w.value.shape[3]}, ops={self.ops}, use_bn={use_bn}, name={repr(name)})'


def load_pretrained_weights_from_pytorch(m: VGG):
import torchvision
torch_model = getattr(torchvision.models, m.name)(pretrained=True)
torch_model.eval() # Just a safety precaution.
numpy_arrays = {name: param.numpy() for name, param in torch_model.state_dict().items()}
numpy_names = {k: pytorch.rename(k) for k in m.vars().keys()}
import_weights(m.vars(), numpy_arrays, numpy_names, pytorch.ARRAY_CONVERT)


def vgg11(use_bn: bool):
ops = 64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'
return VGG(3, 1000, ops, use_bn=use_bn, name='vgg11')


def vgg13(use_bn: bool):
ops = 64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'
return VGG(3, 1000, ops, use_bn=use_bn, name='vgg13')


def vgg16(use_bn: bool):
ops = 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'
return VGG(3, 1000, ops, use_bn=use_bn, name='vgg16')


_VGG19_URL = 'https://github.com/machrisaa/tensorflow-vgg'
_VGG19_NPY = './objax/zoo/pretrained/vgg19.npy'
_SYNSET_URL = 'https://raw.githubusercontent.com/machrisaa/tensorflow-vgg/master/synset.txt'
_SYNSET_PATH = './objax/zoo/pretrained/synset.txt'


def preprocess(x):
bgr_mean = [103.939, 116.779, 123.68]
red, green, blue = [x[:, i, :, :] for i in range(3)]
return jn.stack([blue - bgr_mean[0], green - bgr_mean[1], red - bgr_mean[2]], axis=1)


def max_pool_2d(x):
return functools.partial(objax.functional.max_pool_2d,
size=2, strides=2, padding=objax.constants.ConvPadding.VALID)(x)


class VGG19(objax.nn.Sequential):
"""VGG19 implementation."""

def __init__(self, pretrained=False):
"""Creates VGG19 instance.

Args:
pretrained: if True load weights from ImageNet pretrained model.
"""
if not os.path.exists(_VGG19_NPY):
raise FileNotFoundError(
'You must download vgg19.npy from %s and save it to %s' % (_VGG19_URL, _VGG19_NPY))
if not os.path.exists(_SYNSET_PATH):
request.urlretrieve(_SYNSET_URL, _SYNSET_PATH)
self.data_dict = np.load(_VGG19_NPY, encoding='latin1', allow_pickle=True).item()
self.pretrained = pretrained
self.ops = self.build()
super().__init__(self.ops)

def build(self):
# inputs in [0, 255]
self.preprocess = preprocess
self.conv1_1 = objax.nn.Conv2D(nin=3, nout=64, k=3)
self.relu1_1 = objax.functional.relu
self.conv1_2 = objax.nn.Conv2D(nin=64, nout=64, k=3)
self.relu1_2 = objax.functional.relu
self.pool1 = max_pool_2d

self.conv2_1 = objax.nn.Conv2D(nin=64, nout=128, k=3)
self.relu2_1 = objax.functional.relu
self.conv2_2 = objax.nn.Conv2D(nin=128, nout=128, k=3)
self.relu2_2 = objax.functional.relu
self.pool2 = max_pool_2d

self.conv3_1 = objax.nn.Conv2D(nin=128, nout=256, k=3)
self.relu3_1 = objax.functional.relu
self.conv3_2 = objax.nn.Conv2D(nin=256, nout=256, k=3)
self.relu3_2 = objax.functional.relu
self.conv3_3 = objax.nn.Conv2D(nin=256, nout=256, k=3)
self.relu3_3 = objax.functional.relu
self.conv3_4 = objax.nn.Conv2D(nin=256, nout=256, k=3)
self.relu3_4 = objax.functional.relu
self.pool3 = max_pool_2d

self.conv4_1 = objax.nn.Conv2D(nin=256, nout=512, k=3)
self.relu4_1 = objax.functional.relu
self.conv4_2 = objax.nn.Conv2D(nin=512, nout=512, k=3)
self.relu4_2 = objax.functional.relu
self.conv4_3 = objax.nn.Conv2D(nin=512, nout=512, k=3)
self.relu4_3 = objax.functional.relu
self.conv4_4 = objax.nn.Conv2D(nin=512, nout=512, k=3)
self.relu4_4 = objax.functional.relu
self.pool4 = max_pool_2d

self.conv5_1 = objax.nn.Conv2D(nin=512, nout=512, k=3)
self.relu5_1 = objax.functional.relu
self.conv5_2 = objax.nn.Conv2D(nin=512, nout=512, k=3)
self.relu5_2 = objax.functional.relu
self.conv5_3 = objax.nn.Conv2D(nin=512, nout=512, k=3)
self.relu5_3 = objax.functional.relu
self.conv5_4 = objax.nn.Conv2D(nin=512, nout=512, k=3)
self.relu5_4 = objax.functional.relu
self.pool5 = max_pool_2d

self.flatten = objax.functional.flatten
self.fc6 = objax.nn.Linear(nin=512 * 7 * 7, nout=4096)
self.relu6 = objax.functional.relu
self.fc7 = objax.nn.Linear(nin=4096, nout=4096)
self.relu7 = objax.functional.relu
self.fc8 = objax.nn.Linear(nin=4096, nout=1000)

if self.pretrained:
for it in self.data_dict:
if it.startswith('conv'):
conv = getattr(self, it)
kernel, bias = self.data_dict[it]
conv.w = objax.TrainVar(jn.array(kernel))
conv.b = objax.TrainVar(jn.array(bias[:, None, None]))
setattr(self, it, conv)
elif it.startswith('fc'):
linear = getattr(self, it)
kernel, bias = self.data_dict[it]
if it == 'fc6':
kernel = kernel.reshape([7, 7, 512, -1]).transpose((2, 0, 1, 3)).reshape([512 * 7 * 7, -1])
linear.w = objax.TrainVar(jn.array(kernel))
linear.b = objax.TrainVar(jn.array(bias))
setattr(self, it, linear)

ops = [self.conv1_1, self.relu1_1, self.conv1_2, self.relu1_2, self.pool1,
self.conv2_1, self.relu2_1, self.conv2_2, self.relu2_2, self.pool2,
self.conv3_1, self.relu3_1, self.conv3_2, self.relu3_2,
self.conv3_3, self.relu3_3, self.conv3_4, self.relu3_4, self.pool3,
self.conv4_1, self.relu4_1, self.conv4_2, self.relu4_2,
self.conv4_3, self.relu4_3, self.conv4_4, self.relu4_4, self.pool4,
self.conv5_1, self.relu5_1, self.conv5_2, self.relu5_2,
self.conv5_3, self.relu5_3, self.conv5_4, self.relu5_4, self.pool5,
self.flatten, self.fc6, self.relu6, self.fc7, self.relu7, self.fc8]

return ops
def vgg19(use_bn: bool):
ops = 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'
return VGG(3, 1000, ops, use_bn=use_bn, name='vgg19')