Skip to content

Commit

Permalink
This commit includes the following changes:
Browse files Browse the repository at this point in the history
(1)Created a new main.py file: Established basic FastAPI settings to enhance application structure and scalability. This setup includes initial configurations and middleware setup, laying the groundwork for future development.
(2)Renamed the original main.py: Changed to original_main.py to preserve the previous version and provide a reference for legacy code, facilitating a smooth transition and ensuring no loss of important historical context.
(3)Refactored project structure: Maintained the existing piggy directory and introduced a chat directory. This separation of routers and APIs improves modularity, making it easier to manage and extend each project independently.
These changes aim to improve code organization and prepare the project for scalable development with FastAPI.
  • Loading branch information
XXXJumpingFrogXXX committed Oct 27, 2024
1 parent 0999a45 commit c24a43f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
31 changes: 23 additions & 8 deletions chat/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,33 @@
class Question(BaseModel):
query: str

router = APIRouter()

@router.post("/generate_answer")
def generate_answer(value: Question):
try:
# Load the llama model and tokenizer from the pretrained model
llama_model, llama_tokenizer = FastLanguageModel.from_pretrained(
model_name = "Antonio27/llama3-8b-4-bit-for-sugar",
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
model_name="Antonio27/llama3-8b-4-bit-for-sugar",
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
)

# Load the gemma model and tokenizer from the pretrained model
gemma_model, gemma_tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/gemma-2-9b-it-bnb-4bit",
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
model_name="unsloth/gemma-2-9b-it-bnb-4bit",
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
)

# Prepare llama model for inference
FastLanguageModel.for_inference(llama_model)
llama_tokenizer.pad_token = llama_tokenizer.eos_token
llama_tokenizer.add_eos_token = True

# Tokenize the input question for the llama model
inputs = llama_tokenizer(
[
alpaca_prompt.format(
Expand All @@ -57,21 +63,26 @@ def generate_answer(value: Question):
)
], return_tensors="pt").to("cuda")

# Generate output using the llama model
outputs = llama_model.generate(**inputs, max_new_tokens=256, temperature=0.6)
decoded_outputs = llama_tokenizer.batch_decode(outputs)

# Extract the response text
response_text = decoded_outputs[0]

# Use regex to find the response section in the output
match = re.search(r"### Response:(.*?)(?=\n###|$)", response_text, re.DOTALL)
if match:
initial_response = match.group(1).strip()
else:
initial_response = ""

# Prepare gemma model for inference
FastLanguageModel.for_inference(gemma_model)
gemma_tokenizer.pad_token = gemma_tokenizer.eos_token
gemma_tokenizer.add_eos_token = True

# Tokenize the initial response for the gemma model
inputs = gemma_tokenizer(
[
alpaca_prompt.format(
Expand All @@ -87,17 +98,21 @@ def generate_answer(value: Question):
)
], return_tensors="pt").to("cuda")

# Generate adjusted output using the gemma model
outputs = gemma_model.generate(**inputs, max_new_tokens=256, temperature=0.6)
decoded_outputs = gemma_tokenizer.batch_decode(outputs)

# Extract the adjusted response text
response_text = decoded_outputs[0]

# Use regex to find the response section in the output
match = re.search(r"### Response:(.*?)(?=\n###|$)", response_text, re.DOTALL)
if match:
adjusted_response = match.group(1).strip()
else:
adjusted_response = ""

# Return the final adjusted response in a success dictionary
return {
'success': True,
'response': {
Expand Down
12 changes: 8 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,21 @@
from chat.router import router as chat_router
# from piggy.router import router as piggy_router

# Create a FastAPI application instance with custom documentation URL
app = FastAPI(
docs_url="/sugar-ai/docs",
)

# Include the chat router with a specified prefix for endpoint paths
app.include_router(chat_router, prefix="/sugar-ai/chat")
# Include the piggy router with a specified prefix for endpoint paths (currently commented out)
# app.include_router(piggy_router, prefix="/sugar-ai/piggy")

# Add CORS middleware to allow cross-origin requests from any origin
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
allow_origins=["*"], # Allow requests from any origin
allow_credentials=True, # Allow sending of credentials (e.g., cookies)
allow_methods=["*"], # Allow all HTTP methods
allow_headers=["*"], # Allow all headers
)

0 comments on commit c24a43f

Please sign in to comment.