Skip to content

Commit

Permalink
chore: remove duplicate file
Browse files Browse the repository at this point in the history
  • Loading branch information
honghanhh committed Oct 31, 2024
1 parent 034c49b commit 8284158
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 444 deletions.
246 changes: 157 additions & 89 deletions lib/questions_eval/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,48 @@ def is_no(result: str) -> bool:
return True if "no" in result[:5].lower() else False


def create_synthetic_summary(summary_chain, row):
"""
Create synthetic summary
Args:
summary_chain: The chain for generating summaries
row: The row to generate the summary for
Returns:
The synthetic summary
"""
return summary_chain.invoke(
{
"section_title": row["section_title"],
"text": row["text"],
}
).strip()


def create_question(question_chain, column, num_questions):
"""
Create real/synthetic questions
Args:
question_chain: The chain for generating questions
row: The row to generate the question for (optiont: real or synthetic data)
Returns:
The generated questions
"""
return question_chain.invoke(
{
"summary": column,
"num_questions": num_questions,
}
)


def generate_data(
row: pd.Series,
num_questions: int,
transcription_chain,
summary_chain,
question_chain,
):
"""
Expand All @@ -67,62 +105,47 @@ def generate_data(
Args:
row: text data
num_questions: number of questions to generate
transcription_chain: generated synthetic transcription
summary_chain: generated synthetic summary
question_chain: generated synthetic questions
Returns:
The merged data with the following columns:
question, synthetic_question, answer, synthetic answer,
synthetic_transcription, transcription
synthetic_summary, summary
"""
synthetic_transcription = transcription_chain.invoke(
{
"keywords": row["keywords"],
"derived_keywords": row["derived_keywords"],
"description": row["description"],
"medical_specialty": row["medical_specialty"],
}
).strip()
# Generate synthetic summary
synthetic_summary = create_synthetic_summary(summary_chain, row)

data = []
real_question = question_chain.invoke(
{
"transcription": row["transcription"],
"num_questions": num_questions,
}
)
synthetic_question = question_chain.invoke(
{
"transcription": synthetic_transcription,
"num_questions": num_questions,
}
)
# Generate questions for the real and synthetic summaries
real_question = create_question(question_chain, row["summary"], num_questions)
synthetic_question = create_question(question_chain, synthetic_summary, num_questions)

min_length = min(len(real_question), len(synthetic_question))
real_question = real_question[:min_length]
synthetic_question = synthetic_question[:min_length]

data = []
for sq, q in zip(synthetic_question, real_question):
data.append(
{
"question": q["question"],
"synthetic_question": sq["question"],
"answer": "yes",
"synthetic_answer": "yes",
"synthetic_transcription": synthetic_transcription,
"transcription": row["transcription"],
"synthetic_summary": synthetic_summary,
"summary": row["summary"],
}
)

return data


