Skip to content

Commit

Permalink
Add text-only model support to M1 (#5344)
Browse files Browse the repository at this point in the history
Modify M1 agents to support text-only settings.
This allows M1 to be used with models like o3-mini and Llama3.1+
  • Loading branch information
afourney authored Feb 4, 2025
1 parent 517e3f0 commit cf6fa77
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from typing import Any, Dict, List, Mapping

from autogen_core import AgentId, CancellationToken, DefaultTopicId, Image, MessageContext, event, rpc
from autogen_core import AgentId, CancellationToken, DefaultTopicId, MessageContext, event, rpc
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
Expand All @@ -24,6 +24,7 @@
ToolCallSummaryMessage,
)
from ....state import MagenticOneOrchestratorState
from ....utils import content_to_str, remove_images
from .._base_group_chat_manager import BaseGroupChatManager
from .._events import (
GroupChatAgentResponse,
Expand Down Expand Up @@ -138,15 +139,17 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No
# Create the initial task ledger
#################################
# Combine all message contents for task
self._task = " ".join([self._content_to_str(msg.content) for msg in message.messages])
self._task = " ".join([content_to_str(msg.content) for msg in message.messages])
planning_conversation: List[LLMMessage] = []

# 1. GATHER FACTS
# create a closed book task and generate a response and update the chat history
planning_conversation.append(
UserMessage(content=self._get_task_ledger_facts_prompt(self._task), source=self._name)
)
response = await self._model_client.create(planning_conversation, cancellation_token=ctx.cancellation_token)
response = await self._model_client.create(
self._get_compatible_context(planning_conversation), cancellation_token=ctx.cancellation_token
)

assert isinstance(response.content, str)
self._facts = response.content
Expand All @@ -157,7 +160,9 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No
planning_conversation.append(
UserMessage(content=self._get_task_ledger_plan_prompt(self._team_description), source=self._name)
)
response = await self._model_client.create(planning_conversation, cancellation_token=ctx.cancellation_token)
response = await self._model_client.create(
self._get_compatible_context(planning_conversation), cancellation_token=ctx.cancellation_token
)

assert isinstance(response.content, str)
self._plan = response.content
Expand Down Expand Up @@ -281,7 +286,7 @@ async def _orchestrate_step(self, cancellation_token: CancellationToken) -> None
assert self._max_json_retries > 0
key_error: bool = False
for _ in range(self._max_json_retries):
response = await self._model_client.create(context, json_output=True)
response = await self._model_client.create(self._get_compatible_context(context), json_output=True)
ledger_str = response.content
try:
assert isinstance(ledger_str, str)
Expand Down Expand Up @@ -397,7 +402,9 @@ async def _update_task_ledger(self, cancellation_token: CancellationToken) -> No
update_facts_prompt = self._get_task_ledger_facts_update_prompt(self._task, self._facts)
context.append(UserMessage(content=update_facts_prompt, source=self._name))

response = await self._model_client.create(context, cancellation_token=cancellation_token)
response = await self._model_client.create(
self._get_compatible_context(context), cancellation_token=cancellation_token
)

assert isinstance(response.content, str)
self._facts = response.content
Expand All @@ -407,7 +414,9 @@ async def _update_task_ledger(self, cancellation_token: CancellationToken) -> No
update_plan_prompt = self._get_task_ledger_plan_update_prompt(self._team_description)
context.append(UserMessage(content=update_plan_prompt, source=self._name))

response = await self._model_client.create(context, cancellation_token=cancellation_token)
response = await self._model_client.create(
self._get_compatible_context(context), cancellation_token=cancellation_token
)

assert isinstance(response.content, str)
self._plan = response.content
Expand All @@ -420,7 +429,9 @@ async def _prepare_final_answer(self, reason: str, cancellation_token: Cancellat
final_answer_prompt = self._get_final_answer_prompt(self._task)
context.append(UserMessage(content=final_answer_prompt, source=self._name))

response = await self._model_client.create(context, cancellation_token=cancellation_token)
response = await self._model_client.create(
self._get_compatible_context(context), cancellation_token=cancellation_token
)
assert isinstance(response.content, str)
message = TextMessage(content=response.content, source=self._name)

Expand Down Expand Up @@ -464,15 +475,9 @@ def _thread_to_context(self) -> List[LLMMessage]:
context.append(UserMessage(content=m.content, source=m.source))
return context

def _content_to_str(self, content: str | List[str | Image]) -> str:
"""Convert the content to a string."""
if isinstance(content, str):
return content
def _get_compatible_context(self, messages: List[LLMMessage]) -> List[LLMMessage]:
"""Ensure that the messages are compatible with the underlying client, by removing images if needed."""
if self._model_client.model_info["vision"]:
return messages
else:
result: List[str] = []
for c in content:
if isinstance(c, str):
result.append(c)
else:
result.append("<image>")
return "\n".join(result)
return remove_images(messages)
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
This module implements various utilities common to AgentChat agents and teams.
"""

from ._utils import remove_images
from ._utils import content_to_str, remove_images

__all__ = ["remove_images"]
__all__ = ["content_to_str", "remove_images"]
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from typing import List
from typing import List, Union

from autogen_core import Image
from autogen_core.models import LLMMessage, UserMessage
from autogen_core import FunctionCall, Image
from autogen_core.models import FunctionExecutionResult, LLMMessage, UserMessage

# Type aliases for convenience
_UserContent = Union[str, List[Union[str, Image]]]
_AssistantContent = Union[str, List[FunctionCall]]
_FunctionExecutionContent = List[FunctionExecutionResult]
_SystemContent = str

def _image_content_to_str(content: str | List[str | Image]) -> str:
"""Convert the content of an LLMMessageto a string."""

def content_to_str(content: _UserContent | _AssistantContent | _FunctionExecutionContent | _SystemContent) -> str:
"""Convert the content of an LLMMessage to a string."""
if isinstance(content, str):
return content
else:
Expand All @@ -16,7 +22,7 @@ def _image_content_to_str(content: str | List[str | Image]) -> str:
elif isinstance(c, Image):
result.append("<image>")
else:
raise AssertionError("Received unexpected content type.")
result.append(str(c))

return "\n".join(result)

Expand All @@ -26,7 +32,7 @@ def remove_images(messages: List[LLMMessage]) -> List[LLMMessage]:
str_messages: List[LLMMessage] = []
for message in messages:
if isinstance(message, UserMessage) and isinstance(message.content, list):
str_messages.append(UserMessage(content=_image_content_to_str(message.content), source=message.source))
str_messages.append(UserMessage(content=content_to_str(message.content), source=message.source))
else:
str_messages.append(message)
return str_messages
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
MultiModalMessage,
TextMessage,
)
from autogen_agentchat.utils import remove_images
from autogen_core import CancellationToken, FunctionCall
from autogen_core.models import (
AssistantMessage,
Expand Down Expand Up @@ -126,7 +127,7 @@ async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[
)

create_result = await self._model_client.create(
messages=history + [context_message, task_message],
messages=self._get_compatible_context(history + [context_message, task_message]),
tools=[
TOOL_OPEN_PATH,
TOOL_PAGE_DOWN,
Expand Down Expand Up @@ -172,3 +173,10 @@ async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[

final_response = "TERMINATE"
return False, final_response

def _get_compatible_context(self, messages: List[LLMMessage]) -> List[LLMMessage]:
"""Ensure that the messages are compatible with the underlying client, by removing images if needed."""
if self._model_client.model_info["vision"]:
return messages
else:
return remove_images(messages)
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage, TextMessage
from autogen_agentchat.utils import content_to_str, remove_images
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, Component, ComponentModel, FunctionCall
from autogen_core import Image as AGImage
from autogen_core.models import (
Expand All @@ -40,7 +41,13 @@
from typing_extensions import Self

from ._events import WebSurferEvent
from ._prompts import WEB_SURFER_OCR_PROMPT, WEB_SURFER_QA_PROMPT, WEB_SURFER_QA_SYSTEM_MESSAGE, WEB_SURFER_TOOL_PROMPT
from ._prompts import (
WEB_SURFER_OCR_PROMPT,
WEB_SURFER_QA_PROMPT,
WEB_SURFER_QA_SYSTEM_MESSAGE,
WEB_SURFER_TOOL_PROMPT_MM,
WEB_SURFER_TOOL_PROMPT_TEXT,
)
from ._set_of_mark import add_set_of_mark
from ._tool_definitions import (
TOOL_CLICK,
Expand All @@ -56,7 +63,6 @@
TOOL_WEB_SEARCH,
)
from ._types import InteractiveRegion, UserContent
from ._utils import message_content_to_str
from .playwright_controller import PlaywrightController


Expand Down Expand Up @@ -215,8 +221,7 @@ def __init__(
raise ValueError(
"The model does not support function calling. MultimodalWebSurfer requires a model that supports function calling."
)
if model_client.model_info["vision"] is False:
raise ValueError("The model is not multimodal. MultimodalWebSurfer requires a multimodal model.")

self._model_client = model_client
self.headless = headless
self.browser_channel = browser_channel
Expand Down Expand Up @@ -404,7 +409,7 @@ async def on_messages_stream(
self.model_usage: List[RequestUsage] = []
try:
content = await self._generate_reply(cancellation_token=cancellation_token)
self._chat_history.append(AssistantMessage(content=message_content_to_str(content), source=self.name))
self._chat_history.append(AssistantMessage(content=content_to_str(content), source=self.name))
final_usage = RequestUsage(
prompt_tokens=sum([u.prompt_tokens for u in self.model_usage]),
completion_tokens=sum([u.completion_tokens for u in self.model_usage]),
Expand Down Expand Up @@ -434,22 +439,8 @@ async def _generate_reply(self, cancellation_token: CancellationToken) -> UserCo

assert self._page is not None

# Clone the messages to give context, removing old screenshots
history: List[LLMMessage] = []
for m in self._chat_history:
assert isinstance(m, UserMessage | AssistantMessage | SystemMessage)
assert isinstance(m.content, str | list)

if isinstance(m.content, str):
history.append(m)
else:
content = message_content_to_str(m.content)
if isinstance(m, UserMessage):
history.append(UserMessage(content=content, source=m.source))
elif isinstance(m, AssistantMessage):
history.append(AssistantMessage(content=content, source=m.source))
elif isinstance(m, SystemMessage):
history.append(SystemMessage(content=content))
# Clone the messages, removing old screenshots
history: List[LLMMessage] = remove_images(self._chat_history)

# Ask the page for interactive elements, then prepare the state-of-mark screenshot
rects = await self._playwright_controller.get_interactive_rects(self._page)
Expand Down Expand Up @@ -512,22 +503,37 @@ async def _generate_reply(self, cancellation_token: CancellationToken) -> UserCo

tool_names = "\n".join([t["name"] for t in tools])

text_prompt = WEB_SURFER_TOOL_PROMPT.format(
url=self._page.url,
visible_targets=visible_targets,
other_targets_str=other_targets_str,
focused_hint=focused_hint,
tool_names=tool_names,
).strip()

# Scale the screenshot for the MLM, and close the original
scaled_screenshot = som_screenshot.resize((self.MLM_WIDTH, self.MLM_HEIGHT))
som_screenshot.close()
if self.to_save_screenshots:
scaled_screenshot.save(os.path.join(self.debug_dir, "screenshot_scaled.png")) # type: ignore
if self._model_client.model_info["vision"]:
text_prompt = WEB_SURFER_TOOL_PROMPT_MM.format(
url=self._page.url,
visible_targets=visible_targets,
other_targets_str=other_targets_str,
focused_hint=focused_hint,
tool_names=tool_names,
).strip()

# Scale the screenshot for the MLM, and close the original
scaled_screenshot = som_screenshot.resize((self.MLM_WIDTH, self.MLM_HEIGHT))
som_screenshot.close()
if self.to_save_screenshots:
scaled_screenshot.save(os.path.join(self.debug_dir, "screenshot_scaled.png")) # type: ignore

# Add the message
history.append(UserMessage(content=[text_prompt, AGImage.from_pil(scaled_screenshot)], source=self.name))
else:
visible_text = await self._playwright_controller.get_visible_text(self._page)

# Add the multimodal message and make the request
history.append(UserMessage(content=[text_prompt, AGImage.from_pil(scaled_screenshot)], source=self.name))
text_prompt = WEB_SURFER_TOOL_PROMPT_TEXT.format(
url=self._page.url,
visible_targets=visible_targets,
other_targets_str=other_targets_str,
focused_hint=focused_hint,
tool_names=tool_names,
visible_text=visible_text.strip(),
).strip()

# Add the message
history.append(UserMessage(content=text_prompt, source=self.name))

response = await self._model_client.create(
history, tools=tools, extra_create_args={"tool_choice": "auto"}, cancellation_token=cancellation_token
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
WEB_SURFER_TOOL_PROMPT = """
WEB_SURFER_TOOL_PROMPT_MM = """
Consider the following screenshot of a web browser, which is open to the page '{url}'. In this screenshot, interactive elements are outlined in bounding boxes of different colors. Each bounding box has a numeric ID label in the same color. Additional information about each visible label is listed below:
{visible_targets}{other_targets_str}{focused_hint}
Expand All @@ -13,6 +13,27 @@
- on some other website entirely (in which case actions like performing a new web search might be the best option)
"""

WEB_SURFER_TOOL_PROMPT_TEXT = """
Your web browser is open to the page '{url}'. The following text is visible in the viewport:
```
{visible_text}
```
You have also identified the following interactive components:
{visible_targets}{other_targets_str}{focused_hint}
You are to respond to the most recent request by selecting an appropriate tool from the following set, or by answering the question directly if possible without tools:
{tool_names}
When deciding between tools, consider if the request can be best addressed by:
- the contents of the current viewport (in which case actions like clicking links, clicking buttons, inputting text might be most appropriate, or hovering over element)
- contents found elsewhere on the full webpage (in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate)
- on some other website entirely (in which case actions like performing a new web search might be the best option)
"""

WEB_SURFER_OCR_PROMPT = """
Please transcribe all visible text on this page, including both main content and the labels of UI elements.
"""
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __init__(

def _validate_client_capabilities(self, client: ChatCompletionClient) -> None:
capabilities = client.model_info
required_capabilities = ["vision", "function_calling", "json_output"]
required_capabilities = ["function_calling", "json_output"]

if not all(capabilities.get(cap) for cap in required_capabilities):
warnings.warn(
Expand Down

0 comments on commit cf6fa77

Please sign in to comment.