-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
136 lines (120 loc) · 3.88 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
Main entry for the project
"""
import argparse
import itertools
from typing import Type
from evaluate import evaluate_dataset, evaluate_in_threads
from solver import ZSCoTSolver, PSCoTSolver, GiveAListSolver, CoTSolver
from loader import AddSub, GSM8K, AQuA, CoinFlip, Problem
from logger import ThreadLogger
logger = ThreadLogger()
def parse_range(range_str: str) -> range:
"""
Parse string to range object
"""
try:
range_parts = range_str.split(",")
if len(range_parts) == 1:
# If there's only one number, return range(n, n+1)
n = int(range_parts[0])
return range(n, n + 1)
elif len(range_parts) == 2:
# If there are two numbers, return range(start, end)
return range(int(range_parts[0]), int(range_parts[1]))
else:
raise ValueError("Invalid input format")
except ValueError as exc:
raise argparse.ArgumentTypeError(
"Range must be a single integer or two comma-separated integers"
) from exc
def build_args() -> argparse.Namespace:
"""
Build args
"""
solver_names = ["zero_shot", "plan_and_solve", "give_a_list"]
dataset_names = ["AddSub", "GSM8K", "AQuA", "CoinFlip"]
parser = argparse.ArgumentParser(
description="Use this script to quickly test the effect of a Solver solving a problem in the dataset"
)
parser.add_argument(
"--solver",
nargs="+",
type=str,
help="Solver to be tested",
required=True,
choices=solver_names,
)
parser.add_argument(
"--dataset",
nargs="+",
type=str,
help="Dataset to be tested",
required=True,
choices=dataset_names,
)
parser.add_argument(
"--debug", action="store_true", help="set logger to debug level"
)
parser.add_argument(
"--model",
type=str,
help="Model to be tested",
)
parser.add_argument(
"--range",
type=parse_range,
help="Range of the problems to be tested",
)
return parser.parse_args()
def main():
"""
Main function to avoid global namespace pollution.
"""
args = build_args()
solver_class_map: dict[str, Type[CoTSolver]] = {
"zero_shot": ZSCoTSolver,
"plan_and_solve": PSCoTSolver,
"give_a_list": GiveAListSolver,
}
solvers: list[str] = args.solver
datasets: list[str] = args.dataset
range_arg: range | None = args.range
model = args.model if args.model else "gpt-4o-mini"
evaluate_in_threads(
solvers=map(lambda s: solver_class_map[s], solvers),
datasets=map(lambda d: globals()[d], datasets),
range_arg=range_arg,
model=model,
debug=args.debug,
)
# for solver, dataset in group:
# logger_file = f"./logs/{solver}_{dataset}.log"
# if range_arg is None and dataset == "GSM8K":
# range_arg = range(0, 400)
# solver_cls = solver_class_map[solver]
# dataset_cls: Type[Problem] = globals()[dataset]
# file_path = f"./dataset/{dataset}.{dataset_cls.file_format()}"
# evaluation_thread = threading.Thread(
# target=evaluate_dataset,
# kwargs={
# "file_path": file_path,
# "dataset": dataset_cls,
# "solver": solver_cls,
# "range_arg": range_arg,
# "answer_type": dataset_cls.answer_type(),
# "model_name": model,
# },
# )
# threads.append(evaluation_thread)
# evaluation_thread.start()
# logger.bind(
# evaluation_thread.ident,
# logger_file,
# "DEBUG" if args.debug else "INFO",
# )
# print(f"Starting evaluation for {solver} on {dataset}")
# for thread in threads:
# thread.join()
if __name__ == "__main__":
main()