Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into siglip_clipa_models
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Oct 11, 2023
2 parents 0316911 + 7e2d222 commit 9d8385e
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 17 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -415,11 +415,17 @@ When training a RN50 on YFCC the same hyperparameters as above are used, with th
Note that to use another model, like `ViT-B/32` or `RN50x4` or `RN50x16` or `ViT-B/16`, specify with `--model RN50x4`.
### Launch tensorboard:
### Logging
For tensorboard logging, run:
```bash
tensorboard --logdir=logs/tensorboard/ --port=7777
```
For wandb logging, we recommend looking at the `step` variable instead of `Step`, since the later was not properly set in earlier versions of this codebase.
For older runs with models trained before https://github.com/mlfoundations/open_clip/pull/613, the `Step` variable should be ignored.
For newer runs, after that PR, the two variables are the same.
## Evaluation / Zero-Shot
We recommend https://github.com/LAION-AI/CLIP_benchmark#how-to-use for systematic evaluation on 40 datasets.
Expand Down
2 changes: 2 additions & 0 deletions docs/openclip_results.csv
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ RN50,openai,0.4810,0.5982,0.8329,0.7157,0.4030,0.2171,0.1623,0.1542,0.4154,0.408
ViT-B-16,commonpool_l_text_s1b_b8k,0.4760,0.5605,0.8720,0.9391,0.7054,0.1843,0.2373,0.0995,0.3941,0.3830,0.0451,0.7724,0.2317,0.4437,0.4835,0.2220,0.4770,0.6708,0.2686,0.2593,0.4911,0.5164,0.7049,0.7669,0.4857,0.4931,0.4663,0.6525,0.9523,0.6088,0.2122,0.6078,0.3730,0.4570,0.0623,0.5697,0.0000,0.5643,0.8564
ViT-B-16,commonpool_l_basic_s1b_b8k,0.4585,0.5155,0.8444,0.8289,0.5251,0.2061,0.2277,0.1173,0.4133,0.3820,0.0481,0.7461,0.2021,0.3932,0.4325,0.1913,0.4600,0.6087,0.3333,0.2809,0.4493,0.4357,0.6956,0.7151,0.5899,0.5387,0.4313,0.7216,0.9373,0.5974,0.1173,0.6015,0.3583,0.4812,0.0436,0.5712,0.0000,0.5421,0.8384
ViT-B-16,commonpool_l_s1b_b8k,0.4370,0.4593,0.8089,0.9133,0.6421,0.1594,0.2203,0.1177,0.3383,0.3348,0.0316,0.6735,0.2766,0.3448,0.3914,0.1592,0.4335,0.5265,0.2686,0.3603,0.4126,0.3681,0.5587,0.7093,0.5516,0.5118,0.4154,0.6060,0.9339,0.5713,0.3047,0.4948,0.2855,0.4777,0.0399,0.5102,0.0000,0.5654,0.8305
nllb-clip-large,v1,0.4227,0.3672,0.7234,0.9634,0.6797,0.2389,0.2254,0.0691,0.3447,0.5454,0.0216,0.4447,0.2462,0.3316,0.3233,0.2632,0.1725,0.5624,0.3727,0.2716,0.5268,0.0978,0.1283,0.7551,0.5417,0.5585,0.4983,0.3865,0.9811,0.5512,0.1725,0.6625,0.4004,0.4299,0.0403,0.5181,0.1419,0.6752,0.8305
nllb-clip-base,v1,0.3351,0.2432,0.5914,0.8435,0.4839,0.1531,0.2254,0.0312,0.2782,0.4104,0.0185,0.2962,0.1852,0.1838,0.2029,0.0921,0.2195,0.3656,0.3741,0.1821,0.2874,0.0850,0.0784,0.6802,0.5509,0.5420,0.3603,0.1921,0.9514,0.4708,0.1441,0.5200,0.3081,0.3904,0.0463,0.4873,0.0000,0.5456,0.7136
ViT-B-32,datacomp_m_s128m_b4k,0.3281,0.2972,0.7159,0.8252,0.5476,0.1365,0.2249,0.0453,0.2133,0.3393,0.0304,0.4168,0.1366,0.1930,0.2440,0.0493,0.4085,0.3402,0.2110,0.1147,0.1971,0.2965,0.4311,0.5459,0.5862,0.5316,0.2778,0.2803,0.8365,0.3637,0.1500,0.2241,0.1407,0.3287,0.0142,0.6669,0.0000,0.4498,0.6559
ViT-B-32,commonpool_m_clip_s128m_b4k,0.3278,0.2725,0.6678,0.8405,0.5549,0.1402,0.2238,0.0458,0.2176,0.2589,0.0215,0.3999,0.1586,0.1844,0.2247,0.0420,0.3925,0.3297,0.3235,0.1778,0.2093,0.2551,0.3828,0.6074,0.5210,0.5014,0.2641,0.4123,0.8370,0.3875,0.1931,0.2465,0.1476,0.3581,0.0154,0.5369,0.0000,0.4451,0.6610
RN50-quickgelu,cc12m,0.3260,0.3647,0.6581,0.5404,0.2079,0.2063,0.1574,0.0431,0.1910,0.2146,0.0226,0.4392,0.1284,0.2412,0.3098,0.0759,0.4160,0.4468,0.3713,0.1261,0.2320,0.2383,0.5651,0.4394,0.5033,0.4789,0.2137,0.1837,0.8751,0.4442,0.0918,0.5373,0.2891,0.3876,0.0476,0.5000,0.0000,0.4883,0.7119
Expand Down
11 changes: 11 additions & 0 deletions src/open_clip/hf_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,15 @@
},
"pooler": "cls_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/m2m_100
"m2m_100": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "d_model",
"heads": "encoder_attention_heads",
"layers": "encoder_layers",
},
"pooler": "cls_pooler",
},
}
2 changes: 1 addition & 1 deletion src/open_clip/model_configs/coca_roberta-ViT-B-32.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"text_cfg": {
"hf_model_name": "roberta-base",
"hf_tokenizer_name": "roberta-base",
"proj": "linear",
"hf_proj_type": "linear",
"width": 768,
"output_tokens": true
},
Expand Down
15 changes: 15 additions & 0 deletions src/open_clip/model_configs/nllb-clip-base.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"hf_model_name": "facebook/nllb-200-distilled-600M",
"hf_tokenizer_name": "facebook/nllb-200-distilled-600M",
"hf_proj_type": "linear",
"hf_pooler_type": "cls_pooler"
}
}
16 changes: 16 additions & 0 deletions src/open_clip/model_configs/nllb-clip-large.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14
},
"text_cfg": {
"hf_model_name": "facebook/nllb-200-distilled-1.3B",
"hf_tokenizer_name": "facebook/nllb-200-distilled-1.3B",
"hf_proj_type": "linear",
"hf_pooler_type": "cls_pooler"
}
}
7 changes: 7 additions & 0 deletions src/open_clip/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,13 @@ def _apcfg(url='', hf_hub='', **kwargs):
"ViT-bigG-14-CLIPA-336": dict(
datacomp1b=_apcfg(),
),

