Skip to content

Commit

Permalink
Assistant agent drop images when not provided with a vision-capable m…
Browse files Browse the repository at this point in the history
…odel. (#5351)

Allow AssistantAgent to drop images when not equipped with a multi-modal model.

Adds a corresponding utility function, which can be used in autogen-ext and teams, to accomplish the same.
  • Loading branch information
afourney authored Feb 4, 2025
1 parent 5df5bde commit 517e3f0
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@
HandoffMessage,
MemoryQueryEvent,
ModelClientStreamingChunkEvent,
MultiModalMessage,
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from ..state import AssistantAgentState
from ..utils import remove_images
from ._base_chat_agent import BaseChatAgent

event_logger = logging.getLogger(EVENT_LOGGER_NAME)
Expand Down Expand Up @@ -375,8 +375,6 @@ async def on_messages_stream(
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
# Add messages to the model context.
for msg in messages:
if isinstance(msg, MultiModalMessage) and self._model_client.model_info["vision"] is False:
raise ValueError("The model does not support vision.")
if isinstance(msg, HandoffMessage):
# Add handoff context to the model context.
for context_msg in msg.context:
Expand All @@ -398,7 +396,7 @@ async def on_messages_stream(
yield memory_query_event_msg

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + await self._model_context.get_messages()
llm_messages = self._get_compatible_context(self._system_messages + await self._model_context.get_messages())
model_result: CreateResult | None = None
if self._model_client_stream:
# Stream the model client.
Expand Down Expand Up @@ -494,7 +492,9 @@ async def on_messages_stream(

if self._reflect_on_tool_use:
# Generate another inference result based on the tool call and result.
llm_messages = self._system_messages + await self._model_context.get_messages()
llm_messages = self._get_compatible_context(
self._system_messages + await self._model_context.get_messages()
)
reflection_model_result: CreateResult | None = None
if self._model_client_stream:
# Stream the model client.
Expand Down Expand Up @@ -575,6 +575,13 @@ async def load_state(self, state: Mapping[str, Any]) -> None:
# Load the model context state.
await self._model_context.load_state(assistant_agent_state.llm_context)

def _get_compatible_context(self, messages: List[LLMMessage]) -> Sequence[LLMMessage]:
"""Ensure that the messages are compatible with the underlying client, by removing images if needed."""
if self._model_client.model_info["vision"]:
return messages
else:
return remove_images(messages)

def _to_config(self) -> AssistantAgentConfig:
"""Convert the assistant agent to a declarative config."""

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""
This module implements various utilities common to AgentChat agents and teams.
"""

from ._utils import remove_images

__all__ = ["remove_images"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import List

from autogen_core import Image
from autogen_core.models import LLMMessage, UserMessage


def _image_content_to_str(content: str | List[str | Image]) -> str:
"""Convert the content of an LLMMessageto a string."""
if isinstance(content, str):
return content
else:
result: List[str] = []
for c in content:
if isinstance(c, str):
result.append(c)
elif isinstance(c, Image):
result.append("<image>")
else:
raise AssertionError("Received unexpected content type.")

return "\n".join(result)


def remove_images(messages: List[LLMMessage]) -> List[LLMMessage]:
"""Remove images from a list of LLMMessages"""
str_messages: List[LLMMessage] = []
for message in messages:
if isinstance(message, UserMessage) and isinstance(message.content, list):
str_messages.append(UserMessage(content=_image_content_to_str(message.content), source=message.source))
else:
str_messages.append(message)
return str_messages
49 changes: 43 additions & 6 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,15 @@
from autogen_core import FunctionCall, Image
from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult
from autogen_core.model_context import BufferedChatCompletionContext
from autogen_core.models import CreateResult, FunctionExecutionResult, LLMMessage, RequestUsage
from autogen_core.models import (
AssistantMessage,
CreateResult,
FunctionExecutionResult,
LLMMessage,
RequestUsage,
SystemMessage,
UserMessage,
)
from autogen_core.models._model_client import ModelFamily
from autogen_core.tools import FunctionTool
from autogen_ext.models.openai import OpenAIChatCompletionClient
Expand Down Expand Up @@ -541,15 +549,44 @@ async def test_invalid_model_capabilities() -> None:
FunctionTool(_echo_function, description="Echo"),
],
)
await agent.run(task=TextMessage(source="user", content="Test"))

with pytest.raises(ValueError):
agent = AssistantAgent(name="assistant", model_client=model_client, handoffs=["agent2"])
await agent.run(task=TextMessage(source="user", content="Test"))

with pytest.raises(ValueError):
agent = AssistantAgent(name="assistant", model_client=model_client)
# Generate a random base64 image.
img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)]))

@pytest.mark.asyncio
async def test_remove_images(monkeypatch: pytest.MonkeyPatch) -> None:
model = "random-model"
model_client_1 = OpenAIChatCompletionClient(
model=model,
api_key="",
model_info={"vision": False, "function_calling": False, "json_output": False, "family": ModelFamily.UNKNOWN},
)
model_client_2 = OpenAIChatCompletionClient(
model=model,
api_key="",
model_info={"vision": True, "function_calling": False, "json_output": False, "family": ModelFamily.UNKNOWN},
)

img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
messages: List[LLMMessage] = [
SystemMessage(content="System.1"),
UserMessage(content=["User.1", Image.from_base64(img_base64)], source="user.1"),
AssistantMessage(content="Assistant.1", source="assistant.1"),
UserMessage(content="User.2", source="assistant.2"),
]

agent_1 = AssistantAgent(name="assistant_1", model_client=model_client_1)
result = agent_1._get_compatible_context(messages) # type: ignore
assert len(result) == 4
assert isinstance(result[1].content, str)

agent_2 = AssistantAgent(name="assistant_2", model_client=model_client_2)
result = agent_2._get_compatible_context(messages) # type: ignore
assert len(result) == 4
assert isinstance(result[1].content, list)


@pytest.mark.asyncio
Expand Down
36 changes: 36 additions & 0 deletions python/packages/autogen-agentchat/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import List

import pytest
from autogen_agentchat.utils import remove_images
from autogen_core import Image
from autogen_core.models import AssistantMessage, LLMMessage, SystemMessage, UserMessage


@pytest.mark.asyncio
async def test_remove_images() -> None:
img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
messages: List[LLMMessage] = [
SystemMessage(content="System.1"),
UserMessage(content=["User.1", Image.from_base64(img_base64)], source="user.1"),
AssistantMessage(content="Assistant.1", source="assistant.1"),
UserMessage(content="User.2", source="assistant.2"),
]

result = remove_images(messages)

# Check all the invariants
assert len(result) == 4
assert isinstance(result[0], SystemMessage)
assert isinstance(result[1], UserMessage)
assert isinstance(result[2], AssistantMessage)
assert isinstance(result[3], UserMessage)
assert result[0].content == messages[0].content
assert result[2].content == messages[2].content
assert result[3].content == messages[3].content
assert isinstance(messages[2], AssistantMessage)
assert isinstance(messages[3], UserMessage)
assert result[2].source == messages[2].source
assert result[3].source == messages[3].source

# Check that the image was removed.
assert result[1].content == "User.1\n<image>"

0 comments on commit 517e3f0

Please sign in to comment.