We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
首先感谢大佬的代码,让我解决了一些问题,但是我看demo中都是处理图片的,而且是load一次模型处理一张图,真正大批量使用还是pdf较多,抛砖引玉,我写了个串行处理pdf的脚本,理解不深里面肯定有很多不当之处,望大佬readme中添加一个pdf demo,用于多卡并行处理大量pdf的情景。
#!/usr/bin/env python3 import os import argparse import fitz # PyMuPDF from pathlib import Path import torch from PIL import Image import time from datetime import datetime import json from transformers import AutoTokenizer from GOT.model import GOTQwenForCausalLM from GOT.utils.conversation import conv_templates, SeparatorStyle from GOT.utils.utils import disable_torch_init, KeywordsStoppingCriteria from GOT.model.plug.blip_process import BlipImageEvalProcessor from transformers import TextStreamer # Constants DEFAULT_IMAGE_TOKEN = "<image>" DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>' DEFAULT_IM_START_TOKEN = '<img>' DEFAULT_IM_END_TOKEN = '</img>' def pdf_to_pil_image(page, dpi=300): """Convert PDF page to PIL Image without saving to disk""" pix = page.get_pixmap(matrix=fitz.Matrix(dpi/72, dpi/72)) img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) return img def clean_text(text): """Clean OCR output text""" text = text.replace('====new images batch size======: ', '') return text.strip() def dynamic_preprocess(image, min_num=1, max_num=6, image_size=1024): """Reuse the dynamic preprocessing from run_ocr_2.0_crop.py""" orig_width, orig_height = image.size aspect_ratio = orig_width / orig_height target_ratios = set( (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num ) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) best_ratio = min(target_ratios, key=lambda r: abs(aspect_ratio - r[0]/r[1])) target_width = image_size * best_ratio[0] target_height = image_size * best_ratio[1] resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(best_ratio[0] * best_ratio[1]): box = ( (i % best_ratio[0]) * image_size, (i // best_ratio[0]) * image_size, ((i % best_ratio[0]) + 1) * image_size, ((i // best_ratio[0]) + 1) * image_size ) split_img = resized_img.crop(box) processed_images.append(split_img) return processed_images class OCRProcessor: def __init__(self, model_path): # Record model loading time load_start = time.time() # Initialize model and processors disable_torch_init() self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) self.model = GOTQwenForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=151643 ).eval() self.model.to(device='cuda', dtype=torch.bfloat16) self.image_processor = BlipImageEvalProcessor(image_size=1024) self.image_token_len = 256 self.model_load_time = time.time() - load_start def process_page(self, pil_image): """Process a single page""" page_start = time.time() sub_images = dynamic_preprocess(pil_image) image_tensors = [] for img in sub_images: tensor = self.image_processor(img) image_tensors.append(tensor) image_tensors = torch.stack(image_tensors) # Prepare input prompt qs = 'OCR with format upon the patch reference: ' qs = (DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * self.image_token_len * len(sub_images) + DEFAULT_IM_END_TOKEN + '\n' + qs) conv = conv_templates["mpt"].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() inputs = self.tokenizer([prompt]) input_ids = torch.as_tensor(inputs.input_ids).cuda() stop_str = conv.sep stopping_criteria = KeywordsStoppingCriteria([stop_str], self.tokenizer, input_ids) with torch.autocast("cuda", dtype=torch.bfloat16): output_ids = self.model.generate( input_ids, images=[(image_tensors.half().cuda(), image_tensors.half().cuda())], do_sample=False, num_beams=1, streamer=None, max_new_tokens=4096, stopping_criteria=[stopping_criteria] ) outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] process_time = time.time() - page_start return clean_text(outputs), process_time def get_file_size(file_path): """Get file size in MB""" return os.path.getsize(file_path) / (1024 * 1024) def process_pdf(pdf_path, output_dir, processor, pdf_dir): """Process a single PDF file""" start_time = time.time() # Create output path that mirrors input directory structure rel_path = os.path.relpath(pdf_path, start=pdf_dir) output_path = Path(output_dir) / Path(rel_path).with_suffix('.md') output_path.parent.mkdir(parents=True, exist_ok=True) # Process PDF doc = fitz.open(pdf_path) total_pages = len(doc) page_times = [] print(f"\nProcessing: {rel_path} ({total_pages} pages)") with open(output_path, 'w', encoding='utf-8') as f: for page_num in range(total_pages): print(f" Page {page_num + 1}/{total_pages}", end='\r') page = doc.load_page(page_num) pil_image = pdf_to_pil_image(page) text, page_time = processor.process_page(pil_image) page_times.append(page_time) # f.write(f"## Page {page_num + 1}\n\n") # f.write(text) # f.write("\n\n") print() # New line after progress indicator total_time = time.time() - start_time file_size = get_file_size(pdf_path) # Prepare statistics stats = { "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "pdf_path": str(pdf_path), "relative_path": str(rel_path), "output_path": str(output_path), "file_size_mb": round(file_size, 2), "total_pages": total_pages, "total_time": round(total_time, 2), "average_time_per_page": round(total_time / total_pages, 2), "page_processing_times": [round(t, 2) for t in page_times], # "model_load_time": round(processor.model_load_time, 2) } return stats def main(): parser = argparse.ArgumentParser() parser.add_argument("--pdf", type=str, required=True, help="PDF file path or directory containing PDF files") parser.add_argument("--output", type=str, required=True, help="Output directory for markdown files") parser.add_argument("--model", type=str, default="stepfun-ai/GOT-OCR2_0", help="Model path") args = parser.parse_args() # Create output directory output_dir = Path(args.output) output_dir.mkdir(parents=True, exist_ok=True) # Get absolute path of input PDF directory or file pdf_path = os.path.abspath(args.pdf) pdf_dir = pdf_path if os.path.isdir(pdf_path) else os.path.dirname(pdf_path) # Collect PDF files pdf_files = [] if os.path.isfile(pdf_path): if pdf_path.lower().endswith('.pdf'): pdf_files.append(pdf_path) else: for root, _, files in os.walk(pdf_path): for file in files: if file.lower().endswith('.pdf'): pdf_files.append(os.path.join(root, file)) if not pdf_files: print("No PDF files found!") return # Initialize processor processor = OCRProcessor(args.model) # Create results file timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") results_file = output_dir / f"results_{timestamp}.json" # Process each PDF all_stats = [] for pdf_file in pdf_files: try: stats = process_pdf(pdf_file, output_dir, processor, pdf_dir) all_stats.append(stats) # Save updated results after each PDF with open(results_file, 'w', encoding='utf-8') as f: json.dump({ "run_timestamp": timestamp, "total_pdfs": len(pdf_files), "processed_pdfs": len(all_stats), "model_path": args.model, "model_load_time": processor.model_load_time, "input_directory": pdf_dir, "output_directory": str(output_dir), "pdf_statistics": all_stats }, f, indent=2, ensure_ascii=False) except Exception as e: print(f"\nError processing {os.path.relpath(pdf_file, start=pdf_dir)}: {e}") error_stats = { "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "pdf_path": str(pdf_file), "relative_path": str(os.path.relpath(pdf_file, start=pdf_dir)), "error": str(e) } all_stats.append(error_stats) print(f"\nProcessing completed. Results saved to {results_file}") if __name__ == "__main__": main()
The text was updated successfully, but these errors were encountered:
No branches or pull requests
首先感谢大佬的代码,让我解决了一些问题,但是我看demo中都是处理图片的,而且是load一次模型处理一张图,真正大批量使用还是pdf较多,抛砖引玉,我写了个串行处理pdf的脚本,理解不深里面肯定有很多不当之处,望大佬readme中添加一个pdf demo,用于多卡并行处理大量pdf的情景。
The text was updated successfully, but these errors were encountered: