Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
crisjy committed Jan 12, 2025
1 parent 46c2ab3 commit 4fdbd6a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 12 deletions.
29 changes: 21 additions & 8 deletions libs/langchain-azure-cosmosdb/src/chat_histories/mongodb.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ export interface AzureCosmosDBMongoChatHistoryDBConfig {
readonly collectionName?: string;
}

export type ChatSession = {
export type ChatSessionMongo = {
id: string;
context: Record<string, unknown>;
};

const ID_KEY = "sessionId";
const ID_USER = "userId";

export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHistory {
lc_namespace = ["langchain", "stores", "message", "azurecosmosdb"];
Expand All @@ -48,11 +49,14 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis

private sessionId: string;

private userId: string;

initialize: () => Promise<void>;

constructor(
dbConfig: AzureCosmosDBMongoChatHistoryDBConfig,
sessionId: string
sessionId: string,
userId: string
) {
super();

Expand All @@ -77,6 +81,7 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis
const collectionName = dbConfig.collectionName ?? "chatHistory";

this.sessionId = sessionId;
this.userId = userId ?? "anonymous";

// Deferring initialization to the first call to `initialize`
this.initialize = () => {
Expand Down Expand Up @@ -127,6 +132,7 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis

const document = await this.collection.findOne({
[ID_KEY]: this.sessionId,
[ID_USER]: this.userId,
});
const messages = document?.messages || [];
return mapStoredMessagesToChatMessages(messages);
Expand All @@ -143,7 +149,7 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis
const messages = mapChatMessagesToStoredMessages([message]);
const context = await this.getContext();
await this.collection.updateOne(
{ [ID_KEY]: this.sessionId },
{ [ID_KEY]: this.sessionId, [ID_USER]: this.userId },
{
$push: { messages: { $each: messages } } as PushOperator<Document>,
$set: { context },
Expand All @@ -159,15 +165,19 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis
async clear(): Promise<void> {
await this.initialize();

await this.collection.deleteOne({ [ID_KEY]: this.sessionId });
await this.collection.deleteOne({
[ID_KEY]: this.sessionId,
[ID_USER]: this.userId,
});
}

async getAllSessions(): Promise<ChatSession[]> {
async getAllSessions(): Promise<ChatSessionMongo[]> {
await this.initialize();
const documents = await this.collection.find().toArray();

const chatSessions: ChatSession[] = documents.map((doc) => ({
const chatSessions: ChatSessionMongo[] = documents.map((doc) => ({
id: doc[ID_KEY],
user_id: doc[ID_USER],
context: doc.context || {},
}));

Expand All @@ -179,7 +189,8 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis
try {
await this.collection.deleteMany({});
} catch (error) {
console.log("Error clearing sessions:", error);
console.error("Error clearing chat history sessions:", error);
throw error;
}
}

Expand All @@ -188,6 +199,7 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis

const document = await this.collection.findOne({
[ID_KEY]: this.sessionId,
[ID_USER]: this.userId,
});
this.context = document?.context || this.context;
return this.context;
Expand All @@ -205,7 +217,8 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis
{ upsert: true }
);
} catch (error) {
console.log("Error setting context", error);
console.error("Error setting chat history context", error);
throw error;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ test("Test Azure Cosmos MongoDB history store", async () => {
};

const sessionId = new ObjectId().toString();
const userId = new ObjectId().toString();
const chatHistory = new AzureCosmosDBMongoChatMessageHistory(
dbcfg,
sessionId
sessionId,
userId
);

const blankResult = await chatHistory.getMessages();
Expand Down Expand Up @@ -70,9 +72,11 @@ test("Test clear Azure Cosmos MongoDB history store", async () => {
};

const sessionId = new ObjectId().toString();
const userId = new ObjectId().toString();
const chatHistory = new AzureCosmosDBMongoChatMessageHistory(
dbcfg,
sessionId
sessionId,
userId
);

await chatHistory.addUserMessage("Who is the best vocalist?");
Expand Down Expand Up @@ -109,15 +113,19 @@ test("Test getAllSessions and clearAllSessions", async () => {
};

const sessionId1 = new ObjectId().toString();
const userId1 = new ObjectId().toString();
const sessionId2 = new ObjectId().toString();
const userId2 = new ObjectId().toString();

const chatHistory1 = new AzureCosmosDBMongoChatMessageHistory(
dbcfg,
sessionId1
sessionId1,
userId1
);
const chatHistory2 = new AzureCosmosDBMongoChatMessageHistory(
dbcfg,
sessionId2
sessionId2,
userId2
);

await chatHistory1.addUserMessage("What is AI?");
Expand Down

0 comments on commit 4fdbd6a

Please sign in to comment.