Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compatible with torch.utils.data.DataLoader #154

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pix2tex/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def update(self, **kwargs):
class Dataloader(DataLoader):
def __init__(self, dataset: Im2LatexDataset, batch_size=1, shuffle=False, *args, **kwargs):
self.dataset = dataset
self.dataset.update(batchsize=batch_size, shuffle=shuffle)
self.dataset.update(batchsize=batch_size, shuffle=shuffle, *args, **kwargs)
lukas-blecher marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(self.dataset, *args, shuffle=False, batch_size=None, **kwargs)

def __iter__(self):
Expand Down
6 changes: 3 additions & 3 deletions pix2tex/eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pix2tex.dataset.dataset import Im2LatexDataset
from pix2tex.dataset.dataset import Im2LatexDataset, Dataloader
import argparse
import logging
import yaml
Expand Down Expand Up @@ -28,12 +28,12 @@ def detokenize(tokens, tokenizer):


@torch.no_grad()
def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: int = None, name: str = 'test'):
def evaluate(model: Model, dataset: Dataloader, args: Munch, num_batches: int = None, name: str = 'test'):
"""evaluates the model. Returns bleu score on the dataset

Args:
model (torch.nn.Module): the model
dataset (Im2LatexDataset): test dataset
dataset (Dataloader): test dataset
args (Munch): arguments
num_batches (int): How many batches to evaluate on. Defaults to None (all batches).
name (str, optional): name of the test e.g. val or test for wandb. Defaults to 'test'.
Expand Down
1 change: 1 addition & 0 deletions pix2tex/model/settings/config-vit.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
gpu_devices: null #[0,1,2,3,4,5,6,7]
num_workers: 0
betas:
- 0.9
- 0.999
Expand Down
1 change: 1 addition & 0 deletions pix2tex/model/settings/config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
gpu_devices: null #[0,1,2,3,4,5,6,7]
num_workers: 0
backbone_layers:
- 2
- 3
Expand Down
4 changes: 4 additions & 0 deletions pix2tex/model/settings/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,7 @@ pad: False
pad_token: 0
bos_token: 1
eos_token: 2

#devices(GPU&CPU)
num_workers: 0
gpu_devices: null #[0,1,2,3,4,5,6,7]
20 changes: 10 additions & 10 deletions pix2tex/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pix2tex.dataset.dataset import Im2LatexDataset
from pix2tex.dataset.dataset import Im2LatexDataset, Dataloader
import os
import argparse
import logging
Expand All @@ -16,12 +16,12 @@


def train(args):
dataloader = Im2LatexDataset().load(args.data)
dataloader.update(**args, test=False)
valdataloader = Im2LatexDataset().load(args.valdata)
train_dataset = Im2LatexDataset().load(args.data)
train_dataloader = Dataloader(train_dataset, **args, test=False)
lukas-blecher marked this conversation as resolved.
Show resolved Hide resolved
val_dataset = Im2LatexDataset().load(args.valdata)
valargs = args.copy()
valargs.update(batchsize=args.testbatchsize, keep_smaller_batches=True, test=True)
valdataloader.update(**valargs)
val_dataloader = Dataloader(val_dataset, **valargs)
lukas-blecher marked this conversation as resolved.
Show resolved Hide resolved
device = args.device
model = get_model(args)
if torch.cuda.is_available() and not args.no_cuda:
Expand All @@ -47,7 +47,7 @@ def save_models(e, step=0):
try:
for e in range(args.epoch, args.epochs):
args.epoch = e
dset = tqdm(iter(dataloader))
dset = tqdm(iter(train_dataloader))
lukas-blecher marked this conversation as resolved.
Show resolved Hide resolved
for i, (seq, im) in enumerate(dset):
if seq is not None and im is not None:
opt.zero_grad()
Expand All @@ -63,20 +63,20 @@ def save_models(e, step=0):
dset.set_description('Loss: %.4f' % total_loss)
if args.wandb:
wandb.log({'train/loss': total_loss})
if (i+1+len(dataloader)*e) % args.sample_freq == 0:
bleu_score, edit_distance, token_accuracy = evaluate(model, valdataloader, args, num_batches=int(args.valbatches*e/args.epochs), name='val')
if (i+1+len(train_dataloader)*e) % args.sample_freq == 0:
bleu_score, edit_distance, token_accuracy = evaluate(model, val_dataloader, args, num_batches=int(args.valbatches*e/args.epochs), name='val')
if bleu_score > max_bleu and token_accuracy > max_token_acc:
max_bleu, max_token_acc = bleu_score, token_accuracy
save_models(e, step=i)
if (e+1) % args.save_freq == 0:
save_models(e, step=len(dataloader))
save_models(e, step=len(train_dataloader))
if args.wandb:
wandb.log({'train/epoch': e+1})
except KeyboardInterrupt:
if e >= 2:
save_models(e, step=i)
raise KeyboardInterrupt
save_models(e, step=len(dataloader))
save_models(e, step=len(train_dataloader))


if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions pix2tex/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def parse_args(args, **kwargs) -> Munch:
args.update(kwargs)
args.wandb = not kwargs.debug and not args.debug
args.device = get_device(args, kwargs.no_cuda)
args.num_workers = args.get('num_workers', 0)
args.max_dimensions = [args.max_width, args.max_height]
args.min_dimensions = [args.get('min_width', 32), args.get('min_height', 32)]
if 'decoder_args' not in args or args.decoder_args is None:
Expand Down