Skip to content

Commit

Permalink
add GIT captioning, refactoring, DataLoader
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Feb 2, 2023
1 parent 8c3a52e commit 57d8483
Show file tree
Hide file tree
Showing 9 changed files with 481 additions and 146 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ __pycache__
wd14_tagger_model
venv
*.egg-info
build
build
.vscode
101 changes: 76 additions & 25 deletions finetune/make_captions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,59 @@
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from blip.blip import blip_decoder
# from Salesforce_BLIP.models.blip import blip_decoder
import library.train_util as train_util

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


IMAGE_SIZE = 384

# 正方形でいいのか? という気がするがソースがそうなので
IMAGE_TRANSFORM = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])

# 共通化したいが微妙に処理が異なる……
class ImageLoadingTransformDataset(torch.utils.data.Dataset):
def __init__(self, image_paths):
self.images = image_paths

def __len__(self):
return len(self.images)

def __getitem__(self, idx):
img_path = self.images[idx]

try:
image = Image.open(img_path).convert("RGB")
# convert to tensor temporarily so dataloader will accept it
tensor = IMAGE_TRANSFORM(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None

return (tensor, img_path)


def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch


def main(args):
# fix the seed for reproducibility
seed = args.seed # + utils.get_rank()
seed = args.seed # + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

if not os.path.exists("blip"):
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path

Expand All @@ -31,24 +72,15 @@ def main(args):
os.chdir('finetune')

print(f"load images from {args.train_data_dir}")
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
image_paths = train_util.glob_images(args.train_data_dir)
print(f"found {len(image_paths)} images.")

print(f"loading BLIP caption: {args.caption_weights}")
image_size = 384
model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large', med_config="./blip/med_config.json")
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json")
model.eval()
model = model.to(DEVICE)
print("BLIP loaded")

# 正方形でいいのか? という気がするがソースがそうなので
transform = transforms.Compose([
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])

# captioningする
def run_batch(path_imgs):
imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
Expand All @@ -66,18 +98,35 @@ def run_batch(path_imgs):
if args.debug:
print(image_path, caption)

# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = ImageLoadingTransformDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]

b_imgs = []
for image_path in tqdm(image_paths, smoothing=0.0):
raw_image = Image.open(image_path)
if raw_image.mode != "RGB":
print(f"convert image mode {raw_image.mode} to RGB: {image_path}")
raw_image = raw_image.convert("RGB")

image = transform(raw_image)
b_imgs.append((image_path, image))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue

img_tensor, image_path = data
if img_tensor is None:
try:
raw_image = Image.open(image_path)
if raw_image.mode != 'RGB':
raw_image = raw_image.convert("RGB")
img_tensor = IMAGE_TRANSFORM(raw_image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue

b_imgs.append((image_path, img_tensor))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0:
run_batch(b_imgs)

Expand All @@ -95,6 +144,8 @@ def run_batch(path_imgs):
parser.add_argument("--beam_search", action="store_true",
help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
Expand Down
136 changes: 136 additions & 0 deletions finetune/make_captions_by_git.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import argparse
import os
import re

from PIL import Image
from tqdm import tqdm
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.generation.utils import GenerationMixin

import library.train_util as train_util


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

PATTERN_REPLACE = [re.compile(r'with the (words?|letters?) (" ?[^"]*"|\w+)( on (the)? ?\w+)?'),
re.compile(r'that says (" ?[^"]*"|\w+)')]


# 誤検知しまくりの with the word xxxx を消す
def remove_words(captions, debug):
removed_caps = []
for caption in captions:
cap = caption
for pat in PATTERN_REPLACE:
cap = pat.sub("", caption)
if debug and cap != caption:
print(caption)
print(cap)
removed_caps.append(cap)
return removed_caps


def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch


def main(args):
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように

# input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す
# ここより上で置き換えようとするとすごく大変
def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs):
input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs)
if input_ids.size()[0] != curr_batch_size[0]:
input_ids = input_ids.repeat(curr_batch_size[0], 1)
return input_ids
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch

print(f"load images from {args.train_data_dir}")
image_paths = train_util.glob_images(args.train_data_dir)
print(f"found {len(image_paths)} images.")

# できればcacheに依存せず明示的にダウンロードしたい
print(f"loading GIT: {args.model_id}")
git_processor = AutoProcessor.from_pretrained(args.model_id)
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
print("GIT loaded")

# captioningする
def run_batch(path_imgs):
imgs = [im for _, im in path_imgs]

curr_batch_size[0] = len(path_imgs)
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)

if args.remove_words:
captions = remove_words(captions, args.debug)

for (image_path, _), caption in zip(path_imgs, captions):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)

# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = train_util.ImageLoadingDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]

b_imgs = []
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue

image, image_path = data
if image is None:
try:
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue

b_imgs.append((image_path, image))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()

if len(b_imgs) > 0:
run_batch(b_imgs)

print("done!")


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--model_id", type=str, default="microsoft/git-large-textcaps",
help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
parser.add_argument("--remove_words", action="store_true",
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
parser.add_argument("--debug", action="store_true", help="debug mode")

args = parser.parse_args()
main(args)
36 changes: 17 additions & 19 deletions finetune/merge_captions_to_metadata.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,35 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss

import argparse
import glob
import os
import json

from pathlib import Path
from typing import List
from tqdm import tqdm
import library.train_util as train_util


def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"

train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")

if args.in_json is None and os.path.isfile(args.out_json):
if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json

if args.in_json is not None:
print(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f:
metadata = json.load(f)
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
else:
print("new metadata will be created / 新しいメタデータファイルが作成されます")
metadata = {}

print("merge caption texts to metadata json.")
for image_path in tqdm(image_paths):
caption_path = os.path.splitext(image_path)[0] + args.caption_extension
with open(caption_path, "rt", encoding='utf-8') as f:
lines = f.readlines()
caption = lines[0].strip() if len(lines) > 0 else ""
caption_path = image_path.with_suffix(args.caption_extension)
caption = caption_path.read_text(encoding='utf-8').strip()

image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}

Expand All @@ -43,21 +39,23 @@ def main(args):

# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding='utf-8') as f:
json.dump(metadata, f, indent=2)
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
print("done!")


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
parser.add_argument("--in_json", type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
parser.add_argument("--full_path", action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
parser.add_argument("--recursive", action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
parser.add_argument("--debug", action="store_true", help="debug mode")

args = parser.parse_args()
Expand Down
Loading

0 comments on commit 57d8483

Please sign in to comment.