Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

azure-cosmosdb[minor] add session context and retrieve all sessions for a user #7242

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions examples/src/memory/azure_cosmosdb_nosql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,29 @@ const res1 = await chainWithHistory.invoke(
);
console.log({ res1 });
/*
{ res1: 'Hi Jim! How can I assist you today?' }
{ res1: 'Hi Jim! How can I assist you today?' }
*/

const res2 = await chainWithHistory.invoke(
{ input: "What did I just say my name was?" },
{ configurable: { sessionId: "langchain-test-session" } }
);
console.log({ res2 });

/*
{ res2: { response: 'You said your name was Jim.' }
*/
*/

// Give this session a title
const chatHistory = (await chainWithHistory.getMessageHistory(
"langchain-test-session"
)) as AzureCosmsosDBNoSQLChatMessageHistory;
chatHistory.setContext({ title: "Introducing Jim" });

// List all session for the user
const sessions = await chatHistory.getAllSessions();
console.log(sessions);
/*
[
{ sessionId: 'langchain-test-session', context: { title: "Introducing Jim" } }
]
*/
6 changes: 3 additions & 3 deletions libs/langchain-azure-cosmosdb/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
"author": "LangChain",
"license": "MIT",
"dependencies": {
"@azure/cosmos": "4.0.1-beta.3",
"@azure/identity": "^4.2.0",
"mongodb": "^6.8.0"
"@azure/cosmos": "^4.2.0",
"@azure/identity": "^4.5.0",
"mongodb": "^6.10.0"
},
"peerDependencies": {
"@langchain/core": ">=0.2.21 <0.4.0"
Expand Down
6 changes: 4 additions & 2 deletions libs/langchain-azure-cosmosdb/src/azure_cosmosdb_nosql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ import {
IndexingPolicy,
SqlParameter,
SqlQuerySpec,
VectorEmbedding,
VectorEmbeddingPolicy,
VectorIndex,
} from "@azure/cosmos";
import { DefaultAzureCredential, TokenCredential } from "@azure/identity";

Expand Down Expand Up @@ -186,7 +188,7 @@ export class AzureCosmosDBNoSQLVectorStore extends VectorStore {
distanceFunction: "cosine",
// Will be determined automatically during initialization
dimensions: 0,
},
} as VectorEmbedding,
];
}

Expand All @@ -195,7 +197,7 @@ export class AzureCosmosDBNoSQLVectorStore extends VectorStore {
{
path: "/vector",
type: "quantizedFlat",
},
} as VectorIndex,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe as const works but not a huge deal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried it but it still complained 😞

];
}

