From f6f7457d5f01bc0488e6048f3b3f061209046fa3 Mon Sep 17 00:00:00 2001 From: HavenDV Date: Sat, 1 Jun 2024 14:21:20 +0400 Subject: [PATCH] feat: Added ability to pass settings inside LLM chain. --- src/Core/src/Chains/Chain.cs | 6 ++++-- .../StackableChains/Agents/ReActAgentExecutorChain.cs | 5 ++++- src/Core/src/Chains/StackableChains/LLMChain.cs | 7 +++++-- src/Meta/test/WikiTests.cs | 2 +- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/Core/src/Chains/Chain.cs b/src/Core/src/Chains/Chain.cs index 42a6b124..e9dd12b5 100644 --- a/src/Core/src/Chains/Chain.cs +++ b/src/Core/src/Chains/Chain.cs @@ -75,13 +75,15 @@ public static DoChain Do( /// /// /// + /// /// public static LLMChain LLM( IChatModel llm, string inputKey = "text", - string outputKey = "text") + string outputKey = "text", + ChatSettings? settings = null) { - return new LLMChain(llm, inputKey, outputKey); + return new LLMChain(llm, inputKey, outputKey, settings); } /// diff --git a/src/Core/src/Chains/StackableChains/Agents/ReActAgentExecutorChain.cs b/src/Core/src/Chains/StackableChains/Agents/ReActAgentExecutorChain.cs index 435481b3..018d44f4 100644 --- a/src/Core/src/Chains/StackableChains/Agents/ReActAgentExecutorChain.cs +++ b/src/Core/src/Chains/StackableChains/Agents/ReActAgentExecutorChain.cs @@ -108,7 +108,10 @@ private void InitializeChain() | Set(toolNames, "tool_names") | LoadMemory(_conversationBufferMemory, outputKey: "history") | Template(_reActPrompt) - | Chain.LLM(_model).UseCache(_useCache) + | Chain.LLM(_model, settings: new ChatSettings + { + StopSequences = ["Observation", "[END]"], + }).UseCache(_useCache) | UpdateMemory(_conversationBufferMemory, requestKey: "input", responseKey: "text") | ReActParser(inputKey: "text", outputKey: ReActAnswer); diff --git a/src/Core/src/Chains/StackableChains/LLMChain.cs b/src/Core/src/Chains/StackableChains/LLMChain.cs index 9c9ad670..0ae267d4 100644 --- a/src/Core/src/Chains/StackableChains/LLMChain.cs +++ b/src/Core/src/Chains/StackableChains/LLMChain.cs @@ -10,6 +10,7 @@ public class LLMChain : BaseStackableChain { private readonly IChatModel _llm; private bool _useCache; + private ChatSettings _settings; private const string CACHE_DIR = "cache"; @@ -17,12 +18,14 @@ public class LLMChain : BaseStackableChain public LLMChain( IChatModel llm, string inputKey = "prompt", - string outputKey = "text" + string outputKey = "text", + ChatSettings? settings = null ) { InputKeys = new[] { inputKey }; OutputKeys = new[] { outputKey }; _llm = llm; + _settings = settings ?? new ChatSettings(); } string? GetCachedAnswer(string prompt) @@ -63,7 +66,7 @@ protected override async Task InternalCallAsync( } } - var response = await _llm.GenerateAsync(prompt, cancellationToken: cancellationToken).ConfigureAwait(false); + var response = await _llm.GenerateAsync(prompt, settings: _settings, cancellationToken: cancellationToken).ConfigureAwait(false); responseContent = response.Messages.Last().Content; if (_useCache) SaveCachedAnswer(prompt, responseContent); diff --git a/src/Meta/test/WikiTests.cs b/src/Meta/test/WikiTests.cs index f73a3cec..9bb2a3ea 100644 --- a/src/Meta/test/WikiTests.cs +++ b/src/Meta/test/WikiTests.cs @@ -50,7 +50,7 @@ public async Task AgentWithOllamaReact() var apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new InvalidOperationException("OpenAI API key is not set"); - var llm = new Gpt35TurboModel(apiKey); + var llm = new Gpt35TurboModel(apiKey).UseConsoleForDebug(); // create a google search var searchApiKey =