diff --git a/needlehaystack/providers/cohere.py b/needlehaystack/providers/cohere.py index 076dc234..c600c0dd 100644 --- a/needlehaystack/providers/cohere.py +++ b/needlehaystack/providers/cohere.py @@ -3,6 +3,8 @@ from operator import itemgetter from typing import Optional +from langchain.prompts import PromptTemplate +from langchain_cohere import ChatCohere from cohere import Client, AsyncClient @@ -30,23 +32,33 @@ def __init__(self, self.client = AsyncClient(api_key=self.api_key) - async def evaluate_model(self, prompt: str) -> str: - response = await self.client.chat(message=prompt[-1]["message"], chat_history=prompt[:-1], model=self.model_name, **self.model_kwargs) + async def evaluate_model(self, prompt: tuple[str, list[dict, str, str]]) -> str: + message, chat_history = prompt + response = await self.client.chat(message=message, chat_history=chat_history, model=self.model_name, **self.model_kwargs) return response.text - def generate_prompt(self, context: str, retrieval_question: str) -> str | list[dict[str, str]]: - return [{ + def generate_prompt(self, context: str, retrieval_question: str) -> tuple[str, list[dict[str, str]]]: + ''' + Prepares a chat-formatted prompt + Args: + context (str): The needle in a haystack context + retrieval_question (str): The needle retrieval question + + Returns: + tuple[str, list[dict[str, str]]]: prompt encoded as last message, and chat history + + ''' + return ( + f"{retrieval_question} Don't give information outside the document or repeat your findings", + [{ "role": "System", "message": "You are a helpful AI bot that answers questions for a user. Keep your response short and direct" }, { "role": "User", "message": context - }, - { - "role": "User", - "message": f"{retrieval_question} Don't give information outside the document or repeat your findings" }] + ) def encode_text_to_tokens(self, text: str) -> list[int]: if not text: return [] @@ -55,3 +67,47 @@ def encode_text_to_tokens(self, text: str) -> list[int]: def decode_tokens(self, tokens: list[int], context_length: Optional[int] = None) -> str: # Assuming you have a different decoder for Anthropic return Client().detokenize(tokens=tokens[:context_length], model=self.model_name).text + + def get_langchain_runnable(self, context: str): + """ + Creates a LangChain runnable that constructs a prompt based on a given context and a question. + + Args: + context (str): The context or background information relevant to the user's question. + This context is provided to the model to aid in generating relevant and accurate responses. + + Returns: + str: A LangChain runnable object that can be executed to obtain the model's response to a + dynamically provided question. The runnable encapsulates the entire process from prompt + generation to response retrieval. + + Example: + To use the runnable: + - Define the context and question. + - Execute the runnable with these parameters to get the model's response. + """ + + + template = """Human: You are a helpful AI bot that answers questions for a user. Keep your response short and direct" \n + + {context} + + Here is the user question: + + {question} + + Don't give information outside the document or repeat your findings. + Assistant: Here is the most relevant information in the documents:""" + + api_key = os.getenv('NIAH_MODEL_API_KEY') + model = ChatCohere(cohere_api_key=api_key, temperature=0.3, model=self.model_name) + prompt = PromptTemplate( + template=template, + input_variables=["context", "question"], + ) + chain = ( {"context": lambda x: context, + "question": itemgetter("question")} + | prompt + | model + ) + return chain diff --git a/requirements.txt b/requirements.txt index 5d044d16..a6709d25 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ anyio==3.7.1 attrs==23.1.0 certifi==2023.11.17 charset-normalizer==3.3.2 -cohere>=5.0.0 +cohere>=5.1.2 dataclasses-json==0.6.3 distro==1.8.0 filelock==3.13.1 @@ -22,10 +22,11 @@ jsonpatch==1.33 jsonpointer==2.4 langchain==0.1.9 langchain-community>=0.0.24 -langchain-core==0.1.26 +langchain-core>=0.1.26 langsmith>=0.1.8 langchain_openai langchain_anthropic +langchain_cohere marshmallow==3.20.1 multidict==6.0.4 mypy-extensions==1.0.0