diff --git a/pkg/externalfunctions/ansysgpt.go b/pkg/externalfunctions/ansysgpt.go index b279074..793cb6e 100644 --- a/pkg/externalfunctions/ansysgpt.go +++ b/pkg/externalfunctions/ansysgpt.go @@ -306,7 +306,7 @@ func AnsysGPTPerformLLMRequest(finalQuery string, history []sharedtypes.Historic streamChannel := make(chan string, 400) // Start a goroutine to transfer the data from the response channel to the stream channel - go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false, false, 0, 0, "", false, "") + go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false, false, "", 0, 0, "", "", false, "") // Return the stream channel return "", &streamChannel @@ -685,10 +685,12 @@ func AisPerformLLMFinalRequest(systemTemplate string, prohibitedWords []string, errorList1 []string, errorList2 []string, + tokenCountEndpoint string, previousInputTokenCount int, previousOutputTokenCount int, tokenCountModelName string, - isStream bool) (message string, stream *chan string) { + isStream bool, + userEmail string) (message string, stream *chan string) { logging.Log.Debugf(internalstates.Ctx, "Performing LLM final request") @@ -777,7 +779,7 @@ func AisPerformLLMFinalRequest(systemTemplate string, totalInputTokenCount := previousInputTokenCount + inputTokenCount // Start a goroutine to transfer the data from the response channel to the stream channel. - go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false, true, totalInputTokenCount, previousOutputTokenCount, tokenCountModelName, true, contextString) + go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false, true, tokenCountEndpoint, totalInputTokenCount, previousOutputTokenCount, tokenCountModelName, userEmail, true, contextString) return "", &streamChannel } diff --git a/pkg/externalfunctions/llmhandler.go b/pkg/externalfunctions/llmhandler.go index bdc74f0..b433806 100644 --- a/pkg/externalfunctions/llmhandler.go +++ b/pkg/externalfunctions/llmhandler.go @@ -256,7 +256,7 @@ func PerformGeneralRequest(input string, history []sharedtypes.HistoricMessage, streamChannel := make(chan string, 400) // Start a goroutine to transfer the data from the response channel to the stream channel - go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false, false, 0, 0, "", false, "") + go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false, false, "", 0, 0, "", "", false, "") // Return the stream channel return "", &streamChannel @@ -313,7 +313,7 @@ func PerformGeneralRequestWithImages(input string, history []sharedtypes.Histori streamChannel := make(chan string, 400) // Start a goroutine to transfer the data from the response channel to the stream channel - go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false, false, 0, 0, "", false, "") + go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false, false, "", 0, 0, "", "", false, "") // Return the stream channel return "", &streamChannel @@ -371,7 +371,7 @@ func PerformGeneralRequestSpecificModel(input string, history []sharedtypes.Hist streamChannel := make(chan string, 400) // Start a goroutine to transfer the data from the response channel to the stream channel - go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false, false, 0, 0, "", false, "") + go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false, false, "", 0, 0, "", "", false, "") // Return the stream channel return "", &streamChannel @@ -427,7 +427,7 @@ func PerformCodeLLMRequest(input string, history []sharedtypes.HistoricMessage, streamChannel := make(chan string, 400) // Start a goroutine to transfer the data from the response channel to the stream channel - go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, validateCode, false, 0, 0, "", false, "") + go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, validateCode, false, "", 0, 0, "", "", false, "") // Return the stream channel return "", &streamChannel diff --git a/pkg/externalfunctions/privatefunctions.go b/pkg/externalfunctions/privatefunctions.go index 726edba..3c6a009 100644 --- a/pkg/externalfunctions/privatefunctions.go +++ b/pkg/externalfunctions/privatefunctions.go @@ -40,9 +40,11 @@ func transferDatafromResponseToStreamChannel( streamChannel *chan string, validateCode bool, sendTokenCount bool, + tokenCountEndpoint string, previousInputTokenCount int, previousOutputTokenCount int, tokenCountModelName string, + userEmail string, sendContex bool, contex string) { @@ -81,10 +83,21 @@ func transferDatafromResponseToStreamChannel( // send the error message to the stream channel and exit function *streamChannel <- fmt.Sprintf("$&$error$&$:$&$Error getting token count: %v$&$", err) } + + // calculate the total token count + totalInputTokenCount := previousInputTokenCount totalOuputTokenCount := previousOutputTokenCount + outputTokenCount - // append the token count message to the final message - finalMessage += fmt.Sprintf("$&$input_token_count$&$:$&$%d$&$;$&$output_token_count$&$:$&$%d$&$;", previousInputTokenCount, totalOuputTokenCount) + // send the token count to the token count endpoint + err = sendTokenCountToEndpoint(userEmail, tokenCountEndpoint, totalInputTokenCount, totalOuputTokenCount) + if err != nil { + logging.Log.Errorf(internalstates.Ctx, "Error sending token count: %v\n", err) + // send the error message to the stream channel and exit function + *streamChannel <- fmt.Sprintf("$&$error$&$:$&$Error in updating token count: %v$&$", err) + } else { + // append the token count message to the final message + finalMessage += fmt.Sprintf("$&$input_token_count$&$:$&$%d$&$;$&$output_token_count$&$:$&$%d$&$;", totalInputTokenCount, totalOuputTokenCount) + } } // check for contex @@ -131,6 +144,68 @@ func transferDatafromResponseToStreamChannel( } } +// sendTokenCount sends the token count to the token count endpoint +// +// Parameters: +// - userEmail: the email of the user +// - tokenCountEndpoint: the endpoint to send the token count to +// - inputTokenCount: the number of input tokens +// - ouputTokenCount: the number of output tokens +// +// Returns: +// - err: an error if the request fails +func sendTokenCountToEndpoint(userEmail string, tokenCountEndpoint string, inputTokenCount int, ouputTokenCount int) (err error) { + defer func() { + r := recover() + if r != nil { + err = fmt.Errorf("panic in sendTokenCount: %v", r) + } + }() + + // verify that endpoint is filled + if tokenCountEndpoint == "" { + return fmt.Errorf("no token count endpoint provided") + } + + // Create the request + requestBody := TokenCountUpdateRequest{ + Email: userEmail, + InputToken: inputTokenCount, + OutputToken: ouputTokenCount, + Plattform: "Allie", + } + + // Convert payload to JSON + jsonData, err := json.Marshal(requestBody) + if err != nil { + return fmt.Errorf("error marshalling JSON: %v", err) + } + + // Create a new HTTP request + request, err := http.NewRequest("POST", tokenCountEndpoint, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("error creating request: %v", err) + } + + // Set headers + request.Header.Set("Content-Type", "application/json") + + // Create an HTTP client and make the request + client := &http.Client{} + resp, err := client.Do(request) + if err != nil { + return fmt.Errorf("error making request: %v", err) + } + defer resp.Body.Close() + + // Check the response status + if resp.StatusCode != 200 { + return fmt.Errorf("response status unequal 200: %v", resp.Status) + } + + return nil +} + // sendChatRequestNoHistory sends a chat request to LLM without history // // Parameters: @@ -1502,7 +1577,7 @@ func performGeneralRequest(input string, history []sharedtypes.HistoricMessage, streamChannel := make(chan string, 400) // Start a goroutine to transfer the data from the response channel to the stream channel. - go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false, false, 0, 0, "", false, "") + go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false, false, "", 0, 0, "", "", false, "") // Return the stream channel. return "", &streamChannel, nil diff --git a/pkg/externalfunctions/types.go b/pkg/externalfunctions/types.go index b4744ad..43021f2 100644 --- a/pkg/externalfunctions/types.go +++ b/pkg/externalfunctions/types.go @@ -208,3 +208,10 @@ type DataExtractionSplitterServiceRequest struct { type DataExtractionSplitterServiceResponse struct { Chunks []string `json:"chunks"` } + +type TokenCountUpdateRequest struct { + Email string `json:"email"` + InputToken int `json:"input_token"` + OutputToken int `json:"output_token"` + Plattform string `json:"plattform"` +}