Expand Down
62 changes: 57 additions & 5 deletions libs/langchain-azure-cosmosdb/src/chat_histories.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import { Container, CosmosClient, CosmosClientOptions } from "@azure/cosmos";
import {
Container,
CosmosClient,
CosmosClientOptions,
ErrorResponse,
} from "@azure/cosmos";
import { DefaultAzureCredential, TokenCredential } from "@azure/identity";
import { BaseListChatMessageHistory } from "@langchain/core/chat_history";
import {
Expand All @@ -12,6 +17,14 @@ const USER_AGENT_SUFFIX = "langchainjs-cdbnosql-chathistory-javascript";
const DEFAULT_DATABASE_NAME = "chatHistoryDB";
const DEFAULT_CONTAINER_NAME = "chatHistoryContainer";

/**
* Lightweight type for listing chat sessions.
*/
export type ChatSession = {
id: string;
context: Record<string, unknown>;
};

/**
* Type for the input to the `AzureCosmosDBNoSQLChatMessageHistory` constructor.
*/
Expand Down Expand Up @@ -68,7 +81,6 @@ export interface AzureCosmosDBNoSQLChatMessageHistoryInput {
* );
* ```
*/

export class AzureCosmsosDBNoSQLChatMessageHistory extends BaseListChatMessageHistory {
lc_namespace = ["langchain", "stores", "message", "azurecosmosdb"];

Expand All @@ -90,6 +102,8 @@ export class AzureCosmsosDBNoSQLChatMessageHistory extends BaseListChatMessageHi

private initPromise?: Promise<void>;

private context: Record<string, unknown> = {};

constructor(chatHistoryInput: AzureCosmosDBNoSQLChatMessageHistoryInput) {
super();

Expand Down Expand Up @@ -175,9 +189,11 @@ export class AzureCosmsosDBNoSQLChatMessageHistory extends BaseListChatMessageHi
this.messageList = await this.getMessages();
this.messageList.push(message);
const messages = mapChatMessagesToStoredMessages(this.messageList);
const context = await this.getContext();
await this.container.items.upsert({
id: this.sessionId,
userId: this.userId,
context,
messages,
});
}
Expand All @@ -188,17 +204,53 @@ export class AzureCosmsosDBNoSQLChatMessageHistory extends BaseListChatMessageHi
await this.container.item(this.sessionId, this.userId).delete();
}

async clearAllSessionsForUser(userId: string) {
async clearAllSessions() {
await this.initializeContainer();
const query = {
query: "SELECT c.id FROM c WHERE c.userId = @userId",
parameters: [{ name: "@userId", value: userId }],
parameters: [{ name: "@userId", value: this.userId }],
};
const { resources: userSessions } = await this.container.items
.query(query)
.fetchAll();
for (const userSession of userSessions) {
await this.container.item(userSession.id, userId).delete();
await this.container.item(userSession.id, this.userId).delete();
}
}

async getAllSessions(): Promise<ChatSession[]> {
await this.initializeContainer();
const query = {
query: "SELECT c.id, c.context FROM c WHERE c.userId = @userId",
parameters: [{ name: "@userId", value: this.userId }],
};
const { resources: userSessions } = await this.container.items
.query(query)
.fetchAll();
return userSessions ?? [];
}

async getContext(): Promise<Record<string, unknown>> {
const document = await this.container
.item(this.sessionId, this.userId)
.read();
this.context = document.resource?.context || this.context;
return this.context;
}

async setContext(context: Record<string, unknown>): Promise<void> {
await this.initializeContainer();
this.context = context || {};
try {
await this.container
.item(this.sessionId, this.userId)
.patch([{ op: "replace", path: "/context", value: this.context }]);
} catch (_error: unknown) {
const error = _error as ErrorResponse;
// If document does not exist yet, context will be set when adding the first message
if (error?.code !== 404) {
throw error;
}
}
}
}
3 changes: 2 additions & 1 deletion libs/langchain-azure-cosmosdb/src/tests/caches.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import {
CosmosClient,
IndexingMode,
VectorEmbedding,
VectorEmbeddingPolicy,
} from "@azure/cosmos";
import { DefaultAzureCredential } from "@azure/identity";
Expand Down Expand Up @@ -33,7 +34,7 @@ function vectorEmbeddingPolicy(
dataType: "float32",
distanceFunction,
dimensions: dimension,
},
} as VectorEmbedding,
],
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,41 @@ test("Test clear all sessions for a user", async () => {
const result2 = await chatHistory1.getMessages();
expect(result2).toEqual(expectedMessages);

await chatHistory1.clearAllSessionsForUser("user1");
await chatHistory1.clearAllSessions();

const deletedResult1 = await chatHistory1.getMessages();
const deletedResult2 = await chatHistory2.getMessages();
expect(deletedResult1).toStrictEqual([]);
expect(deletedResult2).toStrictEqual([]);
});

test("Test set context and get all sessions for a user", async () => {
const session1 = {
userId: "user1",
databaseName: DATABASE_NAME,
containerName: CONTAINER_NAME,
sessionId: new ObjectId().toString(),
};
const context1 = { title: "Best vocalist" };
const chatHistory1 = new AzureCosmsosDBNoSQLChatMessageHistory(session1);

await chatHistory1.setContext(context1);
await chatHistory1.addUserMessage("Who is the best vocalist?");
await chatHistory1.addAIMessage("Ozzy Osbourne");

const chatHistory2 = new AzureCosmsosDBNoSQLChatMessageHistory({
...session1,
sessionId: new ObjectId().toString(),
});
const context2 = { title: "Best guitarist" };

await chatHistory2.addUserMessage("Who is the best guitarist?");
await chatHistory2.addAIMessage("Jimi Hendrix");
await chatHistory2.setContext(context2);

const sessions = await chatHistory1.getAllSessions();

expect(sessions.length).toBe(2);
expect(sessions[0].context).toEqual(context1);
expect(sessions[1].context).toEqual(context2);
});
Loading