diff --git a/chat/router.py b/chat/router.py index e2f001e..ab00cb9 100644 --- a/chat/router.py +++ b/chat/router.py @@ -22,27 +22,33 @@ class Question(BaseModel): query: str +router = APIRouter() + @router.post("/generate_answer") def generate_answer(value: Question): try: + # Load the llama model and tokenizer from the pretrained model llama_model, llama_tokenizer = FastLanguageModel.from_pretrained( - model_name = "Antonio27/llama3-8b-4-bit-for-sugar", - max_seq_length = max_seq_length, - dtype = dtype, - load_in_4bit = load_in_4bit, + model_name="Antonio27/llama3-8b-4-bit-for-sugar", + max_seq_length=max_seq_length, + dtype=dtype, + load_in_4bit=load_in_4bit, ) + # Load the gemma model and tokenizer from the pretrained model gemma_model, gemma_tokenizer = FastLanguageModel.from_pretrained( - model_name = "unsloth/gemma-2-9b-it-bnb-4bit", - max_seq_length = max_seq_length, - dtype = dtype, - load_in_4bit = load_in_4bit, + model_name="unsloth/gemma-2-9b-it-bnb-4bit", + max_seq_length=max_seq_length, + dtype=dtype, + load_in_4bit=load_in_4bit, ) + # Prepare llama model for inference FastLanguageModel.for_inference(llama_model) llama_tokenizer.pad_token = llama_tokenizer.eos_token llama_tokenizer.add_eos_token = True + # Tokenize the input question for the llama model inputs = llama_tokenizer( [ alpaca_prompt.format( @@ -57,21 +63,26 @@ def generate_answer(value: Question): ) ], return_tensors="pt").to("cuda") + # Generate output using the llama model outputs = llama_model.generate(**inputs, max_new_tokens=256, temperature=0.6) decoded_outputs = llama_tokenizer.batch_decode(outputs) + # Extract the response text response_text = decoded_outputs[0] + # Use regex to find the response section in the output match = re.search(r"### Response:(.*?)(?=\n###|$)", response_text, re.DOTALL) if match: initial_response = match.group(1).strip() else: initial_response = "" + # Prepare gemma model for inference FastLanguageModel.for_inference(gemma_model) gemma_tokenizer.pad_token = gemma_tokenizer.eos_token gemma_tokenizer.add_eos_token = True + # Tokenize the initial response for the gemma model inputs = gemma_tokenizer( [ alpaca_prompt.format( @@ -87,17 +98,21 @@ def generate_answer(value: Question): ) ], return_tensors="pt").to("cuda") + # Generate adjusted output using the gemma model outputs = gemma_model.generate(**inputs, max_new_tokens=256, temperature=0.6) decoded_outputs = gemma_tokenizer.batch_decode(outputs) + # Extract the adjusted response text response_text = decoded_outputs[0] + # Use regex to find the response section in the output match = re.search(r"### Response:(.*?)(?=\n###|$)", response_text, re.DOTALL) if match: adjusted_response = match.group(1).strip() else: adjusted_response = "" + # Return the final adjusted response in a success dictionary return { 'success': True, 'response': { diff --git a/main.py b/main.py index d1e736e..a685474 100644 --- a/main.py +++ b/main.py @@ -6,17 +6,21 @@ from chat.router import router as chat_router # from piggy.router import router as piggy_router +# Create a FastAPI application instance with custom documentation URL app = FastAPI( docs_url="/sugar-ai/docs", ) +# Include the chat router with a specified prefix for endpoint paths app.include_router(chat_router, prefix="/sugar-ai/chat") +# Include the piggy router with a specified prefix for endpoint paths (currently commented out) # app.include_router(piggy_router, prefix="/sugar-ai/piggy") +# Add CORS middleware to allow cross-origin requests from any origin app.add_middleware( CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + allow_origins=["*"], # Allow requests from any origin + allow_credentials=True, # Allow sending of credentials (e.g., cookies) + allow_methods=["*"], # Allow all HTTP methods + allow_headers=["*"], # Allow all headers ) \ No newline at end of file