-
Notifications
You must be signed in to change notification settings - Fork 0
/
validate_dataset.py
108 lines (92 loc) · 3.12 KB
/
validate_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import argparse
import json
import logging
from pathlib import Path
from datasets import load_dataset
from code_execution.code_trees import safe_ast_parse
from code_execution.eval_dataset import apps
from code_execution.eval_dataset import evaluate_apps
from code_execution.eval_dataset import evaluate_gsm8k
logging.basicConfig(
level=logging.DEBUG,
format="[%(asctime)s - %(levelname)s] %(message)s",
)
def process_ex(ex):
try:
solutions = json.loads(ex["solutions"])
except:
solutions = None
try:
io = json.loads(ex["input_output"])
except:
io = None
return {
**ex,
"solutions": solutions,
"input_output": io,
}
def main(flags):
if flags.dataset == "gsm8k":
dataset = load_dataset("codeparrot/gsm8k", split="test")
if flags.num_examples:
print(f"Selecting {flags.num_examples:,} examples")
dataset = dataset.select(list(range(flags.num_examples)))
metrics, results = evaluate_gsm8k(
dataset.to_list(),
num_workers=4,
timeout=4,
execution_kwargs={"log_freq": 10},
)
elif flags.dataset == "apps":
dataset = load_dataset("codeparrot/apps", split="test")
if flags.num_examples:
print(f"Selecting {flags.num_examples:,} examples")
dataset = dataset.select(list(range(flags.num_examples)))
dataset = dataset.map(
apps.process_raw_example, load_from_cache_file=False
)
dataset = dataset.map(
lambda x: {
**x,
"solutions": [
s for s in x["solutions"] if safe_ast_parse(s) is not None
],
}
)
dataset = dataset.filter(
lambda x: len(x["solutions"]) > 0 and len(x["inputs"]) > 0
)
if flags.sol_per:
dataset = dataset.map(
lambda x: {
**x,
"solutions": x["solutions"][: flags.sol_per],
}
)
print(f"Processing {len(dataset):,} examples")
metrics, results = evaluate_apps(
dataset.to_list(),
num_workers=flags.num_workers,
timeout=4,
execution_kwargs={"log_freq": 10},
command_timeout=3.0,
early_stopping=True,
max_commands=3,
)
out_path = Path("scratch")
out_path.mkdir(exist_ok=True)
with open(out_path / "failed.jsonl", "w") as f:
for r in results:
for p in r["predictions"]:
if not p["passed"]:
f.write(
json.dumps({"problem": r["problem_id"], **p}) + "\n"
)
print(json.dumps(metrics, indent=2))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("dataset", type=str, choices=["apps", "gsm8k"])
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--num_examples", type=int, default=None)
parser.add_argument("--sol_per", type=int, default=None)
main(parser.parse_args())