Skip to content

Commit

Permalink
pre-commit format
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaHSR committed Mar 19, 2024
1 parent 0eec611 commit 7eb379b
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 46 deletions.
1 change: 0 additions & 1 deletion data/inference/make_datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
import re


def extract_diff(response):
Expand Down
6 changes: 2 additions & 4 deletions data/inference/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@

try:
# 设置你想要传递给脚本的命令行参数
sys.argv = ['run_api.py', '--dataset_name_or_path', 'princeton-nlp/SWE-bench_oracle', '--output_dir',
'./outputs']
sys.argv = ["run_api.py", "--dataset_name_or_path", "princeton-nlp/SWE-bench_oracle", "--output_dir", "./outputs"]
# 执行脚本
runpy.run_path(path_name='run_api.py', run_name='__main__')
runpy.run_path(path_name="run_api.py", run_name="__main__")
finally:
# 恢复原始的sys.argv以避免对后续代码的潜在影响
sys.argv = original_argv

4 changes: 2 additions & 2 deletions data/inference/run_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tqdm.auto import tqdm

from data.inference.const import SCIKIT_LEARN_IDS
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.roles.di.data_interpreter import DataInterpreter
from metagpt.utils import count_string_tokens
from metagpt.utils.recovery_util import save_history
from data.inference.const import SCIKIT_LEARN_IDS

# Replace with your own
MAX_TOKEN = 128000
Expand Down Expand Up @@ -71,7 +71,7 @@ async def openai_inference(
for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"):
di = DataInterpreter(use_reflection=use_reflection)
instance_id = datum["instance_id"]

if instance_id in existing_ids:
continue
output_dict = {"instance_id": instance_id}
Expand Down
80 changes: 41 additions & 39 deletions metagpt/roles/di/data_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,66 +43,66 @@ class DataInterpreter(Role):
tool_recommender: ToolRecommender = None
react_mode: Literal["plan_and_act", "react"] = "plan_and_act"
max_react_loop: int = 10 # used for react mode

@model_validator(mode="after")
def set_plan_and_tool(self) -> "Interpreter":
self._set_react_mode(react_mode=self.react_mode, max_react_loop=self.max_react_loop, auto_run=self.auto_run)
self.use_plan = (
self.react_mode == "plan_and_act"
self.react_mode == "plan_and_act"
) # create a flag for convenience, overwrite any passed-in value
if self.tools:
self.tool_recommender = BM25ToolRecommender(tools=self.tools)
self.set_actions([WriteAnalysisCode])
self._set_state(0)
return self

@property
def working_memory(self):
return self.rc.working_memory

async def _think(self) -> bool:
"""Useful in 'react' mode. Use LLM to decide whether and what to do next."""
user_requirement = self.get_memories()[0].content
context = self.working_memory.get()

if not context:
# just started the run, we need action certainly
self.working_memory.add(self.get_memories()[0]) # add user requirement to working memory
self._set_state(0)
return True

prompt = REACT_THINK_PROMPT.format(user_requirement=user_requirement, context=context)
rsp = await self.llm.aask(prompt)
rsp_dict = json.loads(CodeParser.parse_code(block=None, text=rsp))
self.working_memory.add(Message(content=rsp_dict["thoughts"], role="assistant"))
need_action = rsp_dict["state"]
self._set_state(0) if need_action else self._set_state(-1)

return need_action

async def _act(self) -> Message:
"""Useful in 'react' mode. Return a Message conforming to Role._act interface."""
code, _, _ = await self._write_and_exec_code()
return Message(content=code, role="assistant", cause_by=WriteAnalysisCode)

async def _plan_and_act(self) -> Message:
rsp = await super()._plan_and_act()
await self.execute_code.terminate()
return rsp

async def _act_on_task(self, current_task: Task) -> TaskResult:
"""Useful in 'plan_and_act' mode. Wrap the output in a TaskResult for review and confirmation."""
code, result, is_success = await self._write_and_exec_code()
task_result = TaskResult(code=code, result=result, is_success=is_success)
return task_result

async def _write_and_exec_code(self, max_retry: int = 3):
counter = 0
success = False

# plan info
plan_status = self.planner.get_plan_status() if self.use_plan else ""

# tool info
if self.tools:
context = (
Expand All @@ -112,66 +112,68 @@ async def _write_and_exec_code(self, max_retry: int = 3):
tool_info = await self.tool_recommender.get_recommended_tool_info(context=context, plan=plan)
else:
tool_info = ""

# data info
await self._check_data()

while not success and counter < max_retry:
### write code ###
code, cause_by = await self._write_code(counter, plan_status, tool_info)

self.working_memory.add(Message(content=code, role="assistant", cause_by=cause_by))

### execute code ###
import pdb;pdb.set_trace()
import pdb

pdb.set_trace()
result, success = await self.execute_code.run(code)
print(result)

self.working_memory.add(Message(content=result, role="user", cause_by=ExecuteNbCode))

### process execution result ###
counter += 1

if not success and counter >= max_retry:
logger.info("coding failed!")
review, _ = await self.planner.ask_review(auto_run=False, trigger=ReviewConst.CODE_REVIEW_TRIGGER)
if ReviewConst.CHANGE_WORDS[0] in review:
counter = 0 # redo the task again with help of human suggestions

return code, result, success

async def _write_code(
self,
counter: int,
plan_status: str = "",
tool_info: str = "",
self,
counter: int,
plan_status: str = "",
tool_info: str = "",
):
todo = self.rc.todo # todo is WriteAnalysisCode
logger.info(f"ready to {todo.name}")
use_reflection = counter > 0 and self.use_reflection # only use reflection after the first trial

user_requirement = self.get_memories()[0].content

code = await todo.run(
user_requirement=user_requirement,
plan_status=plan_status,
tool_info=tool_info,
working_memory=self.working_memory.get(),
use_reflection=use_reflection,
)

return code, todo

async def _check_data(self):
if (
not self.use_plan
or not self.planner.plan.get_finished_tasks()
or self.planner.plan.current_task.task_type
not in [
TaskType.DATA_PREPROCESS.type_name,
TaskType.FEATURE_ENGINEERING.type_name,
TaskType.MODEL_TRAIN.type_name,
]
not self.use_plan
or not self.planner.plan.get_finished_tasks()
or self.planner.plan.current_task.task_type
not in [
TaskType.DATA_PREPROCESS.type_name,
TaskType.FEATURE_ENGINEERING.type_name,
TaskType.MODEL_TRAIN.type_name,
]
):
return
logger.info("Check updated data")
Expand Down

0 comments on commit 7eb379b

Please sign in to comment.