Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pdf demo #260

Open
Airseai6 opened this issue Jan 4, 2025 · 0 comments
Open

pdf demo #260

Airseai6 opened this issue Jan 4, 2025 · 0 comments
Labels
good first issue Good for newcomers

Comments

@Airseai6
Copy link

Airseai6 commented Jan 4, 2025

首先感谢大佬的代码,让我解决了一些问题,但是我看demo中都是处理图片的,而且是load一次模型处理一张图,真正大批量使用还是pdf较多,抛砖引玉,我写了个串行处理pdf的脚本,理解不深里面肯定有很多不当之处,望大佬readme中添加一个pdf demo,用于多卡并行处理大量pdf的情景。

#!/usr/bin/env python3
import os
import argparse
import fitz  # PyMuPDF
from pathlib import Path
import torch
from PIL import Image
import time
from datetime import datetime
import json

from transformers import AutoTokenizer
from GOT.model import GOTQwenForCausalLM
from GOT.utils.conversation import conv_templates, SeparatorStyle
from GOT.utils.utils import disable_torch_init, KeywordsStoppingCriteria
from GOT.model.plug.blip_process import BlipImageEvalProcessor
from transformers import TextStreamer

# Constants
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'

def pdf_to_pil_image(page, dpi=300):
    """Convert PDF page to PIL Image without saving to disk"""
    pix = page.get_pixmap(matrix=fitz.Matrix(dpi/72, dpi/72))
    img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
    return img

def clean_text(text):
    """Clean OCR output text"""
    text = text.replace('====new images batch size======: ', '')
    return text.strip()

