Skip to content

Commit

Permalink
run pre-commit - fix arg type
Browse files Browse the repository at this point in the history
Signed-off-by: Mírian Silva <[email protected]
  • Loading branch information
mirianfsilva committed Nov 26, 2024
1 parent 8f6de73 commit 4863977
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 15 deletions.
10 changes: 6 additions & 4 deletions lm_eval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,13 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if args.examples:
assert args.limit is None, "If --examples is not None, then --limit must be None."
assert (
args.limit is None
), "If --examples is not None, then --limit must be None."
limit = None
with open(args.examples, 'r') as json_file:
examples = json.load(json_file)
with open(args.examples, "r") as json_file:
examples = json.load(json_file)

if args.tasks is None:
eval_logger.error("Need to specify task to evaluate.")
sys.exit()
Expand Down
24 changes: 18 additions & 6 deletions lm_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,9 @@ def build_all_requests(
limit = None

doc_id_docs = list(
self.doc_iterator(rank=rank, limit=limit, examples=examples, world_size=world_size)
self.doc_iterator(
rank=rank, limit=limit, examples=examples, world_size=world_size
)
)

num_docs = len(doc_id_docs)
Expand Down Expand Up @@ -677,18 +679,28 @@ def eval_docs(self) -> Union[datasets.Dataset, List[dict]]:
)

def doc_iterator(
self, *, rank: int = 0,
self,
*,
rank: int = 0,
limit: Union[int, None] = None,
examples: Optional[List[int]] = None,
world_size: int = 1
world_size: int = 1,
) -> Iterator[Tuple[int, Any]]:
if examples:
n = self.eval_docs.to_pandas().shape[0]
assert all([e<n for e in examples]), f"Elements of --examples should be in the interval [0,k-1] where k is the number of total examples. In this case, k={n}."
assert all(
[e < n for e in examples]
), f"Elements of --examples should be in the interval [0,k-1] where k is the number of total examples. In this case, k={n}."
doc_iterator = utils.create_iterator(
enumerate(datasets.Dataset.from_pandas(self.eval_docs.to_pandas().iloc[examples,:].reset_index(drop=True))),
enumerate(
datasets.Dataset.from_pandas(
self.eval_docs.to_pandas()
.iloc[examples, :]
.reset_index(drop=True)
)
),
rank=int(rank),
limit=None, #limit does not matter here since we are selecting samples directly
limit=None, # limit does not matter here since we are selecting samples directly
world_size=int(world_size),
)
else:
Expand Down
15 changes: 10 additions & 5 deletions lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def simple_evaluate(
rewrite_requests_cache: bool = False,
delete_requests_cache: bool = False,
limit: Optional[Union[int, float]] = None,
examples: Optional[Dict] = None,
examples: Optional[dict] = None,
bootstrap_iters: int = 100000,
check_integrity: bool = False,
write_out: bool = False,
Expand Down Expand Up @@ -365,7 +365,7 @@ def evaluate(
lm: "LM",
task_dict,
limit: Optional[int] = None,
examples: Optional[Dict] = None,
examples: Optional[dict] = None,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
bootstrap_iters: Optional[int] = 100000,
Expand Down Expand Up @@ -535,11 +535,16 @@ def evaluate(
# iterate over different filters used
for filter_key in task.instances[0].filtered_resps.keys():
doc_iterator = task.doc_iterator(
rank=RANK, limit=limit, examples=examples[task_output.task_name], world_size=WORLD_SIZE
rank=RANK,
limit=limit,
examples=examples[task_output.task_name],
world_size=WORLD_SIZE,
)
for doc_id, doc in doc_iterator:
if examples: doc_id_true = examples[task_output.task_name][doc_id]
else: doc_id_true = doc_id
if examples:
doc_id_true = examples[task_output.task_name][doc_id]
else:
doc_id_true = doc_id
requests = instances_by_doc_id[doc_id]
metrics = task.process_results(
doc, [req.filtered_resps[filter_key] for req in requests]
Expand Down

0 comments on commit 4863977

Please sign in to comment.