We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
from PIL import Image import requests import clip import torch from transformers import BertTokenizer from transformers import CLIPProcessor, CLIPModel, CLIPTextModel import numpy as np query_texts = ['一个人', '一辆汽车', '两个男人', '两个女人'] # 这里是输入提示词,可以随意替换。 # 加载SkyCLIP 中英文双语 text_encoder text_tokenizer = BertTokenizer.from_pretrained("./tokenizer") text_encoder = CLIPTextModel.from_pretrained("./text_encoder").eval() text = text_tokenizer(query_texts, return_tensors='pt', padding=True)['input_ids'] url = "http://images.cocodataset.org/val2017/000000040083.jpg" #这里可以换成任意图片的url # 加载CLIP的image encoder clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") clip_text_proj = clip_model.text_projection processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") image = processor(images=Image.open(requests.get(url, stream=True).raw), return_tensors="pt") with torch.no_grad(): image_features = clip_model.get_image_features(**image) text_features = text_encoder(text)[0] # sep_token对应于openai-clip的eot_token sep_index = torch.nonzero(text == student_tokenizer.sep_token_id) text_features = text_features[torch.arange(text.shape[0]), sep_index[:, 1]] # 乘text投影矩阵 text_features = clip_text_proj(text_features) image_features = image_features / image_features.norm(dim=1, keepdim=True) text_features = text_features / text_features.norm(dim=1, keepdim=True) # 计算余弦相似度 logit_scale是尺度系数 logit_scale = clip_model.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logits_per_image.t() probs = logits_per_image.softmax(dim=-1).cpu().numpy() print(np.around(probs, 3))
请问一下代码里 tokenizer 和 text_encoder分别是啥?
The text was updated successfully, but these errors were encountered:
No branches or pull requests
请问一下代码里 tokenizer 和 text_encoder分别是啥?
The text was updated successfully, but these errors were encountered: