Skip to content

Commit

Permalink
Merge pull request #40 from arkadyark-cohere/add-cohere
Browse files Browse the repository at this point in the history
Add cohere model provider
  • Loading branch information
gkamradt authored Apr 12, 2024
2 parents 916227b + 31f1117 commit 7b90d28
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 7 deletions.
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

A simple 'needle in a haystack' analysis to test in-context retrieval ability of long context LLMs.

Supported model providers: OpenAI, Anthropic
Supported model providers: OpenAI, Anthropic, Cohere

Get the behind the scenes on the [overview video](https://youtu.be/KwRRuiCCdmc).

Expand Down Expand Up @@ -46,9 +46,9 @@ pip install needlehaystack

Start using the package by calling the entry point `needlehaystack.run_test` from command line.

You can then run the analysis on OpenAI or Anthropic models with the following command line arguments:
You can then run the analysis on OpenAI, Anthropic, or Cohere models with the following command line arguments:

- `provider` - The provider of the model, available options are `openai` and `anthropic`. Defaults to `openai`
- `provider` - The provider of the model, available options are `openai`, `anthropic`, and `cohere`. Defaults to `openai`
- `evaluator` - The evaluator, which can either be a `model` or `LangSmith`. See more on `LangSmith` below. If using a `model`, only `openai` is currently supported. Defaults to `openai`.
- `model_name` - Model name of the language model accessible by the provider. Defaults to `gpt-3.5-turbo-0125`
- `evaluator_model_name` - Model name of the language model accessible by the evaluator. Defaults to `gpt-3.5-turbo-0125`
Expand All @@ -69,6 +69,11 @@ Following command runs the test for anthropic model `claude-2.1` for a single co
needlehaystack.run_test --provider anthropic --model_name "claude-2.1" --document_depth_percents "[50]" --context_lengths "[2000]"
```

Following command runs the test for cohere model `command-r` for a single context length of 2000 and single document depth of 50%.

```zsh
needlehaystack.run_test --provider cohere --model_name "command-r" --document_depth_percents "[50]" --context_lengths "[2000]"
```
### For Contributors

1. Fork and clone the repository.
Expand Down
3 changes: 2 additions & 1 deletion needlehaystack/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .anthropic import Anthropic
from .cohere import Cohere
from .model import ModelProvider
from .openai import OpenAI
from .openai import OpenAI
113 changes: 113 additions & 0 deletions needlehaystack/providers/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import os
import pkg_resources

from operator import itemgetter
from typing import Optional
from langchain.prompts import PromptTemplate
from langchain_cohere import ChatCohere

from cohere import Client, AsyncClient

from .model import ModelProvider

class Cohere(ModelProvider):
DEFAULT_MODEL_KWARGS: dict = dict(max_tokens = 50,
temperature = 0.3)

def __init__(self,
model_name: str = "command-r",
model_kwargs: dict = DEFAULT_MODEL_KWARGS):
"""
:param model_name: The name of the model. Default is 'command-r'.
:param model_kwargs: Model configuration. Default is {max_tokens_to_sample: 300, temperature: 0}
"""

api_key = os.getenv('NIAH_MODEL_API_KEY')
if (not api_key):
raise ValueError("NIAH_MODEL_API_KEY must be in env.")

self.model_name = model_name
self.model_kwargs = model_kwargs
self.api_key = api_key

self.client = AsyncClient(api_key=self.api_key)

async def evaluate_model(self, prompt: tuple[str, list[dict, str, str]]) -> str:
message, chat_history = prompt
response = await self.client.chat(message=message, chat_history=chat_history, model=self.model_name, **self.model_kwargs)
return response.text

def generate_prompt(self, context: str, retrieval_question: str) -> tuple[str, list[dict[str, str]]]:
'''
Prepares a chat-formatted prompt
Args:
context (str): The needle in a haystack context
retrieval_question (str): The needle retrieval question
Returns:
tuple[str, list[dict[str, str]]]: prompt encoded as last message, and chat history
'''
return (
f"{retrieval_question} Don't give information outside the document or repeat your findings",
[{
"role": "System",
"message": "You are a helpful AI bot that answers questions for a user. Keep your response short and direct"
},
{
"role": "User",
"message": context
}]
)

def encode_text_to_tokens(self, text: str) -> list[int]:
if not text: return []
return Client().tokenize(text=text, model=self.model_name).tokens

def decode_tokens(self, tokens: list[int], context_length: Optional[int] = None) -> str:
# Assuming you have a different decoder for Anthropic
return Client().detokenize(tokens=tokens[:context_length], model=self.model_name).text

def get_langchain_runnable(self, context: str):
"""
Creates a LangChain runnable that constructs a prompt based on a given context and a question.
Args:
context (str): The context or background information relevant to the user's question.
This context is provided to the model to aid in generating relevant and accurate responses.
Returns:
str: A LangChain runnable object that can be executed to obtain the model's response to a
dynamically provided question. The runnable encapsulates the entire process from prompt
generation to response retrieval.
Example:
To use the runnable:
- Define the context and question.
- Execute the runnable with these parameters to get the model's response.
"""


template = """Human: You are a helpful AI bot that answers questions for a user. Keep your response short and direct" \n
<document_content>
{context}
</document_content>
Here is the user question:
<question>
{question}
</question>
Don't give information outside the document or repeat your findings.
Assistant: Here is the most relevant information in the documents:"""

api_key = os.getenv('NIAH_MODEL_API_KEY')
model = ChatCohere(cohere_api_key=api_key, temperature=0.3, model=self.model_name)
prompt = PromptTemplate(
template=template,
input_variables=["context", "question"],
)
chain = ( {"context": lambda x: context,
"question": itemgetter("question")}
| prompt
| model
)
return chain
6 changes: 4 additions & 2 deletions needlehaystack/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from . import LLMNeedleHaystackTester, LLMMultiNeedleHaystackTester
from .evaluators import Evaluator, LangSmithEvaluator, OpenAIEvaluator
from .providers import Anthropic, ModelProvider, OpenAI
from .providers import Anthropic, ModelProvider, OpenAI, Cohere

load_dotenv()

Expand Down Expand Up @@ -63,6 +63,8 @@ def get_model_to_test(args: CommandArgs) -> ModelProvider:
return OpenAI(model_name=args.model_name)
case "anthropic":
return Anthropic(model_name=args.model_name)
case "cohere":
return Cohere(model_name=args.model_name)
case _:
raise ValueError(f"Invalid provider: {args.provider}")

Expand Down Expand Up @@ -109,4 +111,4 @@ def main():
tester.start_test()

if __name__ == "__main__":
main()
main()
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ anyio==3.7.1
attrs==23.1.0
certifi==2023.11.17
charset-normalizer==3.3.2
cohere>=5.1.2
dataclasses-json==0.6.3
distro==1.8.0
filelock==3.13.1
Expand All @@ -21,10 +22,11 @@ jsonpatch==1.33
jsonpointer==2.4
langchain==0.1.9
langchain-community>=0.0.24
langchain-core==0.1.26
langchain-core>=0.1.26
langsmith>=0.1.8
langchain_openai
langchain_anthropic
langchain_cohere
marshmallow==3.20.1
multidict==6.0.4
mypy-extensions==1.0.0
Expand Down

0 comments on commit 7b90d28

Please sign in to comment.