Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed type problems from llmclient #770

Merged
merged 28 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
869eb46
Fixed type problems
maykcaldas Dec 18, 2024
477fb7c
Added updated uv.lock
maykcaldas Dec 18, 2024
976cba4
Added comments in pyproject and pre-commit
maykcaldas Dec 18, 2024
688735f
Removed unneeded comment
maykcaldas Dec 18, 2024
9ca2e66
Updated tests to check the embeddings are lists
maykcaldas Dec 18, 2024
c411a63
Reverted type DocDetails from Text.Doc
maykcaldas Dec 19, 2024
5a2adc9
Preparing to merge updated code on main
maykcaldas Jan 13, 2025
659bb0f
Merge branch 'main' into fix-llmclient-types
maykcaldas Jan 13, 2025
598e19b
Removed unused type ignores
maykcaldas Jan 13, 2025
e358d26
added failure message to tests
maykcaldas Jan 13, 2025
aed1cb9
Reverted Text.doc type to and added explanatory TODO for future refe…
maykcaldas Jan 13, 2025
adf00cb
Fixed pre-commit issues
maykcaldas Jan 13, 2025
fccd820
Checks if Text.doc can be a DocDetails. If not, forces Doc
maykcaldas Jan 14, 2025
c70f3de
Fixed pylint error
maykcaldas Jan 14, 2025
5a81f71
Refactored logic to validate Text.doc
maykcaldas Jan 14, 2025
da9b04d
Fix pylint
maykcaldas Jan 14, 2025
a7bbea4
Cleaned up comments
maykcaldas Jan 14, 2025
be975ab
Avoided creating a model_validator.
maykcaldas Jan 15, 2025
c607522
Adjusted casting
maykcaldas Jan 15, 2025
682eae3
Merge branch 'main' into fix-llmclient-types
maykcaldas Jan 16, 2025
0ea2478
Forbid extras in Doc again
maykcaldas Jan 16, 2025
f8da02d
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Jan 16, 2025
ce08e0b
Implemented some suggestions
maykcaldas Jan 16, 2025
e431b07
Fixed pre-commit
maykcaldas Jan 16, 2025
eb063a4
Reverting uv.lock change
jamesbraza Jan 16, 2025
b237dec
Removed Doc type checkign from validate_all_fields
maykcaldas Jan 16, 2025
9d5e915
Resolving typing PR comments (#813)
jamesbraza Jan 16, 2025
14f3688
Added explictly extra=ignore back
maykcaldas Jan 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions paperqa/agents/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,13 @@ def table_formatter(
table.add_column("Title", style="cyan")
table.add_column("File", style="magenta")
for obj, filename in objects:
try:
display_name = cast(DocDetails, cast(Docs, obj).texts[0].doc).title
except AttributeError:
display_name = cast(Docs, obj).texts[0].doc.formatted_citation
table.add_row(cast(str, display_name)[:max_chars_per_column], filename)
docs = cast(Docs, obj) # Assume homogeneous objects
doc = docs.texts[0].doc
if isinstance(doc, DocDetails) and doc.title:
display_name: str = doc.title # Prefer title if available
else:
display_name = doc.formatted_citation
table.add_row(display_name[:max_chars_per_column], filename)
return table
raise NotImplementedError(
f"Object type {type(example_object)} can not be converted to table."
Expand Down
3 changes: 3 additions & 0 deletions paperqa/agents/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import warnings
import zlib
from collections.abc import Callable, Collection, Sequence
from datetime import datetime
from enum import StrEnum, auto
from typing import TYPE_CHECKING, Any, ClassVar
from uuid import UUID
Expand Down Expand Up @@ -70,6 +71,8 @@ def default(self, o):
return list(o)
if isinstance(o, os.PathLike):
return str(o)
if isinstance(o, datetime):
return o.isoformat()
return json.JSONEncoder.default(self, o)


Expand Down
2 changes: 1 addition & 1 deletion paperqa/clients/crossref.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ async def parse_crossref_to_doc_details(
elif len(date_parts) == 1:
publication_date = datetime(date_parts[0], 1, 1)

doc_details = DocDetails( # type: ignore[call-arg]
doc_details = DocDetails(
key=None if not bibtex else bibtex.split("{")[1].split(",")[0],
bibtex_type=CROSSREF_CONTENT_TYPE_TO_BIBTEX_MAPPING.get(
message.get("type", "other"), "misc"
Expand Down
2 changes: 1 addition & 1 deletion paperqa/clients/journal_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def _process(
# docname can be blank since the validation will add it
# remember, if both have docnames (i.e. key) they are
# wiped and re-generated with resultant data
return doc_details + DocDetails( # type: ignore[call-arg]
return doc_details + DocDetails(
source_quality=max(
[
self.data.get(query.journal.casefold(), DocDetails.UNDEFINED_JOURNAL_QUALITY), # type: ignore[union-attr]
Expand Down
2 changes: 1 addition & 1 deletion paperqa/clients/openalex.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def parse_openalex_to_doc_details(message: dict[str, Any]) -> DocDetails:

bibtex_type = BIBTEX_MAPPING.get(message.get("type") or "other", "misc")

return DocDetails( # type: ignore[call-arg]
return DocDetails(
key=None,
bibtex_type=bibtex_type,
bibtex=None,
Expand Down
2 changes: 1 addition & 1 deletion paperqa/clients/retractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async def _process(self, query: DOIQuery, doc_details: DocDetails) -> DocDetails
if not self.doi_set:
await self.load_data()

return doc_details + DocDetails(is_retracted=query.doi in self.doi_set) # type: ignore[call-arg]
return doc_details + DocDetails(is_retracted=query.doi in self.doi_set)

def query_creator(self, doc_details: DocDetails, **kwargs) -> DOIQuery | None:
try:
Expand Down
2 changes: 1 addition & 1 deletion paperqa/clients/semantic_scholar.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ async def parse_s2_to_doc_details(

journal_data = paper_data.get("journal") or {}

doc_details = DocDetails( # type: ignore[call-arg]
doc_details = DocDetails(
key=None if not bibtex else bibtex.split("{")[1].split(",")[0],
bibtex_type="article", # s2 should be basically all articles
bibtex=bibtex,
Expand Down
2 changes: 1 addition & 1 deletion paperqa/clients/unpaywall.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _create_doc_details(self, data: UnpaywallResponse) -> DocDetails:
if data.best_oa_location:
pdf_url = data.best_oa_location.url_for_pdf
license = data.best_oa_location.license # noqa: A001
return DocDetails( # type: ignore[call-arg]
return DocDetails(
authors=[
f"{author.given} {author.family}" for author in (data.z_authors or [])
],
Expand Down
2 changes: 1 addition & 1 deletion paperqa/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ async def map_fxn_summary(
text=Text(
text=text.text,
name=text.name,
doc=text.doc.__class__(**text.doc.model_dump(exclude={"embedding"})),
doc=text.doc.model_dump(exclude={"embedding"}),
),
score=score, # pylint: disable=possibly-used-before-assignment
**extras,
Expand Down
15 changes: 9 additions & 6 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,10 +470,12 @@ async def aadd_texts(
# 3. Update self
# NOTE: we defer adding texts to the texts index to retrieval time
# (e.g. `self.texts_index.add_texts_and_embeddings(texts)`)
self.docs[doc.dockey] = doc
self.texts += texts
self.docnames.add(doc.docname)
return True
if doc.docname and doc.dockey:
self.docs[doc.dockey] = doc
self.texts += texts
self.docnames.add(doc.docname)
return True
return False

def delete(
self,
Expand All @@ -489,8 +491,9 @@ def delete(
doc = next((doc for doc in self.docs.values() if doc.docname == name), None)
if doc is None:
return
self.docnames.remove(doc.docname)
dockey = doc.dockey
if doc.docname and doc.dockey:
self.docnames.remove(doc.docname)
dockey = doc.dockey
del self.docs[dockey]
self.deleted_dockeys.add(dockey)
self.texts = list(filter(lambda x: x.doc.dockey != dockey, self.texts))
Expand Down
8 changes: 7 additions & 1 deletion paperqa/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
from html2text import __version__ as html2text_version
from html2text import html2text

from paperqa.types import ChunkMetadata, Doc, ParsedMetadata, ParsedText, Text
from paperqa.types import (
ChunkMetadata,
Doc,
ParsedMetadata,
ParsedText,
Text,
)
from paperqa.utils import ImpossibleParsingError
from paperqa.version import __version__ as pqa_version

Expand Down
32 changes: 24 additions & 8 deletions paperqa/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import os
import re
import warnings
from collections.abc import Collection
from collections.abc import Collection, Mapping
from copy import deepcopy
from datetime import datetime
from typing import Any, ClassVar, cast
from uuid import UUID, uuid4
Expand Down Expand Up @@ -40,16 +41,23 @@


class Doc(Embeddable):
model_config = ConfigDict(extra="forbid")

docname: str
citation: str
dockey: DocKey
citation: str
overwrite_fields_from_metadata: bool = Field(
default=True,
description=(
"flag to overwrite fields from metadata when upgrading to a DocDetails"
),
)

@model_validator(mode="before")
@classmethod
def remove_computed_fields(cls, data: Mapping[str, Any]) -> dict[str, Any]:
return {k: v for k, v in data.items() if k != "formatted_citation"}

def __hash__(self) -> int:
return hash((self.docname, self.dockey))

Expand Down Expand Up @@ -80,7 +88,7 @@ def matches_filter_criteria(self, filter_criteria: dict) -> bool:
class Text(Embeddable):
text: str
name: str
doc: Doc
doc: Doc | DocDetails = Field(union_mode="left_to_right")

def __hash__(self) -> int:
return hash(self.text)
Expand Down Expand Up @@ -215,7 +223,7 @@ def filter_content_for_user(self) -> None:
text=Text(
text="",
**c.text.model_dump(exclude={"text", "embedding", "doc"}),
doc=Doc(**c.text.doc.model_dump(exclude={"embedding"})),
doc=c.text.doc.model_dump(exclude={"embedding"}),
),
)
for c in self.contexts
Expand Down Expand Up @@ -304,12 +312,18 @@ def reduce_content(self) -> str:


class DocDetails(Doc):
model_config = ConfigDict(validate_assignment=True)
model_config = ConfigDict(validate_assignment=True, extra="ignore")

# Sentinel to auto-populate a field within model_validator
AUTOPOPULATE_VALUE: ClassVar[str] = ""

citation: str = ""
docname: str = AUTOPOPULATE_VALUE
dockey: DocKey = AUTOPOPULATE_VALUE
citation: str = AUTOPOPULATE_VALUE
key: str | None = None
bibtex: str | None = Field(
default=None, description="Autogenerated from other represented fields."
default=AUTOPOPULATE_VALUE,
description="Autogenerated from other represented fields.",
)
authors: list[str] | None = None
publication_date: datetime | None = None
Expand Down Expand Up @@ -593,7 +607,9 @@ def populate_bibtex_key_citation( # noqa: PLR0912

@model_validator(mode="before")
@classmethod
def validate_all_fields(cls, data: dict[str, Any]) -> dict[str, Any]:
def validate_all_fields(cls, data: Mapping[str, Any]) -> dict[str, Any]:
data = deepcopy(data) # Avoid mutating input
data = dict(data)
data = cls.lowercase_doi_and_populate_doc_id(data)
data = cls.remove_invalid_authors(data)
data = cls.misc_string_cleaning(data)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ def test_cli_ask(agent_index_dir: Path, stub_data_dir: Path) -> None:
assert response.session.formatted_answer

search_result = search_query(
" ".join(response.session.formatted_answer.split()[:5]),
" ".join(response.session.formatted_answer.split()),
"answers",
settings,
)
found_answer = search_result[0][0]
assert isinstance(found_answer, AnswerResponse)
assert found_answer.model_dump_json() == response.model_dump_json()
assert found_answer.model_dump() == response.model_dump()


def test_cli_can_build_and_search_index(
Expand All @@ -80,5 +80,5 @@ def test_cli_can_build_and_search_index(
result = search_query("XAI", index_name, settings)
assert len(result) == 1
assert isinstance(result[0][0], Docs)
assert result[0][0].docnames == {"Wellawatte"}
assert all(d.startswith("Wellawatte") for d in result[0][0].docnames)
assert result[0][1] == "paper.pdf"
51 changes: 48 additions & 3 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,10 @@ def test_sparse_embedding(stub_data_dir: Path, vector_store: type[VectorStore])
citation="WikiMedia Foundation, 2023, Accessed now",
embedding_model=SparseEmbeddingModel(),
)
assert any(cast(list[float], docs.texts[0].embedding))
assert isinstance(
docs.texts[0].embedding, list
), "We require embeddings to be a list"
assert any(docs.texts[0].embedding), "We require embeddings to be populated"
assert all(
len(np.array(x.embedding).shape) == 1 for x in docs.texts
), "Embeddings should be 1D"
Expand All @@ -731,7 +734,10 @@ def test_hybrid_embedding(stub_data_dir: Path, vector_store: type[VectorStore])
citation="WikiMedia Foundation, 2023, Accessed now",
embedding_model=emb_model,
)
assert any(cast(list[float], docs.texts[0].embedding))
assert isinstance(
docs.texts[0].embedding, list
), "We require embeddings to be a list"
assert any(docs.texts[0].embedding), "We require embeddings to be populated"

# check the embeddings are the same size
assert docs.texts[0].embedding is not None
Expand Down Expand Up @@ -1237,7 +1243,7 @@ def test_answer_rename(recwarn) -> None:
],
)
def test_dois_resolve_to_correct_journals(doi_journals):
details = DocDetails(doi=doi_journals["doi"]) # type: ignore[call-arg]
details = DocDetails(doi=doi_journals["doi"])
assert details.journal == doi_journals["journal"]


Expand Down Expand Up @@ -1309,6 +1315,45 @@ def test_docdetails_merge_with_list_fields() -> None:
assert isinstance(merged_doc, DocDetails), "Merged doc should also be DocDetails"


def test_docdetails_deserialization() -> None:
deserialize_to_doc = {
"citation": "stub",
"dockey": "stub",
"docname": "Stub",
"embedding": None,
"formatted_citation": "stub",
"overwrite_fields_from_metadata": True,
}
deepcopy_deserialize_to_doc = deepcopy(deserialize_to_doc)
doc = Doc(**deserialize_to_doc)
assert not isinstance(doc, DocDetails), "Should just be Doc, not DocDetails"
assert (
deserialize_to_doc == deepcopy_deserialize_to_doc
), "Deserialization should not mutate input"

doc_details = DocDetails(**deserialize_to_doc)
serialized_doc_details = doc_details.model_dump(exclude_none=True)
for key, value in {
"docname": "unknownauthorsUnknownyearunknowntitle",
"citation": "Unknown authors. Unknown title. Unknown journal, Unknown year.",
"overwrite_fields_from_metadata": True,
"key": "unknownauthorsUnknownyearunknowntitle",
"bibtex": (
'@article{unknownauthorsUnknownyearunknowntitle,\n author = "authors,'
' Unknown",\n title = "Unknown title",\n year = "Unknown year",\n '
' journal = "Unknown journal"\n}\n'
),
"other": {},
"formatted_citation": (
"Unknown authors. Unknown title. Unknown journal, Unknown year."
),
}.items():
assert serialized_doc_details[key] == value
assert (
deserialize_to_doc == deepcopy_deserialize_to_doc
), "Deserialization should not mutate input"


@pytest.mark.vcr
@pytest.mark.parametrize("use_partition", [True, False])
@pytest.mark.asyncio
Expand Down
Loading