diff --git a/needlehaystack/providers/cohere.py b/needlehaystack/providers/cohere.py
index 076dc234..c600c0dd 100644
--- a/needlehaystack/providers/cohere.py
+++ b/needlehaystack/providers/cohere.py
@@ -3,6 +3,8 @@
from operator import itemgetter
from typing import Optional
+from langchain.prompts import PromptTemplate
+from langchain_cohere import ChatCohere
from cohere import Client, AsyncClient
@@ -30,23 +32,33 @@ def __init__(self,
self.client = AsyncClient(api_key=self.api_key)
- async def evaluate_model(self, prompt: str) -> str:
- response = await self.client.chat(message=prompt[-1]["message"], chat_history=prompt[:-1], model=self.model_name, **self.model_kwargs)
+ async def evaluate_model(self, prompt: tuple[str, list[dict, str, str]]) -> str:
+ message, chat_history = prompt
+ response = await self.client.chat(message=message, chat_history=chat_history, model=self.model_name, **self.model_kwargs)
return response.text
- def generate_prompt(self, context: str, retrieval_question: str) -> str | list[dict[str, str]]:
- return [{
+ def generate_prompt(self, context: str, retrieval_question: str) -> tuple[str, list[dict[str, str]]]:
+ '''
+ Prepares a chat-formatted prompt
+ Args:
+ context (str): The needle in a haystack context
+ retrieval_question (str): The needle retrieval question
+
+ Returns:
+ tuple[str, list[dict[str, str]]]: prompt encoded as last message, and chat history
+
+ '''
+ return (
+ f"{retrieval_question} Don't give information outside the document or repeat your findings",
+ [{
"role": "System",
"message": "You are a helpful AI bot that answers questions for a user. Keep your response short and direct"
},
{
"role": "User",
"message": context
- },
- {
- "role": "User",
- "message": f"{retrieval_question} Don't give information outside the document or repeat your findings"
}]
+ )
def encode_text_to_tokens(self, text: str) -> list[int]:
if not text: return []
@@ -55,3 +67,47 @@ def encode_text_to_tokens(self, text: str) -> list[int]:
def decode_tokens(self, tokens: list[int], context_length: Optional[int] = None) -> str:
# Assuming you have a different decoder for Anthropic
return Client().detokenize(tokens=tokens[:context_length], model=self.model_name).text
+
+ def get_langchain_runnable(self, context: str):
+ """
+ Creates a LangChain runnable that constructs a prompt based on a given context and a question.
+
+ Args:
+ context (str): The context or background information relevant to the user's question.
+ This context is provided to the model to aid in generating relevant and accurate responses.
+
+ Returns:
+ str: A LangChain runnable object that can be executed to obtain the model's response to a
+ dynamically provided question. The runnable encapsulates the entire process from prompt
+ generation to response retrieval.
+
+ Example:
+ To use the runnable:
+ - Define the context and question.
+ - Execute the runnable with these parameters to get the model's response.
+ """
+
+
+ template = """Human: You are a helpful AI bot that answers questions for a user. Keep your response short and direct" \n
+
+ {context}
+
+ Here is the user question:
+
+ {question}
+
+ Don't give information outside the document or repeat your findings.
+ Assistant: Here is the most relevant information in the documents:"""
+
+ api_key = os.getenv('NIAH_MODEL_API_KEY')
+ model = ChatCohere(cohere_api_key=api_key, temperature=0.3, model=self.model_name)
+ prompt = PromptTemplate(
+ template=template,
+ input_variables=["context", "question"],
+ )
+ chain = ( {"context": lambda x: context,
+ "question": itemgetter("question")}
+ | prompt
+ | model
+ )
+ return chain
diff --git a/requirements.txt b/requirements.txt
index 5d044d16..a6709d25 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,7 +6,7 @@ anyio==3.7.1
attrs==23.1.0
certifi==2023.11.17
charset-normalizer==3.3.2
-cohere>=5.0.0
+cohere>=5.1.2
dataclasses-json==0.6.3
distro==1.8.0
filelock==3.13.1
@@ -22,10 +22,11 @@ jsonpatch==1.33
jsonpointer==2.4
langchain==0.1.9
langchain-community>=0.0.24
-langchain-core==0.1.26
+langchain-core>=0.1.26
langsmith>=0.1.8
langchain_openai
langchain_anthropic
+langchain_cohere
marshmallow==3.20.1
multidict==6.0.4
mypy-extensions==1.0.0