def compute_conformity(
synthetic_transcript_question: str,
transcript_synthetic_question: str,
transcript_question: str,
synthetic_transcript_synthetic_question: str,
synthetic_summary_question: str,
summary_synthetic_question: str,
summary_question: str,
synthetic_summary_synthetic_question: str,
) -> float:
"""
Calculate conformity score.
Expand All @@ -131,25 +154,25 @@ def compute_conformity(
is "YES", or vice versa, and computing 100 − X
Args:
synthetic_transcript_question: Synthetic transcript for groundtruth question
transcript_synthetic_question: Groundtruth transcript for synthetic question
transcript_question: Groundtruth transcript for groundtruth question
synthetic_transcript_synthetic_question: Synthetic transcript for synthetic question
synthetic_summary_question: Synthetic summary for groundtruth question
summary_synthetic_question: Groundtruth summary for synthetic question
summary_question: Groundtruth summary for groundtruth question
synthetic_summary_synthetic_question: Synthetic summary for synthetic question
Returns:
The conformity score
"""
score = 2
if (
is_yes(synthetic_transcript_question) != is_yes(transcript_question)
or is_idk(synthetic_transcript_question) != is_idk(transcript_question)
or is_no(synthetic_transcript_question) != is_no(transcript_question)
is_yes(synthetic_summary_question) != is_yes(summary_question)
or is_idk(synthetic_summary_question) != is_idk(summary_question)
or is_no(synthetic_summary_question) != is_no(summary_question)
):
score -= 1
if (
is_yes(transcript_synthetic_question) != is_yes(synthetic_transcript_synthetic_question)
or is_idk(transcript_synthetic_question) != is_idk(synthetic_transcript_synthetic_question)
or is_no(transcript_synthetic_question) != is_no(synthetic_transcript_synthetic_question)
is_yes(summary_synthetic_question) != is_yes(synthetic_summary_synthetic_question)
or is_idk(summary_synthetic_question) != is_idk(synthetic_summary_synthetic_question)
or is_no(summary_synthetic_question) != is_no(synthetic_summary_synthetic_question)
):
score -= 1
return float(score) / 2
Expand All @@ -168,19 +191,19 @@ def evaluate(row, evaluation_chain) -> dict:
"""
results = {}
for i, j in itertools.product(
["synthetic_transcription", "transcription"],
["synthetic_summary", "summary"],
["synthetic_question", "question"],
):
results[f"{i}/{j}"] = evaluation_chain.invoke({"transcription": row[i], "question": row[j]})
results[f"{i}/{j}"] = evaluation_chain.invoke({"summary": row[i], "question": row[j]})

# Compute conformity, consistency and coverage
coverage = 1 if is_idk(results["synthetic_transcription/question"]) else 0
consistency = 1 if not is_idk(results["transcription/synthetic_question"]) else 0
coverage = 1 if is_idk(results["synthetic_summary/question"]) else 0
consistency = 1 if not is_idk(results["summary/synthetic_question"]) else 0
conformity = compute_conformity(
results["synthetic_transcription/question"],
results["transcription/synthetic_question"],
results["transcription/question"],
results["synthetic_transcription/synthetic_question"],
results["synthetic_summary/question"],
results["summary/synthetic_question"],
results["summary/question"],
results["synthetic_summary/synthetic_question"],
)
results["consistency"] = consistency
results["conformity"] = conformity
Expand All @@ -196,17 +219,18 @@ def create_chain(template: str, llm: str, is_question_chain: bool):
template: The template for the prompt
llm: The language model to use
is_question_chain: Boolean indicating whether the chain is used for
question generation (True) or transcription generation (False)
question generation (True) or summary generation (False)
Returns:
The chain of models for either question or transcription generation
The chain of models for either question or summary generation
"""
chat_template = ChatPromptTemplate.from_messages(
[
("system", "You are an helpful clinical assistant."),
("human", template),
]
)

return (
chat_template
| llm
Expand All @@ -218,68 +242,112 @@ def create_chain(template: str, llm: str, is_question_chain: bool):
)


@hydra.main(config_path="./configs", config_name="run.yaml")
def main(cfg: DictConfig):
# Initialize WandB and log the models
wandb.init(project="document-cross-validation", entity="clinical-dream-team")
wandb.config.update(OmegaConf.to_container(cfg, resolve=True))

llm = hydra.utils.instantiate(cfg.model)
question_llm = hydra.utils.instantiate(cfg.question_model)
transcription_chain = create_chain(cfg.prompts.transcription, llm, False)
question_chain = create_chain(cfg.prompts.question, question_llm, True)
evaluation_chain = create_chain(cfg.prompts.evaluation, question_llm, False)

# Load and process dataset
loaded_dataset = load_dataset(cfg.dataset, split="train")
df = loaded_dataset.to_pandas().iloc[: cfg.samples]

tqdm.pandas(desc="Generating data...")

