From 93d37f98e61e7df8ca264370db142e3d9f7480af Mon Sep 17 00:00:00 2001 From: honghanhh Date: Thu, 17 Oct 2024 14:07:18 +0200 Subject: [PATCH] chore: refactor evaluate function, same comlexity, better readability --- lib/questions_eval/run.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/lib/questions_eval/run.py b/lib/questions_eval/run.py index 8211f3e..9ebc044 100644 --- a/lib/questions_eval/run.py +++ b/lib/questions_eval/run.py @@ -1,6 +1,7 @@ import hydra import pandas as pd import wandb +import itertools from datasets import load_dataset from dotenv import load_dotenv from langchain.output_parsers import OutputFixingParser @@ -164,11 +165,12 @@ def evaluate(row, evaluation_chain) -> dict: The evaluation results, including conformity, consistency and coverage """ results = {} - for i in ["synthetic_transcription", "transcription"]: - for j in ["synthetic_question", "question"]: - results[f"{i}/{j}"] = evaluation_chain.invoke( - {"transcription": row[i], "question": row[j]} - ) + # Same time complexity as the nested loop O(1) but better readability and maintainability + for i, j in itertools.product(["synthetic_transcription", "transcription"], ["synthetic_question", "question"]): + results[f"{i}/{j}"] = evaluation_chain.invoke( + {"transcription": row[i], "question": row[j]} + ) + # Compute conformity, consistency and coverage coverage = 1 if is_idk(results["synthetic_transcription/question"]) else 0 consistency = 1 if not is_idk(results["transcription/synthetic_question"]) else 0