def dynamic_preprocess(image, min_num=1, max_num=6, image_size=1024):
    """Reuse the dynamic preprocessing from run_ocr_2.0_crop.py"""
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height
    
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) 
        for i in range(1, n + 1) 
        for j in range(1, n + 1) 
        if i * j <= max_num and i * j >= min_num
    )
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
    
    best_ratio = min(target_ratios, 
                    key=lambda r: abs(aspect_ratio - r[0]/r[1]))
    
    target_width = image_size * best_ratio[0]
    target_height = image_size * best_ratio[1]
    resized_img = image.resize((target_width, target_height))
    
    processed_images = []
    for i in range(best_ratio[0] * best_ratio[1]):
        box = (
            (i % best_ratio[0]) * image_size,
            (i // best_ratio[0]) * image_size,
            ((i % best_ratio[0]) + 1) * image_size,
            ((i // best_ratio[0]) + 1) * image_size
        )
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
        
    return processed_images

class OCRProcessor:
    def __init__(self, model_path):
        # Record model loading time
        load_start = time.time()
        
        # Initialize model and processors
        disable_torch_init()
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        self.model = GOTQwenForCausalLM.from_pretrained(
            model_path,
            low_cpu_mem_usage=True,
            device_map='cuda',
            use_safetensors=True,
            pad_token_id=151643
        ).eval()
        self.model.to(device='cuda', dtype=torch.bfloat16)
        
        self.image_processor = BlipImageEvalProcessor(image_size=1024)
        self.image_token_len = 256
        
        self.model_load_time = time.time() - load_start
        
    def process_page(self, pil_image):
        """Process a single page"""
        page_start = time.time()
        
        sub_images = dynamic_preprocess(pil_image)
        image_tensors = []
        
        for img in sub_images:
            tensor = self.image_processor(img)
            image_tensors.append(tensor)
            
        image_tensors = torch.stack(image_tensors)
        
        # Prepare input prompt
        qs = 'OCR with format upon the patch reference: '
        qs = (DEFAULT_IM_START_TOKEN + 
              DEFAULT_IMAGE_PATCH_TOKEN * self.image_token_len * len(sub_images) +
              DEFAULT_IM_END_TOKEN + '\n' + qs)
        
        conv = conv_templates["mpt"].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
        
        inputs = self.tokenizer([prompt])
        input_ids = torch.as_tensor(inputs.input_ids).cuda()
        
        stop_str = conv.sep
        stopping_criteria = KeywordsStoppingCriteria([stop_str], self.tokenizer, input_ids)
        
        with torch.autocast("cuda", dtype=torch.bfloat16):
            output_ids = self.model.generate(
                input_ids,
                images=[(image_tensors.half().cuda(), image_tensors.half().cuda())],
                do_sample=False,
                num_beams=1,
                streamer=None,
                max_new_tokens=4096,
                stopping_criteria=[stopping_criteria]
            )
            
        outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
        if outputs.endswith(stop_str):
            outputs = outputs[:-len(stop_str)]
            
        process_time = time.time() - page_start
        return clean_text(outputs), process_time

def get_file_size(file_path):
    """Get file size in MB"""
    return os.path.getsize(file_path) / (1024 * 1024)

def process_pdf(pdf_path, output_dir, processor, pdf_dir):
    """Process a single PDF file"""
    start_time = time.time()
    
    # Create output path that mirrors input directory structure
    rel_path = os.path.relpath(pdf_path, start=pdf_dir)
    output_path = Path(output_dir) / Path(rel_path).with_suffix('.md')
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Process PDF
    doc = fitz.open(pdf_path)
    total_pages = len(doc)
    page_times = []
    
    print(f"\nProcessing: {rel_path} ({total_pages} pages)")
    
    with open(output_path, 'w', encoding='utf-8') as f:
        for page_num in range(total_pages):
            print(f"  Page {page_num + 1}/{total_pages}", end='\r')
            page = doc.load_page(page_num)
            pil_image = pdf_to_pil_image(page)
            
            text, page_time = processor.process_page(pil_image)
            page_times.append(page_time)
            
            # f.write(f"## Page {page_num + 1}\n\n")
            # f.write(text)
            # f.write("\n\n")
    
    print()  # New line after progress indicator
    
    total_time = time.time() - start_time
    file_size = get_file_size(pdf_path)
    
    # Prepare statistics
    stats = {
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "pdf_path": str(pdf_path),
        "relative_path": str(rel_path),
        "output_path": str(output_path),
        "file_size_mb": round(file_size, 2),
        "total_pages": total_pages,
        "total_time": round(total_time, 2),
        "average_time_per_page": round(total_time / total_pages, 2),
        "page_processing_times": [round(t, 2) for t in page_times],
        # "model_load_time": round(processor.model_load_time, 2)
    }
    
    return stats

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--pdf", type=str, required=True, 
                       help="PDF file path or directory containing PDF files")
    parser.add_argument("--output", type=str, required=True, 
                       help="Output directory for markdown files")
    parser.add_argument("--model", type=str, default="stepfun-ai/GOT-OCR2_0", 
                       help="Model path")
    args = parser.parse_args()
    
    # Create output directory
    output_dir = Path(args.output)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Get absolute path of input PDF directory or file
    pdf_path = os.path.abspath(args.pdf)
    pdf_dir = pdf_path if os.path.isdir(pdf_path) else os.path.dirname(pdf_path)
    
    # Collect PDF files
    pdf_files = []
    if os.path.isfile(pdf_path):
        if pdf_path.lower().endswith('.pdf'):
            pdf_files.append(pdf_path)
    else:
        for root, _, files in os.walk(pdf_path):
            for file in files:
                if file.lower().endswith('.pdf'):
                    pdf_files.append(os.path.join(root, file))
    
    if not pdf_files:
        print("No PDF files found!")
        return
    
    # Initialize processor
    processor = OCRProcessor(args.model)
    
    # Create results file
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_file = output_dir / f"results_{timestamp}.json"
    
    # Process each PDF
    all_stats = []
    for pdf_file in pdf_files:
        try:
            stats = process_pdf(pdf_file, output_dir, processor, pdf_dir)
            all_stats.append(stats)
            
            # Save updated results after each PDF
            with open(results_file, 'w', encoding='utf-8') as f:
                json.dump({
                    "run_timestamp": timestamp,
                    "total_pdfs": len(pdf_files),
                    "processed_pdfs": len(all_stats),
                    "model_path": args.model,
                    "model_load_time": processor.model_load_time,
                    "input_directory": pdf_dir,
                    "output_directory": str(output_dir),
                    "pdf_statistics": all_stats
                }, f, indent=2, ensure_ascii=False)
                
        except Exception as e:
            print(f"\nError processing {os.path.relpath(pdf_file, start=pdf_dir)}: {e}")
            error_stats = {
                "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                "pdf_path": str(pdf_file),
                "relative_path": str(os.path.relpath(pdf_file, start=pdf_dir)),
                "error": str(e)
            }
            all_stats.append(error_stats)
    
    print(f"\nProcessing completed. Results saved to {results_file}")

if __name__ == "__main__":
    main()
@Ucas-HaoranWei Ucas-HaoranWei added the good first issue Good for newcomers label Jan 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

2 participants