From ed635461ba9cf57c6af4871580f86a02505cc4ce Mon Sep 17 00:00:00 2001 From: Allen Firstenberg Date: Mon, 30 Dec 2024 13:24:10 -0500 Subject: [PATCH] fix(google-common,google-*) Gemini 2.0 support (#7435) Co-authored-by: jacoblee93 --- .../src/chat_models.ts | 39 +- .../src/tests/chat_models.test.ts | 145 +++++ .../src/tests/data/chat-6-mock.json | 85 +++ libs/langchain-google-common/src/types.ts | 108 +++- .../src/utils/common.ts | 18 +- .../src/utils/gemini.ts | 106 ++- .../src/tests/chat_models.int.test.ts | 178 +++-- .../src/tests/chat_models.int.test.ts | 612 +++++++++++++++++- 8 files changed, 1216 insertions(+), 75 deletions(-) create mode 100644 libs/langchain-google-common/src/tests/data/chat-6-mock.json diff --git a/libs/langchain-google-common/src/chat_models.ts b/libs/langchain-google-common/src/chat_models.ts index 7d476b94cf2c..15bf21fd3c94 100644 --- a/libs/langchain-google-common/src/chat_models.ts +++ b/libs/langchain-google-common/src/chat_models.ts @@ -33,6 +33,7 @@ import { GoogleAIBaseLanguageModelCallOptions, GoogleAIAPI, GoogleAIAPIParams, + GoogleSearchToolSetting, } from "./types.js"; import { convertToGeminiTools, @@ -97,10 +98,39 @@ export class ChatConnection extends AbstractGoogleLLMConnection< return true; } + computeGoogleSearchToolAdjustmentFromModel(): Exclude< + GoogleSearchToolSetting, + boolean + > { + if (this.modelName.startsWith("gemini-1.0")) { + return "googleSearchRetrieval"; + } else if (this.modelName.startsWith("gemini-1.5")) { + return "googleSearchRetrieval"; + } else { + return "googleSearch"; + } + } + + computeGoogleSearchToolAdjustment( + apiConfig: GeminiAPIConfig + ): Exclude { + const adj = apiConfig.googleSearchToolAdjustment; + if (adj === undefined || adj === true) { + return this.computeGoogleSearchToolAdjustmentFromModel(); + } else { + return adj; + } + } + buildGeminiAPI(): GoogleAIAPI { + const apiConfig: GeminiAPIConfig = + (this.apiConfig as GeminiAPIConfig) ?? {}; + const googleSearchToolAdjustment = + this.computeGoogleSearchToolAdjustment(apiConfig); const geminiConfig: GeminiAPIConfig = { useSystemInstruction: this.useSystemInstruction, - ...(this.apiConfig as GeminiAPIConfig), + googleSearchToolAdjustment, + ...apiConfig, }; return getGeminiAPI(geminiConfig); } @@ -208,7 +238,12 @@ export abstract class ChatGoogleBase } buildApiKey(fields?: GoogleAIBaseLLMInput): string | undefined { - return fields?.apiKey ?? getEnvironmentVariable("GOOGLE_API_KEY"); + if (fields?.platformType !== "gcp") { + return fields?.apiKey ?? getEnvironmentVariable("GOOGLE_API_KEY"); + } else { + // GCP doesn't support API Keys + return undefined; + } } buildClient( diff --git a/libs/langchain-google-common/src/tests/chat_models.test.ts b/libs/langchain-google-common/src/tests/chat_models.test.ts index aa15be74ed79..5726d9fd445e 100644 --- a/libs/langchain-google-common/src/tests/chat_models.test.ts +++ b/libs/langchain-google-common/src/tests/chat_models.test.ts @@ -1105,6 +1105,151 @@ describe("Mock ChatGoogle - Gemini", () => { // console.log(JSON.stringify(record?.opts?.data, null, 1)); }); + + test("6. GoogleSearchRetrievalTool result", async () => { + const record: Record = {}; + const projectId = mockId(); + const authOptions: MockClientAuthInfo = { + record, + projectId, + resultFile: "chat-6-mock.json", + }; + + const searchRetrievalTool = { + googleSearchRetrieval: { + dynamicRetrievalConfig: { + mode: "MODE_DYNAMIC", + dynamicThreshold: 0.7, // default is 0.7 + }, + }, + }; + const model = new ChatGoogle({ + authOptions, + modelName: "gemini-1.5-pro-002", + temperature: 0, + maxRetries: 0, + }).bindTools([searchRetrievalTool]); + + const result = await model.invoke("Who won the 2024 MLB World Series?"); + expect(result.content as string).toContain("Dodgers"); + expect(result).toHaveProperty("response_metadata"); + expect(result.response_metadata).toHaveProperty("groundingMetadata"); + expect(result.response_metadata).toHaveProperty("groundingSupport"); + expect(Array.isArray(result.response_metadata.groundingSupport)).toEqual( + true + ); + expect(result.response_metadata.groundingSupport).toHaveLength(4); + }); + + test("6. GoogleSearchRetrievalTool request 1.5 ", async () => { + const record: Record = {}; + const projectId = mockId(); + const authOptions: MockClientAuthInfo = { + record, + projectId, + resultFile: "chat-6-mock.json", + }; + + const searchRetrievalTool = { + googleSearchRetrieval: { + dynamicRetrievalConfig: { + mode: "MODE_DYNAMIC", + dynamicThreshold: 0.7, // default is 0.7 + }, + }, + }; + const model = new ChatGoogle({ + authOptions, + modelName: "gemini-1.5-pro-002", + temperature: 0, + maxRetries: 0, + }).bindTools([searchRetrievalTool]); + + const result = await model.invoke("Who won the 2024 MLB World Series?"); + expect(result.content as string).toContain("Dodgers"); + + expect(record.opts.data.tools[0]).toHaveProperty("googleSearchRetrieval"); + }); + + test("6. GoogleSearchRetrievalTool request 2.0 ", async () => { + const record: Record = {}; + const projectId = mockId(); + const authOptions: MockClientAuthInfo = { + record, + projectId, + resultFile: "chat-6-mock.json", + }; + + const searchRetrievalTool = { + googleSearchRetrieval: { + dynamicRetrievalConfig: { + mode: "MODE_DYNAMIC", + dynamicThreshold: 0.7, // default is 0.7 + }, + }, + }; + const model = new ChatGoogle({ + authOptions, + modelName: "gemini-2.0-flash", + temperature: 0, + maxRetries: 0, + }).bindTools([searchRetrievalTool]); + + const result = await model.invoke("Who won the 2024 MLB World Series?"); + expect(result.content as string).toContain("Dodgers"); + + expect(record.opts.data.tools[0]).toHaveProperty("googleSearch"); + }); + + test("6. GoogleSearchTool request 1.5 ", async () => { + const record: Record = {}; + const projectId = mockId(); + const authOptions: MockClientAuthInfo = { + record, + projectId, + resultFile: "chat-6-mock.json", + }; + + const searchTool = { + googleSearch: {}, + }; + const model = new ChatGoogle({ + authOptions, + modelName: "gemini-1.5-pro-002", + temperature: 0, + maxRetries: 0, + }).bindTools([searchTool]); + + const result = await model.invoke("Who won the 2024 MLB World Series?"); + expect(result.content as string).toContain("Dodgers"); + + expect(record.opts.data.tools[0]).toHaveProperty("googleSearchRetrieval"); + }); + + test("6. GoogleSearchTool request 2.0 ", async () => { + const record: Record = {}; + const projectId = mockId(); + const authOptions: MockClientAuthInfo = { + record, + projectId, + resultFile: "chat-6-mock.json", + }; + + const searchTool = { + googleSearch: {}, + }; + const model = new ChatGoogle({ + authOptions, + modelName: "gemini-2.0-flash", + temperature: 0, + maxRetries: 0, + }).bindTools([searchTool]); + + const result = await model.invoke("Who won the 2024 MLB World Series?"); + expect(result.content as string).toContain("Dodgers"); + + expect(record.opts.data.tools[0]).toHaveProperty("googleSearch"); + }); }); describe("Mock ChatGoogle - Anthropic", () => { diff --git a/libs/langchain-google-common/src/tests/data/chat-6-mock.json b/libs/langchain-google-common/src/tests/data/chat-6-mock.json new file mode 100644 index 000000000000..796fdcf9bcee --- /dev/null +++ b/libs/langchain-google-common/src/tests/data/chat-6-mock.json @@ -0,0 +1,85 @@ +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "The Los Angeles Dodgers won the 2024 World Series, defeating the New York Yankees 4-1 in the series. The Dodgers clinched the title with a 7-6 comeback victory in Game 5 at Yankee Stadium on Wednesday, October 30th. This was their eighth World Series title overall and their second in the past five years. It was also their first World Series win in a full season since 1988. Mookie Betts earned his third World Series ring (2018, 2020, and 2024), becoming the only active player with three championships. Shohei Ohtani, in his first year with the Dodgers, also experienced his first post-season appearance.\n" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "groundingMetadata": { + "searchEntryPoint": { + "renderedContent": "\n
\n
\n \n \n \n \n \n \n \n \n \n \n \n \n \n
\n
\n \n
\n" + }, + "groundingChunks": [ + { + "web": { + "uri": "https://vertexaisearch.cloud.google.com/grounding-api-redirect/AYygrcTYmdnM71OvWYUTG4JggmRj8cIIgA2KtKas5RPj09CiALB4n8hl-SfCD6r8WnimL2psBoYmEN9ng9sENjpeP5VxgLMTlm0zgxhrWFfx3yA6B_n0N9j-BgHLISAUi-_Ql4_Buyw68Svq-3v6BgrXzn9hLOtK", + "title": "bbc.com" + } + }, + { + "web": { + "uri": "https://vertexaisearch.cloud.google.com/grounding-api-redirect/AYygrcQRhhvHTdpb8OMOEMVxv9fkevPoMWMnhrpuC7E0E0R94xmFxT9Vv5na1hMrfHGKxVZ9aE3PgCAs5nftC3iAkeD7B6ZTfKGH2Im1CqssMM7zorGx1Ds5_7QPPBDQps_JvpkOuvRluGCVg8KwNaIU-hm3Kg==", + "title": "mlb.com" + } + }, + { + "web": { + "uri": "https://vertexaisearch.cloud.google.com/grounding-api-redirect/AYygrcSwvb2t622A2ZpKxqOWKy16L1mEUvmsAJoHjaR7uffKO71SeZkpdRXRsST9HJzJkGSkMF9kOaXGoDtcvUrttqKYOQHvHSUBYO7LWMlU00KyNlSoQzrBsgN4KuJ4O4acnNyNCSVX3-E=", + "title": "youtube.com" + } + } + ], + "groundingSupports": [ + { + "segment": { + "endIndex": 100, + "text": "The Los Angeles Dodgers won the 2024 World Series, defeating the New York Yankees 4-1 in the series." + }, + "groundingChunkIndices": [0], + "confidenceScores": [0.95898277] + }, + { + "segment": { + "startIndex": 308, + "endIndex": 377, + "text": "It was also their first World Series win in a full season since 1988." + }, + "groundingChunkIndices": [1], + "confidenceScores": [0.96841997] + }, + { + "segment": { + "startIndex": 379, + "endIndex": 508, + "text": "Mookie Betts earned his third World Series ring (2018, 2020, and 2024), becoming the only active player with three championships." + }, + "groundingChunkIndices": [2], + "confidenceScores": [0.99043523] + }, + { + "segment": { + "startIndex": 510, + "endIndex": 611, + "text": "Shohei Ohtani, in his first year with the Dodgers, also experienced his first post-season appearance." + }, + "groundingChunkIndices": [0], + "confidenceScores": [0.95767003] + } + ], + "webSearchQueries": ["2024 MLB World Series winner"] + }, + "avgLogprobs": -0.040494912748883484 + } + ], + "usageMetadata": { + "promptTokenCount": 13, + "candidatesTokenCount": 157, + "totalTokenCount": 170 + }, + "modelVersion": "gemini-1.5-pro-002" +} diff --git a/libs/langchain-google-common/src/types.ts b/libs/langchain-google-common/src/types.ts index b88b3e01d090..3b702cda6f87 100644 --- a/libs/langchain-google-common/src/types.ts +++ b/libs/langchain-google-common/src/types.ts @@ -299,6 +299,71 @@ export type GeminiSafetyRating = { probability: string; } & Record; +export interface GeminiCitationMetadata { + citations: GeminiCitation[]; +} + +export interface GeminiCitation { + startIndex: number; + endIndex: number; + uri: string; + title: string; + license: string; + publicationDate: GoogleTypeDate; +} + +export interface GoogleTypeDate { + year: number; // 1-9999 or 0 to specify a date without a year + month: number; // 1-12 or 0 to specify a year without a month and day + day: number; // Must be from 1 to 31 and valid for the year and month, or 0 to specify a year by itself or a year and month where the day isn't significant +} + +export interface GeminiGroundingMetadata { + webSearchQueries?: string[]; + searchEntryPoint?: GeminiSearchEntryPoint; + groundingChunks: GeminiGroundingChunk[]; + groundingSupports?: GeminiGroundingSupport[]; + retrievalMetadata?: GeminiRetrievalMetadata; +} + +export interface GeminiSearchEntryPoint { + renderedContent?: string; + sdkBlob?: string; // Base64 encoded JSON representing array of tuple. +} + +export interface GeminiGroundingChunk { + web: GeminiGroundingChunkWeb; + retrievedContext: GeminiGroundingChunkRetrievedContext; +} + +export interface GeminiGroundingChunkWeb { + uri: string; + title: string; +} + +export interface GeminiGroundingChunkRetrievedContext { + uri: string; + title: string; + text: string; +} + +export interface GeminiGroundingSupport { + segment: GeminiSegment; + groundingChunkIndices: number[]; + confidenceScores: number[]; +} + +export interface GeminiSegment { + partIndex: number; + startIndex: number; + endIndex: number; + text: string; +} + +export interface GeminiRetrievalMetadata { + googleSearchDynamicRetrievalScore: number; +} + // The "system" content appears to only be valid in the systemInstruction export type GeminiRole = "system" | "user" | "model" | "function"; @@ -307,12 +372,37 @@ export interface GeminiContent { role: GeminiRole; // Vertex AI requires the role } +/* + * If additional attributes are added here, they should also be + * added to the attributes below + */ export interface GeminiTool { functionDeclarations?: GeminiFunctionDeclaration[]; - googleSearchRetrieval?: GoogleSearchRetrieval; + googleSearchRetrieval?: GoogleSearchRetrieval; // Gemini-1.5 + googleSearch?: GoogleSearch; // Gemini-2.0 retrieval?: VertexAIRetrieval; } +/* + * The known strings in this type should match those in GeminiSearchToolAttribuets + */ +export type GoogleSearchToolSetting = + | boolean + | "googleSearchRetrieval" + | "googleSearch" + | string; + +export const GeminiSearchToolAttributes = [ + "googleSearchRetrieval", + "googleSearch", +]; + +export const GeminiToolAttributes = [ + "functionDeclaration", + "retrieval", + ...GeminiSearchToolAttributes, +]; + export interface GoogleSearchRetrieval { dynamicRetrievalConfig?: { mode?: string; @@ -320,6 +410,8 @@ export interface GoogleSearchRetrieval { }; } +export interface GoogleSearch {} + export interface VertexAIRetrieval { vertexAiSearch: { datastore: string; @@ -385,6 +477,8 @@ interface GeminiResponseCandidate { index: number; tokenCount?: number; safetyRatings: GeminiSafetyRating[]; + citationMetadata?: GeminiCitationMetadata; + groundingMetadata?: GeminiGroundingMetadata; } interface GeminiResponsePromptFeedback { @@ -467,6 +561,18 @@ export interface GeminiAPIConfig { safetyHandler?: GoogleAISafetyHandler; mediaManager?: MediaManager; useSystemInstruction?: boolean; + + /** + * How to handle the Google Search tool, since the name (and format) + * of the tool changes between Gemini 1.5 and Gemini 2.0. + * true - Change based on the model version. (Default) + * false - Do not change the tool name provided + * string value - Use this as the attribute name for the search + * tool, adapting any tool attributes if possible. + * When the model is created, a "true" or default setting + * will be changed to a string based on the model. + */ + googleSearchToolAdjustment?: GoogleSearchToolSetting; } export type GoogleAIAPIConfig = GeminiAPIConfig | AnthropicAPIConfig; diff --git a/libs/langchain-google-common/src/utils/common.ts b/libs/langchain-google-common/src/utils/common.ts index b40ce25fe3fc..4194f9578b9e 100644 --- a/libs/langchain-google-common/src/utils/common.ts +++ b/libs/langchain-google-common/src/utils/common.ts @@ -1,10 +1,11 @@ import { isOpenAITool } from "@langchain/core/language_models/base"; import { isLangChainTool } from "@langchain/core/utils/function_calling"; import { isModelGemini, validateGeminiParams } from "./gemini.js"; -import type { +import { GeminiFunctionDeclaration, GeminiFunctionSchema, GeminiTool, + GeminiToolAttributes, GoogleAIBaseLanguageModelCallOptions, GoogleAIModelParams, GoogleAIModelRequestParams, @@ -61,11 +62,24 @@ function processToolChoice( throw new Error("Object inputs for tool_choice not supported."); } +function isGeminiTool(tool: GoogleAIToolType): tool is GeminiTool { + for (const toolAttribute of GeminiToolAttributes) { + if (toolAttribute in tool) { + return true; + } + } + return false; +} + +function isGeminiNonFunctionTool(tool: GoogleAIToolType): tool is GeminiTool { + return isGeminiTool(tool) && !("functionDeclaration" in tool); +} + export function convertToGeminiTools(tools: GoogleAIToolType[]): GeminiTool[] { const geminiTools: GeminiTool[] = []; let functionDeclarationsIndex = -1; tools.forEach((tool) => { - if ("googleSearchRetrieval" in tool || "retrieval" in tool) { + if (isGeminiNonFunctionTool(tool)) { geminiTools.push(tool); } else { if (functionDeclarationsIndex === -1) { diff --git a/libs/langchain-google-common/src/utils/gemini.ts b/libs/langchain-google-common/src/utils/gemini.ts index 23c46f3783db..7f532983d6a4 100644 --- a/libs/langchain-google-common/src/utils/gemini.ts +++ b/libs/langchain-google-common/src/utils/gemini.ts @@ -37,6 +37,7 @@ import type { GeminiPartFunctionCall, GoogleAIAPI, GeminiAPIConfig, + GeminiGroundingSupport, } from "../types.js"; import { GoogleAISafetyError } from "./safety.js"; import { MediaBlob } from "../experimental/utils/media_core.js"; @@ -48,6 +49,7 @@ import { GeminiTool, GoogleAIModelRequestParams, GoogleAIToolType, + GeminiSearchToolAttributes, } from "../types.js"; import { zodToGeminiParameters } from "./zod_to_gemini_parameters.js"; @@ -690,6 +692,8 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI { severity: rating.severity, severity_score: rating.severityScore, })), + citation_metadata: data.candidates[0]?.citationMetadata, + grounding_metadata: data.candidates[0]?.groundingMetadata, finish_reason: data.candidates[0]?.finishReason, }; } @@ -749,7 +753,29 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI { }); } - function responseToChatGenerations( + function groundingSupportByPart( + groundingSupports?: GeminiGroundingSupport[] + ): GeminiGroundingSupport[][] { + const ret: GeminiGroundingSupport[][] = []; + + if (!groundingSupports || groundingSupports.length === 0) { + return []; + } + + groundingSupports?.forEach((groundingSupport) => { + const segment = groundingSupport?.segment; + const partIndex = segment?.partIndex ?? 0; + if (ret[partIndex]) { + ret[partIndex].push(groundingSupport); + } else { + ret[partIndex] = [groundingSupport]; + } + }); + + return ret; + } + + function responseToGroundedChatGenerations( response: GoogleLLMResponse ): ChatGeneration[] { const parts = responseToParts(response); @@ -758,7 +784,46 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI { return []; } - let ret = parts.map((part) => partToChatGeneration(part)); + // Citation and grounding information connected to each part / ChatGeneration + // to make sure they are available in downstream filters. + const candidate = (response?.data as GenerateContentResponseData) + ?.candidates?.[0]; + const groundingMetadata = candidate?.groundingMetadata; + const citationMetadata = candidate?.citationMetadata; + const groundingParts = groundingSupportByPart( + groundingMetadata?.groundingSupports + ); + + const ret = parts.map((part, index) => { + const gen = partToChatGeneration(part); + if (!gen.generationInfo) { + gen.generationInfo = {}; + } + if (groundingMetadata) { + gen.generationInfo.groundingMetadata = groundingMetadata; + const groundingPart = groundingParts[index]; + if (groundingPart) { + gen.generationInfo.groundingSupport = groundingPart; + } + } + if (citationMetadata) { + gen.generationInfo.citationMetadata = citationMetadata; + } + return gen; + }); + + return ret; + } + + function responseToChatGenerations( + response: GoogleLLMResponse + ): ChatGeneration[] { + let ret = responseToGroundedChatGenerations(response); + + if (ret.length === 0) { + return []; + } + if (ret.every((item) => typeof item.message.content === "string")) { const combinedContent = ret.map((item) => item.message.content).join(""); const combinedText = ret.map((item) => item.text).join(""); @@ -1015,17 +1080,44 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI { }; } + function searchToolName(tool: GeminiTool): string | undefined { + for (const name of GeminiSearchToolAttributes) { + if (name in tool) { + return name; + } + } + return undefined; + } + + function cleanGeminiTool(tool: GeminiTool): GeminiTool { + const orig = searchToolName(tool); + const adj = config?.googleSearchToolAdjustment; + if (orig && adj && adj !== orig) { + return { + [adj as string]: {}, + }; + } else { + return tool; + } + } + function formatTools(parameters: GoogleAIModelRequestParams): GeminiTool[] { const tools: GoogleAIToolType[] | undefined = parameters?.tools; if (!tools || tools.length === 0) { return []; } - // Group all LangChain tools into a single functionDeclarations array - const langChainTools = tools.filter(isLangChainTool); - const otherTools = tools.filter( - (tool) => !isLangChainTool(tool) - ) as GeminiTool[]; + // Group all LangChain tools into a single functionDeclarations array. + // Gemini Tools may be normalized to different tool names + const langChainTools: StructuredToolParams[] = []; + const otherTools: GeminiTool[] = []; + tools.forEach((tool) => { + if (isLangChainTool(tool)) { + langChainTools.push(tool); + } else { + otherTools.push(cleanGeminiTool(tool as GeminiTool)); + } + }); const result: GeminiTool[] = [...otherTools]; diff --git a/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts b/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts index ddcdf579a394..6d1606614bd1 100644 --- a/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts @@ -59,18 +59,45 @@ const calculatorTool = tool((_) => "no-op", { }), }); -describe("GAuth Gemini Chat", () => { +/* + * Which models do we want to run the test suite against? + */ +const testGeminiModelNames = [ + ["gemini-1.5-pro-002"], + ["gemini-1.5-flash-002"], + ["gemini-2.0-flash-exp"], + // ["gemini-2.0-flash-thinking-exp-1219"], +]; + +/* + * Some models may have usage quotas still. + * For those models, set how long (in millis) to wait in between each test. + */ +const testGeminiModelDelay: Record = { + "gemini-2.0-flash-exp": 5000, + "gemini-2.0-flash-thinking-exp-1219": 5000, +}; + +describe.each(testGeminiModelNames)("GAuth Gemini Chat (%s)", (modelName) => { let recorder: GoogleRequestRecorder; let callbacks: BaseCallbackHandler[]; - beforeEach(() => { + beforeEach(async () => { recorder = new GoogleRequestRecorder(); callbacks = [recorder, new GoogleRequestLogger()]; + + const delay = testGeminiModelDelay[modelName] ?? 0; + if (delay) { + console.log(`Delaying for ${delay}ms`); + // eslint-disable-next-line no-promise-executor-return + await new Promise((resolve) => setTimeout(resolve, delay)); + } }); test("invoke", async () => { const model = new ChatVertexAI({ callbacks, + modelName, }); const res = await model.invoke("What is 1 + 1?"); expect(res).toBeDefined(); @@ -84,8 +111,10 @@ describe("GAuth Gemini Chat", () => { expect(text).toMatch(/(1 + 1 (equals|is|=) )?2.? ?/); }); - test("generate", async () => { - const model = new ChatVertexAI(); + test(`generate`, async () => { + const model = new ChatVertexAI({ + modelName, + }); const messages: BaseMessage[] = [ new SystemMessage( "You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails." @@ -103,12 +132,13 @@ describe("GAuth Gemini Chat", () => { expect(typeof aiMessage.content).toBe("string"); const text = aiMessage.content as string; - expect(["H", "T"]).toContainEqual(text); + expect(["H", "T"]).toContainEqual(text.trim()); }); test("stream", async () => { const model = new ChatVertexAI({ callbacks, + modelName, }); const input: BaseLanguageModelInput = new ChatPromptValue([ new SystemMessage( @@ -153,7 +183,11 @@ describe("GAuth Gemini Chat", () => { ], }, ]; - const model = new ChatVertexAI().bind({ tools }); + const model = new ChatVertexAI({ + modelName, + }).bind({ + tools, + }); const result = await model.invoke("Run a test on the cobalt project"); expect(result).toHaveProperty("content"); expect(result.content).toBe(""); @@ -197,7 +231,11 @@ describe("GAuth Gemini Chat", () => { ], }, ]; - const model = new ChatVertexAI().bind({ tools }); + const model = new ChatVertexAI({ + modelName, + }).bind({ + tools, + }); const toolResult = { testPassed: true, }; @@ -241,7 +279,9 @@ describe("GAuth Gemini Chat", () => { required: ["location"], }, }; - const model = new ChatVertexAI().withStructuredOutput(tool); + const model = new ChatVertexAI({ + modelName, + }).withStructuredOutput(tool); const result = await model.invoke("What is the weather in Paris?"); expect(result).toHaveProperty("location"); }); @@ -275,7 +315,7 @@ describe("GAuth Gemini Chat", () => { resolvers: [resolver], }); const model = new ChatGoogle({ - modelName: "gemini-1.5-flash", + modelName, apiConfig: { mediaManager, }, @@ -320,6 +360,7 @@ describe("GAuth Gemini Chat", () => { const model = new ChatVertexAI({ temperature: 0, maxOutputTokens: 10, + modelName, }); let res: AIMessageChunk | null = null; for await (const chunk of await model.stream( @@ -347,6 +388,7 @@ describe("GAuth Gemini Chat", () => { const model = new ChatVertexAI({ temperature: 0, streamUsage: false, + modelName, }); let res: AIMessageChunk | null = null; for await (const chunk of await model.stream( @@ -366,6 +408,7 @@ describe("GAuth Gemini Chat", () => { const model = new ChatVertexAI({ temperature: 0, maxOutputTokens: 10, + modelName, }); const res = await model.invoke("Why is the sky blue? Be concise."); // console.log(res); @@ -384,6 +427,7 @@ describe("GAuth Gemini Chat", () => { const modelWithStreaming = new ChatVertexAI({ maxOutputTokens: 50, streaming: true, + modelName, }); let totalTokenCount = 0; @@ -407,7 +451,7 @@ describe("GAuth Gemini Chat", () => { test("Can force a model to invoke a tool", async () => { const model = new ChatVertexAI({ - model: "gemini-1.5-pro", + modelName, }); const modelWithTools = model.bind({ tools: [calculatorTool, weatherTool], @@ -425,8 +469,10 @@ describe("GAuth Gemini Chat", () => { expect(result.tool_calls?.[0].args).toHaveProperty("expression"); }); - test("ChatGoogleGenerativeAI can stream tools", async () => { - const model = new ChatVertexAI({}); + test(`stream tools`, async () => { + const model = new ChatVertexAI({ + modelName, + }); const weatherTool = tool( (_) => "The weather in San Francisco today is 18 degrees and sunny.", @@ -474,7 +520,7 @@ describe("GAuth Gemini Chat", () => { const audioMimeType = "audio/wav"; const model = new ChatVertexAI({ - model: "gemini-1.5-flash", + model: modelName, temperature: 0, maxRetries: 0, }); @@ -505,6 +551,65 @@ describe("GAuth Gemini Chat", () => { expect(typeof response.content).toBe("string"); expect((response.content as string).length).toBeGreaterThan(15); }); + + test("Supports GoogleSearchRetrievalTool", async () => { + const searchRetrievalTool = { + googleSearchRetrieval: { + dynamicRetrievalConfig: { + mode: "MODE_DYNAMIC", + dynamicThreshold: 0.7, // default is 0.7 + }, + }, + }; + const model = new ChatVertexAI({ + modelName, + temperature: 0, + maxRetries: 0, + }).bindTools([searchRetrievalTool]); + + const result = await model.invoke("Who won the 2024 MLB World Series?"); + expect(result.content as string).toContain("Dodgers"); + }); + + test("Supports GoogleSearchTool", async () => { + const searchTool: GeminiTool = { + googleSearch: {}, + }; + const model = new ChatVertexAI({ + modelName, + temperature: 0, + maxRetries: 0, + }).bindTools([searchTool]); + + const result = await model.invoke("Who won the 2024 MLB World Series?"); + expect(result.content as string).toContain("Dodgers"); + }); + + test("Can stream GoogleSearchRetrievalTool", async () => { + const searchRetrievalTool = { + googleSearchRetrieval: { + dynamicRetrievalConfig: { + mode: "MODE_DYNAMIC", + dynamicThreshold: 0.7, // default is 0.7 + }, + }, + }; + const model = new ChatVertexAI({ + modelName, + temperature: 0, + maxRetries: 0, + }).bindTools([searchRetrievalTool]); + + const stream = await model.stream("Who won the 2024 MLB World Series?"); + let finalMsg: AIMessageChunk | undefined; + for await (const msg of stream) { + finalMsg = finalMsg ? concat(finalMsg, msg) : msg; + } + if (!finalMsg) { + throw new Error("finalMsg is undefined"); + } + expect(finalMsg.content as string).toContain("Dodgers"); + }); }); describe("GAuth Anthropic Chat", () => { @@ -617,50 +722,3 @@ describe("GAuth Anthropic Chat", () => { expect(toolCalls?.[0].args).toHaveProperty("location"); }); }); - -describe("GoogleSearchRetrievalTool", () => { - test("Supports GoogleSearchRetrievalTool", async () => { - const searchRetrievalTool = { - googleSearchRetrieval: { - dynamicRetrievalConfig: { - mode: "MODE_DYNAMIC", - dynamicThreshold: 0.7, // default is 0.7 - }, - }, - }; - const model = new ChatVertexAI({ - model: "gemini-1.5-pro", - temperature: 0, - maxRetries: 0, - }).bindTools([searchRetrievalTool]); - - const result = await model.invoke("Who won the 2024 MLB World Series?"); - expect(result.content as string).toContain("Dodgers"); - }); - - test("Can stream GoogleSearchRetrievalTool", async () => { - const searchRetrievalTool = { - googleSearchRetrieval: { - dynamicRetrievalConfig: { - mode: "MODE_DYNAMIC", - dynamicThreshold: 0.7, // default is 0.7 - }, - }, - }; - const model = new ChatVertexAI({ - model: "gemini-1.5-pro", - temperature: 0, - maxRetries: 0, - }).bindTools([searchRetrievalTool]); - - const stream = await model.stream("Who won the 2024 MLB World Series?"); - let finalMsg: AIMessageChunk | undefined; - for await (const msg of stream) { - finalMsg = finalMsg ? concat(finalMsg, msg) : msg; - } - if (!finalMsg) { - throw new Error("finalMsg is undefined"); - } - expect(finalMsg.content as string).toContain("Dodgers"); - }); -}); diff --git a/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts b/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts index 0e10359599b3..e66bab6f06ca 100644 --- a/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts +++ b/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts @@ -1,12 +1,13 @@ /* eslint-disable import/no-extraneous-dependencies */ -import { StructuredTool } from "@langchain/core/tools"; +import { StructuredTool, tool } from "@langchain/core/tools"; import { z } from "zod"; -import { test } from "@jest/globals"; +import { expect, test } from "@jest/globals"; import { AIMessage, AIMessageChunk, BaseMessage, BaseMessageChunk, + BaseMessageLike, HumanMessage, HumanMessageChunk, MessageContentComplex, @@ -19,7 +20,19 @@ import { MediaManager, SimpleWebBlobStore, } from "@langchain/google-common/experimental/utils/media_core"; -import { ChatGoogle } from "../chat_models.js"; +import { + GeminiTool, + GooglePlatformType, + GoogleRequestRecorder, +} from "@langchain/google-common"; +import { BaseCallbackHandler } from "@langchain/core/callbacks/base"; +import { concat } from "@langchain/core/utils/stream"; +import fs from "fs/promises"; +import { + ChatPromptTemplate, + MessagesPlaceholder, +} from "@langchain/core/prompts"; +import { ChatGoogle, ChatGoogleInput } from "../chat_models.js"; import { BlobStoreAIStudioFile } from "../media.js"; class WeatherTool extends StructuredTool { @@ -247,3 +260,596 @@ describe("Google APIKey Chat", () => { } }); }); + +const weatherTool = tool((_) => "no-op", { + name: "get_weather", + description: + "Get the weather of a specific location and return the temperature in Celsius.", + schema: z.object({ + location: z.string().describe("The name of city to get the weather for."), + }), +}); + +const calculatorTool = tool((_) => "no-op", { + name: "calculator", + description: "Calculate the result of a math expression.", + schema: z.object({ + expression: z.string().describe("The math expression to calculate."), + }), +}); + +/* + * Which models do we want to run the test suite against + * and on which platforms? + */ +const testGeminiModelNames = [ + { + modelName: "gemini-1.5-pro-002", + platformType: "gai", + apiVersion: "v1beta", + }, + { modelName: "gemini-1.5-pro-002", platformType: "gcp", apiVersion: "v1" }, + { + modelName: "gemini-1.5-flash-002", + platformType: "gai", + apiVersion: "v1beta", + }, + { modelName: "gemini-1.5-flash-002", platformType: "gcp", apiVersion: "v1" }, + { + modelName: "gemini-2.0-flash-exp", + platformType: "gai", + apiVersion: "v1beta", + }, + { modelName: "gemini-2.0-flash-exp", platformType: "gcp", apiVersion: "v1" }, + + // Flash Thinking doesn't have functions or other features + // {modelName: "gemini-2.0-flash-thinking-exp", platformType: "gai"}, + // {modelName: "gemini-2.0-flash-thinking-exp", platformType: "gcp"}, +]; + +/* + * Some models may have usage quotas still. + * For those models, set how long (in millis) to wait in between each test. + */ +const testGeminiModelDelay: Record = { + "gemini-2.0-flash-exp": 10000, + "gemini-2.0-flash-thinking-exp-1219": 10000, +}; + +describe.each(testGeminiModelNames)( + "Webauth ($platformType) Gemini Chat ($modelName)", + ({ modelName, platformType, apiVersion }) => { + let recorder: GoogleRequestRecorder; + let callbacks: BaseCallbackHandler[]; + + function newChatGoogle(fields?: ChatGoogleInput): ChatGoogle { + // const logger = new GoogleRequestLogger(); + recorder = new GoogleRequestRecorder(); + callbacks = [recorder]; + + return new ChatGoogle({ + modelName, + platformType: platformType as GooglePlatformType, + apiVersion, + callbacks, + ...(fields ?? {}), + }); + } + + beforeEach(async () => { + const delay = testGeminiModelDelay[modelName] ?? 0; + if (delay) { + console.log(`Delaying for ${delay}ms`); + // eslint-disable-next-line no-promise-executor-return + await new Promise((resolve) => setTimeout(resolve, delay)); + } + }); + + test("invoke", async () => { + const model = newChatGoogle(); + const res = await model.invoke("What is 1 + 1?"); + expect(res).toBeDefined(); + expect(res._getType()).toEqual("ai"); + + const aiMessage = res as AIMessageChunk; + expect(aiMessage.content).toBeDefined(); + + expect(typeof aiMessage.content).toBe("string"); + const text = aiMessage.content as string; + expect(text).toMatch(/(1 + 1 (equals|is|=) )?2.? ?/); + + expect(res).toHaveProperty("response_metadata"); + expect(res.response_metadata).not.toHaveProperty("groundingMetadata"); + expect(res.response_metadata).not.toHaveProperty("groundingSupport"); + + console.log(recorder); + }); + + test(`generate`, async () => { + const model = newChatGoogle(); + const messages: BaseMessage[] = [ + new SystemMessage( + "You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails." + ), + new HumanMessage("Flip it"), + new AIMessage("T"), + new HumanMessage("Flip the coin again"), + ]; + const res = await model.predictMessages(messages); + expect(res).toBeDefined(); + expect(res._getType()).toEqual("ai"); + + const aiMessage = res as AIMessageChunk; + expect(aiMessage.content).toBeDefined(); + + expect(typeof aiMessage.content).toBe("string"); + const text = aiMessage.content as string; + expect(["H", "T"]).toContainEqual(text.trim()); + }); + + test("stream", async () => { + const model = newChatGoogle(); + const input: BaseLanguageModelInput = new ChatPromptValue([ + new SystemMessage( + "You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails." + ), + new HumanMessage("Flip it"), + new AIMessage("T"), + new HumanMessage("Flip the coin again"), + ]); + const res = await model.stream(input); + const resArray: BaseMessageChunk[] = []; + for await (const chunk of res) { + resArray.push(chunk); + } + expect(resArray).toBeDefined(); + expect(resArray.length).toBeGreaterThanOrEqual(1); + + const lastChunk = resArray[resArray.length - 1]; + expect(lastChunk).toBeDefined(); + expect(lastChunk._getType()).toEqual("ai"); + }); + + test("function", async () => { + const tools: GeminiTool[] = [ + { + functionDeclarations: [ + { + name: "test", + description: + "Run a test with a specific name and get if it passed or failed", + parameters: { + type: "object", + properties: { + testName: { + type: "string", + description: "The name of the test that should be run.", + }, + }, + required: ["testName"], + }, + }, + ], + }, + ]; + const model = newChatGoogle().bind({ + tools, + }); + const result = await model.invoke("Run a test on the cobalt project"); + expect(result).toHaveProperty("content"); + expect(result.content).toBe(""); + const args = result?.lc_kwargs?.additional_kwargs; + expect(args).toBeDefined(); + expect(args).toHaveProperty("tool_calls"); + expect(Array.isArray(args.tool_calls)).toBeTruthy(); + expect(args.tool_calls).toHaveLength(1); + const call = args.tool_calls[0]; + expect(call).toHaveProperty("type"); + expect(call.type).toBe("function"); + expect(call).toHaveProperty("function"); + const func = call.function; + expect(func).toBeDefined(); + expect(func).toHaveProperty("name"); + expect(func.name).toBe("test"); + expect(func).toHaveProperty("arguments"); + expect(typeof func.arguments).toBe("string"); + expect(func.arguments.replaceAll("\n", "")).toBe('{"testName":"cobalt"}'); + }); + + test("function reply", async () => { + const tools: GeminiTool[] = [ + { + functionDeclarations: [ + { + name: "test", + description: + "Run a test with a specific name and get if it passed or failed", + parameters: { + type: "object", + properties: { + testName: { + type: "string", + description: "The name of the test that should be run.", + }, + }, + required: ["testName"], + }, + }, + ], + }, + ]; + const model = newChatGoogle().bind({ + tools, + }); + const toolResult = { + testPassed: true, + }; + const messages: BaseMessageLike[] = [ + new HumanMessage("Run a test on the cobalt project."), + new AIMessage("", { + tool_calls: [ + { + id: "test", + type: "function", + function: { + name: "test", + arguments: '{"testName":"cobalt"}', + }, + }, + ], + }), + new ToolMessage(JSON.stringify(toolResult), "test"), + ]; + const res = await model.stream(messages); + const resArray: BaseMessageChunk[] = []; + for await (const chunk of res) { + resArray.push(chunk); + } + // console.log(JSON.stringify(resArray, null, 2)); + }); + + test("withStructuredOutput", async () => { + const tool = { + name: "get_weather", + description: + "Get the weather of a specific location and return the temperature in Celsius.", + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: "The name of city to get the weather for.", + }, + }, + required: ["location"], + }, + }; + const model = newChatGoogle().withStructuredOutput(tool); + const result = await model.invoke("What is the weather in Paris?"); + expect(result).toHaveProperty("location"); + }); + + // test("media - fileData", async () => { + // class MemStore extends InMemoryStore { + // get length() { + // return Object.keys(this.store).length; + // } + // } + // const aliasMemory = new MemStore(); + // const aliasStore = new BackedBlobStore({ + // backingStore: aliasMemory, + // defaultFetchOptions: { + // actionIfBlobMissing: undefined, + // }, + // }); + // const backingStore = new BlobStoreGoogleCloudStorage({ + // uriPrefix: new GoogleCloudStorageUri( + // "gs://test-langchainjs/mediatest/" + // ), + // defaultStoreOptions: { + // actionIfInvalid: "prefixPath", + // }, + // }); + // const blobStore = new ReadThroughBlobStore({ + // baseStore: aliasStore, + // backingStore, + // }); + // const resolver = new SimpleWebBlobStore(); + // const mediaManager = new MediaManager({ + // store: blobStore, + // resolvers: [resolver], + // }); + // const model = newChatGoogle({ + // apiConfig: { + // mediaManager, + // }, + // }); + + // const message: MessageContentComplex[] = [ + // { + // type: "text", + // text: "What is in this image?", + // }, + // { + // type: "media", + // fileUri: "https://js.langchain.com/v0.2/img/brand/wordmark.png", + // }, + // ]; + + // const messages: BaseMessage[] = [ + // new HumanMessageChunk({ content: message }), + // ]; + + // try { + // const res = await model.invoke(messages); + + // console.log(res); + + // expect(res).toBeDefined(); + // expect(res._getType()).toEqual("ai"); + + // const aiMessage = res as AIMessageChunk; + // expect(aiMessage.content).toBeDefined(); + + // expect(typeof aiMessage.content).toBe("string"); + // const text = aiMessage.content as string; + // expect(text).toMatch(/LangChain/); + // } catch (e) { + // console.error(e); + // throw e; + // } + // }); + + test("Stream token count usage_metadata", async () => { + const model = newChatGoogle({ + temperature: 0, + maxOutputTokens: 10, + }); + let res: AIMessageChunk | null = null; + for await (const chunk of await model.stream( + "Why is the sky blue? Be concise." + )) { + if (!res) { + res = chunk; + } else { + res = res.concat(chunk); + } + } + // console.log(res); + expect(res?.usage_metadata).toBeDefined(); + if (!res?.usage_metadata) { + return; + } + expect(res.usage_metadata.input_tokens).toBeGreaterThan(1); + expect(res.usage_metadata.output_tokens).toBeGreaterThan(1); + expect(res.usage_metadata.total_tokens).toBe( + res.usage_metadata.input_tokens + res.usage_metadata.output_tokens + ); + }); + + test("streamUsage excludes token usage", async () => { + const model = newChatGoogle({ + temperature: 0, + streamUsage: false, + }); + let res: AIMessageChunk | null = null; + for await (const chunk of await model.stream( + "Why is the sky blue? Be concise." + )) { + if (!res) { + res = chunk; + } else { + res = res.concat(chunk); + } + } + // console.log(res); + expect(res?.usage_metadata).not.toBeDefined(); + }); + + test("Invoke token count usage_metadata", async () => { + const model = newChatGoogle({ + temperature: 0, + maxOutputTokens: 10, + }); + const res = await model.invoke("Why is the sky blue? Be concise."); + // console.log(res); + expect(res?.usage_metadata).toBeDefined(); + if (!res?.usage_metadata) { + return; + } + expect(res.usage_metadata.input_tokens).toBeGreaterThan(1); + expect(res.usage_metadata.output_tokens).toBeGreaterThan(1); + expect(res.usage_metadata.total_tokens).toBe( + res.usage_metadata.input_tokens + res.usage_metadata.output_tokens + ); + }); + + test("Streaming true constructor param will stream", async () => { + const modelWithStreaming = newChatGoogle({ + maxOutputTokens: 50, + streaming: true, + }); + + let totalTokenCount = 0; + let tokensString = ""; + const result = await modelWithStreaming.invoke("What is 1 + 1?", { + callbacks: [ + ...callbacks, + { + handleLLMNewToken: (tok) => { + totalTokenCount += 1; + tokensString += tok; + }, + }, + ], + }); + + expect(result).toBeDefined(); + expect(result.content).toBe(tokensString); + + expect(totalTokenCount).toBeGreaterThan(1); + }); + + test("Can force a model to invoke a tool", async () => { + const model = newChatGoogle(); + const modelWithTools = model.bind({ + tools: [calculatorTool, weatherTool], + tool_choice: "calculator", + }); + + const result = await modelWithTools.invoke( + "Whats the weather like in paris today? What's 1836 plus 7262?" + ); + + expect(result.tool_calls).toHaveLength(1); + expect(result.tool_calls?.[0]).toBeDefined(); + if (!result.tool_calls?.[0]) return; + expect(result.tool_calls?.[0].name).toBe("calculator"); + expect(result.tool_calls?.[0].args).toHaveProperty("expression"); + }); + + test(`stream tools`, async () => { + const model = newChatGoogle(); + + const weatherTool = tool( + (_) => "The weather in San Francisco today is 18 degrees and sunny.", + { + name: "current_weather_tool", + description: "Get the current weather for a given location.", + schema: z.object({ + location: z + .string() + .describe("The location to get the weather for."), + }), + } + ); + + const modelWithTools = model.bindTools([weatherTool]); + const stream = await modelWithTools.stream( + "Whats the weather like today in San Francisco?" + ); + let finalChunk: AIMessageChunk | undefined; + for await (const chunk of stream) { + finalChunk = !finalChunk ? chunk : concat(finalChunk, chunk); + } + + expect(finalChunk).toBeDefined(); + if (!finalChunk) return; + + const toolCalls = finalChunk.tool_calls; + expect(toolCalls).toBeDefined(); + if (!toolCalls) { + throw new Error("tool_calls not in response"); + } + expect(toolCalls.length).toBe(1); + expect(toolCalls[0].name).toBe("current_weather_tool"); + expect(toolCalls[0].args).toHaveProperty("location"); + }); + + async function fileToBase64(filePath: string): Promise { + const fileData = await fs.readFile(filePath); + const base64String = Buffer.from(fileData).toString("base64"); + return base64String; + } + + test("Gemini can understand audio", async () => { + // Update this with the correct path to an audio file on your machine. + const audioPath = + "../langchain-google-genai/src/tests/data/gettysburg10.wav"; + const audioMimeType = "audio/wav"; + + const model = newChatGoogle({ + temperature: 0, + maxRetries: 0, + }); + + const audioBase64 = await fileToBase64(audioPath); + + const prompt = ChatPromptTemplate.fromMessages([ + new MessagesPlaceholder("audio"), + ]); + + const chain = prompt.pipe(model); + const response = await chain.invoke({ + audio: new HumanMessage({ + content: [ + { + type: "media", + mimeType: audioMimeType, + data: audioBase64, + }, + { + type: "text", + text: "Summarize the content in this audio. ALso, what is the speaker's tone?", + }, + ], + }), + }); + + expect(typeof response.content).toBe("string"); + expect((response.content as string).length).toBeGreaterThan(15); + }); + + test("Supports GoogleSearchRetrievalTool", async () => { + const searchRetrievalTool = { + googleSearchRetrieval: { + dynamicRetrievalConfig: { + mode: "MODE_DYNAMIC", + dynamicThreshold: 0.7, // default is 0.7 + }, + }, + }; + const model = newChatGoogle({ + temperature: 0, + maxRetries: 0, + }).bindTools([searchRetrievalTool]); + + const result = await model.invoke("Who won the 2024 MLB World Series?"); + expect(result.content as string).toContain("Dodgers"); + expect(result).toHaveProperty("response_metadata"); + expect(result.response_metadata).toHaveProperty("groundingMetadata"); + expect(result.response_metadata).toHaveProperty("groundingSupport"); + }); + + test("Supports GoogleSearchTool", async () => { + const searchTool: GeminiTool = { + googleSearch: {}, + }; + const model = newChatGoogle({ + temperature: 0, + maxRetries: 0, + }).bindTools([searchTool]); + + const result = await model.invoke("Who won the 2024 MLB World Series?"); + expect(result.content as string).toContain("Dodgers"); + expect(result).toHaveProperty("response_metadata"); + expect(result.response_metadata).toHaveProperty("groundingMetadata"); + expect(result.response_metadata).toHaveProperty("groundingSupport"); + }); + + test("Can stream GoogleSearchRetrievalTool", async () => { + const searchRetrievalTool = { + googleSearchRetrieval: { + dynamicRetrievalConfig: { + mode: "MODE_DYNAMIC", + dynamicThreshold: 0.7, // default is 0.7 + }, + }, + }; + const model = newChatGoogle({ + temperature: 0, + maxRetries: 0, + }).bindTools([searchRetrievalTool]); + + const stream = await model.stream("Who won the 2024 MLB World Series?"); + let finalMsg: AIMessageChunk | undefined; + for await (const msg of stream) { + finalMsg = finalMsg ? concat(finalMsg, msg) : msg; + } + if (!finalMsg) { + throw new Error("finalMsg is undefined"); + } + expect(finalMsg.content as string).toContain("Dodgers"); + }); + } +);