Skip to content

Commit

Permalink
core: Add ruff rule FBT003 (boolean-trap)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Jan 25, 2025
1 parent dbb6b7b commit 1e1ba42
Show file tree
Hide file tree
Showing 10 changed files with 44 additions and 36 deletions.
7 changes: 5 additions & 2 deletions libs/core/langchain_core/_api/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning):


def _validate_deprecation_params(
pending: bool,
removal: str,
alternative: str,
alternative_import: str,
*,
pending: bool,
) -> None:
"""Validate the deprecation parameters."""
if pending and removal:
Expand Down Expand Up @@ -130,7 +131,9 @@ def deprecated(
def the_function_to_deprecate():
pass
"""
_validate_deprecation_params(pending, removal, alternative, alternative_import)
_validate_deprecation_params(
removal, alternative, alternative_import, pending=pending
)

def deprecate(
obj: T,
Expand Down
25 changes: 13 additions & 12 deletions libs/core/langchain_core/callbacks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def get_child(self, tag: Optional[str] = None) -> CallbackManager:
manager.add_tags(self.inheritable_tags)
manager.add_metadata(self.inheritable_metadata)
if tag is not None:
manager.add_tags([tag], False)
manager.add_tags([tag], inherit=False)
return manager


Expand Down Expand Up @@ -641,7 +641,7 @@ def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
manager.add_tags(self.inheritable_tags)
manager.add_metadata(self.inheritable_metadata)
if tag is not None:
manager.add_tags([tag], False)
manager.add_tags([tag], inherit=False)
return manager


Expand Down Expand Up @@ -1563,11 +1563,11 @@ def configure(
cls,
inheritable_callbacks,
local_callbacks,
verbose,
inheritable_tags,
local_tags,
inheritable_metadata,
local_metadata,
verbose=verbose,
)


Expand Down Expand Up @@ -2087,11 +2087,11 @@ def configure(
cls,
inheritable_callbacks,
local_callbacks,
verbose,
inheritable_tags,
local_tags,
inheritable_metadata,
local_metadata,
verbose=verbose,
)


Expand Down Expand Up @@ -2236,11 +2236,12 @@ def _configure(
callback_manager_cls: type[T],
inheritable_callbacks: Callbacks = None,
local_callbacks: Callbacks = None,
verbose: bool = False,
inheritable_tags: Optional[list[str]] = None,
local_tags: Optional[list[str]] = None,
inheritable_metadata: Optional[dict[str, Any]] = None,
local_metadata: Optional[dict[str, Any]] = None,
*,
verbose: bool = False,
) -> T:
"""Configure the callback manager.
Expand Down Expand Up @@ -2314,13 +2315,13 @@ def _configure(
else (local_callbacks.handlers if local_callbacks else [])
)
for handler in local_handlers_:
callback_manager.add_handler(handler, False)
callback_manager.add_handler(handler, inherit=False)
if inheritable_tags or local_tags:
callback_manager.add_tags(inheritable_tags or [])
callback_manager.add_tags(local_tags or [], False)
callback_manager.add_tags(local_tags or [], inherit=False)
if inheritable_metadata or local_metadata:
callback_manager.add_metadata(inheritable_metadata or {})
callback_manager.add_metadata(local_metadata or {}, False)
callback_manager.add_metadata(local_metadata or {}, inherit=False)
if tracing_metadata:
callback_manager.add_metadata(tracing_metadata.copy())
if tracing_tags:
Expand Down Expand Up @@ -2355,18 +2356,18 @@ def _configure(
if debug:
pass
else:
callback_manager.add_handler(StdOutCallbackHandler(), False)
callback_manager.add_handler(StdOutCallbackHandler(), inherit=False)
if debug and not any(
isinstance(handler, ConsoleCallbackHandler)
for handler in callback_manager.handlers
):
callback_manager.add_handler(ConsoleCallbackHandler(), True)
callback_manager.add_handler(ConsoleCallbackHandler())
if tracing_v2_enabled_ and not any(
isinstance(handler, LangChainTracer)
for handler in callback_manager.handlers
):
if tracer_v2:
callback_manager.add_handler(tracer_v2, True)
callback_manager.add_handler(tracer_v2)
else:
try:
handler = LangChainTracer(
Expand All @@ -2378,7 +2379,7 @@ def _configure(
),
tags=tracing_tags,
)
callback_manager.add_handler(handler, True)
callback_manager.add_handler(handler)
except Exception as e:
logger.warning(
"Unable to load requested LangChainTracer."
Expand Down
5 changes: 3 additions & 2 deletions libs/core/langchain_core/language_models/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,7 @@ async def _agenerate_helper(
prompts: list[str],
stop: Optional[list[str]],
run_managers: list[AsyncCallbackManagerForLLMRun],
*,
new_arg_supported: bool,
**kwargs: Any,
) -> LLMResult:
Expand Down Expand Up @@ -1212,7 +1213,7 @@ async def agenerate(
prompts,
stop,
run_managers, # type: ignore[arg-type]
bool(new_arg_supported),
new_arg_supported=bool(new_arg_supported),
**kwargs, # type: ignore[arg-type]
)
return output
Expand All @@ -1235,7 +1236,7 @@ async def agenerate(
missing_prompts,
stop,
run_managers, # type: ignore[arg-type]
bool(new_arg_supported),
new_arg_supported=bool(new_arg_supported),
**kwargs, # type: ignore[arg-type]
)
llm_output = await aupdate_cache(
Expand Down
8 changes: 5 additions & 3 deletions libs/core/langchain_core/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _get_filtered_args(


def _parse_python_function_docstring(
function: Callable, annotations: dict, error_on_invalid_docstring: bool = False
function: Callable, annotations: dict, *, error_on_invalid_docstring: bool = False
) -> tuple[str, dict]:
"""Parse the function and argument descriptions from the docstring of a function.
Expand Down Expand Up @@ -1073,7 +1073,7 @@ def get_all_basemodel_annotations(
generic_map = dict(zip(generic_type_vars, get_args(parent)))
for field in getattr(parent_origin, "__annotations__", {}):
annotations[field] = _replace_type_vars(
annotations[field], generic_map, default_to_bound
annotations[field], generic_map, default_to_bound=default_to_bound
)

return {
Expand All @@ -1085,6 +1085,7 @@ def get_all_basemodel_annotations(
def _replace_type_vars(
type_: type,
generic_map: Optional[dict[TypeVar, type]] = None,
*,
default_to_bound: bool = True,
) -> type:
generic_map = generic_map or {}
Expand All @@ -1097,7 +1098,8 @@ def _replace_type_vars(
return type_
elif (origin := get_origin(type_)) and (args := get_args(type_)):
new_args = tuple(
_replace_type_vars(arg, generic_map, default_to_bound) for arg in args
_replace_type_vars(arg, generic_map, default_to_bound=default_to_bound)
for arg in args
)
return _py_38_safe_origin(origin)[new_args] # type: ignore[index]
else:
Expand Down
4 changes: 2 additions & 2 deletions libs/core/langchain_core/tracers/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _get_trace_callbacks(
isinstance(handler, LangChainTracer)
for handler in callback_manager.handlers
):
callback_manager.add_handler(tracer, True)
callback_manager.add_handler(tracer)
# If it already has a LangChainTracer, we don't need to add another one.
# this would likely mess up the trace hierarchy.
cb = callback_manager
Expand Down Expand Up @@ -217,4 +217,4 @@ def register_configure_hook(
)


register_configure_hook(run_collector_var, False)
register_configure_hook(run_collector_var, inheritable=False)
1 change: 1 addition & 0 deletions libs/core/langchain_core/utils/mustache.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def _html_escape(string: str) -> str:
def _get_key(
key: str,
scopes: Scopes,
*,
warn: bool,
keep: bool,
def_ldel: str,
Expand Down
2 changes: 1 addition & 1 deletion libs/core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ python = ">=3.12.4"
[tool.poetry.extras]

[tool.ruff.lint]
select = [ "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TID", "TRY", "UP", "W", "YTT",]
select = [ "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FBT003", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TID", "TRY", "UP", "W", "YTT",]
ignore = [ "COM812", "UP007", "S110", "S112",]

[tool.coverage.run]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class CustomHandler(AsyncCallbackHandler):
called.
"""

def __init__(self, run_inline: bool) -> None:
def __init__(self, *, run_inline: bool) -> None:
"""Initialize the handler."""
self.run_inline = run_inline

Expand Down Expand Up @@ -91,7 +91,7 @@ async def set_counter_var() -> Any:
counter_var.reset(token)

class StatefulAsyncCallbackHandler(AsyncCallbackHandler):
def __init__(self, name: str, run_inline: bool = True):
def __init__(self, name: str, *, run_inline: bool = True):
self.name = name
self.run_inline = run_inline

Expand Down
22 changes: 11 additions & 11 deletions libs/core/tests/unit_tests/output_parsers/test_openai_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@
]


def _get_iter(use_tool_calls: bool = False) -> Any:
def _get_iter(*, use_tool_calls: bool = False) -> Any:
if use_tool_calls:
list_to_iter = STREAMED_MESSAGES_WITH_TOOL_CALLS
else:
Expand All @@ -374,7 +374,7 @@ def input_iter(_: Any) -> Iterator[BaseMessage]:
return input_iter


def _get_aiter(use_tool_calls: bool = False) -> Any:
def _get_aiter(*, use_tool_calls: bool = False) -> Any:
if use_tool_calls:
list_to_iter = STREAMED_MESSAGES_WITH_TOOL_CALLS
else:
Expand All @@ -389,7 +389,7 @@ async def input_iter(_: Any) -> AsyncIterator[BaseMessage]:

def test_partial_json_output_parser() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_iter(use_tool_calls)
input_iter = _get_iter(use_tool_calls=use_tool_calls)
chain = input_iter | JsonOutputToolsParser()

actual = list(chain.stream(None))
Expand All @@ -402,7 +402,7 @@ def test_partial_json_output_parser() -> None:

async def test_partial_json_output_parser_async() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_aiter(use_tool_calls)
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
chain = input_iter | JsonOutputToolsParser()

actual = [p async for p in chain.astream(None)]
Expand All @@ -415,7 +415,7 @@ async def test_partial_json_output_parser_async() -> None:

def test_partial_json_output_parser_return_id() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_iter(use_tool_calls)
input_iter = _get_iter(use_tool_calls=use_tool_calls)
chain = input_iter | JsonOutputToolsParser(return_id=True)

actual = list(chain.stream(None))
Expand All @@ -434,7 +434,7 @@ def test_partial_json_output_parser_return_id() -> None:

def test_partial_json_output_key_parser() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_iter(use_tool_calls)
input_iter = _get_iter(use_tool_calls=use_tool_calls)
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")

actual = list(chain.stream(None))
Expand All @@ -444,7 +444,7 @@ def test_partial_json_output_key_parser() -> None:

async def test_partial_json_output_parser_key_async() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_aiter(use_tool_calls)
input_iter = _get_aiter(use_tool_calls=use_tool_calls)

chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")

Expand All @@ -455,7 +455,7 @@ async def test_partial_json_output_parser_key_async() -> None:

def test_partial_json_output_key_parser_first_only() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_iter(use_tool_calls)
input_iter = _get_iter(use_tool_calls=use_tool_calls)

chain = input_iter | JsonOutputKeyToolsParser(
key_name="NameCollector", first_tool_only=True
Expand All @@ -466,7 +466,7 @@ def test_partial_json_output_key_parser_first_only() -> None:

async def test_partial_json_output_parser_key_async_first_only() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_aiter(use_tool_calls)
input_iter = _get_aiter(use_tool_calls=use_tool_calls)

chain = input_iter | JsonOutputKeyToolsParser(
key_name="NameCollector", first_tool_only=True
Expand Down Expand Up @@ -507,7 +507,7 @@ class NameCollector(BaseModel):

def test_partial_pydantic_output_parser() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_iter(use_tool_calls)
input_iter = _get_iter(use_tool_calls=use_tool_calls)

chain = input_iter | PydanticToolsParser(
tools=[NameCollector], first_tool_only=True
Expand All @@ -519,7 +519,7 @@ def test_partial_pydantic_output_parser() -> None:

async def test_partial_pydantic_output_parser_async() -> None:
for use_tool_calls in [False, True]:
input_iter = _get_aiter(use_tool_calls)
input_iter = _get_aiter(use_tool_calls=use_tool_calls)

chain = input_iter | PydanticToolsParser(
tools=[NameCollector], first_tool_only=True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ async def _as_async_iterator(iterable: list) -> AsyncIterator:


async def _collect_events(
events: AsyncIterator[StreamEvent], with_nulled_ids: bool = True
events: AsyncIterator[StreamEvent], *, with_nulled_ids: bool = True
) -> list[StreamEvent]:
"""Collect the events and remove the run ids."""
materialized_events = [event async for event in events]
Expand Down

0 comments on commit 1e1ba42

Please sign in to comment.