From 1e1ba4240a221d5dc87540b990274762045761e9 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Sat, 25 Jan 2025 18:00:57 +0100 Subject: [PATCH] core: Add ruff rule FBT003 (boolean-trap) --- libs/core/langchain_core/_api/deprecation.py | 7 ++++-- libs/core/langchain_core/callbacks/manager.py | 25 ++++++++++--------- .../langchain_core/language_models/llms.py | 5 ++-- libs/core/langchain_core/tools/base.py | 8 +++--- libs/core/langchain_core/tracers/context.py | 4 +-- libs/core/langchain_core/utils/mustache.py | 1 + libs/core/pyproject.toml | 2 +- .../callbacks/test_async_callback_manager.py | 4 +-- .../output_parsers/test_openai_tools.py | 22 ++++++++-------- .../runnables/test_runnable_events_v2.py | 2 +- 10 files changed, 44 insertions(+), 36 deletions(-) diff --git a/libs/core/langchain_core/_api/deprecation.py b/libs/core/langchain_core/_api/deprecation.py index a2cccaa2be3a3..d243120d77fdc 100644 --- a/libs/core/langchain_core/_api/deprecation.py +++ b/libs/core/langchain_core/_api/deprecation.py @@ -44,10 +44,11 @@ class LangChainPendingDeprecationWarning(PendingDeprecationWarning): def _validate_deprecation_params( - pending: bool, removal: str, alternative: str, alternative_import: str, + *, + pending: bool, ) -> None: """Validate the deprecation parameters.""" if pending and removal: @@ -130,7 +131,9 @@ def deprecated( def the_function_to_deprecate(): pass """ - _validate_deprecation_params(pending, removal, alternative, alternative_import) + _validate_deprecation_params( + removal, alternative, alternative_import, pending=pending + ) def deprecate( obj: T, diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index f675b230c879b..95210e22e66c4 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -560,7 +560,7 @@ def get_child(self, tag: Optional[str] = None) -> CallbackManager: manager.add_tags(self.inheritable_tags) manager.add_metadata(self.inheritable_metadata) if tag is not None: - manager.add_tags([tag], False) + manager.add_tags([tag], inherit=False) return manager @@ -641,7 +641,7 @@ def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager: manager.add_tags(self.inheritable_tags) manager.add_metadata(self.inheritable_metadata) if tag is not None: - manager.add_tags([tag], False) + manager.add_tags([tag], inherit=False) return manager @@ -1563,11 +1563,11 @@ def configure( cls, inheritable_callbacks, local_callbacks, - verbose, inheritable_tags, local_tags, inheritable_metadata, local_metadata, + verbose=verbose, ) @@ -2087,11 +2087,11 @@ def configure( cls, inheritable_callbacks, local_callbacks, - verbose, inheritable_tags, local_tags, inheritable_metadata, local_metadata, + verbose=verbose, ) @@ -2236,11 +2236,12 @@ def _configure( callback_manager_cls: type[T], inheritable_callbacks: Callbacks = None, local_callbacks: Callbacks = None, - verbose: bool = False, inheritable_tags: Optional[list[str]] = None, local_tags: Optional[list[str]] = None, inheritable_metadata: Optional[dict[str, Any]] = None, local_metadata: Optional[dict[str, Any]] = None, + *, + verbose: bool = False, ) -> T: """Configure the callback manager. @@ -2314,13 +2315,13 @@ def _configure( else (local_callbacks.handlers if local_callbacks else []) ) for handler in local_handlers_: - callback_manager.add_handler(handler, False) + callback_manager.add_handler(handler, inherit=False) if inheritable_tags or local_tags: callback_manager.add_tags(inheritable_tags or []) - callback_manager.add_tags(local_tags or [], False) + callback_manager.add_tags(local_tags or [], inherit=False) if inheritable_metadata or local_metadata: callback_manager.add_metadata(inheritable_metadata or {}) - callback_manager.add_metadata(local_metadata or {}, False) + callback_manager.add_metadata(local_metadata or {}, inherit=False) if tracing_metadata: callback_manager.add_metadata(tracing_metadata.copy()) if tracing_tags: @@ -2355,18 +2356,18 @@ def _configure( if debug: pass else: - callback_manager.add_handler(StdOutCallbackHandler(), False) + callback_manager.add_handler(StdOutCallbackHandler(), inherit=False) if debug and not any( isinstance(handler, ConsoleCallbackHandler) for handler in callback_manager.handlers ): - callback_manager.add_handler(ConsoleCallbackHandler(), True) + callback_manager.add_handler(ConsoleCallbackHandler()) if tracing_v2_enabled_ and not any( isinstance(handler, LangChainTracer) for handler in callback_manager.handlers ): if tracer_v2: - callback_manager.add_handler(tracer_v2, True) + callback_manager.add_handler(tracer_v2) else: try: handler = LangChainTracer( @@ -2378,7 +2379,7 @@ def _configure( ), tags=tracing_tags, ) - callback_manager.add_handler(handler, True) + callback_manager.add_handler(handler) except Exception as e: logger.warning( "Unable to load requested LangChainTracer." diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index 4ba16f516965d..b18e9c2afe1b7 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -1019,6 +1019,7 @@ async def _agenerate_helper( prompts: list[str], stop: Optional[list[str]], run_managers: list[AsyncCallbackManagerForLLMRun], + *, new_arg_supported: bool, **kwargs: Any, ) -> LLMResult: @@ -1212,7 +1213,7 @@ async def agenerate( prompts, stop, run_managers, # type: ignore[arg-type] - bool(new_arg_supported), + new_arg_supported=bool(new_arg_supported), **kwargs, # type: ignore[arg-type] ) return output @@ -1235,7 +1236,7 @@ async def agenerate( missing_prompts, stop, run_managers, # type: ignore[arg-type] - bool(new_arg_supported), + new_arg_supported=bool(new_arg_supported), **kwargs, # type: ignore[arg-type] ) llm_output = await aupdate_cache( diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index f0833cdc24a14..391f906908b47 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -108,7 +108,7 @@ def _get_filtered_args( def _parse_python_function_docstring( - function: Callable, annotations: dict, error_on_invalid_docstring: bool = False + function: Callable, annotations: dict, *, error_on_invalid_docstring: bool = False ) -> tuple[str, dict]: """Parse the function and argument descriptions from the docstring of a function. @@ -1073,7 +1073,7 @@ def get_all_basemodel_annotations( generic_map = dict(zip(generic_type_vars, get_args(parent))) for field in getattr(parent_origin, "__annotations__", {}): annotations[field] = _replace_type_vars( - annotations[field], generic_map, default_to_bound + annotations[field], generic_map, default_to_bound=default_to_bound ) return { @@ -1085,6 +1085,7 @@ def get_all_basemodel_annotations( def _replace_type_vars( type_: type, generic_map: Optional[dict[TypeVar, type]] = None, + *, default_to_bound: bool = True, ) -> type: generic_map = generic_map or {} @@ -1097,7 +1098,8 @@ def _replace_type_vars( return type_ elif (origin := get_origin(type_)) and (args := get_args(type_)): new_args = tuple( - _replace_type_vars(arg, generic_map, default_to_bound) for arg in args + _replace_type_vars(arg, generic_map, default_to_bound=default_to_bound) + for arg in args ) return _py_38_safe_origin(origin)[new_args] # type: ignore[index] else: diff --git a/libs/core/langchain_core/tracers/context.py b/libs/core/langchain_core/tracers/context.py index d36adc9bc8eaf..7d6568713077e 100644 --- a/libs/core/langchain_core/tracers/context.py +++ b/libs/core/langchain_core/tracers/context.py @@ -133,7 +133,7 @@ def _get_trace_callbacks( isinstance(handler, LangChainTracer) for handler in callback_manager.handlers ): - callback_manager.add_handler(tracer, True) + callback_manager.add_handler(tracer) # If it already has a LangChainTracer, we don't need to add another one. # this would likely mess up the trace hierarchy. cb = callback_manager @@ -217,4 +217,4 @@ def register_configure_hook( ) -register_configure_hook(run_collector_var, False) +register_configure_hook(run_collector_var, inheritable=False) diff --git a/libs/core/langchain_core/utils/mustache.py b/libs/core/langchain_core/utils/mustache.py index ee2ed8f2528f8..2ab78ac1fd587 100644 --- a/libs/core/langchain_core/utils/mustache.py +++ b/libs/core/langchain_core/utils/mustache.py @@ -331,6 +331,7 @@ def _html_escape(string: str) -> str: def _get_key( key: str, scopes: Scopes, + *, warn: bool, keep: bool, def_ldel: str, diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index f9b731c730163..2f47c56c2deab 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -44,7 +44,7 @@ python = ">=3.12.4" [tool.poetry.extras] [tool.ruff.lint] -select = [ "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TID", "TRY", "UP", "W", "YTT",] +select = [ "ASYNC", "B", "C4", "COM", "DJ", "E", "EM", "EXE", "F", "FBT003", "FLY", "FURB", "I", "ICN", "INT", "LOG", "N", "NPY", "PD", "PIE", "Q", "RSE", "S", "SIM", "SLOT", "T10", "T201", "TID", "TRY", "UP", "W", "YTT",] ignore = [ "COM812", "UP007", "S110", "S112",] [tool.coverage.run] diff --git a/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py b/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py index 38350f9d82f7b..ac565e43e39cd 100644 --- a/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py +++ b/libs/core/tests/unit_tests/callbacks/test_async_callback_manager.py @@ -41,7 +41,7 @@ class CustomHandler(AsyncCallbackHandler): called. """ - def __init__(self, run_inline: bool) -> None: + def __init__(self, *, run_inline: bool) -> None: """Initialize the handler.""" self.run_inline = run_inline @@ -91,7 +91,7 @@ async def set_counter_var() -> Any: counter_var.reset(token) class StatefulAsyncCallbackHandler(AsyncCallbackHandler): - def __init__(self, name: str, run_inline: bool = True): + def __init__(self, name: str, *, run_inline: bool = True): self.name = name self.run_inline = run_inline diff --git a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py index d4940bab2bbea..e5fe0f3076cd8 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py +++ b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py @@ -362,7 +362,7 @@ ] -def _get_iter(use_tool_calls: bool = False) -> Any: +def _get_iter(*, use_tool_calls: bool = False) -> Any: if use_tool_calls: list_to_iter = STREAMED_MESSAGES_WITH_TOOL_CALLS else: @@ -374,7 +374,7 @@ def input_iter(_: Any) -> Iterator[BaseMessage]: return input_iter -def _get_aiter(use_tool_calls: bool = False) -> Any: +def _get_aiter(*, use_tool_calls: bool = False) -> Any: if use_tool_calls: list_to_iter = STREAMED_MESSAGES_WITH_TOOL_CALLS else: @@ -389,7 +389,7 @@ async def input_iter(_: Any) -> AsyncIterator[BaseMessage]: def test_partial_json_output_parser() -> None: for use_tool_calls in [False, True]: - input_iter = _get_iter(use_tool_calls) + input_iter = _get_iter(use_tool_calls=use_tool_calls) chain = input_iter | JsonOutputToolsParser() actual = list(chain.stream(None)) @@ -402,7 +402,7 @@ def test_partial_json_output_parser() -> None: async def test_partial_json_output_parser_async() -> None: for use_tool_calls in [False, True]: - input_iter = _get_aiter(use_tool_calls) + input_iter = _get_aiter(use_tool_calls=use_tool_calls) chain = input_iter | JsonOutputToolsParser() actual = [p async for p in chain.astream(None)] @@ -415,7 +415,7 @@ async def test_partial_json_output_parser_async() -> None: def test_partial_json_output_parser_return_id() -> None: for use_tool_calls in [False, True]: - input_iter = _get_iter(use_tool_calls) + input_iter = _get_iter(use_tool_calls=use_tool_calls) chain = input_iter | JsonOutputToolsParser(return_id=True) actual = list(chain.stream(None)) @@ -434,7 +434,7 @@ def test_partial_json_output_parser_return_id() -> None: def test_partial_json_output_key_parser() -> None: for use_tool_calls in [False, True]: - input_iter = _get_iter(use_tool_calls) + input_iter = _get_iter(use_tool_calls=use_tool_calls) chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector") actual = list(chain.stream(None)) @@ -444,7 +444,7 @@ def test_partial_json_output_key_parser() -> None: async def test_partial_json_output_parser_key_async() -> None: for use_tool_calls in [False, True]: - input_iter = _get_aiter(use_tool_calls) + input_iter = _get_aiter(use_tool_calls=use_tool_calls) chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector") @@ -455,7 +455,7 @@ async def test_partial_json_output_parser_key_async() -> None: def test_partial_json_output_key_parser_first_only() -> None: for use_tool_calls in [False, True]: - input_iter = _get_iter(use_tool_calls) + input_iter = _get_iter(use_tool_calls=use_tool_calls) chain = input_iter | JsonOutputKeyToolsParser( key_name="NameCollector", first_tool_only=True @@ -466,7 +466,7 @@ def test_partial_json_output_key_parser_first_only() -> None: async def test_partial_json_output_parser_key_async_first_only() -> None: for use_tool_calls in [False, True]: - input_iter = _get_aiter(use_tool_calls) + input_iter = _get_aiter(use_tool_calls=use_tool_calls) chain = input_iter | JsonOutputKeyToolsParser( key_name="NameCollector", first_tool_only=True @@ -507,7 +507,7 @@ class NameCollector(BaseModel): def test_partial_pydantic_output_parser() -> None: for use_tool_calls in [False, True]: - input_iter = _get_iter(use_tool_calls) + input_iter = _get_iter(use_tool_calls=use_tool_calls) chain = input_iter | PydanticToolsParser( tools=[NameCollector], first_tool_only=True @@ -519,7 +519,7 @@ def test_partial_pydantic_output_parser() -> None: async def test_partial_pydantic_output_parser_async() -> None: for use_tool_calls in [False, True]: - input_iter = _get_aiter(use_tool_calls) + input_iter = _get_aiter(use_tool_calls=use_tool_calls) chain = input_iter | PydanticToolsParser( tools=[NameCollector], first_tool_only=True diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index 698a4c4ddab85..cc9a13dc10e92 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -75,7 +75,7 @@ async def _as_async_iterator(iterable: list) -> AsyncIterator: async def _collect_events( - events: AsyncIterator[StreamEvent], with_nulled_ids: bool = True + events: AsyncIterator[StreamEvent], *, with_nulled_ids: bool = True ) -> list[StreamEvent]: """Collect the events and remove the run ids.""" materialized_events = [event async for event in events]