Skip to content

Commit

Permalink
feat: add OpenAI's new structured output API (#180)
Browse files Browse the repository at this point in the history
* feat: add OpenAI's new structured output API

* test(openai): add test cases for response_format

* fix: clean up tests

* chore: revert env comment from testing
  • Loading branch information
monotykamary authored Sep 23, 2024
1 parent 5b392d5 commit 72f93a6
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 24 deletions.
20 changes: 16 additions & 4 deletions lib/chat_models/chat_open_ai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ defmodule LangChain.ChatModels.ChatOpenAI do
# How many chat completion choices to generate for each input message.
field :n, :integer, default: 1
field :json_response, :boolean, default: false
field :json_schema, :map, default: nil
field :stream, :boolean, default: false
field :max_tokens, :integer, default: nil
# Options for streaming response. Only set this when you set `stream: true`
Expand Down Expand Up @@ -153,6 +154,7 @@ defmodule LangChain.ChatModels.ChatOpenAI do
:stream,
:receive_timeout,
:json_response,
:json_schema,
:max_tokens,
:stream_options,
:user,
Expand Down Expand Up @@ -263,11 +265,20 @@ defmodule LangChain.ChatModels.ChatOpenAI do
%{"include_usage" => Map.get(data, :include_usage, Map.get(data, "include_usage"))}
end

defp set_response_format(%ChatOpenAI{json_response: true}),
do: %{"type" => "json_object"}
defp set_response_format(%ChatOpenAI{json_response: true, json_schema: json_schema}) when not is_nil(json_schema) do
%{
"type" => "json_schema",
"json_schema" => json_schema
}
end

defp set_response_format(%ChatOpenAI{json_response: false}),
do: %{"type" => "text"}
defp set_response_format(%ChatOpenAI{json_response: true}) do
%{"type" => "json_object"}
end

defp set_response_format(%ChatOpenAI{json_response: false}) do
%{"type" => "text"}
end

