diff --git a/README.md b/README.md
index c7c22262..3b2a5fa6 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,7 @@
A simple 'needle in a haystack' analysis to test in-context retrieval ability of long context LLMs.
-Supported model providers: OpenAI, Anthropic
+Supported model providers: OpenAI, Anthropic, Cohere
Get the behind the scenes on the [overview video](https://youtu.be/KwRRuiCCdmc).
@@ -46,9 +46,9 @@ pip install needlehaystack
Start using the package by calling the entry point `needlehaystack.run_test` from command line.
-You can then run the analysis on OpenAI or Anthropic models with the following command line arguments:
+You can then run the analysis on OpenAI, Anthropic, or Cohere models with the following command line arguments:
-- `provider` - The provider of the model, available options are `openai` and `anthropic`. Defaults to `openai`
+- `provider` - The provider of the model, available options are `openai`, `anthropic`, and `cohere`. Defaults to `openai`
- `evaluator` - The evaluator, which can either be a `model` or `LangSmith`. See more on `LangSmith` below. If using a `model`, only `openai` is currently supported. Defaults to `openai`.
- `model_name` - Model name of the language model accessible by the provider. Defaults to `gpt-3.5-turbo-0125`
- `evaluator_model_name` - Model name of the language model accessible by the evaluator. Defaults to `gpt-3.5-turbo-0125`
@@ -69,6 +69,11 @@ Following command runs the test for anthropic model `claude-2.1` for a single co
needlehaystack.run_test --provider anthropic --model_name "claude-2.1" --document_depth_percents "[50]" --context_lengths "[2000]"
```
+Following command runs the test for cohere model `command-r` for a single context length of 2000 and single document depth of 50%.
+
+```zsh
+needlehaystack.run_test --provider cohere --model_name "command-r" --document_depth_percents "[50]" --context_lengths "[2000]"
+```
### For Contributors
1. Fork and clone the repository.
diff --git a/needlehaystack/providers/__init__.py b/needlehaystack/providers/__init__.py
index 0974e290..4e514513 100644
--- a/needlehaystack/providers/__init__.py
+++ b/needlehaystack/providers/__init__.py
@@ -1,3 +1,4 @@
from .anthropic import Anthropic
+from .cohere import Cohere
from .model import ModelProvider
-from .openai import OpenAI
\ No newline at end of file
+from .openai import OpenAI
diff --git a/needlehaystack/providers/cohere.py b/needlehaystack/providers/cohere.py
new file mode 100644
index 00000000..c600c0dd
--- /dev/null
+++ b/needlehaystack/providers/cohere.py
@@ -0,0 +1,113 @@
+import os
+import pkg_resources
+
+from operator import itemgetter
+from typing import Optional
+from langchain.prompts import PromptTemplate
+from langchain_cohere import ChatCohere
+
+from cohere import Client, AsyncClient
+
+from .model import ModelProvider
+
+class Cohere(ModelProvider):
+ DEFAULT_MODEL_KWARGS: dict = dict(max_tokens = 50,
+ temperature = 0.3)
+
+ def __init__(self,
+ model_name: str = "command-r",
+ model_kwargs: dict = DEFAULT_MODEL_KWARGS):
+ """
+ :param model_name: The name of the model. Default is 'command-r'.
+ :param model_kwargs: Model configuration. Default is {max_tokens_to_sample: 300, temperature: 0}
+ """
+
+ api_key = os.getenv('NIAH_MODEL_API_KEY')
+ if (not api_key):
+ raise ValueError("NIAH_MODEL_API_KEY must be in env.")
+
+ self.model_name = model_name
+ self.model_kwargs = model_kwargs
+ self.api_key = api_key
+
+ self.client = AsyncClient(api_key=self.api_key)
+
+ async def evaluate_model(self, prompt: tuple[str, list[dict, str, str]]) -> str:
+ message, chat_history = prompt
+ response = await self.client.chat(message=message, chat_history=chat_history, model=self.model_name, **self.model_kwargs)
+ return response.text
+
+ def generate_prompt(self, context: str, retrieval_question: str) -> tuple[str, list[dict[str, str]]]:
+ '''
+ Prepares a chat-formatted prompt
+ Args:
+ context (str): The needle in a haystack context
+ retrieval_question (str): The needle retrieval question
+
+ Returns:
+ tuple[str, list[dict[str, str]]]: prompt encoded as last message, and chat history
+
+ '''
+ return (
+ f"{retrieval_question} Don't give information outside the document or repeat your findings",
+ [{
+ "role": "System",
+ "message": "You are a helpful AI bot that answers questions for a user. Keep your response short and direct"
+ },
+ {
+ "role": "User",
+ "message": context
+ }]
+ )
+
+ def encode_text_to_tokens(self, text: str) -> list[int]:
+ if not text: return []
+ return Client().tokenize(text=text, model=self.model_name).tokens
+
+ def decode_tokens(self, tokens: list[int], context_length: Optional[int] = None) -> str:
+ # Assuming you have a different decoder for Anthropic
+ return Client().detokenize(tokens=tokens[:context_length], model=self.model_name).text
+
+ def get_langchain_runnable(self, context: str):
+ """
+ Creates a LangChain runnable that constructs a prompt based on a given context and a question.
+
+ Args:
+ context (str): The context or background information relevant to the user's question.
+ This context is provided to the model to aid in generating relevant and accurate responses.
+
+ Returns:
+ str: A LangChain runnable object that can be executed to obtain the model's response to a
+ dynamically provided question. The runnable encapsulates the entire process from prompt
+ generation to response retrieval.
+
+ Example:
+ To use the runnable:
+ - Define the context and question.
+ - Execute the runnable with these parameters to get the model's response.
+ """
+
+
+ template = """Human: You are a helpful AI bot that answers questions for a user. Keep your response short and direct" \n
+
+ {context}
+
+ Here is the user question:
+
+ {question}
+
+ Don't give information outside the document or repeat your findings.
+ Assistant: Here is the most relevant information in the documents:"""
+
+ api_key = os.getenv('NIAH_MODEL_API_KEY')
+ model = ChatCohere(cohere_api_key=api_key, temperature=0.3, model=self.model_name)
+ prompt = PromptTemplate(
+ template=template,
+ input_variables=["context", "question"],
+ )
+ chain = ( {"context": lambda x: context,
+ "question": itemgetter("question")}
+ | prompt
+ | model
+ )
+ return chain
diff --git a/needlehaystack/run.py b/needlehaystack/run.py
index d38519bc..8edbccba 100644
--- a/needlehaystack/run.py
+++ b/needlehaystack/run.py
@@ -6,7 +6,7 @@
from . import LLMNeedleHaystackTester, LLMMultiNeedleHaystackTester
from .evaluators import Evaluator, LangSmithEvaluator, OpenAIEvaluator
-from .providers import Anthropic, ModelProvider, OpenAI
+from .providers import Anthropic, ModelProvider, OpenAI, Cohere
load_dotenv()
@@ -63,6 +63,8 @@ def get_model_to_test(args: CommandArgs) -> ModelProvider:
return OpenAI(model_name=args.model_name)
case "anthropic":
return Anthropic(model_name=args.model_name)
+ case "cohere":
+ return Cohere(model_name=args.model_name)
case _:
raise ValueError(f"Invalid provider: {args.provider}")
@@ -109,4 +111,4 @@ def main():
tester.start_test()
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/requirements.txt b/requirements.txt
index 4e47291d..610cb52a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,6 +6,7 @@ anyio==3.7.1
attrs==23.1.0
certifi==2023.11.17
charset-normalizer==3.3.2
+cohere>=5.1.2
dataclasses-json==0.6.3
distro==1.8.0
filelock==3.13.1
@@ -21,10 +22,11 @@ jsonpatch==1.33
jsonpointer==2.4
langchain==0.1.9
langchain-community>=0.0.24
-langchain-core==0.1.26
+langchain-core>=0.1.26
langsmith>=0.1.8
langchain_openai
langchain_anthropic
+langchain_cohere
marshmallow==3.20.1
multidict==6.0.4
mypy-extensions==1.0.0