Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added CancellationToken parameters to the most complete convenience overloads. #53

Merged
merged 13 commits into from
Jun 14, 2024
252 changes: 157 additions & 95 deletions src/Custom/Assistants/AssistantClient.cs

Large diffs are not rendered by default.

31 changes: 19 additions & 12 deletions src/Custom/Audio/AudioClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.ClientModel;
using System.ClientModel.Primitives;
using System.IO;
using System.Threading;
using System.Threading.Tasks;

namespace OpenAI.Audio;
Expand Down Expand Up @@ -85,16 +86,17 @@ protected internal AudioClient(ClientPipeline pipeline, string model, Uri endpoi
/// <param name="text"> The text for the voice to speak. </param>
/// <param name="voice"> The voice to use. </param>
/// <param name="options"> Additional options to tailor the text-to-speech request. </param>
/// <param name="cancellationToken">A token that can be used to cancel this method call.</param>
/// <returns> The generated audio in the specified output format. </returns>
public virtual async Task<ClientResult<BinaryData>> GenerateSpeechFromTextAsync(string text, GeneratedSpeechVoice voice, SpeechGenerationOptions options = null)
public virtual async Task<ClientResult<BinaryData>> GenerateSpeechFromTextAsync(string text, GeneratedSpeechVoice voice, SpeechGenerationOptions options = null, CancellationToken cancellationToken = default)
{
Argument.AssertNotNull(text, nameof(text));

options ??= new();
CreateSpeechGenerationOptions(text, voice, ref options);

using BinaryContent content = options.ToBinaryContent();
ClientResult result = await GenerateSpeechFromTextAsync(content, null).ConfigureAwait(false);
ClientResult result = await GenerateSpeechFromTextAsync(content, cancellationToken.ToRequestOptions()).ConfigureAwait(false);
return ClientResult.FromValue(result.GetRawResponse().Content, result.GetRawResponse());
}

Expand All @@ -108,16 +110,17 @@ public virtual async Task<ClientResult<BinaryData>> GenerateSpeechFromTextAsync(
/// <param name="text"> The text for the voice to speak. </param>
/// <param name="voice"> The voice to use. </param>
/// <param name="options"> Additional options to tailor the text-to-speech request. </param>
/// <param name="cancellationToken">A token that can be used to cancel this method call.</param>
/// <returns> The generated audio in the specified output format. </returns>
public virtual ClientResult<BinaryData> GenerateSpeechFromText(string text, GeneratedSpeechVoice voice, SpeechGenerationOptions options = null)
public virtual ClientResult<BinaryData> GenerateSpeechFromText(string text, GeneratedSpeechVoice voice, SpeechGenerationOptions options = null, CancellationToken cancellationToken = default)
{
Argument.AssertNotNull(text, nameof(text));

options ??= new();
CreateSpeechGenerationOptions(text, voice, ref options);

using BinaryContent content = options.ToBinaryContent();
ClientResult result = GenerateSpeechFromText(content, (RequestOptions)null);
ClientResult result = GenerateSpeechFromText(content, cancellationToken.ToRequestOptions()); ;
return ClientResult.FromValue(result.GetRawResponse().Content, result.GetRawResponse());
}

Expand All @@ -135,10 +138,11 @@ public virtual ClientResult<BinaryData> GenerateSpeechFromText(string text, Gene
/// not match.
/// </param>
/// <param name="options"> Additional options to tailor the audio transcription request. </param>
/// <param name="cancellationToken">A token that can be used to cancel this method call.</param>
/// <exception cref="ArgumentNullException"> <paramref name="audio"/> or <paramref name="audioFilename"/> is null. </exception>
/// <exception cref="ArgumentException"> <paramref name="audioFilename"/> is an empty string, and was expected to be non-empty. </exception>
/// <returns> The audio transcription. </returns>
public virtual async Task<ClientResult<AudioTranscription>> TranscribeAudioAsync(Stream audio, string audioFilename, AudioTranscriptionOptions options = null)
public virtual async Task<ClientResult<AudioTranscription>> TranscribeAudioAsync(Stream audio, string audioFilename, AudioTranscriptionOptions options = null, CancellationToken cancellationToken = default)
{
Argument.AssertNotNull(audio, nameof(audio));
Argument.AssertNotNullOrEmpty(audioFilename, nameof(audioFilename));
Expand All @@ -147,7 +151,7 @@ public virtual async Task<ClientResult<AudioTranscription>> TranscribeAudioAsync
CreateAudioTranscriptionOptions(audio, audioFilename, ref options);

using MultipartFormDataBinaryContent content = options.ToMultipartContent(audio, audioFilename);
ClientResult result = await TranscribeAudioAsync(content, content.ContentType).ConfigureAwait(false);
ClientResult result = await TranscribeAudioAsync(content, content.ContentType, cancellationToken.ToRequestOptions()).ConfigureAwait(false);
return ClientResult.FromValue(AudioTranscription.FromResponse(result.GetRawResponse()), result.GetRawResponse());
}

Expand All @@ -161,10 +165,11 @@ public virtual async Task<ClientResult<AudioTranscription>> TranscribeAudioAsync
/// not match.
/// </param>
/// <param name="options"> Additional options to tailor the audio transcription request. </param>
/// <param name="cancellationToken">A token that can be used to cancel this method call.</param>
/// <exception cref="ArgumentNullException"> <paramref name="audio"/> or <paramref name="audioFilename"/> is null. </exception>
/// <exception cref="ArgumentException"> <paramref name="audioFilename"/> is an empty string, and was expected to be non-empty. </exception>
/// <returns> The audio transcription. </returns>
public virtual ClientResult<AudioTranscription> TranscribeAudio(Stream audio, string audioFilename, AudioTranscriptionOptions options = null)
public virtual ClientResult<AudioTranscription> TranscribeAudio(Stream audio, string audioFilename, AudioTranscriptionOptions options = null, CancellationToken cancellationToken = default)
{
Argument.AssertNotNull(audio, nameof(audio));
Argument.AssertNotNullOrEmpty(audioFilename, nameof(audioFilename));
Expand All @@ -173,7 +178,7 @@ public virtual ClientResult<AudioTranscription> TranscribeAudio(Stream audio, st
CreateAudioTranscriptionOptions(audio, audioFilename, ref options);

using MultipartFormDataBinaryContent content = options.ToMultipartContent(audio, audioFilename);
ClientResult result = TranscribeAudio(content, content.ContentType);
ClientResult result = TranscribeAudio(content, content.ContentType, cancellationToken.ToRequestOptions());
return ClientResult.FromValue(AudioTranscription.FromResponse(result.GetRawResponse()), result.GetRawResponse());
}

Expand Down Expand Up @@ -229,10 +234,11 @@ public virtual ClientResult<AudioTranscription> TranscribeAudio(string audioFile
/// not match.
/// </param>
/// <param name="options"> Additional options to tailor the audio translation request. </param>
/// <param name="cancellationToken">A token that can be used to cancel this method call.</param>
/// <exception cref="ArgumentNullException"> <paramref name="audio"/> or <paramref name="audioFilename"/> is null. </exception>
/// <exception cref="ArgumentException"> <paramref name="audioFilename"/> is an empty string, and was expected to be non-empty. </exception>
/// <returns> The audio translation. </returns>
public virtual async Task<ClientResult<AudioTranslation>> TranslateAudioAsync(Stream audio, string audioFilename, AudioTranslationOptions options = null)
public virtual async Task<ClientResult<AudioTranslation>> TranslateAudioAsync(Stream audio, string audioFilename, AudioTranslationOptions options = null, CancellationToken cancellationToken = default)
{
Argument.AssertNotNull(audio, nameof(audio));
Argument.AssertNotNullOrEmpty(audioFilename, nameof(audioFilename));
Expand All @@ -241,7 +247,7 @@ public virtual async Task<ClientResult<AudioTranslation>> TranslateAudioAsync(St
CreateAudioTranslationOptions(audio, audioFilename, ref options);

using MultipartFormDataBinaryContent content = options.ToMultipartContent(audio, audioFilename);
ClientResult result = await TranslateAudioAsync(content, content.ContentType).ConfigureAwait(false);
ClientResult result = await TranslateAudioAsync(content, content.ContentType, cancellationToken.ToRequestOptions()).ConfigureAwait(false);
return ClientResult.FromValue(AudioTranslation.FromResponse(result.GetRawResponse()), result.GetRawResponse());
}

Expand All @@ -253,10 +259,11 @@ public virtual async Task<ClientResult<AudioTranslation>> TranslateAudioAsync(St
/// not match.
/// </param>
/// <param name="options"> Additional options to tailor the audio translation request. </param>
/// <param name="cancellationToken">A token that can be used to cancel this method call.</param>
/// <exception cref="ArgumentNullException"> <paramref name="audio"/> or <paramref name="audioFilename"/> is null. </exception>
/// <exception cref="ArgumentException"> <paramref name="audioFilename"/> is an empty string, and was expected to be non-empty. </exception>
/// <returns> The audio translation. </returns>
public virtual ClientResult<AudioTranslation> TranslateAudio(Stream audio, string audioFilename, AudioTranslationOptions options = null)
public virtual ClientResult<AudioTranslation> TranslateAudio(Stream audio, string audioFilename, AudioTranslationOptions options = null, CancellationToken cancellationToken = default)
{
Argument.AssertNotNull(audio, nameof(audio));
Argument.AssertNotNullOrEmpty(audioFilename, nameof(audioFilename));
Expand All @@ -265,7 +272,7 @@ public virtual ClientResult<AudioTranslation> TranslateAudio(Stream audio, strin
CreateAudioTranslationOptions(audio, audioFilename, ref options);

using MultipartFormDataBinaryContent content = options.ToMultipartContent(audio, audioFilename);
ClientResult result = TranslateAudio(content, content.ContentType);
ClientResult result = TranslateAudio(content, content.ContentType, cancellationToken.ToRequestOptions());
return ClientResult.FromValue(AudioTranslation.FromResponse(result.GetRawResponse()), result.GetRawResponse());
}

Expand Down
25 changes: 15 additions & 10 deletions src/Custom/Chat/ChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

namespace OpenAI.Chat;
Expand Down Expand Up @@ -68,16 +69,18 @@ protected internal ChatClient(ClientPipeline pipeline, string model, Uri endpoin
/// </summary>
/// <param name="messages"> The messages to provide as input and history for chat completion. </param>
/// <param name="options"> Additional options for the chat completion request. </param>
/// <param name="cancellationToken">A token that can be used to cancel this method call.</param>
/// <returns> A result for a single chat completion. </returns>
public virtual async Task<ClientResult<ChatCompletion>> CompleteChatAsync(IEnumerable<ChatMessage> messages, ChatCompletionOptions options = null)
public virtual async Task<ClientResult<ChatCompletion>> CompleteChatAsync(IEnumerable<ChatMessage> messages, ChatCompletionOptions options = null, CancellationToken cancellationToken = default)
{
Argument.AssertNotNullOrEmpty(messages, nameof(messages));

options ??= new();
CreateChatCompletionOptions(messages, ref options);

using BinaryContent content = options.ToBinaryContent();
ClientResult result = await CompleteChatAsync(content, null).ConfigureAwait(false);

ClientResult result = await CompleteChatAsync(content, cancellationToken.ToRequestOptions()).ConfigureAwait(false);
return ClientResult.FromValue(ChatCompletion.FromResponse(result.GetRawResponse()), result.GetRawResponse());
}

Expand All @@ -94,16 +97,17 @@ public virtual async Task<ClientResult<ChatCompletion>> CompleteChatAsync(params
/// </summary>
/// <param name="messages"> The messages to provide as input and history for chat completion. </param>
/// <param name="options"> Additional options for the chat completion request. </param>
/// <param name="cancellationToken">A token that can be used to cancel this method call.</param>
/// <returns> A result for a single chat completion. </returns>
public virtual ClientResult<ChatCompletion> CompleteChat(IEnumerable<ChatMessage> messages, ChatCompletionOptions options = null)
public virtual ClientResult<ChatCompletion> CompleteChat(IEnumerable<ChatMessage> messages, ChatCompletionOptions options = null, CancellationToken cancellationToken = default)
{
Argument.AssertNotNullOrEmpty(messages, nameof(messages));

options ??= new();
CreateChatCompletionOptions(messages, ref options);

using BinaryContent content = options.ToBinaryContent();
ClientResult result = CompleteChat(content, null);
ClientResult result = CompleteChat(content, cancellationToken.ToRequestOptions());
return ClientResult.FromValue(ChatCompletion.FromResponse(result.GetRawResponse()), result.GetRawResponse());

}
Expand All @@ -126,18 +130,19 @@ public virtual ClientResult<ChatCompletion> CompleteChat(params ChatMessage[] me
/// </remarks>
/// <param name="messages"> The messages to provide as input for chat completion. </param>
/// <param name="options"> Additional options for the chat completion request. </param>
/// <param name="cancellationToken">A token that can be used to cancel this method call.</param>
/// <returns> A streaming result with incremental chat completion updates. </returns>
public virtual AsyncResultCollection<StreamingChatCompletionUpdate> CompleteChatStreamingAsync(IEnumerable<ChatMessage> messages, ChatCompletionOptions options = null)
public virtual AsyncResultCollection<StreamingChatCompletionUpdate> CompleteChatStreamingAsync(IEnumerable<ChatMessage> messages, ChatCompletionOptions options = null, CancellationToken cancellationToken = default)
{
Argument.AssertNotNull(messages, nameof(messages));

options ??= new();
CreateChatCompletionOptions(messages, ref options, stream: true);

using BinaryContent content = options.ToBinaryContent();
RequestOptions requestOptions = new() { BufferResponse = false };

async Task<ClientResult> getResultAsync() =>
await CompleteChatAsync(content, requestOptions).ConfigureAwait(false);
await CompleteChatAsync(content, cancellationToken.ToRequestOptions(streaming: true)).ConfigureAwait(false);
return new AsyncStreamingChatCompletionUpdateCollection(getResultAsync);
}

Expand All @@ -164,17 +169,17 @@ public virtual AsyncResultCollection<StreamingChatCompletionUpdate> CompleteChat
/// </remarks>
/// <param name="messages"> The messages to provide as input for chat completion. </param>
/// <param name="options"> Additional options for the chat completion request. </param>
/// <param name="cancellationToken">A token that can be used to cancel this method call.</param>
/// <returns> A streaming result with incremental chat completion updates. </returns>
public virtual ResultCollection<StreamingChatCompletionUpdate> CompleteChatStreaming(IEnumerable<ChatMessage> messages, ChatCompletionOptions options = null)
public virtual ResultCollection<StreamingChatCompletionUpdate> CompleteChatStreaming(IEnumerable<ChatMessage> messages, ChatCompletionOptions options = null, CancellationToken cancellationToken = default)
{
Argument.AssertNotNull(messages, nameof(messages));

options ??= new();
CreateChatCompletionOptions(messages, ref options, stream: true);

using BinaryContent content = options.ToBinaryContent();
RequestOptions requestOptions = new() { BufferResponse = false };
ClientResult getResult() => CompleteChat(content, requestOptions);
ClientResult getResult() => CompleteChat(content, cancellationToken.ToRequestOptions(streaming: true));
return new StreamingChatCompletionUpdateCollection(getResult);
}

Expand Down
22 changes: 22 additions & 0 deletions src/Custom/Common/CancellationTokenExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using System.ClientModel.Primitives;
using System.Threading;

internal static class CancellationTokenExtensions
{
public static RequestOptions ToRequestOptions(this CancellationToken cancellationToken, bool streaming = false)
{
if (cancellationToken == default)
{
if (!streaming) return null;
return StreamRequestOptions;
}

return new RequestOptions() {
CancellationToken = cancellationToken,
BufferResponse = !streaming,
};
}

private static RequestOptions StreamRequestOptions => _streamRequestOptions ??= new() { BufferResponse = false };
private static RequestOptions _streamRequestOptions;
}
Loading
Loading