Skip to content

Commit

Permalink
Modifications on FastAPI
Browse files Browse the repository at this point in the history
  • Loading branch information
XXXJumpingFrogXXX committed Oct 16, 2024
1 parent ce06570 commit 332db65
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
31 changes: 23 additions & 8 deletions chat/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,33 @@
class Question(BaseModel):
query: str

router = APIRouter()

@router.post("/generate_answer")
def generate_answer(value: Question):
try:
# Load the llama model and tokenizer from the pretrained model
llama_model, llama_tokenizer = FastLanguageModel.from_pretrained(
model_name = "Antonio27/llama3-8b-4-bit-for-sugar",
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
model_name="Antonio27/llama3-8b-4-bit-for-sugar",
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
)

# Load the gemma model and tokenizer from the pretrained model
gemma_model, gemma_tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/gemma-2-9b-it-bnb-4bit",
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
model_name="unsloth/gemma-2-9b-it-bnb-4bit",
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
)

# Prepare llama model for inference
FastLanguageModel.for_inference(llama_model)
llama_tokenizer.pad_token = llama_tokenizer.eos_token
llama_tokenizer.add_eos_token = True

# Tokenize the input question for the llama model
inputs = llama_tokenizer(
[
alpaca_prompt.format(
Expand All @@ -57,21 +63,26 @@ def generate_answer(value: Question):
)
], return_tensors="pt").to("cuda")

# Generate output using the llama model
outputs = llama_model.generate(**inputs, max_new_tokens=256, temperature=0.6)
decoded_outputs = llama_tokenizer.batch_decode(outputs)

# Extract the response text
response_text = decoded_outputs[0]

# Use regex to find the response section in the output
match = re.search(r"### Response:(.*?)(?=\n###|$)", response_text, re.DOTALL)
if match:
initial_response = match.group(1).strip()
else:
initial_response = ""

# Prepare gemma model for inference
FastLanguageModel.for_inference(gemma_model)
gemma_tokenizer.pad_token = gemma_tokenizer.eos_token
gemma_tokenizer.add_eos_token = True

# Tokenize the initial response for the gemma model
inputs = gemma_tokenizer(
[
alpaca_prompt.format(
Expand All @@ -87,17 +98,21 @@ def generate_answer(value: Question):
)
], return_tensors="pt").to("cuda")

# Generate adjusted output using the gemma model
outputs = gemma_model.generate(**inputs, max_new_tokens=256, temperature=0.6)
decoded_outputs = gemma_tokenizer.batch_decode(outputs)

# Extract the adjusted response text
response_text = decoded_outputs[0]

# Use regex to find the response section in the output
match = re.search(r"### Response:(.*?)(?=\n###|$)", response_text, re.DOTALL)
if match:
adjusted_response = match.group(1).strip()
else:
adjusted_response = ""

# Return the final adjusted response in a success dictionary
return {
'success': True,
'response': {
Expand Down
12 changes: 8 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,21 @@
from chat.router import router as chat_router
# from piggy.router import router as piggy_router

# Create a FastAPI application instance with custom documentation URL
app = FastAPI(
docs_url="/sugar-ai/docs",
)

# Include the chat router with a specified prefix for endpoint paths
app.include_router(chat_router, prefix="/sugar-ai/chat")
# Include the piggy router with a specified prefix for endpoint paths (currently commented out)
# app.include_router(piggy_router, prefix="/sugar-ai/piggy")

# Add CORS middleware to allow cross-origin requests from any origin
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
allow_origins=["*"], # Allow requests from any origin
allow_credentials=True, # Allow sending of credentials (e.g., cookies)
allow_methods=["*"], # Allow all HTTP methods
allow_headers=["*"], # Allow all headers
)

0 comments on commit 332db65

Please sign in to comment.