From e0ed709c85322ef06fba8bc40197a04a7b035901 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Tue, 14 Jan 2025 14:29:23 +0100 Subject: [PATCH] Bump ruff version to 0.9 --- .../langchain_core/_api/beta_decorator.py | 2 +- libs/core/langchain_core/_api/deprecation.py | 1 + libs/core/langchain_core/callbacks/manager.py | 3 +- .../core/langchain_core/indexing/in_memory.py | 3 +- .../langchain_core/language_models/base.py | 2 +- .../langchain_core/language_models/llms.py | 3 +- libs/core/langchain_core/messages/utils.py | 10 +- .../langchain_core/output_parsers/list.py | 2 +- libs/core/langchain_core/prompts/chat.py | 4 +- .../langchain_core/pydantic_v1/__init__.py | 2 +- libs/core/langchain_core/rate_limiters.py | 4 +- libs/core/langchain_core/runnables/base.py | 30 +-- libs/core/langchain_core/runnables/config.py | 4 +- .../langchain_core/runnables/graph_ascii.py | 15 +- libs/core/langchain_core/runnables/history.py | 4 +- .../langchain_core/runnables/passthrough.py | 30 +-- libs/core/langchain_core/runnables/retry.py | 4 +- libs/core/langchain_core/runnables/router.py | 6 +- .../langchain_core/tracers/event_stream.py | 2 +- libs/core/langchain_core/tracers/stdout.py | 4 +- libs/core/langchain_core/utils/formatting.py | 2 +- libs/core/langchain_core/utils/mustache.py | 7 +- libs/core/langchain_core/utils/pydantic.py | 2 + .../langchain_core/vectorstores/in_memory.py | 14 +- libs/core/poetry.lock | 48 ++-- libs/core/pyproject.toml | 2 +- .../chat_history/test_chat_history.py | 15 +- .../test_deterministic_embedding.py | 5 +- .../language_models/chat_models/test_cache.py | 2 +- .../tests/unit_tests/messages/test_utils.py | 2 +- .../unit_tests/output_parsers/test_json.py | 6 +- .../output_parsers/test_list_parser.py | 3 +- .../tests/unit_tests/prompts/test_chat.py | 27 ++- .../tests/unit_tests/prompts/test_few_shot.py | 2 +- .../tests/unit_tests/runnables/test_config.py | 12 +- .../unit_tests/runnables/test_fallbacks.py | 6 +- .../unit_tests/runnables/test_runnable.py | 78 +++--- .../runnables/test_runnable_events_v1.py | 19 +- .../runnables/test_runnable_events_v2.py | 36 ++- .../runnables/test_tracing_interops.py | 14 +- libs/core/tests/unit_tests/test_messages.py | 223 +++++++++--------- libs/core/tests/unit_tests/test_outputs.py | 92 ++++---- libs/core/tests/unit_tests/test_tools.py | 50 +++- .../unit_tests/tracers/test_memory_stream.py | 12 +- .../unit_tests/utils/test_function_calling.py | 4 +- .../core/tests/unit_tests/utils/test_usage.py | 10 +- 46 files changed, 473 insertions(+), 355 deletions(-) diff --git a/libs/core/langchain_core/_api/beta_decorator.py b/libs/core/langchain_core/_api/beta_decorator.py index 4d6810dbec347..ec3255762879b 100644 --- a/libs/core/langchain_core/_api/beta_decorator.py +++ b/libs/core/langchain_core/_api/beta_decorator.py @@ -215,7 +215,7 @@ def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: old_doc = inspect.cleandoc(old_doc or "").strip("\n") or "" components = [message, addendum] details = " ".join([component.strip() for component in components if component]) - new_doc = f".. beta::\n" f" {details}\n\n" f"{old_doc}\n" + new_doc = f".. beta::\n {details}\n\n{old_doc}\n" if inspect.iscoroutinefunction(obj): finalized = finalize(awarning_emitting_wrapper, new_doc) diff --git a/libs/core/langchain_core/_api/deprecation.py b/libs/core/langchain_core/_api/deprecation.py index 0e254053d8819..b6b26ef3dfc55 100644 --- a/libs/core/langchain_core/_api/deprecation.py +++ b/libs/core/langchain_core/_api/deprecation.py @@ -240,6 +240,7 @@ def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: exclude=obj.exclude, ), ) + elif isinstance(obj, FieldInfoV2): wrapped = None if not _obj_type: diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index cabe04b6a94d1..6c2a09ac9902c 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -385,8 +385,7 @@ async def _ahandle_event_for_handler( ) except Exception as e: logger.warning( - f"Error in {handler.__class__.__name__}.{event_name} callback:" - f" {repr(e)}" + f"Error in {handler.__class__.__name__}.{event_name} callback: {repr(e)}" ) if handler.raise_error: raise e diff --git a/libs/core/langchain_core/indexing/in_memory.py b/libs/core/langchain_core/indexing/in_memory.py index 7fd7adec6646e..cf12d0012f15d 100644 --- a/libs/core/langchain_core/indexing/in_memory.py +++ b/libs/core/langchain_core/indexing/in_memory.py @@ -1,3 +1,4 @@ +import operator import uuid from collections.abc import Sequence from typing import Any, Optional, cast @@ -80,5 +81,5 @@ def _get_relevant_documents( count = document.page_content.count(query) counts_by_doc.append((document, count)) - counts_by_doc.sort(key=lambda x: x[1], reverse=True) + counts_by_doc.sort(key=operator.itemgetter(1), reverse=True) return [doc.model_copy() for doc, count in counts_by_doc[: self.top_k]] diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py index 051550dfe7f85..12445f2560f53 100644 --- a/libs/core/langchain_core/language_models/base.py +++ b/libs/core/langchain_core/language_models/base.py @@ -390,7 +390,7 @@ def get_num_tokens_from_messages( "Counting tokens in tool schemas is not yet supported. Ignoring tools.", stacklevel=2, ) - return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages]) + return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages) @classmethod def _all_required_field_names(cls) -> set: diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index 7fd47e627d739..08e1d041e0ec8 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -349,8 +349,7 @@ def _get_ls_params( # get default provider from class name default_provider = self.__class__.__name__ - if default_provider.endswith("LLM"): - default_provider = default_provider[:-3] + default_provider = default_provider.removesuffix("LLM") default_provider = default_provider.lower() ls_params = LangSmithParams(ls_provider=default_provider, ls_model_type="llm") diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 435bd1e2f1e07..81107c076fe73 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -1009,7 +1009,10 @@ def convert_to_openai_messages( ) raise ValueError(err) content.append( - {"type": "image_url", "image_url": block["image_url"]} + { + "type": "image_url", + "image_url": block["image_url"], + } ) # Anthropic and Bedrock converse format elif (block.get("type") == "image") or "image" in block: @@ -1128,7 +1131,10 @@ def convert_to_openai_messages( ) raise ValueError(msg) content.append( - {"type": "text", "text": json.dumps(block["json"])} + { + "type": "text", + "text": json.dumps(block["json"]), + } ) elif ( block.get("type") == "guard_content" diff --git a/libs/core/langchain_core/output_parsers/list.py b/libs/core/langchain_core/output_parsers/list.py index ebaca8f8ca94f..fc2326db8011e 100644 --- a/libs/core/langchain_core/output_parsers/list.py +++ b/libs/core/langchain_core/output_parsers/list.py @@ -225,7 +225,7 @@ class MarkdownListOutputParser(ListOutputParser): def get_format_instructions(self) -> str: """Return the format instructions for the Markdown list output.""" - return "Your response should be a markdown list, " "eg: `- foo\n- bar\n- baz`" + return "Your response should be a markdown list, eg: `- foo\n- bar\n- baz`" def parse(self, text: str) -> list[str]: """Parse the output of an LLM call. diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 1629962ba1333..a7c54c8268ddf 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -1402,9 +1402,7 @@ def _create_template_from_message_type( elif len(template) == 2 and isinstance(template[1], bool): var_name_wrapped, is_optional = template if not isinstance(var_name_wrapped, str): - msg = ( - "Expected variable name to be a string." f" Got: {var_name_wrapped}" - ) + msg = f"Expected variable name to be a string. Got: {var_name_wrapped}" raise ValueError(msg) if var_name_wrapped[0] != "{" or var_name_wrapped[-1] != "}": msg = ( diff --git a/libs/core/langchain_core/pydantic_v1/__init__.py b/libs/core/langchain_core/pydantic_v1/__init__.py index 6dfc11dc75513..b70686f1f4a90 100644 --- a/libs/core/langchain_core/pydantic_v1/__init__.py +++ b/libs/core/langchain_core/pydantic_v1/__init__.py @@ -2,7 +2,7 @@ from langchain_core._api.deprecation import warn_deprecated -## Create namespaces for pydantic v1 and v2. +# Create namespaces for pydantic v1 and v2. # This code must stay at the top of the file before other modules may # attempt to import pydantic since it adds pydantic_v1 and pydantic_v2 to sys.modules. # diff --git a/libs/core/langchain_core/rate_limiters.py b/libs/core/langchain_core/rate_limiters.py index 11588020f6ba7..87b47b922d06a 100644 --- a/libs/core/langchain_core/rate_limiters.py +++ b/libs/core/langchain_core/rate_limiters.py @@ -248,14 +248,14 @@ async def aacquire(self, *, blocking: bool = True) -> bool: if not blocking: return self._consume() - while not self._consume(): + while not self._consume(): # noqa: ASYNC110 # This code ignores the ASYNC110 warning which is a false positive in this # case. # There is no external actor that can mark that the Event is done # since the tokens are managed by the rate limiter itself. # It needs to wake up to re-fill the tokens. # https://docs.astral.sh/ruff/rules/async-busy-wait/ - await asyncio.sleep(self.check_every_n_seconds) # ruff: noqa: ASYNC110 + await asyncio.sleep(self.check_every_n_seconds) return True diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 893f393d8b174..dc8dec05fe4c1 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -2868,7 +2868,7 @@ def config_specs(self) -> list[ConfigurableFieldSpec]: # calculate context dependencies specs_by_pos = groupby( [tup for tup in all_specs if tup[0].id.startswith(CONTEXT_CONFIG_PREFIX)], - lambda x: x[1], + itemgetter(1), ) next_deps: set[str] = set() deps_by_pos: dict[int, set[str]] = {} @@ -3012,7 +3012,7 @@ def invoke( for i, step in enumerate(self.steps): # mark each step as a child run config = patch_config( - config, callbacks=run_manager.get_child(f"seq:step:{i+1}") + config, callbacks=run_manager.get_child(f"seq:step:{i + 1}") ) context = copy_context() context.run(_set_config_context, config) @@ -3052,7 +3052,7 @@ async def ainvoke( for i, step in enumerate(self.steps): # mark each step as a child run config = patch_config( - config, callbacks=run_manager.get_child(f"seq:step:{i+1}") + config, callbacks=run_manager.get_child(f"seq:step:{i + 1}") ) context = copy_context() context.run(_set_config_context, config) @@ -3137,7 +3137,8 @@ def batch( [ # each step a child run of the corresponding root run patch_config( - config, callbacks=rm.get_child(f"seq:step:{stepidx+1}") + config, + callbacks=rm.get_child(f"seq:step:{stepidx + 1}"), ) for i, (rm, config) in enumerate(zip(run_managers, configs)) if i not in failed_inputs_map @@ -3169,7 +3170,7 @@ def batch( [ # each step a child run of the corresponding root run patch_config( - config, callbacks=rm.get_child(f"seq:step:{i+1}") + config, callbacks=rm.get_child(f"seq:step:{i + 1}") ) for rm, config in zip(run_managers, configs) ], @@ -3266,7 +3267,8 @@ async def abatch( [ # each step a child run of the corresponding root run patch_config( - config, callbacks=rm.get_child(f"seq:step:{stepidx+1}") + config, + callbacks=rm.get_child(f"seq:step:{stepidx + 1}"), ) for i, (rm, config) in enumerate(zip(run_managers, configs)) if i not in failed_inputs_map @@ -3298,7 +3300,7 @@ async def abatch( [ # each step a child run of the corresponding root run patch_config( - config, callbacks=rm.get_child(f"seq:step:{i+1}") + config, callbacks=rm.get_child(f"seq:step:{i + 1}") ) for rm, config in zip(run_managers, configs) ], @@ -3345,7 +3347,7 @@ def _transform( final_pipeline = cast(Iterator[Output], input) for idx, step in enumerate(steps): config = patch_config( - config, callbacks=run_manager.get_child(f"seq:step:{idx+1}") + config, callbacks=run_manager.get_child(f"seq:step:{idx + 1}") ) if idx == 0: final_pipeline = step.transform(final_pipeline, config, **kwargs) @@ -3374,7 +3376,7 @@ async def _atransform( for idx, step in enumerate(steps): config = patch_config( config, - callbacks=run_manager.get_child(f"seq:step:{idx+1}"), + callbacks=run_manager.get_child(f"seq:step:{idx + 1}"), ) if idx == 0: final_pipeline = step.atransform(final_pipeline, config, **kwargs) @@ -4406,7 +4408,7 @@ def get_input_schema( if dict_keys := get_function_first_arg_dict_keys(func): return create_model_v2( self.get_name("Input"), - field_definitions={key: (Any, ...) for key in dict_keys}, + field_definitions=dict.fromkeys(dict_keys, (Any, ...)), ) return super().get_input_schema(config) @@ -4530,7 +4532,7 @@ def __eq__(self, other: Any) -> bool: def __repr__(self) -> str: """A string representation of this Runnable.""" if hasattr(self, "func") and isinstance(self.func, itemgetter): - return f"RunnableLambda({str(self.func)[len('operator.'):]})" + return f"RunnableLambda({str(self.func)[len('operator.') :]})" elif hasattr(self, "func"): return f"RunnableLambda({get_lambda_source(self.func) or '...'})" elif hasattr(self, "afunc"): @@ -4791,8 +4793,7 @@ def _transform( recursion_limit = config["recursion_limit"] if recursion_limit <= 0: msg = ( - f"Recursion limit reached when invoking " - f"{self} with input {final}." + f"Recursion limit reached when invoking {self} with input {final}." ) raise RecursionError(msg) for chunk in output.stream( @@ -4915,8 +4916,7 @@ async def f(*args, **kwargs): # type: ignore[no-untyped-def] recursion_limit = config["recursion_limit"] if recursion_limit <= 0: msg = ( - f"Recursion limit reached when invoking " - f"{self} with input {final}." + f"Recursion limit reached when invoking {self} with input {final}." ) raise RecursionError(msg) async for chunk in output.astream( diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index 7c285137fafe3..e384092885482 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -110,8 +110,8 @@ class RunnableConfig(TypedDict, total=False): DEFAULT_RECURSION_LIMIT = 25 -var_child_runnable_config = ContextVar( - "child_runnable_config", default=RunnableConfig() +var_child_runnable_config: ContextVar[RunnableConfig | None] = ContextVar( + "child_runnable_config", default=None ) diff --git a/libs/core/langchain_core/runnables/graph_ascii.py b/libs/core/langchain_core/runnables/graph_ascii.py index f2a031ba43dbb..72f700d2d1d10 100644 --- a/libs/core/langchain_core/runnables/graph_ascii.py +++ b/libs/core/langchain_core/runnables/graph_ascii.py @@ -236,17 +236,20 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str: # NOTE: coordinates might me negative, so we need to shift # everything to the positive plane before we actually draw it. - xlist = [] - ylist = [] + xlist: list[float] = [] + ylist: list[float] = [] sug = _build_sugiyama_layout(vertices, edges) for vertex in sug.g.sV: # NOTE: moving boxes w/2 to the left - xlist.append(vertex.view.xy[0] - vertex.view.w / 2.0) - xlist.append(vertex.view.xy[0] + vertex.view.w / 2.0) - ylist.append(vertex.view.xy[1]) - ylist.append(vertex.view.xy[1] + vertex.view.h) + xlist.extend( + ( + vertex.view.xy[0] - vertex.view.w / 2.0, + vertex.view.xy[0] + vertex.view.w / 2.0, + ) + ) + ylist.extend((vertex.view.xy[1], vertex.view.xy[1] + vertex.view.h)) for edge in sug.g.sE: for x, y in edge.view._pts: diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index 3d13da5a66d89..e6a88e4d2243a 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -590,9 +590,7 @@ def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig: if missing_keys and parameter_names: example_input = {self.input_messages_key: "foo"} - example_configurable = { - missing_key: "[your-value-here]" for missing_key in missing_keys - } + example_configurable = dict.fromkeys(missing_keys, "[your-value-here]") example_config = {"configurable": example_configurable} msg = ( f"Missing keys {sorted(missing_keys)} in config['configurable'] " diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index 7e7fa7b45bdf0..1083aa0f1eb41 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -472,9 +472,9 @@ def _invoke( config: RunnableConfig, **kwargs: Any, ) -> dict[str, Any]: - assert isinstance( - input, dict - ), "The input to RunnablePassthrough.assign() must be a dict." + assert isinstance(input, dict), ( + "The input to RunnablePassthrough.assign() must be a dict." + ) return { **input, @@ -500,9 +500,9 @@ async def _ainvoke( config: RunnableConfig, **kwargs: Any, ) -> dict[str, Any]: - assert isinstance( - input, dict - ), "The input to RunnablePassthrough.assign() must be a dict." + assert isinstance(input, dict), ( + "The input to RunnablePassthrough.assign() must be a dict." + ) return { **input, @@ -553,9 +553,9 @@ def _transform( ) # consume passthrough stream for chunk in for_passthrough: - assert isinstance( - chunk, dict - ), "The input to RunnablePassthrough.assign() must be a dict." + assert isinstance(chunk, dict), ( + "The input to RunnablePassthrough.assign() must be a dict." + ) # remove mapper keys from passthrough chunk, to be overwritten by map filtered = AddableDict( {k: v for k, v in chunk.items() if k not in mapper_keys} @@ -603,9 +603,9 @@ async def _atransform( ) # consume passthrough stream async for chunk in for_passthrough: - assert isinstance( - chunk, dict - ), "The input to RunnablePassthrough.assign() must be a dict." + assert isinstance(chunk, dict), ( + "The input to RunnablePassthrough.assign() must be a dict." + ) # remove mapper keys from passthrough chunk, to be overwritten by map output filtered = AddableDict( {k: v for k, v in chunk.items() if k not in mapper_keys} @@ -705,9 +705,9 @@ def get_name( return super().get_name(suffix, name=name) def _pick(self, input: dict[str, Any]) -> Any: - assert isinstance( - input, dict - ), "The input to RunnablePassthrough.assign() must be a dict." + assert isinstance(input, dict), ( + "The input to RunnablePassthrough.assign() must be a dict." + ) if isinstance(self.keys, str): return input.get(self.keys) diff --git a/libs/core/langchain_core/runnables/retry.py b/libs/core/langchain_core/runnables/retry.py index 0469dd961b47a..234f025de1003 100644 --- a/libs/core/langchain_core/runnables/retry.py +++ b/libs/core/langchain_core/runnables/retry.py @@ -249,7 +249,7 @@ def pending(iterable: list[U]) -> list[U]: result = cast(list[Output], [e] * len(inputs)) outputs: list[Union[Output, Exception]] = [] - for idx, _ in enumerate(inputs): + for idx in range(len(inputs)): if idx in results_map: outputs.append(results_map[idx]) else: @@ -315,7 +315,7 @@ def pending(iterable: list[U]) -> list[U]: result = cast(list[Output], [e] * len(inputs)) outputs: list[Union[Output, Exception]] = [] - for idx, _ in enumerate(inputs): + for idx in range(len(inputs)): if idx in results_map: outputs.append(results_map[idx]) else: diff --git a/libs/core/langchain_core/runnables/router.py b/libs/core/langchain_core/runnables/router.py index 8d353648cd928..d5b504f4890f1 100644 --- a/libs/core/langchain_core/runnables/router.py +++ b/libs/core/langchain_core/runnables/router.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import AsyncIterator, Iterator, Mapping +from itertools import starmap from typing import ( Any, Callable, @@ -190,10 +191,7 @@ async def ainvoke( configs = get_config_list(config, len(inputs)) return await gather_with_concurrency( configs[0].get("max_concurrency"), - *( - ainvoke(runnable, input, config) - for runnable, input, config in zip(runnables, actual_inputs, configs) - ), + *starmap(ainvoke, zip(runnables, actual_inputs, configs)), ) def stream( diff --git a/libs/core/langchain_core/tracers/event_stream.py b/libs/core/langchain_core/tracers/event_stream.py index b7a1ddc853f67..b7f3db8595d61 100644 --- a/libs/core/langchain_core/tracers/event_stream.py +++ b/libs/core/langchain_core/tracers/event_stream.py @@ -863,7 +863,7 @@ async def _astream_events_implementation_v1( tags=log_entry["tags"], metadata=log_entry["metadata"], data=data, - parent_ids=[], # Not supported in v1 + parent_ids=[], # Not supported in v1 ) # Finally, we take care of the streaming output from the root chain diff --git a/libs/core/langchain_core/tracers/stdout.py b/libs/core/langchain_core/tracers/stdout.py index 22ace8bb70f98..3643724f9e2b4 100644 --- a/libs/core/langchain_core/tracers/stdout.py +++ b/libs/core/langchain_core/tracers/stdout.py @@ -160,7 +160,7 @@ def _on_llm_error(self, run: Run) -> None: def _on_tool_start(self, run: Run) -> None: crumbs = self.get_breadcrumbs(run) self.function_callback( - f'{get_colored_text("[tool/start]", color="green")} ' + f"{get_colored_text('[tool/start]', color='green')} " + get_bolded_text(f"[{crumbs}] Entering Tool run with input:\n") + f'"{run.inputs["input"].strip()}"' ) @@ -169,7 +169,7 @@ def _on_tool_end(self, run: Run) -> None: crumbs = self.get_breadcrumbs(run) if run.outputs: self.function_callback( - f'{get_colored_text("[tool/end]", color="blue")} ' + f"{get_colored_text('[tool/end]', color='blue')} " + get_bolded_text( f"[{crumbs}] [{elapsed(run)}] Exiting Tool run with output:\n" ) diff --git a/libs/core/langchain_core/utils/formatting.py b/libs/core/langchain_core/utils/formatting.py index d00431be85b1e..d2313c5ca46b1 100644 --- a/libs/core/langchain_core/utils/formatting.py +++ b/libs/core/langchain_core/utils/formatting.py @@ -44,7 +44,7 @@ def validate_input_variables( Raises: ValueError: If any input variables are not used in the format string. """ - dummy_inputs = {input_variable: "foo" for input_variable in input_variables} + dummy_inputs = dict.fromkeys(input_variables, "foo") super().format(format_string, **dummy_inputs) diff --git a/libs/core/langchain_core/utils/mustache.py b/libs/core/langchain_core/utils/mustache.py index 89d5d9fbbf144..37abb47393251 100644 --- a/libs/core/langchain_core/utils/mustache.py +++ b/libs/core/langchain_core/utils/mustache.py @@ -126,8 +126,7 @@ def parse_tag(template: str, l_del: str, r_del: str) -> tuple[tuple[str, str], s ChevronError: If the tag is unclosed. ChevronError: If the set delimiter tag is unclosed. """ - global _CURRENT_LINE - global _LAST_TAG_LINE + global _CURRENT_LINE, _LAST_TAG_LINE tag_types = { "!": "comment", @@ -144,7 +143,7 @@ def parse_tag(template: str, l_del: str, r_del: str) -> tuple[tuple[str, str], s try: tag, template = template.split(r_del, 1) except ValueError as e: - msg = "unclosed tag " f"at line {_CURRENT_LINE}" + msg = f"unclosed tag at line {_CURRENT_LINE}" raise ChevronError(msg) from e # Find the type meaning of the first character @@ -165,7 +164,7 @@ def parse_tag(template: str, l_del: str, r_del: str) -> tuple[tuple[str, str], s # Otherwise we should complain else: - msg = "unclosed set delimiter tag\n" f"at line {_CURRENT_LINE}" + msg = f"unclosed set delimiter tag\nat line {_CURRENT_LINE}" raise ChevronError(msg) elif ( diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index 65f12232f9fc6..704261addf92c 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -390,6 +390,7 @@ def get_fields( else: msg = f"Expected a Pydantic model. Got {type(model)}" raise TypeError(msg) + elif PYDANTIC_MAJOR_VERSION == 1: from pydantic import BaseModel as BaseModelV1_ @@ -398,6 +399,7 @@ def get_fields( # type: ignore[no-redef] ) -> dict[str, FieldInfoV1]: """Get the field names of a Pydantic model.""" return model.__fields__ # type: ignore + else: msg = f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}" raise ValueError(msg) diff --git a/libs/core/langchain_core/vectorstores/in_memory.py b/libs/core/langchain_core/vectorstores/in_memory.py index d6eb978d64304..ab32c7cdacbe9 100644 --- a/libs/core/langchain_core/vectorstores/in_memory.py +++ b/libs/core/langchain_core/vectorstores/in_memory.py @@ -189,7 +189,7 @@ def add_documents( for doc, vector in zip(documents, vectors): doc_id = next(id_iterator) - doc_id_ = doc_id if doc_id else str(uuid.uuid4()) + doc_id_ = doc_id or str(uuid.uuid4()) ids_.append(doc_id_) self.store[doc_id_] = { "id": doc_id_, @@ -221,7 +221,7 @@ async def aadd_documents( for doc, vector in zip(documents, vectors): doc_id = next(id_iterator) - doc_id_ = doc_id if doc_id else str(uuid.uuid4()) + doc_id_ = doc_id or str(uuid.uuid4()) ids_.append(doc_id_) self.store[doc_id_] = { "id": doc_id_, @@ -258,8 +258,7 @@ def get_by_ids(self, ids: Sequence[str], /) -> list[Document]: @deprecated( alternative="VectorStore.add_documents", message=( - "This was a beta API that was added in 0.2.11. " - "It'll be removed in 0.3.0." + "This was a beta API that was added in 0.2.11. It'll be removed in 0.3.0." ), since="0.2.29", removal="1.0", @@ -268,7 +267,7 @@ def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse: vectors = self.embedding.embed_documents([item.page_content for item in items]) ids = [] for item, vector in zip(items, vectors): - doc_id = item.id if item.id else str(uuid.uuid4()) + doc_id = item.id or str(uuid.uuid4()) ids.append(doc_id) self.store[doc_id] = { "id": doc_id, @@ -284,8 +283,7 @@ def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse: @deprecated( alternative="VectorStore.aadd_documents", message=( - "This was a beta API that was added in 0.2.11. " - "It'll be removed in 0.3.0." + "This was a beta API that was added in 0.2.11. It'll be removed in 0.3.0." ), since="0.2.29", removal="1.0", @@ -298,7 +296,7 @@ async def aupsert( ) ids = [] for item, vector in zip(items, vectors): - doc_id = item.id if item.id else str(uuid.uuid4()) + doc_id = item.id or str(uuid.uuid4()) ids.append(doc_id) self.store[doc_id] = { "id": doc_id, diff --git a/libs/core/poetry.lock b/libs/core/poetry.lock index eadfb8d460c70..cba2a987d0d58 100644 --- a/libs/core/poetry.lock +++ b/libs/core/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "annotated-types" @@ -1200,7 +1200,7 @@ files = [ [[package]] name = "langchain-tests" -version = "0.3.7" +version = "0.3.8" description = "Standard tests for LangChain implementations" optional = false python-versions = ">=3.9,<4.0" @@ -1225,7 +1225,7 @@ url = "../standard-tests" [[package]] name = "langchain-text-splitters" -version = "0.3.4" +version = "0.3.5" description = "LangChain text splitting utilities" optional = false python-versions = ">=3.9,<4.0" @@ -1233,7 +1233,7 @@ files = [] develop = true [package.dependencies] -langchain-core = "^0.3.26" +langchain-core = "^0.3.29" [package.source] type = "directory" @@ -2658,29 +2658,29 @@ files = [ [[package]] name = "ruff" -version = "0.5.7" +version = "0.9.1" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.5.7-py3-none-linux_armv6l.whl", hash = "sha256:548992d342fc404ee2e15a242cdbea4f8e39a52f2e7752d0e4cbe88d2d2f416a"}, - {file = "ruff-0.5.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:00cc8872331055ee017c4f1071a8a31ca0809ccc0657da1d154a1d2abac5c0be"}, - {file = "ruff-0.5.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:eaf3d86a1fdac1aec8a3417a63587d93f906c678bb9ed0b796da7b59c1114a1e"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a01c34400097b06cf8a6e61b35d6d456d5bd1ae6961542de18ec81eaf33b4cb8"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fcc8054f1a717e2213500edaddcf1dbb0abad40d98e1bd9d0ad364f75c763eea"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7f70284e73f36558ef51602254451e50dd6cc479f8b6f8413a95fcb5db4a55fc"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:a78ad870ae3c460394fc95437d43deb5c04b5c29297815a2a1de028903f19692"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ccd078c66a8e419475174bfe60a69adb36ce04f8d4e91b006f1329d5cd44bcf"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7e31c9bad4ebf8fdb77b59cae75814440731060a09a0e0077d559a556453acbb"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d796327eed8e168164346b769dd9a27a70e0298d667b4ecee6877ce8095ec8e"}, - {file = "ruff-0.5.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4a09ea2c3f7778cc635e7f6edf57d566a8ee8f485f3c4454db7771efb692c499"}, - {file = "ruff-0.5.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a36d8dcf55b3a3bc353270d544fb170d75d2dff41eba5df57b4e0b67a95bb64e"}, - {file = "ruff-0.5.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9369c218f789eefbd1b8d82a8cf25017b523ac47d96b2f531eba73770971c9e5"}, - {file = "ruff-0.5.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b88ca3db7eb377eb24fb7c82840546fb7acef75af4a74bd36e9ceb37a890257e"}, - {file = "ruff-0.5.7-py3-none-win32.whl", hash = "sha256:33d61fc0e902198a3e55719f4be6b375b28f860b09c281e4bdbf783c0566576a"}, - {file = "ruff-0.5.7-py3-none-win_amd64.whl", hash = "sha256:083bbcbe6fadb93cd86709037acc510f86eed5a314203079df174c40bbbca6b3"}, - {file = "ruff-0.5.7-py3-none-win_arm64.whl", hash = "sha256:2dca26154ff9571995107221d0aeaad0e75a77b5a682d6236cf89a58c70b76f4"}, - {file = "ruff-0.5.7.tar.gz", hash = "sha256:8dfc0a458797f5d9fb622dd0efc52d796f23f0a1493a9527f4e49a550ae9a7e5"}, + {file = "ruff-0.9.1-py3-none-linux_armv6l.whl", hash = "sha256:84330dda7abcc270e6055551aca93fdde1b0685fc4fd358f26410f9349cf1743"}, + {file = "ruff-0.9.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:3cae39ba5d137054b0e5b472aee3b78a7c884e61591b100aeb544bcd1fc38d4f"}, + {file = "ruff-0.9.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:50c647ff96f4ba288db0ad87048257753733763b409b2faf2ea78b45c8bb7fcb"}, + {file = "ruff-0.9.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f0c8b149e9c7353cace7d698e1656ffcf1e36e50f8ea3b5d5f7f87ff9986a7ca"}, + {file = "ruff-0.9.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:beb3298604540c884d8b282fe7625651378e1986c25df51dec5b2f60cafc31ce"}, + {file = "ruff-0.9.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:39d0174ccc45c439093971cc06ed3ac4dc545f5e8bdacf9f067adf879544d969"}, + {file = "ruff-0.9.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:69572926c0f0c9912288915214ca9b2809525ea263603370b9e00bed2ba56dbd"}, + {file = "ruff-0.9.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:937267afce0c9170d6d29f01fcd1f4378172dec6760a9f4dface48cdabf9610a"}, + {file = "ruff-0.9.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:186c2313de946f2c22bdf5954b8dd083e124bcfb685732cfb0beae0c47233d9b"}, + {file = "ruff-0.9.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f94942a3bb767675d9a051867c036655fe9f6c8a491539156a6f7e6b5f31831"}, + {file = "ruff-0.9.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:728d791b769cc28c05f12c280f99e8896932e9833fef1dd8756a6af2261fd1ab"}, + {file = "ruff-0.9.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2f312c86fb40c5c02b44a29a750ee3b21002bd813b5233facdaf63a51d9a85e1"}, + {file = "ruff-0.9.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:ae017c3a29bee341ba584f3823f805abbe5fe9cd97f87ed07ecbf533c4c88366"}, + {file = "ruff-0.9.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5dc40a378a0e21b4cfe2b8a0f1812a6572fc7b230ef12cd9fac9161aa91d807f"}, + {file = "ruff-0.9.1-py3-none-win32.whl", hash = "sha256:46ebf5cc106cf7e7378ca3c28ce4293b61b449cd121b98699be727d40b79ba72"}, + {file = "ruff-0.9.1-py3-none-win_amd64.whl", hash = "sha256:342a824b46ddbcdddd3abfbb332fa7fcaac5488bf18073e841236aadf4ad5c19"}, + {file = "ruff-0.9.1-py3-none-win_arm64.whl", hash = "sha256:1cd76c7f9c679e6e8f2af8f778367dca82b95009bc7b1a85a47f1521ae524fa7"}, + {file = "ruff-0.9.1.tar.gz", hash = "sha256:fd2b25ecaf907d6458fa842675382c8597b3c746a2dde6717fe3415425df0c17"}, ] [[package]] @@ -3138,4 +3138,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "65d2f612fead6395befc285353347bf82d09044ce832c278f8b35e4f179caebb" +content-hash = "789709f8646c52360c7937faf8c22febb07438c7a0adf0b941fee14f7a44b130" diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 65e40fafb09f8..03ecb64964c00 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -82,7 +82,7 @@ classmethod-decorators = [ "classmethod", "langchain_core.utils.pydantic.pre_ini "scripts/**" = [ "S",] [tool.poetry.group.lint.dependencies] -ruff = "^0.5" +ruff = "^0.9.1" [tool.poetry.group.typing.dependencies] diff --git a/libs/core/tests/unit_tests/chat_history/test_chat_history.py b/libs/core/tests/unit_tests/chat_history/test_chat_history.py index eb7690335f974..7557d6988f1dd 100644 --- a/libs/core/tests/unit_tests/chat_history/test_chat_history.py +++ b/libs/core/tests/unit_tests/chat_history/test_chat_history.py @@ -29,7 +29,10 @@ def clear(self) -> None: assert store[1] == HumanMessage(content="World") chat_history.add_messages( - [HumanMessage(content="Hello"), HumanMessage(content="World")] + [ + HumanMessage(content="Hello"), + HumanMessage(content="World"), + ] ) assert len(store) == 4 assert store[2] == HumanMessage(content="Hello") @@ -61,7 +64,10 @@ def clear(self) -> None: assert store[1] == HumanMessage(content="World") chat_history.add_messages( - [HumanMessage(content="Hello"), HumanMessage(content="World")] + [ + HumanMessage(content="Hello"), + HumanMessage(content="World"), + ] ) assert len(store) == 4 assert store[2] == HumanMessage(content="Hello") @@ -85,7 +91,10 @@ def clear(self) -> None: chat_history = BulkAddHistory() await chat_history.aadd_messages( - [HumanMessage(content="Hello"), HumanMessage(content="World")] + [ + HumanMessage(content="Hello"), + HumanMessage(content="World"), + ] ) assert await chat_history.aget_messages() == [ HumanMessage(content="Hello"), diff --git a/libs/core/tests/unit_tests/embeddings/test_deterministic_embedding.py b/libs/core/tests/unit_tests/embeddings/test_deterministic_embedding.py index 5ad33b7e307cd..20596e7c4435e 100644 --- a/libs/core/tests/unit_tests/embeddings/test_deterministic_embedding.py +++ b/libs/core/tests/unit_tests/embeddings/test_deterministic_embedding.py @@ -12,5 +12,8 @@ def test_deterministic_fake_embeddings() -> None: assert fake.embed_query(text) != fake.embed_query("Goodbye world!") assert fake.embed_documents([text, text]) == fake.embed_documents([text, text]) assert fake.embed_documents([text, text]) != fake.embed_documents( - [text, "Goodbye world!"] + [ + text, + "Goodbye world!", + ] ) diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py index 0d5b89de7c354..b35db669695b9 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py @@ -227,7 +227,7 @@ def test_global_cache_batch() -> None: assert results[0].content == results[1].content assert {results[0].content, results[1].content}.issubset({"hello", "goodbye"}) - ## RACE CONDITION -- note behavior is different from async + # RACE CONDITION -- note behavior is different from async # Now, reset cache and test the race condition # For now we just hard-code the result, if this changes # we can investigate further diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index 3a32c2984952a..18b95ce75a447 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -663,7 +663,7 @@ def create_image_data() -> str: def create_base64_image(format: str = "jpeg") -> str: data = create_image_data() - return f"data:image/{format};base64,{data}" # noqa: E501 + return f"data:image/{format};base64,{data}" def test_convert_to_openai_messages_single_message() -> None: diff --git a/libs/core/tests/unit_tests/output_parsers/test_json.py b/libs/core/tests/unit_tests/output_parsers/test_json.py index 326cfc16cd9b1..108b9f2a9cf79 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_json.py +++ b/libs/core/tests/unit_tests/output_parsers/test_json.py @@ -613,6 +613,6 @@ class Sample(BaseModel): parser = SimpleJsonOutputParser(pydantic_object=Sample) format_instructions = parser.get_format_instructions() - assert ( - "科学文章的标题" in format_instructions - ), "Unicode characters should not be escaped" + assert "科学文章的标题" in format_instructions, ( + "Unicode characters should not be escaped" + ) diff --git a/libs/core/tests/unit_tests/output_parsers/test_list_parser.py b/libs/core/tests/unit_tests/output_parsers/test_list_parser.py index 11bd11b6a0b92..fc578788d0f40 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_list_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_list_parser.py @@ -121,7 +121,8 @@ def test_numbered_list() -> None: def test_markdown_list() -> None: parser = MarkdownListOutputParser() text1 = ( - "Your response should be a numbered - not a list item - list with each item on a new line." # noqa: E501 + "Your response should be a numbered - not a list item - " + "list with each item on a new line." "For example: \n- foo\n- bar\n- baz" ) diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 6249aa6f47893..0d018e0418137 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -124,9 +124,7 @@ def test_create_system_message_prompt_template_from_template_partial() -> None: partial_variables={"instructions": json_prompt_instructions}, ) assert graph_analyst_template.format(history="history") == SystemMessage( - content="\n Your instructions are:\n " - " {}\n History:\n " - "history\n " + content="\n Your instructions are:\n {}\n History:\n history\n " ) @@ -235,7 +233,11 @@ def test_chat_prompt_template_from_messages( """Test creating a chat prompt template from messages.""" chat_prompt_template = ChatPromptTemplate.from_messages(messages) assert sorted(chat_prompt_template.input_variables) == sorted( - ["context", "foo", "bar"] + [ + "context", + "foo", + "bar", + ] ) assert len(chat_prompt_template.messages) == 4 @@ -378,7 +380,11 @@ def test_chat_prompt_template_with_messages( messages + [HumanMessage(content="foo")] ) assert sorted(chat_prompt_template.input_variables) == sorted( - ["context", "foo", "bar"] + [ + "context", + "foo", + "bar", + ] ) assert len(chat_prompt_template.messages) == 5 prompt_value = chat_prompt_template.format_prompt( @@ -836,7 +842,10 @@ async def test_messages_prompt_accepts_list() -> None: # Assert still raises a nice error prompt = ChatPromptTemplate( - [("system", "You are a {foo}"), MessagesPlaceholder("history")] + [ + ("system", "You are a {foo}"), + MessagesPlaceholder("history"), + ] ) with pytest.raises(TypeError): prompt.invoke([("user", "Hi there")]) # type: ignore @@ -873,7 +882,11 @@ def test_chat_input_schema(snapshot: SnapshotAssertion) -> None: def test_chat_prompt_w_msgs_placeholder_ser_des(snapshot: SnapshotAssertion) -> None: prompt = ChatPromptTemplate.from_messages( - [("system", "foo"), MessagesPlaceholder("bar"), ("human", "baz")] + [ + ("system", "foo"), + MessagesPlaceholder("bar"), + ("human", "baz"), + ] ) assert dumpd(MessagesPlaceholder("bar")) == snapshot(name="placeholder") assert load(dumpd(MessagesPlaceholder("bar"))) == MessagesPlaceholder("bar") diff --git a/libs/core/tests/unit_tests/prompts/test_few_shot.py b/libs/core/tests/unit_tests/prompts/test_few_shot.py index 27b0208285fe8..97d8eeaeb2625 100644 --- a/libs/core/tests/unit_tests/prompts/test_few_shot.py +++ b/libs/core/tests/unit_tests/prompts/test_few_shot.py @@ -243,7 +243,7 @@ def test_prompt_jinja2_functionality( ) output = prompt.format(foo="hello", bar="bye") expected_output = ( - "Starting with hello\n\n" "happy: sad\n\n" "tall: short\n\n" "Ending with bye" + "Starting with hello\n\nhappy: sad\n\ntall: short\n\nEnding with bye" ) assert output == expected_output diff --git a/libs/core/tests/unit_tests/runnables/test_config.py b/libs/core/tests/unit_tests/runnables/test_config.py index f5243f4e78f9f..678a8086220d6 100644 --- a/libs/core/tests/unit_tests/runnables/test_config.py +++ b/libs/core/tests/unit_tests/runnables/test_config.py @@ -49,12 +49,12 @@ def test_ensure_config() -> None: }, ) config = ctx.run(ensure_config, cast(RunnableConfig, arg)) - assert ( - len(arg["callbacks"]) == 1 - ), "ensure_config should not modify the original config" - assert ( - json.dumps({**arg, "callbacks": []}) == arg_str - ), "ensure_config should not modify the original config" + assert len(arg["callbacks"]) == 1, ( + "ensure_config should not modify the original config" + ) + assert json.dumps({**arg, "callbacks": []}) == arg_str, ( + "ensure_config should not modify the original config" + ) assert config is not arg assert config["callbacks"] is not arg["callbacks"] assert config["metadata"] is not arg["metadata"] diff --git a/libs/core/tests/unit_tests/runnables/test_fallbacks.py b/libs/core/tests/unit_tests/runnables/test_fallbacks.py index 731b3ddaa62aa..e3ebdf794206b 100644 --- a/libs/core/tests/unit_tests/runnables/test_fallbacks.py +++ b/libs/core/tests/unit_tests/runnables/test_fallbacks.py @@ -215,7 +215,11 @@ async def test_abatch() -> None: ) with pytest.raises(RuntimeError): await runnable_with_single.abatch( - [{"text": "foo"}, {"text": "bar"}, {"text": "baz"}] + [ + {"text": "foo"}, + {"text": "bar"}, + {"text": "baz"}, + ] ) actual = await runnable_with_single.abatch( [{"text": "foo"}, {"text": "bar"}, {"text": "baz"}], return_exceptions=True diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index b06cb381e80e6..34df87de0cee6 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -622,32 +622,29 @@ async def aget_values_typed(input: InputType) -> OutputType: "byebye": input["yo"], } - assert ( - _normalize_schema( - RunnableLambda( - aget_values_typed # type: ignore[arg-type] - ).get_input_jsonschema() - ) - == _normalize_schema( - { - "$defs": { - "InputType": { - "properties": { - "variable_name": { - "title": "Variable " "Name", - "type": "string", - }, - "yo": {"title": "Yo", "type": "integer"}, + assert _normalize_schema( + RunnableLambda( + aget_values_typed # type: ignore[arg-type] + ).get_input_jsonschema() + ) == _normalize_schema( + { + "$defs": { + "InputType": { + "properties": { + "variable_name": { + "title": "Variable Name", + "type": "string", }, - "required": ["variable_name", "yo"], - "title": "InputType", - "type": "object", - } - }, - "allOf": [{"$ref": "#/$defs/InputType"}], - "title": "aget_values_typed_input", - } - ) + "yo": {"title": "Yo", "type": "integer"}, + }, + "required": ["variable_name", "yo"], + "title": "InputType", + "type": "object", + } + }, + "allOf": [{"$ref": "#/$defs/InputType"}], + "title": "aget_values_typed_input", + } ) if PYDANTIC_VERSION >= (2, 9): @@ -2793,7 +2790,10 @@ async def test_router_runnable( assert result == "4" result2 = chain.batch( - [{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}] + [ + {"key": "math", "question": "2 + 2"}, + {"key": "english", "question": "2 + 2"}, + ] ) assert result2 == ["4", "2"] @@ -2801,7 +2801,10 @@ async def test_router_runnable( assert result == "4" result2 = await chain.abatch( - [{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}] + [ + {"key": "math", "question": "2 + 2"}, + {"key": "english", "question": "2 + 2"}, + ] ) assert result2 == ["4", "2"] @@ -2855,7 +2858,10 @@ def router(input: dict[str, Any]) -> Runnable: assert result == "4" result2 = chain.batch( - [{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}] + [ + {"key": "math", "question": "2 + 2"}, + {"key": "english", "question": "2 + 2"}, + ] ) assert result2 == ["4", "2"] @@ -2863,7 +2869,10 @@ def router(input: dict[str, Any]) -> Runnable: assert result == "4" result2 = await chain.abatch( - [{"key": "math", "question": "2 + 2"}, {"key": "english", "question": "2 + 2"}] + [ + {"key": "math", "question": "2 + 2"}, + {"key": "english", "question": "2 + 2"}, + ] ) assert result2 == ["4", "2"] @@ -3058,7 +3067,10 @@ def test_map_stream() -> None: assert len(streamed_chunks) == len(llm_res) chain_pick_two = chain.assign(hello=RunnablePick("llm").pipe(llm)).pick( - ["llm", "hello"] + [ + "llm", + "hello", + ] ) assert chain_pick_two.get_output_jsonschema() == { @@ -5445,7 +5457,11 @@ def test_schema_for_prompt_and_chat_model() -> None: chain = prompt | chat assert ( chain.invoke( - {"model_json_schema": "hello", "_private": "goodbye", "json": "json"} + { + "model_json_schema": "hello", + "_private": "goodbye", + "json": "json", + } ).content == chat_res ) diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py index 7389d887769d8..dd9e960a2a84f 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py @@ -795,7 +795,10 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: async def test_event_stream_with_simple_chain() -> None: """Test as event stream.""" template = ChatPromptTemplate.from_messages( - [("system", "You are Cat Agent 007"), ("human", "{question}")] + [ + ("system", "You are Cat Agent 007"), + ("human", "{question}"), + ] ).with_config({"run_name": "my_template", "tags": ["my_template"]}) infinite_cycle = cycle( @@ -1681,7 +1684,10 @@ def fail(inputs: str) -> None: async def test_with_llm() -> None: """Test with regular llm.""" prompt = ChatPromptTemplate.from_messages( - [("system", "You are Cat Agent 007"), ("human", "{question}")] + [ + ("system", "You are Cat Agent 007"), + ("human", "{question}"), + ] ).with_config({"run_name": "my_template", "tags": ["my_template"]}) llm = FakeStreamingListLLM(responses=["abc"]) @@ -1730,7 +1736,7 @@ async def test_with_llm() -> None: { "data": { "input": { - "prompts": ["System: You are Cat Agent 007\n" "Human: hello"] + "prompts": ["System: You are Cat Agent 007\nHuman: hello"] } }, "event": "on_llm_start", @@ -1743,7 +1749,7 @@ async def test_with_llm() -> None: { "data": { "input": { - "prompts": ["System: You are Cat Agent 007\n" "Human: hello"] + "prompts": ["System: You are Cat Agent 007\nHuman: hello"] }, "output": { "generations": [ @@ -1918,7 +1924,10 @@ def get_by_session_id(session_id: str) -> BaseChatMessageHistory: return InMemoryHistory(messages=store[session_id]) infinite_cycle = cycle( - [AIMessage(content="hello", id="ai3"), AIMessage(content="world", id="ai4")] + [ + AIMessage(content="hello", id="ai3"), + AIMessage(content="world", id="ai4"), + ] ) prompt = ChatPromptTemplate.from_messages( diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index 8ceb4bf38b5f1..ff38cf3e95a42 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -55,12 +55,12 @@ def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]: for event in events: assert "run_id" in event, f"Event {event} does not have a run_id." assert "parent_ids" in event, f"Event {event} does not have parent_ids." - assert isinstance( - event["run_id"], str - ), f"Event {event} run_id is not a string." - assert isinstance( - event["parent_ids"], list - ), f"Event {event} parent_ids is not a list." + assert isinstance(event["run_id"], str), ( + f"Event {event} run_id is not a string." + ) + assert isinstance(event["parent_ids"], list), ( + f"Event {event} parent_ids is not a list." + ) return cast( list[StreamEvent], @@ -828,7 +828,10 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: async def test_event_stream_with_simple_chain() -> None: """Test as event stream.""" template = ChatPromptTemplate.from_messages( - [("system", "You are Cat Agent 007"), ("human", "{question}")] + [ + ("system", "You are Cat Agent 007"), + ("human", "{question}"), + ] ).with_config({"run_name": "my_template", "tags": ["my_template"]}) infinite_cycle = cycle( @@ -1628,7 +1631,10 @@ def fail(inputs: str) -> None: async def test_with_llm() -> None: """Test with regular llm.""" prompt = ChatPromptTemplate.from_messages( - [("system", "You are Cat Agent 007"), ("human", "{question}")] + [ + ("system", "You are Cat Agent 007"), + ("human", "{question}"), + ] ).with_config({"run_name": "my_template", "tags": ["my_template"]}) llm = FakeStreamingListLLM(responses=["abc"]) @@ -1677,7 +1683,7 @@ async def test_with_llm() -> None: { "data": { "input": { - "prompts": ["System: You are Cat Agent 007\n" "Human: hello"] + "prompts": ["System: You are Cat Agent 007\nHuman: hello"] } }, "event": "on_llm_start", @@ -1690,7 +1696,7 @@ async def test_with_llm() -> None: { "data": { "input": { - "prompts": ["System: You are Cat Agent 007\n" "Human: hello"] + "prompts": ["System: You are Cat Agent 007\nHuman: hello"] }, "output": { "generations": [ @@ -1865,7 +1871,10 @@ def get_by_session_id(session_id: str) -> BaseChatMessageHistory: return InMemoryHistory(messages=store[session_id]) infinite_cycle = cycle( - [AIMessage(content="hello", id="ai3"), AIMessage(content="world", id="ai4")] + [ + AIMessage(content="hello", id="ai3"), + AIMessage(content="world", id="ai4"), + ] ) prompt = ChatPromptTemplate.from_messages( @@ -2372,7 +2381,10 @@ def passthrough_to_trigger_issue(x: str) -> str: return x chain = passthrough_to_trigger_issue | model.with_config( - {"tags": ["hello"], "callbacks": callbacks} + { + "tags": ["hello"], + "callbacks": callbacks, + } ) return await chain.ainvoke(query) diff --git a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py index 0743929f86120..3409d04f23401 100644 --- a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py +++ b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py @@ -333,7 +333,7 @@ def parent(a: int) -> int: "other_thing": "RunnableParallel", "after": "RunnableSequence", } - assert len(posts) == sum([1 if isinstance(n, str) else len(n) for n in name_order]) + assert len(posts) == sum(1 if isinstance(n, str) else len(n) for n in name_order) prev_dotted_order = None dotted_order_map = {} id_map = {} @@ -360,9 +360,9 @@ def parent(a: int) -> int: if prev_dotted_order is not None and not str( expected_parents[name] ).startswith("RunnableParallel"): - assert ( - dotted_order > prev_dotted_order - ), f"{name} not after {name_order[i-1]}" + assert dotted_order > prev_dotted_order, ( + f"{name} not after {name_order[i - 1]}" + ) prev_dotted_order = dotted_order if name in dotted_order_map: msg = f"Duplicate name {name}" @@ -377,9 +377,9 @@ def parent(a: int) -> int: dotted_order = dotted_order_map[name] if parent_ is not None: parent_dotted_order = dotted_order_map[parent_] - assert dotted_order.startswith( - parent_dotted_order - ), f"{name}, {parent_dotted_order} not in {dotted_order}" + assert dotted_order.startswith(parent_dotted_order), ( + f"{name}, {parent_dotted_order} not in {dotted_order}" + ) assert str(parent_id_map[name]) == str(id_map[parent_]) else: assert dotted_order.split(".")[0] == dotted_order diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index aafcd15e1bb81..d9928b729f396 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -47,48 +47,47 @@ def test_message_init() -> None: def test_message_chunks() -> None: assert AIMessageChunk(content="I am", id="ai3") + AIMessageChunk( content=" indeed." - ) == AIMessageChunk( - content="I am indeed.", id="ai3" - ), "MessageChunk + MessageChunk should be a MessageChunk" + ) == AIMessageChunk(content="I am indeed.", id="ai3"), ( + "MessageChunk + MessageChunk should be a MessageChunk" + ) - assert ( - AIMessageChunk(content="I am", id="ai2") - + HumanMessageChunk(content=" indeed.", id="human1") - == AIMessageChunk(content="I am indeed.", id="ai2") - ), "MessageChunk + MessageChunk should be a MessageChunk of same class as the left side" # noqa: E501 + assert AIMessageChunk(content="I am", id="ai2") + HumanMessageChunk( + content=" indeed.", id="human1" + ) == AIMessageChunk(content="I am indeed.", id="ai2"), ( + "MessageChunk + MessageChunk should be a MessageChunk " + "of same class as the left side" + ) - assert ( - AIMessageChunk(content="", additional_kwargs={"foo": "bar"}) - + AIMessageChunk(content="", additional_kwargs={"baz": "foo"}) - == AIMessageChunk(content="", additional_kwargs={"foo": "bar", "baz": "foo"}) - ), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501 + assert AIMessageChunk( + content="", additional_kwargs={"foo": "bar"} + ) + AIMessageChunk(content="", additional_kwargs={"baz": "foo"}) == AIMessageChunk( + content="", additional_kwargs={"foo": "bar", "baz": "foo"} + ), ( + "MessageChunk + MessageChunk should be a MessageChunk " + "with merged additional_kwargs" + ) - assert ( - AIMessageChunk( - content="", additional_kwargs={"function_call": {"name": "web_search"}} - ) - + AIMessageChunk( - content="", additional_kwargs={"function_call": {"arguments": None}} - ) - + AIMessageChunk( - content="", additional_kwargs={"function_call": {"arguments": "{\n"}} - ) - + AIMessageChunk( - content="", - additional_kwargs={ - "function_call": {"arguments": ' "query": "turtles"\n}'} - }, - ) - == AIMessageChunk( - content="", - additional_kwargs={ - "function_call": { - "name": "web_search", - "arguments": '{\n "query": "turtles"\n}', - } - }, - ) - ), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501 + assert AIMessageChunk( + content="", additional_kwargs={"function_call": {"name": "web_search"}} + ) + AIMessageChunk( + content="", additional_kwargs={"function_call": {"arguments": None}} + ) + AIMessageChunk( + content="", additional_kwargs={"function_call": {"arguments": "{\n"}} + ) + AIMessageChunk( + content="", + additional_kwargs={"function_call": {"arguments": ' "query": "turtles"\n}'}}, + ) == AIMessageChunk( + content="", + additional_kwargs={ + "function_call": { + "name": "web_search", + "arguments": '{\n "query": "turtles"\n}', + } + }, + ), ( + "MessageChunk + MessageChunk should be a MessageChunk " + "with merged additional_kwargs" + ) # Test tool calls assert ( @@ -181,97 +180,107 @@ def test_message_chunks() -> None: def test_chat_message_chunks() -> None: assert ChatMessageChunk(role="User", content="I am", id="ai4") + ChatMessageChunk( role="User", content=" indeed." - ) == ChatMessageChunk( - id="ai4", role="User", content="I am indeed." - ), "ChatMessageChunk + ChatMessageChunk should be a ChatMessageChunk" + ) == ChatMessageChunk(id="ai4", role="User", content="I am indeed."), ( + "ChatMessageChunk + ChatMessageChunk should be a ChatMessageChunk" + ) with pytest.raises(ValueError): ChatMessageChunk(role="User", content="I am") + ChatMessageChunk( role="Assistant", content=" indeed." ) - assert ( - ChatMessageChunk(role="User", content="I am") - + AIMessageChunk(content=" indeed.") - == ChatMessageChunk(role="User", content="I am indeed.") - ), "ChatMessageChunk + other MessageChunk should be a ChatMessageChunk with the left side's role" # noqa: E501 + assert ChatMessageChunk(role="User", content="I am") + AIMessageChunk( + content=" indeed." + ) == ChatMessageChunk(role="User", content="I am indeed."), ( + "ChatMessageChunk + other MessageChunk should be a ChatMessageChunk " + "with the left side's role" + ) assert AIMessageChunk(content="I am") + ChatMessageChunk( role="User", content=" indeed." - ) == AIMessageChunk( - content="I am indeed." - ), "Other MessageChunk + ChatMessageChunk should be a MessageChunk as the left side" + ) == AIMessageChunk(content="I am indeed."), ( + "Other MessageChunk + ChatMessageChunk should be a MessageChunk " + "as the left side" + ) def test_complex_ai_message_chunks() -> None: assert AIMessageChunk(content=["I am"], id="ai4") + AIMessageChunk( content=[" indeed."] - ) == AIMessageChunk( - id="ai4", content=["I am", " indeed."] - ), "Content concatenation with arrays of strings should naively combine" + ) == AIMessageChunk(id="ai4", content=["I am", " indeed."]), ( + "Content concatenation with arrays of strings should naively combine" + ) assert AIMessageChunk(content=[{"index": 0, "text": "I am"}]) + AIMessageChunk( content=" indeed." - ) == AIMessageChunk( - content=[{"index": 0, "text": "I am"}, " indeed."] - ), "Concatenating mixed content arrays should naively combine them" + ) == AIMessageChunk(content=[{"index": 0, "text": "I am"}, " indeed."]), ( + "Concatenating mixed content arrays should naively combine them" + ) - assert ( - AIMessageChunk(content=[{"index": 0, "text": "I am"}]) - + AIMessageChunk(content=[{"index": 0, "text": " indeed."}]) - == AIMessageChunk(content=[{"index": 0, "text": "I am indeed."}]) - ), "Concatenating when both content arrays are dicts with the same index should merge" # noqa: E501 + assert AIMessageChunk(content=[{"index": 0, "text": "I am"}]) + AIMessageChunk( + content=[{"index": 0, "text": " indeed."}] + ) == AIMessageChunk(content=[{"index": 0, "text": "I am indeed."}]), ( + "Concatenating when both content arrays are dicts with the same index " + "should merge" + ) assert AIMessageChunk(content=[{"index": 0, "text": "I am"}]) + AIMessageChunk( content=[{"text": " indeed."}] - ) == AIMessageChunk( - content=[{"index": 0, "text": "I am"}, {"text": " indeed."}] - ), "Concatenating when one chunk is missing an index should not merge or throw" # noqa: E501 + ) == AIMessageChunk(content=[{"index": 0, "text": "I am"}, {"text": " indeed."}]), ( + "Concatenating when one chunk is missing an index should not merge or throw" + ) - assert ( - AIMessageChunk(content=[{"index": 0, "text": "I am"}]) - + AIMessageChunk(content=[{"index": 2, "text": " indeed."}]) - == AIMessageChunk( - content=[{"index": 0, "text": "I am"}, {"index": 2, "text": " indeed."}] - ) - ), "Concatenating when both content arrays are dicts with a gap between indexes should not result in a holey array" # noqa: E501 + assert AIMessageChunk(content=[{"index": 0, "text": "I am"}]) + AIMessageChunk( + content=[{"index": 2, "text": " indeed."}] + ) == AIMessageChunk( + content=[{"index": 0, "text": "I am"}, {"index": 2, "text": " indeed."}] + ), ( + "Concatenating when both content arrays are dicts with a gap between indexes " + "should not result in a holey array" + ) - assert ( - AIMessageChunk(content=[{"index": 0, "text": "I am"}]) - + AIMessageChunk(content=[{"index": 1, "text": " indeed."}]) - == AIMessageChunk( - content=[{"index": 0, "text": "I am"}, {"index": 1, "text": " indeed."}] - ) - ), "Concatenating when both content arrays are dicts with separate indexes should not merge" # noqa: E501 + assert AIMessageChunk(content=[{"index": 0, "text": "I am"}]) + AIMessageChunk( + content=[{"index": 1, "text": " indeed."}] + ) == AIMessageChunk( + content=[{"index": 0, "text": "I am"}, {"index": 1, "text": " indeed."}] + ), ( + "Concatenating when both content arrays are dicts with separate indexes " + "should not merge" + ) - assert ( - AIMessageChunk(content=[{"index": 0, "text": "I am", "type": "text_block"}]) - + AIMessageChunk( - content=[{"index": 0, "text": " indeed.", "type": "text_block"}] - ) - == AIMessageChunk( - content=[{"index": 0, "text": "I am indeed.", "type": "text_block"}] - ) - ), "Concatenating when both content arrays are dicts with the same index and type should merge" # noqa: E501 + assert AIMessageChunk( + content=[{"index": 0, "text": "I am", "type": "text_block"}] + ) + AIMessageChunk( + content=[{"index": 0, "text": " indeed.", "type": "text_block"}] + ) == AIMessageChunk( + content=[{"index": 0, "text": "I am indeed.", "type": "text_block"}] + ), ( + "Concatenating when both content arrays are dicts with the same index and type " + "should merge" + ) - assert ( - AIMessageChunk(content=[{"index": 0, "text": "I am", "type": "text_block"}]) - + AIMessageChunk( - content=[{"index": 0, "text": " indeed.", "type": "text_block_delta"}] - ) - == AIMessageChunk( - content=[{"index": 0, "text": "I am indeed.", "type": "text_block"}] - ) - ), "Concatenating when both content arrays are dicts with the same index and different types should merge without updating type" # noqa: E501 + assert AIMessageChunk( + content=[{"index": 0, "text": "I am", "type": "text_block"}] + ) + AIMessageChunk( + content=[{"index": 0, "text": " indeed.", "type": "text_block_delta"}] + ) == AIMessageChunk( + content=[{"index": 0, "text": "I am indeed.", "type": "text_block"}] + ), ( + "Concatenating when both content arrays are dicts with the same index " + "and different types should merge without updating type" + ) - assert ( - AIMessageChunk(content=[{"index": 0, "text": "I am", "type": "text_block"}]) - + AIMessageChunk(content="", response_metadata={"extra": "value"}) - == AIMessageChunk( - content=[{"index": 0, "text": "I am", "type": "text_block"}], - response_metadata={"extra": "value"}, - ) - ), "Concatenating when one content is an array and one is an empty string should not add a new item, but should concat other fields" # noqa: E501 + assert AIMessageChunk( + content=[{"index": 0, "text": "I am", "type": "text_block"}] + ) + AIMessageChunk( + content="", response_metadata={"extra": "value"} + ) == AIMessageChunk( + content=[{"index": 0, "text": "I am", "type": "text_block"}], + response_metadata={"extra": "value"}, + ), ( + "Concatenating when one content is an array and one is an empty string " + "should not add a new item, but should concat other fields" + ) def test_function_message_chunks() -> None: @@ -290,9 +299,9 @@ def test_function_message_chunks() -> None: def test_ai_message_chunks() -> None: assert AIMessageChunk(example=True, content="I am") + AIMessageChunk( example=True, content=" indeed." - ) == AIMessageChunk( - example=True, content="I am indeed." - ), "AIMessageChunk + AIMessageChunk should be a AIMessageChunk" + ) == AIMessageChunk(example=True, content="I am indeed."), ( + "AIMessageChunk + AIMessageChunk should be a AIMessageChunk" + ) with pytest.raises(ValueError): AIMessageChunk(example=True, content="I am") + AIMessageChunk( diff --git a/libs/core/tests/unit_tests/test_outputs.py b/libs/core/tests/unit_tests/test_outputs.py index a2f6e3e80e2b8..d783d1634a99d 100644 --- a/libs/core/tests/unit_tests/test_outputs.py +++ b/libs/core/tests/unit_tests/test_outputs.py @@ -5,24 +5,25 @@ def test_generation_chunk() -> None: assert GenerationChunk(text="Hello, ") + GenerationChunk( text="world!" - ) == GenerationChunk( - text="Hello, world!" - ), "GenerationChunk + GenerationChunk should be a GenerationChunk" - - assert ( - GenerationChunk(text="Hello, ") - + GenerationChunk(text="world!", generation_info={"foo": "bar"}) - == GenerationChunk(text="Hello, world!", generation_info={"foo": "bar"}) - ), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501 - - assert ( - GenerationChunk(text="Hello, ") - + GenerationChunk(text="world!", generation_info={"foo": "bar"}) - + GenerationChunk(text="!", generation_info={"baz": "foo"}) - == GenerationChunk( - text="Hello, world!!", generation_info={"foo": "bar", "baz": "foo"} - ) - ), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501 + ) == GenerationChunk(text="Hello, world!"), ( + "GenerationChunk + GenerationChunk should be a GenerationChunk" + ) + + assert GenerationChunk(text="Hello, ") + GenerationChunk( + text="world!", generation_info={"foo": "bar"} + ) == GenerationChunk(text="Hello, world!", generation_info={"foo": "bar"}), ( + "GenerationChunk + GenerationChunk should be a GenerationChunk " + "with merged generation_info" + ) + + assert GenerationChunk(text="Hello, ") + GenerationChunk( + text="world!", generation_info={"foo": "bar"} + ) + GenerationChunk(text="!", generation_info={"baz": "foo"}) == GenerationChunk( + text="Hello, world!!", generation_info={"foo": "bar", "baz": "foo"} + ), ( + "GenerationChunk + GenerationChunk should be a GenerationChunk " + "with merged generation_info" + ) def test_chat_generation_chunk() -> None: @@ -30,31 +31,32 @@ def test_chat_generation_chunk() -> None: message=HumanMessageChunk(content="Hello, ") ) + ChatGenerationChunk( message=HumanMessageChunk(content="world!") + ) == ChatGenerationChunk(message=HumanMessageChunk(content="Hello, world!")), ( + "ChatGenerationChunk + ChatGenerationChunk should be a ChatGenerationChunk" + ) + + assert ChatGenerationChunk( + message=HumanMessageChunk(content="Hello, ") + ) + ChatGenerationChunk( + message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"} + ) == ChatGenerationChunk( + message=HumanMessageChunk(content="Hello, world!"), + generation_info={"foo": "bar"}, + ), ( + "GenerationChunk + GenerationChunk should be a GenerationChunk " + "with merged generation_info" + ) + + assert ChatGenerationChunk( + message=HumanMessageChunk(content="Hello, ") + ) + ChatGenerationChunk( + message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"} + ) + ChatGenerationChunk( + message=HumanMessageChunk(content="!"), generation_info={"baz": "foo"} ) == ChatGenerationChunk( - message=HumanMessageChunk(content="Hello, world!") - ), "ChatGenerationChunk + ChatGenerationChunk should be a ChatGenerationChunk" - - assert ( - ChatGenerationChunk(message=HumanMessageChunk(content="Hello, ")) - + ChatGenerationChunk( - message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"} - ) - == ChatGenerationChunk( - message=HumanMessageChunk(content="Hello, world!"), - generation_info={"foo": "bar"}, - ) - ), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501 - - assert ( - ChatGenerationChunk(message=HumanMessageChunk(content="Hello, ")) - + ChatGenerationChunk( - message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"} - ) - + ChatGenerationChunk( - message=HumanMessageChunk(content="!"), generation_info={"baz": "foo"} - ) - == ChatGenerationChunk( - message=HumanMessageChunk(content="Hello, world!!"), - generation_info={"foo": "bar", "baz": "foo"}, - ) - ), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501 + message=HumanMessageChunk(content="Hello, world!!"), + generation_info={"foo": "bar", "baz": "foo"}, + ), ( + "GenerationChunk + GenerationChunk should be a GenerationChunk " + "with merged generation_info" + ) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 03e7fcc01e5e4..d18577f55c211 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -1572,7 +1572,12 @@ def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None: } assert tool_.invoke({"x": 5, "y": "bar"}) == "bar" assert tool_.invoke( - {"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"} + { + "name": "foo", + "args": {"x": 5, "y": "bar"}, + "id": "123", + "type": "tool_call", + } ) == ToolMessage("bar", tool_call_id="123", name="foo") expected_error = ( ValidationError if not isinstance(tool_, InjectedTool) else TypeError @@ -1615,7 +1620,12 @@ def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None: } assert tool_.invoke({"x": 5, "y": "bar"}) == "bar" assert tool_.invoke( - {"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"} + { + "name": "foo", + "args": {"x": 5, "y": "bar"}, + "id": "123", + "type": "tool_call", + } ) == ToolMessage("bar", tool_call_id="123", name="foo") expected_error = ( ValidationError if not isinstance(tool_, InjectedTool) else TypeError @@ -1655,7 +1665,12 @@ def test_tool_injected_arg() -> None: } assert tool_.invoke({"x": 5, "y": "bar"}) == "bar" assert tool_.invoke( - {"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"} + { + "name": "foo", + "args": {"x": 5, "y": "bar"}, + "id": "123", + "type": "tool_call", + } ) == ToolMessage("bar", tool_call_id="123", name="foo") expected_error = ( ValidationError if not isinstance(tool_, InjectedTool) else TypeError @@ -1716,7 +1731,12 @@ def _run(self, x: int, y: str) -> Any: } assert tool_.invoke({"x": 5, "y": "bar"}) == "bar" assert tool_.invoke( - {"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"} + { + "name": "foo", + "args": {"x": 5, "y": "bar"}, + "id": "123", + "type": "tool_call", + } ) == ToolMessage("bar", tool_call_id="123", name="foo") expected_error = ( ValidationError if not isinstance(tool_, InjectedTool) else TypeError @@ -1988,6 +2008,7 @@ def test__get_all_basemodel_annotations_v2(use_v1_namespace: bool) -> None: class ModelA(BaseModel1, Generic[A], extra="allow"): a: A + else: from pydantic import BaseModel as BaseModel2 from pydantic import ConfigDict @@ -2273,7 +2294,12 @@ def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallId]) -> ToolMessage return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore assert foo.invoke( - {"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"} + { + "type": "tool_call", + "args": {"x": 0}, + "name": "foo", + "id": "bar", + } ) == ToolMessage(0, tool_call_id="bar") # type: ignore with pytest.raises(ValueError): @@ -2285,7 +2311,12 @@ def foo2(x: int, tool_call_id: Annotated[str, InjectedToolCallId()]) -> ToolMess return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore assert foo2.invoke( - {"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"} + { + "type": "tool_call", + "args": {"x": 0}, + "name": "foo", + "id": "bar", + } ) == ToolMessage(0, tool_call_id="bar") # type: ignore @@ -2322,7 +2353,12 @@ def foo(x: int) -> Bar: return Bar(x=x) assert foo.invoke( - {"type": "tool_call", "args": {"x": 0}, "name": "foo", "id": "bar"} + { + "type": "tool_call", + "args": {"x": 0}, + "name": "foo", + "id": "bar", + } ) == Bar(x=0) diff --git a/libs/core/tests/unit_tests/tracers/test_memory_stream.py b/libs/core/tests/unit_tests/tracers/test_memory_stream.py index 451ab35bb678e..74541f42adb0c 100644 --- a/libs/core/tests/unit_tests/tracers/test_memory_stream.py +++ b/libs/core/tests/unit_tests/tracers/test_memory_stream.py @@ -54,9 +54,9 @@ async def consumer() -> AsyncIterator[dict]: # To verify that the producer and consumer are running in parallel, we # expect the delta_time to be smaller than the sleep delay in the producer # * # of items = 30 ms - assert ( - math.isclose(delta_time, 0, abs_tol=0.010) is True - ), f"delta_time: {delta_time}" + assert math.isclose(delta_time, 0, abs_tol=0.010) is True, ( + f"delta_time: {delta_time}" + ) async def test_queue_for_streaming_via_sync_call() -> None: @@ -107,9 +107,9 @@ async def consumer() -> AsyncIterator[dict]: # To verify that the producer and consumer are running in parallel, we # expect the delta_time to be smaller than the sleep delay in the producer # * # of items = 30 ms - assert ( - math.isclose(delta_time, 0, abs_tol=0.010) is True - ), f"delta_time: {delta_time}" + assert math.isclose(delta_time, 0, abs_tol=0.010) is True, ( + f"delta_time: {delta_time}" + ) def test_send_to_closed_stream() -> None: diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index bf1a4f56337fe..03803e8657d15 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -13,8 +13,8 @@ from typing import TypedDict as TypingTypedDict import pytest -from pydantic import BaseModel as BaseModelV2Maybe # pydantic: ignore -from pydantic import Field as FieldV2Maybe # pydantic: ignore +from pydantic import BaseModel as BaseModelV2Maybe # pydantic: ignore +from pydantic import Field as FieldV2Maybe # pydantic: ignore from typing_extensions import ( TypedDict as ExtensionsTypedDict, ) diff --git a/libs/core/tests/unit_tests/utils/test_usage.py b/libs/core/tests/unit_tests/utils/test_usage.py index 0d845d0078920..099917219b48a 100644 --- a/libs/core/tests/unit_tests/utils/test_usage.py +++ b/libs/core/tests/unit_tests/utils/test_usage.py @@ -1,3 +1,5 @@ +import operator + import pytest from langchain_core.utils.usage import _dict_int_op @@ -6,7 +8,7 @@ def test_dict_int_op_add() -> None: left = {"a": 1, "b": 2} right = {"b": 3, "c": 4} - result = _dict_int_op(left, right, lambda x, y: x + y) + result = _dict_int_op(left, right, operator.add) assert result == {"a": 1, "b": 5, "c": 4} @@ -20,7 +22,7 @@ def test_dict_int_op_subtract() -> None: def test_dict_int_op_nested() -> None: left = {"a": 1, "b": {"c": 2, "d": 3}} right = {"a": 2, "b": {"c": 1, "e": 4}} - result = _dict_int_op(left, right, lambda x, y: x + y) + result = _dict_int_op(left, right, operator.add) assert result == {"a": 3, "b": {"c": 3, "d": 3, "e": 4}} @@ -28,11 +30,11 @@ def test_dict_int_op_max_depth_exceeded() -> None: left = {"a": {"b": {"c": 1}}} right = {"a": {"b": {"c": 2}}} with pytest.raises(ValueError): - _dict_int_op(left, right, lambda x, y: x + y, max_depth=2) + _dict_int_op(left, right, operator.add, max_depth=2) def test_dict_int_op_invalid_types() -> None: left = {"a": 1, "b": "string"} right = {"a": 2, "b": 3} with pytest.raises(ValueError): - _dict_int_op(left, right, lambda x, y: x + y) + _dict_int_op(left, right, operator.add)