From f0bfd6c2a0f5ce62e6d3a42709cc8ed354948b4e Mon Sep 17 00:00:00 2001 From: Ashish Kumar Singh Date: Sat, 17 Aug 2024 13:13:16 +0530 Subject: [PATCH] feat: implementation of UViT --- Diffusion flax linen.ipynb | 205 ++++++------ evaluate.ipynb | 256 ++++++++++++-- flaxdiff/models/common.py | 35 +- flaxdiff/models/simple_vit.py | 145 ++++---- setup.py | 2 +- test modeling.ipynb | 410 +++++++++++++++++++++++ training.py | 613 ++++++++++++++++++++++++++-------- 7 files changed, 1317 insertions(+), 349 deletions(-) create mode 100644 test modeling.ipynb diff --git a/Diffusion flax linen.ipynb b/Diffusion flax linen.ipynb index 33b57ed..c984319 100644 --- a/Diffusion flax linen.ipynb +++ b/Diffusion flax linen.ipynb @@ -66,14 +66,15 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "There was a problem when trying to write in your cache folder (/home/mrwhite0racle/.cache/huggingface/hub). You should set the environment variable TRANSFORMERS_CACHE to a writable directory.\n" + "The dotenv extension is already loaded. To reload it, use:\n", + " %reload_ext dotenv\n" ] } ], @@ -120,7 +121,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -171,7 +172,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -201,7 +202,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -304,7 +305,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -886,7 +887,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -924,25 +925,23 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 10, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'Optional' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[10], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;43;01mclass\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;21;43;01mEfficientAttention\u001b[39;49;00m\u001b[43m(\u001b[49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mModule\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;250;43m \u001b[39;49m\u001b[38;5;124;43;03m\"\"\"\u001b[39;49;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;124;43;03m Based on the pallas attention implementation.\u001b[39;49;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;124;43;03m \"\"\"\u001b[39;49;00m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mquery_dim\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mint\u001b[39;49m\n", + "Cell \u001b[0;32mIn[10], line 8\u001b[0m, in \u001b[0;36mEfficientAttention\u001b[0;34m()\u001b[0m\n\u001b[1;32m 6\u001b[0m heads: \u001b[38;5;28mint\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m4\u001b[39m\n\u001b[1;32m 7\u001b[0m dim_head: \u001b[38;5;28mint\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m64\u001b[39m\n\u001b[0;32m----> 8\u001b[0m dtype: \u001b[43mOptional\u001b[49m[Dtype] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 9\u001b[0m precision: PrecisionLike \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 10\u001b[0m use_bias: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n", + "\u001b[0;31mNameError\u001b[0m: name 'Optional' is not defined" + ] + } + ], "source": [ - "import jax.experimental.pallas.ops.gpu.attention\n", - "import jax.experimental.pallas.ops.tpu.flash_attention\n", - "# import jax.experimental.pallas.ops.attention\n", - "\n", - "from flaxdiff.models.simple_unet import l2norm, ConvLayer, TimeEmbedding, TimeProjection, Upsample, Downsample, ResidualBlock, PixelShuffle\n", - "from flaxdiff.models.simple_unet import FourierEmbedding\n", - "\n", - "from flaxdiff.models.attention import kernel_init\n", - "# from flash_attn_jax import flash_mha\n", - "# from flaxdiff.models.favor_fastattn import make_fast_generalized_attention, make_fast_softmax_attention\n", - "\n", - "# Kernel initializer to use\n", - "def kernel_init(scale, dtype=jnp.float32):\n", - " scale = max(scale, 1e-10)\n", - " return nn.initializers.variance_scaling(scale=scale, mode=\"fan_avg\", distribution=\"truncated_normal\", dtype=dtype)\n", "\n", "class EfficientAttention(nn.Module):\n", " \"\"\"\n", @@ -951,10 +950,11 @@ " query_dim: int\n", " heads: int = 4\n", " dim_head: int = 64\n", - " dtype: Any = jnp.float32\n", - " precision: Any = jax.lax.Precision.HIGHEST\n", + " dtype: Optional[Dtype] = None\n", + " precision: PrecisionLike = None\n", " use_bias: bool = True\n", - " kernel_init: Callable = lambda : kernel_init(1.0)\n", + " kernel_init: Callable = kernel_init(1.0)\n", + " force_fp32_for_softmax: bool = True\n", "\n", " def setup(self):\n", " inner_dim = self.dim_head * self.heads\n", @@ -964,7 +964,7 @@ " self.heads * self.dim_head,\n", " precision=self.precision, \n", " use_bias=self.use_bias, \n", - " kernel_init=self.kernel_init(), \n", + " kernel_init=self.kernel_init, \n", " dtype=self.dtype\n", " )\n", " self.query = dense(name=\"to_q\")\n", @@ -972,7 +972,7 @@ " self.value = dense(name=\"to_v\")\n", " \n", " self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision, \n", - " kernel_init=self.kernel_init(), dtype=self.dtype, name=\"to_out_0\")\n", + " kernel_init=self.kernel_init, dtype=self.dtype, name=\"to_out_0\")\n", " # self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)\n", " \n", " def _reshape_tensor_to_head_dim(self, tensor):\n", @@ -1042,10 +1042,11 @@ " query_dim: int\n", " heads: int = 4\n", " dim_head: int = 64\n", - " dtype: Any = jnp.float32\n", - " precision: Any = jax.lax.Precision.HIGHEST\n", + " dtype: Optional[Dtype] = None\n", + " precision: PrecisionLike = None\n", " use_bias: bool = True\n", - " kernel_init: Callable = lambda : kernel_init(1.0)\n", + " kernel_init: Callable = kernel_init(1.0)\n", + " force_fp32_for_softmax: bool = True\n", "\n", " def setup(self):\n", " inner_dim = self.dim_head * self.heads\n", @@ -1055,7 +1056,7 @@ " axis=-1, \n", " precision=self.precision, \n", " use_bias=self.use_bias, \n", - " kernel_init=self.kernel_init(), \n", + " kernel_init=self.kernel_init, \n", " dtype=self.dtype\n", " )\n", " self.query = dense(name=\"to_q\")\n", @@ -1069,7 +1070,7 @@ " use_bias=self.use_bias, \n", " dtype=self.dtype, \n", " name=\"to_out_0\",\n", - " kernel_init=self.kernel_init()\n", + " kernel_init=self.kernel_init\n", " # kernel_init=jax.nn.initializers.xavier_uniform()\n", " )\n", "\n", @@ -1088,72 +1089,28 @@ " value = self.value(context)\n", " \n", " hidden_states = nn.dot_product_attention(\n", - " query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision\n", + " query, key, value, dtype=self.dtype, broadcast_dropout=False, \n", + " dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,\n", + " deterministic=True\n", " )\n", " proj = self.proj_attn(hidden_states)\n", " proj = proj.reshape(orig_x_shape)\n", " return proj\n", "\n", - "class BasicTransformerBlock(nn.Module):\n", - " # Has self and cross attention\n", - " query_dim: int\n", - " heads: int = 4\n", - " dim_head: int = 64\n", - " dtype: Any = jnp.float32\n", - " precision: Any = jax.lax.Precision.HIGHEST\n", - " use_bias: bool = True\n", - " kernel_init: Callable = lambda : kernel_init(1.0)\n", - " use_flash_attention:bool = False\n", - " use_cross_only:bool = False\n", - " \n", - " def setup(self):\n", - " if self.use_flash_attention:\n", - " attenBlock = EfficientAttention\n", - " else:\n", - " attenBlock = NormalAttention\n", - " \n", - " self.attention1 = attenBlock(\n", - " query_dim=self.query_dim,\n", - " heads=self.heads,\n", - " dim_head=self.dim_head,\n", - " name=f'Attention1',\n", - " precision=self.precision,\n", - " use_bias=self.use_bias,\n", - " dtype=self.dtype,\n", - " kernel_init=self.kernel_init\n", - " )\n", - " self.attention2 = attenBlock(\n", - " query_dim=self.query_dim,\n", - " heads=self.heads,\n", - " dim_head=self.dim_head,\n", - " name=f'Attention2',\n", - " precision=self.precision,\n", - " use_bias=self.use_bias,\n", - " dtype=self.dtype,\n", - " kernel_init=self.kernel_init\n", - " )\n", - " \n", - " self.ff = FlaxFeedForward(dim=self.query_dim)\n", - " self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)\n", - " self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)\n", - " self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)\n", - " \n", - " @nn.compact\n", - " def __call__(self, hidden_states, context=None):\n", - " # self attention\n", - " if not self.use_cross_only:\n", - " print(\"Using self attention\")\n", - " hidden_states = hidden_states + self.attention1(self.norm1(hidden_states))\n", - "\n", - " # cross attention\n", - " hidden_states = hidden_states + self.attention2(self.norm2(hidden_states), context)\n", - "\n", - " # feed forward\n", - " hidden_states = hidden_states + self.ff(self.norm3(hidden_states))\n", - " \n", - " return hidden_states\n", - "\n", "class FlaxGEGLU(nn.Module):\n", + " r\"\"\"\n", + " Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from\n", + " https://arxiv.org/abs/2002.05202.\n", + "\n", + " Parameters:\n", + " dim (:obj:`int`):\n", + " Input hidden states dimension\n", + " dropout (:obj:`float`, *optional*, defaults to 0.0):\n", + " Dropout rate\n", + " dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):\n", + " Parameters `dtype`\n", + " \"\"\"\n", + "\n", " dim: int\n", " dropout: float = 0.0\n", " dtype: jnp.dtype = jnp.float32\n", @@ -1165,10 +1122,27 @@ "\n", " def __call__(self, hidden_states):\n", " hidden_states = self.proj(hidden_states)\n", - " hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=3)\n", + " hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=-1)\n", " return hidden_linear * nn.gelu(hidden_gelu)\n", " \n", "class FlaxFeedForward(nn.Module):\n", + " r\"\"\"\n", + " Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's\n", + " [`FeedForward`] class, with the following simplifications:\n", + " - The activation function is currently hardcoded to a gated linear unit from:\n", + " https://arxiv.org/abs/2002.05202\n", + " - `dim_out` is equal to `dim`.\n", + " - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].\n", + "\n", + " Parameters:\n", + " dim (:obj:`int`):\n", + " Inner hidden states dimension\n", + " dropout (:obj:`float`, *optional*, defaults to 0.0):\n", + " Dropout rate\n", + " dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):\n", + " Parameters `dtype`\n", + " \"\"\"\n", + "\n", " dim: int\n", " dtype: jnp.dtype = jnp.float32\n", " precision: Any = jax.lax.Precision.DEFAULT\n", @@ -1189,13 +1163,14 @@ " query_dim: int\n", " heads: int = 4\n", " dim_head: int = 64\n", - " dtype: Any = jnp.float32\n", - " precision: Any = jax.lax.Precision.HIGHEST\n", + " dtype: Optional[Dtype] = None\n", + " precision: PrecisionLike = None\n", " use_bias: bool = True\n", - " kernel_init: Callable = lambda : kernel_init(1.0)\n", + " kernel_init: Callable = kernel_init(1.0)\n", " use_flash_attention:bool = False\n", " use_cross_only:bool = False\n", " only_pure_attention:bool = False\n", + " force_fp32_for_softmax: bool = True\n", " \n", " def setup(self):\n", " if self.use_flash_attention:\n", @@ -1211,7 +1186,8 @@ " precision=self.precision,\n", " use_bias=self.use_bias,\n", " dtype=self.dtype,\n", - " kernel_init=self.kernel_init\n", + " kernel_init=self.kernel_init,\n", + " force_fp32_for_softmax=self.force_fp32_for_softmax\n", " )\n", " self.attention2 = attenBlock(\n", " query_dim=self.query_dim,\n", @@ -1221,7 +1197,8 @@ " precision=self.precision,\n", " use_bias=self.use_bias,\n", " dtype=self.dtype,\n", - " kernel_init=self.kernel_init\n", + " kernel_init=self.kernel_init,\n", + " force_fp32_for_softmax=self.force_fp32_for_softmax\n", " )\n", " \n", " self.ff = FlaxFeedForward(dim=self.query_dim)\n", @@ -1232,7 +1209,7 @@ " @nn.compact\n", " def __call__(self, hidden_states, context=None):\n", " if self.only_pure_attention:\n", - " return self.attention2(self.norm2(hidden_states), context)\n", + " return self.attention2(hidden_states, context)\n", " \n", " # self attention\n", " if not self.use_cross_only:\n", @@ -1249,28 +1226,30 @@ " heads: int = 4\n", " dim_head: int = 32\n", " use_linear_attention: bool = True\n", - " dtype: Any = jnp.float32\n", - " precision: Any = jax.lax.Precision.HIGH\n", + " dtype: Optional[Dtype] = None\n", + " precision: PrecisionLike = None\n", " use_projection: bool = False\n", - " use_flash_attention:bool = True\n", - " use_self_and_cross:bool = False\n", + " use_flash_attention:bool = False\n", + " use_self_and_cross:bool = True\n", " only_pure_attention:bool = False\n", + " force_fp32_for_softmax: bool = True\n", + " kernel_init: Callable = kernel_init(1.0)\n", "\n", " @nn.compact\n", " def __call__(self, x, context=None):\n", " inner_dim = self.heads * self.dim_head\n", - " B, H, W, C = x.shape\n", + " C = x.shape[-1]\n", " normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)\n", " if self.use_projection == True:\n", " if self.use_linear_attention:\n", " projected_x = nn.Dense(features=inner_dim, \n", " use_bias=False, precision=self.precision, \n", - " kernel_init=kernel_init(1.0),\n", + " kernel_init=self.kernel_init,\n", " dtype=self.dtype, name=f'project_in')(normed_x)\n", " else:\n", " projected_x = nn.Conv(\n", " features=inner_dim, kernel_size=(1, 1),\n", - " kernel_init=kernel_init(1.0),\n", + " kernel_init=self.kernel_init,\n", " strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,\n", " precision=self.precision, name=f'project_in_conv',\n", " )(normed_x)\n", @@ -1290,19 +1269,21 @@ " dtype=self.dtype,\n", " use_flash_attention=self.use_flash_attention,\n", " use_cross_only=(not self.use_self_and_cross),\n", - " only_pure_attention=self.only_pure_attention\n", + " only_pure_attention=self.only_pure_attention,\n", + " force_fp32_for_softmax=self.force_fp32_for_softmax,\n", + " kernel_init=self.kernel_init\n", " )(projected_x, context)\n", " \n", " if self.use_projection == True:\n", " if self.use_linear_attention:\n", " projected_x = nn.Dense(features=C, precision=self.precision, \n", " dtype=self.dtype, use_bias=False, \n", - " kernel_init=kernel_init(1.0),\n", + " kernel_init=self.kernel_init,\n", " name=f'project_out')(projected_x)\n", " else:\n", " projected_x = nn.Conv(\n", " features=C, kernel_size=(1, 1),\n", - " kernel_init=kernel_init(1.0),\n", + " kernel_init=self.kernel_init,\n", " strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,\n", " precision=self.precision, name=f'project_out_conv',\n", " )(projected_x)\n", diff --git a/evaluate.ipynb b/evaluate.ipynb index 3a36a6a..ed60279 100644 --- a/evaluate.ipynb +++ b/evaluate.ipynb @@ -9,10 +9,11 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-08-13 03:17:55.397368: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2024-08-13 03:17:55.419844: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2024-08-13 03:17:55.426718: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "2024-08-13 03:17:56.472841: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" + "2024-08-17 12:13:19.640002: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "2024-08-17 12:13:19.715866: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "2024-08-17 12:13:19.736481: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", + "2024-08-17 12:13:20.881285: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", + "There was a problem when trying to write in your cache folder (/home/mrwhite0racle/.cache/huggingface/hub). You should set the environment variable TRANSFORMERS_CACHE to a writable directory.\n" ] } ], @@ -100,7 +101,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -308,7 +309,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -316,16 +317,25 @@ "output_type": "stream", "text": [ "WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by August 1st, 2024.\n", - "WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by August 1st, 2024.\n" + "2024-08-16 02:07:16.356225: E external/local_tsl/tsl/platform/cloud/curl_http_request.cc:610] The transmission of request 0x9c95380 (URI: http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token) has been stuck at 0 of 0 bytes for 61 seconds and will be aborted. CURL timing information: lookup time: 0.091788 (No error), connect time: 0 (No error), pre-transfer time: 0 (No error), start-transfer time: 0 (No error)\n", + "2024-08-16 02:08:17.919157: E external/local_tsl/tsl/platform/cloud/curl_http_request.cc:610] The transmission of request 0x9c95380 (URI: http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token) has been stuck at 0 of 0 bytes for 61 seconds and will be aborted. CURL timing information: lookup time: 0.086971 (No error), connect time: 0 (No error), pre-transfer time: 0 (No error), start-transfer time: 0 (No error)\n" ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loading model from checkpoint at step 120002\n", - "Loaded model from checkpoint at epoch 0 step 120002 1000000000.0\n", - "Generating states for DiffusionTrainer\n" + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mCannot execute code, session has been disposed. Please try restarting the Kernel." + ] + }, + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mCannot execute code, session has been disposed. Please try restarting the Kernel. \n", + "\u001b[1;31mView Jupyter log for further details." ] } ], @@ -338,9 +348,12 @@ "# checkpoint_id, model_config = (\"dataset-combined_30m/image_size-512/batch-512-v4-64_flaxdiff-0-1-8_ldm_dyn_scale_new_arch_combined_30\", CONFIG_BIG) # --> Good\n", "# checkpoint_id, model_config = (\"dataset-combined_30m/image_size-128/batch-128-v4-16_flaxdiff-0-1-9_light_combined_30m_1\", OLD_CONFIG_SMALL) # --> Good\n", "\n", - "checkpoint_id, model_config = (\"dataset-combined_online/image_size-128/batch-512-v4-64-_combined_online\", OLD_CONFIG_MEDIUM) # --> Good\n", + "# checkpoint_id, model_config = (\"dataset-combined_online/image_size-128/batch-512-v4-64-_combined_online\", OLD_CONFIG_MEDIUM) # --> Good\n", + "# checkpoint_id, model_config = (\"dataset-combined_30m/image_size-128/batch-512-v4-64-_combined_30m-finetuned\", OLD_CONFIG_MEDIUM) # --> Good\n", + "checkpoint_id, model_config = (\"combined-img128-good\", OLD_CONFIG_MEDIUM) # --> Good\n", "\n", - "checkpoint_base_path = \"gs://flaxdiff-datasets-regional/checkpoints/\"\n", + "# checkpoint_base_path = \"gs://flaxdiff-datasets-regional/checkpoints/\"\n", + "checkpoint_base_path = \"./checkpoints/\"\n", "\n", "model_config = model_config.copy()\n", "IMAGE_SIZE=128\n", @@ -985,9 +998,6 @@ " \"An apple\",\n", " \"A banana\",\n", " \"An astronaut riding a horse in space\",\n", - " \"A beautiful naked girl on snow\",\n", - " \"A beautiful naked girl on snow\",\n", - " \"A beautiful naked girl on snow\",\n", " ]\n", "pooled_labels, labels_seq = encodePrompts(prompts, textEncoderModel, textTokenizer)\n", "samples = sampler.generate_images(num_images=len(prompts), diffusion_steps=200, start_step=1000, end_step=0, priors=None, model_conditioning_inputs=(labels_seq,))\n", @@ -2121,10 +2131,13 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "from flax import linen as nn\n", "from typing import Dict, Callable, Sequence, Any, Union, Tuple, Optional\n", "from flax.typing import Dtype, PrecisionLike\n", "import einops\n", @@ -2483,14 +2496,14 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "172 μs ± 2.27 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" + "686 µs ± 3.39 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" ] } ], @@ -2532,6 +2545,205 @@ "# %timeit attention_block.apply(params, x)" ] }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [], + "source": [ + "from flaxdiff.models.simple_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init\n", + "\n", + "def unpatchify(x, channels=3):\n", + " patch_size = int((x.shape[2] // channels) ** 0.5)\n", + " h = w = int(x.shape[1] ** .5)\n", + " assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2], f\"Invalid shape: {x.shape}, should be {h*w}, {patch_size**2*channels}\"\n", + " x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B (h p1) (w p2) C', h=h, p1=patch_size, p2=patch_size)\n", + " return x\n", + "\n", + "class PatchEmbedding(nn.Module):\n", + " patch_size: int\n", + " embedding_dim: int\n", + " dtype: Any = jnp.float32\n", + " precision: Any = jax.lax.Precision.HIGH\n", + "\n", + " @nn.compact\n", + " def __call__(self, x):\n", + " batch, height, width, channels = x.shape\n", + " assert height % self.patch_size == 0 and width % self.patch_size == 0, \"Image dimensions must be divisible by patch size\"\n", + " \n", + " x = nn.Conv(features=self.embedding_dim, \n", + " kernel_size=(self.patch_size, self.patch_size), \n", + " strides=(self.patch_size, self.patch_size),\n", + " dtype=self.dtype,\n", + " precision=self.precision)(x)\n", + " x = jnp.reshape(x, (batch, -1, self.embedding_dim))\n", + " return x\n", + "\n", + "class PositionalEncoding(nn.Module):\n", + " max_len: int\n", + " embedding_dim: int\n", + "\n", + " @nn.compact\n", + " def __call__(self, x):\n", + " pe = self.param('pos_encoding',\n", + " jax.nn.initializers.zeros,\n", + " (1, self.max_len, self.embedding_dim))\n", + " return x + pe[:, :x.shape[1], :]\n", + "\n", + "class UViT(nn.Module):\n", + " output_channels:int=3\n", + " patch_size: int = 16\n", + " emb_features:int=768,\n", + " num_layers: int = 12\n", + " num_heads: int = 12\n", + " dropout_rate: float = 0.1\n", + " dtype: Any = jnp.float32\n", + " precision: Any = jax.lax.Precision.HIGH\n", + " use_projection: bool = False\n", + " activation:Callable = jax.nn.swish\n", + " norm_groups:int=8\n", + " dtype: Optional[Dtype] = None\n", + " precision: PrecisionLike = None\n", + " kernel_init: Callable = partial(kernel_init, 1.0)\n", + "\n", + " def setup(self):\n", + " if self.norm_groups > 0:\n", + " self.norm = partial(nn.GroupNorm, self.norm_groups)\n", + " else:\n", + " self.norm = partial(nn.RMSNorm, 1e-5)\n", + " \n", + " @nn.compact\n", + " def __call__(self, x, temb, textcontext=None):\n", + " # Time embedding\n", + " temb = FourierEmbedding(features=self.emb_features)(temb)\n", + " temb = TimeProjection(features=self.emb_features)(temb)\n", + "\n", + " # Patch embedding\n", + " x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features, \n", + " dtype=self.dtype, precision=self.precision)(x)\n", + " num_patches = x.shape[1]\n", + " \n", + " context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(), \n", + " dtype=self.dtype, precision=self.precision)(textcontext)\n", + " num_text_tokens = textcontext.shape[1]\n", + " \n", + " # print(f'Shape of x after patch embedding: {x.shape}, numPatches: {num_patches}, temb: {temb.shape}, context_emb: {context_emb.shape}')\n", + " \n", + " # Add time embedding\n", + " temb = jnp.expand_dims(temb, axis=1)\n", + " x = jnp.concatenate([x, temb, context_emb], axis=1)\n", + " # print(f'Shape of x after time embedding: {x.shape}')\n", + " \n", + " # Add positional encoding\n", + " x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.emb_features)(x)\n", + " \n", + " # print(f'Shape of x after positional encoding: {x.shape}')\n", + " \n", + " skips = []\n", + " # In blocks\n", + " for i in range(self.num_layers // 2):\n", + " x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads, \n", + " dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, \n", + " use_flash_attention=False, use_self_and_cross=False, force_fp32_for_softmax=True, \n", + " only_pure_attention=False,\n", + " kernel_init=self.kernel_init())(x)\n", + " skips.append(x)\n", + " \n", + " # Middle block\n", + " x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads, \n", + " dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, \n", + " use_flash_attention=False, use_self_and_cross=True, force_fp32_for_softmax=True, \n", + " only_pure_attention=False,\n", + " kernel_init=self.kernel_init())(x)\n", + " \n", + " # # Out blocks\n", + " for i in range(self.num_layers // 2):\n", + " skip = jnp.concatenate([x, skips.pop()], axis=-1)\n", + " skip = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(), \n", + " dtype=self.dtype, precision=self.precision)(skip)\n", + " x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads, \n", + " dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, \n", + " use_flash_attention=False, use_self_and_cross=False, force_fp32_for_softmax=True, \n", + " only_pure_attention=False,\n", + " kernel_init=self.kernel_init())(skip)\n", + " \n", + " # print(f'Shape of x after transformer blocks: {x.shape}')\n", + " x = self.norm()(x)\n", + " \n", + " # print(f'Shape of x after norm: {x.shape}')\n", + " \n", + " patch_dim = self.patch_size ** 2 * self.output_channels\n", + " x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x)\n", + " # print(f'Shape of x after patch dense layer: {x.shape}, patch_dim: {patch_dim}')\n", + " x = x[:, 1 + num_text_tokens:, :]\n", + " x = unpatchify(x, channels=self.output_channels)\n", + " # print(f'Shape of x after final dense layer: {x.shape}')\n", + " x = nn.Dense(features=self.output_channels, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x)\n", + " \n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "8.78 ms ± 8.09 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "x = jnp.ones((8, 128, 128, 3), dtype=jnp.bfloat16)\n", + "temb = jnp.ones((8,), dtype=jnp.bfloat16)\n", + "textcontext = jnp.ones((8, 77, 768), dtype=jnp.bfloat16) \n", + "vit = UViT(patch_size=16, \n", + " emb_features=768, \n", + " num_layers=12, \n", + " num_heads=12, \n", + " dropout_rate=0.1, \n", + " dtype=jnp.bfloat16)\n", + "params = vit.init(jax.random.PRNGKey(0), x, temb, textcontext)\n", + "\n", + "@jax.jit\n", + "def apply(params, x, temb, textcontext):\n", + " return vit.apply(params, x, temb, textcontext)\n", + "\n", + "out = apply(params, x, temb, textcontext)\n", + "\n", + "%timeit apply(params, x, temb, textcontext)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(8, 8, 8, 3)" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": 103, @@ -2738,7 +2950,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/flaxdiff/models/common.py b/flaxdiff/models/common.py index e619668..55e7e5a 100644 --- a/flaxdiff/models/common.py +++ b/flaxdiff/models/common.py @@ -108,12 +108,13 @@ def __call__(self, x): class TimeProjection(nn.Module): features:int activation:Callable=jax.nn.gelu + kernel_init:Callable=partial(kernel_init, 1.0) @nn.compact def __call__(self, x): - x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x) + x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init())(x) x = self.activation(x) - x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x) + x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init())(x) x = self.activation(x) return x @@ -122,7 +123,7 @@ class SeparableConv(nn.Module): kernel_size:tuple=(3, 3) strides:tuple=(1, 1) use_bias:bool=False - kernel_init:Callable=kernel_init(1.0) + kernel_init:Callable=partial(kernel_init, 1.0) padding:str="SAME" dtype: Optional[Dtype] = None precision: PrecisionLike = None @@ -132,7 +133,7 @@ def __call__(self, x): in_features = x.shape[-1] depthwise = nn.Conv( features=in_features, kernel_size=self.kernel_size, - strides=self.strides, kernel_init=self.kernel_init, + strides=self.strides, kernel_init=self.kernel_init(), feature_group_count=in_features, use_bias=self.use_bias, padding=self.padding, dtype=self.dtype, @@ -140,7 +141,7 @@ def __call__(self, x): )(x) pointwise = nn.Conv( features=self.features, kernel_size=(1, 1), - strides=(1, 1), kernel_init=self.kernel_init, + strides=(1, 1), kernel_init=self.kernel_init(), use_bias=self.use_bias, dtype=self.dtype, precision=self.precision @@ -152,7 +153,7 @@ class ConvLayer(nn.Module): features:int kernel_size:tuple=(3, 3) strides:tuple=(1, 1) - kernel_init:Callable=kernel_init(1.0) + kernel_init:Callable=partial(kernel_init, 1.0) dtype: Optional[Dtype] = None precision: PrecisionLike = None @@ -163,7 +164,7 @@ def setup(self): features=self.features, kernel_size=self.kernel_size, strides=self.strides, - kernel_init=self.kernel_init, + kernel_init=self.kernel_init(), dtype=self.dtype, precision=self.precision ) @@ -182,7 +183,7 @@ def setup(self): features=self.features, kernel_size=self.kernel_size, strides=self.strides, - kernel_init=self.kernel_init, + kernel_init=self.kernel_init(), dtype=self.dtype, precision=self.precision ) @@ -191,7 +192,7 @@ def setup(self): features=self.features, kernel_size=self.kernel_size, strides=self.strides, - kernel_init=self.kernel_init, + kernel_init=self.kernel_init(), dtype=self.dtype, precision=self.precision ) @@ -205,6 +206,7 @@ class Upsample(nn.Module): activation:Callable=jax.nn.swish dtype: Optional[Dtype] = None precision: PrecisionLike = None + kernel_init:Callable=partial(kernel_init, 1.0) @nn.compact def __call__(self, x, residual=None): @@ -218,7 +220,8 @@ def __call__(self, x, residual=None): kernel_size=(3, 3), strides=(1, 1), dtype=self.dtype, - precision=self.precision + precision=self.precision, + kernel_init=self.kernel_init() )(out) if residual is not None: out = jnp.concatenate([out, residual], axis=-1) @@ -230,6 +233,7 @@ class Downsample(nn.Module): activation:Callable=jax.nn.swish dtype: Optional[Dtype] = None precision: PrecisionLike = None + kernel_init:Callable=partial(kernel_init, 1.0) @nn.compact def __call__(self, x, residual=None): @@ -239,7 +243,8 @@ def __call__(self, x, residual=None): kernel_size=(3, 3), strides=(2, 2), dtype=self.dtype, - precision=self.precision + precision=self.precision, + kernel_init=self.kernel_init() )(x) if residual is not None: if residual.shape[1] > out.shape[1]: @@ -264,7 +269,7 @@ class ResidualBlock(nn.Module): direction:str=None res:int=2 norm_groups:int=8 - kernel_init:Callable=kernel_init(1.0) + kernel_init:Callable=partial(kernel_init, 1.0) dtype: Optional[Dtype] = None precision: PrecisionLike = None named_norms:bool=False @@ -291,7 +296,7 @@ def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_fe features=self.features, kernel_size=self.kernel_size, strides=self.strides, - kernel_init=self.kernel_init, + kernel_init=self.kernel_init(), name="conv1", dtype=self.dtype, precision=self.precision @@ -316,7 +321,7 @@ def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_fe features=self.features, kernel_size=self.kernel_size, strides=self.strides, - kernel_init=self.kernel_init, + kernel_init=self.kernel_init(), name="conv2", dtype=self.dtype, precision=self.precision @@ -328,7 +333,7 @@ def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_fe features=self.features, kernel_size=(1, 1), strides=1, - kernel_init=self.kernel_init, + kernel_init=self.kernel_init(), name="residual_conv", dtype=self.dtype, precision=self.precision diff --git a/flaxdiff/models/simple_vit.py b/flaxdiff/models/simple_vit.py index 6abe5e8..b63d0ed 100644 --- a/flaxdiff/models/simple_vit.py +++ b/flaxdiff/models/simple_vit.py @@ -3,9 +3,20 @@ import jax import jax.numpy as jnp from flax import linen as nn -from typing import Callable, Any +from typing import Callable, Any, Optional, Tuple from .simple_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init from .attention import TransformerBlock +from flaxdiff.models.simple_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init +import einops +from flax.typing import Dtype, PrecisionLike +from functools import partial + +def unpatchify(x, channels=3): + patch_size = int((x.shape[2] // channels) ** 0.5) + h = w = int(x.shape[1] ** .5) + assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2], f"Invalid shape: {x.shape}, should be {h*w}, {patch_size**2*channels}" + x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B (h p1) (w p2) C', h=h, p1=patch_size, p2=patch_size) + return x class PatchEmbedding(nn.Module): patch_size: int @@ -37,39 +48,28 @@ def __call__(self, x): (1, self.max_len, self.embedding_dim)) return x + pe[:, :x.shape[1], :] -class TransformerEncoder(nn.Module): - num_layers: int - num_heads: int - dropout_rate: float = 0.1 - dtype: Any = jnp.float32 - precision: Any = jax.lax.Precision.HIGH - use_projection: bool = False - - @nn.compact - def __call__(self, x, context=None): - for _ in range(self.num_layers): - x = TransformerBlock( - heads=self.num_heads, - dim_head=x.shape[-1] // self.num_heads, - dropout_rate=self.dropout_rate, - dtype=self.dtype, - precision=self.precision, - use_self_and_cross=True, - use_projection=self.use_projection, - )(x, context) - return x - -class VisionTransformer(nn.Module): +class UViT(nn.Module): + output_channels:int=3 patch_size: int = 16 - embedding_dim: int = 768 + emb_features:int=768, num_layers: int = 12 num_heads: int = 12 - emb_features: int = 256 dropout_rate: float = 0.1 dtype: Any = jnp.float32 precision: Any = jax.lax.Precision.HIGH use_projection: bool = False - + activation:Callable = jax.nn.swish + norm_groups:int=8 + dtype: Optional[Dtype] = None + precision: PrecisionLike = None + kernel_init: Callable = partial(kernel_init, 1.0) + + def setup(self): + if self.norm_groups > 0: + self.norm = partial(nn.GroupNorm, self.norm_groups) + else: + self.norm = partial(nn.RMSNorm, 1e-5) + @nn.compact def __call__(self, x, temb, textcontext=None): # Time embedding @@ -77,44 +77,65 @@ def __call__(self, x, temb, textcontext=None): temb = TimeProjection(features=self.emb_features)(temb) # Patch embedding - x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.embedding_dim, + x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features, dtype=self.dtype, precision=self.precision)(x) + num_patches = x.shape[1] - # Add positional encoding - x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.embedding_dim)(x) + context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(), + dtype=self.dtype, precision=self.precision)(textcontext) + num_text_tokens = textcontext.shape[1] + + # print(f'Shape of x after patch embedding: {x.shape}, numPatches: {num_patches}, temb: {temb.shape}, context_emb: {context_emb.shape}') - num_patches = x.shape[1] - # Add time embedding temb = jnp.expand_dims(temb, axis=1) - x = jnp.concatenate([x, temb], axis=1) - - # Transformer encoder - x = TransformerEncoder( - num_layers=self.num_layers, - num_heads=self.num_heads, - dropout_rate=self.dropout_rate, - dtype=self.dtype, - precision=self.precision, - use_projection=self.use_projection - )(x, textcontext) - - x = x[:, :num_patches, :] - - # Reshape to image dimensions - batch, _, _ = x.shape - height = width = int((num_patches) ** 0.5) - x = jnp.reshape(x, (batch, height, width, self.embedding_dim)) - - # Final convolution to get the desired output channels - x = ConvLayer( - conv_type="conv", - features=3, - kernel_size=(3, 3), - strides=(1, 1), - kernel_init=kernel_init(0.0), - dtype=self.dtype, - precision=self.precision - )(x) - + x = jnp.concatenate([x, temb, context_emb], axis=1) + # print(f'Shape of x after time embedding: {x.shape}') + + # Add positional encoding + x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.emb_features)(x) + + # print(f'Shape of x after positional encoding: {x.shape}') + + skips = [] + # In blocks + for i in range(self.num_layers // 2): + x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads, + dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, + use_flash_attention=False, use_self_and_cross=False, force_fp32_for_softmax=True, + only_pure_attention=False, + kernel_init=self.kernel_init())(x) + skips.append(x) + + # Middle block + x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads, + dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, + use_flash_attention=False, use_self_and_cross=True, force_fp32_for_softmax=True, + only_pure_attention=False, + kernel_init=self.kernel_init())(x) + + # # Out blocks + for i in range(self.num_layers // 2): + skip = jnp.concatenate([x, skips.pop()], axis=-1) + skip = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(), + dtype=self.dtype, precision=self.precision)(skip) + x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads, + dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, + use_flash_attention=False, use_self_and_cross=False, force_fp32_for_softmax=True, + only_pure_attention=False, + kernel_init=self.kernel_init())(skip) + + # print(f'Shape of x after transformer blocks: {x.shape}') + x = self.norm()(x) + + # print(f'Shape of x after norm: {x.shape}') + + patch_dim = self.patch_size ** 2 * self.output_channels + x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x) + # print(f'Shape of x after patch dense layer: {x.shape}, patch_dim: {patch_dim}') + x = x[:, 1 + num_text_tokens:, :] + x = unpatchify(x, channels=self.output_channels) + # print(f'Shape of x after final dense layer: {x.shape}') + x = nn.Dense(features=self.output_channels, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x) + return x \ No newline at end of file diff --git a/setup.py b/setup.py index a2da33f..1e88bf2 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setup( name='flaxdiff', packages=find_packages(), - version='0.1.22', + version='0.1.23', description='A versatile and easy to understand Diffusion library', long_description=open('README.md').read(), long_description_content_type='text/markdown', diff --git a/test modeling.ipynb b/test modeling.ipynb new file mode 100644 index 0000000..9b1f100 --- /dev/null +++ b/test modeling.ipynb @@ -0,0 +1,410 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "attention mode is flash\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import math\n", + "import einops\n", + "import torch.utils.checkpoint\n", + "# code from timm 0.3.2\n", + "import torch\n", + "import torch.nn as nn\n", + "import math\n", + "import warnings\n", + "\n", + "\n", + "def _no_grad_trunc_normal_(tensor, mean, std, a, b):\n", + " # Cut & paste from PyTorch official master until it's in a few official releases - RW\n", + " # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf\n", + " def norm_cdf(x):\n", + " # Computes standard normal cumulative distribution function\n", + " return (1. + math.erf(x / math.sqrt(2.))) / 2.\n", + "\n", + " if (mean < a - 2 * std) or (mean > b + 2 * std):\n", + " warnings.warn(\"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. \"\n", + " \"The distribution of values may be incorrect.\",\n", + " stacklevel=2)\n", + "\n", + " with torch.no_grad():\n", + " # Values are generated by using a truncated uniform distribution and\n", + " # then using the inverse CDF for the normal distribution.\n", + " # Get upper and lower cdf values\n", + " l = norm_cdf((a - mean) / std)\n", + " u = norm_cdf((b - mean) / std)\n", + "\n", + " # Uniformly fill tensor with values from [l, u], then translate to\n", + " # [2l-1, 2u-1].\n", + " tensor.uniform_(2 * l - 1, 2 * u - 1)\n", + "\n", + " # Use inverse cdf transform for normal distribution to get truncated\n", + " # standard normal\n", + " tensor.erfinv_()\n", + "\n", + " # Transform to proper mean, std\n", + " tensor.mul_(std * math.sqrt(2.))\n", + " tensor.add_(mean)\n", + "\n", + " # Clamp to ensure it's in the proper range\n", + " tensor.clamp_(min=a, max=b)\n", + " return tensor\n", + "\n", + "\n", + "def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):\n", + " # type: (Tensor, float, float, float, float) -> Tensor\n", + " r\"\"\"Fills the input Tensor with values drawn from a truncated\n", + " normal distribution. The values are effectively drawn from the\n", + " normal distribution :math:`\\mathcal{N}(\\text{mean}, \\text{std}^2)`\n", + " with values outside :math:`[a, b]` redrawn until they are within\n", + " the bounds. The method used for generating the random values works\n", + " best when :math:`a \\leq \\text{mean} \\leq b`.\n", + " Args:\n", + " tensor: an n-dimensional `torch.Tensor`\n", + " mean: the mean of the normal distribution\n", + " std: the standard deviation of the normal distribution\n", + " a: the minimum cutoff value\n", + " b: the maximum cutoff value\n", + " Examples:\n", + " >>> w = torch.empty(3, 5)\n", + " >>> nn.init.trunc_normal_(w)\n", + " \"\"\"\n", + " return _no_grad_trunc_normal_(tensor, mean, std, a, b)\n", + "\n", + "\n", + "def drop_path(x, drop_prob: float = 0., training: bool = False):\n", + " \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n", + "\n", + " This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,\n", + " the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...\n", + " See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for\n", + " changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use\n", + " 'survival rate' as the argument.\n", + "\n", + " \"\"\"\n", + " if drop_prob == 0. or not training:\n", + " return x\n", + " keep_prob = 1 - drop_prob\n", + " shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets\n", + " random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)\n", + " random_tensor.floor_() # binarize\n", + " output = x.div(keep_prob) * random_tensor\n", + " return output\n", + "\n", + "\n", + "class DropPath(nn.Module):\n", + " \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n", + " \"\"\"\n", + " def __init__(self, drop_prob=None):\n", + " super(DropPath, self).__init__()\n", + " self.drop_prob = drop_prob\n", + "\n", + " def forward(self, x):\n", + " return drop_path(x, self.drop_prob, self.training)\n", + "\n", + "\n", + "class Mlp(nn.Module):\n", + " def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n", + " super().__init__()\n", + " out_features = out_features or in_features\n", + " hidden_features = hidden_features or in_features\n", + " self.fc1 = nn.Linear(in_features, hidden_features)\n", + " self.act = act_layer()\n", + " self.fc2 = nn.Linear(hidden_features, out_features)\n", + " self.drop = nn.Dropout(drop)\n", + "\n", + " def forward(self, x):\n", + " x = self.fc1(x)\n", + " x = self.act(x)\n", + " x = self.drop(x)\n", + " x = self.fc2(x)\n", + " x = self.drop(x)\n", + " return x\n", + "\n", + "if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):\n", + " ATTENTION_MODE = 'flash'\n", + "else:\n", + " try:\n", + " import xformers\n", + " import xformers.ops\n", + " ATTENTION_MODE = 'xformers'\n", + " except:\n", + " ATTENTION_MODE = 'math'\n", + "print(f'attention mode is {ATTENTION_MODE}')\n", + "\n", + "\n", + "def timestep_embedding(timesteps, dim, max_period=10000):\n", + " \"\"\"\n", + " Create sinusoidal timestep embeddings.\n", + "\n", + " :param timesteps: a 1-D Tensor of N indices, one per batch element.\n", + " These may be fractional.\n", + " :param dim: the dimension of the output.\n", + " :param max_period: controls the minimum frequency of the embeddings.\n", + " :return: an [N x dim] Tensor of positional embeddings.\n", + " \"\"\"\n", + " half = dim // 2\n", + " freqs = torch.exp(\n", + " -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half\n", + " ).to(device=timesteps.device)\n", + " args = timesteps[:, None].float() * freqs[None]\n", + " embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n", + " if dim % 2:\n", + " embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)\n", + " return embedding\n", + "\n", + "\n", + "def patchify(imgs, patch_size):\n", + " x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)\n", + " return x\n", + "\n", + "\n", + "def unpatchify(x, channels=3):\n", + " patch_size = int((x.shape[2] // channels) ** 0.5)\n", + " h = w = int(x.shape[1] ** .5)\n", + " assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2]\n", + " x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size)\n", + " return x\n", + "\n", + "\n", + "class Attention(nn.Module):\n", + " def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):\n", + " super().__init__()\n", + " self.num_heads = num_heads\n", + " head_dim = dim // num_heads\n", + " self.scale = qk_scale or head_dim ** -0.5\n", + "\n", + " self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n", + " self.attn_drop = nn.Dropout(attn_drop)\n", + " self.proj = nn.Linear(dim, dim)\n", + " self.proj_drop = nn.Dropout(proj_drop)\n", + "\n", + " def forward(self, x):\n", + " B, L, C = x.shape\n", + "\n", + " qkv = self.qkv(x)\n", + " if ATTENTION_MODE == 'flash':\n", + " qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()\n", + " q, k, v = qkv[0], qkv[1], qkv[2] # B H L D\n", + " x = torch.nn.functional.scaled_dot_product_attention(q, k, v)\n", + " x = einops.rearrange(x, 'B H L D -> B L (H D)')\n", + " elif ATTENTION_MODE == 'xformers':\n", + " qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)\n", + " q, k, v = qkv[0], qkv[1], qkv[2] # B L H D\n", + " x = xformers.ops.memory_efficient_attention(q, k, v)\n", + " x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)\n", + " elif ATTENTION_MODE == 'math':\n", + " qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)\n", + " q, k, v = qkv[0], qkv[1], qkv[2] # B H L D\n", + " attn = (q @ k.transpose(-2, -1)) * self.scale\n", + " attn = attn.softmax(dim=-1)\n", + " attn = self.attn_drop(attn)\n", + " x = (attn @ v).transpose(1, 2).reshape(B, L, C)\n", + " else:\n", + " raise NotImplemented\n", + "\n", + " x = self.proj(x)\n", + " x = self.proj_drop(x)\n", + " return x\n", + "\n", + "\n", + "class Block(nn.Module):\n", + "\n", + " def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,\n", + " act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):\n", + " super().__init__()\n", + " self.norm1 = norm_layer(dim)\n", + " self.attn = Attention(\n", + " dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale)\n", + " self.norm2 = norm_layer(dim)\n", + " mlp_hidden_dim = int(dim * mlp_ratio)\n", + " self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)\n", + " self.skip_linear = nn.Linear(2 * dim, dim) if skip else None\n", + " self.use_checkpoint = use_checkpoint\n", + "\n", + " def forward(self, x, skip=None):\n", + " if self.use_checkpoint:\n", + " return torch.utils.checkpoint.checkpoint(self._forward, x, skip)\n", + " else:\n", + " return self._forward(x, skip)\n", + "\n", + " def _forward(self, x, skip=None):\n", + " if self.skip_linear is not None:\n", + " x = self.skip_linear(torch.cat([x, skip], dim=-1))\n", + " x = x + self.attn(self.norm1(x))\n", + " x = x + self.mlp(self.norm2(x))\n", + " return x\n", + "\n", + "\n", + "class PatchEmbed(nn.Module):\n", + " \"\"\" Image to Patch Embedding\n", + " \"\"\"\n", + " def __init__(self, patch_size, in_chans=3, embed_dim=768):\n", + " super().__init__()\n", + " self.patch_size = patch_size\n", + " self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n", + "\n", + " def forward(self, x):\n", + " B, C, H, W = x.shape\n", + " assert H % self.patch_size == 0 and W % self.patch_size == 0\n", + " x = self.proj(x).flatten(2).transpose(1, 2)\n", + " return x\n", + "\n", + "\n", + "class UViT(nn.Module):\n", + " def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,\n", + " qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, use_checkpoint=False,\n", + " clip_dim=768, num_clip_token=77, conv=True, skip=True):\n", + " super().__init__()\n", + " self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models\n", + " self.in_chans = in_chans\n", + "\n", + " self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\n", + " num_patches = (img_size // patch_size) ** 2\n", + "\n", + " self.time_embed = nn.Sequential(\n", + " nn.Linear(embed_dim, 4 * embed_dim),\n", + " nn.SiLU(),\n", + " nn.Linear(4 * embed_dim, embed_dim),\n", + " ) if mlp_time_embed else nn.Identity()\n", + "\n", + " self.context_embed = nn.Linear(clip_dim, embed_dim)\n", + "\n", + " self.extras = 1 + num_clip_token\n", + "\n", + " self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim))\n", + "\n", + " self.in_blocks = nn.ModuleList([\n", + " Block(\n", + " dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n", + " norm_layer=norm_layer, use_checkpoint=use_checkpoint)\n", + " for _ in range(depth // 2)])\n", + "\n", + " self.mid_block = Block(\n", + " dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n", + " norm_layer=norm_layer, use_checkpoint=use_checkpoint)\n", + "\n", + " self.out_blocks = nn.ModuleList([\n", + " Block(\n", + " dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n", + " norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint)\n", + " for _ in range(depth // 2)])\n", + "\n", + " self.norm = norm_layer(embed_dim)\n", + " self.patch_dim = patch_size ** 2 * in_chans\n", + " self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True)\n", + " self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) if conv else nn.Identity()\n", + "\n", + " trunc_normal_(self.pos_embed, std=.02)\n", + " self.apply(self._init_weights)\n", + "\n", + " def _init_weights(self, m):\n", + " if isinstance(m, nn.Linear):\n", + " trunc_normal_(m.weight, std=.02)\n", + " if isinstance(m, nn.Linear) and m.bias is not None:\n", + " nn.init.constant_(m.bias, 0)\n", + " elif isinstance(m, nn.LayerNorm):\n", + " nn.init.constant_(m.bias, 0)\n", + " nn.init.constant_(m.weight, 1.0)\n", + "\n", + " @torch.jit.ignore\n", + " def no_weight_decay(self):\n", + " return {'pos_embed'}\n", + "\n", + " def forward(self, x, timesteps, context):\n", + " x = self.patch_embed(x)\n", + " B, L, D = x.shape\n", + "\n", + " time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim))\n", + " time_token = time_token.unsqueeze(dim=1)\n", + " context_token = self.context_embed(context)\n", + " print(f\"Shape of context token: {context_token.shape}, time token: {time_token.shape}, x: {x.shape}\")\n", + " x = torch.cat((time_token, context_token, x), dim=1)\n", + " print(f\"Shape after concat: {x.shape}\")\n", + " x = x + self.pos_embed\n", + "\n", + " skips = []\n", + " for blk in self.in_blocks:\n", + " x = blk(x)\n", + " skips.append(x)\n", + "\n", + " x = self.mid_block(x)\n", + "\n", + " for blk in self.out_blocks:\n", + " x = blk(x, skips.pop())\n", + "\n", + " print(f\"Shape after transformers: {x.shape}\")\n", + "\n", + " x = self.norm(x)\n", + " x = self.decoder_pred(x)\n", + " assert x.size(1) == self.extras + L\n", + " x = x[:, self.extras:, :]\n", + " x = unpatchify(x, self.in_chans)\n", + " x = self.final_layer(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Shape of context token: torch.Size([1, 77, 768]), time token: torch.Size([1, 1, 768]), x: torch.Size([1, 196, 768])\n", + "Shape after concat: torch.Size([1, 274, 768])\n", + "Shape before transformers: torch.Size([1, 274, 768])\n" + ] + } + ], + "source": [ + "x = torch.randn(1, 3, 224, 224)\n", + "timesteps = torch.randn(1,)\n", + "context = torch.randn(1, 77, 768)\n", + "model = UViT()\n", + "out = model(x, timesteps, context)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "torch", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/training.py b/training.py index c6aef2a..2ef5b4d 100644 --- a/training.py +++ b/training.py @@ -1,5 +1,6 @@ from typing import Any, Tuple, Mapping, Callable, List, Dict from functools import partial +import flax.experimental import flax.jax_utils import flax.training import flax.training.dynamic_scale @@ -51,10 +52,18 @@ from orbax.checkpoint.utils import fully_replicated_host_local_array_to_global_array from termcolor import colored +import warnings +import traceback + +warnings.filterwarnings("ignore") + + ##################################################################################################################### ################################################# Initialization #################################################### ##################################################################################################################### +os.environ['TOKENIZERS_PARALLELISM'] = "false" + class RandomClass(): def __init__(self, rng: jax.random.PRNGKey): @@ -93,6 +102,30 @@ def get_random_key(self): 7: "light_cyan" } +def _build_global_shape_and_sharding( + local_shape: tuple[int, ...], global_mesh: Mesh +) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]: + sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names)) + global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:] + return global_shape, sharding + + +def form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array: + """Put local sharded array into local devices""" + global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh) + try: + local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0) + except ValueError as array_split_error: + raise ValueError( + f"Unable to put to devices shape {array.shape} with " + f"local device count {len(global_mesh.local_devices)} " + ) from array_split_error + local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices) + return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers) + +def convert_to_global_tree(global_mesh, pytree): + return jax.tree_util.tree_map_with_path(partial(form_global_array, global_mesh=global_mesh), pytree) + ##################################################################################################################### ################################################## Data Pipeline #################################################### ##################################################################################################################### @@ -129,7 +162,6 @@ def encodePrompts(prompts, model, tokenizer=None): return embed_pooled, embed_labels_full - class CaptionProcessor: def __init__(self, tensor_type="pt", modelname="openai/clip-vit-large-patch14"): self.tokenizer = AutoTokenizer.from_pretrained(modelname) @@ -197,7 +229,7 @@ def map(self, element) -> Dict[str, jnp.array]: return augmenters # -----------------------------------------------------------------------------------------------# -# CC12m and other GCS data sources -------------------------------------------------------------# +# CC12m and other GCS data sources --------------------------------------------------------------# # -----------------------------------------------------------------------------------------------# def data_source_gcs(source='arrayrecord/laion-aesthetics-12m+mscoco-2017'): @@ -221,9 +253,6 @@ def data_source(base="/home/mrwhite0racle/gcs_mount"): return ds return data_source -def labelizer_gcs(sample): - return sample['txt'] - def unpack_dict_of_byte_arrays(packed_data): unpacked_dict = {} offset = 0 @@ -243,21 +272,25 @@ def unpack_dict_of_byte_arrays(packed_data): unpacked_dict[key] = byte_array return unpacked_dict +def image_augmenter(image, image_scale, method=cv2.INTER_AREA): + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = cv2.resize(image, (image_scale, image_scale), + interpolation=cv2.INTER_AREA) + return image + def gcs_augmenters(image_scale, method): - labelizer = labelizer_gcs + labelizer = lambda sample : sample['txt'] class augmenters(pygrain.MapTransform): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.caption_processor = CaptionProcessor(tensor_type="np") + self.image_augmenter = partial(image_augmenter, image_scale=image_scale, method=method) def map(self, element) -> Dict[str, jnp.array]: element = unpack_dict_of_byte_arrays(element) image = np.asarray(bytearray(element['jpg']), dtype="uint8") image = cv2.imdecode(image, cv2.IMREAD_UNCHANGED) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - image = cv2.resize(image, (image_scale, image_scale), - interpolation=cv2.INTER_AREA) - # image = (image - 127.5) / 127.5 + image = self.image_augmenter(image) caption = labelizer(element).decode('utf-8') results = self.caption_processor(caption) return { @@ -307,11 +340,20 @@ def map(self, element) -> Dict[str, jnp.array]: 'arrayrecord2/laion-aesthetics-12m+mscoco-2017', 'arrayrecord2/cc12m', 'arrayrecord2/aestheticCoyo_0.26_clip_5.5aesthetic_256plus', + "arrayrecord2/playground+leonardo_x4+cc3m.parquet", ]), "augmenter": gcs_augmenters, } } +def batch_mesh_map(mesh): + class augmenters(pygrain.MapTransform): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def map(self, batch) -> Dict[str, jnp.array]: + return convert_to_global_tree(mesh, batch) + return augmenters def get_dataset_grain( data_name="cc12m", @@ -319,12 +361,11 @@ def get_dataset_grain( image_scale=256, count=None, num_epochs=None, - text_encoders=None, method=jax.image.ResizeMethod.LANCZOS3, - grain_worker_count=32, - grain_read_thread_count=64, - grain_read_buffer_size=50, - grain_worker_buffer_size=20, + worker_count=32, + read_thread_count=64, + read_buffer_size=50, + worker_buffer_size=20, seed=0, dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/", ): @@ -333,7 +374,7 @@ def get_dataset_grain( augmenter = dataset["augmenter"](image_scale, method) local_batch_size = batch_size // jax.process_count() - model, tokenizer = text_encoders + model, tokenizer = defaultTextEncodeModel() null_labels, null_labels_full = encodePrompts([""], model, tokenizer) null_labels = np.array(null_labels[0], dtype=np.float16) @@ -347,24 +388,27 @@ def get_dataset_grain( shard_options=pygrain.ShardByJaxProcess(), ) - transformations = [ - augmenter(), - pygrain.Batch(local_batch_size, drop_remainder=True), - ] - - loader = pygrain.DataLoader( - data_source=data_source, - sampler=sampler, - operations=transformations, - worker_count=grain_worker_count, - read_options=pygrain.ReadOptions( - grain_read_thread_count, grain_read_buffer_size - ), - worker_buffer_size=grain_worker_buffer_size, - ) - def get_trainset(): + transformations = [ + augmenter(), + pygrain.Batch(local_batch_size, drop_remainder=True), + ] + + # if mesh != None: + # transformations += [batch_mesh_map(mesh)] + + loader = pygrain.DataLoader( + data_source=data_source, + sampler=sampler, + operations=transformations, + worker_count=worker_count, + read_options=pygrain.ReadOptions( + read_thread_count, read_buffer_size + ), + worker_buffer_size=worker_buffer_size, + ) return loader + return { "train": get_trainset, @@ -377,34 +421,252 @@ def get_trainset(): "tokenizer": tokenizer, } +# -----------------------------------------------------------------------------------------------# +# Dataloader for directly streaming images from urls --------------------------------------------# +# -----------------------------------------------------------------------------------------------# -##################################################################################################################### -############################################### Training Pipeline ################################################### -##################################################################################################################### +import albumentations as A +from flaxdiff.data.online_loader import OnlineStreamingDataLoader, dataMapper, \ + default_collate, load_dataset, concatenate_datasets, \ + ImageBatchIterator, default_image_processor, load_from_disk -def _build_global_shape_and_sharding( - local_shape: tuple[int, ...], global_mesh: Mesh -) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]: - sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names)) - global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:] - return global_shape, sharding +import threading +import queue +def default_image_processor( + image, image_shape, + min_image_shape=(128, 128), + upscale_interpolation=cv2.INTER_CUBIC, + downscale_interpolation=cv2.INTER_AREA, +): + try: + image = np.array(image) + if len(image.shape) != 3 or image.shape[2] != 3: + return None, 0, 0 + original_height, original_width = image.shape[:2] + # check if the image is too small + if min(original_height, original_width) < min(min_image_shape): + return None, original_height, original_width + # check if wrong aspect ratio + if max(original_height, original_width) / min(original_height, original_width) > 2.4: + return None, original_height, original_width + # check if the variance is too low + if np.std(image) < 1e-5: + return None, original_height, original_width + # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + downscale = max(original_width, original_height) > max(image_shape) + interpolation = downscale_interpolation if downscale else upscale_interpolation + + image = A.longest_max_size(image, max( + image_shape), interpolation=interpolation) + image = A.pad( + image, + min_height=image_shape[0], + min_width=image_shape[1], + border_mode=cv2.BORDER_CONSTANT, + value=[255, 255, 255], + ) + return image, original_height, original_width + except Exception as e: + # print("Error processing image", e, image_shape, interpolation) + # traceback.print_exc() + return None, 0, 0 -def form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array: - """Put local sharded array into local devices""" - global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh) - try: - local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0) - except ValueError as array_split_error: - raise ValueError( - f"Unable to put to devices shape {array.shape} with " - f"local device count {len(global_mesh.local_devices)} " - ) from array_split_error - local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices) - return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers) -def convert_to_global_tree(global_mesh, pytree): - return jax.tree_util.tree_map_with_path(partial(form_global_array, global_mesh=global_mesh), pytree) +class OnlineStreamingDataLoader(): + def __init__( + self, + dataset, + batch_size=64, + image_shape=(256, 256), + min_image_shape=(128, 128), + num_workers=16, + num_threads=512, + default_split="all", + pre_map_maker=dataMapper, + pre_map_def={ + "url": "URL", + "caption": "TEXT", + }, + global_process_count=1, + global_process_index=0, + prefetch=1000, + collate_fn=default_collate, + timeout=15, + retries=3, + image_processor=default_image_processor, + upscale_interpolation=cv2.INTER_CUBIC, + downscale_interpolation=cv2.INTER_AREA, + ): + if isinstance(dataset, str): + dataset_path = dataset + print("Loading dataset from path") + if "gs://" in dataset: + dataset = load_from_disk(dataset_path) + else: + dataset = load_dataset(dataset_path, split=default_split) + elif isinstance(dataset, list): + if isinstance(dataset[0], str): + print("Loading multiple datasets from paths") + dataset = [load_from_disk(dataset_path) if "gs://" in dataset_path else load_dataset( + dataset_path, split=default_split) for dataset_path in dataset] + print("Concatenating multiple datasets") + dataset = concatenate_datasets(dataset) + dataset = dataset.shuffle(seed=0) + # dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000) + self.dataset = dataset.shard( + num_shards=global_process_count, index=global_process_index) + print(f"Dataset length: {len(dataset)}") + self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape, + min_image_shape=min_image_shape, + num_workers=num_workers, batch_size=batch_size, num_threads=num_threads, + timeout=timeout, retries=retries, image_processor=image_processor, + upscale_interpolation=upscale_interpolation, + downscale_interpolation=downscale_interpolation) + self.batch_size = batch_size + + # Launch a thread to load batches in the background + self.batch_queue = queue.Queue(prefetch) + + def batch_loader(): + for batch in self.iterator: + try: + self.batch_queue.put(collate_fn(batch)) + except Exception as e: + print("Error collating batch", e) + + self.loader_thread = threading.Thread(target=batch_loader) + self.loader_thread.start() + + def __iter__(self): + return self + + def __next__(self): + return self.batch_queue.get() + # return self.collate_fn(next(self.iterator)) + + def __len__(self): + return len(self.dataset) + +onlineDatasetMap = { + "combined_online": { + "source": [ + # "gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017.parquet" + # "ChristophSchuhmann/MS_COCO_2017_URL_TEXT", + # "dclure/laion-aesthetics-12m-umap", + "gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017", + # "gs://flaxdiff-datasets-regional/datasets/coyo700m-aesthetic-5.4_25M", + "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m", + "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m", + "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m", + "gs://flaxdiff-datasets-regional/datasets/cc12m", + "gs://flaxdiff-datasets-regional/datasets/cc3m", + "gs://flaxdiff-datasets-regional/datasets/playground-liked", + "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m", + "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m", + "gs://flaxdiff-datasets-regional/datasets/cc3m", + "gs://flaxdiff-datasets-regional/datasets/cc3m", + ] + } +} + +def generate_collate_fn(tokenizer): + caption_processor = CaptionProcessor(tensor_type="np") + def default_collate(batch): + try: + # urls = [sample["url"] for sample in batch] + captions = [sample["caption"] for sample in batch] + results = caption_processor(captions) + images = np.stack([sample["image"] for sample in batch], axis=0) + return { + "image": images, + "input_ids": results['input_ids'], + "attention_mask": results['attention_mask'], + } + except Exception as e: + print("Error in collate function", e, [sample["image"].shape for sample in batch]) + traceback.print_exc() + + return default_collate + +def get_dataset_online( + data_name="combined_online", + batch_size=64, + image_scale=256, + count=None, + num_epochs=None, + method=jax.image.ResizeMethod.LANCZOS3, + worker_count=32, + read_thread_count=64, + read_buffer_size=50, + worker_buffer_size=20, + seed=0, + dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/", + ): + local_batch_size = batch_size // jax.process_count() + + model, tokenizer = defaultTextEncodeModel() + + null_labels, null_labels_full = encodePrompts([""], model, tokenizer) + null_labels = np.array(null_labels[0], dtype=np.float16) + null_labels_full = np.array(null_labels_full[0], dtype=np.float16) + + sources = onlineDatasetMap[data_name]["source"] + dataloader = OnlineStreamingDataLoader( + sources, + batch_size=local_batch_size, + num_workers=worker_count, + num_threads=read_thread_count, + image_shape=(image_scale, image_scale), + global_process_count=jax.process_count(), + global_process_index=jax.process_index(), + prefetch=worker_buffer_size, + collate_fn=generate_collate_fn(tokenizer), + default_split="train", + ) + + def get_trainset(mesh: Mesh = None): + if mesh != None: + class dataLoaderWithMesh: + def __init__(self, dataloader, mesh): + self.dataloader = dataloader + self.mesh = mesh + self.tmp_queue = queue.Queue(worker_buffer_size) + def batch_loader(): + for batch in self.dataloader: + try: + self.tmp_queue.put(convert_to_global_tree(mesh, batch)) + except Exception as e: + print("Error processing batch", e) + self.loader_thread = threading.Thread(target=batch_loader) + self.loader_thread.start() + + def __iter__(self): + return self + + def __next__(self): + return self.tmp_queue.get() + + dataloader_with_mesh = dataLoaderWithMesh(dataloader, mesh) + + return dataloader_with_mesh + return dataloader + + return { + "train": get_trainset, + "train_len": len(dataloader) * jax.process_count(), + "local_batch_size": local_batch_size, + "global_batch_size": batch_size, + "null_labels": null_labels, + "null_labels_full": null_labels_full, + "model": model, + "tokenizer": tokenizer, + } + + +##################################################################################################################### +############################################### Training Pipeline ################################################### +##################################################################################################################### @struct.dataclass class Metrics(metrics.Collection): @@ -550,16 +812,16 @@ def init_state( self.best_state = best_state def get_state(self): - # return fully_replicated_host_local_array_to_global_array() - return jax.tree_util.tree_map(lambda x : np.array(x), self.state) + return self.get_np_tree(self.state) def get_best_state(self): - # return convert_to_global_tree(self.mesh, flax.jax_utils.replicate(self.best_state, jax.local_devices())) - return jax.tree_util.tree_map(lambda x : np.array(x), self.best_state) + return self.get_np_tree(self.best_state) def get_rngstate(self): - # return convert_to_global_tree(self.mesh, flax.jax_utils.replicate(self.rngstate, jax.local_devices())) - return jax.tree_util.tree_map(lambda x : np.array(x), self.rngstate) + return self.get_np_tree(self.rngstate) + + def get_np_tree(self, pytree): + return jax.tree_util.tree_map(lambda x : np.array(x), pytree) def checkpoint_path(self): path = os.path.join(self.checkpoint_base_path, self.name.replace(' ', '_').lower()) @@ -596,29 +858,35 @@ def load(self, checkpoint_path=None, checkpoint_step=None): rngstate = ckpt['rngs'] # Convert the state to a TrainState self.best_loss = ckpt['best_loss'] + if self.best_loss == 0: + # It cant be zero as that must have been some problem + self.best_loss = 1e9 current_epoch = ckpt.get('epoch', step) # Must be a checkpoint from an older version which used epochs instead of steps print( f"Loaded model from checkpoint at epoch {current_epoch} step {step}", ckpt['best_loss']) return current_epoch, step, state, best_state, rngstate - def save(self, epoch=0, step=0): + def save(self, epoch=0, step=0, state=None, rngstate=None): print(f"Saving model at epoch {epoch} step {step}") - ckpt = { - # 'model': self.model, - 'rngs': self.get_rngstate(), - 'state': self.get_state(), - 'best_state': self.get_best_state(), - 'best_loss': np.array(self.best_loss), - 'epoch': epoch, - } try: - save_args = orbax_utils.save_args_from_target(ckpt) - self.checkpointer.save(step, ckpt, save_kwargs={ - 'save_args': save_args}, force=True) - self.checkpointer.wait_until_finished() - pass + ckpt = { + # 'model': self.model, + 'rngs': self.get_rngstate() if rngstate is None else self.get_np_tree(rngstate), + 'state': self.get_state() if state is None else self.get_np_tree(state), + 'best_state': self.get_best_state(), + 'best_loss': np.array(self.best_loss), + 'epoch': epoch, + } + try: + save_args = orbax_utils.save_args_from_target(ckpt) + self.checkpointer.save(step, ckpt, save_kwargs={ + 'save_args': save_args}, force=True) + self.checkpointer.wait_until_finished() + pass + except Exception as e: + print("Error saving checkpoint", e) except Exception as e: - print("Error saving checkpoint", e) + print("Error saving checkpoint outer", e) def _define_train_step(self, **kwargs): model = self.model @@ -711,19 +979,25 @@ def train_loop(current_step, pbar: tqdm.tqdm, train_state, rng_state): last_save_time = time.time() for i in range(steps_per_epoch): batch = next(train_ds) + if i == 0: + print(f"First batch loaded at step {current_step}") + if self.distributed_training and global_device_count > 1: - # Convert the local device batches to a unified global jax.Array + # # Convert the local device batches to a unified global jax.Array batch = convert_to_global_tree(self.mesh, batch) train_state, loss, rng_state = train_step(train_state, rng_state, batch, global_device_indexes) + if i == 0: + print(f"Training started for process index {process_index} at step {current_step}") + if self.distributed_training: - loss = jax.experimental.multihost_utils.process_allgather(loss) + # loss = jax.experimental.multihost_utils.process_allgather(loss) loss = jnp.mean(loss) # Just to make sure its a scaler value if loss <= 1e-6: # If the loss is too low, we can assume the model has diverged print(colored(f"Loss too low at step {current_step} => {loss}", 'red')) - # Exit the training loop + # Reset the model to the old state exit(1) epoch_loss += loss @@ -737,10 +1011,12 @@ def train_loop(current_step, pbar: tqdm.tqdm, train_state, rng_state): "train/step" : current_step, "train/loss": loss, }, step=current_step) - # Save the model every 40 minutes - if time.time() - last_save_time > 40 * 60: - print(f"Saving model after 40 minutes at step {current_step}") - self.save(current_epoch, current_step) + # Save the model every few steps + if i % 10000 == 0 and i > 0: + print(f"Saving model after 10000 step {current_step}") + print(f"Devices: {len(jax.devices())}") # To sync the devices + self.save(current_epoch, current_step, train_state, rng_state) + print(f"Saving done by process index {process_index}") last_save_time = time.time() print(colored(f"Epoch done on index {process_index} => {current_epoch} Loss: {epoch_loss/steps_per_epoch}", 'green')) return epoch_loss, current_step, train_state, rng_state @@ -897,6 +1173,11 @@ def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, loc local_rng_state = RandomMarkovState(subkey) images = batch['image'] + + # First get the standard deviation of the images + # std = jnp.std(images, axis=(1, 2, 3)) + # is_non_zero = (std > 0) + images = jnp.array(images, dtype=jnp.float32) # normalize image images = (images - 127.5) / 127.5 @@ -929,8 +1210,11 @@ def model_loss(params): preds = model_output_transform.pred_transform( noisy_images, preds, rates) nloss = loss_fn(preds, expected_output) - # nloss = jnp.mean(nloss, axis=1) + # Ignore the loss contribution of images with zero standard deviation nloss *= noise_schedule.get_weights(noise_level) + # nloss = jnp.mean(nloss, axis=(1,2,3)) + # nloss = jnp.where(is_non_zero, nloss, 0) + # nloss = jnp.mean(nloss, where=nloss != 0) nloss = jnp.mean(nloss) loss = nloss return loss @@ -951,11 +1235,11 @@ def model_loss(params): new_state = train_state.apply_gradients(grads=grads) - if train_state.dynamic_scale: + if train_state.dynamic_scale is not None: # if is_fin == False the gradients contain Inf/NaNs and optimizer state and # params should be restored (= skip this step). select_fn = functools.partial(jnp.where, is_fin) - new_state = train_state.replace( + new_state = new_state.replace( opt_state=jax.tree_util.tree_map( select_fn, new_state.opt_state, train_state.opt_state ), @@ -1003,13 +1287,21 @@ def boolean_string(s): # Parse command-line arguments parser = argparse.ArgumentParser(description='Train a diffusion model') parser.add_argument('--GRAIN_WORKER_COUNT', type=int, - default=16, help='Number of grain workers') + default=32, help='Number of grain workers') +# parser.add_argument('--GRAIN_READ_THREAD_COUNT', type=int, +# default=512, help='Number of grain read threads') +# parser.add_argument('--GRAIN_READ_BUFFER_SIZE', type=int, +# default=80, help='Grain read buffer size') +# parser.add_argument('--GRAIN_WORKER_BUFFER_SIZE', type=int, +# default=500, help='Grain worker buffer size') +# parser.add_argument('--GRAIN_WORKER_COUNT', type=int, +# default=32, help='Number of grain workers') parser.add_argument('--GRAIN_READ_THREAD_COUNT', type=int, - default=64, help='Number of grain read threads') + default=128, help='Number of grain read threads') parser.add_argument('--GRAIN_READ_BUFFER_SIZE', type=int, - default=50, help='Grain read buffer size') + default=80, help='Grain read buffer size') parser.add_argument('--GRAIN_WORKER_BUFFER_SIZE', type=int, - default=20, help='Grain worker buffer size') + default=50, help='Grain worker buffer size') parser.add_argument('--batch_size', type=int, default=64, help='Batch size') parser.add_argument('--image_size', type=int, default=128, help='Image size') @@ -1030,6 +1322,11 @@ def boolean_string(s): parser.add_argument('--flash_attention', type=boolean_string, default=False, help='Use Flash Attention') parser.add_argument('--use_projection', type=boolean_string, default=False, help='Use projection') parser.add_argument('--use_self_and_cross', type=boolean_string, default=False, help='Use self and cross attention') +parser.add_argument('--only_pure_attention', type=boolean_string, default=True, help='Use only pure attention or proper transformer in the attention blocks') +parser.add_argument('--norm_groups', type=int, default=8, help='Number of normalization groups. 0 for RMSNorm') + +parser.add_argument('--named_norms', type=boolean_string, default=False, help='Use named norms') + parser.add_argument('--num_res_blocks', type=int, default=2, help='Number of residual blocks') parser.add_argument('--num_middle_res_blocks', type=int, default=1, help='Number of middle residual blocks') parser.add_argument('--activation', type=str, default='swish', help='activation to use') @@ -1066,6 +1363,7 @@ def boolean_string(s): default='{"modelname":"CompVis/stable-diffusion-v1-4"}', help='Autoencoder options as a dictionary') parser.add_argument('--use_dynamic_scale', type=boolean_string, default=False, help='Use dynamic scale for training') +parser.add_argument('--clip_grads', type=float, default=0, help='Clip gradients to this value') def main(args): resource.setrlimit( @@ -1076,6 +1374,7 @@ def main(args): resource.RLIMIT_OFILE, (65535, 65535)) + print("Initializing JAX") jax.distributed.initialize() # jax.config.update('jax_threefry_partitionable', True) @@ -1124,7 +1423,31 @@ def main(args): IMAGE_SIZE = args.image_size dataset_name = args.dataset - datalen = len(datasetMap[dataset_name]['source'](args.dataset_path)) + + if 'online' in dataset_name: + print("Using Online Dataset Generator") + dataset_generator = get_dataset_online + GRAIN_WORKER_BUFFER_SIZE *= 10 + GRAIN_READ_THREAD_COUNT *= 4 + else: + dataset_generator = get_dataset_grain + + data = dataset_generator( + args.dataset, + batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE, + worker_count=GRAIN_WORKER_COUNT, read_thread_count=GRAIN_READ_THREAD_COUNT, + read_buffer_size=GRAIN_READ_BUFFER_SIZE, worker_buffer_size=GRAIN_WORKER_BUFFER_SIZE, + seed=args.dataset_seed, + dataset_source=args.dataset_path, + ) + + if args.dataset_test: + dataset = iter(data['train']()) + + for _ in tqdm.tqdm(range(2000)): + batch = next(dataset) + + datalen = data['train_len'] batches = datalen // BATCH_SIZE # Define the configuration using the command-line arguments attention_configs = [ @@ -1133,12 +1456,18 @@ def main(args): if args.attention_heads > 0: attention_configs += [ - {"heads": args.attention_heads, "dtype": DTYPE, "flash_attention": args.flash_attention, - "use_projection": args.use_projection, "use_self_and_cross": args.use_self_and_cross}, + { + "heads": args.attention_heads, "dtype": DTYPE, "flash_attention": args.flash_attention, + "use_projection": args.use_projection, "use_self_and_cross": args.use_self_and_cross, + "only_pure_attention": args.only_pure_attention, + }, ] * (len(args.feature_depths) - 2) attention_configs += [ - {"heads": args.attention_heads, "dtype": DTYPE, "flash_attention": False, - "use_projection": False, "use_self_and_cross": False}, + { + "heads": args.attention_heads, "dtype": DTYPE, "flash_attention": False, + "use_projection": False, "use_self_and_cross": args.use_self_and_cross, + "only_pure_attention": args.only_pure_attention + }, ] else: print("Attention heads not provided, disabling attention") @@ -1169,6 +1498,8 @@ def main(args): "precision": PRECISION, "activation": args.activation, "output_channels": INPUT_CHANNELS, + "norm_groups": args.norm_groups, + "named_norms": args.named_norms, }, "dataset": { "name": dataset_name, @@ -1188,24 +1519,6 @@ def main(args): "autoencoder_opts": args.autoencoder_opts, } - text_encoders = defaultTextEncodeModel() - - data = get_dataset_grain( - CONFIG['dataset']['name'], - batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE, - grain_worker_count=GRAIN_WORKER_COUNT, grain_read_thread_count=GRAIN_READ_THREAD_COUNT, - grain_read_buffer_size=GRAIN_READ_BUFFER_SIZE, grain_worker_buffer_size=GRAIN_WORKER_BUFFER_SIZE, - text_encoders=text_encoders, - seed=args.dataset_seed, - dataset_source=args.dataset_path, - ) - - if args.dataset_test: - dataset = iter(data['train']()) - - for _ in tqdm.tqdm(range(3000)): - batch = next(dataset) - cosine_schedule = CosineNoiseSchedule(1000, beta_end=1) karas_ve_schedule = KarrasVENoiseScheduler( 1, sigma_max=80, rho=7, sigma_data=0.5) @@ -1235,6 +1548,12 @@ def main(args): decay_steps=batches * args.learning_rate_decay_epochs, end_value=args.learning_rate_end, ) solver = optimizer(learning_rate, **optimizer_opts) + + if args.clip_grads > 0: + solver = optax.chain( + optax.clip_by_global_norm(args.clip_grads), + solver, + ) wandb_config = { "project": "flaxdiff", @@ -1277,54 +1596,74 @@ def main(args): for tpu-v4-32 -python3 training.py --dataset=combined_30m --dataset_path='/home/mrwhite0racle/gcs_mount/'\ +python3 training.py --dataset=combined_online --dataset_path='/home/mrwhite0racle/gcs_mount/'\ --checkpoint_dir='flaxdiff-datasets-regional/checkpoints/' --checkpoint_fs='gcs'\ - --epochs=40 --batch_size=256 --image_size=128 \ + --epochs=40 --batch_size=256 --image_size=512 \ --learning_rate=9e-5 --num_res_blocks=3 --emb_features 512 \ --use_self_and_cross=False --precision=default --dtype=bfloat16 --attention_heads=16\ - --experiment_name='dataset-{dataset}/image_size-{image_size}/batch-{batch_size}-v4-32_NEW_combined_30m_1'\ - --optimizer=adamw --feature_depths 128 256 512 512 --use_dynamic_scale=True \ - --load_from_checkpoint='gs://flaxdiff-datasets-regional/checkpoints/dataset-combined_aesthetic/image_size-128/batch-256-v4-32_flaxdiff-0-1-8__new-combined_1' + --experiment_name='dataset-{dataset}/image_size-{image_size}/batch-{batch_size}-v4-32_ldm_data-online_big'\ + --optimizer=adamw --feature_depths 128 256 512 512 --autoencoder=stable_diffusion \ + --norm_groups 0 --clip_grads 0.5 --only_pure_attention=True +python3 training.py --dataset=combined_online --dataset_path='/home/mrwhite0racle/gcs_mount/'\ + --checkpoint_dir='flaxdiff-datasets-regional/checkpoints/' --checkpoint_fs='gcs'\ + --epochs=40 --batch_size=256 --image_size=128 \ + --learning_rate=1e-4 --num_res_blocks=3 --emb_features 512 \ + --use_self_and_cross=False --precision=default --dtype=bfloat16 --attention_heads=16\ + --experiment_name='dataset-{dataset}/image_size-{image_size}/batch-{batch_size}-v4-32_data-online'\ + --optimizer=adamw --feature_depths 128 256 512 512 \ + --norm_groups 0 --clip_grads 0.5 --only_pure_attention=True + for tpu-v4-16 python3 training.py --dataset=combined_30m --dataset_path='/home/mrwhite0racle/gcs_mount/'\ --checkpoint_dir='flaxdiff-datasets-regional/checkpoints/' --checkpoint_fs='gcs'\ - --epochs=40 --batch_size=128 --image_size=512 \ - --learning_rate=8e-5 --num_res_blocks=3 \ - --use_self_and_cross=False --precision=default --attention_heads=16\ - --experiment_name='dataset-{dataset}/image_size-{image_size}/batch-{batch_size}-v4-16_flaxdiff-0-1-8__combined_30m_ldm_1'\ - --learning_rate_schedule=cosine --learning_rate_peak=1e-4 --learning_rate_end=4e-5 --learning_rate_warmup_steps=5000 --learning_rate_decay_epochs=1\ - --optimizer=adamw --autoencoder=stable_diffusion --use_dynamic_scale=True\ - --load_from_checkpoint='gs://flaxdiff-datasets-regional/checkpoints/dataset-combined_aesthetic/image_size-512/batch-128-v4-16_flaxdiff-0-1-8_new-combined_ldm_1' + --epochs=40 --batch_size=128 --image_size=128 \ + --learning_rate=4e-5 --num_res_blocks=3 \ + --use_self_and_cross=False --dtype=bfloat16 --precision=default --attention_heads=8\ + --experiment_name='dataset-{dataset}/image_size-{image_size}/batch-{batch_size}-v4-16_flaxdiff-0-1-9_light_combined_30m_1'\ + --optimizer=adamw --use_dynamic_scale=True --norm_groups 0 --only_pure_attention=False \ + --load_from_checkpoint='gs://flaxdiff-datasets-regional/checkpoints/dataset-combined_30m/image_size-128/batch-128-v4-16_flaxdiff-0-1-9_light_combined_30m_ldm_1' ---------------------------------------------------------------------------------------------------------------------------- Old --> for tpu-v4-64 -python3 training.py --dataset=combined_aesthetic --dataset_path='/home/mrwhite0racle/gcs_mount/'\ +python3 training.py --dataset=combined_online --dataset_path='/home/mrwhite0racle/gcs_mount/'\ --checkpoint_dir='flaxdiff-datasets-regional/checkpoints/' --checkpoint_fs='gcs'\ - --epochs=40 --batch_size=512 --image_size=512 \ - --learning_rate=9e-6 --num_res_blocks=4 \ - --use_self_and_cross=False --dtype=bfloat16 --precision=default --attention_heads=16\ - --experiment_name='dataset-{dataset}/image_size-{image_size}/batch-{batch_size}-v4-64_flaxdiff-0-1-8_ldm_dyn_scale__new-combined_1'\ - --learning_rate_schedule=cosine --learning_rate_peak=4e-5 --learning_rate_end=9e-6 --learning_rate_warmup_steps=5000 --learning_rate_decay_epochs=2\ - --optimizer=adamw --autoencoder=stable_diffusion --use_dynamic_scale=True --feature_depths 128 256 512 512\ - --emb_features 512 --load_from_checkpoint='gs://flaxdiff-datasets-regional/checkpoints/dataset-combined_aesthetic/image_size-512/batch-512-v4-64_flaxdiff-0-1-8_ldm_dyn_scale_1' + --epochs=40 --batch_size=512 --image_size=512 --learning_rate=4e-5 \ + --num_res_blocks=4 --emb_features 512 --feature_depths 128 256 512 512 --norm_groups 0 --only_pure_attention=True --use_self_and_cross=False \ + --dtype=bfloat16 --precision=default --attention_heads=16\ + --experiment_name='dataset-{dataset}/image_size-{image_size}/batch-{batch_size}-v4-64_ldm_combined_online-bigger'\ + --learning_rate_schedule=cosine --learning_rate_peak=2.7e-4 --learning_rate_end=9e-5 --learning_rate_warmup_steps=10000 --learning_rate_decay_epochs=2\ + --optimizer=adamw --autoencoder=stable_diffusion --clip_grads 0.5 + + + --load_from_checkpoint='gs://flaxdiff-datasets-regional/checkpoints/dataset-combined_30m/image_size-512/batch-512-v4-64_flaxdiff-0-1-8_ldm_dyn_scale_NEW_ARCH_combined_30' --learning_rate_schedule=cosine --learning_rate_peak=4e-5 --learning_rate_end=9e-6 --learning_rate_warmup_steps=5000 --learning_rate_decay_epochs=2\ + + +python3 training.py --dataset=combined_30m --dataset_path='/home/mrwhite0racle/gcs_mount/'\ + --checkpoint_dir='flaxdiff-datasets-regional/checkpoints/' --checkpoint_fs='gcs'\ + --epochs=40 --batch_size=256 --image_size=128 \ + --learning_rate=4e-5 --num_res_blocks=3 \ + --use_self_and_cross=False --precision=default --dtype=bfloat16 --attention_heads=16\ + --experiment_name='dataset-{dataset}/image_size-{image_size}/batch-{batch_size}-v4-64_flaxdiff-0-1-10__new-combined_30m'\ + --optimizer=adamw --feature_depths 128 256 512 512 --use_dynamic_scale=True\ + --load_from_checkpoint='gs://flaxdiff-datasets-regional/checkpoints/dataset-combined_aesthetic/image_size-128/batch-256-v4-32_flaxdiff-0-1-8__new-combined_1' for tpu-v4-32 -python3 training.py --dataset=combined_aesthetic --dataset_path='/home/mrwhite0racle/gcs_mount/'\ +python3 training.py --dataset=combined_30m --dataset_path='/home/mrwhite0racle/gcs_mount/'\ --checkpoint_dir='flaxdiff-datasets-regional/checkpoints/' --checkpoint_fs='gcs'\ --epochs=40 --batch_size=256 --image_size=128 \ --learning_rate=8e-5 --num_res_blocks=3 \ --use_self_and_cross=False --precision=default --dtype=bfloat16 --attention_heads=16\ - --experiment_name='dataset-{dataset}/image_size-{image_size}/batch-{batch_size}-v4-32_flaxdiff-0-1-8__new-combined_1'\ - --optimizer=adamw --feature_depths 128 256 512 512 --use_dynamic_scale=True\ + --experiment_name='dataset-{dataset}/image_size-{image_size}/batch-{batch_size}-v4-32_flaxdiff-0-1-9_combined_30m'\ + --optimizer=adamw --feature_depths 128 256 512 512 --use_dynamic_scale=True --named_norms=True --only_pure_attention=True\ --load_from_checkpoint='gs://flaxdiff-datasets-regional/checkpoints/dataset-combined_aesthetic/image_size-128/batch-256-v4-32_flaxdiff-0-1-8__3' for tpu-v4-16 @@ -1338,4 +1677,4 @@ def main(args): --learning_rate_schedule=cosine --learning_rate_peak=1e-4 --learning_rate_end=4e-5 --learning_rate_warmup_steps=5000 --learning_rate_decay_epochs=1\ --optimizer=adamw --autoencoder=stable_diffusion --use_dynamic_scale=True\ --load_from_checkpoint='gs://flaxdiff-datasets-regional/checkpoints/dataset-combined_aesthetic/image_size-512/batch-128-v4-16_flaxdiff-0-1-8__ldm_1' -""" \ No newline at end of file +"""