Skip to content

Commit

Permalink
implemented bindTools
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Jun 27, 2024
1 parent 562f208 commit 35c13c1
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 3 deletions.
3 changes: 2 additions & 1 deletion libs/langchain-aws/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@
"release-it": "^15.10.1",
"rollup": "^4.5.2",
"ts-jest": "^29.1.0",
"typescript": "<5.2.0"
"typescript": "<5.2.0",
"zod": "^3.22.4"
},
"publishConfig": {
"access": "public"
Expand Down
19 changes: 19 additions & 0 deletions libs/langchain-aws/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { AIMessageChunk } from "@langchain/core/messages";
import type {
ToolDefinition,
BaseLanguageModelCallOptions,
BaseLanguageModelInput,
} from "@langchain/core/language_models/base";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import {
Expand All @@ -26,6 +27,7 @@ import {
} from "@aws-sdk/credential-provider-node";
import type { DocumentType as __DocumentType } from "@smithy/types";
import { StructuredToolInterface } from "@langchain/core/tools";
import { Runnable } from "@langchain/core/runnables";
import {
BedrockToolChoice,
ConverseCommandParams,
Expand Down Expand Up @@ -229,6 +231,23 @@ export class ChatBedrockConverse
this.additionalModelRequestFields = rest?.additionalModelRequestFields;
}

override bindTools(
tools: (
| StructuredToolInterface
| BedrockTool
| ToolDefinition
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| Record<string, any>
)[],
kwargs?: Partial<this["ParsedCallOptions"]>
): Runnable<
BaseLanguageModelInput,
AIMessageChunk,
this["ParsedCallOptions"]
> {
return this.bind({ tools: convertToConverseTools(tools), ...kwargs });
}

// Replace
_llmType() {
return "chat_bedrock_converse";
Expand Down
8 changes: 7 additions & 1 deletion libs/langchain-aws/src/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,13 @@ export function isBedrockTool(tool: unknown): tool is BedrockTool {
}

export function convertToConverseTools(
tools: (StructuredToolInterface | ToolDefinition | BedrockTool)[]
tools: (
| StructuredToolInterface
| ToolDefinition
| BedrockTool
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| Record<string, any>
)[]
): BedrockTool[] {
if (tools.every(isOpenAITool)) {
return tools.map((tool) => ({
Expand Down
35 changes: 34 additions & 1 deletion libs/langchain-aws/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
/* eslint-disable no-process-env */
import { test, expect } from "@jest/globals";
import { AIMessageChunk, HumanMessage } from "@langchain/core/messages";
import { tool } from "@langchain/core/tools";
import { z } from "zod";
import { ChatBedrockConverse } from "../chat_models.js";

const baseConstructorArgs: Partial<
Expand All @@ -25,7 +27,7 @@ test("Test ChatBedrockConverse can invoke", async () => {
expect(res.content).not.toContain("world");
});

test.only("Test ChatBedrockConverse stream method", async () => {
test("Test ChatBedrockConverse stream method", async () => {
const model = new ChatBedrockConverse({
...baseConstructorArgs,
maxTokens: 50,
Expand Down Expand Up @@ -161,3 +163,34 @@ test("populates ID field on AIMessage", async () => {
expect(finalChunk?.id?.length).toBeGreaterThan(1);
expect(finalChunk?.id?.startsWith("chatcmpl-")).toBe(true);
});

test.only("Test ChatBedrockConverse can invoke tools", async () => {
const model = new ChatBedrockConverse({
...baseConstructorArgs,
});
const tools = [
tool(
(input) => {
console.log("tool", input);
return "Hello";
},
{
name: "get_weather",
description: "Get the weather",
schema: z.object({
location: z.string().describe("Location to get the weather for"),
}),
}
),
];
const modelWithTools = model.bindTools(tools);
const result = await modelWithTools.invoke([
new HumanMessage("Get the weather for London"),
]);

expect(result.tool_calls).toBeDefined();
expect(result.tool_calls).toHaveLength(1);
console.log("result.tool_calls?.[0]", result.tool_calls?.[0]);
expect(result.tool_calls?.[0].name).toBe("get_weather");
expect(result.tool_calls?.[0].id).toBeDefined();
});
1 change: 1 addition & 0 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -9967,6 +9967,7 @@ __metadata:
rollup: ^4.5.2
ts-jest: ^29.1.0
typescript: <5.2.0
zod: ^3.22.4
zod-to-json-schema: ^3.22.5
languageName: unknown
linkType: soft
Expand Down

0 comments on commit 35c13c1

Please sign in to comment.