Skip to content

Commit

Permalink
Update requirements, clean up cohere SDK code, add langchain to suppo…
Browse files Browse the repository at this point in the history
…rt multi-needle
  • Loading branch information
arkadyark-cohere committed Mar 26, 2024
1 parent 8af85a8 commit 64698fd
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 10 deletions.
72 changes: 64 additions & 8 deletions needlehaystack/providers/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from operator import itemgetter
from typing import Optional
from langchain.prompts import PromptTemplate
from langchain_cohere import ChatCohere

from cohere import Client, AsyncClient

Expand Down Expand Up @@ -30,23 +32,33 @@ def __init__(self,

self.client = AsyncClient(api_key=self.api_key)

async def evaluate_model(self, prompt: str) -> str:
response = await self.client.chat(message=prompt[-1]["message"], chat_history=prompt[:-1], model=self.model_name, **self.model_kwargs)
async def evaluate_model(self, prompt: tuple[str, list[dict, str, str]]) -> str:
message, chat_history = prompt
response = await self.client.chat(message=message, chat_history=chat_history, model=self.model_name, **self.model_kwargs)
return response.text

def generate_prompt(self, context: str, retrieval_question: str) -> str | list[dict[str, str]]:
return [{
def generate_prompt(self, context: str, retrieval_question: str) -> tuple[str, list[dict[str, str]]]:
'''
Prepares a chat-formatted prompt
Args:
context (str): The needle in a haystack context
retrieval_question (str): The needle retrieval question
Returns:
tuple[str, list[dict[str, str]]]: prompt encoded as last message, and chat history
'''
return (
f"{retrieval_question} Don't give information outside the document or repeat your findings",
[{
"role": "System",
"message": "You are a helpful AI bot that answers questions for a user. Keep your response short and direct"
},
{
"role": "User",
"message": context
},
{
"role": "User",
"message": f"{retrieval_question} Don't give information outside the document or repeat your findings"
}]
)

def encode_text_to_tokens(self, text: str) -> list[int]:
if not text: return []
Expand All @@ -55,3 +67,47 @@ def encode_text_to_tokens(self, text: str) -> list[int]:
def decode_tokens(self, tokens: list[int], context_length: Optional[int] = None) -> str:
# Assuming you have a different decoder for Anthropic
return Client().detokenize(tokens=tokens[:context_length], model=self.model_name).text

def get_langchain_runnable(self, context: str):
"""
Creates a LangChain runnable that constructs a prompt based on a given context and a question.
Args:
context (str): The context or background information relevant to the user's question.
This context is provided to the model to aid in generating relevant and accurate responses.
Returns:
str: A LangChain runnable object that can be executed to obtain the model's response to a
dynamically provided question. The runnable encapsulates the entire process from prompt
generation to response retrieval.
Example:
To use the runnable:
- Define the context and question.
- Execute the runnable with these parameters to get the model's response.
"""


template = """Human: You are a helpful AI bot that answers questions for a user. Keep your response short and direct" \n
<document_content>
{context}
</document_content>
Here is the user question:
<question>
{question}
</question>
Don't give information outside the document or repeat your findings.
Assistant: Here is the most relevant information in the documents:"""

api_key = os.getenv('NIAH_MODEL_API_KEY')
model = ChatCohere(cohere_api_key=api_key, temperature=0.3, model=self.model_name)
prompt = PromptTemplate(
template=template,
input_variables=["context", "question"],
)
chain = ( {"context": lambda x: context,
"question": itemgetter("question")}
| prompt
| model
)
return chain
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ anyio==3.7.1
attrs==23.1.0
certifi==2023.11.17
charset-normalizer==3.3.2
cohere>=5.0.0
cohere>=5.1.2
dataclasses-json==0.6.3
distro==1.8.0
filelock==3.13.1
Expand All @@ -22,10 +22,11 @@ jsonpatch==1.33
jsonpointer==2.4
langchain==0.1.9
langchain-community>=0.0.24
langchain-core==0.1.26
langchain-core>=0.1.26
langsmith>=0.1.8
langchain_openai
langchain_anthropic
langchain_cohere
marshmallow==3.20.1
multidict==6.0.4
mypy-extensions==1.0.0
Expand Down

0 comments on commit 64698fd

Please sign in to comment.