diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 21a96a1e2..b133660e4 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -14,6 +14,7 @@ from keras_hub.src.layers.modeling.reversible_embedding import ( ReversibleEmbedding, ) +from keras_hub.src.layers.modeling.rms_normalization import RMSNormalization from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding from keras_hub.src.layers.modeling.sine_position_encoding import ( SinePositionEncoding, diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 70585cec1..ba707c67e 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -162,6 +162,11 @@ ) from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone +from keras_hub.src.models.flux.flux_model import FluxBackbone +from keras_hub.src.models.flux.flux_text_to_image import FluxTextToImage +from keras_hub.src.models.flux.flux_text_to_image_preprocessor import ( + FluxTextToImagePreprocessor, +) from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import ( diff --git a/keras_hub/src/layers/modeling/rms_normalization.py b/keras_hub/src/layers/modeling/rms_normalization.py new file mode 100644 index 000000000..f5d9f6929 --- /dev/null +++ b/keras_hub/src/layers/modeling/rms_normalization.py @@ -0,0 +1,34 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export + + +@keras_hub_export("keras_hub.layers.RMSNormalization") +class RMSNormalization(keras.layers.Layer): + """ + Root Mean Square (RMS) Normalization layer. + This layer normalizes the input tensor based on its RMS value and applies + a learned scaling factor. + Args: + input_dim: int. The dimensionality of the input tensor. + """ + + def __init__(self, input_dim): + super().__init__() + self.scale = self.add_weight( + name="scale", shape=(input_dim,), initializer="ones" + ) + + def call(self, x): + """ + Applies RMS normalization to the input tensor. + Args: + x: KerasTensor. Input tensor of shape (batch_size, input_dim). + Returns: + KerasTensor: The RMS-normalized tensor of the same shape (batch_size, input_dim), + scaled by the learned `scale` parameter. + """ + x = ops.cast(x, float) + rrms = ops.rsqrt(ops.mean(ops.square(x), axis=-1, keepdims=True) + 1e-6) + return (x * rrms) * self.scale diff --git a/keras_hub/src/models/flux/__init__.py b/keras_hub/src/models/flux/__init__.py new file mode 100644 index 000000000..02dffc1c4 --- /dev/null +++ b/keras_hub/src/models/flux/__init__.py @@ -0,0 +1,5 @@ +from keras_hub.src.models.flux.flux_model import FluxBackbone +from keras_hub.src.models.flux.flux_presets import presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(presets, FluxBackbone) diff --git a/keras_hub/src/models/flux/flux_backbone_test.py b/keras_hub/src/models/flux/flux_backbone_test.py new file mode 100644 index 000000000..5a15e3b7f --- /dev/null +++ b/keras_hub/src/models/flux/flux_backbone_test.py @@ -0,0 +1,73 @@ +import pytest +from keras import ops + +from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder +from keras_hub.src.models.flux.flux_model import FluxBackbone +from keras_hub.src.models.vae.vae_backbone import VAEBackbone +from keras_hub.src.tests.test_case import TestCase + + +class FluxBackboneTest(TestCase): + def setUp(self): + vae = VAEBackbone( + [32, 32, 32, 32], + [1, 1, 1, 1], + [32, 32, 32, 32], + [1, 1, 1, 1], + # Use `mode` generate a deterministic output. + sampler_method="mode", + name="vae", + ) + clip_l = CLIPTextEncoder( + 20, 32, 32, 2, 2, 64, "quick_gelu", -2, name="clip_l" + ) + self.init_kwargs = { + "input_channels": 256, + "hidden_size": 1024, + "mlp_ratio": 2.0, + "num_heads": 8, + "depth": 4, + "depth_single_blocks": 8, + "axes_dim": [16, 56, 56], + "theta": 10_000, + "use_bias": True, + "guidance_embed": True, + "image_shape": (32, 256), + "text_shape": (32, 256), + "image_ids_shape": (32, 3), + "text_ids_shape": (32, 3), + "y_shape": (256,), + } + + self.pipeline_models = { + "vae": vae, + "clip_l": clip_l, + } + + self.input_data = { + "image": ops.ones((1, 32, 256)), + "image_ids": ops.ones((1, 32, 3)), + "text": ops.ones((1, 32, 256)), + "text_ids": ops.ones((1, 32, 3)), + "y": ops.ones((1, 256)), + "timesteps": ops.ones((1)), + "guidance": ops.ones((1)), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=FluxBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(1, 32, 256), + run_mixed_precision_check=False, + run_quantization_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=FluxBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/flux/flux_layers.py b/keras_hub/src/models/flux/flux_layers.py new file mode 100644 index 000000000..2dfd4e339 --- /dev/null +++ b/keras_hub/src/models/flux/flux_layers.py @@ -0,0 +1,494 @@ +import keras +from keras import layers +from keras import ops + +from keras_hub.src.layers.modeling.rms_normalization import RMSNormalization +from keras_hub.src.models.flux.flux_maths import FluxRoPEAttention +from keras_hub.src.models.flux.flux_maths import RotaryPositionalEmbedding +from keras_hub.src.models.flux.flux_maths import rearrange_symbolic_tensors + + +class EmbedND(keras.Model): + """ + Embedding layer for N-dimensional inputs using Rotary Positional Embedding (RoPE). + + This layer applies RoPE embeddings across multiple axes of the input tensor and + concatenates the embeddings along a specified axis. + + Args: + theta. Rotational angle parameter for RoPE. + axes_dim. Dimensionality for each axis of the input tensor. + """ + + def __init__(self, theta, axes_dim): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + self.rope = RotaryPositionalEmbedding() + + def build(self, input_shape): + n_axes = input_shape[-1] + for i in range(n_axes): + self.rope.build((input_shape[:-1] + (self.axes_dim[i],))) + + def call(self, ids): + """ + Computes the positional embeddings for each axis and concatenates them. + + Args: + ids: KerasTensor. Input tensor of shape (..., num_axes). + + Returns: + KerasTensor: Positional embeddings of shape (..., concatenated_dim, 1, ...). + """ + n_axes = ids.shape[-1] + emb = ops.concatenate( + [ + self.rope(ids[..., i], dim=self.axes_dim[i], theta=self.theta) + for i in range(n_axes) + ], + axis=-3, + ) + + return ops.expand_dims(emb, axis=1) + + +class MLPEmbedder(keras.Model): + """ + A simple multi-layer perceptron (MLP) embedder model. + + This model applies a linear transformation followed by the SiLU activation + function and another linear transformation to the input tensor. + + Args: + hidden_dim. The dimensionality of the hidden layer. + """ + + def __init__(self, hidden_dim): + super().__init__() + self.hidden_dim = hidden_dim + self.input_layer = layers.Dense(hidden_dim, use_bias=True) + self.silu = layers.Activation("silu") + self.output_layer = layers.Dense(hidden_dim, use_bias=True) + + def build(self, input_shape): + self.input_layer.build(input_shape) + self.output_layer.build((input_shape[0], self.input_layer.units)) + + def call(self, x): + """ + Applies the MLP embedding to the input tensor. + + Args: + x: KerasTensor. Input tensor of shape (batch_size, in_dim). + + Returns: + KerasTensor: Output tensor of shape (batch_size, hidden_dim) after applying + the MLP transformations. + """ + x = self.input_layer(x) + x = self.silu(x) + return self.output_layer(x) + + +class QKNorm(keras.layers.Layer): + """ + A layer that applies RMS normalization to query and key tensors. + + This layer normalizes the input query and key tensors using separate RMSNormalization + layers for each. + + Args: + input_dim. The dimensionality of the input query and key tensors. + """ + + def __init__(self, input_dim): + super().__init__() + self.query_norm = RMSNormalization(input_dim) + self.key_norm = RMSNormalization(input_dim) + + def build(self, input_shape): + self.query_norm.build(input_shape) + self.key_norm.build(input_shape) + + def call(self, q, k): + """ + Applies RMS normalization to the query and key tensors. + + Args: + q: KerasTensor. The query tensor of shape (batch_size, input_dim). + k: KerasTensor. The key tensor of shape (batch_size, input_dim). + + Returns: + tuple[KerasTensor, KerasTensor]: A tuple containing the normalized query and key tensors. + """ + q = self.query_norm(q) + k = self.key_norm(k) + return q, k + + +class SelfAttention(keras.Model): + """ + Multi-head self-attention layer with RoPE embeddings and RMS normalization. + + This layer performs self-attention over the input sequence and applies RMS + normalization to the query and key tensors before computing the attention scores. + + Args: + dim: int. Dimensionality of the input tensor. + num_heads: int. Number of attention heads. Default is 8. + use_bias: bool. Whether to use bias in the query, key, value projection layers. + Default is False. + """ + + def __init__(self, dim, num_heads=8, use_bias=False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.dim = dim + + self.qkv = layers.Dense(dim * 3, use_bias=use_bias) + self.norm = QKNorm(head_dim) + self.proj = layers.Dense(dim) + self.attention = FluxRoPEAttention() + + def build(self, input_shape): + self.qkv.build(input_shape) + head_dim = input_shape[-1] // self.num_heads + self.norm.build((None, input_shape[1], head_dim)) + self.proj.build((None, input_shape[1], input_shape[-1])) + + def call(self, x, positional_encoding): + """ + Applies self-attention with RoPE embeddings. + + Args: + x: KerasTensor. Input tensor of shape (batch_size, seq_len, dim). + positional_encoding: KerasTensor. Positional encoding tensor for RoPE. + + Returns: + KerasTensor: Output tensor after self-attention and projection. + """ + qkv = self.qkv(x) + q, k, v = rearrange_symbolic_tensors(qkv, K=3, H=self.num_heads) + q, k = self.norm(q, k) + x = self.attention( + q=q, k=k, v=v, positional_encoding=positional_encoding + ) + x = self.proj(x) + return x + + +class Modulation(keras.Model): + """ + Modulation layer that produces shift, scale, and gate tensors. + + This layer applies a SiLU activation to the input tensor followed by a linear + transformation to generate modulation parameters. It can optionally generate two + sets of modulation parameters. + + Args: + dim: int. Dimensionality of the modulation output. + double: bool. Whether to generate two sets of modulation parameters. + """ + + def __init__(self, dim, double): + super().__init__() + self.dim = dim + self.is_double = double + self.multiplier = 6 if double else 3 + self.linear_projection = keras.layers.Dense( + self.multiplier * dim, use_bias=True + ) + + def build(self, input_shape): + self.linear_projection.build(input_shape) + + def call(self, x): + """ + Generates modulation parameters from the input tensor. + + Args: + x: KerasTensor. Input tensor. + + Returns: + tuple[ModulationOut, ModulationOut | None]: A tuple containing the shift, + scale, and gate tensors. If `double` is True, returns two sets of modulation parameters. + """ + x = keras.layers.Activation("silu")(x) + out = self.linear_projection(x) + out = ops.split( + out[:, None, :], indices_or_sections=self.multiplier, axis=-1 + ) + + first_output = {"shift": out[0], "scale": out[1], "gate": out[2]} + second_output = ( + {"shift": out[3], "scale": out[4], "gate": out[5]} + if self.is_double + else None + ) + + return first_output, second_output + + +class DoubleStreamBlock(keras.Model): + """ + A block that processes image and text inputs in parallel using + self-attention and MLP layers, with modulation. + + Args: + hidden_size: int. The hidden dimension size for the model. + num_heads: int. The number of attention heads. + mlp_ratio: float. The ratio of the MLP hidden dimension to the hidden size. + use_bias: bool, optional. Whether to include bias in QKV projection. Default is False. + """ + + def __init__( + self, + hidden_size, + num_heads, + mlp_ratio, + use_bias=False, + ): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + + self.image_mod = Modulation(hidden_size, double=True) + self.image_norm1 = keras.layers.LayerNormalization(epsilon=1e-6) + self.image_attn = SelfAttention( + dim=hidden_size, num_heads=num_heads, use_bias=use_bias + ) + + self.image_norm2 = keras.layers.LayerNormalization(epsilon=1e-6) + self.image_mlp = keras.Sequential( + [ + keras.layers.Dense(mlp_hidden_dim, use_bias=True), + keras.layers.Activation("gelu"), + keras.layers.Dense(hidden_size, use_bias=True), + ] + ) + + self.text_mod = Modulation(hidden_size, double=True) + self.text_norm1 = keras.layers.LayerNormalization(epsilon=1e-6) + self.text_attn = SelfAttention( + dim=hidden_size, num_heads=num_heads, use_bias=use_bias + ) + + self.text_norm2 = keras.layers.LayerNormalization(epsilon=1e-6) + self.text_mlp = keras.Sequential( + [ + keras.layers.Dense(mlp_hidden_dim, use_bias=True), + keras.layers.Activation("gelu"), + keras.layers.Dense(hidden_size, use_bias=True), + ] + ) + self.attention = FluxRoPEAttention() + + def call(self, image, text, modulation_encoding, positional_encoding): + """ + Forward pass for the DoubleStreamBlock. + + Args: + image: KerasTensor. Input image tensor. + text: KerasTensor. Input text tensor. + modulation_encoding: KerasTensor. Modulation vector. + positional_encoding: KerasTensor. Positional encoding tensor. + + Returns: + Tuple[KerasTensor, KerasTensor]: The modified image and text tensors. + """ + image_mod1, image_mod2 = self.image_mod(modulation_encoding) + text_mod1, text_mod2 = self.text_mod(modulation_encoding) + + # prepare image for attention + image_modulated = self.image_norm1(image) + image_modulated = ( + 1 + image_mod1["scale"] + ) * image_modulated + image_mod1["shift"] + image_qkv = self.image_attn.qkv(image_modulated) + + image_q, image_k, image_v = rearrange_symbolic_tensors( + image_qkv, K=3, H=self.num_heads + ) + image_q, image_k = self.image_attn.norm(image_q, image_k) + + # prepare text for attention + text_modulated = self.text_norm1(text) + text_modulated = (1 + text_mod1["scale"]) * text_modulated + text_mod1[ + "shift" + ] + text_qkv = self.text_attn.qkv(text_modulated) + + text_q, text_k, text_v = rearrange_symbolic_tensors( + text_qkv, K=3, H=self.num_heads + ) + + text_q, text_k = self.text_attn.norm(text_q, text_k) + + # run actual attention + q = ops.concatenate((text_q, image_q), axis=2) + k = ops.concatenate((text_k, image_k), axis=2) + v = ops.concatenate((text_v, image_v), axis=2) + + attn = self.attention( + q=q, k=k, v=v, positional_encoding=positional_encoding + ) + text_attn, image_attn = ( + attn[:, : text.shape[1]], + attn[:, text.shape[1] :], + ) + + # calculate the image blocks + image = image + image_mod1["gate"] * self.image_attn.proj(image_attn) + image = image + image_mod2["gate"] * self.image_mlp( + (1 + image_mod2["scale"]) * self.image_norm2(image) + + image_mod2["shift"] + ) + + # calculate the text blocks + text = text + text_mod1["gate"] * self.text_attn.proj(text_attn) + text = text + text_mod2["gate"] * self.text_mlp( + (1 + text_mod2["scale"]) * self.text_norm2(text) + + text_mod2["shift"] + ) + return image, text + + +class SingleStreamBlock(keras.Model): + """ + A DiT block with parallel linear layers. + + As described in https://arxiv.org/abs/2302.05442 and + adapted for the modulation interface. + + Args: + hidden_size: int. The hidden dimension size for the model. + num_heads: int. The number of attention heads. + mlp_ratio: float, optional. The ratio of the MLP hidden dimension to the hidden size. Default is 4.0. + qk_scale: float, optional. Scaling factor for the query-key product. Default is None. + """ + + def __init__( + self, + hidden_size, + num_heads, + mlp_ratio=4.0, + qk_scale=None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = keras.layers.Dense(hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = keras.layers.Dense(hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = keras.layers.LayerNormalization(epsilon=1e-6) + self.modulation = Modulation(hidden_size, double=False) + self.attention = FluxRoPEAttention() + + def build( + self, x_shape, modulation_encoding_shape, positional_encoding_shape + ): + self.linear1.build(x_shape) + self.linear2.build( + (x_shape[0], x_shape[1], self.hidden_size + self.mlp_hidden_dim) + ) + + self.modulation.build( + modulation_encoding_shape + ) # Build the modulation layer + + self.norm.build( + ( + x_shape[0], + self.num_heads, + x_shape[1], + x_shape[-1] // self.num_heads, + ) + ) + + def call(self, x, modulation_encoding, positional_encoding): + """ + Forward pass for the SingleStreamBlock. + + Args: + x: KerasTensor. Input tensor. + modulation_encoding: KerasTensor. Modulation vector. + positional_encoding: KerasTensor. Positional encoding tensor. + + Returns: + KerasTensor: The modified input tensor after processing. + """ + mod, _ = self.modulation(modulation_encoding) + x_mod = (1 + mod["scale"]) * self.pre_norm(x) + mod["shift"] + qkv, mlp = ops.split( + self.linear1(x_mod), [3 * self.hidden_size], axis=-1 + ) + + q, k, v = rearrange_symbolic_tensors(qkv, K=3, H=self.num_heads) + q, k = self.norm(q, k) + + # compute attention + attn = self.attention( + q, k=k, v=v, positional_encoding=positional_encoding + ) + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2( + ops.concatenate( + (attn, keras.activations.gelu(mlp, approximate=True)), 2 + ) + ) + return x + mod["gate"] * output + + +class LastLayer(keras.Model): + """ + Final layer for processing output tensors with adaptive normalization. + + Args: + hidden_size: int. The hidden dimension size for the model. + patch_size: int. The size of each patch. + output_channels: int. The number of output channels. + """ + + def __init__(self, hidden_size, patch_size, output_channels): + super().__init__() + self.norm_final = keras.layers.LayerNormalization(epsilon=1e-6) + self.linear = keras.layers.Dense( + patch_size * patch_size * output_channels, use_bias=True + ) + self.adaLN_modulation = keras.Sequential( + [ + keras.layers.Activation("silu"), + keras.layers.Dense(2 * hidden_size, use_bias=True), + ] + ) + + def call(self, x, modulation_encoding): + """ + Forward pass for the LastLayer. + + Args: + x: KerasTensor. Input tensor. + modulation_encoding: KerasTensor. Modulation vector. + + Returns: + KerasTensor: The output tensor after final processing. + """ + shift, scale = ops.split( + self.adaLN_modulation(modulation_encoding), 2, axis=1 + ) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x diff --git a/keras_hub/src/models/flux/flux_maths.py b/keras_hub/src/models/flux/flux_maths.py new file mode 100644 index 000000000..fad9bc9e0 --- /dev/null +++ b/keras_hub/src/models/flux/flux_maths.py @@ -0,0 +1,219 @@ +import keras +from einops import rearrange +from keras import ops + + +class TimestepEmbedding(keras.layers.Layer): + """ + Creates sinusoidal timestep embeddings. + + Call arguments: + t: KerasTensor of shape (N,), representing N indices, one per batch element. + These values may be fractional. + dim: int. The dimension of the output. + max_period: int, optional. Controls the minimum frequency of the embeddings. Defaults to 10000. + time_factor: float, optional. A scaling factor applied to `t`. Defaults to 1000.0. + + Returns: + KerasTensor: A tensor of shape (N, D) representing the positional embeddings, + where N is the number of batch elements and D is the specified dimension `dim`. + """ + + def call(self, t, dim, max_period=10000, time_factor=1000.0): + t = time_factor * t + half_dim = dim // 2 + freqs = ops.exp( + ops.cast(-ops.log(max_period), dtype=t.dtype) + * ops.arange(half_dim, dtype=t.dtype) + / half_dim + ) + args = t[:, None] * freqs[None] + embedding = ops.concatenate([ops.cos(args), ops.sin(args)], axis=-1) + + if dim % 2 != 0: + embedding = ops.concatenate( + [embedding, ops.zeros_like(embedding[:, :1])], axis=-1 + ) + + return embedding + + +class RotaryPositionalEmbedding(keras.layers.Layer): + """ + Applies Rotary Positional Embedding (RoPE) to the input tensor. + + Call arguments: + pos: KerasTensor. The positional tensor with shape (..., n, d). + dim: int. The embedding dimension, should be even. + theta: int. The base frequency. + + Returns: + KerasTensor: The tensor with applied RoPE transformation. + """ + + def call(self, pos, dim, theta): + scale = ops.arange(0, dim, 2, dtype="float32") / dim + omega = 1.0 / (theta**scale) + out = ops.einsum("...n,d->...nd", pos, omega) + out = ops.stack( + [ops.cos(out), -ops.sin(out), ops.sin(out), ops.cos(out)], axis=-1 + ) + out = rearrange(out, "... n d (i j) -> ... n d i j", i=2, j=2) + return ops.cast(out, dtype="float32") + + +class ApplyRoPE(keras.layers.Layer): + """ + Applies the RoPE transformation to the query and key tensors. + + Call arguments: + xq: KerasTensor. The query tensor of shape (..., L, D). + xk: KerasTensor. The key tensor of shape (..., L, D). + freqs_cis: KerasTensor. The frequency complex numbers tensor with shape (..., 2). + + Returns: + tuple[KerasTensor, KerasTensor]: The transformed query and key tensors. + """ + + def call(self, xq, xk, freqs_cis): + xq_ = ops.reshape(xq, (*ops.shape(xq)[:-1], -1, 1, 2)) + xk_ = ops.reshape(xk, (*ops.shape(xk)[:-1], -1, 1, 2)) + + xq_out = ( + freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + ) + xk_out = ( + freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + ) + + return ops.reshape(xq_out, ops.shape(xq)), ops.reshape( + xk_out, ops.shape(xk) + ) + + +class FluxRoPEAttention(keras.layers.Layer): + """ + Computes the attention mechanism with the RoPE transformation applied to the query and key tensors. + + Args: + dropout_p: float, optional. Dropout probability. Defaults to 0.0. + is_causal: bool, optional. If True, applies causal masking. Defaults to False. + + Call arguments: + q: KerasTensor. Query tensor of shape (..., L, D). + k: KerasTensor. Key tensor of shape (..., S, D). + v: KerasTensor. Value tensor of shape (..., S, D). + positional_encoding: KerasTensor. Positional encoding tensor. + + Returns: + KerasTensor: The resulting tensor from the attention mechanism. + """ + + def __init__(self, dropout_p=0.0, is_causal=False): + super(FluxRoPEAttention, self).__init__() + self.dropout_p = dropout_p + self.is_causal = is_causal + + def call(self, q, k, v, positional_encoding): + # Apply the RoPE transformation + q, k = ApplyRoPE()(q, k, positional_encoding) + + # Scaled dot-product attention + x = scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout_p, is_causal=self.is_causal + ) + + x = rearrange(x, "B H L D -> B L (H D)") + return x + + +# TODO: This is probably already implemented in several places, but is needed to ensure numeric equivalence to the original +# implementation. It uses torch.functional.scaled_dot_product_attention() - do we have an equivalent already in Keras? +def scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, +): + """ + Computes the scaled dot-product attention. + + Args: + query: KerasTensor. Query tensor of shape (..., L, D). + key: KerasTensor. Key tensor of shape (..., S, D). + value: KerasTensor. Value tensor of shape (..., S, D). + attn_mask: KerasTensor, optional. Attention mask tensor. Defaults to None. + dropout_p: float, optional. Dropout probability. Defaults to 0.0. + is_causal: bool, optional. If True, applies causal masking. Defaults to False. + scale: float, optional. Scale factor for attention. Defaults to None. + + Returns: + KerasTensor: The output tensor from the attention mechanism. + """ + L, S = ops.shape(query)[-2], ops.shape(key)[-2] + scale_factor = ( + 1 / ops.sqrt(ops.cast(ops.shape(query)[-1], dtype=query.dtype)) + if scale is None + else scale + ) + attn_bias = ops.zeros((L, S), dtype=query.dtype) + + if is_causal: + assert attn_mask is None + temp_mask = ops.ones((L, S), dtype=ops.bool) + temp_mask = ops.tril(temp_mask, diagonal=0) + attn_bias = ops.where(temp_mask, attn_bias, float("-inf")) + + if attn_mask is not None: + if ops.shape(attn_mask)[-1] == 1: # If the mask is 3D + attn_bias += attn_mask + else: + attn_bias = ops.where(attn_mask, attn_bias, float("-inf")) + + # Compute attention weights + attn_weight = ( + ops.matmul(query, ops.transpose(key, axes=[0, 1, 3, 2])) * scale_factor + ) + attn_weight += attn_bias + attn_weight = keras.activations.softmax(attn_weight, axis=-1) + + if dropout_p > 0.0: + attn_weight = keras.layers.Dropout(dropout_p)( + attn_weight, training=True + ) + + return ops.matmul(attn_weight, value) + + +def rearrange_symbolic_tensors(qkv, K, H): + """ + Splits the qkv tensor into query (q), key (k), and value (v) components. + + Mimics rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=num_heads), + for graph-mode TensorFlow support when doing functional subclassing + models. + + Arguments: + qkv: np.ndarray. Input tensor of shape (B, L, K*H*D). + K: int. Number of components (q, k, v). + H: int. Number of attention heads. + + Returns: + tuple: q, k, v tensors of shape (B, H, L, D). + """ + # Get the shape of qkv and calculate L and D + B, L, dim = ops.shape(qkv) + D = dim // (K * H) + + # Reshape and transpose the qkv tensor + qkv_reshaped = ops.reshape(qkv, (B, L, K, H, D)) + qkv_transposed = ops.transpose(qkv_reshaped, (2, 0, 3, 1, 4)) + + # Split q, k, v along the first dimension (K) + qkv_splits = ops.split(qkv_transposed, K, axis=0) + q, k, v = [ops.squeeze(split, 0) for split in qkv_splits] + + return q, k, v diff --git a/keras_hub/src/models/flux/flux_model.py b/keras_hub/src/models/flux/flux_model.py new file mode 100644 index 000000000..8cc40a90f --- /dev/null +++ b/keras_hub/src/models/flux/flux_model.py @@ -0,0 +1,231 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.flux.flux_layers import DoubleStreamBlock +from keras_hub.src.models.flux.flux_layers import EmbedND +from keras_hub.src.models.flux.flux_layers import LastLayer +from keras_hub.src.models.flux.flux_layers import MLPEmbedder +from keras_hub.src.models.flux.flux_layers import SingleStreamBlock +from keras_hub.src.models.flux.flux_maths import TimestepEmbedding + + +@keras_hub_export("keras_hub.models.FluxBackbone") +class FluxBackbone(Backbone): + """ + Transformer model for flow matching on sequences. + + The model processes image and text data with associated positional and timestep + embeddings, and optionally applies guidance embedding. Double-stream blocks + handle separate image and text streams, while single-stream blocks combine + these streams. Ported from: https://github.com/black-forest-labs/flux + + Args: + input_channels: int. The number of input channels. + hidden_size: int. The hidden size of the transformer, must be divisible by `num_heads`. + mlp_ratio: float. The ratio of the MLP dimension to the hidden size. + num_heads: int. The number of attention heads. + depth: int. The number of double-stream blocks. + depth_single_blocks: int. The number of single-stream blocks. + axes_dim: list[int]. A list of dimensions for the positional embedding axes. + theta: int. The base frequency for positional embeddings. + use_bias: bool. Whether to apply bias to the query, key, and value projections. + guidance_embed: bool. If True, applies guidance embedding in the model. + + Call arguments: + image: KerasTensor. Image input tensor of shape (N, L, D) where N is the batch size, + L is the sequence length, and D is the feature dimension. + image_ids: KerasTensor. Image ID input tensor of shape (N, L, D) corresponding + to the image sequences. + text: KerasTensor. Text input tensor of shape (N, L, D). + text_ids: KerasTensor. Text ID input tensor of shape (N, L, D) corresponding + to the text sequences. + timesteps: KerasTensor. Timestep tensor used to compute positional embeddings. + y: KerasTensor. Additional vector input, such as target values. + guidance: KerasTensor, optional. Guidance input tensor used + in guidance-embedded models. + Raises: + ValueError: If `hidden_size` is not divisible by `num_heads`, or if `sum(axes_dim)` is not equal to the + positional embedding dimension. + """ + + def __init__( + self, + input_channels, + hidden_size, + mlp_ratio, + num_heads, + depth, + depth_single_blocks, + axes_dim, + theta, + use_bias, + guidance_embed=False, + # These will be inferred from the CLIP/T5 encoders later + image_shape=(None, 768, 3072), + text_shape=(None, 768, 3072), + image_ids_shape=(None, 768, 3072), + text_ids_shape=(None, 768, 3072), + y_shape=(None, 128), + **kwargs, + ): + + # === Layers === + self.positional_embedder = EmbedND(theta=theta, axes_dim=axes_dim) + self.image_input_embedder = keras.layers.Dense( + hidden_size, use_bias=True + ) + self.time_input_embedder = MLPEmbedder(hidden_dim=hidden_size) + self.vector_embedder = MLPEmbedder(hidden_dim=hidden_size) + self.guidance_input_embedder = ( + MLPEmbedder(hidden_dim=hidden_size) + if guidance_embed + else keras.layers.Identity() + ) + self.text_input_embedder = keras.layers.Dense(hidden_size) + + self.double_blocks = [ + DoubleStreamBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + use_bias=use_bias, + ) + for _ in range(depth) + ] + + self.single_blocks = [ + SingleStreamBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) + for _ in range(depth_single_blocks) + ] + + self.final_layer = LastLayer(hidden_size, 1, input_channels) + self.timestep_embedding = TimestepEmbedding() + self.guidance_embed = guidance_embed + + # === Functional Model === + image_input = keras.Input(shape=image_shape, name="image") + image_ids = keras.Input(shape=image_ids_shape, name="image_ids") + text_input = keras.Input(shape=text_shape, name="text") + text_ids = keras.Input(shape=text_ids_shape, name="text_ids") + y = keras.Input(shape=y_shape, name="y") + timesteps_input = keras.Input(shape=(), name="timesteps") + guidance_input = keras.Input(shape=(), name="guidance") + + # running on sequences image + image = self.image_input_embedder(image_input) + modulation_encoding = self.time_input_embedder( + self.timestep_embedding(timesteps_input, dim=256) + ) + if self.guidance_embed: + if guidance_input is None: + raise ValueError( + "Didn't get guidance strength for guidance distilled model." + ) + modulation_encoding = ( + modulation_encoding + + self.guidance_input_embedder( + self.timestep_embedding(guidance_input, dim=256) + ) + ) + + modulation_encoding = modulation_encoding + self.vector_embedder(y) + text = self.text_input_embedder(text_input) + + ids = keras.ops.concatenate((text_ids, image_ids), axis=1) + positional_encoding = self.positional_embedder(ids) + + for block in self.double_blocks: + image, text = block( + image=image, + text=text, + modulation_encoding=modulation_encoding, + positional_encoding=positional_encoding, + ) + + image = keras.ops.concatenate((text, image), axis=1) + for block in self.single_blocks: + image = block( + image, + modulation_encoding=modulation_encoding, + positional_encoding=positional_encoding, + ) + image = image[:, text.shape[1] :, ...] + + image = self.final_layer( + image, modulation_encoding + ) # (N, T, patch_size ** 2 * output_channels) + + super().__init__( + inputs={ + "image": image_input, + "image_ids": image_ids, + "text": text_input, + "text_ids": text_ids, + "y": y, + "timesteps": timesteps_input, + "guidance": guidance_input, + }, + outputs=image, + **kwargs, + ) + + # === Config === + self.input_channels = input_channels + self.output_channels = self.input_channels + self.hidden_size = hidden_size + self.num_heads = num_heads + self.image_shape = image_shape + self.text_shape = text_shape + self.image_ids_shape = image_ids_shape + self.text_ids_shape = text_ids_shape + self.y_shape = y_shape + self.mlp_ratio = mlp_ratio + self.depth = depth + self.depth_single_blocks = depth_single_blocks + self.axes_dim = axes_dim + self.theta = theta + self.use_bias = use_bias + + def get_config(self): + config = super().get_config() + config.update( + { + "input_channels": self.input_channels, + "hidden_size": self.hidden_size, + "mlp_ratio": self.mlp_ratio, + "num_heads": self.num_heads, + "depth": self.depth, + "depth_single_blocks": self.depth_single_blocks, + "axes_dim": self.axes_dim, + "theta": self.theta, + "use_bias": self.use_bias, + "guidance_embed": self.guidance_embed, + "image_shape": self.image_shape, + "text_shape": self.text_shape, + "image_ids_shape": self.image_ids_shape, + "text_ids_shape": self.text_ids_shape, + "y_shape": self.y_shape, + } + ) + return config + + def encode_text_step(self, token_ids, negative_token_ids): + raise NotImplementedError("Not implemented yet") + + def encode(token_ids): + raise NotImplementedError("Not implemented yet") + + def encode_image_step(self, images): + raise NotImplementedError("Not implemented yet") + + def add_noise_step(self, latents, noises, step, num_steps): + raise NotImplementedError("Not implemented yet") + + def denoise_step( + self, + ): + raise NotImplementedError("Not implemented yet") + + def decode_step(self, latents): + raise NotImplementedError("Not implemented yet") diff --git a/keras_hub/src/models/flux/flux_presets.py b/keras_hub/src/models/flux/flux_presets.py new file mode 100644 index 000000000..af66ff074 --- /dev/null +++ b/keras_hub/src/models/flux/flux_presets.py @@ -0,0 +1,16 @@ +"""FLUX model preset configurations.""" + +presets = { + "schnell": { + "metadata": { + "description": ( + "A 12 billion parameter rectified flow transformer capable of generating images from text descriptions." + ), + "params": 124439808, + "official_name": "FLUX.1-schnell", + "path": "flux", + "model_card": "https://github.com/black-forest-labs/flux/blob/main/model_cards/FLUX.1-schnell.md", + }, + "kaggle_handle": "TBA", + }, +} diff --git a/keras_hub/src/models/flux/flux_text_to_image.py b/keras_hub/src/models/flux/flux_text_to_image.py new file mode 100644 index 000000000..792718214 --- /dev/null +++ b/keras_hub/src/models/flux/flux_text_to_image.py @@ -0,0 +1,142 @@ +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.flux.flux_model import FluxBackbone +from keras_hub.src.models.flux.flux_text_to_image_preprocessor import ( + FluxTextToImagePreprocessor, +) +from keras_hub.src.models.text_to_image import TextToImage + + +@keras_hub_export("keras_hub.models.FluxTextToImage") +class FluxTextToImage(TextToImage): + """An end-to-end Flux model for text-to-image generation. + + This model has a `generate()` method, which generates image based on a + prompt. + + Args: + backbone: A `keras_hub.models.FluxBackbone` instance. + preprocessor: A + `keras_hub.models.FluxTextToImagePreprocessor` instance. + + Examples: + + Use `generate()` to do image generation. + ```python + text_to_image = keras_hub.models.FluxTextToImage.from_preset( + "TBA", height=512, width=512 + ) + text_to_image.generate( + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" + ) + + # Generate with batched prompts. + text_to_image.generate( + ["cute wallpaper art of a cat", "cute wallpaper art of a dog"] + ) + + # Generate with different `num_steps` and `guidance_scale`. + text_to_image.generate( + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + num_steps=50, + guidance_scale=5.0, + ) + + # Generate with `negative_prompts`. + text_to_image.generate( + { + "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + "negative_prompts": "green color", + } + ) + ``` + """ + + backbone_cls = FluxBackbone + preprocessor_cls = FluxTextToImagePreprocessor + + def __init__( + self, + backbone, + preprocessor, + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.input + outputs = backbone.output + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + def fit(self, *args, **kwargs): + raise NotImplementedError( + "Currently, `fit` is not supported for " "`FluxTextToImage`." + ) + + def generate_step( + self, + latents, + token_ids, + num_steps, + guidance_scale, + ): + """A compilable generation function for batched of inputs. + + This function represents the inner, XLA-compilable, generation function + for batched inputs. + + Args: + latents: A (batch_size, height, width, channels) tensor + containing the latents to start generation from. Typically, this + tensor is sampled from the Gaussian distribution. + token_ids: A pair of (batch_size, num_tokens) tensor containing the + tokens based on the input prompts and negative prompts. + num_steps: int. The number of diffusion steps to take. + guidance_scale: float. The classifier free guidance scale defined in + [Classifier-Free Diffusion Guidance]( + https://arxiv.org/abs/2207.12598). Higher scale encourages to + generate images that are closely linked to prompts, usually at + the expense of lower image quality. + """ + token_ids, negative_token_ids = token_ids + + # Encode prompts. + embeddings = self.backbone.encode_text_step( + token_ids, negative_token_ids + ) + + # Denoise. + def body_fun(step, latents): + return self.backbone.denoise_step( + latents, + embeddings, + step, + num_steps, + guidance_scale, + ) + + latents = ops.fori_loop(0, num_steps, body_fun, latents) + + # Decode. + return self.backbone.decode_step(latents) + + def generate( + self, + inputs, + num_steps=28, + guidance_scale=7.0, + seed=None, + ): + return super().generate( + inputs, + num_steps=num_steps, + guidance_scale=guidance_scale, + seed=seed, + ) diff --git a/keras_hub/src/models/flux/flux_text_to_image_preprocessor.py b/keras_hub/src/models/flux/flux_text_to_image_preprocessor.py new file mode 100644 index 000000000..6750850d4 --- /dev/null +++ b/keras_hub/src/models/flux/flux_text_to_image_preprocessor.py @@ -0,0 +1,73 @@ +import keras +from keras import layers + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.flux.flux_model import FluxBackbone +from keras_hub.src.models.preprocessor import Preprocessor + + +@keras_hub_export("keras_hub.models.FluxTextToImagePreprocessor") +class FluxTextToImagePreprocessor(Preprocessor): + """Flux text-to-image model preprocessor. + + This preprocessing layer is meant for use with + `keras_hub.models.FluxTextToImagePreprocessor`. + + For use with generation, the layer exposes one methods + `generate_preprocess()`. + + Args: + clip_l_preprocessor: A `keras_hub.models.CLIPPreprocessor` instance. + t5_preprocessor: A optional `keras_hub.models.T5Preprocessor` instance. + """ + + backbone_cls = FluxBackbone + + def __init__( + self, + clip_l_preprocessor, + t5_preprocessor=None, + **kwargs, + ): + super().__init__(**kwargs) + self.clip_l_preprocessor = clip_l_preprocessor + self.t5_preprocessor = t5_preprocessor + + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self.clip_l_preprocessor.sequence_length + + def build(self, input_shape): + self.built = True + + def generate_preprocess(self, x): + token_ids = {} + token_ids["clip_l"] = self.clip_l_preprocessor(x)["token_ids"] + if self.t5_preprocessor is not None: + token_ids["t5"] = self.t5_preprocessor(x)["token_ids"] + return token_ids + + def get_config(self): + config = super().get_config() + config.update( + { + "clip_l_preprocessor": layers.serialize( + self.clip_l_preprocessor + ), + "t5_preprocessor": layers.serialize(self.t5_preprocessor), + } + ) + return config + + @classmethod + def from_config(cls, config): + for layer_name in ( + "clip_l_preprocessor", + "t5_preprocessor", + ): + if layer_name in config and isinstance(config[layer_name], dict): + config[layer_name] = keras.layers.deserialize( + config[layer_name] + ) + return cls(**config) diff --git a/keras_hub/src/models/flux/flux_text_to_image_preprocessor_test.py b/keras_hub/src/models/flux/flux_text_to_image_preprocessor_test.py new file mode 100644 index 000000000..d9a3a9d0a --- /dev/null +++ b/keras_hub/src/models/flux/flux_text_to_image_preprocessor_test.py @@ -0,0 +1,51 @@ +import pytest + +from keras_hub.src.models.clip.clip_preprocessor import CLIPPreprocessor +from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer +from keras_hub.src.models.flux.flux_text_to_image_preprocessor import ( + FluxTextToImagePreprocessor, +) +from keras_hub.src.tests.test_case import TestCase + + +class FluxTextToImagePreprocessorTest(TestCase): + def setUp(self): + vocab = ["air", "plane", "port"] + vocab += ["<|endoftext|>", "<|startoftext|>"] + vocab = dict([(token, i) for i, token in enumerate(vocab)]) + merges = ["a i", "p l", "n e", "p o", "r t", "ai r", "pl a"] + merges += ["po rt", "pla ne"] + clip_l_tokenizer = CLIPTokenizer( + vocabulary=vocab, merges=merges, pad_with_end_token=True + ) + clip_l_preprocessor = CLIPPreprocessor( + clip_l_tokenizer, sequence_length=8 + ) + self.init_kwargs = { + "clip_l_preprocessor": clip_l_preprocessor, + } + self.input_data = ["airplane"] + + def test_preprocessor_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_preprocessing_layer_test( + cls=FluxTextToImagePreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 4, 9, 5, 7, 2, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], + }, + [[4, 9, 5, 7, 2, 0, 0, 0]], # Labels shifted. + [[1, 1, 1, 1, 1, 0, 0, 0]], # Zero out unlabeled examples. + ), + ) + + def test_generate_preprocess(self): + preprocessor = FluxTextToImagePreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess(self.input_data) + self.assertIn("clip_l", x) + self.assertAllEqual(x["clip_l"][0], [4, 0, 1, 3, 3, 3, 3, 3]) diff --git a/requirements-common.txt b/requirements-common.txt index 2bdc4a572..c935b10f2 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -19,3 +19,6 @@ sentencepiece tensorflow-datasets safetensors pillow +# Will be replaced once https://github.com/keras-team/keras/issues/20332 +# is resolved +einops diff --git a/tools/checkpoint_conversion/convert_flux_checkpoints.py b/tools/checkpoint_conversion/convert_flux_checkpoints.py new file mode 100644 index 000000000..17961c58d --- /dev/null +++ b/tools/checkpoint_conversion/convert_flux_checkpoints.py @@ -0,0 +1,239 @@ +import os + +import keras +from safetensors import safe_open + +from keras_hub.src.models.flux.flux_model import Flux + +DOWNLOAD_URL = "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors" +keras.config.set_dtype_policy("mixed_float16") + + +def convert_mlpembedder_weights(weights_dict, keras_model, prefix): + in_layer_weight = weights_dict[f"{prefix}.in_layer.weight"].T + in_layer_bias = weights_dict[f"{prefix}.in_layer.bias"] + + out_layer_weight = weights_dict[f"{prefix}.out_layer.weight"].T + out_layer_bias = weights_dict[f"{prefix}.out_layer.bias"] + + keras_model.input_layer.set_weights([in_layer_weight, in_layer_bias]) + keras_model.output_layer.set_weights([out_layer_weight, out_layer_bias]) + + +def convert_selfattention_weights(weights_dict, keras_model, prefix): + qkv_weight = weights_dict[f"{prefix}.qkv.weight"].T + qkv_bias = weights_dict.get(f"{prefix}.qkv.bias") + + proj_weight = weights_dict[f"{prefix}.proj.weight"].T + proj_bias = weights_dict[f"{prefix}.proj.bias"] + + keras_model.qkv.set_weights( + [qkv_weight] + ([qkv_bias] if qkv_bias is not None else []) + ) + keras_model.proj.set_weights([proj_weight, proj_bias]) + + +def convert_modulation_weights(weights_dict, keras_model, prefix): + lin_weight = weights_dict[f"{prefix}.lin.weight"].T + lin_bias = weights_dict[f"{prefix}.lin.bias"] + + keras_model.linear_projection.set_weights([lin_weight, lin_bias]) + + +def convert_doublestreamblock_weights(weights_dict, keras_model, block_idx): + # Convert img_mod weights + convert_modulation_weights( + weights_dict, + keras_model.image_mod, + f"double_blocks.{block_idx}.img_mod", + ) + + # Convert txt_mod weights + convert_modulation_weights( + weights_dict, keras_model.text_mod, f"double_blocks.{block_idx}.txt_mod" + ) + + # Convert img_attn weights + convert_selfattention_weights( + weights_dict, + keras_model.image_attn, + f"double_blocks.{block_idx}.img_attn", + ) + + # Convert txt_attn weights + convert_selfattention_weights( + weights_dict, + keras_model.text_attention, + f"double_blocks.{block_idx}.txt_attn", + ) + + # Convert img_mlp weights (2 layers) + keras_model.image_mlp.layers[0].set_weights( + [ + weights_dict[f"double_blocks.{block_idx}.img_mlp.0.weight"].T, + weights_dict[f"double_blocks.{block_idx}.img_mlp.0.bias"], + ] + ) + keras_model.image_mlp.layers[2].set_weights( + [ + weights_dict[f"double_blocks.{block_idx}.img_mlp.2.weight"].T, + weights_dict[f"double_blocks.{block_idx}.img_mlp.2.bias"], + ] + ) + + # Convert txt_mlp weights (2 layers) + keras_model.text_mlp.layers[0].set_weights( + [ + weights_dict[f"double_blocks.{block_idx}.txt_mlp.0.weight"].T, + weights_dict[f"double_blocks.{block_idx}.txt_mlp.0.bias"], + ] + ) + keras_model.text_mlp.layers[2].set_weights( + [ + weights_dict[f"double_blocks.{block_idx}.txt_mlp.2.weight"].T, + weights_dict[f"double_blocks.{block_idx}.txt_mlp.2.bias"], + ] + ) + + +def convert_singlestreamblock_weights(weights_dict, keras_model, block_idx): + convert_modulation_weights( + weights_dict, + keras_model.modulation, + f"single_blocks.{block_idx}.modulation", + ) + + # Convert linear1 weights + keras_model.linear1.set_weights( + [ + weights_dict[f"single_blocks.{block_idx}.linear1.weight"].T, + weights_dict[f"single_blocks.{block_idx}.linear1.bias"], + ] + ) + + # Convert linear2 weights + keras_model.linear2.set_weights( + [ + weights_dict[f"single_blocks.{block_idx}.linear2.weight"].T, + weights_dict[f"single_blocks.{block_idx}.linear2.bias"], + ] + ) + + +def convert_lastlayer_weights(weights_dict, keras_model): + # Convert linear weights + keras_model.linear.set_weights( + [ + weights_dict["final_layer.linear.weight"].T, + weights_dict["final_layer.linear.bias"], + ] + ) + + # Convert adaLN_modulation weights + keras_model.adaLN_modulation.layers[1].set_weights( + [ + weights_dict["final_layer.adaLN_modulation.1.weight"].T, + weights_dict["final_layer.adaLN_modulation.1.bias"], + ] + ) + + +def convert_flux_weights(weights_dict, keras_model): + # Convert img_in weights + keras_model.image_input_embedder.set_weights( + [weights_dict["img_in.weight"].T, weights_dict["img_in.bias"]] + ) + + # Convert time_in weights (MLPEmbedder) + convert_mlpembedder_weights( + weights_dict, keras_model.time_input_embedder, "time_in" + ) + + # Convert vector_in weights (MLPEmbedder) + convert_mlpembedder_weights( + weights_dict, keras_model.vector_embedder, "vector_in" + ) + + # Convert guidance_in weights (if present) + if hasattr(keras_model, "guidance_embed"): + convert_mlpembedder_weights( + weights_dict, keras_model.guidance_input_embedder, "guidance_in" + ) + + # Convert txt_in weights + keras_model.text_input_embedder.set_weights( + [weights_dict["txt_in.weight"].T, weights_dict["txt_in.bias"]] + ) + + # Convert double_blocks weights + for block_idx in range(len(keras_model.double_blocks)): + convert_doublestreamblock_weights( + weights_dict, keras_model.double_blocks[block_idx], block_idx + ) + + # Convert single_blocks weights + for block_idx in range(len(keras_model.single_blocks)): + convert_singlestreamblock_weights( + weights_dict, keras_model.single_blocks[block_idx], block_idx + ) + + # Convert final_layer weights + convert_lastlayer_weights(weights_dict, keras_model.final_layer) + + +def main(_): + # get the original weights + print("Downloading weights") + + os.system(f"wget {DOWNLOAD_URL}") + + flux_weights = {} + with safe_open( + "flux1-schnell.safetensors", framework="pt", device="cpu" + ) as f: + for key in f.keys(): + flux_weights[key] = f.get_tensor(key) + + keras_model = Flux( + in_channels=64, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + use_bias=True, + guidance_embed=False, + ) + + # Define input shapes + img_shape = (1, 96, 64) + txt_shape = (1, 96, 64) + img_ids_shape = (1, 96, 3) + txt_ids_shape = (1, 96, 3) + timestep_shape = (32,) + y_shape = (1, 64) + guidance_shape = (32,) + + # Build the model + keras_model.build( + ( + img_shape, + img_ids_shape, + txt_shape, + txt_ids_shape, + timestep_shape, + y_shape, + guidance_shape, + ) + ) + + convert_flux_weights(flux_weights, keras_model) + keras_model.save_to_preset("flux1-schnell") + + os.remove("flux1-schnell.safetensors") + + +if __name__ == "__main__": + main()