From 39ef80982e2df3795bfba3885a78aef52f9f8824 Mon Sep 17 00:00:00 2001 From: arkadyark-cohere Date: Mon, 25 Mar 2024 19:54:32 +0000 Subject: [PATCH 1/3] Add cohere model provider --- needlehaystack/providers/__init__.py | 3 +- needlehaystack/providers/cohere.py | 57 ++++++++++++++++++++++++++++ needlehaystack/run.py | 6 ++- requirements.txt | 3 +- 4 files changed, 65 insertions(+), 4 deletions(-) create mode 100644 needlehaystack/providers/cohere.py diff --git a/needlehaystack/providers/__init__.py b/needlehaystack/providers/__init__.py index 0974e290..4e514513 100644 --- a/needlehaystack/providers/__init__.py +++ b/needlehaystack/providers/__init__.py @@ -1,3 +1,4 @@ from .anthropic import Anthropic +from .cohere import Cohere from .model import ModelProvider -from .openai import OpenAI \ No newline at end of file +from .openai import OpenAI diff --git a/needlehaystack/providers/cohere.py b/needlehaystack/providers/cohere.py new file mode 100644 index 00000000..076dc234 --- /dev/null +++ b/needlehaystack/providers/cohere.py @@ -0,0 +1,57 @@ +import os +import pkg_resources + +from operator import itemgetter +from typing import Optional + +from cohere import Client, AsyncClient + +from .model import ModelProvider + +class Cohere(ModelProvider): + DEFAULT_MODEL_KWARGS: dict = dict(max_tokens = 50, + temperature = 0.3) + + def __init__(self, + model_name: str = "command-r", + model_kwargs: dict = DEFAULT_MODEL_KWARGS): + """ + :param model_name: The name of the model. Default is 'command-r'. + :param model_kwargs: Model configuration. Default is {max_tokens_to_sample: 300, temperature: 0} + """ + + api_key = os.getenv('NIAH_MODEL_API_KEY') + if (not api_key): + raise ValueError("NIAH_MODEL_API_KEY must be in env.") + + self.model_name = model_name + self.model_kwargs = model_kwargs + self.api_key = api_key + + self.client = AsyncClient(api_key=self.api_key) + + async def evaluate_model(self, prompt: str) -> str: + response = await self.client.chat(message=prompt[-1]["message"], chat_history=prompt[:-1], model=self.model_name, **self.model_kwargs) + return response.text + + def generate_prompt(self, context: str, retrieval_question: str) -> str | list[dict[str, str]]: + return [{ + "role": "System", + "message": "You are a helpful AI bot that answers questions for a user. Keep your response short and direct" + }, + { + "role": "User", + "message": context + }, + { + "role": "User", + "message": f"{retrieval_question} Don't give information outside the document or repeat your findings" + }] + + def encode_text_to_tokens(self, text: str) -> list[int]: + if not text: return [] + return Client().tokenize(text=text, model=self.model_name).tokens + + def decode_tokens(self, tokens: list[int], context_length: Optional[int] = None) -> str: + # Assuming you have a different decoder for Anthropic + return Client().detokenize(tokens=tokens[:context_length], model=self.model_name).text diff --git a/needlehaystack/run.py b/needlehaystack/run.py index d38519bc..8edbccba 100644 --- a/needlehaystack/run.py +++ b/needlehaystack/run.py @@ -6,7 +6,7 @@ from . import LLMNeedleHaystackTester, LLMMultiNeedleHaystackTester from .evaluators import Evaluator, LangSmithEvaluator, OpenAIEvaluator -from .providers import Anthropic, ModelProvider, OpenAI +from .providers import Anthropic, ModelProvider, OpenAI, Cohere load_dotenv() @@ -63,6 +63,8 @@ def get_model_to_test(args: CommandArgs) -> ModelProvider: return OpenAI(model_name=args.model_name) case "anthropic": return Anthropic(model_name=args.model_name) + case "cohere": + return Cohere(model_name=args.model_name) case _: raise ValueError(f"Invalid provider: {args.provider}") @@ -109,4 +111,4 @@ def main(): tester.start_test() if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/requirements.txt b/requirements.txt index e281cdc6..5d044d16 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ anyio==3.7.1 attrs==23.1.0 certifi==2023.11.17 charset-normalizer==3.3.2 +cohere>=5.0.0 dataclasses-json==0.6.3 distro==1.8.0 filelock==3.13.1 @@ -46,4 +47,4 @@ tqdm==4.66.1 typing-inspect==0.9.0 typing_extensions==4.8.0 urllib3==2.1.0 -yarl==1.9.3 \ No newline at end of file +yarl==1.9.3 From 8af85a8feee4becffab1c57ac02967410992e291 Mon Sep 17 00:00:00 2001 From: arkadyark-cohere Date: Mon, 25 Mar 2024 19:57:09 +0000 Subject: [PATCH 2/3] Update README.md --- README.md | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index c7c22262..3b2a5fa6 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ A simple 'needle in a haystack' analysis to test in-context retrieval ability of long context LLMs. -Supported model providers: OpenAI, Anthropic +Supported model providers: OpenAI, Anthropic, Cohere Get the behind the scenes on the [overview video](https://youtu.be/KwRRuiCCdmc). @@ -46,9 +46,9 @@ pip install needlehaystack Start using the package by calling the entry point `needlehaystack.run_test` from command line. -You can then run the analysis on OpenAI or Anthropic models with the following command line arguments: +You can then run the analysis on OpenAI, Anthropic, or Cohere models with the following command line arguments: -- `provider` - The provider of the model, available options are `openai` and `anthropic`. Defaults to `openai` +- `provider` - The provider of the model, available options are `openai`, `anthropic`, and `cohere`. Defaults to `openai` - `evaluator` - The evaluator, which can either be a `model` or `LangSmith`. See more on `LangSmith` below. If using a `model`, only `openai` is currently supported. Defaults to `openai`. - `model_name` - Model name of the language model accessible by the provider. Defaults to `gpt-3.5-turbo-0125` - `evaluator_model_name` - Model name of the language model accessible by the evaluator. Defaults to `gpt-3.5-turbo-0125` @@ -69,6 +69,11 @@ Following command runs the test for anthropic model `claude-2.1` for a single co needlehaystack.run_test --provider anthropic --model_name "claude-2.1" --document_depth_percents "[50]" --context_lengths "[2000]" ``` +Following command runs the test for cohere model `command-r` for a single context length of 2000 and single document depth of 50%. + +```zsh +needlehaystack.run_test --provider cohere --model_name "command-r" --document_depth_percents "[50]" --context_lengths "[2000]" +``` ### For Contributors 1. Fork and clone the repository. From 64698fdda989c8641e618233a447ab52d87ca6ee Mon Sep 17 00:00:00 2001 From: arkadyark-cohere Date: Tue, 26 Mar 2024 13:07:40 +0000 Subject: [PATCH 3/3] Update requirements, clean up cohere SDK code, add langchain to support multi-needle --- needlehaystack/providers/cohere.py | 72 ++++++++++++++++++++++++++---- requirements.txt | 5 ++- 2 files changed, 67 insertions(+), 10 deletions(-) diff --git a/needlehaystack/providers/cohere.py b/needlehaystack/providers/cohere.py index 076dc234..c600c0dd 100644 --- a/needlehaystack/providers/cohere.py +++ b/needlehaystack/providers/cohere.py @@ -3,6 +3,8 @@ from operator import itemgetter from typing import Optional +from langchain.prompts import PromptTemplate +from langchain_cohere import ChatCohere from cohere import Client, AsyncClient @@ -30,23 +32,33 @@ def __init__(self, self.client = AsyncClient(api_key=self.api_key) - async def evaluate_model(self, prompt: str) -> str: - response = await self.client.chat(message=prompt[-1]["message"], chat_history=prompt[:-1], model=self.model_name, **self.model_kwargs) + async def evaluate_model(self, prompt: tuple[str, list[dict, str, str]]) -> str: + message, chat_history = prompt + response = await self.client.chat(message=message, chat_history=chat_history, model=self.model_name, **self.model_kwargs) return response.text - def generate_prompt(self, context: str, retrieval_question: str) -> str | list[dict[str, str]]: - return [{ + def generate_prompt(self, context: str, retrieval_question: str) -> tuple[str, list[dict[str, str]]]: + ''' + Prepares a chat-formatted prompt + Args: + context (str): The needle in a haystack context + retrieval_question (str): The needle retrieval question + + Returns: + tuple[str, list[dict[str, str]]]: prompt encoded as last message, and chat history + + ''' + return ( + f"{retrieval_question} Don't give information outside the document or repeat your findings", + [{ "role": "System", "message": "You are a helpful AI bot that answers questions for a user. Keep your response short and direct" }, { "role": "User", "message": context - }, - { - "role": "User", - "message": f"{retrieval_question} Don't give information outside the document or repeat your findings" }] + ) def encode_text_to_tokens(self, text: str) -> list[int]: if not text: return [] @@ -55,3 +67,47 @@ def encode_text_to_tokens(self, text: str) -> list[int]: def decode_tokens(self, tokens: list[int], context_length: Optional[int] = None) -> str: # Assuming you have a different decoder for Anthropic return Client().detokenize(tokens=tokens[:context_length], model=self.model_name).text + + def get_langchain_runnable(self, context: str): + """ + Creates a LangChain runnable that constructs a prompt based on a given context and a question. + + Args: + context (str): The context or background information relevant to the user's question. + This context is provided to the model to aid in generating relevant and accurate responses. + + Returns: + str: A LangChain runnable object that can be executed to obtain the model's response to a + dynamically provided question. The runnable encapsulates the entire process from prompt + generation to response retrieval. + + Example: + To use the runnable: + - Define the context and question. + - Execute the runnable with these parameters to get the model's response. + """ + + + template = """Human: You are a helpful AI bot that answers questions for a user. Keep your response short and direct" \n + + {context} + + Here is the user question: + + {question} + + Don't give information outside the document or repeat your findings. + Assistant: Here is the most relevant information in the documents:""" + + api_key = os.getenv('NIAH_MODEL_API_KEY') + model = ChatCohere(cohere_api_key=api_key, temperature=0.3, model=self.model_name) + prompt = PromptTemplate( + template=template, + input_variables=["context", "question"], + ) + chain = ( {"context": lambda x: context, + "question": itemgetter("question")} + | prompt + | model + ) + return chain diff --git a/requirements.txt b/requirements.txt index 5d044d16..a6709d25 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ anyio==3.7.1 attrs==23.1.0 certifi==2023.11.17 charset-normalizer==3.3.2 -cohere>=5.0.0 +cohere>=5.1.2 dataclasses-json==0.6.3 distro==1.8.0 filelock==3.13.1 @@ -22,10 +22,11 @@ jsonpatch==1.33 jsonpointer==2.4 langchain==0.1.9 langchain-community>=0.0.24 -langchain-core==0.1.26 +langchain-core>=0.1.26 langsmith>=0.1.8 langchain_openai langchain_anthropic +langchain_cohere marshmallow==3.20.1 multidict==6.0.4 mypy-extensions==1.0.0