diff --git a/src/Custom/Assistants/Streaming/AsyncStreamingUpdateCollection.cs b/src/Custom/Assistants/Streaming/AsyncStreamingUpdateCollection.cs index b3f87cf0..c7640f02 100644 --- a/src/Custom/Assistants/Streaming/AsyncStreamingUpdateCollection.cs +++ b/src/Custom/Assistants/Streaming/AsyncStreamingUpdateCollection.cs @@ -3,6 +3,8 @@ using System.ClientModel.Primitives; using System.Collections.Generic; using System.Diagnostics; +using System.Linq; +using System.Net.ServerSentEvents; using System.Threading; using System.Threading.Tasks; @@ -31,7 +33,7 @@ public override IAsyncEnumerator GetAsyncEnumerator(Cancellatio private sealed class AsyncStreamingUpdateEnumerator : IAsyncEnumerator { - private const string _terminalData = "[DONE]"; + private static ReadOnlySpan TerminalData => "[DONE]"u8; private readonly Func> _getResultAsync; private readonly AsyncStreamingUpdateCollection _enumerable; @@ -44,7 +46,7 @@ private sealed class AsyncStreamingUpdateEnumerator : IAsyncEnumerator? _events; + private IAsyncEnumerator>? _events; private IEnumerator? _updates; private StreamingUpdate? _current; @@ -84,7 +86,7 @@ async ValueTask IAsyncEnumerator.MoveNextAsync() if (await _events.MoveNextAsync().ConfigureAwait(false)) { - if (_events.Current.Data == _terminalData) + if (_events.Current.Data.AsSpan().SequenceEqual(TerminalData)) { _current = default; return false; @@ -104,7 +106,7 @@ async ValueTask IAsyncEnumerator.MoveNextAsync() return false; } - private async Task> CreateEventEnumeratorAsync() + private async Task>> CreateEventEnumeratorAsync() { ClientResult result = await _getResultAsync().ConfigureAwait(false); PipelineResponse response = result.GetRawResponse(); @@ -115,7 +117,7 @@ private async Task> CreateEventEnumeratorAsync throw new InvalidOperationException("Unable to create result from response with null ContentStream"); } - AsyncServerSentEventEnumerable enumerable = new(response.ContentStream); + IAsyncEnumerable> enumerable = SseParser.Create(response.ContentStream, (_, bytes) => bytes.ToArray()).EnumerateAsync(); return enumerable.GetAsyncEnumerator(_cancellationToken); } diff --git a/src/Custom/Assistants/Streaming/StreamingUpdate.cs b/src/Custom/Assistants/Streaming/StreamingUpdate.cs index 577b5ac3..7940993d 100644 --- a/src/Custom/Assistants/Streaming/StreamingUpdate.cs +++ b/src/Custom/Assistants/Streaming/StreamingUpdate.cs @@ -1,4 +1,5 @@ using System.Collections.Generic; +using System.Net.ServerSentEvents; using System.Text.Json; namespace OpenAI.Assistants; @@ -38,7 +39,7 @@ internal StreamingUpdate(StreamingUpdateReason updateKind) UpdateKind = updateKind; } - internal static IEnumerable FromEvent(ServerSentEvent sseItem) + internal static IEnumerable FromEvent(SseItem sseItem) { StreamingUpdateReason updateKind = StreamingUpdateReasonExtensions.FromSseEventLabel(sseItem.EventType); using JsonDocument dataDocument = JsonDocument.Parse(sseItem.Data); diff --git a/src/Custom/Assistants/Streaming/StreamingUpdateCollection.cs b/src/Custom/Assistants/Streaming/StreamingUpdateCollection.cs index 4f6e6c94..85f51273 100644 --- a/src/Custom/Assistants/Streaming/StreamingUpdateCollection.cs +++ b/src/Custom/Assistants/Streaming/StreamingUpdateCollection.cs @@ -4,6 +4,7 @@ using System.Collections; using System.Collections.Generic; using System.Diagnostics; +using System.Net.ServerSentEvents; #nullable enable @@ -30,7 +31,7 @@ public override IEnumerator GetEnumerator() private sealed class StreamingUpdateEnumerator : IEnumerator { - private const string _terminalData = "[DONE]"; + private static ReadOnlySpan TerminalData => "[DONE]"u8; private readonly Func _getResult; private readonly StreamingUpdateCollection _enumerable; @@ -42,7 +43,7 @@ private sealed class StreamingUpdateEnumerator : IEnumerator // // get _updates from sse event // foreach (var update in _updates) { ... } // } - private IEnumerator? _events; + private IEnumerator>? _events; private IEnumerator? _updates; private StreamingUpdate? _current; @@ -81,7 +82,7 @@ public bool MoveNext() if (_events.MoveNext()) { - if (_events.Current.Data == _terminalData) + if (_events.Current.Data.AsSpan().SequenceEqual(TerminalData)) { _current = default; return false; @@ -101,7 +102,7 @@ public bool MoveNext() return false; } - private IEnumerator CreateEventEnumerator() + private IEnumerator> CreateEventEnumerator() { ClientResult result = _getResult(); PipelineResponse response = result.GetRawResponse(); @@ -112,7 +113,7 @@ private IEnumerator CreateEventEnumerator() throw new InvalidOperationException("Unable to create result from response with null ContentStream"); } - ServerSentEventEnumerable enumerable = new(response.ContentStream); + IEnumerable> enumerable = SseParser.Create(response.ContentStream, (_, bytes) => bytes.ToArray()).Enumerate(); return enumerable.GetEnumerator(); } diff --git a/src/Custom/Chat/Internal/AsyncStreamingChatCompletionUpdateCollection.cs b/src/Custom/Chat/Internal/AsyncStreamingChatCompletionUpdateCollection.cs index 320cedbf..55bbeb7d 100644 --- a/src/Custom/Chat/Internal/AsyncStreamingChatCompletionUpdateCollection.cs +++ b/src/Custom/Chat/Internal/AsyncStreamingChatCompletionUpdateCollection.cs @@ -3,6 +3,7 @@ using System.ClientModel.Primitives; using System.Collections.Generic; using System.Diagnostics; +using System.Net.ServerSentEvents; using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -32,7 +33,7 @@ public override IAsyncEnumerator GetAsyncEnumerat private sealed class AsyncStreamingChatUpdateEnumerator : IAsyncEnumerator { - private const string _terminalData = "[DONE]"; + private static ReadOnlySpan TerminalData => "[DONE]"u8; private readonly Func> _getResultAsync; private readonly AsyncStreamingChatCompletionUpdateCollection _enumerable; @@ -45,7 +46,7 @@ private sealed class AsyncStreamingChatUpdateEnumerator : IAsyncEnumerator? _events; + private IAsyncEnumerator>? _events; private IEnumerator? _updates; private StreamingChatCompletionUpdate? _current; @@ -85,7 +86,7 @@ async ValueTask IAsyncEnumerator.MoveNextAs if (await _events.MoveNextAsync().ConfigureAwait(false)) { - if (_events.Current.Data == _terminalData) + if (_events.Current.Data.AsSpan().SequenceEqual(TerminalData)) { _current = default; return false; @@ -106,7 +107,7 @@ async ValueTask IAsyncEnumerator.MoveNextAs return false; } - private async Task> CreateEventEnumeratorAsync() + private async Task>> CreateEventEnumeratorAsync() { ClientResult result = await _getResultAsync().ConfigureAwait(false); PipelineResponse response = result.GetRawResponse(); @@ -117,7 +118,7 @@ private async Task> CreateEventEnumeratorAsync throw new InvalidOperationException("Unable to create result from response with null ContentStream"); } - AsyncServerSentEventEnumerable enumerable = new(response.ContentStream); + IAsyncEnumerable> enumerable = SseParser.Create(response.ContentStream, (_, bytes) => bytes.ToArray()).EnumerateAsync(); return enumerable.GetAsyncEnumerator(_cancellationToken); } diff --git a/src/Custom/Chat/Internal/StreamingChatCompletionUpdateCollection.cs b/src/Custom/Chat/Internal/StreamingChatCompletionUpdateCollection.cs index fe454699..0e51009e 100644 --- a/src/Custom/Chat/Internal/StreamingChatCompletionUpdateCollection.cs +++ b/src/Custom/Chat/Internal/StreamingChatCompletionUpdateCollection.cs @@ -4,6 +4,7 @@ using System.Collections; using System.Collections.Generic; using System.Diagnostics; +using System.Net.ServerSentEvents; using System.Text.Json; #nullable enable @@ -31,7 +32,7 @@ public override IEnumerator GetEnumerator() private sealed class StreamingChatUpdateEnumerator : IEnumerator { - private const string _terminalData = "[DONE]"; + private static ReadOnlySpan TerminalData => "[DONE]"u8; private readonly Func _getResult; private readonly StreamingChatCompletionUpdateCollection _enumerable; @@ -43,7 +44,7 @@ private sealed class StreamingChatUpdateEnumerator : IEnumerator? _events; + private IEnumerator>? _events; private IEnumerator? _updates; private StreamingChatCompletionUpdate? _current; @@ -82,7 +83,7 @@ public bool MoveNext() if (_events.MoveNext()) { - if (_events.Current.Data == _terminalData) + if (_events.Current.Data.AsSpan().SequenceEqual(TerminalData)) { _current = default; return false; @@ -103,7 +104,7 @@ public bool MoveNext() return false; } - private IEnumerator CreateEventEnumerator() + private IEnumerator> CreateEventEnumerator() { ClientResult result = _getResult(); PipelineResponse response = result.GetRawResponse(); @@ -114,7 +115,7 @@ private IEnumerator CreateEventEnumerator() throw new InvalidOperationException("Unable to create result from response with null ContentStream"); } - ServerSentEventEnumerable enumerable = new(response.ContentStream); + IEnumerable> enumerable = SseParser.Create(response.ContentStream, (_, bytes) => bytes.ToArray()).Enumerate(); return enumerable.GetEnumerator(); } diff --git a/src/OpenAI.csproj b/src/OpenAI.csproj index 4dd14c0f..6da549b4 100644 --- a/src/OpenAI.csproj +++ b/src/OpenAI.csproj @@ -41,6 +41,13 @@ $(NoWarn),0169 + + + true + + true diff --git a/src/Utility/AsyncServerSentEventEnumerable.cs b/src/Utility/AsyncServerSentEventEnumerable.cs deleted file mode 100644 index 98ba2501..00000000 --- a/src/Utility/AsyncServerSentEventEnumerable.cs +++ /dev/null @@ -1,82 +0,0 @@ -using System; -using System.Collections.Generic; -using System.IO; -using System.Threading; -using System.Threading.Tasks; - -#nullable enable - -namespace OpenAI; - -/// -/// Represents a collection of SSE events that can be enumerated as a C# async stream. -/// -internal class AsyncServerSentEventEnumerable : IAsyncEnumerable -{ - private readonly Stream _contentStream; - - public AsyncServerSentEventEnumerable(Stream contentStream) - { - Argument.AssertNotNull(contentStream, nameof(contentStream)); - - _contentStream = contentStream; - - LastEventId = string.Empty; - ReconnectionInterval = Timeout.InfiniteTimeSpan; - } - - public string LastEventId { get; private set; } - - public TimeSpan ReconnectionInterval { get; private set; } - - public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) - { - return new AsyncServerSentEventEnumerator(_contentStream, this, cancellationToken); - } - - private sealed class AsyncServerSentEventEnumerator : IAsyncEnumerator - { - private readonly ServerSentEventReader _reader; - private readonly AsyncServerSentEventEnumerable _enumerable; - private readonly CancellationToken _cancellationToken; - - public ServerSentEvent Current { get; private set; } - - public AsyncServerSentEventEnumerator(Stream contentStream, - AsyncServerSentEventEnumerable enumerable, - CancellationToken cancellationToken = default) - { - _reader = new(contentStream); - _enumerable = enumerable; - _cancellationToken = cancellationToken; - } - - public async ValueTask MoveNextAsync() - { - ServerSentEvent? nextEvent = await _reader.TryGetNextEventAsync(_cancellationToken).ConfigureAwait(false); - _enumerable.LastEventId = _reader.LastEventId; - _enumerable.ReconnectionInterval = _reader.ReconnectionInterval; - - if (nextEvent.HasValue) - { - Current = nextEvent.Value; - return true; - } - - Current = default; - return false; - } - - public ValueTask DisposeAsync() - { - // The creator of the enumerable has responsibility for disposing - // the content stream passed to the enumerable constructor. - -#if NET6_0_OR_GREATER - return ValueTask.CompletedTask; -#else - return new ValueTask(); -#endif - } - } -} diff --git a/src/Utility/ServerSentEvent.cs b/src/Utility/ServerSentEvent.cs deleted file mode 100644 index 1ef24963..00000000 --- a/src/Utility/ServerSentEvent.cs +++ /dev/null @@ -1,24 +0,0 @@ -#nullable enable - -namespace OpenAI; - -/// -/// Represents an SSE event. -/// See SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html -/// -internal readonly struct ServerSentEvent -{ - // Gets the value of the SSE "event type" buffer, used to distinguish - // between event kinds. - public string EventType { get; } - - // Gets the value of the SSE "data" buffer, which holds the payload of the - // server-sent event. - public string Data { get; } - - public ServerSentEvent(string type, string data) - { - EventType = type; - Data = data; - } -} diff --git a/src/Utility/ServerSentEventEnumerable.cs b/src/Utility/ServerSentEventEnumerable.cs deleted file mode 100644 index 6f5c061c..00000000 --- a/src/Utility/ServerSentEventEnumerable.cs +++ /dev/null @@ -1,81 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Generic; -using System.IO; -using System.Threading; - -#nullable enable - -namespace OpenAI; - -/// -/// Represents a collection of SSE events that can be enumerated as a C# collection. -/// -internal class ServerSentEventEnumerable : IEnumerable -{ - private readonly Stream _contentStream; - - public ServerSentEventEnumerable(Stream contentStream) - { - Argument.AssertNotNull(contentStream, nameof(contentStream)); - - _contentStream = contentStream; - - LastEventId = string.Empty; - ReconnectionInterval = Timeout.InfiniteTimeSpan; - } - - public string LastEventId { get; private set; } - - public TimeSpan ReconnectionInterval { get; private set; } - - public IEnumerator GetEnumerator() - { - return new ServerSentEventEnumerator(_contentStream, this); - } - - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - - private sealed class ServerSentEventEnumerator : IEnumerator - { - private readonly ServerSentEventReader _reader; - private readonly ServerSentEventEnumerable _enumerable; - - public ServerSentEventEnumerator(Stream contentStream, ServerSentEventEnumerable enumerable) - { - _reader = new(contentStream); - _enumerable = enumerable; - } - - public ServerSentEvent Current { get; private set; } - - object IEnumerator.Current => Current; - - public bool MoveNext() - { - ServerSentEvent? nextEvent = _reader.TryGetNextEvent(); - _enumerable.LastEventId = _reader.LastEventId; - _enumerable.ReconnectionInterval = _reader.ReconnectionInterval; - - if (nextEvent.HasValue) - { - Current = nextEvent.Value; - return true; - } - - Current = default; - return false; - } - - public void Reset() - { - throw new NotSupportedException("Cannot seek back in an SSE stream."); - } - - public void Dispose() - { - // The creator of the enumerable has responsibility for disposing - // the content stream passed to the enumerable constructor. - } - } -} diff --git a/src/Utility/ServerSentEventField.cs b/src/Utility/ServerSentEventField.cs deleted file mode 100644 index ddea207b..00000000 --- a/src/Utility/ServerSentEventField.cs +++ /dev/null @@ -1,55 +0,0 @@ -using System; - -#nullable enable - -namespace OpenAI; - -/// -/// Represents a field that can be composed into an SSE event. -/// See SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html -/// -internal readonly struct ServerSentEventField -{ - private static readonly ReadOnlyMemory s_eventFieldName = "event".AsMemory(); - private static readonly ReadOnlyMemory s_dataFieldName = "data".AsMemory(); - private static readonly ReadOnlyMemory s_lastEventIdFieldName = "id".AsMemory(); - private static readonly ReadOnlyMemory s_retryFieldName = "retry".AsMemory(); - - public ServerSentEventFieldKind FieldType { get; } - - // Note: we don't plan to expose UTF16 publicly - public ReadOnlyMemory Value { get; } - - internal ServerSentEventField(string line) - { - int colonIndex = line.AsSpan().IndexOf(':'); - - ReadOnlyMemory fieldName = colonIndex < 0 ? - line.AsMemory() : - line.AsMemory(0, colonIndex); - - FieldType = fieldName.Span switch - { - var x when x.SequenceEqual(s_eventFieldName.Span) => ServerSentEventFieldKind.Event, - var x when x.SequenceEqual(s_dataFieldName.Span) => ServerSentEventFieldKind.Data, - var x when x.SequenceEqual(s_lastEventIdFieldName.Span) => ServerSentEventFieldKind.Id, - var x when x.SequenceEqual(s_retryFieldName.Span) => ServerSentEventFieldKind.Retry, - _ => ServerSentEventFieldKind.Ignore, - }; - - if (colonIndex < 0) - { - Value = ReadOnlyMemory.Empty; - } - else - { - Value = line.AsMemory(colonIndex + 1); - - // Per spec, remove a leading space if present. - if (Value.Length > 0 && Value.Span[0] == ' ') - { - Value = Value.Slice(1); - } - } - } -} diff --git a/src/Utility/ServerSentEventFieldKind.cs b/src/Utility/ServerSentEventFieldKind.cs deleted file mode 100644 index a1162695..00000000 --- a/src/Utility/ServerSentEventFieldKind.cs +++ /dev/null @@ -1,14 +0,0 @@ -namespace OpenAI; - -/// -/// The kind of line or field received over an SSE stream. -/// See SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html -/// -internal enum ServerSentEventFieldKind -{ - Ignore, - Event, - Data, - Id, - Retry, -} diff --git a/src/Utility/ServerSentEventReader.cs b/src/Utility/ServerSentEventReader.cs deleted file mode 100644 index 1f59022e..00000000 --- a/src/Utility/ServerSentEventReader.cs +++ /dev/null @@ -1,192 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.IO; -using System.Threading; -using System.Threading.Tasks; - -#nullable enable - -namespace OpenAI; - -/// -/// An SSE event reader that reads lines from an SSE stream and composes them -/// into SSE events. -/// See SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html -/// -internal sealed class ServerSentEventReader -{ - private readonly StreamReader _reader; - - public ServerSentEventReader(Stream stream) - { - Argument.AssertNotNull(stream, nameof(stream)); - - // The creator of the reader has responsibility for disposing the - // stream passed to the reader's constructor. - _reader = new StreamReader(stream); - - LastEventId = string.Empty; - ReconnectionInterval = Timeout.InfiniteTimeSpan; - } - - public string LastEventId { get; private set; } - - public TimeSpan ReconnectionInterval { get; private set; } - - /// - /// Synchronously retrieves the next server-sent event from the underlying stream, blocking until a new event is - /// available and returning null once no further data is present on the stream. - /// - /// An optional cancellation token that can abort subsequent reads. - /// - /// The next in the stream, or null once no more data can be read from the stream. - /// - public ServerSentEvent? TryGetNextEvent(CancellationToken cancellationToken = default) - { - PendingEvent pending = default; - while (true) - { - cancellationToken.ThrowIfCancellationRequested(); - - // Note: would be nice to have polyfill that takes cancellation token, - // but may become moot if we shift to all UTF-8. - string? line = _reader.ReadLine(); - - if (line is null) - { - // A null line indicates end of input - return null; - } - - ProcessLine(line, ref pending, out bool dispatch); - - if (dispatch) - { - return pending.ToEvent(); - } - } - } - - /// - /// Asynchronously retrieves the next server-sent event from the underlying stream, blocking until a new event is - /// available and returning null once no further data is present on the stream. - /// - /// An optional cancellation token that can abort subsequent reads. - /// - /// The next in the stream, or null once no more data can be read from the stream. - /// - public async Task TryGetNextEventAsync(CancellationToken cancellationToken = default) - { - PendingEvent pending = default; - while (true) - { - cancellationToken.ThrowIfCancellationRequested(); - - // Note: would be nice to have polyfill that takes cancellation token, - // but may become moot if we shift to all UTF-8. - string? line = await _reader.ReadLineAsync().ConfigureAwait(false); - - if (line is null) - { - // A null line indicates end of input - return null; - } - - ProcessLine(line, ref pending, out bool dispatch); - - if (dispatch) - { - return pending.ToEvent(); - } - } - } - - private void ProcessLine(string line, ref PendingEvent pending, out bool dispatch) - { - dispatch = false; - - if (line.Length == 0) - { - if (pending.DataLength == 0) - { - // Per spec, if there's no data, don't dispatch an event. - pending = default; - } - else - { - dispatch = true; - } - } - else if (line[0] != ':') - { - // Per spec, ignore comment lines (i.e. that begin with ':'). - // If we got this far, process the field + value and accumulate - // it for the next dispatched event. - ServerSentEventField field = new(line); - switch (field.FieldType) - { - case ServerSentEventFieldKind.Event: - pending.EventTypeField = field; - break; - case ServerSentEventFieldKind.Data: - // Per spec, we'll append \n when we concatenate the data lines. - pending.DataLength += field.Value.Length + 1; - pending.DataFields.Add(field); - break; - case ServerSentEventFieldKind.Id: - LastEventId = field.Value.ToString(); - break; - case ServerSentEventFieldKind.Retry: - if (field.Value.Length > 0 && int.TryParse(field.Value.ToString(), out int retry)) - { - ReconnectionInterval = TimeSpan.FromMilliseconds(retry); - } - break; - default: - // Ignore - break; - } - } - } - - private struct PendingEvent - { - private const char LF = '\n'; - - private List? _dataFields; - - public int DataLength { get; set; } - public List DataFields => _dataFields ??= new(); - public ServerSentEventField? EventTypeField { get; set; } - - public ServerSentEvent ToEvent() - { - Debug.Assert(DataLength > 0); - - // Per spec, if event type buffer is empty, set event.type to "message". - string type = EventTypeField.HasValue ? - EventTypeField.Value.Value.ToString() : - "message"; - - Memory buffer = new(new char[DataLength]); - - int curr = 0; - foreach (ServerSentEventField field in DataFields) - { - Debug.Assert(field.FieldType == ServerSentEventFieldKind.Data); - - field.Value.Span.CopyTo(buffer.Span.Slice(curr)); - - // Per spec, append trailing LF to each data field value. - buffer.Span[curr + field.Value.Length] = LF; - curr += field.Value.Length + 1; - } - - // Per spec, remove trailing LF from concatenated data fields. - string data = buffer.Slice(0, buffer.Length - 1).ToString(); - - return new ServerSentEvent(type, data); - } - } -} diff --git a/src/Utility/System.Net.ServerSentEvents.cs b/src/Utility/System.Net.ServerSentEvents.cs new file mode 100644 index 00000000..ea0b4f81 --- /dev/null +++ b/src/Utility/System.Net.ServerSentEvents.cs @@ -0,0 +1,623 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +// This file contains a source copy of: +// https://github.com/dotnet/runtime/tree/2bd15868f12ace7cee9999af61d5c130b2603f04/src/libraries/System.Net.ServerSentEvents/src/System/Net/ServerSentEvents +// Once the System.Net.ServerSentEvents package is available, this file should be removed and replaced with a package reference. +// +// The only changes made to this code from the original are: +// - Enabled nullable reference types at file scope, and use a few null suppression operators to work around the lack of [NotNull] +// - Put into a single file for ease of management (it should not be edited in this repo). +// - Changed public types to be internal. +// - Removed a use of a [NotNull] attribute to assist in netstandard2.0 compilation. +// - Replaced a reference to a .resx string with an inline constant. + +#nullable enable + +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.Globalization; +using System.IO; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading.Tasks; +using System.Threading; + +namespace System.Net.ServerSentEvents +{ + /// Represents a server-sent event. + /// Specifies the type of data payload in the event. + internal readonly struct SseItem + { + /// Initializes the server-sent event. + /// The event's payload. + /// The event's type. + public SseItem(T data, string eventType) + { + Data = data; + EventType = eventType; + } + + /// Gets the event's payload. + public T Data { get; } + + /// Gets the event's type. + public string EventType { get; } + } + + /// Encapsulates a method for parsing the bytes payload of a server-sent event. + /// Specifies the type of the return value of the parser. + /// The event's type. + /// The event's payload bytes. + /// The parsed . + internal delegate T SseItemParser(string eventType, ReadOnlySpan data); + + /// Provides a parser for parsing server-sent events. + internal static class SseParser + { + /// The default ("message") for an event that did not explicitly specify a type. + public const string EventTypeDefault = "message"; + + /// Creates a parser for parsing a of server-sent events into a sequence of values. + /// The stream containing the data to parse. + /// + /// The enumerable of strings, which may be enumerated synchronously or asynchronously. The strings + /// are decoded from the UTF8-encoded bytes of the payload of each event. + /// + /// is null. + /// + /// This overload has behavior equivalent to calling with a delegate + /// that decodes the data of each event using 's GetString method. + /// + public static SseParser Create(Stream sseStream) => + Create(sseStream, static (_, bytes) => Utf8GetString(bytes)); + + /// Creates a parser for parsing a of server-sent events into a sequence of values. + /// Specifies the type of data in each event. + /// The stream containing the data to parse. + /// The parser to use to transform each payload of bytes into a data element. + /// The enumerable, which may be enumerated synchronously or asynchronously. + /// is null. + /// is null. + public static SseParser Create(Stream sseStream, SseItemParser itemParser) => + new SseParser( + sseStream ?? throw new ArgumentNullException(nameof(sseStream)), + itemParser ?? throw new ArgumentNullException(nameof(itemParser))); + + /// Encoding.UTF8.GetString(bytes) + internal static string Utf8GetString(ReadOnlySpan bytes) + { +#if NET + return Encoding.UTF8.GetString(bytes); +#else + unsafe + { + fixed (byte* ptr = bytes) + { + return ptr is null ? + string.Empty : + Encoding.UTF8.GetString(ptr, bytes.Length); + } + } +#endif + } + } + + /// Provides a parser for server-sent events information. + /// Specifies the type of data parsed from an event. + internal sealed class SseParser + { + // For reference: + // Specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#server-sent-events + + /// Carriage Return. + private const byte CR = (byte)'\r'; + /// Line Feed. + private const byte LF = (byte)'\n'; + /// Carriage Return Line Feed. + private static ReadOnlySpan CRLF => "\r\n"u8; + + /// The default size of an ArrayPool buffer to rent. + /// Larger size used by default to minimize number of reads. Smaller size used in debug to stress growth/shifting logic. + private const int DefaultArrayPoolRentSize = +#if DEBUG + 16; +#else + 1024; +#endif + + /// The stream to be parsed. + private readonly Stream _stream; + /// The parser delegate used to transform bytes into a . + private readonly SseItemParser _itemParser; + + /// Indicates whether the enumerable has already been used for enumeration. + private int _used; + + /// Buffer, either empty or rented, containing the data being read from the stream while looking for the next line. + private byte[] _lineBuffer = []; + /// The starting offset of valid data in . + private int _lineOffset; + /// The length of valid data in , starting from . + private int _lineLength; + /// The index in where a newline ('\r', '\n', or "\r\n") was found. + private int _newlineIndex; + /// The index in of characters already checked for newlines. + /// + /// This is to avoid O(LineLength^2) behavior in the rare case where we have long lines that are built-up over multiple reads. + /// We want to avoid re-checking the same characters we've already checked over and over again. + /// + private int _lastSearchedForNewline; + /// Set when eof has been reached in the stream. + private bool _eof; + + /// Rented buffer containing buffered data for the next event. + private byte[]? _dataBuffer; + /// The length of valid data in , starting from index 0. + private int _dataLength; + /// Whether data has been appended to . + /// This can be different than != 0 if empty data was appended. + private bool _dataAppended; + + /// The event type for the next event. + private string _eventType = SseParser.EventTypeDefault; + + /// Initialize the enumerable. + /// The stream to parse. + /// The function to use to parse payload bytes into a . + internal SseParser(Stream stream, SseItemParser itemParser) + { + _stream = stream; + _itemParser = itemParser; + } + + /// Gets an enumerable of the server-sent events from this parser. + /// The parser has already been enumerated. Such an exception may propagate out of a call to . + public IEnumerable> Enumerate() + { + // Validate that the parser is only used for one enumeration. + ThrowIfNotFirstEnumeration(); + + // Rent a line buffer. This will grow as needed. The line buffer is what's passed to the stream, + // so we want it to be large enough to reduce the number of reads we need to do when data is + // arriving quickly. (In debug, we use a smaller buffer to stress the growth and shifting logic.) + _lineBuffer = ArrayPool.Shared.Rent(DefaultArrayPoolRentSize); + try + { + // Spec: "Event streams in this format must always be encoded as UTF-8". + // Skip a UTF8 BOM if it exists at the beginning of the stream. (The BOM is defined as optional in the SSE grammar.) + while (FillLineBuffer() != 0 && _lineLength < Utf8Bom.Length) ; + SkipBomIfPresent(); + + // Process all events in the stream. + while (true) + { + // See if there's a complete line in data already read from the stream. Lines are permitted to + // end with CR, LF, or CRLF. Look for all of them and if we find one, process the line. However, + // if we only find a CR and it's at the end of the read data, don't process it now, as we want + // to process it together with an LF that might immediately follow, rather than treating them + // as two separate characters, in which case we'd incorrectly process the CR as a line by itself. + GetNextSearchOffsetAndLength(out int searchOffset, out int searchLength); + _newlineIndex = _lineBuffer.AsSpan(searchOffset, searchLength).IndexOfAny(CR, LF); + if (_newlineIndex >= 0) + { + _lastSearchedForNewline = -1; + _newlineIndex += searchOffset; + if (_lineBuffer[_newlineIndex] is LF || // the newline is LF + _newlineIndex - _lineOffset + 1 < _lineLength || // we must have CR and we have whatever comes after it + _eof) // if we get here, we know we have a CR at the end of the buffer, so it's definitely the whole newline if we've hit EOF + { + // Process the line. + if (ProcessLine(out SseItem sseItem, out int advance)) + { + yield return sseItem; + } + + // Move past the line. + _lineOffset += advance; + _lineLength -= advance; + continue; + } + } + else + { + // Record the last position searched for a newline. The next time we search, + // we'll search from here rather than from _lineOffset, in order to avoid searching + // the same characters again. + _lastSearchedForNewline = _lineOffset + _lineLength; + } + + // We've processed everything in the buffer we currently can, so if we've already read EOF, we're done. + if (_eof) + { + // Spec: "Once the end of the file is reached, any pending data must be discarded. (If the file ends in the middle of an + // event, before the final empty line, the incomplete event is not dispatched.)" + break; + } + + // Read more data into the buffer. + FillLineBuffer(); + } + } + finally + { + ArrayPool.Shared.Return(_lineBuffer); + if (_dataBuffer is not null) + { + ArrayPool.Shared.Return(_dataBuffer); + } + } + } + + /// Gets an asynchronous enumerable of the server-sent events from this parser. + /// The cancellation token to use to cancel the enumeration. + /// The parser has already been enumerated. Such an exception may propagate out of a call to . + /// The enumeration was canceled. Such an exception may propagate out of a call to . + public async IAsyncEnumerable> EnumerateAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + // Validate that the parser is only used for one enumeration. + ThrowIfNotFirstEnumeration(); + + // Rent a line buffer. This will grow as needed. The line buffer is what's passed to the stream, + // so we want it to be large enough to reduce the number of reads we need to do when data is + // arriving quickly. (In debug, we use a smaller buffer to stress the growth and shifting logic.) + _lineBuffer = ArrayPool.Shared.Rent(DefaultArrayPoolRentSize); + try + { + // Spec: "Event streams in this format must always be encoded as UTF-8". + // Skip a UTF8 BOM if it exists at the beginning of the stream. (The BOM is defined as optional in the SSE grammar.) + while (await FillLineBufferAsync(cancellationToken).ConfigureAwait(false) != 0 && _lineLength < Utf8Bom.Length) ; + SkipBomIfPresent(); + + // Process all events in the stream. + while (true) + { + // See if there's a complete line in data already read from the stream. Lines are permitted to + // end with CR, LF, or CRLF. Look for all of them and if we find one, process the line. However, + // if we only find a CR and it's at the end of the read data, don't process it now, as we want + // to process it together with an LF that might immediately follow, rather than treating them + // as two separate characters, in which case we'd incorrectly process the CR as a line by itself. + GetNextSearchOffsetAndLength(out int searchOffset, out int searchLength); + _newlineIndex = _lineBuffer.AsSpan(searchOffset, searchLength).IndexOfAny(CR, LF); + if (_newlineIndex >= 0) + { + _lastSearchedForNewline = -1; + _newlineIndex += searchOffset; + if (_lineBuffer[_newlineIndex] is LF || // newline is LF + _newlineIndex - _lineOffset + 1 < _lineLength || // newline is CR, and we have whatever comes after it + _eof) // if we get here, we know we have a CR at the end of the buffer, so it's definitely the whole newline if we've hit EOF + { + // Process the line. + if (ProcessLine(out SseItem sseItem, out int advance)) + { + yield return sseItem; + } + + // Move past the line. + _lineOffset += advance; + _lineLength -= advance; + continue; + } + } + else + { + // Record the last position searched for a newline. The next time we search, + // we'll search from here rather than from _lineOffset, in order to avoid searching + // the same characters again. + _lastSearchedForNewline = searchOffset + searchLength; + } + + // We've processed everything in the buffer we currently can, so if we've already read EOF, we're done. + if (_eof) + { + // Spec: "Once the end of the file is reached, any pending data must be discarded. (If the file ends in the middle of an + // event, before the final empty line, the incomplete event is not dispatched.)" + break; + } + + // Read more data into the buffer. + await FillLineBufferAsync(cancellationToken).ConfigureAwait(false); + } + } + finally + { + ArrayPool.Shared.Return(_lineBuffer); + if (_dataBuffer is not null) + { + ArrayPool.Shared.Return(_dataBuffer); + } + } + } + + /// Gets the next index and length with which to perform a newline search. + private void GetNextSearchOffsetAndLength(out int searchOffset, out int searchLength) + { + if (_lastSearchedForNewline > _lineOffset) + { + searchOffset = _lastSearchedForNewline; + searchLength = _lineLength - (_lastSearchedForNewline - _lineOffset); + } + else + { + searchOffset = _lineOffset; + searchLength = _lineLength; + } + + Debug.Assert(searchOffset >= _lineOffset, $"{searchOffset}, {_lineLength}"); + Debug.Assert(searchOffset <= _lineOffset + _lineLength, $"{searchOffset}, {_lineOffset}, {_lineLength}"); + Debug.Assert(searchOffset <= _lineBuffer.Length, $"{searchOffset}, {_lineBuffer.Length}"); + + Debug.Assert(searchLength >= 0, $"{searchLength}"); + Debug.Assert(searchLength <= _lineLength, $"{searchLength}, {_lineLength}"); + } + + private int GetNewLineLength() + { + Debug.Assert(_newlineIndex - _lineOffset < _lineLength, "Expected to be positioned at a non-empty newline"); + return _lineBuffer.AsSpan(_newlineIndex, _lineLength - (_newlineIndex - _lineOffset)).StartsWith(CRLF) ? 2 : 1; + } + + /// + /// If there's no room remaining in the line buffer, either shifts the contents + /// left or grows the buffer in order to make room for the next read. + /// + private void ShiftOrGrowLineBufferIfNecessary() + { + // If data we've read is butting up against the end of the buffer and + // it's not taking up the entire buffer, slide what's there down to + // the beginning, making room to read more data into the buffer (since + // there's no newline in the data that's there). Otherwise, if the whole + // buffer is full, grow the buffer to accommodate more data, since, again, + // what's there doesn't contain a newline and thus a line is longer than + // the current buffer accommodates. + if (_lineOffset + _lineLength == _lineBuffer.Length) + { + if (_lineOffset != 0) + { + _lineBuffer.AsSpan(_lineOffset, _lineLength).CopyTo(_lineBuffer); + if (_lastSearchedForNewline >= 0) + { + _lastSearchedForNewline -= _lineOffset; + } + _lineOffset = 0; + } + else if (_lineLength == _lineBuffer.Length) + { + GrowBuffer(ref _lineBuffer!, _lineBuffer.Length * 2); + } + } + } + + /// Processes a complete line from the SSE stream. + /// The parsed item if the method returns true. + /// How many characters to advance in the line buffer. + /// true if an SSE item was successfully parsed; otherwise, false. + private bool ProcessLine(out SseItem sseItem, out int advance) + { + ReadOnlySpan line = _lineBuffer.AsSpan(_lineOffset, _newlineIndex - _lineOffset); + + // Spec: "If the line is empty (a blank line) Dispatch the event" + if (line.IsEmpty) + { + advance = GetNewLineLength(); + + if (_dataAppended) + { + sseItem = new SseItem(_itemParser(_eventType, _dataBuffer.AsSpan(0, _dataLength)), _eventType); + _eventType = SseParser.EventTypeDefault; + _dataLength = 0; + _dataAppended = false; + return true; + } + + sseItem = default; + return false; + } + + // Find the colon separating the field name and value. + int colonPos = line.IndexOf((byte)':'); + ReadOnlySpan fieldName; + ReadOnlySpan fieldValue; + if (colonPos >= 0) + { + // Spec: "Collect the characters on the line before the first U+003A COLON character (:), and let field be that string." + fieldName = line.Slice(0, colonPos); + + // Spec: "Collect the characters on the line after the first U+003A COLON character (:), and let value be that string. + // If value starts with a U+0020 SPACE character, remove it from value." + fieldValue = line.Slice(colonPos + 1); + if (!fieldValue.IsEmpty && fieldValue[0] == (byte)' ') + { + fieldValue = fieldValue.Slice(1); + } + } + else + { + // Spec: "using the whole line as the field name, and the empty string as the field value." + fieldName = line; + fieldValue = []; + } + + if (fieldName.SequenceEqual("data"u8)) + { + // Spec: "Append the field value to the data buffer, then append a single U+000A LINE FEED (LF) character to the data buffer." + // Spec: "If the data buffer's last character is a U+000A LINE FEED (LF) character, then remove the last character from the data buffer." + + // If there's nothing currently in the data buffer and we can easily detect that this line is immediately followed by + // an empty line, we can optimize it to just handle the data directly from the line buffer, rather than first copying + // into the data buffer and dispatching from there. + if (!_dataAppended) + { + int newlineLength = GetNewLineLength(); + ReadOnlySpan remainder = _lineBuffer.AsSpan(_newlineIndex + newlineLength, _lineLength - line.Length - newlineLength); + if (!remainder.IsEmpty && + (remainder[0] is LF || (remainder[0] is CR && remainder.Length > 1))) + { + advance = line.Length + newlineLength + (remainder.StartsWith(CRLF) ? 2 : 1); + sseItem = new SseItem(_itemParser(_eventType, fieldValue), _eventType); + _eventType = SseParser.EventTypeDefault; + return true; + } + } + + // We need to copy the data from the data buffer to the line buffer. Make sure there's enough room. + if (_dataBuffer is null || _dataLength + _lineLength + 1 > _dataBuffer.Length) + { + GrowBuffer(ref _dataBuffer, _dataLength + _lineLength + 1); + } + + // Append a newline if there's already content in the buffer. + // Then copy the field value to the data buffer + if (_dataAppended) + { + _dataBuffer![_dataLength++] = LF; + } + fieldValue.CopyTo(_dataBuffer.AsSpan(_dataLength)); + _dataLength += fieldValue.Length; + _dataAppended = true; + } + else if (fieldName.SequenceEqual("event"u8)) + { + // Spec: "Set the event type buffer to field value." + _eventType = SseParser.Utf8GetString(fieldValue); + } + else if (fieldName.SequenceEqual("id"u8)) + { + // Spec: "If the field value does not contain U+0000 NULL, then set the last event ID buffer to the field value. Otherwise, ignore the field." + if (fieldValue.IndexOf((byte)'\0') < 0) + { + // Note that fieldValue might be empty, in which case LastEventId will naturally be reset to the empty string. This is per spec. + LastEventId = SseParser.Utf8GetString(fieldValue); + } + } + else if (fieldName.SequenceEqual("retry"u8)) + { + // Spec: "If the field value consists of only ASCII digits, then interpret the field value as an integer in base ten, + // and set the event stream's reconnection time to that integer. Otherwise, ignore the field." + if (long.TryParse( +#if NET7_0_OR_GREATER + fieldValue, +#else + SseParser.Utf8GetString(fieldValue), +#endif + NumberStyles.None, CultureInfo.InvariantCulture, out long milliseconds)) + { + ReconnectionInterval = TimeSpan.FromMilliseconds(milliseconds); + } + } + else + { + // We'll end up here if the line starts with a colon, producing an empty field name, or if the field name is otherwise unrecognized. + // Spec: "If the line starts with a U+003A COLON character (:) Ignore the line." + // Spec: "Otherwise, The field is ignored" + } + + advance = line.Length + GetNewLineLength(); + sseItem = default; + return false; + } + + /// Gets the last event ID. + /// This value is updated any time a new last event ID is parsed. It is not reset between SSE items. + public string LastEventId { get; private set; } = string.Empty; // Spec: "must be initialized to the empty string" + + /// Gets the reconnection interval. + /// + /// If no retry event was received, this defaults to , and it will only + /// ever be in that situation. If a client wishes to retry, the server-sent + /// events specification states that the interval may then be decided by the client implementation and should be a + /// few seconds. + /// + public TimeSpan ReconnectionInterval { get; private set; } = Timeout.InfiniteTimeSpan; + + /// Transitions the object to a used state, throwing if it's already been used. + private void ThrowIfNotFirstEnumeration() + { + if (Interlocked.Exchange(ref _used, 1) != 0) + { + throw new InvalidOperationException("The enumerable may be enumerated only once."); + } + } + + /// Reads data from the stream into the line buffer. + private int FillLineBuffer() + { + ShiftOrGrowLineBufferIfNecessary(); + + int offset = _lineOffset + _lineLength; + int bytesRead = _stream.Read( +#if NET + _lineBuffer.AsSpan(offset)); +#else + _lineBuffer, offset, _lineBuffer.Length - offset); +#endif + + if (bytesRead > 0) + { + _lineLength += bytesRead; + } + else + { + _eof = true; + bytesRead = 0; + } + + return bytesRead; + } + + /// Reads data asynchronously from the stream into the line buffer. + private async ValueTask FillLineBufferAsync(CancellationToken cancellationToken) + { + ShiftOrGrowLineBufferIfNecessary(); + + int offset = _lineOffset + _lineLength; + int bytesRead = await +#if NET + _stream.ReadAsync(_lineBuffer.AsMemory(offset), cancellationToken) +#else + new ValueTask(_stream.ReadAsync(_lineBuffer, offset, _lineBuffer.Length - offset, cancellationToken)) +#endif + .ConfigureAwait(false); + + if (bytesRead > 0) + { + _lineLength += bytesRead; + } + else + { + _eof = true; + bytesRead = 0; + } + + return bytesRead; + } + + /// Gets the UTF8 BOM. + private static ReadOnlySpan Utf8Bom => [0xEF, 0xBB, 0xBF]; + + /// Called at the beginning of processing to skip over an optional UTF8 byte order mark. + private void SkipBomIfPresent() + { + Debug.Assert(_lineOffset == 0, $"Expected _lineOffset == 0, got {_lineOffset}"); + + if (_lineBuffer.AsSpan(0, _lineLength).StartsWith(Utf8Bom)) + { + _lineOffset += 3; + _lineLength -= 3; + } + } + + /// Grows the buffer, returning the existing one to the ArrayPool and renting an ArrayPool replacement. + private static void GrowBuffer(ref byte[]? buffer, int minimumLength) + { + byte[]? toReturn = buffer; + buffer = ArrayPool.Shared.Rent(Math.Max(minimumLength, DefaultArrayPoolRentSize)); + if (toReturn is not null) + { + Array.Copy(toReturn, buffer, toReturn.Length); + ArrayPool.Shared.Return(toReturn); + } + } + } +} \ No newline at end of file