Skip to content

Commit

Permalink
core: Add ruff rules PYI
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Jan 21, 2025
1 parent 536b44a commit 9d30c94
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 26 deletions.
14 changes: 5 additions & 9 deletions libs/core/langchain_core/prompts/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
SkipValidation,
model_validator,
)
from typing_extensions import Self

from langchain_core._api import deprecated
from langchain_core.load import Serializable
Expand Down Expand Up @@ -455,11 +456,6 @@ async def aformat(self, **kwargs: Any) -> BaseMessage:
)


_StringImageMessagePromptTemplateT = TypeVar(
"_StringImageMessagePromptTemplateT", bound="_StringImageMessagePromptTemplate"
)


class _TextTemplateParam(TypedDict, total=False):
text: Union[str, dict]

Expand Down Expand Up @@ -487,13 +483,13 @@ def get_lc_namespace(cls) -> list[str]:

@classmethod
def from_template(
cls: type[_StringImageMessagePromptTemplateT],
cls: type[Self],
template: Union[str, list[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
template_format: PromptTemplateFormat = "f-string",
*,
partial_variables: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> _StringImageMessagePromptTemplateT:
) -> Self:
"""Create a class from a string template.
Args:
Expand Down Expand Up @@ -581,11 +577,11 @@ def from_template(

@classmethod
def from_template_file(
cls: type[_StringImageMessagePromptTemplateT],
cls: type[Self],
template_file: Union[str, Path],
input_variables: list[str],
**kwargs: Any,
) -> _StringImageMessagePromptTemplateT:
) -> Self:
"""Create a class from a template file.
Args:
Expand Down
4 changes: 2 additions & 2 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4143,7 +4143,7 @@ def get_output_schema(
module_name=module,
)

def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
if isinstance(other, RunnableGenerator):
if hasattr(self, "_transform") and hasattr(other, "_transform"):
return self._transform == other._transform
Expand Down Expand Up @@ -4516,7 +4516,7 @@ def get_graph(self, config: RunnableConfig | None = None) -> Graph:

return graph

def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
if isinstance(other, RunnableLambda):
if hasattr(self, "func") and hasattr(other, "func"):
return self.func == other.func
Expand Down
14 changes: 12 additions & 2 deletions libs/core/langchain_core/utils/aiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,12 @@ class NoLock:
async def __aenter__(self) -> None:
pass

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
return False


Expand Down Expand Up @@ -223,7 +228,12 @@ def __iter__(self) -> Iterator[AsyncIterator[T]]:
async def __aenter__(self) -> "Tee[T]":
return self

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
await self.aclose()
return False

Expand Down
15 changes: 13 additions & 2 deletions libs/core/langchain_core/utils/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import Generator, Iterable, Iterator
from contextlib import AbstractContextManager
from itertools import islice
from types import TracebackType
from typing import (
Any,
Generic,
Expand All @@ -22,7 +23,12 @@ class NoLock:
def __enter__(self) -> None:
pass

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Literal[False]:
return False


Expand Down Expand Up @@ -167,7 +173,12 @@ def __iter__(self) -> Iterator[Iterator[T]]:
def __enter__(self) -> "Tee[T]":
return self

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Literal[False]:
self.close()
return False

Expand Down
7 changes: 1 addition & 6 deletions libs/core/langchain_core/utils/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,12 +374,7 @@ def get_fields(model: type[BaseModelV1]) -> dict[str, FieldInfoV1]: ...
def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ...

def get_fields(
model: Union[
BaseModelV2,
BaseModelV1,
type[BaseModelV2],
type[BaseModelV1],
],
model: Union[type[Union[BaseModelV2, BaseModelV1]], BaseModelV2, BaseModelV1],
) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]:
"""Get the field names of a Pydantic model."""
if hasattr(model, "model_fields"):
Expand Down
2 changes: 1 addition & 1 deletion libs/core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ python = ">=3.12.4"
[tool.poetry.extras]

[tool.ruff.lint]
select = [ "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TID", "UP", "W", "YTT",]
select = [ "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "PYI", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TID", "UP", "W", "YTT",]
ignore = [ "COM812", "UP007", "S110", "S112",]

[tool.coverage.run]
Expand Down
2 changes: 1 addition & 1 deletion libs/core/tests/unit_tests/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class AnyStr(str):
__slots__ = ()

def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, str)


Expand Down
2 changes: 1 addition & 1 deletion libs/core/tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2313,7 +2313,7 @@ class Bar(ToolOutputMixin):
def __init__(self, x: int) -> None:
self.x = x

def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.x == other.x

@tool
Expand Down
4 changes: 2 additions & 2 deletions libs/core/tests/unit_tests/utils/test_function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,12 +969,12 @@ class Tool(typed_dict):
)
def test_convert_union_type_py_39() -> None:
@tool
def magic_function(input: int | float) -> str:
def magic_function(input: int | str) -> str:
"""Compute a magic function."""

result = convert_to_openai_function(magic_function)
assert result["parameters"]["properties"]["input"] == {
"anyOf": [{"type": "integer"}, {"type": "number"}]
"anyOf": [{"type": "integer"}, {"type": "string"}]
}


Expand Down

0 comments on commit 9d30c94

Please sign in to comment.