Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fine tuning tests #44695

Merged
merged 15 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sdk/openai/Azure.AI.OpenAI/src/Azure.AI.OpenAI.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,10 @@
<PackageReference Include="OpenAI" VersionOverride="2.0.0-*" />
</ItemGroup>

<ItemGroup>
<InternalsVisibleTo Include="Azure.AI.OpenAI.Tests">
<PublicKey>0024000004800000940000000602000000240000525341310004000001000100d15ddcb29688295338af4b7686603fe614abd555e09efba8fb88ee09e1f7b1ccaeed2e8f823fa9eef3fdd60217fc012ea67d2479751a0b8c087a4185541b851bd8b16f8d91b840e51b1cb0ba6fe647997e57429265e85ef62d565db50a69ae1647d54d7bd855e4db3d8a91510e5bcbd0edfbbecaa20a7bd9ae74593daa7b11b4</PublicKey>
</InternalsVisibleTo>
</ItemGroup>

</Project>
6 changes: 3 additions & 3 deletions sdk/openai/Azure.AI.OpenAI/src/Custom/AzureOpenAIClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ private static PipelinePolicy CreateAddUserAgentHeaderPolicy(AzureOpenAIClientOp
private static readonly string s_aoaiEndpointEnvironmentVariable = "AZURE_OPENAI_ENDPOINT";
private static readonly string s_aoaiApiKeyEnvironmentVariable = "AZURE_OPENAI_API_KEY";
private static readonly string s_userAgentHeaderKey = "User-Agent";
private static PipelineMessageClassifier _pipelineMessageClassifier200;
internal static PipelineMessageClassifier PipelineMessageClassifier200
=> _pipelineMessageClassifier200 ??= PipelineMessageClassifier.Create(stackalloc ushort[] { 200 });
private static PipelineMessageClassifier _pipelineMessageClassifier;
internal static PipelineMessageClassifier PipelineMessageClassifier
=> _pipelineMessageClassifier ??= PipelineMessageClassifier.Create(stackalloc ushort[] { 200, 201 });
}
45 changes: 45 additions & 0 deletions sdk/openai/Azure.AI.OpenAI/src/Custom/Files/AzureFileClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using OpenAI.Files;
using System.ClientModel;
using System.ClientModel.Primitives;