"nllb-clip-base": dict(
v1=_pcfg(hf_hub='visheratin/nllb-clip-base-oc/'),
),
"nllb-clip-large": dict(
v1=_pcfg(hf_hub='visheratin/nllb-clip-large-oc/'),
)
}


Expand Down
2 changes: 1 addition & 1 deletion src/open_clip/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ def build_causal_mask(self):

def build_cls_mask(self, text, cast_dtype: torch.dtype):
cls_mask = (text != self.pad_id).unsqueeze(1)
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)
additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
additive_mask.fill_(0)
additive_mask.masked_fill_(~cls_mask, float("-inf"))
Expand Down
35 changes: 23 additions & 12 deletions src/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,14 +231,17 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist
}
log_data.update({name:val.val for name,val in losses_m.items()})

for name, val in log_data.items():
name = "train/" + name
if tb_writer is not None:
tb_writer.add_scalar(name, val, step)
if args.wandb:
assert wandb is not None, 'Please install wandb.'
wandb.log({name: val, 'step': step})
log_data = {"train/" + name: val for name, val in log_data.items()}

if tb_writer is not None:
for name, val in log_data.items():
tb_writer.add_scalar(name, val, step)

if args.wandb:
assert wandb is not None, 'Please install wandb.'
log_data['step'] = step # for backwards compatibility
wandb.log(log_data, step=step)

# resetting batch / data time meters per log window
batch_time_m.reset()
data_time_m.reset()
Expand Down Expand Up @@ -329,19 +332,27 @@ def evaluate(model, data, epoch, args, tb_writer=None):
+ "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
)

log_data = {"val/" + name: val for name, val in metrics.items()}

if args.save_logs:
for name, val in metrics.items():
if tb_writer is not None:
tb_writer.add_scalar(f"val/{name}", val, epoch)
if tb_writer is not None:
for name, val in log_data.items():
tb_writer.add_scalar(name, val, epoch)

with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
f.write(json.dumps(metrics))
f.write("\n")

if args.wandb:
assert wandb is not None, 'Please install wandb.'
for name, val in metrics.items():
wandb.log({f"val/{name}": val, 'epoch': epoch})
if 'train' in data:
dataloader = data['train'].dataloader
num_batches_per_epoch = dataloader.num_batches // args.accum_freq
step = num_batches_per_epoch * epoch
else:
step = None
log_data['epoch'] = epoch
wandb.log(log_data, step=step)

return metrics

Expand Down
4 changes: 2 additions & 2 deletions tests/test_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
def test_poolers():
bs, sl, d = 2, 10, 5
h = torch.arange(sl).repeat(bs).reshape(bs, sl)[..., None] * torch.linspace(0.2, 1., d)
mask = torch.ones(bs, sl, dtype=torch.long)
mask[:2, 6:] = 0
mask = torch.ones(bs, sl, dtype=torch.bool)
mask[:2, 6:] = False
x = BaseModelOutput(h)
for name, cls in _POOLERS.items():
pooler = cls()
Expand Down

0 comments on commit 9d8385e

Please sign in to comment.