Skip to content

Commit

Permalink
Swap nougat for texify
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Jan 2, 2024
1 parent bba88c2 commit d35d295
Show file tree
Hide file tree
Showing 10 changed files with 125 additions and 1,588 deletions.
128 changes: 47 additions & 81 deletions marker/cleaners/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,26 @@
from typing import List

import torch
from nougat import NougatModel
from nougat.postprocessing import markdown_compatible
from nougat.utils.checkpoint import get_checkpoint
from texify.inference import batch_inference
from texify.model.model import load_model
from texify.model.processor import load_processor
import re
from PIL import Image, ImageDraw
from nougat.utils.dataset import ImageDataset

from marker.bbox import should_merge_blocks, merge_boxes
from marker.debug.data import dump_nougat_debug_data
from marker.debug.data import dump_equation_debug_data
from marker.settings import settings
from marker.schema import Page, Span, Line, Block, BlockType
from nougat.utils.device import move_to_device
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

processor = load_processor()

def load_nougat_model():
ckpt = get_checkpoint(None, model_tag=settings.NOUGAT_MODEL_NAME)
nougat_model = NougatModel.from_pretrained(ckpt)
if settings.TORCH_DEVICE != "cpu":
move_to_device(nougat_model, bf16=settings.CUDA, cuda=settings.CUDA)
nougat_model.eval()
return nougat_model

def load_texify_model():
texify_model = load_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=settings.TORCH_DEVICE_MODEL, dtype=settings.TEXIFY_DTYPE)
return texify_model


def mask_bbox(png_image, bbox, selected_bboxes):
Expand All @@ -54,76 +50,48 @@ def mask_bbox(png_image, bbox, selected_bboxes):
return result


def get_nougat_image(page, bbox, selected_bboxes):
pix = page.get_pixmap(dpi=settings.NOUGAT_DPI, clip=bbox)
png = pix.pil_tobytes(format="BMP")
def get_masked_image(page, bbox, selected_bboxes):
pix = page.get_pixmap(dpi=settings.TEXIFY_DPI, clip=bbox)
png = pix.pil_tobytes(format="PNG")
png_image = Image.open(io.BytesIO(png))
png_image = mask_bbox(png_image, bbox, selected_bboxes)
png_image = png_image.convert("RGB")

img_out = io.BytesIO()
png_image.save(img_out, format="BMP")
return img_out


def replace_latex_fences(text):
# Replace block equations: \[ ... \] with $$...$$
text = re.sub(r'\\\[(.*?)\\\]', r'$$\1$$', text)

# Replace inline math: \( ... \) with $...$
text = re.sub(r'\\\((.*?)\\\)', r'$\1$', text)

return text
return png_image


def get_nougat_text_batched(images, reformat_region_lens, nougat_model, batch_size):
def get_latex_batched(images, reformat_region_lens, texify_model, batch_size):
if len(images) == 0:
return []

predictions = [""] * len(images)

dataset = ImageDataset(
images,
partial(nougat_model.encoder.prepare_input, random_padding=False),
)

dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
pin_memory=True,
shuffle=False,
)

for idx, sample in enumerate(dataloader):
for i in range(0, len(images), batch_size):
# Dynamically set max length to save inference time
min_idx = idx * batch_size
min_idx = i
max_idx = min(min_idx + batch_size, len(images))
max_length = max(reformat_region_lens[min_idx:max_idx])
max_length = min(max_length, settings.NOUGAT_MODEL_MAX)
max_length += settings.NOUGAT_TOKEN_BUFFER

nougat_model.config.max_length = max_length
model_output = nougat_model.inference(image_tensors=sample, early_stopping=False)
for j, output in enumerate(model_output["predictions"]):
disclaimer = ""
token_count = get_total_nougat_tokens(output, nougat_model)
max_length = min(max_length, settings.TEXIFY_MODEL_MAX)
max_length += settings.TEXIFY_TOKEN_BUFFER

model_output = batch_inference(images[min_idx:max_idx], texify_model, processor, max_tokens=max_length)

