Skip to content

Commit

Permalink
save mdel on conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidLandup0 committed Nov 10, 2024
1 parent c6e20f6 commit dda8ec3
Showing 1 changed file with 4 additions and 22 deletions.
26 changes: 4 additions & 22 deletions tools/checkpoint_conversion/convert_flux_checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,3 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Requires installation of source code from
# https://github.com/black-forest-labs/flux

import os

import keras
Expand All @@ -31,8 +16,8 @@ def convert_mlpembedder_weights(weights_dict, keras_model, prefix):
out_layer_weight = weights_dict[f"{prefix}.out_layer.weight"].T
out_layer_bias = weights_dict[f"{prefix}.out_layer.bias"]

keras_model.in_layer.set_weights([in_layer_weight, in_layer_bias])
keras_model.out_layer.set_weights([out_layer_weight, out_layer_bias])
keras_model.input_layer.set_weights([in_layer_weight, in_layer_bias])
keras_model.output_layer.set_weights([out_layer_weight, out_layer_bias])


def convert_selfattention_weights(weights_dict, keras_model, prefix):
Expand All @@ -52,7 +37,7 @@ def convert_modulation_weights(weights_dict, keras_model, prefix):
lin_weight = weights_dict[f"{prefix}.lin.weight"].T
lin_bias = weights_dict[f"{prefix}.lin.bias"]

keras_model.lin.set_weights([lin_weight, lin_bias])
keras_model.linear_projection.set_weights([lin_weight, lin_bias])


def convert_doublestreamblock_weights(weights_dict, keras_model, block_idx):
Expand Down Expand Up @@ -245,10 +230,7 @@ def main(_):
)

convert_flux_weights(flux_weights, keras_model)

# TODO:
# validation
# save
keras_model.save_to_preset("flux1-schnell")

os.remove("flux1-schnell.safetensors")

Expand Down

0 comments on commit dda8ec3

Please sign in to comment.