-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathrun_reflexion.py
101 lines (87 loc) · 2.9 KB
/
run_reflexion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import time
from cmdline import args
from lang import can_be_solution_whole
from lang import code_of_msg
from prompts import prompt, min_lines, check_func, check_string
from scoring import calculate_score_err_whole
from reflection import reflect
import llm
from common import limit_tokens
from lang_config import LANG
import wandb
if args.use_wandb:
wandb.init(
entity=args.wandb_entity,
project=args.wandb_project,
group=args.wandb_group,
config=args.dict(),
name=args.wandb_name,
)
def buildPrompt(prompt, initial_code, text=None, err=None):
if text is not None:
r = reflect(code_of_msg(text), None, err)
else:
r = None
if "CODE" in prompt:
prompt = prompt.split("CODE")[0]
if r:
prompt += "\nTURN:\n"
prompt += r
prompt += "\n\nCODE:\n"
prompt += "\n\n```dafny\n" + initial_code
return prompt
def trial(prompt, initial_code, trial_id=0):
#print("PROMPT: [[\n", prompt, "\n]]")
stats = {"trial_id": trial_id}
init_n_tokens = llm.token_counter
init_time = time.time()
text = llm.generate_full(prompt, max_new_tokens=1000)
score, err = calculate_score_err_whole(text)
is_solution = (
score is not None
and score > 0
and can_be_solution_whole(text, min_lines, check_func, check_string)
)
score_sign = 0 if score is None else (1 if score > 0 else -1)
stats["time"] = time.time() - init_time
stats["text"] = text
stats["is_solution"] = 1 if is_solution else 0
stats["score_sign"] = score_sign
stats["n_tokens"] = llm.token_counter - init_n_tokens
if is_solution:
print("DONE")
return True, text, stats
else:
return False, buildPrompt(prompt, initial_code, text, err), stats
def main(prompt=prompt):
code_tag = f"```{LANG.lower()}\n"
if code_tag in prompt:
parts = prompt.split(code_tag)
assert len(parts)==2
prompt = parts[0]
initial_code = parts[1]
assert "```" not in parts[1]
else:
initial_code = ""
prompt += "\nOnly provide the code; do not provide any explanation or commentary.\n"
prompt = buildPrompt(prompt, initial_code)
init_time = time.time()
done = False
trials = 0
while not done:
trials += 1
solved, prompt, stats = trial(prompt, initial_code, trial_id=trials)
if args.use_wandb:
wandb.log(stats)
print("Token counter: ", llm.token_counter, ", Trial: ", trials)
done = solved or limit_tokens()
print("Token limit: ", limit_tokens())
if args.use_wandb:
final_stats = {}
final_stats["final/n_trials"] = trials
final_stats["final/n_tokens"] = llm.token_counter
final_stats["final/time"] = time.time() - init_time
final_stats["final/solved"] = 1 if solved else 0
wandb.log(final_stats)
if __name__ == "__main__":
main()