diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index d24d4330db9..4d3aecf724b 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -42,13 +42,13 @@ HandoffMessage, MemoryQueryEvent, ModelClientStreamingChunkEvent, - MultiModalMessage, TextMessage, ToolCallExecutionEvent, ToolCallRequestEvent, ToolCallSummaryMessage, ) from ..state import AssistantAgentState +from ..utils import remove_images from ._base_chat_agent import BaseChatAgent event_logger = logging.getLogger(EVENT_LOGGER_NAME) @@ -375,8 +375,6 @@ async def on_messages_stream( ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]: # Add messages to the model context. for msg in messages: - if isinstance(msg, MultiModalMessage) and self._model_client.model_info["vision"] is False: - raise ValueError("The model does not support vision.") if isinstance(msg, HandoffMessage): # Add handoff context to the model context. for context_msg in msg.context: @@ -398,7 +396,7 @@ async def on_messages_stream( yield memory_query_event_msg # Generate an inference result based on the current model context. - llm_messages = self._system_messages + await self._model_context.get_messages() + llm_messages = self._get_compatible_context(self._system_messages + await self._model_context.get_messages()) model_result: CreateResult | None = None if self._model_client_stream: # Stream the model client. @@ -494,7 +492,9 @@ async def on_messages_stream( if self._reflect_on_tool_use: # Generate another inference result based on the tool call and result. - llm_messages = self._system_messages + await self._model_context.get_messages() + llm_messages = self._get_compatible_context( + self._system_messages + await self._model_context.get_messages() + ) reflection_model_result: CreateResult | None = None if self._model_client_stream: # Stream the model client. @@ -575,6 +575,13 @@ async def load_state(self, state: Mapping[str, Any]) -> None: # Load the model context state. await self._model_context.load_state(assistant_agent_state.llm_context) + def _get_compatible_context(self, messages: List[LLMMessage]) -> Sequence[LLMMessage]: + """Ensure that the messages are compatible with the underlying client, by removing images if needed.""" + if self._model_client.model_info["vision"]: + return messages + else: + return remove_images(messages) + def _to_config(self) -> AssistantAgentConfig: """Convert the assistant agent to a declarative config.""" diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/utils/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/utils/__init__.py new file mode 100644 index 00000000000..7fada74afd1 --- /dev/null +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/utils/__init__.py @@ -0,0 +1,7 @@ +""" +This module implements various utilities common to AgentChat agents and teams. +""" + +from ._utils import remove_images + +__all__ = ["remove_images"] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/utils/_utils.py b/python/packages/autogen-agentchat/src/autogen_agentchat/utils/_utils.py new file mode 100644 index 00000000000..c7065277093 --- /dev/null +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/utils/_utils.py @@ -0,0 +1,32 @@ +from typing import List + +from autogen_core import Image +from autogen_core.models import LLMMessage, UserMessage + + +def _image_content_to_str(content: str | List[str | Image]) -> str: + """Convert the content of an LLMMessageto a string.""" + if isinstance(content, str): + return content + else: + result: List[str] = [] + for c in content: + if isinstance(c, str): + result.append(c) + elif isinstance(c, Image): + result.append("") + else: + raise AssertionError("Received unexpected content type.") + + return "\n".join(result) + + +def remove_images(messages: List[LLMMessage]) -> List[LLMMessage]: + """Remove images from a list of LLMMessages""" + str_messages: List[LLMMessage] = [] + for message in messages: + if isinstance(message, UserMessage) and isinstance(message.content, list): + str_messages.append(UserMessage(content=_image_content_to_str(message.content), source=message.source)) + else: + str_messages.append(message) + return str_messages diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 6c2852c667c..3db9924368f 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -21,7 +21,15 @@ from autogen_core import FunctionCall, Image from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult from autogen_core.model_context import BufferedChatCompletionContext -from autogen_core.models import CreateResult, FunctionExecutionResult, LLMMessage, RequestUsage +from autogen_core.models import ( + AssistantMessage, + CreateResult, + FunctionExecutionResult, + LLMMessage, + RequestUsage, + SystemMessage, + UserMessage, +) from autogen_core.models._model_client import ModelFamily from autogen_core.tools import FunctionTool from autogen_ext.models.openai import OpenAIChatCompletionClient @@ -541,15 +549,44 @@ async def test_invalid_model_capabilities() -> None: FunctionTool(_echo_function, description="Echo"), ], ) + await agent.run(task=TextMessage(source="user", content="Test")) with pytest.raises(ValueError): agent = AssistantAgent(name="assistant", model_client=model_client, handoffs=["agent2"]) + await agent.run(task=TextMessage(source="user", content="Test")) - with pytest.raises(ValueError): - agent = AssistantAgent(name="assistant", model_client=model_client) - # Generate a random base64 image. - img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC" - await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)])) + +@pytest.mark.asyncio +async def test_remove_images(monkeypatch: pytest.MonkeyPatch) -> None: + model = "random-model" + model_client_1 = OpenAIChatCompletionClient( + model=model, + api_key="", + model_info={"vision": False, "function_calling": False, "json_output": False, "family": ModelFamily.UNKNOWN}, + ) + model_client_2 = OpenAIChatCompletionClient( + model=model, + api_key="", + model_info={"vision": True, "function_calling": False, "json_output": False, "family": ModelFamily.UNKNOWN}, + ) + + img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC" + messages: List[LLMMessage] = [ + SystemMessage(content="System.1"), + UserMessage(content=["User.1", Image.from_base64(img_base64)], source="user.1"), + AssistantMessage(content="Assistant.1", source="assistant.1"), + UserMessage(content="User.2", source="assistant.2"), + ] + + agent_1 = AssistantAgent(name="assistant_1", model_client=model_client_1) + result = agent_1._get_compatible_context(messages) # type: ignore + assert len(result) == 4 + assert isinstance(result[1].content, str) + + agent_2 = AssistantAgent(name="assistant_2", model_client=model_client_2) + result = agent_2._get_compatible_context(messages) # type: ignore + assert len(result) == 4 + assert isinstance(result[1].content, list) @pytest.mark.asyncio diff --git a/python/packages/autogen-agentchat/tests/test_utils.py b/python/packages/autogen-agentchat/tests/test_utils.py new file mode 100644 index 00000000000..36e98929b91 --- /dev/null +++ b/python/packages/autogen-agentchat/tests/test_utils.py @@ -0,0 +1,36 @@ +from typing import List + +import pytest +from autogen_agentchat.utils import remove_images +from autogen_core import Image +from autogen_core.models import AssistantMessage, LLMMessage, SystemMessage, UserMessage + + +@pytest.mark.asyncio +async def test_remove_images() -> None: + img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC" + messages: List[LLMMessage] = [ + SystemMessage(content="System.1"), + UserMessage(content=["User.1", Image.from_base64(img_base64)], source="user.1"), + AssistantMessage(content="Assistant.1", source="assistant.1"), + UserMessage(content="User.2", source="assistant.2"), + ] + + result = remove_images(messages) + + # Check all the invariants + assert len(result) == 4 + assert isinstance(result[0], SystemMessage) + assert isinstance(result[1], UserMessage) + assert isinstance(result[2], AssistantMessage) + assert isinstance(result[3], UserMessage) + assert result[0].content == messages[0].content + assert result[2].content == messages[2].content + assert result[3].content == messages[3].content + assert isinstance(messages[2], AssistantMessage) + assert isinstance(messages[3], UserMessage) + assert result[2].source == messages[2].source + assert result[3].source == messages[3].source + + # Check that the image was removed. + assert result[1].content == "User.1\n"