Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Added Audio Input for Generating Q&A #35

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Binary file not shown.
130 changes: 77 additions & 53 deletions backend/server.py
Original file line number Diff line number Diff line change
@@ -1,114 +1,138 @@
import http.server
import json
import socketserver
import urllib.parse
from http.server import BaseHTTPRequestHandler

import librosa
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import json
from transformers import (
T5ForConditionalGeneration,
T5Tokenizer,
Wav2Vec2ForCTC,
Wav2Vec2Tokenizer,
pipeline,
)

IP = "127.0.0.1"
PORT = 8000

IP='127.0.0.1'
PORT=8000

def summarize(text):
summarizer=pipeline('summarization')
return summarizer(text,max_length=110)[0]['summary_text']
summarizer = pipeline("summarization")
return summarizer(text, max_length=110)[0]["summary_text"]


def generate_question(context,answer,model_path, tokenizer_path):
def generate_question(context, answer, model_path, tokenizer_path):
model = T5ForConditionalGeneration.from_pretrained(model_path)
tokenizer = T5Tokenizer.from_pretrained(tokenizer_path)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

input_text=f'answer: {answer} context: {context}'
input_text = f"answer: {answer} context: {context}"

inputs=tokenizer.encode_plus(
inputs = tokenizer.encode_plus(
input_text,
padding='max_length',
padding="max_length",
truncation=True,
max_length=512,
return_tensors='pt'
return_tensors="pt",
)

input_ids=inputs['input_ids'].to(device)
attention_mask=inputs['attention_mask'].to(device)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)

with torch.no_grad():
output=model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=32
output = model.generate(
input_ids=input_ids, attention_mask=attention_mask, max_length=32
)

generated_question = tokenizer.decode(output[0], skip_special_tokens=True)
return generated_question

