-
-
Notifications
You must be signed in to change notification settings - Fork 76
/
Copy pathcollect.py
237 lines (200 loc) · 11 KB
/
collect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import os
import argparse
import json
import asyncio
import shutil
from pipeline.download import (
remove_version_from_string,
download_pdf_from_arxiv,
download_pdf_from_openreview,
get_paper_from_arxiv_by_openreview
)
from pipeline.pdf_to_images import pdf_to_images
from pipeline.parse_media_html import get_media_from_html
from pipeline.crop import crop_figures
from pipeline.crop_doublecheck import doublecheck_figures
from pipeline.enrich_desc import enrich_description_from_images, enrich_description_from_html
from pipeline.reformat_tables import reformat_tables_from_html
from pipeline.extract_sections import extract_sections
from pipeline.extract_section_details import extract_section_details
from pipeline.extract_references import extract_references
from pipeline.extract_essentials import extract_essentials
from pipeline.extract_affiliation import extract_affiliation
from pipeline.extract_category import extract_category
from pipeline.write_script import write_script
from pipeline.script_to_speech import script_to_speech
from pipeline.utils import UploadedFiles
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--arxiv-id', type=str, default=None, help='arXiv ID')
parser.add_argument('--openreview-id', type=str, default=None,help='OpenReview ID')
parser.add_argument('--skip-comparision-openreview-arxiv', action='store_true', help='Skip downloading the PDF file')
parser.add_argument('--skip-page-threshold', type=int, default=50, help='Skip the paper if the number of pages is greater than the threshold')
parser.add_argument('--workers', type=int, default=10, help='Number of workers')
parser.add_argument('--use-upstage', action='store_true', help='Use Upstage to extract figures from images')
parser.add_argument('--use-mineru', action='store_true', help='Use MinerU to extract figures from images')
parser.add_argument('--stop-at-no-html', action='store_true', help='Stop if no HTML is found')
parser.add_argument('--known-affiliations-path', type=str, default='configs/known_affiliations.txt', help='Path to known affiliations')
parser.add_argument('--known-categories-path', type=str, default='configs/known_categories.json', help='Path to known categories')
parser.add_argument('--voice-synthesis', type=str, default=None, choices=['vertexai', 'local'], help='Voice synthesis service to use')
return parser.parse_args()
async def main(args):
print(args)
use_html = True
if args.arxiv_id is not None:
# 1. download pdf
root_path = remove_version_from_string(args.arxiv_id)
print(f"Downloading PDF from arXiv: {args.arxiv_id}")
pdf_file_path = download_pdf_from_arxiv(root_path, args.arxiv_id)
if args.openreview_id is not None:
found_on_arxiv = False
# 1. download pdf
if args.skip_comparision_openreview_arxiv:
print(f"Skipping the comparision between OpenReview and arXiv")
else:
print(f"Checking if the paper is on arXiv")
found_on_arxiv, arxiv_id = get_paper_from_arxiv_by_openreview(args.openreview_id)
if found_on_arxiv:
print(f"The paper is on arXiv. Skip downloading the PDF from OpenReview")
args.arxiv_id = arxiv_id
root_path = args.arxiv_id
print(f"Downloading PDF from arXiv: {args.arxiv_id}")
pdf_file_path = download_pdf_from_arxiv(root_path, args.arxiv_id)
args.openreview_id = None
else:
print(f"The paper is not on arXiv. Downloading the PDF from OpenReview")
root_path = args.openreview_id
print(f"Downloading PDF from OpenReview: {args.openreview_id}")
pdf_file_path = download_pdf_from_openreview(root_path, args.openreview_id)
with UploadedFiles(pdf_file_path) as uploaded_files:
pdf_file_in_gemini = uploaded_files[0]
# 2. convert pdf to images
print(f"Converting PDF to images")
image_paths = pdf_to_images(pdf_file_path, f"{root_path}/paper_images")
if len(image_paths) >= args.skip_page_threshold:
print(f"Too many images: {len(image_paths)}. Skip this paper.")
shutil.rmtree(root_path)
return
# 3. crop figures from images
print(f"Using HTML to extract figures and tables")
if args.arxiv_id is not None:
figures, tables = get_media_from_html(args.arxiv_id)
print(figures)
print(tables)
print("---")
if figures is None or tables is None:
if args.stop_at_no_html:
print(f"No HTML is found. Skip this paper.")
shutil.rmtree(root_path)
return
else:
use_html = False
else:
use_html = False
if not use_html:
print(f"Cropping figures from images")
figure_paths, table_paths = await crop_figures(image_paths, root_path, args.use_upstage, args.use_mineru, args.workers, pdf_file_path)
print(f"{len(figure_paths)} number of figures are extracted and saved {figure_paths}.")
print(f"{len(table_paths)} number of tables are extracted and saved {table_paths}.")
# 4. Double check if figure image file contians figure
# if not use_html:
# print(f"Double checking if figure image file actually contians figure")
# # Filter out invalid figures and clean up files
# valid_figure_paths = await doublecheck_figures(figure_paths, pdf_file_in_gemini, args.workers, "figure")
# valid_figure_paths = [figure_paths[0]] if len(valid_figure_paths) == 0 else valid_figure_paths
# invalid_paths = set(figure_paths) - set(valid_figure_paths)
# for path in invalid_paths:
# os.remove(path)
# figure_paths = valid_figure_paths
# print(f"{len(figure_paths)} number of figures are remained. {figure_paths}.")
# print(f"Double checking if table file actually contians table")
# valid_table_paths = await doublecheck_figures(table_paths, pdf_file_in_gemini, args.workers, "table")
# valid_table_paths = [table_paths[0]] if len(valid_table_paths) == 0 else valid_table_paths
# invalid_paths = set(table_paths) - set(valid_table_paths)
# for path in invalid_paths:
# os.remove(path)
# table_paths = valid_table_paths
# print(f"{len(table_paths)} number of tables are remained. {table_paths}.")
# else:
# print(f"Reformatting tables")
# tables = await reformat_tables_from_html(args.arxiv_id, tables, args.workers)
# 5. associate each figure and table with description
print(f"Associating each figure with relevant information")
if not use_html:
association_figure_results = await enrich_description_from_images(figure_paths, pdf_file_in_gemini, args.workers, "figure")
association_table_results = await enrich_description_from_images(table_paths, pdf_file_in_gemini, args.workers, "table")
else:
print(figures)
association_figure_results = await enrich_description_from_html(figures, pdf_file_in_gemini, args.workers, "figure")
association_table_results = await enrich_description_from_html(tables, pdf_file_in_gemini, args.workers, "table")
# 6. save the results
print(f"Saving the figure information")
association_figure_path = f"{root_path}/figures.json"
with open(association_figure_path, "w") as f:
json.dump(association_figure_results, f)
print(f"Figure information is saved to {association_figure_path}")
print(f"Saving the table information")
association_table_path = f"{root_path}/tables.json"
print(association_table_path)
try:
with open(association_table_path, "w") as f:
json.dump(association_table_results, f)
except Exception as e:
print(e)
print(f"Table information is saved to {association_table_path}")
# 7. extract fundamental information from the pdf
print(f"Extracting essential information from the pdf")
essential_info = extract_essentials(pdf_file_in_gemini)
# 8. extract affiliation from the pdf
print(f"Extracting affiliation from the pdf")
affiliation = extract_affiliation(pdf_file_in_gemini, args.known_affiliations_path)
essential_info["affiliation"] = affiliation["affiliation"]
categories = extract_category(pdf_file_in_gemini, args.known_categories_path)
essential_info["categories"] = categories
if args.voice_synthesis == "vertexai":
print("Generating podcast")
print("Writing script")
raw_script_path = f"{root_path}/raw_script.txt"
script = write_script(pdf_file_in_gemini)
with open(raw_script_path, "w", encoding="utf-8") as f:
json.dump(script, f)
podcast = script_to_speech(script, use_vertexai=True)
podcast_path = f"{root_path}/podcast.wav"
podcast.export(podcast_path, format="wav")
print(f"Podcast is saved to {podcast_path}")
essential_info["podcast_path"] = podcast_path
elif args.voice_synthesis == "local":
pass
print(f"Saving essential information")
results_path = f"{root_path}/essential.json"
with open(results_path, "w") as f:
json.dump(essential_info, f)
print(f"Essential information is saved to {results_path}")
# 9. extract sections from the pdf
print(f"Extracting section list from the pdf")
sections = extract_sections(pdf_file_in_gemini)["sections"]
# 10. extract section details from the pdf
print(f"Extracting section details from the pdf")
section_detail_list = await extract_section_details(pdf_file_in_gemini, sections, args.workers)
for i in range(len(section_detail_list)):
sections[i]["details"] = section_detail_list[i]
print(f"Saving section details")
results_path = f"{root_path}/sections.json"
with open(results_path, "w") as f:
json.dump(sections, f)
print(f"Section details are saved to {results_path}")
# 11. extract references from the pdf
print(f"Extracting references from the pdf")
references = extract_references(pdf_file_in_gemini, sections)
print(f"Saving references")
results_path = f"{root_path}/references.json"
with open(results_path, "w") as f:
json.dump(references, f)
print(f"References are saved to {results_path}")
if __name__ == "__main__":
args = parse_args()
if args.arxiv_id is None and args.openreview_id is None:
raise ValueError("Either arxiv-id or openreview-id must be provided")
if args.use_upstage and args.use_mineru:
raise ValueError("use-upstage and use-mineru cannot be provided at the same time")
asyncio.run(main(args))