From 49d6a13b016c35e6557589087b6544231e0a8798 Mon Sep 17 00:00:00 2001 From: Tom Bruyneel Date: Wed, 6 Mar 2024 22:24:20 +0100 Subject: [PATCH] add function calling --- src/ConversationalSearchPlatform.sln | 5 +- .../ConversationalSearchEndpoints.cs | 7 +- .../Preferences/ChatComponent/Chat.razor | 2 +- ...ersationalSearchPlatform.BackOffice.csproj | 4 + .../Extensions/ChatBuilderExtensions.cs | 4 +- .../Extensions/ChatChoiceExtensions.cs | 4 + .../Extensions/ChatResultExtensions.cs | 25 +- .../Jobs/WebsitePageIndexingJob.cs | 4 +- .../Implementations/ConversationService.cs | 351 ++++++++++++++++-- .../Services/Models/ConversationHistory.cs | 34 +- .../Services/Models/HoldConversation.cs | 6 +- .../Services/Models/StreamResult.cs | 2 + .../Weaviate/Queries/GetByPromptFiltered.cs | 10 +- src/nuget.config | 8 + 14 files changed, 373 insertions(+), 93 deletions(-) create mode 100644 src/nuget.config diff --git a/src/ConversationalSearchPlatform.sln b/src/ConversationalSearchPlatform.sln index e4d76a4..07fd622 100644 --- a/src/ConversationalSearchPlatform.sln +++ b/src/ConversationalSearchPlatform.sln @@ -4,6 +4,9 @@ Microsoft Visual Studio Solution File, Format Version 12.00 VisualStudioVersion = 17.8.34525.116 MinimumVisualStudioVersion = 10.0.40219.1 Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{33CCB6E3-745D-42A7-A75B-C9B30AD359E4}" + ProjectSection(SolutionItems) = preProject + nuget.config = nuget.config + EndProjectSection EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ConversationalSearchPlatform.BackOffice", "backoffice\ConversationalSearchPlatform.BackOffice\ConversationalSearchPlatform.BackOffice.csproj", "{72303287-C899-41CE-95A2-09534F6B3066}" EndProject @@ -11,7 +14,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ConversationalSearchPlatfor EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ConversationalSearchPlatform.Widget", "widget\ConversationalSearchPlatform.Widget\ConversationalSearchPlatform.Widget.csproj", "{687A9754-5AAB-47F8-AA76-9A959C1D613E}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Jint.Fetch", "Jint.Fetch\Jint.Fetch.csproj", "{88306706-982B-44EB-8678-96E709FB17C6}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Jint.Fetch", "Jint.Fetch\Jint.Fetch.csproj", "{88306706-982B-44EB-8678-96E709FB17C6}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution diff --git a/src/backoffice/ConversationalSearchPlatform.BackOffice/Api/Conversation/ConversationalSearchEndpoints.cs b/src/backoffice/ConversationalSearchPlatform.BackOffice/Api/Conversation/ConversationalSearchEndpoints.cs index 1712077..b0cf87b 100644 --- a/src/backoffice/ConversationalSearchPlatform.BackOffice/Api/Conversation/ConversationalSearchEndpoints.cs +++ b/src/backoffice/ConversationalSearchPlatform.BackOffice/Api/Conversation/ConversationalSearchEndpoints.cs @@ -15,6 +15,7 @@ using Finbuckle.MultiTenant; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Mvc; +using Microsoft.Identity.Client; using Swashbuckle.AspNetCore.Filters; namespace ConversationalSearchPlatform.BackOffice.Api.Conversation; @@ -142,7 +143,7 @@ CancellationToken cancellationToken var tenantId = httpContext.GetTenantHeader(); - var holdConversation = new HoldConversation(conversationId, tenantId, request.Prompt, request.Context, request.Debug, (Language)request.Language); + var holdConversation = new HoldConversation(conversationId, tenantId, new Rystem.OpenAi.Chat.ChatMessage() { Content = request.Prompt, Role = Rystem.OpenAi.Chat.ChatRole.User }, request.Context, request.Debug, (Language)request.Language); await foreach (var crr in conversationService .ConverseStreamingAsync( @@ -187,7 +188,7 @@ private static async Task HandleHoldConversation IConversationService conversationService, CancellationToken cancellationToken) { - var holdConversation = new HoldConversation(conversationId, tenantId, request.Prompt, request.Context, request.Debug, (Language)request.Language); + var holdConversation = new HoldConversation(conversationId, tenantId, new Rystem.OpenAi.Chat.ChatMessage() { Content = request.Prompt, Role = Rystem.OpenAi.Chat.ChatRole.User }, request.Context, request.Debug, (Language)request.Language); var response = await conversationService.ConverseAsync(holdConversation, cancellationToken); return MapToApiResponse(response); @@ -307,7 +308,7 @@ private static async Task HandleHoldConversationWebSocketMessage( } var tenantId = tenant.Id!; - var holdConversation = new HoldConversation(request.ConversationId.Value, tenantId, request.Prompt, request.Context, request.Debug, (Language)request.Language); + var holdConversation = new HoldConversation(request.ConversationId.Value, tenantId, new Rystem.OpenAi.Chat.ChatMessage() { Content = request.Prompt, Role = Rystem.OpenAi.Chat.ChatRole.User }, request.Context, request.Debug, (Language)request.Language); await foreach (var crr in conversationService .ConverseStreamingAsync( diff --git a/src/backoffice/ConversationalSearchPlatform.BackOffice/Components/Pages/Preferences/ChatComponent/Chat.razor b/src/backoffice/ConversationalSearchPlatform.BackOffice/Components/Pages/Preferences/ChatComponent/Chat.razor index 42cfeec..63677b6 100644 --- a/src/backoffice/ConversationalSearchPlatform.BackOffice/Components/Pages/Preferences/ChatComponent/Chat.razor +++ b/src/backoffice/ConversationalSearchPlatform.BackOffice/Components/Pages/Preferences/ChatComponent/Chat.razor @@ -192,7 +192,7 @@ try { Loading = true; - conversationResult = await ConversationService.ConverseAsync(new HoldConversation(CurrentConversationId!.Value, TenantInfo.Id!, prompt, cleanedContextVariables, Debug, Language)); + conversationResult = await ConversationService.ConverseAsync(new HoldConversation(CurrentConversationId!.Value, TenantInfo.Id!, new Rystem.OpenAi.Chat.ChatMessage() { Content = prompt, Role = Rystem.OpenAi.Chat.ChatRole.User }, cleanedContextVariables, Debug, Language)); } finally { diff --git a/src/backoffice/ConversationalSearchPlatform.BackOffice/ConversationalSearchPlatform.BackOffice.csproj b/src/backoffice/ConversationalSearchPlatform.BackOffice/ConversationalSearchPlatform.BackOffice.csproj index 5809b0f..3f52729 100644 --- a/src/backoffice/ConversationalSearchPlatform.BackOffice/ConversationalSearchPlatform.BackOffice.csproj +++ b/src/backoffice/ConversationalSearchPlatform.BackOffice/ConversationalSearchPlatform.BackOffice.csproj @@ -75,4 +75,8 @@ <_ContentIncludedByDefault Remove="Components\TryItOut\TryItOut.razor" /> + + + + diff --git a/src/backoffice/ConversationalSearchPlatform.BackOffice/Extensions/ChatBuilderExtensions.cs b/src/backoffice/ConversationalSearchPlatform.BackOffice/Extensions/ChatBuilderExtensions.cs index 8a61787..ac91a35 100644 --- a/src/backoffice/ConversationalSearchPlatform.BackOffice/Extensions/ChatBuilderExtensions.cs +++ b/src/backoffice/ConversationalSearchPlatform.BackOffice/Extensions/ChatBuilderExtensions.cs @@ -9,8 +9,8 @@ public static ChatRequestBuilder AddPreviousMessages(this ChatRequestBuilder cha { foreach (var conversation in previousMessages) { - chatRequestBuilder.AddUserMessage(conversation.Prompt); - chatRequestBuilder.AddAssistantMessage(conversation.Response); + chatRequestBuilder.AddMessage(conversation.Prompt); + chatRequestBuilder.AddMessage(conversation.Response); } return chatRequestBuilder; diff --git a/src/backoffice/ConversationalSearchPlatform.BackOffice/Extensions/ChatChoiceExtensions.cs b/src/backoffice/ConversationalSearchPlatform.BackOffice/Extensions/ChatChoiceExtensions.cs index 4029f3e..feafa04 100644 --- a/src/backoffice/ConversationalSearchPlatform.BackOffice/Extensions/ChatChoiceExtensions.cs +++ b/src/backoffice/ConversationalSearchPlatform.BackOffice/Extensions/ChatChoiceExtensions.cs @@ -19,6 +19,10 @@ public static bool IsAnswerCompleted(this ChatChoice chunk, ILogger logger) completed = true; break; + case { FinishReason: "function_call" }: + completed = true; + break; + case { FinishReason: "length" }: completed = true; logger.LogDebug("Stopped due to length"); diff --git a/src/backoffice/ConversationalSearchPlatform.BackOffice/Extensions/ChatResultExtensions.cs b/src/backoffice/ConversationalSearchPlatform.BackOffice/Extensions/ChatResultExtensions.cs index e16459a..ae209fd 100644 --- a/src/backoffice/ConversationalSearchPlatform.BackOffice/Extensions/ChatResultExtensions.cs +++ b/src/backoffice/ConversationalSearchPlatform.BackOffice/Extensions/ChatResultExtensions.cs @@ -4,20 +4,23 @@ namespace ConversationalSearchPlatform.BackOffice.Extensions; public static class ChatResultExtensions { - public static string CombineAnswers(this ChatResult chatResult) + public static ChatMessage GetFirstAnswer(this ChatResult chatResult) { - var answers = chatResult + var answer = chatResult .Choices? - .Select(choice => choice.Message) - .Where(message => message != null) - .Select(message => message!) - .Where(msg => msg.Role == ChatRole.Assistant) - .Select(message => message.Content) - .Where(content => content != null) ?? - Enumerable.Empty(); + .FirstOrDefault()? + .Message; - return string.Join(Environment.NewLine, answers) - .ReplaceLineEndings(); + if (answer == null) + { + answer = new ChatMessage() + { + Role = ChatRole.Assistant, + Content = string.Empty, + }; + } + + return answer; } public static string CombineStreamAnswer(this StreamingChatResult chatResult) diff --git a/src/backoffice/ConversationalSearchPlatform.BackOffice/Jobs/WebsitePageIndexingJob.cs b/src/backoffice/ConversationalSearchPlatform.BackOffice/Jobs/WebsitePageIndexingJob.cs index 16e5283..1f864f2 100644 --- a/src/backoffice/ConversationalSearchPlatform.BackOffice/Jobs/WebsitePageIndexingJob.cs +++ b/src/backoffice/ConversationalSearchPlatform.BackOffice/Jobs/WebsitePageIndexingJob.cs @@ -300,8 +300,8 @@ private async Task CreateEntry(ApplicationDbContext db, string tenantId, Website { foreach (var node in nodes) { - var cleanText = Regex.Replace(node.InnerText, @"\s+", " ").Trim(); - cleanText = WebUtility.HtmlDecode(cleanText); + //var cleanText = Regex.Replace(node.InnerText, @"\s+", " ").Trim(); + var cleanText = WebUtility.HtmlDecode(node.InnerText); if (!string.IsNullOrEmpty(cleanText)) { diff --git a/src/backoffice/ConversationalSearchPlatform.BackOffice/Services/Implementations/ConversationService.cs b/src/backoffice/ConversationalSearchPlatform.BackOffice/Services/Implementations/ConversationService.cs index 0ce2659..4828df3 100644 --- a/src/backoffice/ConversationalSearchPlatform.BackOffice/Services/Implementations/ConversationService.cs +++ b/src/backoffice/ConversationalSearchPlatform.BackOffice/Services/Implementations/ConversationService.cs @@ -1,5 +1,9 @@ +using System.Diagnostics.Eventing.Reader; +using System.Dynamic; using System.Runtime.CompilerServices; using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; using System.Text.RegularExpressions; using ConversationalSearchPlatform.BackOffice.Data.Entities; using ConversationalSearchPlatform.BackOffice.Exceptions; @@ -11,7 +15,11 @@ using ConversationalSearchPlatform.BackOffice.Tenants; using Finbuckle.MultiTenant; using GraphQL; +using Jint; +using Jint.Fetch; using Microsoft.Extensions.Caching.Memory; +using Newtonsoft.Json.Linq; +using Polly; using Rystem.OpenAi; using Rystem.OpenAi.Chat; using Language = ConversationalSearchPlatform.BackOffice.Services.Models.Language; @@ -83,13 +91,13 @@ public async Task ConverseAsync(HoldConversation h var (chatBuilder, textReferences, imageReferences) = await BuildChatAsync(holdConversation, conversationHistory, cancellationToken); - string answer = string.Empty; + ChatMessage answer = new ChatMessage() { Role = ChatRole.Assistant }; // don't give an answer when no references are found if (textReferences.Count == 0 && imageReferences.Count == 0) { shouldEndConversation = true; - answer = "I'm sorry, but I couldn't find relevant information in my database. Try asking a new question, please."; + answer.Content = "I'm sorry, but I couldn't find relevant information in my database. Try asking a new question, please."; } else { @@ -101,11 +109,47 @@ public async Task ConverseAsync(HoldConversation h conversationHistory.Model ); - answer = chatResult.Result.CombineAnswers(); + answer = chatResult.Result.GetFirstAnswer(); + if (answer.Function != null) + { + // call the function + conversationHistory.AppendToConversation(holdConversation.UserPrompt, answer); + + // add the function reply + var functionReply = await Task.Run(() => { + + var functionReplyTask = CallFunction(answer.Function.Name, answer.Function.Arguments); + + return functionReplyTask; + }).ConfigureAwait(false); + + // create new chatbuilder request with product reference + ChatMessage functionMessage = new ChatMessage() + { + Role = ChatRole.Function, + Content = functionReply, + Name = answer.Function.Name, + }; + + holdConversation.UserPrompt = functionMessage; + (chatBuilder, textReferences, imageReferences) = await BuildChatAsync(holdConversation, conversationHistory, cancellationToken); + chatResult = await chatBuilder.ExecuteAndCalculateCostAsync(false, cancellationToken); + _telemetryService.RegisterGPTUsage( + holdConversation.ConversationId, + holdConversation.TenantId, + chatResult.Result.Usage ?? throw new InvalidOperationException("No usage was passed in after executing an OpenAI call"), + conversationHistory.Model + ); + answer = chatResult.Result.GetFirstAnswer(); + conversationHistory.AppendToConversation(functionMessage, answer); + } + else + { + conversationHistory.AppendToConversation(holdConversation.UserPrompt, answer); + } } conversationHistory.HasEnded = shouldEndConversation; - conversationHistory.AppendToConversation(holdConversation.UserPrompt, answer); conversationHistory.SaveConversationHistory(_conversationsCache, cacheKey); var conversationReferencedResult = ParseAnswerWithReferences(holdConversation, conversationHistory, textReferences, imageReferences, shouldEndConversation); @@ -142,6 +186,7 @@ public async IAsyncEnumerable ConverseStreamingAsy } conversationHistory.HasEnded = shouldEndConversation; + ChatMessage composedMessage = new ChatMessage(); await foreach (var entry in chatBuilder .ExecuteAsStreamAndCalculateCostAsync(false, cancellationToken) @@ -153,6 +198,7 @@ public async IAsyncEnumerable ConverseStreamingAsy textReferences, imageReferences, shouldEndConversation, + composedMessage, cancellationToken) ) .Where(result => result is { IsOk: true, Value: not null }) @@ -161,6 +207,48 @@ public async IAsyncEnumerable ConverseStreamingAsy { yield return entry; } + + if (composedMessage.Function != null) + { + // add the function reply + var functionReply = await Task.Run(() => { + + var functionReplyTask = CallFunction(composedMessage.Function.Name, composedMessage.Function.Arguments); + + return functionReplyTask; + }).ConfigureAwait(false); + + // create new chatbuilder request with product reference + ChatMessage functionMessage = new ChatMessage() + { + Role = ChatRole.Function, + Content = functionReply, + Name = composedMessage.Function.Name, + }; + + holdConversation.UserPrompt = functionMessage; + (chatBuilder, textReferences, imageReferences) = await BuildChatAsync(holdConversation, conversationHistory, cancellationToken); + + await foreach (var entry in chatBuilder + .ExecuteAsStreamAndCalculateCostAsync(false, cancellationToken) + .SelectAwait(streamEntry => ProcessStreamedChatChunk( + holdConversation, + streamEntry, + cacheKey, + conversationHistory, + textReferences, + imageReferences, + shouldEndConversation, + composedMessage, + cancellationToken) + ) + .Where(result => result is { IsOk: true, Value: not null }) + .Select(result => result.Value!) + .WithCancellation(cancellationToken)) + { + yield return entry; + } + } } public async Task GetConversationContext(GetConversationContext getConversationContext) @@ -176,6 +264,96 @@ public async Task GetConversationContext(GetConversationCon return new ConversationContext(tags.ToList()); } + private string CallFunction(string? functionName, string? arguments) + { + dynamic argumentsObj = JObject.Parse(arguments); + + var engine = new Engine(); + engine.SetValue("log", new Action(Console.WriteLine)) + .SetValue("fetch", new Func>((uri, options) => FetchClass.Fetch(uri, FetchClass.ExpandoToOptionsObject(options)))) + .SetValue("__result", 0) + .SetValue("genderCtx", argumentsObj.gender) + .SetValue("mobilityCtx", argumentsObj.mobility) + .SetValue("incontinence_levelCtx", argumentsObj.incontinence_level); + + try + { + engine.Execute( + """ +(async () => { + async function GetRecommendedProducts(gender, mobility, incontinence_level) { + log(`Call product selector with: ${gender}, ${mobility}, ${incontinence_level}`); + // create body + var genderData; + if (gender == "male") { + genderData = {"data-title":"men_inco,women_men_inco", "data-details":"Men", "data-description":"men"} + } else { + genderData = {"data-title":"women_inco,women_men_inco", "data-details":"Women", "data-description":"women"} + } + + var mobilityData; + switch(mobility) { + case "mobile": + mobilityData = {"data-title":"FullMobility", "data-details":"fullMobilityScore", "data-description":"Able"}; + break; + case "needs_help_toilet": + mobilityData = {"data-title":"NeedAssistance", "data-details":"needAssistanceScore", "data-description":"Needs"}; + break; + case "bedridden": + mobilityData = {"data-title":"Bedridden", "data-details":"bedriddenScore", "data-description":"Unable"}; + break; + } + + var absorptionData; + switch(incontinence_level) { + case "small": + absorptionData = {"data-title":"0_5-drop-inco,1-drop-inco", "data-details":"0_5-drop-inco,1-drop-inco", "data-description":"VeryLight"}; + break; + case "light": + absorptionData = {"data-title":"1_5-drop-inco,2-drop-inco,2_5-drop-inco", "data-details":"1_5-drop-inco,2_5-drop-inco,2-drop-inco", "data-description":"Light"}; + break; + case "moderate": + absorptionData = {"data-title":"3-drop-inco,3_5-drop-inco,4-drop-inco,4_5-drop-inco,5-drop-inco", "data-details":"3_5-drop-inco,3-drop-inco,4_5-drop-inco,4-drop-inco,5-drop-inco", "data-description":"Medium"}; + break; + case "heavy": + absorptionData = {"data-title":"5_5-drop-inco,6-drop-inco,6_5-drop-inco,7-drop-inco", "data-details":"5_5-drop-inco,6_5-drop-inco,6-drop-inco,7-drop-inco", "data-description":"Heavy"}; + break; + case "very_heavy": + absorptionData = {"data-title":"7_5-drop-inco,8-drop-inco,8_5-drop-inco,9-drop-inco", "data-details":"7_5-drop-inco,8_5-drop-inco,8-drop-inco,9-drop-inco", "data-description":"VeryHeavy"}; + break; + } + + var body = JSON.stringify({ + User: {"data-title":"Homecare", "data-details":"Pharmacist", "data-description":"Pharmacist"}, + Gender: genderData, + Mobility: mobilityData, + Absorption: absorptionData + }); + var response = await fetch("https://www.tena.co.uk/professionals/api/Services/ProductFinder/GetProductSelectorResult", { Method: "POST", Body: { inputData: body } }); + var content = await response.json(); + + var considerationProducts = [content.recommendedProduct.recommendedProduct, ...content.considerationProducts.products]; + + return considerationProducts.map((p) => { + return { productName: p.productName }; + }); + }; + + __result = await GetRecommendedProducts(genderCtx, mobilityCtx, incontinence_levelCtx); +})(); +"""); + } + catch (Exception e) + { + Console.WriteLine(e.Message); + } + + + string jsonString = JsonSerializer.Serialize(engine.GetValue("__result").ToObject()); + + return jsonString; + } + private ValueTask> ProcessStreamedChatChunk(HoldConversation holdConversation, CostResult streamEntry, @@ -184,6 +362,7 @@ private ValueTask> ProcessStreamedCha List textReferences, List imageReferences, bool shouldEndConversation, + ChatMessage composedMessage, CancellationToken cancellationToken) { var chunk = streamEntry.Result.LastChunk.Choices? @@ -209,20 +388,44 @@ private ValueTask> ProcessStreamedCha conversationHistory.StreamingResponseChunks.Add(chunkedAnswer); conversationHistory.IsLastChunk = true; - result = ParseAnswerWithReferences(holdConversation, conversationHistory, textReferences, imageReferences, shouldEndConversation); + // function call or not? + var composedResult = streamEntry.Result.Composed; + var message = streamEntry.Result.Composed.GetFirstAnswer(); + if (message.Function == null) + { + result = ParseAnswerWithReferences(holdConversation, conversationHistory, textReferences, imageReferences, shouldEndConversation); + + conversationHistory.StreamingResponseChunks.Clear(); + conversationHistory.IsStreaming = false; + + conversationHistory.AppendToConversation(holdConversation.UserPrompt, message); + } + else + { + // call the function + conversationHistory.AppendToConversation(holdConversation.UserPrompt, message); + } - conversationHistory.StreamingResponseChunks.Clear(); - conversationHistory.IsStreaming = false; - var answer = streamEntry.Result.CombineStreamAnswer(); - conversationHistory.AppendToConversation(holdConversation.UserPrompt, answer); conversationHistory.SaveConversationHistory(_conversationsCache, cacheKey); + + composedMessage.Content = message.Content; + composedMessage.Function = message.Function; + composedMessage.Role = message.Role; + composedMessage.Name = message.Name; } else { - conversationHistory.IsLastChunk = false; - var chunkedAnswer = chunk.Delta?.Content ?? string.Empty; - conversationHistory.StreamingResponseChunks.Add(chunkedAnswer); - result = ParseAnswerWithReferences(holdConversation, conversationHistory, textReferences, imageReferences, shouldEndConversation); + if (!string.IsNullOrEmpty(chunk.Delta?.Content)) + { + conversationHistory.IsLastChunk = false; + var chunkedAnswer = chunk.Delta?.Content ?? string.Empty; + conversationHistory.StreamingResponseChunks.Add(chunkedAnswer); + result = ParseAnswerWithReferences(holdConversation, conversationHistory, textReferences, imageReferences, shouldEndConversation); + } + else + { + return ValueTask.FromResult(StreamResult.Skip("NoChunkContent")); + } } if (conversationHistory.DebugEnabled) @@ -253,7 +456,7 @@ private ValueTask> ProcessStreamedCha conversationHistory.InitializeDebugInformation(); } - holdConversation.UserPrompt = holdConversation.UserPrompt.Trim(); + holdConversation.UserPrompt.Content = holdConversation.UserPrompt.Content?.Trim() ?? ""; var promptBuilder = new PromptBuilder(new StringBuilder((await ResourceHelper.GetEmbeddedResourceTextAsync(ResourceHelper.BasePromptFile)).Trim())) .ReplaceTenantPrompt(GetTenantPrompt(tenant)) @@ -264,12 +467,16 @@ private ValueTask> ProcessStreamedCha // add history of conversation to vector context foreach (var promptResponse in conversationHistory.PromptResponses.TakeLast(2)) { - vectorPrompt.AppendLine(promptResponse.Prompt); - vectorPrompt.AppendLine(promptResponse.Response); + vectorPrompt.AppendLine(promptResponse.Prompt.Content); + + if (!string.IsNullOrEmpty(promptResponse.Response.Content)) + { + vectorPrompt.AppendLine(promptResponse.Response.Content); + } } // add last user prompt - vectorPrompt.AppendLine(holdConversation.UserPrompt); + vectorPrompt.AppendLine(holdConversation.UserPrompt.Content); // convert to keywords var keywords = await _keywordExtractorService.ExtractKeywordAsync(vectorPrompt.ToString()); @@ -309,7 +516,7 @@ private ValueTask> ProcessStreamedCha ragClass.Sources.Add(new RAGSource() { - ReferenceId = $"{index}", + ReferenceId = $"{index}", Properties = properties, }); @@ -325,26 +532,56 @@ private ValueTask> ProcessStreamedCha Guid TENA_ID = Guid.Parse("CCFA9314-ABE6-403A-9E21-2B31D95A5258"); - /*if (Guid.Parse(tenantId) == TENA_ID) + if (Guid.Parse(tenantId) == TENA_ID) { // specialleke voor tena - var articleNumber = Regex.Match(holdConversation.UserPrompt, @"\d+").Value; + var articleNumber = Regex.Match(holdConversation.UserPrompt.Content ?? "", @"\d+").Value; if (!string.IsNullOrEmpty(articleNumber)) { - if (!productReferences.Any(p => p.ArticleNumber == articleNumber)) + var ragClass = ragDocument.Classes.FirstOrDefault(r => r.Name == "Product"); + + if (ragClass != null) { - var articleNumberReferences = await GetProductReferenceById(articleNumber, - nameof(WebsitePage), - tenantId, - "English", - ConversationReferenceType.Product.ToString(), - cancellationToken); - - productReferences.AddRange(articleNumberReferences); + if (!ragClass.Sources.Any(p => p.Properties["ArticleNumber"] == articleNumber)) + { + var articleNumberReferences = await GetProductReferenceById(articleNumber, + nameof(WebsitePage), + tenantId, + "English", + ConversationReferenceType.Product.ToString(), + cancellationToken); + + foreach (var reference in articleNumberReferences) + { + index++; + + Dictionary properties = new Dictionary(); + properties["Content"] = reference.Content; + properties["Title"] = reference.Title; + + if (ragClass.Name == "Product") + { + properties["ArticleNumber"] = reference.ArticleNumber; + properties["Packaging"] = reference.Packaging; + } + + ragClass.Sources.Add(new RAGSource() + { + ReferenceId = $"{index}", + Properties = properties, + }); + + indexedTextReferences.Add(new SortedSearchReference() + { + Index = index, + TextSearchReference = reference, + }); + } + } } } - }*/ + } // TODO: restore this, but better /*var imageReferences = await GetImageReferences( @@ -356,7 +593,7 @@ private ValueTask> ProcessStreamedCha var imageReferences = new List(); var ragString = await ragDocument.GenerateXMLStringAsync(); - + var systemPrompt = promptBuilder .ReplaceRAGDocument(ragString) .Build(); @@ -365,11 +602,45 @@ private ValueTask> ProcessStreamedCha var chatBuilder = _openAiFactory.CreateChat() .RequestWithSystemMessage(systemPrompt) .AddPreviousMessages(conversationHistory.PromptResponses) - .AddUserMessage(holdConversation.UserPrompt) - .AddUserMessage("Do not give me any information that is not mentioned in the document") - .WithModel(chatModel) + .AddMessage(holdConversation.UserPrompt); + + if (holdConversation.UserPrompt.Role == ChatRole.User) + { + chatBuilder.AddUserMessage("Do not give me any information that is not mentioned in the document. Only use the functions you have been provided with."); + } + + chatBuilder.WithModel(chatModel) .WithTemperature(0.75); + if (Guid.Parse(tenantId) == TENA_ID) + { + chatBuilder.WithFunction(new System.Text.Json.Serialization.JsonFunction() + { + Name = "get_product_recommendation", + Description = "Get a recommendation for a Tena incontenince product based on gender, level of incontinence and mobility", + Parameters = new JsonFunctionNonPrimitiveProperty() + .AddEnum("gender", new JsonFunctionEnumProperty + { + Type = "string", + Enums = new List { "male", "female" }, + Description = "The gender of the person derived from the context, male (he/him) or female (she/her)", + }) + .AddEnum("mobility", new JsonFunctionEnumProperty + { + Type = "string", + Enums = new List { "mobile", "needs_help_toilet", "bedridden" }, + Description = "The level of mobility of the incontinent person from mobile and being able to go to the toilet himself to bedridden", + }) + .AddEnum("incontinence_level", new JsonFunctionEnumProperty + { + Type = "string", + Enums = new List { "small", "light", "moderate", "heavy", "very_heavy" }, + Description = "How heavy the urine loss is, from very small drops to a full cup of urine loss. Ranging between; Small (drops making the underwear damp), Light (leakages making the underwear fairly wet), moderate (Quarter cup, making the underwear quite wet), heavy (Half cup, making the underwear very wet) and very heavy (Full cup, emptying more than half bladder). Do not make this data up, it should be explicetely mentioned by the user.", + }) + .AddRequired("gender", "mobility", "incontinence_level") + }); + } + if (conversationHistory.DebugEnabled) { conversationHistory.AppendPreRequestDebugInformation(chatBuilder, references, imageReferences, promptBuilder); @@ -391,7 +662,7 @@ private async Task GetTenantAsync(string tenantId) return tenant; } - + private static ConversationReferencedResult ParseAnswerWithReferences(HoldConversation holdConversation, ConversationHistory conversationHistory, IReadOnlyCollection textReferences, @@ -413,14 +684,14 @@ private static ConversationReferencedResult ParseAnswerWithReferences(HoldConver { shouldReturnFullMessage = true; var (_, answer, _, _) = conversationHistory.PromptResponses.Last(); - mergedAnswer = answer; + mergedAnswer = answer.Content; } Console.WriteLine($"Merged answer {mergedAnswer}"); List? validReferences = null; if (shouldReturnFullMessage) - { + { validReferences = DetermineValidReferences(textReferences, mergedAnswer); if (conversationHistory.DebugEnabled) @@ -577,7 +848,8 @@ private async Task> GetTextReferences( return search .GroupBy(p => p.Source) - .Select(grouping => { + .Select(grouping => + { var first = grouping.First(); first.Text = string.Join(" ", grouping.Select(g => g.Text).ToList()); @@ -654,7 +926,8 @@ private async Task> GetProductReferenceById( return search .GroupBy(p => p.Source) - .Select(grouping => { + .Select(grouping => + { var first = grouping.First(); first.Text = string.Join(" ", grouping.Select(g => g.Text).ToList()); diff --git a/src/backoffice/ConversationalSearchPlatform.BackOffice/Services/Models/ConversationHistory.cs b/src/backoffice/ConversationalSearchPlatform.BackOffice/Services/Models/ConversationHistory.cs index 6682589..71ea31c 100644 --- a/src/backoffice/ConversationalSearchPlatform.BackOffice/Services/Models/ConversationHistory.cs +++ b/src/backoffice/ConversationalSearchPlatform.BackOffice/Services/Models/ConversationHistory.cs @@ -5,21 +5,10 @@ namespace ConversationalSearchPlatform.BackOffice.Services.Models; -public enum PromptType -{ - User, - FunctionResponse, -} - -public enum ResponseType -{ - Assistant, - FunctionRequest, -} public class ConversationExchange { - public ConversationExchange(string prompt, string response, List promptKeywords, List responseKeywords) + public ConversationExchange(ChatMessage prompt, ChatMessage response, List promptKeywords, List responseKeywords) { Prompt = prompt; Response = response; @@ -27,29 +16,18 @@ public ConversationExchange(string prompt, string response, List promptK ResponseKeywords = responseKeywords; } - public ConversationExchange(string prompt, string response) + public ConversationExchange(ChatMessage prompt, ChatMessage response) { Prompt = prompt; Response = response; } - public ConversationExchange(string prompt, PromptType promptType, string response, ResponseType responseType) - { - Prompt = prompt; - PromptType = promptType; - Response = response; - ResponseType = responseType; - } - - public PromptType PromptType { get; set; } = PromptType.User; - public ResponseType ResponseType { get; set; } = ResponseType.Assistant; - - public string Prompt { get; init; } - public string Response { get; init; } + public ChatMessage Prompt { get; init; } + public ChatMessage Response { get; init; } public List PromptKeywords { get; set; } = new(); public List ResponseKeywords { get; set; } = new(); - public void Deconstruct(out string prompt, out string response, out List promptKeywords, out List responseKeywords) + public void Deconstruct(out ChatMessage prompt, out ChatMessage response, out List promptKeywords, out List responseKeywords) { prompt = Prompt; response = Response; @@ -77,7 +55,7 @@ public class ConversationHistory(ChatModel model, int amountOfSearchReferences) public string GetAllStreamingResponseChunksMerged() => string.Join(null, StreamingResponseChunks); - public void AppendToConversation(string prompt, string answer) + public void AppendToConversation(ChatMessage prompt, ChatMessage answer) { PromptResponses.Add(new ConversationExchange(prompt, answer)); } diff --git a/src/backoffice/ConversationalSearchPlatform.BackOffice/Services/Models/HoldConversation.cs b/src/backoffice/ConversationalSearchPlatform.BackOffice/Services/Models/HoldConversation.cs index 2dc9b4b..073836f 100644 --- a/src/backoffice/ConversationalSearchPlatform.BackOffice/Services/Models/HoldConversation.cs +++ b/src/backoffice/ConversationalSearchPlatform.BackOffice/Services/Models/HoldConversation.cs @@ -1,8 +1,10 @@ +using Rystem.OpenAi.Chat; + namespace ConversationalSearchPlatform.BackOffice.Services.Models; public record HoldConversation(Guid ConversationId, string TenantId, - string UserPrompt, + ChatMessage UserPrompt, IDictionary ConversationContext, bool Debug, Language Language = Language.English) @@ -10,7 +12,7 @@ public record HoldConversation(Guid ConversationId, public Guid ConversationId { get; private set; } = ConversationId; public string TenantId { get; private set; } = TenantId; - public string UserPrompt { get; set; } = UserPrompt; + public ChatMessage UserPrompt { get; set; } = UserPrompt; public Language Language { get; private set; } = Language; public IDictionary ConversationContext { get; private set; } = ConversationContext; public bool Debug { get; init; } = Debug; diff --git a/src/backoffice/ConversationalSearchPlatform.BackOffice/Services/Models/StreamResult.cs b/src/backoffice/ConversationalSearchPlatform.BackOffice/Services/Models/StreamResult.cs index bc0da89..6d5ea83 100644 --- a/src/backoffice/ConversationalSearchPlatform.BackOffice/Services/Models/StreamResult.cs +++ b/src/backoffice/ConversationalSearchPlatform.BackOffice/Services/Models/StreamResult.cs @@ -16,6 +16,8 @@ private StreamResult(T? Value, Exception? Error, string? SkipReason) public static StreamResult Ok(T result) => new(result, default, default); + public static StreamResult FunctionCall(T result) => new(result, default, default); + public static StreamResult Fail(Exception exception) => new(default, exception, default); diff --git a/src/backoffice/ConversationalSearchPlatform.BackOffice/Services/Models/Weaviate/Queries/GetByPromptFiltered.cs b/src/backoffice/ConversationalSearchPlatform.BackOffice/Services/Models/Weaviate/Queries/GetByPromptFiltered.cs index 066bff3..af395df 100644 --- a/src/backoffice/ConversationalSearchPlatform.BackOffice/Services/Models/Weaviate/Queries/GetByPromptFiltered.cs +++ b/src/backoffice/ConversationalSearchPlatform.BackOffice/Services/Models/Weaviate/Queries/GetByPromptFiltered.cs @@ -22,9 +22,7 @@ public static GraphQLRequest Request(T @params) where T : IQueryParams string cleanQuery = HttpUtility.JavaScriptStringEncode(queryParams.query.ReplaceLineEndings(" ")); var vectorAsJsonArray = JsonSerializer.Serialize(queryParams.Vector); - return new GraphQLRequest - { - Query = $$""" + var query = $$""" { Get { {{queryParams.CollectionName}}( @@ -63,7 +61,11 @@ public static GraphQLRequest Request(T @params) where T : IQueryParams } } } - """ + """; + + return new GraphQLRequest + { + Query = query, }; } diff --git a/src/nuget.config b/src/nuget.config new file mode 100644 index 0000000..4f31113 --- /dev/null +++ b/src/nuget.config @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file