def generate_keyphrases(abstract, model_path,tokenizer_path):
device= torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def generate_keyphrases(abstract, model_path, tokenizer_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = T5ForConditionalGeneration.from_pretrained(model_path)
tokenizer = T5Tokenizer.from_pretrained(tokenizer_path)
model.to(device)
# tokenizer.to(device)
input_text=f'detect keyword: abstract: {abstract}'
input_ids=tokenizer.encode(input_text, truncation=True,padding='max_length',max_length=512,return_tensors='pt').to(device)
output=model.generate(input_ids)
keyphrases= tokenizer.decode(output[0],skip_special_tokens=True).split(',')
return [x.strip() for x in keyphrases if x != '']
input_text = f"detect keyword: abstract: {abstract}"
input_ids = tokenizer.encode(
input_text,
truncation=True,
padding="max_length",
max_length=512,
return_tensors="pt",
).to(device)
output = model.generate(input_ids)
keyphrases = tokenizer.decode(output[0], skip_special_tokens=True).split(",")
return [x.strip() for x in keyphrases if x != ""]


def generate_qa(text):

# text_summary=summarize(text)
text_summary=text

text_summary = text

modelA, modelB='./models/modelA','./models/modelB'
modelA, modelB = "./models/modelA", "./models/modelB"
# tokenizerA, tokenizerB= './tokenizers/tokenizerA', './tokenizers/tokenizerB'
tokenizerA, tokenizerB= 't5-base', 't5-base'
tokenizerA, tokenizerB = "t5-base", "t5-base"

answers=generate_keyphrases(text_summary, modelA, tokenizerA)
answers = generate_keyphrases(text_summary, modelA, tokenizerA)

qa={}
qa = {}
for answer in answers:
question= generate_question(text_summary, answer, modelB, tokenizerB)
qa[question]=answer
question = generate_question(text_summary, answer, modelB, tokenizerB)
qa[question] = answer

return qa



def process_audio(audio_file):
audio, _ = librosa.load(audio_file, sr=16000)

model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")

input_values = tokenizer(audio, return_tensors="pt").input_values

class QARequestHandler(http.server.BaseHTTPRequestHandler):
logits = model(input_values).logits

predicted_ids = torch.argmax(logits, dim=-1)
transcription = tokenizer.batch_decode(predicted_ids)[0]
print("transcription", transcription)
return transcription

def do_POST(self):

class QARequestHandler(BaseHTTPRequestHandler):
def do_POST(self):
self.send_response(200)
self.send_header("Content-type", "text/plain")
self.end_headers()

content_length=int(self.headers["Content-Length"])
post_data=self.rfile.read(content_length).decode('utf-8')
content_length = int(self.headers["Content-Length"])
post_data = self.rfile.read(content_length)

# parsed_data=urllib.parse.parse_qs(post_data)
parsed_data = json.loads(post_data)
input_type = parsed_data.get("input_type")
input_data = parsed_data.get("input_data")


input_text=parsed_data.get('input_text')

qa=generate_qa(input_text)


if input_type == "text":
qa = generate_qa(input_data)
elif input_type == "audio":
audio_text = process_audio(input_data)
qa = generate_qa(audio_text)
else:
qa = {}

self.wfile.write(json.dumps(qa).encode("utf-8"))
self.wfile.flush()


def main():
with socketserver.TCPServer((IP, PORT), QARequestHandler) as server:
print(f'Server started at http://{IP}:{PORT}')
print(f"Server started at http://{IP}:{PORT}")
server.serve_forever()

if __name__=="__main__":
main()


if __name__ == "__main__":
main()
15 changes: 8 additions & 7 deletions backend/test_server.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import requests
from sample_input import sample_input
url='http://127.0.0.1:8000'

url = "http://127.0.0.1:8000"

# sample_input="""
# Mitochondria are double-membraned organelles with an inner membrane that forms cristae. The enzymes within the inner membrane are essential for ATP production during oxidative phosphorylation. The outer membrane provides a protective barrier and contains porins to allow the passage of ions and molecules. The matrix, the innermost compartment, is involved in citric acid cycle and houses the mitochondrial DNA. The electron transport chain, present in the inner membrane, is responsible for electron transport and the generation of the electrochemical gradient. Overall, mitochondria function as the cell's powerhouses, producing energy through cellular respiration and maintaining cellular processes like apoptosis.
# """

response=requests.post(url,data={"input_text": sample_input})
response = requests.post(url, data={"input_text": sample_input})

result=response.json()
result = response.json()

for question,answer in result.items():
print(f'Question: {question}')
print(f'Answer: {answer}')
print('-'*30)
for question, answer in result.items():
print(f"Question: {question}")
print(f"Answer: {answer}")
print("-" * 30)
8 changes: 3 additions & 5 deletions extension/html/text_input.html
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
<html>
<head>
<title>EduAid: Text Input</title>
<!-- <link href="https://fonts.googleapis.com/css?family=Roboto:400,500" rel="stylesheet">
<link rel="stylesheet" href="./popup.css"> -->
<script src="../pdfjs-3.9.179-dist/build/pdf.js"></script>
<link href='https://fonts.googleapis.com/css?family=Inter' rel='stylesheet'>
<link rel="stylesheet" href="../styles/text_input.css">
Expand All @@ -15,7 +13,9 @@ <h1>EduAid</h1>
</header>
<main>
<h3>Generate QnA</h3>
<textarea id="text-input" placeholder="Paste your text here"></textarea>
<textarea class="text-input" id="text-input" placeholder="Paste your text here"></textarea>
<div class="or-divider">OR</div>
<input id="audio-input" type="text" placeholder="Enter audio path">
<div>
<input type="file" id="file-upload" accept=".pdf" hidden>
<label for="file-upload" id="upload-label">Upload PDF &#128196</label>
Expand All @@ -24,11 +24,9 @@ <h3>Generate QnA</h3>
<button id="back-button">Back</button>
<button id="next-button">Next</button>
</div>
<!-- ******************* -->
<div id="loading-screen" class="loading-screen">
<div class="loading-spinner"></div>
</div>
<!-- ****************** -->
</main>
<script src="../js/text_input.js"></script>
</body>
Expand Down
111 changes: 58 additions & 53 deletions extension/js/question_generation.js
Original file line number Diff line number Diff line change
@@ -1,56 +1,61 @@
document.addEventListener("DOMContentLoaded", function(){
const saveButton= document.getElementById("save-button");
const backButton= document.getElementById("back-button");
const viewQuestionsButton = document.getElementById("view-questions-button");
const qaPairs=JSON.parse(localStorage.getItem("qaPairs"));
const modalClose= document.querySelector("[data-close-modal]");
const modal=document.querySelector("[data-modal]");


viewQuestionsButton.addEventListener("click", function(){
const modalQuestionList = document.getElementById("modal-question-list");
modalQuestionList.innerHTML = ""; // Clear previous content

for (const [question, answer] of Object.entries(qaPairs)) {
const questionElement = document.createElement("li");
questionElement.textContent = `Question: ${question}, Answer: ${answer}`;
modalQuestionList.appendChild(questionElement)
document.addEventListener("DOMContentLoaded", function () {
const saveButton = document.getElementById("save-button");
const backButton = document.getElementById("back-button");
const viewQuestionsButton = document.getElementById("view-questions-button");
const qaPairs = JSON.parse(localStorage.getItem("qaPairs"));
const modalClose = document.querySelector("[data-close-modal]");
const modal = document.querySelector("[data-modal]");

viewQuestionsButton.addEventListener("click", function () {
const modalQuestionList = document.getElementById("modal-question-list");
modalQuestionList.innerHTML = "";

for (const [question, answer] of Object.entries(qaPairs)) {
const questionElement = document.createElement("li");
if (question.includes("Options:")) {
const options = question.split("Options: ")[1].split(", ");
const formattedOptions = options.map(
(opt, index) => `${String.fromCharCode(97 + index)}) ${opt}`
);
questionElement.textContent = `Question: ${
question.split(" Options:")[0]
}\n${formattedOptions.join("\n")}`;
} else {
questionElement.textContent = `Question: ${question}\n\nAnswer: ${answer}\n`;
}
modal.showModal();
});

modalClose.addEventListener("click", function(){
modal.close();
});
saveButton.addEventListener("click", async function(){
let textContent= "EduAid Generated QnA:\n\n";
modalQuestionList.appendChild(questionElement);
}
modal.showModal();
});

for (const [question,answer] of Object.entries(qaPairs)){
textContent+= `Question: ${question}\nAnswer: ${answer}\n\n`;
}
const blob = new Blob([textContent], { type: "text/plain" });

// Create a URL for the Blob
const blobUrl = URL.createObjectURL(blob);

// Create a temporary <a> element to trigger the download
const downloadLink = document.createElement("a");
downloadLink.href = blobUrl;
downloadLink.download = "questions_and_answers.txt";
downloadLink.style.display = "none";

// Append the <a> element to the document
document.body.appendChild(downloadLink);

// Simulate a click on the link to trigger the download
downloadLink.click();

// Clean up: remove the temporary <a> element and revoke the Blob URL
document.body.removeChild(downloadLink);
URL.revokeObjectURL(blobUrl);
});

backButton.addEventListener("click", function(){
window.location.href="../html/text_input.html"
});
});
modalClose.addEventListener("click", function () {
modal.close();
});
saveButton.addEventListener("click", async function () {
let textContent = "EduAid Generated QnA:\n\n";

for (const [question, answer] of Object.entries(qaPairs)) {
textContent += `Question: ${question}\nAnswer: ${answer}\n\n`;
}
const blob = new Blob([textContent], { type: "text/plain" });

const blobUrl = URL.createObjectURL(blob);

const downloadLink = document.createElement("a");
downloadLink.href = blobUrl;
downloadLink.download = "questions_and_answers.txt";
downloadLink.style.display = "none";

document.body.appendChild(downloadLink);

downloadLink.click();

document.body.removeChild(downloadLink);
URL.revokeObjectURL(blobUrl);
});

backButton.addEventListener("click", function () {
window.location.href = "../html/text_input.html";
});
});
Loading