Skip to content

Commit

Permalink
feat(pinecone): Add support for Pinecone /embed endpoint (#7203)
Browse files Browse the repository at this point in the history
Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
aulorbe and jacoblee93 authored Nov 17, 2024
1 parent ae80cbf commit 762ed46
Show file tree
Hide file tree
Showing 20 changed files with 691 additions and 82 deletions.
344 changes: 344 additions & 0 deletions docs/core_docs/docs/integrations/text_embedding/pinecone.ipynb

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions examples/src/embeddings/pinecone.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import { PineconeEmbeddings } from "@langchain/pinecone";

export const run = async () => {
const model = new PineconeEmbeddings();
console.log({ model }); // Prints out model metadata
const res = await model.embedQuery(
"What would be a good company name a company that makes colorful socks?"
);
console.log({ res });
};

await run();
2 changes: 1 addition & 1 deletion examples/src/indexes/vector_stores/pinecone/delete_docs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { OpenAIEmbeddings } from "@langchain/openai";
import { PineconeStore } from "@langchain/pinecone";

// Instantiate a new Pinecone client, which will automatically read the
// env vars: PINECONE_API_KEY and PINECONE_ENVIRONMENT which come from
// env vars: PINECONE_API_KEY which comes from
// the Pinecone dashboard at https://app.pinecone.io

const pinecone = new Pinecone();
Expand Down
3 changes: 2 additions & 1 deletion examples/src/indexes/vector_stores/pinecone/index_docs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ import { Pinecone } from "@pinecone-database/pinecone";
import { Document } from "@langchain/core/documents";
import { OpenAIEmbeddings } from "@langchain/openai";
import { PineconeStore } from "@langchain/pinecone";
// import { Index } from "@upstash/vector";

// Instantiate a new Pinecone client, which will automatically read the
// env vars: PINECONE_API_KEY and PINECONE_ENVIRONMENT which come from
// env vars: PINECONE_API_KEY which comes from
// the Pinecone dashboard at https://app.pinecone.io

const pinecone = new Pinecone();
Expand Down
2 changes: 1 addition & 1 deletion examples/src/indexes/vector_stores/pinecone/mmr.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { OpenAIEmbeddings } from "@langchain/openai";
import { PineconeStore } from "@langchain/pinecone";

// Instantiate a new Pinecone client, which will automatically read the
// env vars: PINECONE_API_KEY and PINECONE_ENVIRONMENT which come from
// env vars: PINECONE_API_KEY which comes from
// the Pinecone dashboard at https://app.pinecone.io

const pinecone = new Pinecone();
Expand Down
2 changes: 1 addition & 1 deletion examples/src/indexes/vector_stores/pinecone/query_docs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { OpenAIEmbeddings } from "@langchain/openai";
import { PineconeStore } from "@langchain/pinecone";

// Instantiate a new Pinecone client, which will automatically read the
// env vars: PINECONE_API_KEY and PINECONE_ENVIRONMENT which come from
// env vars: PINECONE_API_KEY which comes from
// the Pinecone dashboard at https://app.pinecone.io

const pinecone = new Pinecone();
Expand Down
10 changes: 2 additions & 8 deletions examples/src/retrievers/pinecone_self_query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,8 @@ const attributeInfo: AttributeInfo[] = [
* Next, we instantiate a vector store. This is where we store the embeddings of the documents.
* We also need to provide an embeddings object. This is used to embed the documents.
*/
if (
!process.env.PINECONE_API_KEY ||
!process.env.PINECONE_ENVIRONMENT ||
!process.env.PINECONE_INDEX
) {
throw new Error(
"PINECONE_ENVIRONMENT and PINECONE_API_KEY and PINECONE_INDEX must be set"
);
if (!process.env.PINECONE_API_KEY || !process.env.PINECONE_INDEX) {
throw new Error("PINECONE_API_KEY and PINECONE_INDEX must be set");
}

const pinecone = new Pinecone();
Expand Down
1 change: 0 additions & 1 deletion libs/langchain-community/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@
"@neondatabase/serverless": "^0.9.1",
"@notionhq/client": "^2.2.10",
"@opensearch-project/opensearch": "^2.2.0",
"@pinecone-database/pinecone": "^1.1.0",
"@planetscale/database": "^1.8.0",
"@premai/prem-sdk": "^0.3.25",
"@qdrant/js-client-rest": "^1.8.2",
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain-pinecone/jest.config.cjs
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ module.exports = {
setupFiles: ["dotenv/config"],
testTimeout: 20_000,
passWithNoTests: true,
collectCoverageFrom: ["src/**/*.ts"]
};
collectCoverageFrom: ["src/**/*.ts"],
};
2 changes: 1 addition & 1 deletion libs/langchain-pinecone/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"author": "Pinecone, Inc",
"license": "MIT",
"dependencies": {
"@pinecone-database/pinecone": "^3.0.0 || ^4.0.0",
"@pinecone-database/pinecone": "^4.0.0",
"flat": "^5.0.2",
"uuid": "^10.0.0"
},
Expand Down
16 changes: 16 additions & 0 deletions libs/langchain-pinecone/src/client.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { Pinecone, PineconeConfiguration } from "@pinecone-database/pinecone";
import { getEnvironmentVariable } from "@langchain/core/utils/env";

export function getPineconeClient(config?: PineconeConfiguration): Pinecone {
if (
getEnvironmentVariable("PINECONE_API_KEY") === undefined ||
getEnvironmentVariable("PINECONE_API_KEY") === ""
) {
throw new Error("PINECONE_API_KEY must be set in environment");
}
if (!config) {
return new Pinecone();
} else {
return new Pinecone(config);
}
}
139 changes: 139 additions & 0 deletions libs/langchain-pinecone/src/embeddings.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/* eslint-disable arrow-body-style */

import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings";
import {
EmbeddingsList,
Pinecone,
PineconeConfiguration,
} from "@pinecone-database/pinecone";
import { getPineconeClient } from "./client.js";

/* PineconeEmbeddingsParams holds the optional fields a user can pass to a Pinecone embedding model.
* @param model - Model to use to generate embeddings. Default is "multilingual-e5-large".
* @param params - Additional parameters to pass to the embedding model. Note: parameters are model-specific. Read
* more about model-specific parameters in the [Pinecone
* documentation](https://docs.pinecone.io/guides/inference/understanding-inference#model-specific-parameters).
* */
export interface PineconeEmbeddingsParams extends EmbeddingsParams {
model?: string; // Model to use to generate embeddings
params?: Record<string, string>; // Additional parameters to pass to the embedding model
}

/* PineconeEmbeddings generates embeddings using the Pinecone Inference API. */
export class PineconeEmbeddings
extends Embeddings
implements PineconeEmbeddingsParams
{
client: Pinecone;

model: string;

params: Record<string, string>;

constructor(
fields?: Partial<PineconeEmbeddingsParams> & Partial<PineconeConfiguration>
) {
const defaultFields = { maxRetries: 3, ...fields };
super(defaultFields);

if (defaultFields.apiKey) {
const config = {
apiKey: defaultFields.apiKey,
controllerHostUrl: defaultFields.controllerHostUrl,
fetchApi: defaultFields.fetchApi,
additionalHeaders: defaultFields.additionalHeaders,
sourceTag: defaultFields.sourceTag,
} as PineconeConfiguration;
this.client = getPineconeClient(config);
} else {
this.client = getPineconeClient();
}

if (!defaultFields.model) {
this.model = "multilingual-e5-large";
} else {
this.model = defaultFields.model;
}

const defaultParams = { inputType: "passage" };

if (defaultFields.params) {
this.params = { ...defaultFields.params, ...defaultParams };
} else {
this.params = defaultParams;
}
}

/* Generate embeddings for a list of input strings using a specified embedding model.
*
* @param texts - List of input strings for which to generate embeddings.
* */
async embedDocuments(texts: string[]): Promise<number[][]> {
if (texts.length === 0) {
throw new Error(
"At least one document is required to generate embeddings"
);
}

let embeddings;
if (this.params) {
embeddings = await this.caller.call(async () => {
const result: EmbeddingsList = await this.client.inference.embed(
this.model,
texts,
this.params
);
return result;
});
} else {
embeddings = await this.caller.call(async () => {
const result: EmbeddingsList = await this.client.inference.embed(
this.model,
texts,
{}
);
return result;
});
}

const embeddingsList: number[][] = [];

for (let i = 0; i < embeddings.length; i += 1) {
if (embeddings[i].values) {
embeddingsList.push(embeddings[i].values as number[]);
}
}
return embeddingsList;
}

/* Generate embeddings for a given query string using a specified embedding model.
* @param text - Query string for which to generate embeddings.
* */
async embedQuery(text: string): Promise<number[]> {
// Change inputType to query-specific param for multilingual-e5-large embedding model
this.params.inputType = "query";

if (!text) {
throw new Error("No query passed for which to generate embeddings");
}
let embeddings: EmbeddingsList;
if (this.params) {
embeddings = await this.caller.call(async () => {
return await this.client.inference.embed(
this.model,
[text],
this.params
);
});
} else {
embeddings = await this.caller.call(async () => {
return await this.client.inference.embed(this.model, [text], {});
});
}
if (embeddings[0].values) {
return embeddings[0].values as number[];
} else {
return [];
}
}
}
1 change: 1 addition & 0 deletions libs/langchain-pinecone/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
export * from "./vectorstores.js";
export * from "./translator.js";
export * from "./embeddings.js";
39 changes: 39 additions & 0 deletions libs/langchain-pinecone/src/tests/client.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import { Pinecone } from "@pinecone-database/pinecone";
import { getPineconeClient } from "../client.js";

describe("Tests for getPineconeClient", () => {
test("Happy path for getPineconeClient with and without `config` obj passed", async () => {
const client = getPineconeClient();
expect(client).toBeInstanceOf(Pinecone);
expect(client).toHaveProperty("config"); // Config is always set to *at least* the user's api key

const clientWithConfig = getPineconeClient({
// eslint-disable-next-line no-process-env
apiKey: process.env.PINECONE_API_KEY!,
additionalHeaders: { header: "value" },
});
expect(clientWithConfig).toBeInstanceOf(Pinecone);
expect(client).toHaveProperty("config"); // Unfortunately cannot assert on contents of config b/c it's a private
// attribute of the Pinecone class
});

test("Unhappy path: expect getPineconeClient to throw error if reset PINECONE_API_KEY to empty string", async () => {
// eslint-disable-next-line no-process-env
const originalApiKey = process.env.PINECONE_API_KEY;
try {
// eslint-disable-next-line no-process-env
process.env.PINECONE_API_KEY = "";
const errorThrown = async () => {
getPineconeClient();
};
await expect(errorThrown).rejects.toThrow(Error);
await expect(errorThrown).rejects.toThrow(
"PINECONE_API_KEY must be set in environment"
);
} finally {
// Restore the original value of PINECONE_API_KEY
// eslint-disable-next-line no-process-env
process.env.PINECONE_API_KEY = originalApiKey;
}
});
});
15 changes: 15 additions & 0 deletions libs/langchain-pinecone/src/tests/client.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import { getPineconeClient } from "../client.js";

describe("Tests for getPineconeClient", () => {
test("Confirm getPineconeClient throws error when PINECONE_API_KEY is not set", async () => {
/* eslint-disable-next-line no-process-env */
process.env.PINECONE_API_KEY = "";
const errorThrown = async () => {
getPineconeClient();
};
await expect(errorThrown).rejects.toThrow(Error);
await expect(errorThrown).rejects.toThrow(
"PINECONE_API_KEY must be set in environment"
);
});
});
59 changes: 59 additions & 0 deletions libs/langchain-pinecone/src/tests/embeddings.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import { PineconeEmbeddings } from "../embeddings.js";

describe("Integration tests for Pinecone embeddings", () => {
test("Happy path: defaults for both embedDocuments and embedQuery", async () => {
const model = new PineconeEmbeddings();
expect(model.model).toBe("multilingual-e5-large");
expect(model.params).toEqual({ inputType: "passage" });

const docs = ["hello", "world"];
const embeddings = await model.embedDocuments(docs);
expect(embeddings.length).toBe(docs.length);

const query = "hello";
const queryEmbedding = await model.embedQuery(query);
expect(queryEmbedding.length).toBeGreaterThan(0);
});

test("Happy path: custom `params` obj passed to embedDocuments and embedQuery", async () => {
const model = new PineconeEmbeddings({
params: { customParam: "value" },
});
expect(model.model).toBe("multilingual-e5-large");
expect(model.params).toEqual({
inputType: "passage",
customParam: "value",
});

const docs = ["hello", "world"];
const embeddings = await model.embedDocuments(docs);
expect(embeddings.length).toBe(docs.length);
expect(embeddings[0].length).toBe(1024); // Assert correct dims on random doc
expect(model.model).toBe("multilingual-e5-large");
expect(model.params).toEqual({
inputType: "passage", // Maintain default inputType for docs
customParam: "value",
});

const query = "hello";
const queryEmbedding = await model.embedQuery(query);
expect(model.model).toBe("multilingual-e5-large");
expect(queryEmbedding.length).toBe(1024);
expect(model.params).toEqual({
inputType: "query", // Change inputType for query
customParam: "value",
});
});

test("Unhappy path: embedDocuments and embedQuery throw when empty objs are passed", async () => {
const model = new PineconeEmbeddings();
await expect(model.embedDocuments([])).rejects.toThrow();
await expect(model.embedQuery("")).rejects.toThrow();
});

test("Unhappy path: PineconeEmbeddings throws when invalid model is passed", async () => {
const model = new PineconeEmbeddings({ model: "invalid-model" });
await expect(model.embedDocuments([])).rejects.toThrow();
await expect(model.embedQuery("")).rejects.toThrow();
});
});
Loading

0 comments on commit 762ed46

Please sign in to comment.