-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
122 lines (98 loc) · 4.27 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from dotenv import load_dotenv
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import Chroma
from langchain.llms import GPT4All, LlamaCpp
import os
from fastapi import FastAPI, UploadFile, File
from typing import List, Optional
import urllib.parse
app = FastAPI()
load_dotenv()
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
persist_directory = os.environ.get('PERSIST_DIRECTORY')
model_type = os.environ.get('MODEL_TYPE')
model_path = os.environ.get('MODEL_PATH')
model_n_ctx = os.environ.get('MODEL_N_CTX')
source_directory = os.environ.get('SOURCE_DIRECTORY', 'source_documents')
from constants import CHROMA_SETTINGS
async def test_embedding():
# Create the folder if it doesn't exist
os.makedirs(source_directory, exist_ok=True)
# Create a sample.txt file in the source_documents directory
file_path = os.path.join("source_documents", "test.txt")
with open(file_path, "w") as file:
file.write("This is a test.")
# Run the ingest.py command
os.system('python ingest.py --collection test')
# Delete the sample.txt file
os.remove(file_path)
print("embeddings working")
async def model_download():
match model_type:
case "LlamaCpp":
url = "https://gpt4all.io/models/ggml-gpt4all-l13b-snoozy.bin"
case "GPT4All":
url = "https://gpt4all.io/models/ggml-gpt4all-j-v1.3-groovy.bin"
folder = "models"
parsed_url = urllib.parse.urlparse(url)
filename = os.path.join(folder, os.path.basename(parsed_url.path))
# Check if the file already exists
if os.path.exists(filename):
print("File already exists.")
return
# Create the folder if it doesn't exist
os.makedirs(folder, exist_ok=True)
# Run wget command to download the file
os.system(f"wget {url} -O {filename}")
global model_path
model_path = filename
os.environ['MODEL_PATH'] = filename
print("model downloaded")
# Starting the app with embedding and llm download
@app.on_event("startup")
async def startup_event():
await test_embedding()
await model_download()
# Example route
@app.get("/")
async def root():
return {"message": "Hello, the APIs are now ready for your embeds and queries!"}
@app.post("/embed")
async def embed(files: List[UploadFile], collection_name: Optional[str] = None):
saved_files = []
# Save the files to the specified folder
for file in files:
file_path = os.path.join(source_directory, file.filename)
saved_files.append(file_path)
with open(file_path, "wb") as f:
f.write(await file.read())
if collection_name is None:
# Handle the case when the collection_name is not defined
collection_name = file.filename
os.system(f'python ingest.py --collection {collection_name}')
# Delete the contents of the folder
[os.remove(os.path.join(source_directory, file.filename)) or os.path.join(source_directory, file.filename) for file in files]
return {"message": "Files embedded successfully", "saved_files": saved_files}
@app.post("/retrieve")
async def query(query: str, collection_name:str):
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
db = Chroma(persist_directory=persist_directory,collection_name=collection_name, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
retriever = db.as_retriever()
# Prepare the LLM
callbacks = [StreamingStdOutCallbackHandler()]
match model_type:
case "LlamaCpp":
llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False)
case "GPT4All":
llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
case _default:
print(f"Model {model_type} not supported!")
exit;
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
# Get the answer from the chain
res = qa(query)
print(res)
answer, docs = res['result'], res['source_documents']
return {"results": answer, "docs":docs}