Skip to content

Commit

Permalink
Filters bugfix; add metrics and filter to logged sample (#2517)
Browse files Browse the repository at this point in the history
* allow !function filters

* bugfix

* nit

* add `filter` to logged samples

* add `filter` and `metric` to logged samples to identification

* convert `metric` to `metrics`: list
  • Loading branch information
baberabb authored Nov 28, 2024
1 parent 0ef7548 commit 5680a2e
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
12 changes: 8 additions & 4 deletions lm_eval/api/registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Callable, Dict
from typing import Callable, Dict, Union

import evaluate as hf_evaluate

Expand Down Expand Up @@ -185,8 +185,12 @@ def decorate(cls):
return decorate


def get_filter(filter_name: str) -> type:
def get_filter(filter_name: Union[str, Callable]) -> Callable:
try:
return FILTER_REGISTRY[filter_name]
except KeyError:
eval_logger.warning(f"filter `{filter_name}` is not registered!")
except KeyError as e:
if callable(filter_name):
return filter_name
else:
eval_logger.warning(f"filter `{filter_name}` is not registered!")
raise e
2 changes: 2 additions & 0 deletions lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,8 @@ def evaluate(
"filtered_resps": [
req.filtered_resps[filter_key] for req in requests
],
"filter": filter_key,
"metrics": list(metrics.keys()),
"doc_hash": hash_string(
json.dumps(
requests[0].doc,
Expand Down
8 changes: 5 additions & 3 deletions lm_eval/filters/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,18 @@ def filter_set(inst):
if match:
match = match[self.group_select]
if isinstance(match, tuple):
match = [m for m in match if m][0]
match = [m for m in match if m]
if match:
match = match[0]
else:
match = self.fallback
match = match.strip()
else:
match = self.fallback
filtered.append(match)
return filtered

# print(resps)
filtered_resps = list(map(lambda x: filter_set(x), resps))
# print(filtered_resps)

return filtered_resps

Expand Down

0 comments on commit 5680a2e

Please sign in to comment.