Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add session context for a user mongodb #7436

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 72 additions & 3 deletions libs/langchain-azure-cosmosdb/src/chat_histories/mongodb.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ export interface AzureCosmosDBMongoChatHistoryDBConfig {
readonly collectionName?: string;
}

export type ChatSessionMongo = {
id: string;
context: Record<string, unknown>;
};

const ID_KEY = "sessionId";
const ID_USER = "userId";

export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHistory {
lc_namespace = ["langchain", "stores", "message", "azurecosmosdb"];
Expand All @@ -33,6 +39,8 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis

private initPromise?: Promise<void>;

private context: Record<string, unknown> = {};

private readonly client: MongoClient | undefined;

private database: Db;
Expand All @@ -41,11 +49,14 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis

private sessionId: string;

private userId: string;

initialize: () => Promise<void>;

constructor(
dbConfig: AzureCosmosDBMongoChatHistoryDBConfig,
sessionId: string
sessionId: string,
userId: string
) {
super();

Expand All @@ -70,6 +81,7 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis
const collectionName = dbConfig.collectionName ?? "chatHistory";

this.sessionId = sessionId;
this.userId = userId ?? "anonymous";

// Deferring initialization to the first call to `initialize`
this.initialize = () => {
Expand Down Expand Up @@ -120,6 +132,7 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis

const document = await this.collection.findOne({
[ID_KEY]: this.sessionId,
[ID_USER]: this.userId,
});
const messages = document?.messages || [];
return mapStoredMessagesToChatMessages(messages);
Expand All @@ -134,10 +147,12 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis
await this.initialize();

const messages = mapChatMessagesToStoredMessages([message]);
const context = await this.getContext();
await this.collection.updateOne(
{ [ID_KEY]: this.sessionId },
{ [ID_KEY]: this.sessionId, [ID_USER]: this.userId },
{
$push: { messages: { $each: messages } } as PushOperator<Document>,
$set: { context },
},
{ upsert: true }
);
Expand All @@ -150,6 +165,60 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis
async clear(): Promise<void> {
await this.initialize();

await this.collection.deleteOne({ [ID_KEY]: this.sessionId });
await this.collection.deleteOne({
[ID_KEY]: this.sessionId,
[ID_USER]: this.userId,
});
}

async getAllSessions(): Promise<ChatSessionMongo[]> {
await this.initialize();
const documents = await this.collection.find().toArray();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing user filter here


const chatSessions: ChatSessionMongo[] = documents.map((doc) => ({
id: doc[ID_KEY],
user_id: doc[ID_USER],
context: doc.context || {},
}));

return chatSessions;
}

async clearAllSessions() {
await this.initialize();
try {
await this.collection.deleteMany({});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would delete all sessions for all users, which might not be the intended behavior... I think you missed the user handling here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, add user id as handle for sessions, thank you for your comment!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you forgot to update this call to filter it for the current user only

} catch (error) {
console.error("Error clearing chat history sessions:", error);
throw error;
}
}

async getContext(): Promise<Record<string, unknown>> {
await this.initialize();

const document = await this.collection.findOne({
[ID_KEY]: this.sessionId,
[ID_USER]: this.userId,
});
this.context = document?.context || this.context;
return this.context;
}

async setContext(context: Record<string, unknown>): Promise<void> {
await this.initialize();

try {
await this.collection.updateOne(
{ [ID_KEY]: this.sessionId },
{
$set: { context },
},
{ upsert: true }
);
} catch (error) {
console.error("Error setting chat history context", error);
throw error;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ test("Test Azure Cosmos MongoDB history store", async () => {
};

const sessionId = new ObjectId().toString();
const userId = new ObjectId().toString();
const chatHistory = new AzureCosmosDBMongoChatMessageHistory(
dbcfg,
sessionId
sessionId,
userId
);

const blankResult = await chatHistory.getMessages();
Expand Down Expand Up @@ -70,9 +72,11 @@ test("Test clear Azure Cosmos MongoDB history store", async () => {
};

const sessionId = new ObjectId().toString();
const userId = new ObjectId().toString();
const chatHistory = new AzureCosmosDBMongoChatMessageHistory(
dbcfg,
sessionId
sessionId,
userId
);

await chatHistory.addUserMessage("Who is the best vocalist?");
Expand All @@ -93,3 +97,50 @@ test("Test clear Azure Cosmos MongoDB history store", async () => {

await mongoClient.close();
});

test("Test getAllSessions and clearAllSessions", async () => {
expect(process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING).toBeDefined();

// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const mongoClient = new MongoClient(
process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING!
);
const dbcfg: AzureCosmosDBMongoChatHistoryDBConfig = {
client: mongoClient,
connectionString: process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING,
databaseName: "langchain",
collectionName: "chathistory",
};

const sessionId1 = new ObjectId().toString();
const userId1 = new ObjectId().toString();
const sessionId2 = new ObjectId().toString();
const userId2 = new ObjectId().toString();

const chatHistory1 = new AzureCosmosDBMongoChatMessageHistory(
dbcfg,
sessionId1,
userId1
);
const chatHistory2 = new AzureCosmosDBMongoChatMessageHistory(
dbcfg,
sessionId2,
userId2
);

await chatHistory1.addUserMessage("What is AI?");
await chatHistory1.addAIChatMessage("AI stands for Artificial Intelligence.");
await chatHistory2.addUserMessage("What is the best programming language?");
await chatHistory2.addAIChatMessage("It depends on the use case.");

const allSessions = await chatHistory1.getAllSessions();
expect(allSessions.length).toBe(2);
expect(allSessions[0].id).toBe(sessionId1);
expect(allSessions[1].id).toBe(sessionId2);

await chatHistory1.clearAllSessions();
const clearedSessions = await chatHistory1.getAllSessions();
expect(clearedSessions.length).toBe(0);

await mongoClient.close();
});