namespace Azure.AI.OpenAI.Files;
Expand Down Expand Up @@ -30,4 +31,48 @@ internal AzureFileClient(

protected AzureFileClient()
{ }

/// <inheritdoc />
public override ClientResult<OpenAIFileInfo> UploadFile(Stream file, string filename, FileUploadPurpose purpose, CancellationToken cancellationToken = default)
trrwilson marked this conversation as resolved.
Show resolved Hide resolved
{
if (purpose != FileUploadPurpose.FineTune)
{
return base.UploadFile(file, filename, purpose, cancellationToken);
}

// need to set the content type for fine tuning file uploads in Azure OpenAI
Argument.AssertNotNull(file, "file");
Argument.AssertNotNullOrEmpty(filename, "filename");

using MultipartFormDataBinaryContent content = CreateMultiPartContentWithMimeType(file, filename, purpose);
ClientResult clientResult = UploadFile(content, content.ContentType, new() { CancellationToken = cancellationToken });
return ClientResult.FromValue(OpenAIFileInfo.FromResponse(clientResult.GetRawResponse()), clientResult.GetRawResponse());
}

/// <inheritdoc />
public override async Task<ClientResult<OpenAIFileInfo>> UploadFileAsync(Stream file, string filename, FileUploadPurpose purpose, CancellationToken cancellationToken = default)
{
if (purpose != FileUploadPurpose.FineTune)
{
return await base.UploadFileAsync(file, filename, purpose, cancellationToken)
.ConfigureAwait(false);
}

// need to set the content type for fine tuning file uploads in Azure OpenAI
Argument.AssertNotNull(file, "file");
Argument.AssertNotNullOrEmpty(filename, "filename");

using MultipartFormDataBinaryContent content = CreateMultiPartContentWithMimeType(file, filename, purpose);
ClientResult result = await UploadFileAsync(content, content.ContentType, new() { CancellationToken = cancellationToken })
.ConfigureAwait(continueOnCapturedContext: false);
return ClientResult.FromValue(OpenAIFileInfo.FromResponse(result.GetRawResponse()), result.GetRawResponse());
}

private MultipartFormDataBinaryContent CreateMultiPartContentWithMimeType(Stream file, string filename, FileUploadPurpose purpose)
{
MultipartFormDataBinaryContent multipartFormDataBinaryContent = new MultipartFormDataBinaryContent();
multipartFormDataBinaryContent.Add(file, "file", filename, "text/plain");
multipartFormDataBinaryContent.Add(purpose.ToString(), "purpose");
return multipartFormDataBinaryContent;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#nullable enable

using System.ClientModel;
using System.ClientModel.Primitives;
using System.Diagnostics.CodeAnalysis;
using OpenAI.FineTuning;

namespace Azure.AI.OpenAI.FineTuning;

/// <summary>
/// Extension methods for Azure fine tuning clients.
/// </summary>
internal static class AzureFineTuningClientExtensions
{
/// <summary>
/// Deletes an Azure OpenAI fine tuning job.
/// </summary>
/// <param name="client">The Azure OpenAI fine tuning client.</param>
/// <param name="jobId">The identifier for the fine tuning job to delete.</param>
/// <param name="options">(Optional) The request options to use.</param>
/// <returns>The request result.</returns>
/// <remarks>The Azure OpenAI service will always return a success (HTTP 204) regardless of whether or not
/// the job you are trying to delete exists.</remarks>
[Experimental("AOAI001")]
public static ClientResult DeleteJob(this FineTuningClient client, string jobId, RequestOptions? options = null)
{
Argument.AssertNotNullOrEmpty(jobId, nameof(jobId));
return Cast(client).DeleteJob(jobId, options);
}

/// <summary>
/// Deletes an Azure OpenAI fine tuning job.
/// </summary>
/// <param name="client">The Azure OpenAI fine tuning client.</param>
/// <param name="jobId">The identifier for the fine tuning job to delete.</param>
/// <param name="options">(Optional) The request options to use.</param>
/// <returns>The request result.</returns>
/// <remarks>The Azure OpenAI service will always return a success (HTTP 204) regardless of whether or not
/// the job you are trying to delete exists.</remarks>
[Experimental("AOAI001")]
public static Task<ClientResult> DeleteJobAsync(this FineTuningClient client, string jobId, RequestOptions? options = null)
{
Argument.AssertNotNullOrEmpty(jobId, nameof(jobId));
return Cast(client).DeleteJobAsync(jobId, options);
}

private static AzureFineTuningClient Cast(FineTuningClient? client)
{
Argument.AssertNotNull(client, nameof(client));
var azureClient = client as AzureFineTuningClient;
if (azureClient == null)
{
throw new InvalidOperationException("Only supported on Azure OpenAI fine tuning clients");
}

return azureClient;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@

using System.ClientModel;
using System.ClientModel.Primitives;
using System.Diagnostics.CodeAnalysis;
using OpenAI.FineTuning;

namespace Azure.AI.OpenAI.FineTuning;

internal partial class AzureFineTuningClient : FineTuningClient
{
private readonly PipelineMessageClassifier DeleteJobClassifier = PipelineMessageClassifier.Create(stackalloc ushort[] { 204 });

public override ClientResult CreateJob(BinaryContent content, RequestOptions options = null)
{
using PipelineMessage message = CreateCreateJobRequestMessage(content, options);
Expand Down Expand Up @@ -61,6 +64,19 @@ public override async Task<ClientResult> GetJobEventsAsync(string fineTuningJobI
return ClientResult.FromResponse(response);
}

public override ClientResult GetJobCheckpoints(string fineTuningJobId, string after, int? limit, RequestOptions options)
{
using PipelineMessage message = CreateGetJobCheckpointsRequestMessage(fineTuningJobId, after, limit, options);
return ClientResult.FromResponse(Pipeline.ProcessMessage(message, options));
}

public override async Task<ClientResult> GetJobCheckpointsAsync(string fineTuningJobId, string after, int? limit, RequestOptions options)
{
using PipelineMessage message = CreateGetJobCheckpointsRequestMessage(fineTuningJobId, after, limit, options);
PipelineResponse response = await Pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false);
return ClientResult.FromResponse(response);
}

public override ClientResult CancelJob(string fineTuningJobId, RequestOptions options)
{
using PipelineMessage message = CreateCancelJobRequestMessage(fineTuningJobId, options);
Expand All @@ -74,10 +90,25 @@ public override async Task<ClientResult> CancelJobAsync(string fineTuningJobId,
return ClientResult.FromResponse(response);
}

[Experimental("AOAI001")]
public virtual ClientResult DeleteJob(string jobId, RequestOptions options = null)
{
using PipelineMessage message = CreateDeleteJobRequestMessage(jobId, options);
return ClientResult.FromResponse(Pipeline.ProcessMessage(message, options));
}

[Experimental("AOAI001")]
public virtual async Task<ClientResult> DeleteJobAsync(string jobId, RequestOptions options = null)
{
using PipelineMessage message = CreateDeleteJobRequestMessage(jobId, options);
PipelineResponse response = await Pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false);
return ClientResult.FromResponse(response);
}

private PipelineMessage CreateCreateJobRequestMessage(BinaryContent content, RequestOptions options)
=> new AzureOpenAIPipelineMessageBuilder(Pipeline, _endpoint, _apiVersion)
.WithMethod("POST")
.WithPath("fine_tuning")
.WithPath("fine_tuning", "jobs")
.WithContent(content, "application/json")
.WithAccept("application/json")
.WithOptions(options)
Expand All @@ -86,7 +117,7 @@ private PipelineMessage CreateCreateJobRequestMessage(BinaryContent content, Req
private PipelineMessage CreateGetJobsRequestMessage(string after, int? limit, RequestOptions options)
=> new AzureOpenAIPipelineMessageBuilder(Pipeline, _endpoint, _apiVersion)
.WithMethod("GET")
.WithPath("fine_tuning")
.WithPath("fine_tuning", "jobs")
.WithOptionalQueryParameter("after", after)
.WithOptionalQueryParameter("limit", limit)
.WithAccept("application/json")
Expand All @@ -96,15 +127,25 @@ private PipelineMessage CreateGetJobsRequestMessage(string after, int? limit, Re
private PipelineMessage CreateGetJobRequestMessage(string jobId, RequestOptions options)
=> new AzureOpenAIPipelineMessageBuilder(Pipeline, _endpoint, _apiVersion)
.WithMethod("GET")
.WithPath("fine_tuning", jobId)
.WithPath("fine_tuning", "jobs", jobId)
.WithAccept("application/json")
.WithOptions(options)
.Build();

private PipelineMessage CreateGetJobEventsRequestMessage(string jobId, string after, int? limit, RequestOptions options)
=> new AzureOpenAIPipelineMessageBuilder(Pipeline, _endpoint, _apiVersion)
.WithMethod("GET")
.WithPath("fine_tuning", jobId, "events")
.WithPath("fine_tuning", "jobs", jobId, "events")
.WithOptionalQueryParameter("after", after)
.WithOptionalQueryParameter("limit", limit)
.WithAccept("application/json")
.WithOptions(options)
.Build();

private PipelineMessage CreateGetJobCheckpointsRequestMessage(string jobId, string after, int? limit, RequestOptions options)
=> new AzureOpenAIPipelineMessageBuilder(Pipeline, _endpoint, _apiVersion)
.WithMethod("GET")
.WithPath("fine_tuning", "jobs", jobId, "checkpoints")
.WithOptionalQueryParameter("after", after)
.WithOptionalQueryParameter("limit", limit)
.WithAccept("application/json")
Expand All @@ -114,8 +155,17 @@ private PipelineMessage CreateGetJobEventsRequestMessage(string jobId, string af
private PipelineMessage CreateCancelJobRequestMessage(string jobId, RequestOptions options)
=> new AzureOpenAIPipelineMessageBuilder(Pipeline, _endpoint, _apiVersion)
.WithMethod("POST")
.WithPath("fine_tuning", jobId, "cancel")
.WithPath("fine_tuning", "jobs", jobId, "cancel")
.WithAccept("application/json")
.WithOptions(options)
.Build();

private PipelineMessage CreateDeleteJobRequestMessage(string jobId, RequestOptions options)
=> new AzureOpenAIPipelineMessageBuilder(Pipeline, _endpoint, _apiVersion)
.WithMethod("DELETE")
.WithPath("fine_tuning", "jobs", jobId)
.WithAccept("application/json")
.WithClassifier(DeleteJobClassifier)
.WithOptions(options)
.Build();
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ internal class AzureOpenAIPipelineMessageBuilder
private string _method;
private BinaryContent _content;
private readonly Dictionary<string, string> _headers = [];
private PipelineMessageClassifier _classifier;
private RequestOptions _options;
private bool? _bufferResponse;

Expand Down Expand Up @@ -90,12 +91,18 @@ public AzureOpenAIPipelineMessageBuilder WithResponseContentBuffering(bool? shou
return this;
}

public AzureOpenAIPipelineMessageBuilder WithClassifier(PipelineMessageClassifier classifier)
{
_classifier = classifier;
return this;
}

public PipelineMessage Build()
{
Argument.AssertNotNullOrWhiteSpace(_method, nameof(_method));

PipelineMessage message = _pipeline.CreateMessage();
message.ResponseClassifier = AzureOpenAIClient.PipelineMessageClassifier200;
message.ResponseClassifier = _classifier ?? AzureOpenAIClient.PipelineMessageClassifier;
if (_bufferResponse.HasValue)
{
message.BufferResponse = _bufferResponse.Value;
Expand Down
11 changes: 11 additions & 0 deletions sdk/openai/Azure.AI.OpenAI/tests/Assets/fine_tuning.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{"messages":[{"role":"system","content":"Given a sports headline, provide the following fields in a JSON dictionary, where applicable: \"player\" (full name), \"team\", \"sport\", and \"gender\". In the case there is more than one team return an array of that dictionary. Do not include any markdown characters such as ```json and ```"},{"role":"user","content":"Sources: Colts grant RB Taylor OK to seek trade"},{"role":"assistant","content":"{\n \"player\": \"Jonathan Taylor\",\n \"team\": \"Indianapolis Colts\",\n \"sport\": \"football\",\n \"gender\": \"male\"\n}"}]}
{"messages":[{"role":"system","content":"Given a sports headline, provide the following fields in a JSON dictionary, where applicable: \"player\" (full name), \"team\", \"sport\", and \"gender\". In the case there is more than one team return an array of that dictionary. Do not include any markdown characters such as ```json and ```"},{"role":"user","content":"OSU 'split down middle' on starting QB battle"},{"role":"assistant","content":"{\n \"team\": \"Ohio State University\",\n \"sport\": \"football\"\n}"}]}
{"messages":[{"role":"system","content":"Given a sports headline, provide the following fields in a JSON dictionary, where applicable: \"player\" (full name), \"team\", \"sport\", and \"gender\". In the case there is more than one team return an array of that dictionary. Do not include any markdown characters such as ```json and ```"},{"role":"user","content":"Christian McCaffrey given 2 year extension by 49ers"},{"role":"assistant","content":"{\n \"player\": \"Christian McCaffrey\",\n \"team\": \"San Francisco 49ers\",\n \"sport\": \"football\",\n \"gender\": \"male\"\n}"}]}
{"messages":[{"role":"system","content":"Given a sports headline, provide the following fields in a JSON dictionary, where applicable: \"player\" (full name), \"team\", \"sport\", and \"gender\". In the case there is more than one team return an array of that dictionary. Do not include any markdown characters such as ```json and ```"},{"role":"user","content":"Tucupita Marcano banned for life by MLB for betting on baseball"},{"role":"assistant","content":"{\n \"player\": \"Tucupita Marcano\",\n \"sport\": \"baseball\",\n \"gender\": \"male\"\n}"}]}
{"messages":[{"role":"system","content":"Given a sports headline, provide the following fields in a JSON dictionary, where applicable: \"player\" (full name), \"team\", \"sport\", and \"gender\". In the case there is more than one team return an array of that dictionary. Do not include any markdown characters such as ```json and ```"},{"role":"user","content":"Who will win the 2024 NBA finals? Predictions about the Celtics vs. Mavericks matchup"},{"role":"assistant","content":"[\n {\n \"team\": \"Celtics\",\n \"sport\": \"basketball\"\n },\n {\n \"team\": \"Mavericks\",\n \"sport\": \"basketball\"\n }\n]"}]}
{"messages":[{"role":"system","content":"Given a sports headline, provide the following fields in a JSON dictionary, where applicable: \"player\" (full name), \"team\", \"sport\", and \"gender\". In the case there is more than one team return an array of that dictionary. Do not include any markdown characters such as ```json and ```"},{"role":"user","content":"Pavleski will not play in 2024-2025 season"},{"role":"assistant","content":"{\n \"player\": \"Pavelski\",\n \"sport\": \"hockey\",\n \"gender\": \"male\"\n}"}]}
{"messages":[{"role":"system","content":"Given a sports headline, provide the following fields in a JSON dictionary, where applicable: \"player\" (full name), \"team\", \"sport\", and \"gender\". In the case there is more than one team return an array of that dictionary. Do not include any markdown characters such as ```json and ```"},{"role":"user","content":"Charges against Scottie Scheffler have been dropped after arrest"},{"role":"assistant","content":"{\n \"player\": \"Scottie Scheffler\",\n \"sport\": \"golf\",\n \"gender\": \"male\"\n}"}]}
{"messages":[{"role":"system","content":"Given a sports headline, provide the following fields in a JSON dictionary, where applicable: \"player\" (full name), \"team\", \"sport\", and \"gender\". In the case there is more than one team return an array of that dictionary. Do not include any markdown characters such as ```json and ```"},{"role":"user","content":"Perez picked by Red Bull for F1 2025 deal"},{"role":"assistant","content":"{\n \"player\": \"Perez\",\n \"team\": \"Red Bull\",\n \"sport\": \"F1\",\n \"gender\": \"male\"\n}"}]}
{"messages":[{"role":"system","content":"Given a sports headline, provide the following fields in a JSON dictionary, where applicable: \"player\" (full name), \"team\", \"sport\", and \"gender\". In the case there is more than one team return an array of that dictionary. Do not include any markdown characters such as ```json and ```"},{"role":"user","content":"DL Johnson III waived by 49ers"},{"role":"assistant","content":"{\n \"player\": \"DL Johnson III\",\n \"team\": \"49ers\",\n \"sport\": \"football\",\n \"gender\": \"male\"\n}"}]}
{"messages":[{"role":"system","content":"Given a sports headline, provide the following fields in a JSON dictionary, where applicable: \"player\" (full name), \"team\", \"sport\", and \"gender\". In the case there is more than one team return an array of that dictionary. Do not include any markdown characters such as ```json and ```"},{"role":"user","content":"Trevor Williams placed on injured list"},{"role":"assistant","content":"{\n \"player\": \"Trevor Williams\",\n \"sport\": \"baseball\",\n \"gender\": \"male\"\n}"}]}
{"messages":[{"role":"system","content":"Given a sports headline, provide the following fields in a JSON dictionary, where applicable: \"player\" (full name), \"team\", \"sport\", and \"gender\". In the case there is more than one team return an array of that dictionary. Do not include any markdown characters such as ```json and ```"},{"role":"user","content":"Coco Gauff, and Iga Swiatek will meet in French Open semis"},{"role":"assistant","content":"[\n {\n \"player\": \"Coco Gauff\",\n \"sport\": \"tennis\",\n \"gender\": \"female\"\n },\n {\n \"player\": \"Iga Swiatek\",\n \"sport\": \"tennis\",\n \"gender\": \"female\"\n }\n]"}]}
Loading