Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nougat -> Texify, automatic gpu detection #55

Merged
merged 2 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Marker converts PDF, EPUB, and MOBI to markdown. It's 10x faster than nougat, m
- Removes headers/footers/other artifacts
- Converts most equations to latex
- Formats code blocks and tables
- Support for multiple languages (although most testing is done in English). See `settings.py` for a language list.
- Support for multiple languages (although most testing is done in English). See `settings.py` for a language list, or to add your own.
- Works on GPU, CPU, or MPS

## How it works
Expand All @@ -15,7 +15,7 @@ Marker is a pipeline of deep learning models:

- Extract text, OCR if necessary (heuristics, tesseract)
- Detect page layout ([layout segmenter](https://huggingface.co/vikp/layout_segmenter), [column detector](https://huggingface.co/vikp/column_detector))
- Clean and format each block (heuristics, [nougat](https://huggingface.co/facebook/nougat-base))
- Clean and format each block (heuristics, [texify](https://huggingface.co/vikp/texify))
- Combine blocks and postprocess complete text (heuristics, [pdf_postprocessor](https://huggingface.co/vikp/pdf_postprocessor_t5))

Relying on autoregressive forward passes to generate text is slow and prone to hallucination/repetition. From the nougat paper: `We observed [repetition] in 1.5% of pages in the test set, but the frequency increases for out-of-domain documents.` In my anecdotal testing, repetitions happen on 5%+ of out-of-domain (non-arXiv) pages.
Expand Down Expand Up @@ -48,10 +48,10 @@ See [below](#benchmarks) for detailed speed and accuracy benchmarks, and instruc

PDF is a tricky format, so marker will not always work perfectly. Here are some known limitations that are on the roadmap to address:

- Marker will convert fewer equations to latex than nougat. This is because it has to first detect equations, then convert them without hallucation.
- Marker will not convert 100% of equations to LaTeX. This is because it has to first detect equations, then convert them.
- Whitespace and indentations are not always respected.
- Not all lines/spans will be joined properly.
- Languages similar to English (Spanish, French, German, Russian, etc) have the best support. There is provisional support for Chinese, Japanese, Korean, and Hindi, but it may not work as well.
- Languages similar to English (Spanish, French, German, Russian, etc) have the best support. There is provisional support for Chinese, Japanese, Korean, and Hindi, but it may not work as well. You can add other languages by adding them to the `TESSERACT_LANGUAGES` and `SPELLCHECK_LANGUAGES` settings in `settings.py`.
- This works best on digital PDFs that won't require a lot of OCR. It's optimized for speed, and limited OCR is used to fix errors.

# Installation
Expand Down Expand Up @@ -88,17 +88,16 @@ First, clone the repo:
- Install python requirements
- `poetry install`
- `poetry shell` to activate your poetry venv
- On ARM macs (M1+), make sure to set the `TORCH_DEVICE` setting to `mps` (more details below) for a speedup

# Usage

First, some configuration:
First, some configuration. Note that settings can be overridden with env vars, or in a `local.env` file in the root `marker` folder.

- Set your torch device in the `local.env` file. For example, `TORCH_DEVICE=cuda` or `TORCH_DEVICE=mps`. `cpu` is the default.
- Your torch device will be automatically detected, but you can manually set it also. For example, `TORCH_DEVICE=cuda` or `TORCH_DEVICE=mps`. `cpu` is the default.
- If using GPU, set `INFERENCE_RAM` to your GPU VRAM (per GPU). For example, if you have 16 GB of VRAM, set `INFERENCE_RAM=16`.
- Depending on your document types, marker's average memory usage per task can vary slightly. You can configure `VRAM_PER_TASK` to adjust this if you notice tasks failing with GPU out of memory errors.
- Inspect the other settings in `marker/settings.py`. You can override any settings in the `local.env` file, or by setting environment variables.
- By default, the final editor model is off. Turn it on with `ENABLE_EDITOR_MODEL`.
- By default, the final editor model is off. Turn it on with `ENABLE_EDITOR_MODEL=true`.
- By default, marker will use ocrmypdf for OCR, which is slower than base tesseract, but higher quality. You can change this with the `OCR_ENGINE` setting.

## Convert a single file
Expand Down Expand Up @@ -148,6 +147,8 @@ MIN_LENGTH=10000 METADATA_FILE=../pdf_meta.json NUM_DEVICES=4 NUM_WORKERS=15 bas
- `NUM_WORKERS` is the number of parallel processes to run on each GPU. Per-GPU parallelism will not increase beyond `INFERENCE_RAM / VRAM_PER_TASK`.
- `MIN_LENGTH` is the minimum number of characters that need to be extracted from a pdf before it will be considered for processing. If you're processing a lot of pdfs, I recommend setting this to avoid OCRing pdfs that are mostly images. (slows everything down)

Note that the env variables above are specific to this script, and cannot be set in `local.env`.

# Benchmarks

Benchmarking PDF extraction quality is hard. I've created a test set by finding books and scientific papers that have a pdf version and a latex source. I convert the latex to text, and compare the reference to the output of text extraction methods.
Expand Down Expand Up @@ -203,7 +204,6 @@ I'm building a version that can be used commercially, by stripping out the depen
Here are the non-commercial/restrictive dependencies:

- LayoutLMv3: CC BY-NC-SA 4.0 . [Source](https://huggingface.co/microsoft/layoutlmv3-base)
- Nougat: CC-BY-NC . [Source](https://github.com/facebookresearch/nougat)
- PyMuPDF - GPL . [Source](https://pymupdf.readthedocs.io/en/latest/about.html#license-and-copyright)

Other dependencies/datasets are openly licensed (doclaynet, byt5), or used in a way that is compatible with commercial usage (ghostscript).
Expand Down
128 changes: 47 additions & 81 deletions marker/cleaners/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,26 @@
from typing import List

import torch
from nougat import NougatModel
from nougat.postprocessing import markdown_compatible
from nougat.utils.checkpoint import get_checkpoint
from texify.inference import batch_inference
from texify.model.model import load_model
from texify.model.processor import load_processor
import re
from PIL import Image, ImageDraw
from nougat.utils.dataset import ImageDataset

from marker.bbox import should_merge_blocks, merge_boxes
from marker.debug.data import dump_nougat_debug_data
from marker.debug.data import dump_equation_debug_data
from marker.settings import settings
from marker.schema import Page, Span, Line, Block, BlockType
from nougat.utils.device import move_to_device
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

processor = load_processor()

def load_nougat_model():
ckpt = get_checkpoint(None, model_tag=settings.NOUGAT_MODEL_NAME)
nougat_model = NougatModel.from_pretrained(ckpt)
if settings.TORCH_DEVICE != "cpu":
move_to_device(nougat_model, bf16=settings.CUDA, cuda=settings.CUDA)
nougat_model.eval()
return nougat_model

def load_texify_model():
texify_model = load_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=settings.TORCH_DEVICE_MODEL, dtype=settings.TEXIFY_DTYPE)
return texify_model


def mask_bbox(png_image, bbox, selected_bboxes):
Expand All @@ -54,76 +50,48 @@ def mask_bbox(png_image, bbox, selected_bboxes):
return result


def get_nougat_image(page, bbox, selected_bboxes):
pix = page.get_pixmap(dpi=settings.NOUGAT_DPI, clip=bbox)
png = pix.pil_tobytes(format="BMP")
def get_masked_image(page, bbox, selected_bboxes):
pix = page.get_pixmap(dpi=settings.TEXIFY_DPI, clip=bbox)
png = pix.pil_tobytes(format="PNG")
png_image = Image.open(io.BytesIO(png))
png_image = mask_bbox(png_image, bbox, selected_bboxes)
png_image = png_image.convert("RGB")

img_out = io.BytesIO()
png_image.save(img_out, format="BMP")
return img_out


def replace_latex_fences(text):
# Replace block equations: \[ ... \] with $$...$$
text = re.sub(r'\\\[(.*?)\\\]', r'$$\1$$', text)

# Replace inline math: \( ... \) with $...$
text = re.sub(r'\\\((.*?)\\\)', r'$\1$', text)

return text
return png_image


def get_nougat_text_batched(images, reformat_region_lens, nougat_model, batch_size):
def get_latex_batched(images, reformat_region_lens, texify_model, batch_size):
if len(images) == 0:
return []

predictions = [""] * len(images)

dataset = ImageDataset(
images,
partial(nougat_model.encoder.prepare_input, random_padding=False),
)

dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
pin_memory=True,
shuffle=False,
)

for idx, sample in enumerate(dataloader):
for i in range(0, len(images), batch_size):
# Dynamically set max length to save inference time
min_idx = idx * batch_size
min_idx = i
max_idx = min(min_idx + batch_size, len(images))
max_length = max(reformat_region_lens[min_idx:max_idx])
max_length = min(max_length, settings.NOUGAT_MODEL_MAX)
max_length += settings.NOUGAT_TOKEN_BUFFER

nougat_model.config.max_length = max_length
model_output = nougat_model.inference(image_tensors=sample, early_stopping=False)
for j, output in enumerate(model_output["predictions"]):
disclaimer = ""
token_count = get_total_nougat_tokens(output, nougat_model)
max_length = min(max_length, settings.TEXIFY_MODEL_MAX)
max_length += settings.TEXIFY_TOKEN_BUFFER

model_output = batch_inference(images[min_idx:max_idx], texify_model, processor, max_tokens=max_length)

for j, output in enumerate(model_output):
token_count = get_total_texify_tokens(output)
if token_count >= max_length - 1:
disclaimer = "[TRUNCATED]"
output = ""

image_idx = idx * batch_size + j
predictions[image_idx] = (
replace_latex_fences(markdown_compatible(output)) + disclaimer
)
image_idx = i + j
predictions[image_idx] = output
return predictions


def get_total_nougat_tokens(text, nougat_model):
tokenizer = nougat_model.decoder.tokenizer
def get_total_texify_tokens(text):
tokenizer = processor.tokenizer
tokens = tokenizer(text)
return len(tokens["input_ids"])


def find_page_equation_regions(pnum, page, block_types, nougat_model):
def find_page_equation_regions(pnum, page, block_types):
i = 0
equation_boxes = [b.bbox for b in block_types[pnum] if b.block_type == "Formula"]
reformatted_blocks = set()
Expand Down Expand Up @@ -151,7 +119,7 @@ def find_page_equation_regions(pnum, page, block_types, nougat_model):
prev_block = page.blocks[j]
prev_bbox = prev_block.bbox
prelim_block_text = prev_block.prelim_text + " " + block_text
if get_total_nougat_tokens(prelim_block_text, nougat_model) >= settings.NOUGAT_MODEL_MAX:
if get_total_texify_tokens(prelim_block_text) >= settings.TEXIFY_MODEL_MAX:
break

block_text = prelim_block_text
Expand All @@ -166,7 +134,7 @@ def find_page_equation_regions(pnum, page, block_types, nougat_model):
equation_boxes)) and i + 1 < len(page.blocks):
bbox = merge_boxes(bbox, next_bbox)
prelim_block_text = block_text + " " + next_block.prelim_text
if get_total_nougat_tokens(prelim_block_text, nougat_model) >= settings.NOUGAT_MODEL_MAX:
if get_total_texify_tokens(prelim_block_text) >= settings.TEXIFY_MODEL_MAX:
break

block_text = prelim_block_text
Expand All @@ -176,9 +144,9 @@ def find_page_equation_regions(pnum, page, block_types, nougat_model):
next_block = page.blocks[i + 1]
next_bbox = next_block.bbox

total_tokens = get_total_nougat_tokens(block_text, nougat_model)
total_tokens = get_total_texify_tokens(block_text)
ordered_blocks = sorted(([sb[0] for sb in selected_blocks]))
if total_tokens < settings.NOUGAT_MODEL_MAX:
if total_tokens < settings.TEXIFY_MODEL_MAX:
# Get indices of all blocks to merge
reformat_regions.append(ordered_blocks)
block_lens.append(total_tokens)
Expand Down Expand Up @@ -206,7 +174,7 @@ def get_bboxes_for_region(page, region):
return bboxes, merged_box


def replace_blocks_with_nougat_predictions(page_blocks: Page, merged_boxes, reformat_regions, predictions, pnum, nougat_model):
def replace_blocks_with_latex(page_blocks: Page, merged_boxes, reformat_regions, predictions, pnum):
new_blocks = []
converted_spans = []
current_region = 0
Expand All @@ -221,13 +189,12 @@ def replace_blocks_with_nougat_predictions(page_blocks: Page, merged_boxes, refo
continue

orig_block_text = " ".join([page_blocks.blocks[i].prelim_text for i in reformat_regions[current_region]])
nougat_text = predictions[current_region]
latex_text = predictions[current_region]
conditions = [
len(nougat_text) > 0,
not any([word in nougat_text for word in settings.NOUGAT_HALLUCINATION_WORDS]),
get_total_nougat_tokens(nougat_text, nougat_model) < settings.NOUGAT_MODEL_MAX, # Make sure we didn't run to the token max
len(nougat_text) > len(orig_block_text) * .8,
len(nougat_text.strip()) > 0
len(latex_text) > 0,
get_total_texify_tokens(latex_text) < settings.TEXIFY_MODEL_MAX, # Make sure we didn't run to the overall token max
len(latex_text) > len(orig_block_text) * .8,
len(latex_text.strip()) > 0
]

idx = reformat_regions[current_region][-1] + 1
Expand All @@ -241,7 +208,7 @@ def replace_blocks_with_nougat_predictions(page_blocks: Page, merged_boxes, refo
block_line = Line(
spans=[
Span(
text=nougat_text,
text=latex_text,
bbox=merged_boxes[current_region],
span_id=f"{pnum}_{idx}_fixeq",
font="Latex",
Expand All @@ -261,15 +228,15 @@ def replace_blocks_with_nougat_predictions(page_blocks: Page, merged_boxes, refo
return new_blocks, success_count, fail_count, converted_spans


def replace_equations(doc, blocks: List[Page], block_types: List[List[BlockType]], nougat_model, batch_size=settings.NOUGAT_BATCH_SIZE):
def replace_equations(doc, blocks: List[Page], block_types: List[List[BlockType]], texify_model, batch_size=settings.TEXIFY_BATCH_SIZE):
unsuccessful_ocr = 0
successful_ocr = 0

# Find potential equation regions, and length of text in each region
reformat_regions = []
reformat_region_lens = []
for pnum, page in enumerate(blocks):
regions, region_lens = find_page_equation_regions(pnum, page, block_types, nougat_model)
regions, region_lens = find_page_equation_regions(pnum, page, block_types)
reformat_regions.append(regions)
reformat_region_lens.append(region_lens)

Expand All @@ -283,26 +250,25 @@ def replace_equations(doc, blocks: List[Page], block_types: List[List[BlockType]
page_obj = doc[page_idx]
for reformat_region in reformat_regions_page:
bboxes, merged_box = get_bboxes_for_region(blocks[page_idx], reformat_region)
png_image = get_nougat_image(page_obj, merged_box, bboxes)
png_image = get_masked_image(page_obj, merged_box, bboxes)
images.append(png_image)
merged_boxes.append(merged_box)

# Make batched predictions
predictions = get_nougat_text_batched(images, flat_reformat_region_lens, nougat_model, batch_size)
predictions = get_latex_batched(images, flat_reformat_region_lens, texify_model, batch_size)

# Replace blocks with predictions
page_start = 0
converted_spans = []
for page_idx, reformat_regions_page in enumerate(reformat_regions):
page_predictions = predictions[page_start:page_start + len(reformat_regions_page)]
page_boxes = merged_boxes[page_start:page_start + len(reformat_regions_page)]
new_page_blocks, success_count, fail_count, converted_span = replace_blocks_with_nougat_predictions(
new_page_blocks, success_count, fail_count, converted_span = replace_blocks_with_latex(
blocks[page_idx],
page_boxes,
reformat_regions_page,
page_predictions,
page_idx,
nougat_model
page_idx
)
converted_spans.extend(converted_span)
blocks[page_idx].blocks = new_page_blocks
Expand All @@ -311,6 +277,6 @@ def replace_equations(doc, blocks: List[Page], block_types: List[List[BlockType]
unsuccessful_ocr += fail_count

# If debug mode is on, dump out conversions for comparison
dump_nougat_debug_data(doc, images, converted_spans)
dump_equation_debug_data(doc, images, converted_spans)

return blocks, {"successful_ocr": successful_ocr, "unsuccessful_ocr": unsuccessful_ocr, "equations": eq_count}
6 changes: 3 additions & 3 deletions marker/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def convert_single_pdf(
return "", out_meta

# Unpack models from list
nougat_model, layoutlm_model, order_model, edit_model = model_lst
texify_model, layoutlm_model, order_model, edit_model = model_lst

block_types = detect_document_block_types(
doc,
Expand Down Expand Up @@ -146,8 +146,8 @@ def convert_single_pdf(
doc,
blocks,
block_types,
nougat_model,
batch_size=int(settings.NOUGAT_BATCH_SIZE * parallel_factor)
texify_model,
batch_size=int(settings.TEXIFY_BATCH_SIZE * parallel_factor)
)
out_meta["block_stats"]["equations"] = eq_stats

Expand Down
4 changes: 2 additions & 2 deletions marker/debug/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import io


def dump_nougat_debug_data(doc, images, converted_spans):
def dump_equation_debug_data(doc, images, converted_spans):
if not settings.DEBUG_DATA_FOLDER or settings.DEBUG_LEVEL == 0:
return

Expand Down Expand Up @@ -55,7 +55,7 @@ def dump_bbox_debug_data(doc, blocks: List[Page]):
for idx, page_blocks in enumerate(blocks):
page = doc[idx]

pix = page.get_pixmap(dpi=settings.NOUGAT_DPI, annots=False, clip=page_blocks.bbox)
pix = page.get_pixmap(dpi=settings.TEXIFY_DPI, annots=False, clip=page_blocks.bbox)
png = pix.pil_tobytes(format="PNG")
png_image = Image.open(io.BytesIO(png))
width, height = png_image.size
Expand Down
6 changes: 3 additions & 3 deletions marker/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from marker.cleaners.equations import load_nougat_model
from marker.cleaners.equations import load_texify_model
from marker.ordering import load_ordering_model
from marker.postprocessors.editor import load_editing_model
from marker.segmentation import load_layout_model
Expand All @@ -8,6 +8,6 @@ def load_all_models():
edit = load_editing_model()
order = load_ordering_model()
layout = load_layout_model()
nougat = load_nougat_model()
model_lst = [nougat, layout, order, edit]
texify = load_texify_model()
model_lst = [texify, layout, order, edit]
return model_lst
Loading