Skip to content

Commit

Permalink
✨ Questions Eval: some improvements on computing score and better pro…
Browse files Browse the repository at this point in the history
…mpts
  • Loading branch information
simonmeoni committed Oct 8, 2024
1 parent a753dc6 commit 37f20d8
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 59 deletions.
2 changes: 1 addition & 1 deletion lib/questions_eval/bash/experiments/short.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
python run.py -m model=gpt-4o,gpt-4o-mini samples=50,100
python run.py -m model=llama3.1-405b-local question_model=gpt-4o samples=50 num_questions=6
1 change: 1 addition & 0 deletions lib/questions_eval/bash/experiments/tiny.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python run.py -m model=gpt-4o,gpt-4o-mini samples=20 num_questions=10
2 changes: 1 addition & 1 deletion lib/questions_eval/configs/model/llama3.1-405b-local.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
_target_: langchain_community.llms.huggingface_text_gen_inference.HuggingFaceTextGenInference
inference_server_url: http://20:216:186:42:8080
inference_server_url: http://20.216.186.42:8080
2 changes: 2 additions & 0 deletions lib/questions_eval/configs/question_model/gpt-4o.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_target_: langchain_community.chat_models.ChatOpenAI
model_name: gpt-4o
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_target_: langchain_community.llms.huggingface_text_gen_inference.HuggingFaceTextGenInference
inference_server_url: http://20.216.186.42:8080
51 changes: 34 additions & 17 deletions lib/questions_eval/configs/run.yaml
Original file line number Diff line number Diff line change
@@ -1,32 +1,49 @@
# @package _global_
defaults:
- _self_
- model: gpt-4o.yaml
- model: gpt-4o-mini.yaml
- question_model: gpt-4o.yaml

samples: 1
num_questions: 2
samples: 2
num_questions: 5

prompts:
transcription: >-
Given the following medical instruction and description, generate a synthetic transcription:
Instruction: {instruction} Description: {description}
As a clinician assistant, you must write a clinical report given these patients information.
keywords: {keywords}, {derived_keywords}
description: {description}
medical specialty: {medical_specialty}
Synthetic Transcription:
question: >-
Given the following transcription, generate a yes/no question that can be answered using the
information in the transcription. The question must be formulated in a way that the correct
answer is always 'yes':
Transcription: {transcription}
Question:
As a clinical assistant, please formulate {num_questions} critical, concise and closed-ended
questions (in a YES/NO format) that thoroughly scrutinize the document. The questions
generated should ALWAYS result in a ‘YES’ based on the given text. Questions should be
about the content of the document and not include any qualifier of the clarity, justification
or definition.
**Note**
The questions have to be STRICTLY closed-ended and should not be subjective or open to
human interpretation.
You should return in a JSON format. The JSON should be a list of dictionaries where each
dictionary will have two keys:
- ‘question’: specifying the question
- ‘answer’: either YES or NO.
The given text should be able to answer ‘YES’ for each generated question.
Document: {transcription}
JSON:
evaluation: >-
Compare the following pairs of transcriptions and questions:
As a clinical assistant, answer the following questions with a YES or NO, grounded on the
text content only. Do not use any external knowledge. If you cannot answer the question
based on the provided text, please respond with ‘IDK’.
**Note**
You should respond either YES, NO or IDK.
Transcription : {transcription}
Document : {transcription}
Question : {question} Based on the given transcription, answer the question with only 'yes',
'no', or 'idk' (if you don't know or can't determine from the information provided). Answer:
Question : {question}
Answer:
138 changes: 98 additions & 40 deletions lib/questions_eval/run.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import hydra
import pandas as pd
import wandb
from botocore.parsers import JSONParser
from coverage import coverage
from datasets import load_dataset
from dotenv import load_dotenv
from langchain.output_parsers import OutputFixingParser
from langchain.prompts import PromptTemplate
from langchain.schema import StrOutputParser
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import ChatPromptTemplate
from omegaconf import DictConfig, OmegaConf
from pandas import json_normalize
from tqdm import tqdm
Expand All @@ -13,34 +18,82 @@


def is_yes(result):
return "yes" in result[:5].lower()
return True if "yes" in result[:5].lower() else False


def is_idk(result):
return True if "idk" in result[:5].lower() else False


def is_no(result):
return True if "no" in result[:5].lower() else False


def generate_data(row, num_questions, transcription_chain, question_chain):
synthetic_transcription = transcription_chain.invoke(
{"instruction": row["instruction"], "description": row["description"]}
{
"keywords": row["keywords"],
"derived_keywords": row["derived_keywords"],
"description": row["description"],
"medical_specialty": row["medical_specialty"],
}
).strip()

data = []
for _ in range(num_questions):
real_question = question_chain.invoke({"transcription": row["transcription"]}).strip()
synthetic_question = question_chain.invoke(
{"transcription": synthetic_transcription}
).strip()
real_question = question_chain.invoke(
{
"transcription": row["transcription"],
"num_questions": num_questions,
}
)
synthetic_question = question_chain.invoke(
{
"transcription": synthetic_transcription,
"num_questions": num_questions,
}
)