for j, output in enumerate(model_output):
token_count = get_total_texify_tokens(output)
if token_count >= max_length - 1:
disclaimer = "[TRUNCATED]"
output = ""

image_idx = idx * batch_size + j
predictions[image_idx] = (
replace_latex_fences(markdown_compatible(output)) + disclaimer
)
image_idx = i + j
predictions[image_idx] = output
return predictions


def get_total_nougat_tokens(text, nougat_model):
tokenizer = nougat_model.decoder.tokenizer
def get_total_texify_tokens(text):
tokenizer = processor.tokenizer
tokens = tokenizer(text)
return len(tokens["input_ids"])


def find_page_equation_regions(pnum, page, block_types, nougat_model):
def find_page_equation_regions(pnum, page, block_types):
i = 0
equation_boxes = [b.bbox for b in block_types[pnum] if b.block_type == "Formula"]
reformatted_blocks = set()
Expand Down Expand Up @@ -151,7 +119,7 @@ def find_page_equation_regions(pnum, page, block_types, nougat_model):
prev_block = page.blocks[j]
prev_bbox = prev_block.bbox
prelim_block_text = prev_block.prelim_text + " " + block_text
if get_total_nougat_tokens(prelim_block_text, nougat_model) >= settings.NOUGAT_MODEL_MAX:
if get_total_texify_tokens(prelim_block_text) >= settings.TEXIFY_MODEL_MAX:
break

block_text = prelim_block_text
Expand All @@ -166,7 +134,7 @@ def find_page_equation_regions(pnum, page, block_types, nougat_model):
equation_boxes)) and i + 1 < len(page.blocks):
bbox = merge_boxes(bbox, next_bbox)
prelim_block_text = block_text + " " + next_block.prelim_text
if get_total_nougat_tokens(prelim_block_text, nougat_model) >= settings.NOUGAT_MODEL_MAX:
if get_total_texify_tokens(prelim_block_text) >= settings.TEXIFY_MODEL_MAX:
break

block_text = prelim_block_text
Expand All @@ -176,9 +144,9 @@ def find_page_equation_regions(pnum, page, block_types, nougat_model):
next_block = page.blocks[i + 1]
next_bbox = next_block.bbox

total_tokens = get_total_nougat_tokens(block_text, nougat_model)
total_tokens = get_total_texify_tokens(block_text)
ordered_blocks = sorted(([sb[0] for sb in selected_blocks]))
if total_tokens < settings.NOUGAT_MODEL_MAX:
if total_tokens < settings.TEXIFY_MODEL_MAX:
# Get indices of all blocks to merge
reformat_regions.append(ordered_blocks)
block_lens.append(total_tokens)
Expand Down Expand Up @@ -206,7 +174,7 @@ def get_bboxes_for_region(page, region):
return bboxes, merged_box


