Skip to content

Commit

Permalink
fix(community) : Upgrade node-llama-cpp to be compatible with version…
Browse files Browse the repository at this point in the history
… 3 (#7135)

Co-authored-by: Jacky Chen <[email protected]>
  • Loading branch information
rd4cake and Jacky3003 authored Nov 12, 2024
1 parent a1530da commit 2a7a2b8
Show file tree
Hide file tree
Showing 22 changed files with 908 additions and 465 deletions.
4 changes: 2 additions & 2 deletions docs/core_docs/docs/integrations/chat/llama_cpp.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ This module is based on the [node-llama-cpp](https://github.com/withcatai/node-l

## Setup

You'll need to install major version `2` of the [node-llama-cpp](https://github.com/withcatai/node-llama-cpp) module to communicate with your local model.
You'll need to install major version `3` of the [node-llama-cpp](https://github.com/withcatai/node-llama-cpp) module to communicate with your local model.

import IntegrationInstallTooltip from "@mdx_components/integration_install_tooltip.mdx";

<IntegrationInstallTooltip></IntegrationInstallTooltip>

```bash npm2yarn
npm install -S node-llama-cpp@2 @langchain/community @langchain/core
npm install -S node-llama-cpp@3 @langchain/community @langchain/core
```

You will also need a local Llama 2 model (or a model supported by [node-llama-cpp](https://github.com/withcatai/node-llama-cpp)). You will need to pass the path to this model to the LlamaCpp module as a part of the parameters (see example).
Expand Down
4 changes: 2 additions & 2 deletions docs/core_docs/docs/integrations/llms/llama_cpp.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ This module is based on the [node-llama-cpp](https://github.com/withcatai/node-l

## Setup

You'll need to install major version `2` of the [node-llama-cpp](https://github.com/withcatai/node-llama-cpp) module to communicate with your local model.
You'll need to install major version `3` of the [node-llama-cpp](https://github.com/withcatai/node-llama-cpp) module to communicate with your local model.

```bash npm2yarn
npm install -S node-llama-cpp@2
npm install -S node-llama-cpp@3
```

import IntegrationInstallTooltip from "@mdx_components/integration_install_tooltip.mdx";
Expand Down
4 changes: 2 additions & 2 deletions docs/core_docs/docs/integrations/text_embedding/llama_cpp.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ This module is based on the [node-llama-cpp](https://github.com/withcatai/node-l

## Setup

You'll need to install major version `2` of the [node-llama-cpp](https://github.com/withcatai/node-llama-cpp) module to communicate with your local model.
You'll need to install major version `3` of the [node-llama-cpp](https://github.com/withcatai/node-llama-cpp) module to communicate with your local model.

```bash npm2yarn
npm install -S node-llama-cpp@2
npm install -S node-llama-cpp@3
```

import IntegrationInstallTooltip from "@mdx_components/integration_install_tooltip.mdx";
Expand Down
2 changes: 1 addition & 1 deletion examples/src/embeddings/llama_cpp_basic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { LlamaCppEmbeddings } from "@langchain/community/embeddings/llama_cpp";

const llamaPath = "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin";

const embeddings = new LlamaCppEmbeddings({
const embeddings = await LlamaCppEmbeddings.initialize({
modelPath: llamaPath,
});

Expand Down
2 changes: 1 addition & 1 deletion examples/src/embeddings/llama_cpp_docs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ const llamaPath = "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin";

const documents = ["Hello World!", "Bye Bye!"];

const embeddings = new LlamaCppEmbeddings({
const embeddings = await LlamaCppEmbeddings.initialize({
modelPath: llamaPath,
});

Expand Down
2 changes: 1 addition & 1 deletion examples/src/models/chat/integration_llama_cpp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { HumanMessage } from "@langchain/core/messages";

const llamaPath = "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin";

const model = new ChatLlamaCpp({ modelPath: llamaPath });
const model = await ChatLlamaCpp.initialize({ modelPath: llamaPath });

const response = await model.invoke([
new HumanMessage({ content: "My name is John." }),
Expand Down
5 changes: 4 additions & 1 deletion examples/src/models/chat/integration_llama_cpp_chain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ import { PromptTemplate } from "@langchain/core/prompts";

const llamaPath = "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin";

const model = new ChatLlamaCpp({ modelPath: llamaPath, temperature: 0.5 });
const model = await ChatLlamaCpp.initialize({
modelPath: llamaPath,
temperature: 0.5,
});

const prompt = PromptTemplate.fromTemplate(
"What is a good name for a company that makes {product}?"
Expand Down
5 changes: 4 additions & 1 deletion examples/src/models/chat/integration_llama_cpp_stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ import { ChatLlamaCpp } from "@langchain/community/chat_models/llama_cpp";

const llamaPath = "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin";

const model = new ChatLlamaCpp({ modelPath: llamaPath, temperature: 0.7 });
const model = await ChatLlamaCpp.initialize({
modelPath: llamaPath,
temperature: 0.7,
});

const stream = await model.stream("Tell me a short story about a happy Llama.");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ import { SystemMessage, HumanMessage } from "@langchain/core/messages";

const llamaPath = "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin";

const model = new ChatLlamaCpp({ modelPath: llamaPath, temperature: 0.7 });
const model = await ChatLlamaCpp.initialize({
modelPath: llamaPath,
temperature: 0.7,
});

const controller = new AbortController();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ import { SystemMessage, HumanMessage } from "@langchain/core/messages";

const llamaPath = "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin";

const llamaCpp = new ChatLlamaCpp({ modelPath: llamaPath, temperature: 0.7 });
const llamaCpp = await ChatLlamaCpp.initialize({
modelPath: llamaPath,
temperature: 0.7,
});

const stream = await llamaCpp.stream([
new SystemMessage(
Expand Down
2 changes: 1 addition & 1 deletion examples/src/models/chat/integration_llama_cpp_system.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { SystemMessage, HumanMessage } from "@langchain/core/messages";

const llamaPath = "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin";

const model = new ChatLlamaCpp({ modelPath: llamaPath });
const model = await ChatLlamaCpp.initialize({ modelPath: llamaPath });

const response = await model.invoke([
new SystemMessage(
Expand Down
2 changes: 1 addition & 1 deletion examples/src/models/llm/llama_cpp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { LlamaCpp } from "@langchain/community/llms/llama_cpp";
const llamaPath = "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin";
const question = "Where do Llamas come from?";

const model = new LlamaCpp({ modelPath: llamaPath });
const model = await LlamaCpp.initialize({ modelPath: llamaPath });

console.log(`You: ${question}`);
const response = await model.invoke(question);
Expand Down
5 changes: 4 additions & 1 deletion examples/src/models/llm/llama_cpp_stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ import { LlamaCpp } from "@langchain/community/llms/llama_cpp";

const llamaPath = "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin";

const model = new LlamaCpp({ modelPath: llamaPath, temperature: 0.7 });
const model = await LlamaCpp.initialize({
modelPath: llamaPath,
temperature: 0.7,
});

const prompt = "Tell me a short story about a happy Llama.";

Expand Down
2 changes: 1 addition & 1 deletion libs/langchain-community/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@
"mongodb": "^5.2.0",
"mysql2": "^3.9.8",
"neo4j-driver": "^5.17.0",
"node-llama-cpp": "^2",
"node-llama-cpp": "3.1.1",
"notion-to-md": "^3.1.0",
"officeparser": "^4.0.4",
"pdf-parse": "1.1.1",
Expand Down
93 changes: 59 additions & 34 deletions libs/langchain-community/src/chat_models/llama_cpp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ import {
LlamaModel,
LlamaContext,
LlamaChatSession,
type ConversationInteraction,
type Token,
ChatUserMessage,
ChatModelResponse,
ChatHistoryItem,
getLlama,
} from "node-llama-cpp";

import {
Expand Down Expand Up @@ -47,7 +51,7 @@ export interface LlamaCppCallOptions extends BaseLanguageModelCallOptions {
* @example
* ```typescript
* // Initialize the ChatLlamaCpp model with the path to the model binary file.
* const model = new ChatLlamaCpp({
* const model = await ChatLlamaCpp.initialize({
* modelPath: "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin",
* temperature: 0.5,
* });
Expand Down Expand Up @@ -87,20 +91,35 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> {
return "ChatLlamaCpp";
}

constructor(inputs: LlamaCppInputs) {
public constructor(inputs: LlamaCppInputs) {
super(inputs);
this.maxTokens = inputs?.maxTokens;
this.temperature = inputs?.temperature;
this.topK = inputs?.topK;
this.topP = inputs?.topP;
this.trimWhitespaceSuffix = inputs?.trimWhitespaceSuffix;
this._model = createLlamaModel(inputs);
this._context = createLlamaContext(this._model, inputs);
this._session = null;
}

/**
* Initializes the llama_cpp model for usage in the chat models wrapper.
* @param inputs - the inputs passed onto the model.
* @returns A Promise that resolves to the ChatLlamaCpp type class.
*/
public static async initialize(
inputs: LlamaBaseCppInputs
): Promise<ChatLlamaCpp> {
const instance = new ChatLlamaCpp(inputs);
const llama = await getLlama();

instance._model = await createLlamaModel(inputs, llama);
instance._context = await createLlamaContext(instance._model, inputs);

return instance;
}

_llmType() {
return "llama2_cpp";
return "llama_cpp";
}

/** @ignore */
Expand Down Expand Up @@ -146,7 +165,9 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> {
signal: options.signal,
onToken: async (tokens: number[]) => {
options.onToken?.(tokens);
await runManager?.handleLLMNewToken(this._context.decode(tokens));
await runManager?.handleLLMNewToken(
this._model.detokenize(tokens.map((num) => num as Token))
);
},
maxTokens: this?.maxTokens,
temperature: this?.temperature,
Expand Down Expand Up @@ -180,20 +201,23 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> {
};

const prompt = this._buildPrompt(input);
const sequence = this._context.getSequence();

const stream = await this.caller.call(async () =>
this._context.evaluate(this._context.encode(prompt), promptOptions)
sequence.evaluate(this._model.tokenize(prompt), promptOptions)
);

for await (const chunk of stream) {
yield new ChatGenerationChunk({
text: this._context.decode([chunk]),
text: this._model.detokenize([chunk]),
message: new AIMessageChunk({
content: this._context.decode([chunk]),
content: this._model.detokenize([chunk]),
}),
generationInfo: {},
});
await runManager?.handleLLMNewToken(this._context.decode([chunk]) ?? "");
await runManager?.handleLLMNewToken(
this._model.detokenize([chunk]) ?? ""
);
}
}

Expand All @@ -202,12 +226,12 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> {
let prompt = "";
let sysMessage = "";
let noSystemMessages: BaseMessage[] = [];
let interactions: ConversationInteraction[] = [];
let interactions: ChatHistoryItem[] = [];

// Let's see if we have a system message
if (messages.findIndex((msg) => msg._getType() === "system") !== -1) {
if (messages.findIndex((msg) => msg.getType() === "system") !== -1) {
const sysMessages = messages.filter(
(message) => message._getType() === "system"
(message) => message.getType() === "system"
);

const systemMessageContent = sysMessages[sysMessages.length - 1].content;
Expand All @@ -222,7 +246,7 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> {

// Now filter out the system messages
noSystemMessages = messages.filter(
(message) => message._getType() !== "system"
(message) => message.getType() !== "system"
);
} else {
noSystemMessages = messages;
Expand All @@ -231,9 +255,7 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> {
// Lets see if we just have a prompt left or are their previous interactions?
if (noSystemMessages.length > 1) {
// Is the last message a prompt?
if (
noSystemMessages[noSystemMessages.length - 1]._getType() === "human"
) {
if (noSystemMessages[noSystemMessages.length - 1].getType() === "human") {
const finalMessageContent =
noSystemMessages[noSystemMessages.length - 1].content;
if (typeof finalMessageContent !== "string") {
Expand Down Expand Up @@ -261,23 +283,23 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> {
// Now lets construct a session according to what we got
if (sysMessage !== "" && interactions.length > 0) {
this._session = new LlamaChatSession({
context: this._context,
conversationHistory: interactions,
contextSequence: this._context.getSequence(),
systemPrompt: sysMessage,
});
this._session.setChatHistory(interactions);
} else if (sysMessage !== "" && interactions.length === 0) {
this._session = new LlamaChatSession({
context: this._context,
contextSequence: this._context.getSequence(),
systemPrompt: sysMessage,
});
} else if (sysMessage === "" && interactions.length > 0) {
this._session = new LlamaChatSession({
context: this._context,
conversationHistory: interactions,
contextSequence: this._context.getSequence(),
});
this._session.setChatHistory(interactions);
} else {
this._session = new LlamaChatSession({
context: this._context,
contextSequence: this._context.getSequence(),
});
}

Expand All @@ -287,8 +309,8 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> {
// This builds a an array of interactions
protected _convertMessagesToInteractions(
messages: BaseMessage[]
): ConversationInteraction[] {
const result: ConversationInteraction[] = [];
): ChatHistoryItem[] {
const result: ChatHistoryItem[] = [];

for (let i = 0; i < messages.length; i += 2) {
if (i + 1 < messages.length) {
Expand All @@ -299,10 +321,13 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> {
"ChatLlamaCpp does not support non-string message content."
);
}
result.push({
prompt,
response,
});
const llamaPrompt: ChatUserMessage = { type: "user", text: prompt };
const llamaResponse: ChatModelResponse = {
type: "model",
response: [response],
};
result.push(llamaPrompt);
result.push(llamaResponse);
}
}

Expand All @@ -313,19 +338,19 @@ export class ChatLlamaCpp extends SimpleChatModel<LlamaCppCallOptions> {
const prompt = input
.map((message) => {
let messageText;
if (message._getType() === "human") {
if (message.getType() === "human") {
messageText = `[INST] ${message.content} [/INST]`;
} else if (message._getType() === "ai") {
} else if (message.getType() === "ai") {
messageText = message.content;
} else if (message._getType() === "system") {
} else if (message.getType() === "system") {
messageText = `<<SYS>> ${message.content} <</SYS>>`;
} else if (ChatMessage.isInstance(message)) {
messageText = `\n\n${message.role[0].toUpperCase()}${message.role.slice(
1
)}: ${message.content}`;
} else {
console.warn(
`Unsupported message type passed to llama_cpp: "${message._getType()}"`
`Unsupported message type passed to llama_cpp: "${message.getType()}"`
);
messageText = "";
}
Expand Down
Loading

0 comments on commit 2a7a2b8

Please sign in to comment.