Skip to content

Commit

Permalink
Convert JIT model (on state dict load) to sd for pretrained='filename…
Browse files Browse the repository at this point in the history
….pt' support for OpenAI .pt files. Fix #622
  • Loading branch information
rwightman committed Sep 11, 2023
1 parent 79a20ee commit 64d42df
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'):
checkpoint = torch.load(checkpoint_path, map_location=map_location)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif isinstance(checkpoint, torch.jit.ScriptModule):
state_dict = checkpoint.state_dict()
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
else:
state_dict = checkpoint
if next(iter(state_dict.items()))[0].startswith('module'):
Expand Down

0 comments on commit 64d42df

Please sign in to comment.