ds_questions = [
item
for _, row in df.progress_apply(
generate_data,
axis=1,
args=[cfg.num_questions, transcription_chain, question_chain],
).items()
for item in row
]
def merge_evaluate_df(df: pd.DataFrame, row: pd.Series, evaluation_chain) -> pd.DataFrame:
"""
Evaluate the generated data
df_questions = pd.DataFrame(ds_questions)
Args:
df: The dataframe to evaluate
row: The row to evaluate
evaluation_chain: The evaluation chain
Returns:
df: The concatenate dataframw with additional evaluation results,
including conformity, consistency and coverage
"""
# Evaluate
tqdm.pandas(desc="Evaluating...")
df_questions["evaluation"] = df_questions.progress_apply(
row["evaluation"] = row.progress_apply(
evaluate,
args=[evaluation_chain],
axis=1,
)
json_df = json_normalize(df_questions["evaluation"])
json_df = json_normalize(row["evaluation"])

# Combine the original dataframe with the extracted JSON data
df_questions = pd.concat([df_questions, json_df], axis=1)
del df_questions["evaluation"]
row = pd.concat([row, json_df], axis=1)
del row["evaluation"]

# Join df_questions and df
df_joined = df.merge(
df_questions, left_on="transcription", right_on="transcription", how="right"
)
df_joined = df.merge(row, left_on="summary", right_on="summary", how="right")
print(f"Shape of joined dataframe: {df_joined.shape}")
return df_joined

# Log results in wandb

def log_wandb(df: pd.DataFrame) -> dict:
"""
Log the results in wandb
Args:
df: The dataframe to log
Returns:
log_dict: The dictionary containing the results
"""
log_dict = {
f"{stat}/score/{score_type}": (
df_joined[f"{score_type}"].agg(stat)
df[f"{score_type}"].agg(stat)
if stat == "sum"
else df_joined[f"{score_type}"].agg("sum") / len(df_joined)
else df[f"{score_type}"].agg("sum") / len(df)
)
for stat in ["sum", "mean"]
for score_type in ["consistency", "conformity", "coverage"]
}
return log_dict


def process_questions(df: pd.DataFrame, num_questions: int, summary_chain, question_chain):
"""
Generate questions for the the summary
Args:
df: The dataframe to generate questions for
num_questions: The number of questions to generate
summary_chain: The chain for generating summaries
question_chain: The chain for generating questions
Returns:
ds_questions: The generated questions
"""
ds_questions = [
item
for _, row in df.progress_apply(
generate_data,
axis=1,
args=[num_questions, summary_chain, question_chain],
).items()
for item in row
]
return ds_questions


@hydra.main(config_path="./configs", config_name="run_mimoracle.yaml")
def main(cfg: DictConfig):
# Initialize WandB and log the models
wandb.init(project="document-cross-validation", entity="clinical-dream-team")
wandb.config.update(OmegaConf.to_container(cfg, resolve=True))

llm = hydra.utils.instantiate(cfg.model)
question_llm = hydra.utils.instantiate(cfg.question_model)
summary_chain = create_chain(cfg.prompts.summary, llm, False)
question_chain = create_chain(cfg.prompts.question, question_llm, True)
evaluation_chain = create_chain(cfg.prompts.evaluation, question_llm, False)

# Load and process dataset
loaded_dataset = load_dataset(cfg.dataset, split="train")
df = loaded_dataset.to_pandas().iloc[: cfg.samples]

tqdm.pandas(desc="Generating data...")
ds_questions = process_questions(df, cfg.num_questions, summary_chain, question_chain)
print(f"Shape of generated data: {len(ds_questions)}")
df_questions = pd.DataFrame(ds_questions)

# Evaluate
df_joined = merge_evaluate_df(df, df_questions, evaluation_chain)

# Log results in wandb
log_dict = log_wandb(df_joined)
for key, value in log_dict.items():
wandb.run.summary[key] = value
wandb.log({"dataset/evaluation": wandb.Table(dataframe=df_joined)})
wandb.log({"dataset/evaluation_mimoracle_gpt4o_retest": wandb.Table(dataframe=df_joined)})
wandb.finish()


Expand Down
Loading

0 comments on commit 8284158

Please sign in to comment.