def replace_blocks_with_nougat_predictions(page_blocks: Page, merged_boxes, reformat_regions, predictions, pnum, nougat_model):
def replace_blocks_with_latex(page_blocks: Page, merged_boxes, reformat_regions, predictions, pnum):
new_blocks = []
converted_spans = []
current_region = 0
Expand All @@ -221,13 +189,12 @@ def replace_blocks_with_nougat_predictions(page_blocks: Page, merged_boxes, refo
continue

orig_block_text = " ".join([page_blocks.blocks[i].prelim_text for i in reformat_regions[current_region]])
nougat_text = predictions[current_region]
latex_text = predictions[current_region]
conditions = [
len(nougat_text) > 0,
not any([word in nougat_text for word in settings.NOUGAT_HALLUCINATION_WORDS]),
get_total_nougat_tokens(nougat_text, nougat_model) < settings.NOUGAT_MODEL_MAX, # Make sure we didn't run to the token max
len(nougat_text) > len(orig_block_text) * .8,
len(nougat_text.strip()) > 0
len(latex_text) > 0,
get_total_texify_tokens(latex_text) < settings.TEXIFY_MODEL_MAX, # Make sure we didn't run to the overall token max
len(latex_text) > len(orig_block_text) * .8,
len(latex_text.strip()) > 0
]

idx = reformat_regions[current_region][-1] + 1
Expand All @@ -241,7 +208,7 @@ def replace_blocks_with_nougat_predictions(page_blocks: Page, merged_boxes, refo
block_line = Line(
spans=[
Span(
text=nougat_text,
text=latex_text,
bbox=merged_boxes[current_region],
span_id=f"{pnum}_{idx}_fixeq",
font="Latex",
Expand All @@ -261,15 +228,15 @@ def replace_blocks_with_nougat_predictions(page_blocks: Page, merged_boxes, refo
return new_blocks, success_count, fail_count, converted_spans


def replace_equations(doc, blocks: List[Page], block_types: List[List[BlockType]], nougat_model, batch_size=settings.NOUGAT_BATCH_SIZE):
def replace_equations(doc, blocks: List[Page], block_types: List[List[BlockType]], texify_model, batch_size=settings.TEXIFY_BATCH_SIZE):
unsuccessful_ocr = 0
successful_ocr = 0

# Find potential equation regions, and length of text in each region
reformat_regions = []
reformat_region_lens = []
for pnum, page in enumerate(blocks):
regions, region_lens = find_page_equation_regions(pnum, page, block_types, nougat_model)
regions, region_lens = find_page_equation_regions(pnum, page, block_types)
reformat_regions.append(regions)
reformat_region_lens.append(region_lens)

Expand All @@ -283,26 +250,25 @@ def replace_equations(doc, blocks: List[Page], block_types: List[List[BlockType]
page_obj = doc[page_idx]
for reformat_region in reformat_regions_page:
bboxes, merged_box = get_bboxes_for_region(blocks[page_idx], reformat_region)
png_image = get_nougat_image(page_obj, merged_box, bboxes)
png_image = get_masked_image(page_obj, merged_box, bboxes)
images.append(png_image)
merged_boxes.append(merged_box)

# Make batched predictions
predictions = get_nougat_text_batched(images, flat_reformat_region_lens, nougat_model, batch_size)
predictions = get_latex_batched(images, flat_reformat_region_lens, texify_model, batch_size)

# Replace blocks with predictions
page_start = 0
converted_spans = []
for page_idx, reformat_regions_page in enumerate(reformat_regions):
page_predictions = predictions[page_start:page_start + len(reformat_regions_page)]
page_boxes = merged_boxes[page_start:page_start + len(reformat_regions_page)]
new_page_blocks, success_count, fail_count, converted_span = replace_blocks_with_nougat_predictions(
new_page_blocks, success_count, fail_count, converted_span = replace_blocks_with_latex(
blocks[page_idx],
page_boxes,
reformat_regions_page,
page_predictions,
page_idx,
nougat_model
page_idx
)
converted_spans.extend(converted_span)
blocks[page_idx].blocks = new_page_blocks
Expand All @@ -311,6 +277,6 @@ def replace_equations(doc, blocks: List[Page], block_types: List[List[BlockType]
unsuccessful_ocr += fail_count

# If debug mode is on, dump out conversions for comparison
dump_nougat_debug_data(doc, images, converted_spans)
dump_equation_debug_data(doc, images, converted_spans)

return blocks, {"successful_ocr": successful_ocr, "unsuccessful_ocr": unsuccessful_ocr, "equations": eq_count}
6 changes: 3 additions & 3 deletions marker/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def convert_single_pdf(
return "", out_meta

# Unpack models from list
nougat_model, layoutlm_model, order_model, edit_model = model_lst
texify_model, layoutlm_model, order_model, edit_model = model_lst

block_types = detect_document_block_types(
doc,
Expand Down Expand Up @@ -146,8 +146,8 @@ def convert_single_pdf(
doc,
blocks,
block_types,
nougat_model,
batch_size=int(settings.NOUGAT_BATCH_SIZE * parallel_factor)
texify_model,
batch_size=int(settings.TEXIFY_BATCH_SIZE * parallel_factor)
)
out_meta["block_stats"]["equations"] = eq_stats

Expand Down
4 changes: 2 additions & 2 deletions marker/debug/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import io


def dump_nougat_debug_data(doc, images, converted_spans):
def dump_equation_debug_data(doc, images, converted_spans):
if not settings.DEBUG_DATA_FOLDER or settings.DEBUG_LEVEL == 0:
return

Expand Down Expand Up @@ -55,7 +55,7 @@ def dump_bbox_debug_data(doc, blocks: List[Page]):
for idx, page_blocks in enumerate(blocks):
page = doc[idx]

pix = page.get_pixmap(dpi=settings.NOUGAT_DPI, annots=False, clip=page_blocks.bbox)
pix = page.get_pixmap(dpi=settings.TEXIFY_DPI, annots=False, clip=page_blocks.bbox)
png = pix.pil_tobytes(format="PNG")
png_image = Image.open(io.BytesIO(png))
width, height = png_image.size
Expand Down
6 changes: 3 additions & 3 deletions marker/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from marker.cleaners.equations import load_nougat_model
from marker.cleaners.equations import load_texify_model
from marker.ordering import load_ordering_model
from marker.postprocessors.editor import load_editing_model
from marker.segmentation import load_layout_model
Expand All @@ -8,6 +8,6 @@ def load_all_models():
edit = load_editing_model()
order = load_ordering_model()
layout = load_layout_model()
nougat = load_nougat_model()
model_lst = [nougat, layout, order, edit]
texify = load_texify_model()
model_lst = [texify, layout, order, edit]
return model_lst
7 changes: 3 additions & 4 deletions marker/ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def load_ordering_model():
model = LayoutLMv3ForSequenceClassification.from_pretrained(
settings.ORDERER_MODEL_NAME,
torch_dtype=settings.MODEL_DTYPE,
).to(settings.TORCH_DEVICE)
).to(settings.TORCH_DEVICE_MODEL)
model.eval()
return model

Expand Down Expand Up @@ -65,12 +65,11 @@ def batch_inference(rgb_images, bboxes, words, model):
max_length=128
)

if settings.CUDA:
encoding["pixel_values"] = encoding["pixel_values"].to(torch.bfloat16)
encoding["pixel_values"] = encoding["pixel_values"].to(model.dtype)

with torch.inference_mode():
for k in ["bbox", "input_ids", "pixel_values", "attention_mask"]:
encoding[k] = encoding[k].to(settings.TORCH_DEVICE)
encoding[k] = encoding[k].to(model.device)
outputs = model(**encoding)
logits = outputs.logits

Expand Down
2 changes: 1 addition & 1 deletion marker/postprocessors/editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def load_editing_model():
model = T5ForTokenClassification.from_pretrained(
settings.EDITOR_MODEL_NAME,
torch_dtype=settings.MODEL_DTYPE,
).to(settings.TORCH_DEVICE)
).to(settings.TORCH_DEVICE_MODEL)
model.eval()

model.config.label2id = {
Expand Down
7 changes: 3 additions & 4 deletions marker/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def load_layout_model():
model = LayoutLMv3ForTokenClassification.from_pretrained(
settings.LAYOUT_MODEL_NAME,
torch_dtype=settings.MODEL_DTYPE,
).to(settings.TORCH_DEVICE)
).to(settings.TORCH_DEVICE_MODEL)

model.config.id2label = {
0: "Caption",
Expand Down Expand Up @@ -173,10 +173,9 @@ def predict_block_types(encodings, layoutlm_model, batch_size):

model_in = {}
for k in ["bbox", "input_ids", "attention_mask", "pixel_values"]:
model_in[k] = torch.stack([b[k] for b in batch]).to(settings.TORCH_DEVICE)
model_in[k] = torch.stack([b[k] for b in batch]).to(layoutlm_model.device)

if settings.CUDA:
model_in["pixel_values"] = model_in["pixel_values"].to(torch.bfloat16)
model_in["pixel_values"] = model_in["pixel_values"].to(layoutlm_model.dtype)

with torch.inference_mode():
outputs = layoutlm_model(**model_in)
Expand Down
Loading

0 comments on commit d35d295

Please sign in to comment.