min_length = min(len(real_question), len(synthetic_question))
real_question = real_question[:min_length]
synthetic_question = synthetic_question[:min_length]

for sq, q in zip(synthetic_question, real_question):
data.append(
{
"question": real_question,
"synthetic_question": synthetic_question,
"question": q["question"],
"synthetic_question": sq["question"],
"answer": "yes",
"synthetic_answer": "yes",
"synthetic_transcription": synthetic_transcription,
"transcription": row["transcription"],
}
)

return data


def compute_conformity(
synthetic_transcript_question,
transcript_synthetic_question,
transcript_question,
synthetic_transcript_synthetic_question,
):
score = 2
if (
is_yes(synthetic_transcript_question) != is_yes(transcript_question)
or is_idk(synthetic_transcript_question) != is_idk(transcript_question)
or is_no(synthetic_transcript_question) != is_no(transcript_question)
):
score -= 1
if (
is_yes(transcript_synthetic_question) != is_yes(synthetic_transcript_synthetic_question)
or is_idk(transcript_synthetic_question) != is_idk(synthetic_transcript_synthetic_question)
or is_no(transcript_synthetic_question) != is_no(synthetic_transcript_synthetic_question)
):
score -= 1
return float(score) / 2


def evaluate(row, evaluation_chain):
results = {}
for i in ["synthetic_transcription", "transcription"]:
Expand All @@ -49,32 +102,39 @@ def evaluate(row, evaluation_chain):
{"transcription": row[i], "question": row[j]}
)

raw_score = sum(1 for result in results.values() if is_yes(result))
synthetic_score = 1 if is_yes(results["synthetic_transcription/synthetic_question"]) else 0
real_score = 1 if is_yes(results["transcription/question"]) else 0
strict_synthetic_qa_score = (
1
if is_yes(results["synthetic_transcription/synthetic_question"])
and is_yes(results["transcription/synthetic_question"])
else 0
)
strict_qa_score = (
1
if is_yes(results["transcription/question"])
and is_yes(results["synthetic_transcription/question"])
else 0
coverage = 1 if is_idk(results["synthetic_transcription/question"]) else 0
consistency = 1 if not is_idk(results["transcription/synthetic_question"]) else 0
conformity = compute_conformity(
results["synthetic_transcription/question"],
results["transcription/synthetic_question"],
results["transcription/question"],
results["synthetic_transcription/synthetic_question"],
)
results["synthetic_score"] = synthetic_score
results["raw_score"] = raw_score / 4
results["real_score"] = real_score
results["strict_synthetic_qa_score"] = strict_synthetic_qa_score
results["strict_qa_score"] = strict_qa_score
results["consistency"] = consistency
results["conformity"] = conformity
results["coverage"] = coverage
return results


def create_chain(template, input, llm):
prompt = PromptTemplate(input=input, template=template)
return prompt | llm | StrOutputParser()
def create_chain(template, llm):
chat_template = ChatPromptTemplate.from_messages(
[
("system", "You are an helpful clinical assistant."),
("human", template),
]
)
return chat_template | llm | StrOutputParser()


def create_question_chain(template, llm):
chat_template = ChatPromptTemplate.from_messages(
[
("system", "You are an helpful clinical assistant."),
("human", template),
]
)
fix_output_parser = OutputFixingParser.from_llm(llm, parser=JsonOutputParser())
return chat_template | llm | fix_output_parser


@hydra.main(config_path="./configs", config_name="run.yaml")
Expand All @@ -83,20 +143,18 @@ def main(cfg: DictConfig):
wandb.config.update(OmegaConf.to_container(cfg, resolve=True))

llm = hydra.utils.instantiate(cfg.model)
question_llm = hydra.utils.instantiate(cfg.question_model)
transcription_chain = create_chain(
cfg.prompts.transcription,
["instruction", "description"],
llm,
)
question_chain = create_chain(
question_chain = create_question_chain(
cfg.prompts.question,
["transcription"],
llm,
question_llm,
)
evaluation_chain = create_chain(
cfg.prompts.evaluation,
["transcription", "question"],
llm,
question_llm,
)

# Load and process dataset
Expand Down Expand Up @@ -138,12 +196,12 @@ def main(cfg: DictConfig):
# Log results in wandb
log_dict = {
f"{stat}/score/{score_type}": (
df_joined[f"{score_type}_score"].agg(stat)
df_joined[f"{score_type}"].agg(stat)
if stat == "sum"
else df_joined[f"{score_type}_score"].agg("sum") / len(df_joined)
else df_joined[f"{score_type}"].agg("sum") / len(df_joined)
)
for stat in ["sum", "mean"]
for score_type in ["raw", "synthetic", "real", "strict_synthetic_qa", "strict_qa"]
for score_type in ["consistency", "conformity", "coverage"]
}
for key, value in log_dict.items():
wandb.run.summary[key] = value
Expand Down

0 comments on commit 37f20d8

Please sign in to comment.