Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update distributed group chat example to use streaming API calls as well #4440

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ This example runs a gRPC server using [WorkerAgentRuntimeHost](../../src/autogen
### Setup Python Environment

1. Create a virtual environment as instructed in [README](../../../../../../../../README.md).
2. Run `uv pip install chainlit` in the same virtual environment
2. Run `uv pip install pydantic==2.10.1 chainlit` in the same virtual environment. We have to pin the pydantic version due to [this issue](https://github.com/Chainlit/chainlit/issues/1544)

### General Configuration

Expand Down Expand Up @@ -111,4 +111,3 @@ graph TD;
## TODO:

- [ ] Properly handle chat restarts. It complains about group chat manager being already registered
- [ ] Add streaming to the UI like [this example](https://docs.chainlit.io/advanced-features/streaming) when [this bug](https://github.com/microsoft/autogen/issues/4213) is resolved
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import random
from typing import Awaitable, Callable, List
from typing import AsyncGenerator, Awaitable, Callable, List, Union
from uuid import uuid4

from _types import GroupChatMessage, MessageChunk, RequestToSpeak, UIAgentConfig
Expand All @@ -14,6 +14,7 @@
SystemMessage,
UserMessage,
)
from autogen_core.components.models._types import CreateResult
from rich.console import Console
from rich.markdown import Markdown

Expand Down Expand Up @@ -51,21 +52,23 @@ async def handle_request_to_speak(self, message: RequestToSpeak, ctx: MessageCon
self._chat_history.append(
UserMessage(content=f"Transferred to {self.id.type}, adopt the persona immediately.", source="system")
)
completion = await self._model_client.create([self._system_message] + self._chat_history)
assert isinstance(completion.content, str)
self._chat_history.append(AssistantMessage(content=completion.content, source=self.id.type))

console_message = f"\n{'-'*80}\n**{self.id.type}**: {completion.content}"
self.console.print(Markdown(console_message))

await publish_message_to_ui_and_backend(
stream_output = self._model_client.create_stream(
messages=[self._system_message] + self._chat_history, max_consecutive_empty_chunk_tolerance=3
)
create_stream_result = await publish_message_stream_to_ui_and_backend(
runtime=self,
source=self.id.type,
user_message=completion.content,
stream_output=stream_output,
ui_config=self._ui_config,
group_chat_topic_type=self._group_chat_topic_type,
)

if create_stream_result is not None:
self._chat_history.append(AssistantMessage(content=create_stream_result.content, source=self.id.type))

console_message = f"\n{'-'*80}\n**{self.id.type}**: {create_stream_result.content}"
self.console.print(Markdown(console_message))


class GroupChatManager(RoutedAgent):
def __init__(
Expand Down Expand Up @@ -168,12 +171,72 @@ async def handle_message_chunk(self, message: MessageChunk, ctx: MessageContext)
await self._on_message_chunk_func(message)


async def publish_message_stream_to_ui(
runtime: RoutedAgent | WorkerAgentRuntime,
source: str,
ui_config: UIAgentConfig,
stream_output: AsyncGenerator[Union[str, CreateResult], None],
) -> None:
"""Publishes a stream of messages to the UI."""
message_id = str(uuid4())
async for chunk in stream_output:
if isinstance(chunk, str):
msg_chunk = MessageChunk(message_id=message_id, text=str(chunk), author=source, finished=False)

await runtime.publish_message(
msg_chunk,
DefaultTopicId(type=ui_config.topic_type),
)
await asyncio.sleep(random.uniform(ui_config.min_delay, ui_config.max_delay))
elif isinstance(chunk, CreateResult):
print("Ok, finished the message!")
await runtime.publish_message(
MessageChunk(message_id=message_id, text=" ", author=source, finished=True),
DefaultTopicId(type=ui_config.topic_type),
)


async def publish_message_stream_to_ui_and_backend(
runtime: RoutedAgent | WorkerAgentRuntime,
source: str,
ui_config: UIAgentConfig,
group_chat_topic_type: str,
stream_output: AsyncGenerator[Union[str, CreateResult], None],
) -> None | CreateResult:
"""Publishes a stream of messages to both the UI and backend."""
message_id = str(uuid4())
async for chunk in stream_output:
if isinstance(chunk, str):
msg_chunk = MessageChunk(message_id=message_id, text=str(chunk), author=source, finished=False)

await runtime.publish_message(
msg_chunk,
DefaultTopicId(type=ui_config.topic_type),
)
await asyncio.sleep(random.uniform(ui_config.min_delay, ui_config.max_delay))
elif isinstance(chunk, CreateResult):
print("Ok, finished the message!")
await runtime.publish_message(
MessageChunk(message_id=message_id, text=" ", author=source, finished=True),
DefaultTopicId(type=ui_config.topic_type),
)
# Publish message to backend
await runtime.publish_message(
GroupChatMessage(body=UserMessage(content=str(chunk.content), source=source)),
topic_id=DefaultTopicId(type=group_chat_topic_type),
)
return chunk

return None


async def publish_message_to_ui(
runtime: RoutedAgent | WorkerAgentRuntime,
source: str,
user_message: str,
ui_config: UIAgentConfig,
) -> None:
"""Publishes a single message to the UI."""
message_id = str(uuid4())
# Stream the message to UI
message_chunks = (
Expand All @@ -200,6 +263,7 @@ async def publish_message_to_ui_and_backend(
ui_config: UIAgentConfig,
group_chat_topic_type: str,
) -> None:
"""Publishes a single message to both the UI and backend."""
# Publish messages for ui
await publish_message_to_ui(
runtime=runtime,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#!/bin/bash
# # Start a new tmux session named 'distributed_group_chat'
# These line are to supress https://stackoverflow.com/questions/78780089
export GRPC_VERBOSITY=ERROR
export GLOG_minloglevel=2

tmux new-session -d -s distributed_group_chat

# # Split the terminal into 2 vertical panes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def create_stream(
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
max_consecutive_empty_chunk_tolerance: int = 0,
) -> AsyncGenerator[Union[str, CreateResult], None]: ...

def actual_usage(self) -> RequestUsage: ...
Expand Down
1 change: 1 addition & 0 deletions python/packages/autogen-core/tests/test_tool_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def create_stream(
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
max_consecutive_empty_chunk_tolerance: int = 0,
) -> AsyncGenerator[Union[str, CreateResult], None]:
raise NotImplementedError()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,6 @@ async def create_stream(
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
*,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jackgerrits do we need this here? Can I possibly be breaking anything by removing it?
I have added the max_consecutive_empty_chunk_tolerance to the ChatCompletionClient and it's implementations and had to remove this * as it was raising type check errors (in theory non named parameters, the *, should be after the named parameters, but I was getting pyright and mypy errors for it complaining about number of arguments in the overriden methods ...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we might need to put the * before tools, both in the interface and the implementations. @jackgerrits

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a separate PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all extra params should come after the * to avoid too many positional params

max_consecutive_empty_chunk_tolerance: int = 0,
) -> AsyncGenerator[Union[str, CreateResult], None]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ async def create_stream(
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
max_consecutive_empty_chunk_tolerance: int = 0,
) -> AsyncGenerator[Union[str, CreateResult], None]:
"""Return the next completion as a stream."""
if self._current_index >= len(self.chat_completions):
Expand Down
Loading