diff --git a/jaxrl_m/vision/__init__.py b/jaxrl_m/vision/__init__.py index d975d47..1308796 100644 --- a/jaxrl_m/vision/__init__.py +++ b/jaxrl_m/vision/__init__.py @@ -1,9 +1,12 @@ from jaxrl_m.vision.impala import impala_configs -from jaxrl_m.vision.bigvision_resnetv2 import resnetv2_configs +from jaxrl_m.vision.resnet_v2 import resnetv2_configs from jaxrl_m.vision.small_encoders import small_configs from jaxrl_m.vision.bridge_resnet_v1 import bridge_resnetv1_configs from jaxrl_m.vision.resnet_v1 import vanilla_resnetv1_configs +from jaxrl_m.vision.mae import mae_model_configs +from jaxrl_m.vision.vit import vit_configs + from jaxrl_m.vision import data_augmentations encoders = dict() @@ -11,4 +14,7 @@ encoders.update(resnetv2_configs) encoders.update(bridge_resnetv1_configs) encoders.update(vanilla_resnetv1_configs) -encoders.update(small_configs) \ No newline at end of file +encoders.update(small_configs) + +encoders.update(mae_model_configs) +encoders.update(vit_configs) \ No newline at end of file diff --git a/jaxrl_m/vision/bigvision_common.py b/jaxrl_m/vision/bigvision_common.py deleted file mode 100644 index f6c307e..0000000 --- a/jaxrl_m/vision/bigvision_common.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2022 Big Vision Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utilities shared across models.""" - -from absl import logging -import jaxrl_m.vision.bigvision_utils as u -import flax.linen as nn -import jax -import jax.numpy as jnp - - -def merge_params(loaded, inited, dont_load=()): - """Makes `loaded` pytree match `init`, warning or failing on mismatch. - - Args: - loaded: pytree of parameters, typically loaded from a checkpoint. - inited: pytree of parameter, typically coming from model init. - dont_load: List of regexes for parameters which shall not be taken - from `loaded`, either because they should remain at their init value, - or because they are missing on either side. - - Returns: - If successful, a new pytree which matches the structure of `init` - but contains values from `loaded`, except for `dont_load`. - - If structures don't match and mismatches are not covered by regexes in - `dont_load` argument, then raises an exception with more information. - """ - dont_load = u.check_and_compile_patterns(dont_load) - - def should_merge(name): - return not any(pattern.fullmatch(name) for pattern in dont_load) - - loaded_flat, _ = u.tree_flatten_with_names(loaded) - inited_flat, _ = u.tree_flatten_with_names(inited) - loaded_flat = {k: v for k, v in loaded_flat} - inited_flat = {k: v for k, v in inited_flat} - - # Let's first build the pytree from all common keys. - merged = {} - for name, init_val in inited_flat.items(): - # param is present in both. Load or ignore it! - if name in loaded_flat and should_merge(name): - merged[name] = loaded_flat[name] - else: - logging.info("Ignoring checkpoint and using init value for %s", name) - merged[name] = init_val - - def pp(title, names, indent=" "): # Just pretty-printing - if names: - return f"{title}:\n" + "\n".join(f"{indent}{k}" for k in sorted(names)) - else: - return "" - - # Now, if there are keys that only exist in inited or loaded, be helpful: - not_in_loaded = inited_flat.keys() - loaded_flat.keys() - not_in_inited = loaded_flat.keys() - inited_flat.keys() - logging.info(pp("Parameters in model but not in checkpoint", not_in_loaded)) - logging.info(pp("Parameters in checkpoint but not in model", not_in_inited)) - - # And now see if any of them are not explicitly ignored => an error - not_in_loaded = {k for k in not_in_loaded if should_merge(k)} - not_in_inited = {k for k in not_in_inited if should_merge(k)} - - if not_in_loaded or not_in_inited: - raise ValueError( - pp("Params in checkpoint", loaded_flat.keys()) + "\n" + - pp("Params in model (code)", inited_flat.keys()) + "\n" + - pp("Params in model (code) but not in checkpoint and not `dont_load`ed", - not_in_loaded, indent=" - ") + "\n" + # Special indent for tests. - pp("Params in checkpoint but not in model (code) and not `dont_load`ed", - not_in_inited, indent=" + ")) # Special indent for tests. - - return u.recover_tree(merged.keys(), merged.values()) \ No newline at end of file diff --git a/jaxrl_m/vision/bigvision_utils.py b/jaxrl_m/vision/bigvision_utils.py deleted file mode 100644 index 8b4e73e..0000000 --- a/jaxrl_m/vision/bigvision_utils.py +++ /dev/null @@ -1,919 +0,0 @@ -# Copyright 2022 Big Vision Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utils very specific to this project, not generic.""" - -import collections -import contextlib -import dataclasses -import functools -import io -import json -import multiprocessing -import multiprocessing.pool -import os -import re -import sys -import time -from typing import Mapping - -from absl import flags -from absl import logging -# import einops -import flax -import flax.jax_utils as flax_utils -import jax -import jax.numpy as jnp -import ml_collections as mlc -import numpy as np - -# import tensorflow.io.gfile as gfile - -# Registry = pp_registry.Registry - - -# pylint: disable=logging-fstring-interpolation - - -def pad_shard_unpad(wrapped, static_argnums=(0,), static_argnames=()): - """Wraps a function with code that pads, shards, then un-shards, un-pads. - - Args: - wrapped: the function to be wrapped. Signature is `params, *args, *kwargs`. - static_argnums: indices of arguments to `wrapped` that should _not_ be - padded and sharded, but instead be forwarded as-is. The default is (0,) - because by far the most common use-case is to pass `params` first. - static_argnames: names of kwargs to `wrapped` that should _not_ be padded - and sharded, but instead be forwarded as-is. - - Returns: - A new function that pads and shards its arguments before passing them to - the wrapped function, and un-shards and un-pads the returned pytree. - - This is useful for calling a pmap'ed function with inputs that aren't - divisible by the number of devices. A typical use is: - @pad_shard_unpad - @jax.pmap - def forward(params, x): ... - - Notes: - The padding is done in host-memory before being passed to the function, and - the values returned by the function are transferred back to host memory. - - The returned function is augmented with a new keyword-only argument - `min_device_batch` that, if specified, forces padding inputs to at least - this size per device. This can be useful to avoid recompiles for the last - batch and reduce memory fragmentation. - """ - - def pad_shard_unpad_wrapper(*args, min_device_batch=None, **kw): - d = jax.local_device_count() # d = devices, b = batch - batch_sizes = ( - {a.shape[0] for i, a in enumerate(args) if i not in static_argnums} | - {v.shape[0] for k, v in kw.items() if k not in static_argnames}) - assert len(batch_sizes) == 1, f"Inconsistent batch-sizes: {batch_sizes}" - b = batch_sizes.pop() - - def maybe_pad(x, actually_pad=True): - if not actually_pad: return x # For call-site convenience below. - _, *shape = x.shape - db, rest = divmod(b, d) - if rest: - x = np.concatenate([x, np.zeros((d - rest, *shape), x.dtype)], axis=0) - db += 1 - if min_device_batch and db < min_device_batch: - x = np.concatenate( - [x, np.zeros((d * (min_device_batch - db), *shape), x.dtype)]) - db = min_device_batch - return x.reshape(d, db, *shape) - - args = [maybe_pad(a, i not in static_argnums) for i, a in enumerate(args)] - kw = {k: maybe_pad(v, k not in static_argnames) for k, v in kw.items()} - out = wrapped(*args, **kw) - - def unpad(x): - # Transfer back before cutting, to reduce on-device shape diversity. - return einops.rearrange(jax.device_get(x), "d b ... -> (d b) ...")[:b] - return jax.tree_map(unpad, out) - - return pad_shard_unpad_wrapper - - -def onehot(labels, num_classes, on_value=1.0, off_value=0.0): - x = (labels[..., None] == jnp.arange(num_classes)[None]) - x = jax.lax.select(x, jnp.full(x.shape, on_value), - jnp.full(x.shape, off_value)) - return x.astype(jnp.float32) - - -def npload(fname): - with open(fname, "rb") as f: - data = f.read() - return dict(np.load(io.BytesIO(data), allow_pickle=False)) - - -def load_checkpoint(tree, npz): - """Loads a jax pytree from a npz file. - - Args: - tree: deprecated, use None. - Bwd-compat for old format that only stored values: the pytree structure. - npz: Either path to the checkpoint file (.npz), or a dict-like. - - Returns: - A pytree that is the checkpoint. - """ - if isinstance(npz, str): # If not already loaded, then load. - npz = npload(npz) - keys, values = zip(*list(npz.items())) - if tree: - checkpoint = tree.unflatten(values) - else: - checkpoint = recover_tree(keys, values) - return checkpoint - - -def load_params(tree, npz): - """Loads a parameters from a npz checkpoint. - - Args: - tree: deprecated, use None. - Bwd-compat for old format that only stored values: the pytree structure. - npz: Either path to the checkpoint file (.npz), or a dict-like. - - Returns: - A pytree that is the checkpoint. - - Notes: - The filename can contain an indicator like `/path/to/file.npz:keyname`, in - which case ["opt"]["params"]["keyname"] will become ["opt"]["params"] in - the returned checkpoint. This allows ANY model that uses this function to - load itself from a checkpoint that contains multiple sub-models, such as - checkpoints generated from Argus or Distillation trainers. - """ - key = None # Whether we want to extract only a sub-key of the model. - if isinstance(npz, str): - if ((":" in npz and "://" not in npz) or # Like /path/to/file:subtree_name - ("://" in npz and npz.count(":") == 2)): # Like gs://path/to/file:sub - npz, key = npz.rsplit(":", 1) - checkpoint = load_checkpoint(tree, npz) - if "params" in checkpoint: - # Checkpoint with optax state (after cl/423007216). - params = checkpoint["params"] - elif "opt" in checkpoint: - # Checkpoint with Flax optimizer. - params = checkpoint["opt"]["target"] - else: - # When open-sourcing, we usually shared only the params directly. - params = checkpoint - if key is not None: - params = tree_get(params, key) - return params - - -def prefetch_scalar(it, nprefetch=1, devices=None): - n_loc_dev = len(devices) if devices else jax.local_device_count() - repl_iter = (np.ones(n_loc_dev) * i for i in it) - return flax_utils.prefetch_to_device(repl_iter, nprefetch, devices) - - -def sigmoid_xent(*, logits, labels, reduction=True): - # NOTE: This implementation is stable, see these two: - # (internal link) - # https://github.com/google/jax/issues/2140 - log_p = jax.nn.log_sigmoid(logits) - log_not_p = jax.nn.log_sigmoid(-logits) - nll = -jnp.sum(labels * log_p + (1. - labels) * log_not_p, axis=-1) - return jnp.mean(nll) if reduction else nll - - -def softmax_xent(*, logits, labels, reduction=True, kl=False, axis=-1): - log_p = jax.nn.log_softmax(logits, axis=axis) - nll = -jnp.sum(labels * log_p, axis=axis) - if kl: - nll += jnp.sum(labels * jnp.log(jnp.clip(labels, 1e-8)), axis=axis) - return jnp.mean(nll) if reduction else nll - - -def weighted_softmax_xent(*, - logits, - labels, - reduction=True, - weights=None, - label_smoothing=0.0, - normalize=True): - """Compute weighted cross entropy. - - Args: - logits: [batch, length, num_classes] float array. - labels: categorical targets [batch, length] int array. - reduction: reduce across batch dim. - weights: None or array of shape [batch, length]. - label_smoothing: label smoothing constant, used to determine the on and off - values. - normalize: normalize each "sentence" loss by the number of tokens in it. - - Returns: - Tuple of scalar loss and batch normalizing factor. - """ - if logits.ndim != labels.ndim + 1: - raise ValueError("Incorrect shapes. Got shape %s logits and %s targets" % - (str(logits.shape), str(labels.shape))) - vocab_size = logits.shape[-1] - confidence = 1.0 - label_smoothing - low_confidence = (1.0 - confidence) / (vocab_size - 1) - soft_targets = onehot( - labels, vocab_size, on_value=confidence, off_value=low_confidence) - - loss = -jnp.sum(soft_targets * jax.nn.log_softmax(logits), axis=-1) - - normalizing_factor = labels.shape[1] - if weights is not None: - loss = loss * weights - normalizing_factor = weights.sum(axis=1) - - loss = loss.sum(axis=1) - if normalize: - loss = loss / normalizing_factor - - return loss.mean() if reduction else loss - - -def accumulate_gradient(loss_and_grad_fn, params, images, labels, accum_steps): - """Accumulate gradient over multiple steps to save on memory.""" - # See (internal link) for details and experiments. - if accum_steps and accum_steps > 1: - assert images.shape[0] % accum_steps == 0, ( - f"Bad accum_steps {accum_steps} for batch size {images.shape[0]}") - step_size = images.shape[0] // accum_steps - l, g = loss_and_grad_fn(params, images[:step_size], labels[:step_size]) - def acc_grad_and_loss(i, l_and_g): - imgs = jax.lax.dynamic_slice(images, (i*step_size, 0, 0, 0), - (step_size,) + images.shape[1:]) - lbls = jax.lax.dynamic_slice(labels, (i*step_size, 0), - (step_size, labels.shape[1])) - li, gi = loss_and_grad_fn(params, imgs, lbls) - l, g = l_and_g - return (l + li, jax.tree_map(lambda x, y: x + y, g, gi)) - l, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g)) - return jax.tree_map(lambda x: x / accum_steps, (l, g)) - else: - return loss_and_grad_fn(params, images, labels) - - -def itstime(step, every_n_steps, total_steps, host=None, last=True, first=True, - drop_close_to_last=0.25): - """Returns True if it's time to execute an action. - - Args: - step: the current step representing "now". - every_n_steps: the action should run every this many steps. - total_steps: the step number of the last step of training. - host: host number. If provided, only run if we are this process. - last: whether to run on the last step or not. - first: whether to run on the first step or not. - drop_close_to_last: if a step would run, but is this close (in terms of - fraction of every_n_step) to the last one, skip. - - Returns: - True if the action should be executed, False if not. - """ - - # This logic avoids running `itstime` "a few" steps before the last step. - # Canonical example: don't save checkpoint 2 steps before the last, and then - # at the last again; it's pointless and checkpoint timing will time out. - close_to_last = False - if drop_close_to_last and every_n_steps: - close_to_last = abs(step - total_steps) < drop_close_to_last * every_n_steps - - is_host = host is None or jax.process_index() == host - is_step = every_n_steps and (step % every_n_steps == 0) and not close_to_last - is_last = every_n_steps and step == total_steps - is_first = every_n_steps and step == 1 - return is_host and (is_step or (last and is_last) or (first and is_first)) - - -def checkpointing_timeout(writer, timeout): - # Make sure checkpoint writing is not a bottleneck - if writer is not None: - try: - writer.get(timeout=timeout) - except multiprocessing.TimeoutError as e: - raise TimeoutError( - "Checkpoint writing seems to be a bottleneck. Make sure you do " - "not do something wrong, like writing checkpoints to a distant " - "cell. In a case you are OK with checkpoint writing being a " - "bottleneck, you can configure `checkpoint_timeout` parameter") from e - - -def hms(s): - """Format time in hours/minutes/seconds.""" - if s < 60: - return f"{s:.0f}s" - m, s = divmod(s, 60) - if m < 60: - return f"{m:.0f}m{s:.0f}s" - h, m = divmod(m, 60) - return f"{h:.0f}h{m:.0f}m" # Seconds intentionally omitted. - - -class Chrono: - """Measures time and reports progress, hyper-specific to our train loops. - - Some concepts: - 1. This differentiates between three "types" of time: - - training time: the time spent on actual training (fprop/bprop/update) - - program time: overall time the program runs, including all overheads - - pause time: the chronometer can be paused (eg during evals). - 2. This handles a "warmup": the first step is skipped for training time - purposes, as it includes significant compilation overheads, which distort - estimates. - 3. `accum`ulates (i.e. integrates) timings, and save/load them across - restarts. - """ - - def __init__(self): - self.program_start_time = time.time() - self.train_start_time = None - self.train_start_step = None # When we started timing (after warmup) - - self.prev_time = None - self.prev_step = None - - self.pause_start = None - self.paused_time = 0 - - self.warmup = 2 # How many calls to `tick` to skip. - self.load() # Inits accum integrators. - self.note = "Chrono n/a" - - def inform(self, first_step, total_steps, global_bs, steps_per_epoch): - """Provide some extra info that's only known later in the program.""" - self.prev_step = first_step - self.first_step = first_step - self.total_steps = total_steps - self.steps_per_epoch = steps_per_epoch - self.global_bs = global_bs - if total_steps: - self.note = f"Steps:{first_step}/{total_steps} [{first_step/total_steps:.1%}]" - - def tick(self, step, measure, write_note): - """A chronometer tick.""" - now = time.time() - - # We do always count examples, regardless of the timing-related warmup that - # happens a few lines below. - ds = step - self.prev_step # Steps between ticks - self.prev_step = step - self.accum_examples_seen += ds * self.global_bs - measure("examples_seen", self.accum_examples_seen) - measure("epoch", step / self.steps_per_epoch) - - # We take the start as the second time `tick` is called, so we avoid - # measuring the overhead of compilation and don't include it in time - # estimates. - if self.warmup > 1: - self.warmup -= 1 - write_note(self.note) # This can help debugging. - return - if self.warmup == 1: - self.train_start_time = self.prev_time = now - self.train_start_step = step - self.accum_program_time += now - self.program_start_time - self.paused_time = 0 # Drop pauses that happened before timing starts. - self.warmup = 0 - write_note(self.note) # This can help debugging. - return - - # Measurement with micro-timings of current training steps speed. - # Time between ticks (ignoring pause) - dt = now - self.prev_time - self.paused_time - ncores = jax.device_count() # Global device count - measure("img/sec/core", self.global_bs * ds / dt / ncores) - - # Accumulate (integrate) times, good for plots. - self.accum_train_time += dt - self.accum_pause_time += self.paused_time - self.accum_program_time += dt + self.paused_time - - # Convert to, and log as, core hours. - core_hours = self.accum_train_time * ncores / 60 / 60 - devtype = jax.devices()[0].device_kind - measure(f"core_hours_{devtype}", core_hours) - measure("core_hours", core_hours) # For convenience as x-axis in sweeps. - - # Progress note with "global" full-program average timings - # (eg in program-time minus warmup) - dt = now - self.train_start_time # Time elapsed since end of warmup. - steps_timed = step - self.train_start_step - steps_todo = self.total_steps - step - self.note = f"Steps:{step}/{self.total_steps} [{step/self.total_steps:.1%}]" - self.note += f"\nWalltime:{hms(self.accum_program_time)}" - self.note += f" ({hms(self.accum_pause_time)} eval)" - self.note += f"\nETA:{hms(dt / steps_timed * steps_todo)}" - self.note += f"\nTotal train time:{hms(dt / steps_timed * self.total_steps)}" - write_note(self.note) - - self.prev_time = now - self.paused_time = 0 - - def pause(self, wait_for=()): - assert self.pause_start is None, "Don't pause twice." - jax.block_until_ready(wait_for) - self.pause_start = time.time() - - def resume(self): - self.paused_time += time.time() - self.pause_start - self.pause_start = None - - def save(self): - return dict( - accum_program_time=self.accum_program_time, - accum_train_time=self.accum_train_time, - accum_pause_time=self.accum_pause_time, - accum_examples_seen=self.accum_examples_seen, - ) - - def load(self, ckpt={}): # pylint: disable=dangerous-default-value - self.accum_program_time = ckpt.get("accum_program_time", 0.0) - self.accum_train_time = ckpt.get("accum_train_time", 0.0) - self.accum_pause_time = ckpt.get("accum_pause_time", 0.0) - self.accum_examples_seen = ckpt.get("accum_examples_seen", 0) - - -def _traverse_with_names(tree, with_inner_nodes=False): - """Traverses nested dicts/dataclasses and emits (leaf_name, leaf_val).""" - if dataclasses.is_dataclass(tree): - tree = flax.serialization.to_state_dict(tree) - # Don't output the non-leaf nodes. If the optimizer doesn't have a state - # the tree leaves can be Nones which was interpreted as a leaf by this - # function but not by the other functions (like jax.tree_map). - if tree is None: - return - elif isinstance(tree, Mapping): - keys = sorted(tree.keys()) - for key in keys: - for path, v in _traverse_with_names(tree[key], with_inner_nodes): - yield (key + "/" + path).rstrip("/"), v - if with_inner_nodes: - yield "", tree - elif isinstance(tree, (list, tuple)): - for idx in range(len(tree)): - for path, v in _traverse_with_names(tree[idx], with_inner_nodes): - yield (str(idx) + "/" + path).rstrip("/"), v - if with_inner_nodes: - yield "", tree - else: - yield "", tree - - -def tree_flatten_with_names(tree): - """Populates tree_flatten with leaf names. - - This function populates output of tree_flatten with leaf names, using a - custom traversal that produces names is provided. The custom traversal does - NOT have to traverse tree in the same order as jax, as we take care of - automatically aligning jax' and custom traversals. - - Args: - tree: python tree. - - Returns: - A list of values with names: [(name, value), ...] - """ - vals, tree_def = jax.tree_flatten(tree) - - # "Fake" token tree that is use to track jax internal tree traversal and - # adjust our custom tree traversal to be compatible with it. - tokens = range(len(vals)) - token_tree = tree_def.unflatten(tokens) - val_names, perm = zip(*_traverse_with_names(token_tree)) - inv_perm = np.argsort(perm) - - # Custom traverasal should visit the same number of leaves. - assert len(val_names) == len(vals) - - return [(val_names[i], v) for i, v in zip(inv_perm, vals)], tree_def - - -def tree_map_with_names(f, tree, *rest): - """Like jax.tree_map but with a filter on the leaf path name. - - Args: - f: A function with first parameter `name` (path-like "a/b/c") and remaining - parameters values of `tree` and `*rest` corresponding to the given `name` - Should return a new value for parameter `name`. - tree: The tree of parameters `f` should be applied to. - *rest: more trees of the exact same structure. - - Returns: - A tree identical in structure to `tree` and `*rest` but with the leaves the - result of calling `f` on corresponding name/leaves in `tree` and `*rest`. - """ - names_and_vals, tree_def = tree_flatten_with_names(tree) - names, vals = zip(*names_and_vals) - rest_vals = [list(zip(*tree_flatten_with_names(t)[0]))[1] for t in rest] - vals = [f(*name_and_vals) for name_and_vals in zip(names, vals, *rest_vals)] - return tree_def.unflatten(vals) - - -def tree_map_with_regex(f, tree, regex_rules, not_f=lambda x: x, name=None): - """Apply jax-style tree_map based on regex rules. - - Args: - f: a function that is being applied to every variable. - tree: jax tree of arrays. - regex_rules: a list of tuples `(pattern, args)`, where `pattern` is a regex - which used for variable matching and `args` are positional arguments - passed to `f`. If some variable is not matched, we apply `not_f` transform - which is id by default. If multiple patterns match, then only the first - rule is applied. - not_f: optional function which is applied to variables that do not match any - pattern. - name: a name of transform for logging purposes. - - Returns: - a tree, transformed by `f` according to the given rules. - """ - def _f(vname, v): - for pattern, arg in regex_rules: - if re.fullmatch(pattern, vname): - if name and jax.process_index() == 0: - logging.info("Applying %s to %s with %s due to `%s`", - name, vname, arg, pattern) - return f(v, arg) - return not_f(v) - return tree_map_with_names(_f, tree) - - -def tree_get(tree, name): - """Get an entry of pytree by flattened key name, eg a/b/c, with nice error. - - Args: - tree: the pytree to be queried. - name: the path to extract from the tree, see below for examples. - - Returns: - A few examples: - tree = {'a': 1, 'b': {'c': 2, 'd': 3}} - tree_get(tree, 'a') == 1 - tree_get(tree, 'b/c') == 2 - tree_get(tree, 'b') == {'c': 2, 'd': 3} - """ - flattened = dict(_traverse_with_names(tree, with_inner_nodes=True)) - try: - return flattened[name] - except KeyError as e: - class Msg(str): # Reason: https://stackoverflow.com/a/70114007/2366315 - def __repr__(self): - return str(self) - msg = "\n".join([name, "Available keys:", *flattened, ""]) - # Turn into configdict to use its "did you mean?" error message! - msg = mlc.ConfigDict(flattened)._generate_did_you_mean_message(name, msg) # pylint: disable=protected-access - raise KeyError(Msg(msg)) from e - - -def recover_dtype(a): - """Numpy's `save` stores bfloat16 type as "void" type, so we recover it.""" - if hasattr(a, "dtype") and a.dtype.type is np.void: - assert a.itemsize == 2, "Unknown dtype!" - return a.view(jax.numpy.bfloat16) - else: - return a - - -# Checkpoint names encode tree structure, you can check out this colab for an -# example of how to recover tree structure from names: -# (internal link) -def save_checkpoint(checkpoint, path, step_copy=None, compressed=False): - """Util for checkpointing: saves jax pytree objects to the disk. - - Args: - checkpoint: arbitrary jax pytree to be saved. - path: a path to save the checkpoint. - step_copy: creates a copy of the checkpoint with `path-{step_copy}` name. - compressed: whether to use np.savez or np.savez_compressed, useful if saving - large buffers that are easily compressed (e.g. repeated or integers). - """ - names_and_vals, _ = tree_flatten_with_names(checkpoint) - io_buffer = io.BytesIO() - - if compressed: - np.savez_compressed(io_buffer, **{k: v for k, v in names_and_vals}) - else: - np.savez(io_buffer, **{k: v for k, v in names_and_vals}) - - # In order to be robust to interruptions we first save checkpoint to the - # temporal file and then move to actual path name. - path_tmp = path + "-TEMPORARY" - with gfile.GFile(path_tmp, "wb") as f: - f.write(io_buffer.getvalue()) - gfile.rename(path_tmp, path, overwrite=True) - - if step_copy is not None: - gfile.copy(path, f"{path}-{step_copy:09d}", overwrite=True) - - -def recover_tree(keys, values): - """Recovers a tree as a nested dict from flat names and values. - - This function is useful to analyze checkpoints that are saved by our programs - without need to access the exact source code of the experiment. In particular, - it can be used to extract an reuse various subtrees of the scheckpoint, e.g. - subtree of parameters. - - Args: - keys: a list of keys, where '/' is used as separator between nodes. - values: a list of leaf values. - - Returns: - A nested tree-like dict. - """ - tree = {} - sub_trees = collections.defaultdict(list) - for k, v in zip(keys, values): - if "/" not in k: - tree[k] = v - else: - k_left, k_right = k.split("/", 1) - sub_trees[k_left].append((k_right, v)) - for k, kv_pairs in sub_trees.items(): - k_subtree, v_subtree = zip(*kv_pairs) - tree[k] = recover_tree(k_subtree, v_subtree) - return tree - - -def create_learning_rate_schedule( - global_batch_size, total_steps, steps_per_epoch=None, - base=0.0, decay_type="stair", - scale_with_batchsize=False, - warmup_steps=0, cooldown_steps=0, - warmup_epochs=0, cooldown_epochs=0, - **kw): - """Creates learning rate schedule, see (internal link) - - Args: - global_batch_size: The global batch-size optionally used for scaling. - total_steps: The total number of steps to run. - steps_per_epoch: How many steps form an epoch. Needed only if anything is - passed in terms of epochs. - base: The starting learning-rate (without warmup). - decay_type: 'linear' or 'cosine', 'rsqrt', 'stair'. - scale_with_batchsize: Whether or not to scale lr automatically. - warmup_steps: how many steps to warm up for. - cooldown_steps: how many steps to cool down for. - warmup_epochs: how many epochs to warm up for. - cooldown_epochs: how many epochs to cool down for. - **kw: extra arguments specific to individual decay_types. - - Returns: - A function learning_rate(step): float -> {"learning_rate": float}. - """ - - # For convenience, convert {warmup,cooldown}_epochs to _steps. - assert bool(warmup_epochs) + bool(warmup_steps) < 2, "Only one!" - assert bool(cooldown_epochs) + bool(cooldown_steps) < 2, "Only one!" - if warmup_epochs: - warmup_steps = warmup_epochs * steps_per_epoch - # Early catch hard to backtrack errors due to warmup_steps >= total_steps, - # but let it run for 0 and 1 steps used to eval and debug runs. - assert (total_steps <= 1) or (warmup_steps < total_steps), ( - "warmup_steps is >= total_steps") - if cooldown_epochs: - cooldown_steps = cooldown_epochs * steps_per_epoch - - def step_fn(step): - """Step to learning rate function.""" - lr = base - - # This implements the linear scaling rule following - # Goyal et al. at arxiv.org/abs/1706.02677. - # The reference batch size in literature is 256, so we scale the lr to - # adjust to the literature lr when bach_size changes. - if scale_with_batchsize: - lr = lr * global_batch_size / 256.0 - - progress = (step - warmup_steps) / float(total_steps - warmup_steps) - progress = jnp.clip(progress, 0.0, 1.0) - if decay_type in ("linear", "polynomial"): - power = kw.get("power", 1) - zero = kw.get("end", kw.get("linear_end", 0)) - lr = zero + (lr - zero) * (1.0 - progress) ** power - elif decay_type == "cosine": - lr = lr * 0.5 * (1. + jnp.cos(jnp.pi * progress)) - elif decay_type == "rsqrt": - timescale = kw.get("timescale", 10_000) - shift = timescale - warmup_steps - lr = jnp.where( - warmup_steps < step, lr / jnp.sqrt((step + shift) / timescale), lr) - elif decay_type == "stair": - i = jnp.searchsorted(jnp.array(kw.get("steps", [])), step + 1) - lr = lr * jnp.take(jnp.array([1.0] + list(kw.get("mults", []))), i) - else: - raise ValueError(f"Unknown lr type {decay_type}") - - if warmup_steps: - lr = lr * jnp.minimum(1., step / warmup_steps) - if cooldown_steps: - lr = lr * jnp.minimum(1., (total_steps - step) / cooldown_steps) - - return jnp.asarray(lr, dtype=jnp.float32) - - return step_fn - - -def mixup(rng, *things, p=0.1, fold_in=None, n=2, **more_things): - """Perform mixup https://arxiv.org/abs/1710.09412. - - Args: - rng: The random key to use. - *things: further arguments are the arrays to be mixed. - p: the beta/dirichlet concentration parameter, typically 0.1 or 0.2. - fold_in: One of None, "host", "device", or "sample". Whether to sample a - global mixing coefficient, one per host, one per device, or one per - example, respectively. The latter is usually a bad idea. - n: with how many other images an image is mixed. Default mixup is n=2. - **more_things: further kwargs are arrays to be mixed. See also (internal link) - for further experiments and investigations. - - Returns: - A new rng key. A list of mixed *things. A dict of mixed **more_things. - """ - rng, rng_m = jax.random.split(rng, 2) - if fold_in == "host": - rng_m = jax.random.fold_in(rng_m, jax.process_index()) - elif fold_in in ("device", "sample"): - rng_m = jax.random.fold_in(rng_m, jax.lax.axis_index("batch")) - ashape = (len(things[0]),) if fold_in == "sample" else (1,) - alpha = jax.random.dirichlet(rng_m, jnp.array([p]*n), ashape) - # Sort alpha values in decreasing order. This avoids destroying examples when - # the concentration parameter p is very small, due to Dirichlet's symmetry. - alpha = -jnp.sort(-alpha, axis=-1) - def mix(batch): - if batch is None: return None # For call-side convenience! - def mul(a, b): # B * BHWC -> B111 * BHWC - return b * jnp.expand_dims(a, tuple(range(1, b.ndim))) - return sum(mul(alpha[:, i], jnp.roll(batch, i, axis=0)) for i in range(n)) - return rng, map(mix, things), {k: mix(v) for k, v in more_things.items()} - - -def sync_all_hosts(): - """Makes sure all hosts are synced.""" - if jax.process_count() > 1: - x = jnp.ones([jax.local_device_count()]) - x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, "i"), "i")(x)) - assert x[0] == jax.device_count() - - -def check_and_compile_patterns(patterns): - """Validates and compiles a list of param-patterns. - - The validation consists of checking for common mistakes, currently only that - the pattern does not start with a slash, because unlike FLAX, our parameter - names don't start with a slash. - - Args: - patterns: a single (string) pattern (regex), or a list of patterns. - - Returns: - A list of compiled and verified regexes. - """ - if isinstance(patterns, str): - patterns = [patterns] - - assert isinstance(patterns, (list, tuple)), patterns - - def check_and_compile(pattern): - assert not pattern.startswith("/"), ( - f"Big vision parameter names never start with '/': '{pattern}") - return re.compile(pattern) - - return list(map(check_and_compile, patterns)) - - -def make_mask_trees(tree, patterns, *, log=None): - """Returns a boolean mask tree for every pattern (only first match).""" - compiled_patterns = check_and_compile_patterns(patterns) - - def matchfirst(name, _): - matches = [] - for pattern in compiled_patterns: - matches.append(not any(matches) and bool(pattern.fullmatch(name))) - if log is not None and True in matches and jax.process_index() == 0: - logging.info("%s: %s - matched by %s", log, name, - patterns[matches.index(True)]) - return np.array(matches) - - multimask = tree_map_with_names(matchfirst, tree) - return [ - jax.tree_map(lambda matches, i=idx: matches[i], multimask) - for idx in range(len(patterns)) - ] - - -@contextlib.contextmanager -def profile(name, ttl=3 * 365 * 24 * 3600): - sess = startstop_prof_at_steps(None, name=name, ttl=ttl) - yield - startstop_prof_at_steps(sess, name=name, ttl=ttl) - - -def startstop_prof(sess, step=None, first_step=0, - log_steps=1, surround=20, **kw): - """Runs the profiler for `surround` steps around the next `log_steps`.""" - first_log = first_step + log_steps - (first_step % log_steps) - # don't start before first! - start = max(first_log - surround//2, first_step + 1) - return startstop_prof_at_steps(sess, step, start, start + surround, **kw) - - -def startstop_prof_at_steps( - sess, step=None, first_step=None, last_step=None, - name="steps", ttl=3 * 365 * 24 * 3600): - del sess, step, first_step, last_step, name, ttl - pass # TODO: implement using `jax.profiler` API. Needs workdir. - - -# This is a very minimal variant for open-sourcing. Our internal code makes use -# of multiple internal logging tools instead. -class BigVisionMetricWriter: - """A class for logging metrics.""" - - def __init__(self, xid=-1, wid=-1, workdir=None): - self.step_start(0) - if jax.process_index() != 0: return # Only one host shall write stuff. - - self.pool = multiprocessing.pool.ThreadPool(1) # 1 is important here. - self.fname = None - if workdir: - if xid != -1 and wid != -1: - self.fname = os.path.join(workdir, - f"big_vision_{xid}_{wid}_metrics.txt") - else: - self.fname = os.path.join(workdir, "big_vision_metrics.txt") - - def step_start(self, step): - self.step = step - self.step_metrics = {} - - def measure(self, name, value): - """Logs the metric value.""" - if jax.process_index() != 0: return # Only one host shall write stuff. - - # Convenience for accepting scalar np/DeviceArrays, as well as N-d single - # scalars, like [[[123]]] or similar, avoiding silly mistakes. - value = np.array(value).squeeze() - - # If the value is a scalar, we keep it in mind to append a line to the logs. - # If it has any structure, we instead just log its shape. - value = float(value) if value.ndim == 0 else value.shape - - logging.info(f"\u001b[35m[{self.step}]\u001b[0m {name} = {value}") - logging.flush() - self.step_metrics[name] = value - - return value # Just for convenience - - def step_end(self): - """Ends a training step, write its full row.""" - if not self.step_metrics: return - - def write(metrics): - with gfile.GFile(self.fname, "a") as f: - f.write(json.dumps({"step": self.step, **metrics}) + "\n") - - if self.fname: - self.pool.apply(lambda: None) # Potentially wait for past writes. - self.pool.apply_async(write, (self.step_metrics,)) - - def close(self): - self.step_end() - if jax.process_index() == 0: - self.pool.close() - self.pool.join() - - -def maybe_cleanup_workdir(workdir, cleanup, info): - """Potentially removes workdirs at end of run for cleanup.""" - if not workdir: - return - - if not cleanup: - info("Logs/checkpoints are in %s", workdir) - elif jax.process_index() == 0: - gfile.rmtree(workdir) - try: # Only need this on the last work-unit, if already empty. - gfile.remove(os.path.join(workdir, "..")) - except tf.errors.OpError: - pass \ No newline at end of file diff --git a/jaxrl_m/vision/mae.py b/jaxrl_m/vision/mae.py new file mode 100644 index 0000000..03836e8 --- /dev/null +++ b/jaxrl_m/vision/mae.py @@ -0,0 +1,726 @@ +# Implementation modified from https://github.com/young-geng/m3ae_public/tree/master/m3ae + +from typing import Callable, Optional, Any + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp + +from functools import partial + +LayerNorm = partial(nn.LayerNorm, epsilon=1e-5) + +def mask_union(mask1, mask2): + return jnp.logical_or(mask1 > 0, mask2 > 0).astype(jnp.float32) + + +def mask_intersection(mask1, mask2): + return jnp.logical_and(mask1 > 0, mask2 > 0).astype(jnp.float32) + + +def mask_not(mask): + return 1.0 - mask + + +def mask_select(mask, this, other=None): + if other is None: + other = jnp.array(0, dtype=this.dtype) + if len(this.shape) == 3: + mask = jnp.expand_dims(mask, axis=-1) + return jnp.where(mask == 0.0, this, other) + + +def no_mask(x): + return jnp.zeros(x.shape[:2]) + + +def all_mask(x): + return jnp.ones(x.shape[:2]) + + +def cross_entropy_loss_and_accuracy(logits, tokens, valid=None): + if valid is None: + valid = all_mask(tokens) + valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-5) + + token_log_prob = jnp.squeeze( + jnp.take_along_axis( + jax.nn.log_softmax(logits, axis=-1), + jnp.expand_dims(tokens, -1), + axis=-1, + ), + -1, + ) + token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0)) + loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length) + correct = jnp.where( + valid > 0.0, + jnp.argmax(logits, axis=-1) == tokens, + jnp.array(False) + ) + accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length) + return loss, accuracy + + +def patch_mse_loss(patch_output, patch_target, valid=None): + if valid is None: + valid = all_mask(patch_target) + valid_ratio = jnp.sum(valid, axis=-1) / valid.shape[-1] + return jnp.mean( + jnp.mean( + jnp.where( + valid > 0.0, + jnp.mean(jnp.square(patch_target - patch_output), axis=-1), + jnp.array(0.0), + ), + axis=-1, + ) / valid_ratio + ) + + +def extract_patches(inputs, patch_size): + batch, height, width, channels = inputs.shape + height, width = height // patch_size, width // patch_size + x = jnp.reshape(inputs, (batch, height, patch_size, width, patch_size, channels)) + x = jnp.swapaxes(x, 2, 3) + x = jnp.reshape(x, (batch, height * width, patch_size**2 * channels)) + return x + + +def merge_patches(inputs, patch_size): + batch, length, _ = inputs.shape + height = width = int(length**0.5) + x = jnp.reshape(inputs, (batch, height, width, patch_size, patch_size, -1)) + x = jnp.swapaxes(x, 2, 3) + x = jnp.reshape(x, (batch, height * patch_size, width * patch_size, -1)) + return x + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + assert embed_dim % 2 == 0 + omega = jnp.arange(embed_dim // 2, dtype=jnp.float32) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = jnp.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = jnp.sin(out) # (M, D/2) + emb_cos = jnp.cos(out) # (M, D/2) + + emb = jnp.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_1d_sincos_pos_embed(embed_dim, length): + return jnp.expand_dims( + get_1d_sincos_pos_embed_from_grid( + embed_dim, jnp.arange(length, dtype=jnp.float32) + ), + 0 + ) + + +def get_2d_sincos_pos_embed(embed_dim, length): + grid_size = int(length ** 0.5) + assert grid_size * grid_size == length + def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + emb = jnp.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + grid_h = jnp.arange(grid_size, dtype=jnp.float32) + grid_w = jnp.arange(grid_size, dtype=jnp.float32) + grid = jnp.meshgrid(grid_w, grid_h) # here w goes first + grid = jnp.stack(grid, axis=0) + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + return jnp.expand_dims(pos_embed, 0) + + +def index_sequence(x, ids): + return x[:, ids, ...] + + +def random_masking(x, rng, keep_len, padding_mask=None): + batch, length, dim = x.shape + noise = jax.random.uniform(rng, (length,), dtype=jnp.float32) + ids_shuffle = jnp.argsort(noise, axis=0) + ids_restore = jnp.argsort(ids_shuffle, axis=0) + kept = index_sequence(x, ids_shuffle[:keep_len]) + mask = jnp.ones([batch, length], dtype=jnp.float32) + mask = mask.at[:, :keep_len].set(0.0) + mask = index_sequence(mask, ids_restore) + + if padding_mask is None: + return kept, mask, ids_restore + + padding_mask_kept = index_sequence(padding_mask, ids_shuffle[:keep_len]) + return kept, mask, ids_restore, padding_mask_kept + +from typing import Tuple +class PatchEmbed(nn.Module): + patch_size: int = 16 + in_chans: int = 3 + embed_dim: int = 768 + + def setup(self): + self.proj = nn.Conv(features=self.embed_dim, + kernel_size=(self.patch_size, self.patch_size), + strides=(self.patch_size, self.patch_size), + padding=0) + + + def __call__(self, x): + B, H, W, C = x.shape + x = self.proj(x) # B, H // patch_size, W // patch_size, self.embed_dim + x = jnp.reshape(x, (B, -1, self.embed_dim)) + return x + + +class MLP(nn.Module): + hidden_dim: int + output_dim: int + depth: int + input_norm: bool = True + + @nn.compact + def __call__(self, inputs): + x = inputs + if self.input_norm: + x = LayerNorm()(x) + + for i in range(self.depth): + y = nn.Dense(self.hidden_dim, kernel_init=nn.initializers.xavier_uniform())(x) + y = nn.gelu(y) + y = LayerNorm()(y) + if i > 0: + x = x + y + else: + x = y + + x = nn.Dense(self.output_dim, kernel_init=nn.initializers.xavier_uniform())(x) + return x + + +class DropPath(nn.Module): + dropout_prob: float = 0.0 + deterministic: Optional[bool] = None + + @nn.compact + def __call__(self, input, deterministic=None): + deterministic = nn.merge_param( + "deterministic", self.deterministic, deterministic + ) + if deterministic: + return input + keep_prob = 1 - self.dropout_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) + rng = self.make_rng("drop_path") + random_tensor = keep_prob + jax.random.uniform(rng, shape, dtype=jnp.float32) + random_tensor = jnp.floor(random_tensor) + return jnp.divide(input, keep_prob) * random_tensor + + +class TransformerMLP(nn.Module): + dim: int = 256 + out_dim: int = 256 + dropout: float = 0.0 + kernel_init: Callable = nn.initializers.xavier_uniform() + + @nn.compact + def __call__(self, inputs, deterministic=None): + x = nn.Dense( + self.dim, kernel_init=self.kernel_init, name="fc1" + )(inputs) + + x = nn.gelu(x) + x = nn.Dropout(self.dropout)(x, deterministic) + x = nn.Dense( + self.out_dim, kernel_init=self.kernel_init, name="fc2" + )(x) + x = nn.Dropout(self.dropout)(x, deterministic) + + return x + + +class Attention(nn.Module): + """Modified from flax_models to support mask""" + + dim: int + num_heads: int = 8 + use_bias: bool = False + att_drop: float = 0 + proj_drop: float = 0 + kernel_init: Callable = nn.initializers.xavier_uniform() + deterministic: Optional[bool] = None + + @nn.compact + def __call__(self, inputs, deterministic=None, padding_mask=None): + deterministic = nn.merge_param( + "deterministic", self.deterministic, deterministic + ) + batch, n, channels = inputs.shape + scale = (self.dim // self.num_heads) ** -0.5 + qkv = nn.Dense( + self.dim * 3, + use_bias=self.use_bias, + kernel_init=self.kernel_init, + name='qkv' + )(inputs) + qkv = jnp.reshape( + qkv, (batch, n, 3, self.num_heads, channels // self.num_heads) + ) + qkv = jnp.transpose(qkv, (2, 0, 3, 1, 4)) + q, k, v = qkv[0], qkv[1], qkv[2] + + attention = (q @ jnp.swapaxes(k, -2, -1)) * scale + + if padding_mask is not None: + padding_mask = jnp.expand_dims(jnp.expand_dims(padding_mask, 1), 1) + padding_mask = jnp.broadcast_to(padding_mask, attention.shape) + attention = jnp.where(padding_mask > 0, jnp.array(-1e7), attention) + + attention = nn.softmax(attention, axis=-1) + self.sow('intermediates', 'attention', attention) + attention = nn.Dropout(self.att_drop)(attention, deterministic) + + x = (attention @ v).swapaxes(1, 2).reshape(batch, n, channels) + x = nn.Dense( + self.dim, kernel_init=nn.initializers.xavier_uniform(), + name='proj' + )(x) + x = nn.Dropout(self.proj_drop)(x, deterministic) + + return x + + +class Block(nn.Module): + emb_dim: int = 256 + num_heads: int = 8 + mlp_ratio: int = 4 + att_drop: float = 0.0 + drop: float = 0.0 + drop_path: float = 0.0 + + @nn.compact + def __call__(self, inputs, deterministic=False, padding_mask=None): + x = LayerNorm(name='norm1')(inputs) + x = Attention( + self.emb_dim, self.num_heads, True, self.att_drop, self.drop, + name='attn' + )(x, deterministic, padding_mask) + x = DropPath(self.drop_path)(x, deterministic) + inputs = inputs + x + + x = LayerNorm(name='norm2')(inputs) + x = TransformerMLP( + self.emb_dim * self.mlp_ratio, self.emb_dim, self.drop, + name='mlp' + )(x, deterministic) + x = DropPath(self.drop_path)(x, deterministic) + return inputs + x + + +class Transformer(nn.Module): + emb_dim: int = 1024 + depth: int = 24 + att_drop: float = 0 + drop: float = 0 + drop_path: float = 0 + num_heads: int = 16 + mlp_ratio: int = 4 + + @nn.compact + def __call__(self, x, deterministic=False, padding_mask=None): + for n in range(self.depth): + x = Block( + self.emb_dim, + self.num_heads, + self.mlp_ratio, + self.att_drop, + self.drop, + self.drop_path, + name=f'blocks_{n}' + )(x, deterministic, padding_mask) + print(x[0, 0, :5]) + + x = LayerNorm(name='norm')(x) + return x + +class MaskedAutoencoder(nn.Module): + emb_dim: int = 1024 + dec_emb_dim: int = 512 + depth: int = 24 + dec_depth: int = 8 + num_heads: int = 16 + dec_num_heads: int = 16 + mlp_ratio: int = 4 + + output_head_depth: int = 0 + att_drop: float = 0.0 + drop: float = 0.0 + drop_path: float = 0.0 + + image_mask_ratio: float = 0.75 + use_type_embedding: bool = True + image_output_dim: int = 768 + + # @staticmethod + # @nn.nowrap + # def get_default_config(updates=None): + # config = ConfigDict() + # config.model_type = config_dict.placeholder(str) + # config.emb_dim = 1024 + # config.dec_emb_dim = 512 + # config.depth = 24 + # config.dec_depth = 8 + # config.num_heads = 16 + # config.dec_num_heads = 16 + # config.mlp_ratio = 4 + + # config.output_head_depth = 0 + # # Dropout not applied in original MAE implementation. + # config.att_drop = 0.0 + # config.drop = 0.0 + # config.drop_path = 0.0 + + # # Tuned default mask ratio + # config.image_mask_ratio = 0.75 + + # config.use_type_embedding = True + + # if updates is not None: + # config.update(ConfigDict(updates).copy_and_resolve_references()) + + # if config.model_type is not None: + # get_transformer_by_config(config.model_type, config) + + # return config + + @nn.nowrap + def rng_keys(self): + return ('params', 'noise', 'drop_path', 'dropout') + + @nn.nowrap + def no_decay_list(self): + # model specific no decay list + no_decay = [ + 'cls_token', 'encoder_image_type_embedding', 'image_mask_embedding', + 'bias', + ] + return no_decay + + def setup(self): + self.patch_embed = PatchEmbed(embed_dim=self.emb_dim) + # self.image_embedding = nn.Dense( + # self.emb_dim, + # kernel_init=nn.initializers.xavier_uniform() + # ) + # Type embeddings + if self.use_type_embedding: + self.encoder_image_type_embedding = self.param( + "encoder_image_type_embedding", + nn.initializers.normal(stddev=0.02, dtype=jnp.float32), + (1, 1, self.emb_dim), + ) + self.decoder_image_type_embedding = self.param( + "decoder_image_type_embedding", + nn.initializers.normal(stddev=0.02, dtype=jnp.float32), + (1, 1, self.dec_emb_dim), + ) + + # CLS and masks + self.cls_token = self.param( + "cls_token", + nn.initializers.normal(stddev=0.02, dtype=jnp.float32), + (1, 1, self.emb_dim), + ) + self.image_mask_embedding = self.param( + "mask_token", + nn.initializers.normal(stddev=0.02, dtype=jnp.float32), + (1, 1, self.dec_emb_dim), + ) + + self.encoder = Transformer( + emb_dim=self.emb_dim, + depth=self.depth, + att_drop=self.att_drop, + drop=self.drop, + drop_path=self.drop_path, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + ) + + self.decoder = Transformer( + emb_dim=self.dec_emb_dim, + depth=self.dec_depth, + att_drop=self.att_drop, + drop=self.drop, + drop_path=self.drop_path, + num_heads=self.dec_num_heads, + mlp_ratio=self.mlp_ratio, + ) + + self.decoder_input_projection = nn.Dense( + self.dec_emb_dim, + kernel_init=nn.initializers.xavier_uniform() + ) + + self.decoder_image_output = MLP( + self.dec_emb_dim, + self.image_output_dim, + self.output_head_depth, + input_norm=self.output_head_depth > 0, + name='decoder_image_output', + ) + + def get_type_embedding(self, name): + if self.use_type_embedding: + return { + 'encoder_image_type_embedding': self.encoder_image_type_embedding, + 'decoder_image_type_embedding': self.decoder_image_type_embedding, + }[name] + else: + return 0.0 + + def forward_representation(self, image, deterministic=False): + batch_size = image.shape[0] + # image_x = self.image_embedding(image) + image_x = self.patch_embed(image) + print(image_x.shape) + image_x = ( + image_x + + get_2d_sincos_pos_embed(self.emb_dim, image_x.shape[1]) + + self.get_type_embedding('encoder_image_type_embedding') + ) + cls_token = jnp.broadcast_to( + self.cls_token, (batch_size, 1, self.emb_dim) + ) + x = jnp.concatenate([cls_token, image_x], axis=1) + x = self.encoder(x, deterministic) + return x + + def forward_encoder(self, image, deterministic=False): + batch_size = image.shape[0] + # image_x = self.image_embedding(image) + image_x = self.patch_embed(image) + print(image_x.shape) + image_keep_length = int(image_x.shape[1] * (1.0 - self.image_mask_ratio)) + + image_x = ( + image_x + + get_2d_sincos_pos_embed(self.emb_dim, image_x.shape[1]) + + self.get_type_embedding('encoder_image_type_embedding') + ) + image_x, image_mask, image_ids_restore = random_masking( + image_x, self.make_rng("noise"), image_keep_length + ) + cls_token = jnp.broadcast_to( + self.cls_token, (batch_size, 1, self.emb_dim) + ) + x = jnp.concatenate([cls_token, image_x], axis=1) + x = self.encoder(x, deterministic) + + return x, image_mask, image_ids_restore + + def forward_decoder(self, x, image_ids_restore, deterministic=False): + batch_size = x.shape[0] + image_keep_length = int(image_ids_restore.shape[0] * (1.0 - self.image_mask_ratio)) + x = self.decoder_input_projection(x) + encoder_cls = x[:, :1, :] + image_x = x[:, 1:, :] + + masked_image_x = jnp.broadcast_to( + self.image_mask_embedding, + ( + batch_size, + image_ids_restore.shape[0] - image_keep_length, + self.dec_emb_dim, + ), + ) + + image_x = index_sequence( + jnp.concatenate([image_x, masked_image_x], axis=1), image_ids_restore + ) + + image_x = ( + image_x + + get_2d_sincos_pos_embed(self.dec_emb_dim, image_ids_restore.shape[0]) + + self.get_type_embedding('decoder_image_type_embedding') + ) + + x = jnp.concatenate([encoder_cls, image_x], axis=1) + x = self.decoder(x, deterministic) + image_x = x[:, 1:, :] + image_output = self.decoder_image_output(image_x) + + return image_output + + def __call__(self, image, deterministic=False): + return self.forward_representation(image, deterministic) + + def encode_and_decode(self, image, deterministic=False): + x, image_mask, image_ids_restore = self.forward_encoder(image, deterministic) + image_output = self.forward_decoder(x, image_ids_restore, deterministic) + return image_output, image_mask, x + +class LinearCLS(nn.Module): + num_classes: int = 1000 + pool: bool = False + + @nn.compact + def __call__(self, x, train=True): + if self.pool: + x = x[:, 1:, :].mean(axis=1) # global pool without cls token + else: + x = x[:, 0] + norm = partial( + nn.BatchNorm, + use_running_average=not train, + momentum=0.9, + epsilon=1e-5, + use_scale=False, + use_bias=False, + ) + x = norm(name="bn")(x) + logits = nn.Dense(self.num_classes)(x) + return logits + + +class ViTClassifier(nn.Module): + base_model: nn.Module + num_classes: int + global_pool: bool = False + stop_gradient: bool = False + + @nn.nowrap + def rng_keys(self): + return ('params', 'noise', 'drop_path') + + @nn.compact + def __call__(self, x, deterministic=False, features=False): + x = self.base_model.forward_representation(x, deterministic=deterministic) + if self.global_pool: + x = x[:, 1:, :].mean(axis=1) # global pool without cls token + else: + x = x[:, 0] + + z = x + + x = LayerNorm()(x) + x = nn.Dense(self.num_classes)(x) + logits = x + log_probs = nn.log_softmax(x, axis=1) + + if features: + return log_probs, logits, z + else: + return logits + + +def map_to_jax(pytorch_key): + if 'blocks' in pytorch_key[0]: + if 'decoder' in pytorch_key[0]: + jax_key = ('decoder', f'blocks_{pytorch_key[1]}') + pytorch_key[2:] + else: + jax_key = ('encoder', f'blocks_{pytorch_key[1]}') + pytorch_key[2:] + else: + if pytorch_key[0] == 'decoder_pred' : + jax_key = ('decoder_image_output', 'Dense_0', *pytorch_key[1:]) +# elif 'patch_embed' == pytorch_key[0]: +# jax_key = ('image_embedding', *pytorch_key[1:]) + elif 'decoder_embed' == pytorch_key[0]: + jax_key = ('decoder_input_projection', *pytorch_key[1:]) + elif 'decoder' in pytorch_key[0]: + jax_key = ('decoder', pytorch_key[0].partition('_')[2], *pytorch_key[1:]) + else: + if pytorch_key[0] in ['cls_token', 'mask_token', 'patch_embed']: + jax_key = pytorch_key + else: + jax_key = ('encoder', *pytorch_key) + + + if jax_key[-1] == "weight": + if 'norm' in jax_key[-2]: + jax_key = jax_key[:-1] + ("scale",) + else: + jax_key = jax_key[:-1] + ("kernel",) + return jax_key + + +def pytorch_statedict_to_jax(state_dict): + pytorch_dict = {tuple(k.split('.')): v for k, v in state_dict['model'].items()} + + jax_flat_dict = {map_to_jax(k): jnp.asarray(v) for k, v in pytorch_dict.items()} + for k in jax_flat_dict: + if k[-1] == 'kernel': + kernel = jax_flat_dict[k] + if kernel.ndim > 2: # Conv + kernel = jnp.transpose(kernel, (2, 3, 1, 0)) + else: + kernel = jnp.transpose(kernel, (1, 0)) + jax_flat_dict[k] = kernel + return flax.traverse_util.unflatten_dict(jax_flat_dict) + + +transformer_config_dicts = { + 'small': { + 'emb_dim': 384, + 'dec_emb_dim': 512, + 'depth': 12, + 'dec_depth': 8, + 'num_heads': 6, + 'dec_num_heads': 16, + 'mlp_ratio': 4, + }, + + 'base': { + 'emb_dim': 768, + 'dec_emb_dim': 512, + 'depth': 12, + 'dec_depth': 8, + 'num_heads': 12, + 'dec_num_heads': 16, + 'mlp_ratio': 4, + }, + + 'large': { + 'emb_dim': 1024, + 'dec_emb_dim': 512, + 'depth': 24, + 'dec_depth': 8, + 'num_heads': 16, + 'dec_num_heads': 16, + 'mlp_ratio': 4, + }, + + 'huge': { + 'emb_dim': 1280, + 'dec_emb_dim': 512, + 'depth': 32, + 'dec_depth': 8, + 'num_heads': 16, + 'dec_num_heads': 16, + 'mlp_ratio': 4, + }, + + 'debug': { + 'emb_dim': 1024, + 'dec_emb_dim': 512, + 'depth': 2, + 'dec_depth': 2, + 'num_heads': 16, + 'dec_num_heads': 16, + 'mlp_ratio': 4, + } +} + +mae_model_configs = { + f'mae_{size}': partial(MaskedAutoencoder, **config) + for size, config in transformer_config_dicts.items() +} \ No newline at end of file diff --git a/jaxrl_m/vision/pretrained_encoder.py b/jaxrl_m/vision/pretrained_encoder.py new file mode 100644 index 0000000..fb54954 --- /dev/null +++ b/jaxrl_m/vision/pretrained_encoder.py @@ -0,0 +1,106 @@ +from jaxrl_m.vision import resnet_v1 +import flax.linen as nn +import jax +import jax.numpy as jnp +import functools as ft + +import flax.linen as nn +from flax.core import freeze, unfreeze +import pickle +import flax.training.checkpoints as checkpoints +from flax.traverse_util import flatten_dict, unflatten_dict + +def preprocess_observations(obs, + normalize_imagenet=False, + resize=False, + final_shape=(224, 224), + center_crop=False, + pre_crop_shape=(256, 256)): + if resize: + if obs.shape[-3] != final_shape[0] or obs.shape[-2] != final_shape[1]: # Already resized + resize_shape = pre_crop_shape if center_crop else final_shape + if obs.shape[-3] != resize_shape[0] or obs.shape[-2] != resize_shape[1]: + print('Resizing to %s' % str(resize_shape)) + obs = jax.image.resize(obs, (*obs.shape[:-3], *resize_shape, 3), method='bilinear') + + if center_crop: + start_y, start_x = (pre_crop_shape[0] - final_shape[0]) // 2, (pre_crop_shape[1] - final_shape[1]) // 2 + obs = obs[..., start_y:start_y + final_shape[0], start_x:start_x + final_shape[1], :] + print('Cropping to %s' % str(obs.shape)) + + if normalize_imagenet: + obs = obs / 255.0 + obs = obs - jnp.array([0.485, 0.456, 0.406]) + obs = obs / jnp.array([0.229, 0.224, 0.225]) + + return obs + +class ResizingEncoder(nn.Module): + encoder: nn.Module + normalize_imagenet: bool = False + + resize: bool = True + final_shape: tuple = (224, 224) + center_crop: bool = False + pre_crop_shape: tuple = (256, 256) + + freeze_encoder: bool = False + default_kwargs: dict = None + + @nn.compact + def __call__(self, observations, **kwargs): + no_batch_dim = len(observations.shape) == 3 + if no_batch_dim: + print('Adding batch dimension') + observations = jnp.expand_dims(observations, 0) + observations = preprocess_observations(observations, self.normalize_imagenet, self.resize, self.final_shape, self.center_crop, self.pre_crop_shape) + if self.default_kwargs is not None: + kwargs = {**self.default_kwargs, **kwargs} + output = self.encoder(observations, **kwargs) + if self.freeze_encoder: + output = jax.lax.stop_gradient(output) + + if no_batch_dim: + print('Removing batch dimension') + output = jnp.squeeze(output, 0) + return output + +def merge_dicts(new_dict, restore_from, allow_extra=True, allow_missing=True): + new_dict_flat = flatten_dict(new_dict) + restore_from_flat = flatten_dict(restore_from) + + missing_from_new = set(restore_from_flat.keys()) - set(new_dict_flat.keys()) + missing_from_restore = set(new_dict_flat.keys()) - set(restore_from_flat.keys()) + if not allow_extra: + assert len(missing_from_new) == 0, 'Keys missing from new dict: %s' % str(missing_from_new) + elif len(missing_from_new) > 0: + print('Keys missing from target dict: %s' % str(missing_from_new)) + + if not allow_missing: + assert len(missing_from_restore) == 0, 'Keys missing from restore dict: %s' % str(missing_from_restore) + elif len(missing_from_restore) > 0: + print('Keys missing from restore dict: %s' % str(missing_from_restore)) + + new_dict_flat.update(restore_from_flat) + return unflatten_dict(new_dict_flat) + +def load_pretrained_params(pretrained_params, pretrained_extra_variables, params, extra_variables, prefix_key='encoder/encoder'): + params, extra_variables = unfreeze(params), unfreeze(extra_variables) + + sp = params + prefix_list = prefix_key.split('/') if prefix_key != '' else [] + for k in prefix_list: + sp = sp[k] + assert sp.keys() == pretrained_params.keys(), (sp.keys(), pretrained_params.keys()) + merge_dicts(sp, pretrained_params, True, True) # Just checking + sp.update(pretrained_params) + + for k in pretrained_extra_variables: + sp = extra_variables[k] + for kk in prefix_list: + sp = sp[kk] + assert sp.keys() == pretrained_extra_variables[k].keys(), (sp.keys(), pretrained_extra_variables[k].keys()) + sp.update(pretrained_extra_variables[k]) + + return freeze(params), freeze(extra_variables) + diff --git a/jaxrl_m/vision/bigvision_resnetv2.py b/jaxrl_m/vision/resnet_v2.py similarity index 56% rename from jaxrl_m/vision/bigvision_resnetv2.py rename to jaxrl_m/vision/resnet_v2.py index d019efc..612af4c 100644 --- a/jaxrl_m/vision/bigvision_resnetv2.py +++ b/jaxrl_m/vision/resnet_v2.py @@ -21,11 +21,6 @@ import re from typing import Optional, Sequence, Union -import jaxrl_m.vision.bigvision_utils as u -# from big_vision import utils as u -import jaxrl_m.vision.bigvision_common as common -# from big_vision.models import bit -# from big_vision.models import common import flax.linen as nn import jax.numpy as jnp import jax @@ -179,106 +174,8 @@ def get_block_desc(depth): 200: [3, 24, 36, 3] }.get(depth, depth) -def load(init_params, init_file, model_cfg, dont_load=()): - """Loads the TF-dumped NumPy or big_vision checkpoint. - - Args: - init_params: random init params from which the new head is taken. - init_file: comes from `config.model_init`, can either be an absolute - path (ie starts with /) to the checkpoint, or a string like - "L-imagenet2012" describing one of the variants from the paper. - model_cfg: the model configuration. - dont_load: list of param names to be reset to init. - - Returns: - The loaded parameters. - """ - - # Support for vanity model names from the paper. - vanity = { - 'FunMatch-224px-i1k82.8': 'gs://bit_models/distill/R50x1_224.npz', - 'FunMatch-160px-i1k80.5': 'gs://bit_models/distill/R50x1_160.npz', - } - if init_file[0] in ('L', 'M', 'S'): # The models from the original paper. - # Supported names are of the following type: - # - 'M' or 'S': the original "upstream" model without fine-tuning. - # - 'M-ILSVRC2012': i21k model fine-tuned on i1k. - # - 'M-run0-caltech101': i21k model fine-tuned on VTAB's caltech101. - # each VTAB fine-tuning was run 3x, so there's run0, run1, run2. - if '-' in init_file: - up, down = init_file[0], init_file[1:] - else: - up, down = init_file, '' - down = {'-imagenet2012': '-ILSVRC2012'}.get(down, down) # normalize - fname = f'BiT-{up}-R{model_cfg.depth}x{model_cfg.width}{down}.npz' - fname = f'gs://bit_models/{fname}' - else: - fname = vanity.get(init_file, init_file) - - params = u.load_params(None, fname) - params = maybe_convert_big_transfer_format(params) - return common.merge_params(params, init_params, dont_load) - - -def maybe_convert_big_transfer_format(params_tf): - """If the checkpoint comes from legacy codebase, convert it.""" - - # Only do anything at all if we recognize the format. - if 'resnet' not in params_tf: - return params_tf - - # For ease of processing and backwards compatibility, flatten again: - params_tf = dict(u.tree_flatten_with_names(params_tf)[0]) - - # Works around some files containing weird naming of variables: - for k in list(params_tf): - k2 = re.sub('/standardized_conv2d_\\d+/', '/standardized_conv2d/', k) - if k2 != k: - params_tf[k2] = params_tf[k] - del params_tf[k] - - params = { - 'root_block': {'conv_root': {'kernel': params_tf[ - 'resnet/root_block/standardized_conv2d/kernel']}}, - 'norm-pre-head': { - 'bias': params_tf['resnet/group_norm/beta'][None, None, None], - 'scale': params_tf['resnet/group_norm/gamma'][None, None, None], - }, - 'head': { - 'kernel': params_tf['resnet/head/conv2d/kernel'][0, 0], - 'bias': params_tf['resnet/head/conv2d/bias'], - } - } - - for block in ('block1', 'block2', 'block3', 'block4'): - params[block] = {} - units = set([re.findall(r'unit\d+', p)[0] for p in params_tf.keys() - if p.find(block) >= 0]) - for unit in units: - params[block][unit] = {} - for i, group in enumerate('abc', 1): - params[block][unit][f'conv{i}'] = { - 'kernel': params_tf[f'resnet/{block}/{unit}/{group}/standardized_conv2d/kernel'] # pylint: disable=line-too-long - } - params[block][unit][f'gn{i}'] = { - 'bias': params_tf[f'resnet/{block}/{unit}/{group}/group_norm/beta'][None, None, None], # pylint: disable=line-too-long - 'scale': params_tf[f'resnet/{block}/{unit}/{group}/group_norm/gamma'][None, None, None], # pylint: disable=line-too-long - } - - projs = [p for p in params_tf.keys() - if p.find(f'{block}/{unit}/a/proj') >= 0] - assert len(projs) <= 1 - if projs: - params[block][unit]['conv_proj'] = { - 'kernel': params_tf[projs[0]] - } - - return params - import functools as ft resnetv2_configs = { 'resnetv2-26-1': ft.partial(Model, num_classes=None, depth=26), - 'resnetv2-26-1-128': ft.partial(Model, num_classes=None, depth=26, image_shape=(128, 128)), 'resnetv2-50-1': ft.partial(Model, num_classes=None, depth=50), - 'resnetv2-50-1-128': ft.partial(Model, num_classes=None, depth=50, image_shape=(128, 128)), } \ No newline at end of file diff --git a/jaxrl_m/vision/vit.py b/jaxrl_m/vision/vit.py new file mode 100644 index 0000000..8f70eb4 --- /dev/null +++ b/jaxrl_m/vision/vit.py @@ -0,0 +1,355 @@ +# Copyright 2023 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Optional, Tuple, Type, Dict + +import flax.linen as nn +import jax.numpy as jnp + +from . import vit_resnet as models_resnet +from functools import partial + +Array = Any +PRNGKey = Any +Shape = Tuple[int] +Dtype = Any + + +class IdentityLayer(nn.Module): + """Identity layer, convenient for giving a name to an array.""" + + @nn.compact + def __call__(self, x): + return x + + +class AddPositionEmbs(nn.Module): + """Adds learned positional embeddings to the inputs. + + Attributes: + posemb_init: positional embedding initializer. + """ + + posemb_init: Callable[[PRNGKey, Shape, Dtype], Array] + + @nn.compact + def __call__(self, inputs): + """Applies the AddPositionEmbs module. + + Args: + inputs: Inputs to the layer. + + Returns: + Output tensor with shape `(bs, timesteps, in_dim)`. + """ + # inputs.shape is (batch_size, seq_len, emb_dim). + assert inputs.ndim == 3, ('Number of dimensions should be 3,' + ' but it is: %d' % inputs.ndim) + pos_emb_shape = (1, inputs.shape[1], inputs.shape[2]) + pe = self.param('pos_embedding', self.posemb_init, pos_emb_shape) + return inputs + pe + + +class MlpBlock(nn.Module): + """Transformer MLP / feed-forward block.""" + + mlp_dim: int + dtype: Dtype = jnp.float32 + out_dim: Optional[int] = None + dropout_rate: float = 0.1 + kernel_init: Callable[[PRNGKey, Shape, Dtype], + Array] = nn.initializers.xavier_uniform() + bias_init: Callable[[PRNGKey, Shape, Dtype], + Array] = nn.initializers.normal(stddev=1e-6) + + @nn.compact + def __call__(self, inputs, *, deterministic): + """Applies Transformer MlpBlock module.""" + actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim + x = nn.Dense( + features=self.mlp_dim, + dtype=self.dtype, + kernel_init=self.kernel_init, + bias_init=self.bias_init)( # pytype: disable=wrong-arg-types + inputs) + x = nn.gelu(x) + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) + output = nn.Dense( + features=actual_out_dim, + dtype=self.dtype, + kernel_init=self.kernel_init, + bias_init=self.bias_init)( # pytype: disable=wrong-arg-types + x) + output = nn.Dropout( + rate=self.dropout_rate)( + output, deterministic=deterministic) + return output + + +class Encoder1DBlock(nn.Module): + """Transformer encoder layer. + + Attributes: + inputs: input data. + mlp_dim: dimension of the mlp on top of attention block. + dtype: the dtype of the computation (default: float32). + dropout_rate: dropout rate. + attention_dropout_rate: dropout for attention heads. + deterministic: bool, deterministic or not (to apply dropout). + num_heads: Number of heads in nn.MultiHeadDotProductAttention + """ + + mlp_dim: int + num_heads: int + dtype: Dtype = jnp.float32 + dropout_rate: float = 0.1 + attention_dropout_rate: float = 0.1 + + @nn.compact + def __call__(self, inputs, *, deterministic): + """Applies Encoder1DBlock module. + + Args: + inputs: Inputs to the layer. + deterministic: Dropout will not be applied when set to true. + + Returns: + output after transformer encoder block. + """ + + # Attention block. + assert inputs.ndim == 3, f'Expected (batch, seq, hidden) got {inputs.shape}' + x = nn.LayerNorm(dtype=self.dtype)(inputs) + x = nn.MultiHeadDotProductAttention( + dtype=self.dtype, + kernel_init=nn.initializers.xavier_uniform(), + broadcast_dropout=False, + deterministic=deterministic, + dropout_rate=self.attention_dropout_rate, + num_heads=self.num_heads)( + x, x) + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) + x = x + inputs + + # MLP block. + y = nn.LayerNorm(dtype=self.dtype)(x) + y = MlpBlock( + mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate)( + y, deterministic=deterministic) + + return x + y + + +class Encoder(nn.Module): + """Transformer Model Encoder for sequence to sequence translation. + + Attributes: + num_layers: number of layers + mlp_dim: dimension of the mlp on top of attention block + num_heads: Number of heads in nn.MultiHeadDotProductAttention + dropout_rate: dropout rate. + attention_dropout_rate: dropout rate in self attention. + """ + + num_layers: int + mlp_dim: int + num_heads: int + dropout_rate: float = 0.1 + attention_dropout_rate: float = 0.1 + add_position_embedding: bool = True + + @nn.compact + def __call__(self, x, *, train): + """Applies Transformer model on the inputs. + + Args: + x: Inputs to the layer. + train: Set to `True` when training. + + Returns: + output of a transformer encoder. + """ + assert x.ndim == 3 # (batch, len, emb) + + if self.add_position_embedding: + x = AddPositionEmbs( + posemb_init=nn.initializers.normal(stddev=0.02), # from BERT. + name='posembed_input')( + x) + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) + + # Input Encoder + for lyr in range(self.num_layers): + x = Encoder1DBlock( + mlp_dim=self.mlp_dim, + dropout_rate=self.dropout_rate, + attention_dropout_rate=self.attention_dropout_rate, + name=f'encoderblock_{lyr}', + num_heads=self.num_heads)( + x, deterministic=not train) + encoded = nn.LayerNorm(name='encoder_norm')(x) + + return encoded + + +class VisionTransformer(nn.Module): + """VisionTransformer.""" + + patch_size: int + transformer: dict + hidden_size: int + resnet: dict = None + representation_size: Optional[int] = None + classifier: str = 'token' + head_bias_init: float = 0. + encoder: Type[nn.Module] = Encoder + model_name: Optional[str] = None + num_classes: int = None + + @nn.compact + def __call__(self, inputs, *, train): + + x = inputs + # (Possibly partial) ResNet root. + if self.resnet is not None: + width = int(64 * self.resnet['width_factor']) + + # Root block. + x = models_resnet.StdConv( + features=width, + kernel_size=(7, 7), + strides=(2, 2), + use_bias=False, + name='conv_root')( + x) + x = nn.GroupNorm(name='gn_root')(x) + x = nn.relu(x) + x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding='SAME') + + # ResNet stages. + if self.resnet['num_layers']: + x = models_resnet.ResNetStage( + block_size=self.resnet['num_layers'][0], + nout=width, + first_stride=(1, 1), + name='block1')( + x) + for i, block_size in enumerate(self.resnet['num_layers'][1:], 1): + x = models_resnet.ResNetStage( + block_size=block_size, + nout=width * 2**i, + first_stride=(2, 2), + name=f'block{i + 1}')( + x) + + n, h, w, c = x.shape + + # We can merge s2d+emb into a single conv; it's the same. + x = nn.Conv( + features=self.hidden_size, + kernel_size=(self.patch_size, self.patch_size), + strides=(self.patch_size, self.patch_size), + padding='VALID', + name='embedding')( + x) + + # Here, x is a grid of embeddings. + + # (Possibly partial) Transformer. + if self.transformer is not None: + n, h, w, c = x.shape + x = jnp.reshape(x, [n, h * w, c]) + + # If we want to add a class token, add it here. + if self.classifier in ['token', 'token_unpooled']: + cls = self.param('cls', nn.initializers.zeros, (1, 1, c)) + cls = jnp.tile(cls, [n, 1, 1]) + x = jnp.concatenate([cls, x], axis=1) + + x = self.encoder(name='Transformer', **self.transformer)(x, train=train) + + if self.classifier == 'token': + x = x[:, 0] + elif self.classifier == 'gap': + x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2) + elif self.classifier in ['unpooled', 'token_unpooled']: + pass + else: + raise ValueError(f'Invalid classifier={self.classifier}') + + if self.representation_size is not None: + x = nn.Dense(features=self.representation_size, name='pre_logits')(x) + x = nn.tanh(x) + else: + x = IdentityLayer(name='pre_logits')(x) + + if self.num_classes: + x = nn.Dense( + features=self.num_classes, + name='head', + kernel_init=nn.initializers.zeros, + bias_init=nn.initializers.constant(self.head_bias_init))(x) + return x + +core_configs = { + 'ViT-Ti/16': { + 'patch_size': 16, + 'hidden_size': 192, + 'transformer': { + 'mlp_dim': 768, + 'num_heads': 3, + 'num_layers': 12, + 'attention_dropout_rate': 0.0, + 'dropout_rate': 0.0, + }, + }, + 'ViT-S/16': { + 'patch_size': 16, + 'hidden_size': 384, + 'transformer': { + 'mlp_dim': 1536, + 'num_heads': 6, + 'num_layers': 12, + 'attention_dropout_rate': 0.0, + 'dropout_rate': 0.0, + }, + }, + 'ViT-B/16': { + 'patch_size': 16, + 'hidden_size': 768, + 'transformer': { + 'mlp_dim': 3072, + 'num_heads': 12, + 'num_layers': 12, + 'attention_dropout_rate': 0.0, + 'dropout_rate': 0.0, + }, + }, + 'ViT-L/16': { + 'patch_size': 16, + 'hidden_size': 1024, + 'transformer': { + 'mlp_dim': 4096, + 'num_heads': 16, + 'num_layers': 24, + 'attention_dropout_rate': 0.0, + 'dropout_rate': 0.1, + } + }, +} + +vit_configs = { + k: partial(VisionTransformer, **config) + for k, config in core_configs.items() +} \ No newline at end of file diff --git a/jaxrl_m/vision/vit_resnet.py b/jaxrl_m/vision/vit_resnet.py new file mode 100644 index 0000000..61afdde --- /dev/null +++ b/jaxrl_m/vision/vit_resnet.py @@ -0,0 +1,106 @@ +# Copyright 2023 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Sequence, TypeVar + +from flax import linen as nn +import jax.numpy as jnp + +T = TypeVar('T') + + +def weight_standardize(w, axis, eps): + """Subtracts mean and divides by standard deviation.""" + w = w - jnp.mean(w, axis=axis) + w = w / (jnp.std(w, axis=axis) + eps) + return w + + +class StdConv(nn.Conv): + """Convolution with weight standardization.""" + + def param(self, + name: str, + init_fn: Callable[..., T], + *init_args) -> T: + param = super().param(name, init_fn, *init_args) + if name == 'kernel': + param = weight_standardize(param, axis=[0, 1, 2], eps=1e-5) + return param + + +class ResidualUnit(nn.Module): + """Bottleneck ResNet block.""" + + features: int + strides: Sequence[int] = (1, 1) + + @nn.compact + def __call__(self, x): + needs_projection = ( + x.shape[-1] != self.features * 4 or self.strides != (1, 1)) + + residual = x + if needs_projection: + residual = StdConv( + features=self.features * 4, + kernel_size=(1, 1), + strides=self.strides, + use_bias=False, + name='conv_proj')( + residual) + residual = nn.GroupNorm(name='gn_proj')(residual) + + y = StdConv( + features=self.features, + kernel_size=(1, 1), + use_bias=False, + name='conv1')( + x) + y = nn.GroupNorm(name='gn1')(y) + y = nn.relu(y) + y = StdConv( + features=self.features, + kernel_size=(3, 3), + strides=self.strides, + use_bias=False, + name='conv2')( + y) + y = nn.GroupNorm(name='gn2')(y) + y = nn.relu(y) + y = StdConv( + features=self.features * 4, + kernel_size=(1, 1), + use_bias=False, + name='conv3')( + y) + + y = nn.GroupNorm(name='gn3', scale_init=nn.initializers.zeros)(y) + y = nn.relu(residual + y) + return y + + +class ResNetStage(nn.Module): + """A ResNet stage.""" + + block_size: Sequence[int] + nout: int + first_stride: Sequence[int] + + @nn.compact + def __call__(self, x): + x = ResidualUnit(self.nout, strides=self.first_stride, name='unit1')(x) + for i in range(1, self.block_size): + x = ResidualUnit(self.nout, strides=(1, 1), name=f'unit{i + 1}')(x) + return x \ No newline at end of file