@doc """
Convert a LangChain structure to the expected map of data for the OpenAI API.
Expand Down Expand Up @@ -908,6 +919,7 @@ defmodule LangChain.ChatModels.ChatOpenAI do
:seed,
:n,
:json_response,
:json_schema,
:stream,
:max_tokens,
:stream_options
Expand Down
30 changes: 16 additions & 14 deletions test/chains/data_extraction_chain_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ defmodule LangChain.Chains.DataExtractionChainTest do
FunctionParam.new!(%{name: "person_name", type: :string}),
FunctionParam.new!(%{name: "person_age", type: :number}),
FunctionParam.new!(%{name: "person_hair_color", type: :string}),
FunctionParam.new!(%{name: "dog_name", type: :string}),
FunctionParam.new!(%{name: "dog_breed", type: :string})
FunctionParam.new!(%{name: "pet_dog_name", type: :string}),
FunctionParam.new!(%{name: "pet_dog_breed", type: :string})
]
|> FunctionParam.to_parameters_schema()

Expand All @@ -31,8 +31,8 @@ defmodule LangChain.Chains.DataExtractionChainTest do
items: %{
"type" => "object",
"properties" => %{
"dog_breed" => %{"type" => "string"},
"dog_name" => %{"type" => "string"},
"pet_dog_breed" => %{"type" => "string"},
"pet_dog_name" => %{"type" => "string"},
"person_age" => %{"type" => "number"},
"person_hair_color" => %{"type" => "string"},
"person_name" => %{"type" => "string"}
Expand All @@ -55,32 +55,34 @@ defmodule LangChain.Chains.DataExtractionChainTest do
FunctionParam.new!(%{name: "person_name", type: :string}),
FunctionParam.new!(%{name: "person_age", type: :number}),
FunctionParam.new!(%{name: "person_hair_color", type: :string}),
FunctionParam.new!(%{name: "dog_name", type: :string}),
FunctionParam.new!(%{name: "dog_breed", type: :string})
FunctionParam.new!(%{name: "pet_dog_name", type: :string}),
FunctionParam.new!(%{name: "pet_dog_breed", type: :string})
]
|> FunctionParam.to_parameters_schema()

# Model setup - specify the model and seed
{:ok, chat} = ChatOpenAI.new(%{model: "gpt-4o", temperature: 0, seed: 0, stream: false})
{:ok, chat} = ChatOpenAI.new(%{model: "gpt-4o-mini-2024-07-18", temperature: 0, seed: 0, stream: false})

# run the chain, chain.run(prompt to extract data from)
data_prompt =
"Alex is 5 feet tall. Claudia is 4 feet taller than Alex and jumps higher than him.
Claudia is a brunette and Alex is blonde. Alex's dog Frosty is a labrador and likes to play hide and seek. Identify each person and their relevant information."
data_prompt = """
Alex is 5 feet tall. Claudia is 4 feet taller than Alex and jumps higher than him.
Claudia is a brunette and Alex is blonde.
Alex's dog Frosty is a labrador and likes to play hide and seek. Identify each person and their relevant information.
"""

{:ok, result} = DataExtractionChain.run(chat, schema_parameters, data_prompt, verbose: true)

assert result == [
%{
"dog_breed" => "labrador",
"dog_name" => "Frosty",
"pet_dog_breed" => "labrador",
"pet_dog_name" => "Frosty",
"person_age" => nil,
"person_hair_color" => "blonde",
"person_name" => "Alex"
},
%{
"dog_breed" => nil,
"dog_name" => nil,
"pet_dog_breed" => nil,
"pet_dog_name" => nil,
"person_age" => nil,
"person_hair_color" => "brunette",
"person_name" => "Claudia"
Expand Down
100 changes: 96 additions & 4 deletions test/chat_models/chat_open_ai_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ defmodule LangChain.ChatModels.ChatOpenAITest do
alias LangChain.Message.ToolCall
alias LangChain.Message.ToolResult

@test_model "gpt-3.5-turbo"
@test_model "gpt-4o-mini-2024-07-18"
@gpt4 "gpt-4-1106-preview"

defp hello_world(_args, _context) do
Expand Down Expand Up @@ -73,6 +73,25 @@ defmodule LangChain.ChatModels.ChatOpenAITest do

assert model.endpoint == override_url
end

test "supports setting json_response and json_schema" do
json_schema = %{
"type" => "object",
"properties" => %{
"name" => %{"type" => "string"},
"age" => %{"type" => "integer"}
}
}

{:ok, openai} = ChatOpenAI.new(%{
"model" => @test_model,
"json_response" => true,
"json_schema" => json_schema
})

assert openai.json_response == true
assert openai.json_schema == json_schema
end
end

describe "for_api/3" do
Expand Down Expand Up @@ -108,6 +127,34 @@ defmodule LangChain.ChatModels.ChatOpenAITest do
assert data.response_format == %{"type" => "json_object"}
end

test "generates a map for an API call with JSON response and schema" do
json_schema = %{
"type" => "object",
"properties" => %{
"name" => %{"type" => "string"},
"age" => %{"type" => "integer"}
}
}

{:ok, openai} =
ChatOpenAI.new(%{
"model" => @test_model,
"temperature" => 1,
"frequency_penalty" => 0.5,
"json_response" => true,
"json_schema" => json_schema
})

data = ChatOpenAI.for_api(openai, [], [])
assert data.model == @test_model
assert data.temperature == 1
assert data.frequency_penalty == 0.5
assert data.response_format == %{
"type" => "json_schema",
"json_schema" => json_schema
}
end

test "generates a map for an API call with max_tokens set" do
{:ok, openai} =
ChatOpenAI.new(%{
Expand Down Expand Up @@ -419,7 +466,7 @@ defmodule LangChain.ChatModels.ChatOpenAITest do
"description" => nil,
"enum" => ["yellow", "red", "green"],
"type" => "string"
}
}
},
"required" => ["p1"]
}
Expand Down Expand Up @@ -789,7 +836,7 @@ defmodule LangChain.ChatModels.ChatOpenAITest do
@tag live_call: true, live_open_ai: true
test "handles when request is too large" do
{:ok, chat} =
ChatOpenAI.new(%{model: "gpt-3.5-turbo-0301", seed: 0, stream: false, temperature: 1})
ChatOpenAI.new(%{model: "gpt-4-0613", seed: 0, stream: false, temperature: 1})

{:error, reason} = ChatOpenAI.call(chat, [too_large_user_request()])
assert reason =~ "maximum context length"
Expand Down Expand Up @@ -1330,7 +1377,7 @@ defmodule LangChain.ChatModels.ChatOpenAITest do
@tag live_call: true, live_open_ai: true
test "supports multi-modal user message with image prompt" do
# https://platform.openai.com/docs/guides/vision
{:ok, chat} = ChatOpenAI.new(%{model: "gpt-4-vision-preview", seed: 0})
{:ok, chat} = ChatOpenAI.new(%{model: "gpt-4o-2024-08-06", seed: 0})

url =
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
Expand Down Expand Up @@ -1891,8 +1938,53 @@ defmodule LangChain.ChatModels.ChatOpenAITest do
"stream_options" => %{"include_usage" => true},
"temperature" => 0.0,
"version" => 1,
"json_schema" => nil,
"module" => "Elixir.LangChain.ChatModels.ChatOpenAI"
}
end
end

describe "set_response_format/1" do
test "generates a map for an API call with text format when json_response is false" do
{:ok, openai} = ChatOpenAI.new(%{
model: @test_model,
json_response: false
})
data = ChatOpenAI.for_api(openai, [], [])

assert data.response_format == %{"type" => "text"}
end

test "generates a map for an API call with json_object format when json_response is true and no schema" do
{:ok, openai} = ChatOpenAI.new(%{
model: @test_model,
json_response: true
})
data = ChatOpenAI.for_api(openai, [], [])

assert data.response_format == %{"type" => "json_object"}
end

test "generates a map for an API call with json_schema format when json_response is true and schema is provided" do
json_schema = %{
"type" => "object",
"properties" => %{
"name" => %{"type" => "string"},
"age" => %{"type" => "integer"}
}
}

{:ok, openai} = ChatOpenAI.new(%{
model: @test_model,
json_response: true,
json_schema: json_schema
})
data = ChatOpenAI.for_api(openai, [], [])

assert data.response_format == %{
"type" => "json_schema",
"json_schema" => json_schema
}
end
end
end
2 changes: 1 addition & 1 deletion test/message_delta_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ defmodule LangChain.MessageDeltaTest do
status: :incomplete,
type: :function,
call_id: "toolu_123",
name: "get_codeget_codeget_codeget_codeget_code",
name: "get_code",
arguments: "{\"code\": \"def my_function(x):\n return x + 1\"}",
index: 1
}
Expand Down
2 changes: 1 addition & 1 deletion test/support/fixtures.ex
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ defmodule LangChain.Fixtures do
end

def too_large_user_request() do
Message.new_user!("Analyze the following text: \n\n" <> text_chunks(8))
Message.new_user!("Analyze the following text: \n\n" <> text_chunks(16))
end

def results_in_too_long_response() do
Expand Down

0 comments on commit 72f93a6

Please sign in to comment.