Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

APIs to import weights from external frameworks. #148

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions examples/vgg/keras_vgg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import numpy as np
import tensorflow as tf

import objax
from objax.zoo import vgg

mo = vgg.VGG16()
vgg.load_pretrained_weights_from_keras(mo)
print(mo.vars())

mk = tf.keras.applications.VGG16(weights='imagenet')
x = np.random.randn(4, 3, 224, 224)
yk = mk(x.transpose((0, 2, 3, 1))) # (4, 1000)

for name, param in ((weight.name, weight.numpy()) for layer in mk.layers for weight in layer.weights):
print(f'{name:40s} {tuple(param.shape)}')

yo = objax.functional.softmax(mo(x, training=False))
print('Max difference:', np.abs(yk - yo).max())
20 changes: 20 additions & 0 deletions examples/vgg/pytorch_vgg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import jax.numpy as jn
import torch
import torchvision

from objax.zoo import vgg

mo = vgg.VGG16()
vgg.load_pretrained_weights_from_pytorch(mo)
print(mo.vars())

mt = torchvision.models.vgg16(pretrained=True)
mt.eval() # Wow that's error prone
x = torch.randn((4, 3, 224, 224))
yt = mt(x) # (4, 1000)

for name, param in mt.state_dict().items():
print(f'{name:40s} {tuple(param.shape)}')

yo = mo(x.numpy(), training=False)
print('Max difference:', jn.abs(yt.detach().numpy() - yo).max())
2 changes: 1 addition & 1 deletion objax/functional/core/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from objax.constants import ConvPadding
from objax.typing import JaxArray, ConvPaddingInt
from objax.util import to_padding, to_tuple
from objax.util.util import to_padding, to_tuple


def average_pool_2d(x: JaxArray,
Expand Down
2 changes: 1 addition & 1 deletion objax/io/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import jax.numpy as jn
import numpy as np

from objax.util import Renamer
from objax.util.util import Renamer
from objax.variable import TrainRef, VarCollection


Expand Down
2 changes: 1 addition & 1 deletion objax/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from jax.interpreters.pxla import ShardedDeviceArray

from objax.typing import JaxArray
from objax.util import override_args_kwargs, positional_args_names
from objax.util.util import override_args_kwargs, positional_args_names
from objax.variable import BaseVar, RandomState, VarCollection


Expand Down
2 changes: 1 addition & 1 deletion objax/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""This module contains type declarations for Objax."""

__all__ = ['FileOrStr', 'JaxArray', 'JaxDType']
__all__ = ['ConvPaddingInt', 'FileOrStr', 'JaxArray', 'JaxDType']

from typing import Union, IO, BinaryIO, Sequence, Tuple

Expand Down
3 changes: 2 additions & 1 deletion objax/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from . import image
from . import check
from . import convert
from . import image
from .util import *
3 changes: 3 additions & 0 deletions objax/util/convert/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from . import keras
from . import pytorch
from .convert import *
28 changes: 28 additions & 0 deletions objax/util/convert/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
__all__ = ['assign', 'import_weights']

import re
from typing import Dict, Callable

import jax.numpy as jn
import numpy as np

from objax.variable import BaseVar, VarCollection


def assign(x: BaseVar, v: np.ndarray):
x.assign(jn.array(v.reshape(x.value.shape)))


def import_weights(target_vc: VarCollection,
source_numpy: Dict[str, np.ndarray],
target_to_source_names: Dict[str, str],
numpy_convert: Dict[str, Callable[[BaseVar, np.ndarray], None]]):
module_var = re.compile(r'.*(\([^)]*\)\.[^(]*)$')
for k, v in target_vc.items():
s = target_to_source_names[k]
t = module_var.match(k).group(1)
if s not in source_numpy:
print(f'Skipping {k} ({s})')
continue
assert t in numpy_convert, f'Unhandled name {k}'
numpy_convert[t](v, source_numpy[s])
27 changes: 27 additions & 0 deletions objax/util/convert/keras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
__all__ = ['ARRAY_CONVERT', 'rename']

import re

from .convert import assign

ARRAY_CONVERT = {
# '(BatchNorm2D).beta': assign,
# '(BatchNorm2D).gamma': assign,
# '(BatchNorm2D).running_mean': assign,
# '(BatchNorm2D).running_var': assign,
'(Conv2D).b': assign,
'(Conv2D).w': assign,
'(Linear).b': assign,
'(Linear).w': assign,
}


def rename(x):
# x = x.replace('(BatchNorm2D).gamma', '(BatchNorm2D).weight').replace('(BatchNorm2D).beta', '(BatchNorm2D).bias')
x = re.sub(r'\([^)]*\)', '', x)
x = re.sub(r'^\.', '', x)
x = re.sub(r'.w$', '/kernel', x)
x = re.sub(r'.b$', '/bias', x)
x = re.sub(r'\[|\]', '', x)
x = re.sub(r'\.', '_', x)
return x
26 changes: 26 additions & 0 deletions objax/util/convert/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
__all__ = ['ARRAY_CONVERT', 'rename']

import re

from .convert import assign

ARRAY_CONVERT = {
'(BatchNorm2D).beta': assign,
'(BatchNorm2D).gamma': assign,
'(BatchNorm2D).running_mean': assign,
'(BatchNorm2D).running_var': assign,
'(Conv2D).b': assign,
'(Conv2D).w': lambda x, y: assign(x, y.transpose((2, 3, 1, 0))),
'(Linear).b': assign,
'(Linear).w': lambda x, y: assign(x, y.T),
}


def rename(x):
x = x.replace('(BatchNorm2D).gamma', '(BatchNorm2D).weight').replace('(BatchNorm2D).beta', '(BatchNorm2D).bias')
x = re.sub(r'\([^)]*\)', '', x)
x = re.sub(r'^\.', '', x)
x = re.sub('.w$', '.weight', x)
x = re.sub('.b$', '.bias', x)
x = x.replace('[', '.').replace(']', '')
return x
2 changes: 1 addition & 1 deletion objax/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
import numpy as np

from objax.typing import JaxArray
from objax.util import map_to_device, Renamer
from objax.util.check import assert_assigned_type_and_shape_match
from objax.util.util import map_to_device, Renamer


def reduce_mean(x: JaxArray) -> JaxArray:
Expand Down
Loading