diff --git a/README.md b/README.md index c7c22262..3b2a5fa6 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ A simple 'needle in a haystack' analysis to test in-context retrieval ability of long context LLMs. -Supported model providers: OpenAI, Anthropic +Supported model providers: OpenAI, Anthropic, Cohere Get the behind the scenes on the [overview video](https://youtu.be/KwRRuiCCdmc). @@ -46,9 +46,9 @@ pip install needlehaystack Start using the package by calling the entry point `needlehaystack.run_test` from command line. -You can then run the analysis on OpenAI or Anthropic models with the following command line arguments: +You can then run the analysis on OpenAI, Anthropic, or Cohere models with the following command line arguments: -- `provider` - The provider of the model, available options are `openai` and `anthropic`. Defaults to `openai` +- `provider` - The provider of the model, available options are `openai`, `anthropic`, and `cohere`. Defaults to `openai` - `evaluator` - The evaluator, which can either be a `model` or `LangSmith`. See more on `LangSmith` below. If using a `model`, only `openai` is currently supported. Defaults to `openai`. - `model_name` - Model name of the language model accessible by the provider. Defaults to `gpt-3.5-turbo-0125` - `evaluator_model_name` - Model name of the language model accessible by the evaluator. Defaults to `gpt-3.5-turbo-0125` @@ -69,6 +69,11 @@ Following command runs the test for anthropic model `claude-2.1` for a single co needlehaystack.run_test --provider anthropic --model_name "claude-2.1" --document_depth_percents "[50]" --context_lengths "[2000]" ``` +Following command runs the test for cohere model `command-r` for a single context length of 2000 and single document depth of 50%. + +```zsh +needlehaystack.run_test --provider cohere --model_name "command-r" --document_depth_percents "[50]" --context_lengths "[2000]" +``` ### For Contributors 1. Fork and clone the repository. diff --git a/needlehaystack/providers/__init__.py b/needlehaystack/providers/__init__.py index 0974e290..4e514513 100644 --- a/needlehaystack/providers/__init__.py +++ b/needlehaystack/providers/__init__.py @@ -1,3 +1,4 @@ from .anthropic import Anthropic +from .cohere import Cohere from .model import ModelProvider -from .openai import OpenAI \ No newline at end of file +from .openai import OpenAI diff --git a/needlehaystack/providers/cohere.py b/needlehaystack/providers/cohere.py new file mode 100644 index 00000000..c600c0dd --- /dev/null +++ b/needlehaystack/providers/cohere.py @@ -0,0 +1,113 @@ +import os +import pkg_resources + +from operator import itemgetter +from typing import Optional +from langchain.prompts import PromptTemplate +from langchain_cohere import ChatCohere + +from cohere import Client, AsyncClient + +from .model import ModelProvider + +class Cohere(ModelProvider): + DEFAULT_MODEL_KWARGS: dict = dict(max_tokens = 50, + temperature = 0.3) + + def __init__(self, + model_name: str = "command-r", + model_kwargs: dict = DEFAULT_MODEL_KWARGS): + """ + :param model_name: The name of the model. Default is 'command-r'. + :param model_kwargs: Model configuration. Default is {max_tokens_to_sample: 300, temperature: 0} + """ + + api_key = os.getenv('NIAH_MODEL_API_KEY') + if (not api_key): + raise ValueError("NIAH_MODEL_API_KEY must be in env.") + + self.model_name = model_name + self.model_kwargs = model_kwargs + self.api_key = api_key + + self.client = AsyncClient(api_key=self.api_key) + + async def evaluate_model(self, prompt: tuple[str, list[dict, str, str]]) -> str: + message, chat_history = prompt + response = await self.client.chat(message=message, chat_history=chat_history, model=self.model_name, **self.model_kwargs) + return response.text + + def generate_prompt(self, context: str, retrieval_question: str) -> tuple[str, list[dict[str, str]]]: + ''' + Prepares a chat-formatted prompt + Args: + context (str): The needle in a haystack context + retrieval_question (str): The needle retrieval question + + Returns: + tuple[str, list[dict[str, str]]]: prompt encoded as last message, and chat history + + ''' + return ( + f"{retrieval_question} Don't give information outside the document or repeat your findings", + [{ + "role": "System", + "message": "You are a helpful AI bot that answers questions for a user. Keep your response short and direct" + }, + { + "role": "User", + "message": context + }] + ) + + def encode_text_to_tokens(self, text: str) -> list[int]: + if not text: return [] + return Client().tokenize(text=text, model=self.model_name).tokens + + def decode_tokens(self, tokens: list[int], context_length: Optional[int] = None) -> str: + # Assuming you have a different decoder for Anthropic + return Client().detokenize(tokens=tokens[:context_length], model=self.model_name).text + + def get_langchain_runnable(self, context: str): + """ + Creates a LangChain runnable that constructs a prompt based on a given context and a question. + + Args: + context (str): The context or background information relevant to the user's question. + This context is provided to the model to aid in generating relevant and accurate responses. + + Returns: + str: A LangChain runnable object that can be executed to obtain the model's response to a + dynamically provided question. The runnable encapsulates the entire process from prompt + generation to response retrieval. + + Example: + To use the runnable: + - Define the context and question. + - Execute the runnable with these parameters to get the model's response. + """ + + + template = """Human: You are a helpful AI bot that answers questions for a user. Keep your response short and direct" \n + + {context} + + Here is the user question: + + {question} + + Don't give information outside the document or repeat your findings. + Assistant: Here is the most relevant information in the documents:""" + + api_key = os.getenv('NIAH_MODEL_API_KEY') + model = ChatCohere(cohere_api_key=api_key, temperature=0.3, model=self.model_name) + prompt = PromptTemplate( + template=template, + input_variables=["context", "question"], + ) + chain = ( {"context": lambda x: context, + "question": itemgetter("question")} + | prompt + | model + ) + return chain diff --git a/needlehaystack/run.py b/needlehaystack/run.py index d38519bc..8edbccba 100644 --- a/needlehaystack/run.py +++ b/needlehaystack/run.py @@ -6,7 +6,7 @@ from . import LLMNeedleHaystackTester, LLMMultiNeedleHaystackTester from .evaluators import Evaluator, LangSmithEvaluator, OpenAIEvaluator -from .providers import Anthropic, ModelProvider, OpenAI +from .providers import Anthropic, ModelProvider, OpenAI, Cohere load_dotenv() @@ -63,6 +63,8 @@ def get_model_to_test(args: CommandArgs) -> ModelProvider: return OpenAI(model_name=args.model_name) case "anthropic": return Anthropic(model_name=args.model_name) + case "cohere": + return Cohere(model_name=args.model_name) case _: raise ValueError(f"Invalid provider: {args.provider}") @@ -109,4 +111,4 @@ def main(): tester.start_test() if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/requirements.txt b/requirements.txt index 4e47291d..610cb52a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ anyio==3.7.1 attrs==23.1.0 certifi==2023.11.17 charset-normalizer==3.3.2 +cohere>=5.1.2 dataclasses-json==0.6.3 distro==1.8.0 filelock==3.13.1 @@ -21,10 +22,11 @@ jsonpatch==1.33 jsonpointer==2.4 langchain==0.1.9 langchain-community>=0.0.24 -langchain-core==0.1.26 +langchain-core>=0.1.26 langsmith>=0.1.8 langchain_openai langchain_anthropic +langchain_cohere marshmallow==3.20.1 multidict==6.0.4 mypy-extensions==1.0.0