Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: save/load test for dotnet agents #5284

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) =>
return ValueTask.FromResult(agent);
});

// Ensure the agent is actually created
// Ensure the agent id is registered
AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false);

// Validate agent ID
Expand Down Expand Up @@ -148,23 +148,50 @@ await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) =>
}

[Fact]
public async Task AgentShouldSaveStateCorrectlyTest()
public async Task AgentShouldSaveLoadStateCorrectlyTest()
{
var runtime = new InProcessRuntime();
await runtime.StartAsync();

Logger<BaseAgent> logger = new(new LoggerFactory());
TestAgent agent = new TestAgent(new AgentId("TestType", "TestKey"), runtime, logger);
SubscribedSaveLoadAgent agent = null!;

await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) =>
{
agent = new SubscribedSaveLoadAgent(id, runtime, logger);
return ValueTask.FromResult(agent);
});

// Ensure the agent id is registered
AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false);

// Validate agent ID
agentId.Should().Be(agent.Id, "Agent ID should match the registered agent");

await runtime.RegisterImplicitAgentSubscriptionsAsync<SubscribedSaveLoadAgent>("MyAgent");

var topicType = "TestTopic";

await runtime.PublishMessageAsync(new TextMessage { Source = topicType, Content = "test" }, new TopicId(topicType)).ConfigureAwait(true);

await runtime.RunUntilIdleAsync();

agent.ReceivedMessages.Any().Should().BeTrue("Agent should receive messages when subscribed.");

// Save the state
var savedState = await agent.SaveStateAsync();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lokitoth can we require that the result of this be json serializable? This is a requirement on the python side too afaik


// Ensure the state contains receivedMessages
savedState.Should().ContainKey("receivedMessages");
savedState["receivedMessages"].Should().BeOfType<Dictionary<string, object>>();

var state = await agent.SaveStateAsync();
// Create a new instance of the agent to simulate a restart
var newAgent = new SubscribedSaveLoadAgent(agent.Id, runtime, logger);

// Ensure state is a dictionary
state.Should().NotBeNull();
state.Should().BeOfType<Dictionary<string, object>>();
state.Should().BeEmpty("Default SaveStateAsync should return an empty dictionary.");
// Load the saved state into the new agent
await newAgent.LoadStateAsync(savedState);

// Add a sample value and verify it updates correctly
state["testKey"] = "testValue";
state.Should().ContainKey("testKey").WhoseValue.Should().Be("testValue");
// Verify that the loaded state contains the received message
newAgent.ReceivedMessages.Should().ContainKey(topicType).WhoseValue.Should().Be("test");
}
}
37 changes: 36 additions & 1 deletion dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public ValueTask<string> HandleAsync(RpcTextMessage item, MessageContext message
/// Key: source
/// Value: message
/// </summary>
private readonly Dictionary<string, object> _receivedMessages = new();
protected Dictionary<string, object> _receivedMessages = new();
public Dictionary<string, object> ReceivedMessages => _receivedMessages;
}

Expand All @@ -72,3 +72,38 @@ public SubscribedAgent(AgentId id,
{
}
}

[TypeSubscription("TestTopic")]
public class SubscribedSaveLoadAgent : TestAgent
{
private const string SavedStateKey = "receivedMessages";

public SubscribedSaveLoadAgent(AgentId id,
IAgentRuntime runtime,
Logger<BaseAgent>? logger = null) : base(id, runtime, logger)
{
}

public override ValueTask<IDictionary<string, object>> SaveStateAsync()
{
return ValueTask.FromResult<IDictionary<string, object>>(new Dictionary<string, object>
{
{ SavedStateKey, new Dictionary<string, object>(_receivedMessages) } // Save _receivedMessages
});
}

public override ValueTask LoadStateAsync(IDictionary<string, object> state)
{
if (state.TryGetValue(SavedStateKey, out var loadedMessagesObj) &&
loadedMessagesObj is Dictionary<string, object> loadedMessages)
{
_receivedMessages.Clear();
foreach (var kvp in loadedMessages)
{
_receivedMessages[kvp.Key] = kvp.Value;
}
}

return ValueTask.CompletedTask;
}
}
Loading