Skip to content

Commit

Permalink
Merge model_cfg & model_kwargs before passing to model, allows SigLIP…
Browse files Browse the repository at this point in the history
… models to be trained with SigLIP loss via --siglip (avoid dupe arg)
  • Loading branch information
rwightman committed Oct 22, 2023
1 parent 27e5037 commit 7b8dd2c
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,14 @@ def create_model(
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model

model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg)
if custom_text:
if "multimodal_cfg" in model_cfg:
model = CoCa(**model_cfg, **model_kwargs, cast_dtype=cast_dtype)
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
else:
model = CustomTextCLIP(**model_cfg, **model_kwargs, cast_dtype=cast_dtype)
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, **model_kwargs, cast_dtype=cast_dtype)
model = CLIP(**model_cfg, cast_dtype=cast_dtype)

if precision in ("fp16", "bf16"):
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
Expand Down

0 comments on commit 7b8dd2c

Please sign in to comment.