Skip to content

Commit

Permalink
langchain: Bump ruff version to 0.9
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Jan 14, 2025
1 parent d9b856a commit 4d0d1d4
Show file tree
Hide file tree
Showing 27 changed files with 315 additions and 248 deletions.
3 changes: 1 addition & 2 deletions libs/langchain/langchain/agents/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ def initialize_agent(
pass
else:
raise ValueError(
"Somehow both `agent` and `agent_path` are None, "
"this should never happen."
"Somehow both `agent` and `agent_path` are None, this should never happen."
)
return AgentExecutor.from_agent_and_tools(
agent=agent_obj,
Expand Down
3 changes: 1 addition & 2 deletions libs/langchain/langchain/agents/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ def load_agent_from_config(
if load_from_tools:
if llm is None:
raise ValueError(
"If `load_from_llm_and_tools` is set to True, "
"then LLM must be provided"
"If `load_from_llm_and_tools` is set to True, then LLM must be provided"
)
if tools is None:
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain/langchain/callbacks/tracers/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ def on_text(
except TracerException:
crumbs_str = ""
self.function_callback(
f'{get_colored_text("[text]", color="blue")}'
f' {get_bolded_text(f"{crumbs_str}New text:")}\n{text}'
f"{get_colored_text('[text]', color='blue')}"
f" {get_bolded_text(f'{crumbs_str}New text:')}\n{text}"
)
1 change: 1 addition & 0 deletions libs/langchain/langchain/chains/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def from_llm_and_api_docs(
@property
def _chain_type(self) -> str:
return "api_chain"

except ImportError:

class APIChain: # type: ignore[no-redef]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _call(
response = (
f"{api_response.status_code}: {api_response.reason}"
+ f"\nFor {name} "
+ f"Called with args: {args.get('params','')}"
+ f"Called with args: {args.get('params', '')}"
)
else:
try:
Expand Down
5 changes: 4 additions & 1 deletion libs/langchain/langchain/chains/openai_tools/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def create_extraction_chain_pydantic(
if not isinstance(pydantic_schemas, list):
pydantic_schemas = [pydantic_schemas]
prompt = ChatPromptTemplate.from_messages(
[("system", system_message), ("user", "{input}")]
[
("system", system_message),
("user", "{input}"),
]
)
functions = [convert_pydantic_to_openai_function(p) for p in pydantic_schemas]
tools = [{"type": "function", "function": d} for d in functions]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
"If the context isn't useful, return the original answer."
)
CHAT_REFINE_PROMPT = ChatPromptTemplate.from_messages(
[("human", "{question}"), ("ai", "{existing_answer}"), ("human", refine_template)]
[
("human", "{question}"),
("ai", "{existing_answer}"),
("human", refine_template),
]
)
REFINE_PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=DEFAULT_REFINE_PROMPT,
Expand All @@ -60,7 +64,10 @@
"answer any questions"
)
CHAT_QUESTION_PROMPT = ChatPromptTemplate.from_messages(
[("system", chat_qa_prompt_template), ("human", "{question}")]
[
("system", chat_qa_prompt_template),
("human", "{question}"),
]
)
QUESTION_PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=DEFAULT_TEXT_QA_PROMPT,
Expand Down
6 changes: 4 additions & 2 deletions libs/langchain/langchain/chains/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,9 @@ def _call(
_input = inputs[self.input_key]
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
for i, chain in enumerate(self.chains):
_input = chain.run(_input, callbacks=_run_manager.get_child(f"step_{i+1}"))
_input = chain.run(
_input, callbacks=_run_manager.get_child(f"step_{i + 1}")
)
if self.strip_outputs:
_input = _input.strip()
_run_manager.on_text(
Expand All @@ -196,7 +198,7 @@ async def _acall(
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
for i, chain in enumerate(self.chains):
_input = await chain.arun(
_input, callbacks=_run_manager.get_child(f"step_{i+1}")
_input, callbacks=_run_manager.get_child(f"step_{i + 1}")
)
if self.strip_outputs:
_input = _input.strip()
Expand Down
6 changes: 5 additions & 1 deletion libs/langchain/langchain/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,11 @@ def with_config(
queued_declarative_operations = list(self._queued_declarative_operations)
if remaining_config:
queued_declarative_operations.append(
("with_config", (), {"config": remaining_config})
(
"with_config",
(),
{"config": remaining_config},
)
)
return _ConfigurableModel(
default_config={**self._default_config, **model_params},
Expand Down
3 changes: 1 addition & 2 deletions libs/langchain/langchain/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ def init_embeddings(
if not model:
providers = _SUPPORTED_PROVIDERS.keys()
raise ValueError(
"Must specify model name. "
f"Supported providers are: {', '.join(providers)}"
f"Must specify model name. Supported providers are: {', '.join(providers)}"
)

provider, model_name = _infer_model_and_provider(model, provider=provider)
Expand Down
15 changes: 12 additions & 3 deletions libs/langchain/langchain/evaluation/embedding_distance/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,10 @@ async def _acall(
Dict[str, Any]: The computed score.
"""
embedded = await self.embeddings.aembed_documents(
[inputs["prediction"], inputs["reference"]]
[
inputs["prediction"],
inputs["reference"],
]
)
vectors = np.array(embedded)
score = self._compute_score(vectors)
Expand Down Expand Up @@ -427,7 +430,10 @@ def _call(
"""
vectors = np.array(
self.embeddings.embed_documents(
[inputs["prediction"], inputs["prediction_b"]]
[
inputs["prediction"],
inputs["prediction_b"],
]
)
)
score = self._compute_score(vectors)
Expand All @@ -449,7 +455,10 @@ async def _acall(
Dict[str, Any]: The computed score.
"""
embedded = await self.embeddings.aembed_documents(
[inputs["prediction"], inputs["prediction_b"]]
[
inputs["prediction"],
inputs["prediction_b"],
]
)
vectors = np.array(embedded)
score = self._compute_score(vectors)
Expand Down
10 changes: 8 additions & 2 deletions libs/langchain/langchain/memory/chat_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
input_str, output_str = self._get_input_output(inputs, outputs)
self.chat_memory.add_messages(
[HumanMessage(content=input_str), AIMessage(content=output_str)]
[
HumanMessage(content=input_str),
AIMessage(content=output_str),
]
)

async def asave_context(
Expand All @@ -80,7 +83,10 @@ async def asave_context(
"""Save context from this conversation to buffer."""
input_str, output_str = self._get_input_output(inputs, outputs)
await self.chat_memory.aadd_messages(
[HumanMessage(content=input_str), AIMessage(content=output_str)]
[
HumanMessage(content=input_str),
AIMessage(content=output_str),
]
)

def clear(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ def rerank(
result_dicts = []
for res in results:
result_dicts.append(
{"index": res.index, "relevance_score": res.relevance_score}
{
"index": res.index,
"relevance_score": res.relevance_score,
}
)
return result_dicts

Expand Down
5 changes: 3 additions & 2 deletions libs/langchain/langchain/retrievers/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def rank_fusion(
retriever.invoke(
query,
patch_config(
config, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
config, callbacks=run_manager.get_child(tag=f"retriever_{i + 1}")
),
)
for i, retriever in enumerate(self.retrievers)
Expand Down Expand Up @@ -265,7 +265,8 @@ async def arank_fusion(
retriever.ainvoke(
query,
patch_config(
config, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
config,
callbacks=run_manager.get_child(tag=f"retriever_{i + 1}"),
),
)
for i, retriever in enumerate(self.retrievers)
Expand Down
3 changes: 1 addition & 2 deletions libs/langchain/langchain/smith/evaluation/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,7 @@ def _get_prompt(inputs: Dict[str, Any]) -> str:
if "prompt" in inputs:
if not isinstance(inputs["prompt"], str):
raise InputFormatError(
"Expected string for 'prompt', got"
f" {type(inputs['prompt']).__name__}"
f"Expected string for 'prompt', got {type(inputs['prompt']).__name__}"
)
prompts = [inputs["prompt"]]
elif "prompts" in inputs:
Expand Down
Loading

0 comments on commit 4d0d1d4

Please sign in to comment.