Skip to content

Commit

Permalink
Add jax_memory_cleanup argument in Task.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed Apr 12, 2024
1 parent d71f3ee commit 677de72
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion keras_nlp/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def from_preset(
task_preset_cls = check_config_class(preset, TASK_CONFIG_FILE)
task = load_serialized_object(preset, TASK_CONFIG_FILE)
if load_weights:
jax_memory_cleanup()
jax_memory_cleanup(task)
if check_file_exists(preset, TASK_WEIGHTS_FILE):
task.load_task_weights(get_file(preset, TASK_WEIGHTS_FILE))
task.backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE))
Expand Down

0 comments on commit 677de72

Please sign in to comment.