Skip to content

Commit

Permalink
feat: implementation of UViT
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 17, 2024
1 parent 74c0075 commit f0bfd6c
Show file tree
Hide file tree
Showing 7 changed files with 1,317 additions and 349 deletions.
205 changes: 93 additions & 112 deletions Diffusion flax linen.ipynb

Large diffs are not rendered by default.

256 changes: 234 additions & 22 deletions evaluate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
"name": "stderr",
"output_type": "stream",
"text": [
"2024-08-13 03:17:55.397368: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2024-08-13 03:17:55.419844: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2024-08-13 03:17:55.426718: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
"2024-08-13 03:17:56.472841: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
"2024-08-17 12:13:19.640002: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2024-08-17 12:13:19.715866: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2024-08-17 12:13:19.736481: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
"2024-08-17 12:13:20.881285: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
"There was a problem when trying to write in your cache folder (/home/mrwhite0racle/.cache/huggingface/hub). You should set the environment variable TRANSFORMERS_CACHE to a writable directory.\n"
]
}
],
Expand Down Expand Up @@ -100,7 +101,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -308,24 +309,33 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by August 1st, 2024.\n",
"WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by August 1st, 2024.\n"
"2024-08-16 02:07:16.356225: E external/local_tsl/tsl/platform/cloud/curl_http_request.cc:610] The transmission of request 0x9c95380 (URI: http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token) has been stuck at 0 of 0 bytes for 61 seconds and will be aborted. CURL timing information: lookup time: 0.091788 (No error), connect time: 0 (No error), pre-transfer time: 0 (No error), start-transfer time: 0 (No error)\n",
"2024-08-16 02:08:17.919157: E external/local_tsl/tsl/platform/cloud/curl_http_request.cc:610] The transmission of request 0x9c95380 (URI: http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token) has been stuck at 0 of 0 bytes for 61 seconds and will be aborted. CURL timing information: lookup time: 0.086971 (No error), connect time: 0 (No error), pre-transfer time: 0 (No error), start-transfer time: 0 (No error)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading model from checkpoint at step 120002\n",
"Loaded model from checkpoint at epoch 0 step 120002 1000000000.0\n",
"Generating states for DiffusionTrainer\n"
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mCannot execute code, session has been disposed. Please try restarting the Kernel."
]
},
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mCannot execute code, session has been disposed. Please try restarting the Kernel. \n",
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
]
}
],
Expand All @@ -338,9 +348,12 @@
"# checkpoint_id, model_config = (\"dataset-combined_30m/image_size-512/batch-512-v4-64_flaxdiff-0-1-8_ldm_dyn_scale_new_arch_combined_30\", CONFIG_BIG) # --> Good\n",
"# checkpoint_id, model_config = (\"dataset-combined_30m/image_size-128/batch-128-v4-16_flaxdiff-0-1-9_light_combined_30m_1\", OLD_CONFIG_SMALL) # --> Good\n",
"\n",
"checkpoint_id, model_config = (\"dataset-combined_online/image_size-128/batch-512-v4-64-_combined_online\", OLD_CONFIG_MEDIUM) # --> Good\n",
"# checkpoint_id, model_config = (\"dataset-combined_online/image_size-128/batch-512-v4-64-_combined_online\", OLD_CONFIG_MEDIUM) # --> Good\n",
"# checkpoint_id, model_config = (\"dataset-combined_30m/image_size-128/batch-512-v4-64-_combined_30m-finetuned\", OLD_CONFIG_MEDIUM) # --> Good\n",
"checkpoint_id, model_config = (\"combined-img128-good\", OLD_CONFIG_MEDIUM) # --> Good\n",
"\n",
"checkpoint_base_path = \"gs://flaxdiff-datasets-regional/checkpoints/\"\n",
"# checkpoint_base_path = \"gs://flaxdiff-datasets-regional/checkpoints/\"\n",
"checkpoint_base_path = \"./checkpoints/\"\n",
"\n",
"model_config = model_config.copy()\n",
"IMAGE_SIZE=128\n",
Expand Down Expand Up @@ -985,9 +998,6 @@
" \"An apple\",\n",
" \"A banana\",\n",
" \"An astronaut riding a horse in space\",\n",
" \"A beautiful naked girl on snow\",\n",
" \"A beautiful naked girl on snow\",\n",
" \"A beautiful naked girl on snow\",\n",
" ]\n",
"pooled_labels, labels_seq = encodePrompts(prompts, textEncoderModel, textTokenizer)\n",
"samples = sampler.generate_images(num_images=len(prompts), diffusion_steps=200, start_step=1000, end_step=0, priors=None, model_conditioning_inputs=(labels_seq,))\n",
Expand Down Expand Up @@ -2121,10 +2131,13 @@
},
{
"cell_type": "code",
"execution_count": 59,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"from flax import linen as nn\n",
"from typing import Dict, Callable, Sequence, Any, Union, Tuple, Optional\n",
"from flax.typing import Dtype, PrecisionLike\n",
"import einops\n",
Expand Down Expand Up @@ -2483,14 +2496,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"172 μs ± 2.27 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
"686 µs ± 3.39 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
]
}
],
Expand Down Expand Up @@ -2532,6 +2545,205 @@
"# %timeit attention_block.apply(params, x)"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"from flaxdiff.models.simple_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init\n",
"\n",
"def unpatchify(x, channels=3):\n",
" patch_size = int((x.shape[2] // channels) ** 0.5)\n",
" h = w = int(x.shape[1] ** .5)\n",
" assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2], f\"Invalid shape: {x.shape}, should be {h*w}, {patch_size**2*channels}\"\n",
" x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B (h p1) (w p2) C', h=h, p1=patch_size, p2=patch_size)\n",
" return x\n",
"\n",
"class PatchEmbedding(nn.Module):\n",
" patch_size: int\n",
" embedding_dim: int\n",
" dtype: Any = jnp.float32\n",
" precision: Any = jax.lax.Precision.HIGH\n",
"\n",
" @nn.compact\n",
" def __call__(self, x):\n",
" batch, height, width, channels = x.shape\n",
" assert height % self.patch_size == 0 and width % self.patch_size == 0, \"Image dimensions must be divisible by patch size\"\n",
" \n",
" x = nn.Conv(features=self.embedding_dim, \n",
" kernel_size=(self.patch_size, self.patch_size), \n",
" strides=(self.patch_size, self.patch_size),\n",
" dtype=self.dtype,\n",
" precision=self.precision)(x)\n",
" x = jnp.reshape(x, (batch, -1, self.embedding_dim))\n",
" return x\n",
"\n",
"class PositionalEncoding(nn.Module):\n",
" max_len: int\n",
" embedding_dim: int\n",
"\n",
" @nn.compact\n",
" def __call__(self, x):\n",
" pe = self.param('pos_encoding',\n",
" jax.nn.initializers.zeros,\n",
" (1, self.max_len, self.embedding_dim))\n",
" return x + pe[:, :x.shape[1], :]\n",
"\n",
"class UViT(nn.Module):\n",
" output_channels:int=3\n",
" patch_size: int = 16\n",
" emb_features:int=768,\n",
" num_layers: int = 12\n",
" num_heads: int = 12\n",
" dropout_rate: float = 0.1\n",
" dtype: Any = jnp.float32\n",
" precision: Any = jax.lax.Precision.HIGH\n",
" use_projection: bool = False\n",
" activation:Callable = jax.nn.swish\n",
" norm_groups:int=8\n",
" dtype: Optional[Dtype] = None\n",
" precision: PrecisionLike = None\n",
" kernel_init: Callable = partial(kernel_init, 1.0)\n",
"\n",
" def setup(self):\n",
" if self.norm_groups > 0:\n",
" self.norm = partial(nn.GroupNorm, self.norm_groups)\n",
" else:\n",
" self.norm = partial(nn.RMSNorm, 1e-5)\n",
" \n",
" @nn.compact\n",
" def __call__(self, x, temb, textcontext=None):\n",
" # Time embedding\n",
" temb = FourierEmbedding(features=self.emb_features)(temb)\n",
" temb = TimeProjection(features=self.emb_features)(temb)\n",
"\n",
" # Patch embedding\n",
" x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features, \n",
" dtype=self.dtype, precision=self.precision)(x)\n",
" num_patches = x.shape[1]\n",
" \n",
" context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(), \n",
" dtype=self.dtype, precision=self.precision)(textcontext)\n",
" num_text_tokens = textcontext.shape[1]\n",
" \n",
" # print(f'Shape of x after patch embedding: {x.shape}, numPatches: {num_patches}, temb: {temb.shape}, context_emb: {context_emb.shape}')\n",
" \n",
" # Add time embedding\n",
" temb = jnp.expand_dims(temb, axis=1)\n",
" x = jnp.concatenate([x, temb, context_emb], axis=1)\n",
" # print(f'Shape of x after time embedding: {x.shape}')\n",
" \n",
" # Add positional encoding\n",
" x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.emb_features)(x)\n",
" \n",
" # print(f'Shape of x after positional encoding: {x.shape}')\n",
" \n",
" skips = []\n",
" # In blocks\n",
" for i in range(self.num_layers // 2):\n",
" x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads, \n",
" dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, \n",
" use_flash_attention=False, use_self_and_cross=False, force_fp32_for_softmax=True, \n",
" only_pure_attention=False,\n",
" kernel_init=self.kernel_init())(x)\n",
" skips.append(x)\n",
" \n",
" # Middle block\n",
" x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads, \n",
" dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, \n",
" use_flash_attention=False, use_self_and_cross=True, force_fp32_for_softmax=True, \n",
" only_pure_attention=False,\n",
" kernel_init=self.kernel_init())(x)\n",
" \n",
" # # Out blocks\n",
" for i in range(self.num_layers // 2):\n",
" skip = jnp.concatenate([x, skips.pop()], axis=-1)\n",
" skip = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(), \n",
" dtype=self.dtype, precision=self.precision)(skip)\n",
" x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads, \n",
" dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, \n",
" use_flash_attention=False, use_self_and_cross=False, force_fp32_for_softmax=True, \n",
" only_pure_attention=False,\n",
" kernel_init=self.kernel_init())(skip)\n",
" \n",
" # print(f'Shape of x after transformer blocks: {x.shape}')\n",
" x = self.norm()(x)\n",
" \n",
" # print(f'Shape of x after norm: {x.shape}')\n",
" \n",
" patch_dim = self.patch_size ** 2 * self.output_channels\n",
" x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x)\n",
" # print(f'Shape of x after patch dense layer: {x.shape}, patch_dim: {patch_dim}')\n",
" x = x[:, 1 + num_text_tokens:, :]\n",
" x = unpatchify(x, channels=self.output_channels)\n",
" # print(f'Shape of x after final dense layer: {x.shape}')\n",
" x = nn.Dense(features=self.output_channels, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x)\n",
" \n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"8.78 ms ± 8.09 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"x = jnp.ones((8, 128, 128, 3), dtype=jnp.bfloat16)\n",
"temb = jnp.ones((8,), dtype=jnp.bfloat16)\n",
"textcontext = jnp.ones((8, 77, 768), dtype=jnp.bfloat16) \n",
"vit = UViT(patch_size=16, \n",
" emb_features=768, \n",
" num_layers=12, \n",
" num_heads=12, \n",
" dropout_rate=0.1, \n",
" dtype=jnp.bfloat16)\n",
"params = vit.init(jax.random.PRNGKey(0), x, temb, textcontext)\n",
"\n",
"@jax.jit\n",
"def apply(params, x, temb, textcontext):\n",
" return vit.apply(params, x, temb, textcontext)\n",
"\n",
"out = apply(params, x, temb, textcontext)\n",
"\n",
"%timeit apply(params, x, temb, textcontext)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(8, 8, 8, 3)"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 103,
Expand Down Expand Up @@ -2738,7 +2950,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.12.4"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit f0bfd6c

Please sign in to comment.