Skip to content

Commit

Permalink
support for o1 OpenAI model (#234)
Browse files Browse the repository at this point in the history
- adds new "reasoning_mode". Sends "developer" role instead of "system" role.
- adds support for "reasoning_effort" API option
- refactors many of the for_api functions to include the model. The reasoning model changes how things are sent
  • Loading branch information
brainlid authored Jan 15, 2025
1 parent 9ee95a7 commit 45708db
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 117 deletions.
7 changes: 6 additions & 1 deletion lib/chat_models/chat_google_ai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,12 @@ defmodule LangChain.ChatModels.ChatGoogleAI do
end

def for_api(%Function{} = function) do
encoded = ChatOpenAI.for_api(function)
encoded =
%{
"name" => function.name,
"parameters" => ChatOpenAI.get_parameters(function)
}
|> Utils.conditionally_add_to_map("description", function.description)

# For functions with no parameters, Google AI needs the parameters field removing, otherwise it will error
# with "* GenerateContentRequest.tools[0].function_declarations[0].parameters.properties: should be non-empty for OBJECT type\n"
Expand Down
2 changes: 1 addition & 1 deletion lib/chat_models/chat_mistral_ai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ defmodule LangChain.ChatModels.ChatMistralAI do
top_p: mistral.top_p,
safe_prompt: mistral.safe_prompt,
stream: mistral.stream,
messages: Enum.map(messages, &ChatOpenAI.for_api/1)
messages: Enum.map(messages, &ChatOpenAI.for_api(mistral, &1))
}
|> Utils.conditionally_add_to_map(:random_seed, mistral.random_seed)
|> Utils.conditionally_add_to_map(:max_tokens, mistral.max_tokens)
Expand Down
2 changes: 1 addition & 1 deletion lib/chat_models/chat_ollama_ai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ defmodule LangChain.ChatModels.ChatOllamaAI do
%{
model: model.model,
temperature: model.temperature,
messages: messages |> Enum.map(&ChatOpenAI.for_api/1),
messages: messages |> Enum.map(&ChatOpenAI.for_api(model, &1)),
stream: model.stream,
seed: model.seed,
num_ctx: model.num_ctx,
Expand Down
157 changes: 114 additions & 43 deletions lib/chat_models/chat_open_ai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,26 @@ defmodule LangChain.ChatModels.ChatOpenAI do
`https://some-subdomain.cognitiveservices.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2024-08-01-preview"`
## Reasoning Model Support
OpenAI made some significant API changes with the introduction of their "reasoning" models. This includes the `o1` and `o1-mini` models.
To enable this mode, set `:reasoning_mode` to `true`:
model = ChatOpenAI.new!(%{reasoning_mode: true})
Setting `reasoning_mode` to `true` does at least the two following things:
- Set `:developer` as the `role` for system messages. The OpenAI documentation says API calls to `o1` and newer models must use the `role: :developer` instead of `role: :system` and errors if not set correctly.
- The `:reasoning_effort` option included in LLM requests. This setting is only permitted on a reasoning model. The `:reasoning_effort` values support the "low", "medium" (default), and "high" options specified in the OpenAI documentation. This instructs the LLM on how much time, and tokens, should be spent on thinking through and reasoning about the request and the response.
"""
use Ecto.Schema
require Logger
import Ecto.Changeset
alias __MODULE__
alias LangChain.Config
alias LangChain.ChatModels.ChatModel
alias LangChain.PromptTemplate
alias LangChain.Message
alias LangChain.Message.ContentPart
alias LangChain.Message.ToolCall
Expand Down Expand Up @@ -155,6 +168,19 @@ defmodule LangChain.ChatModels.ChatOpenAI do
# their existing frequency in the text so far, decreasing the model's
# likelihood to repeat the same line verbatim.
field :frequency_penalty, :float, default: 0.0

# Used when working with a reasoning model like `o1` and newer. This setting
# is required when working with those models as the API behavior needs to
# change.
field :reasoning_mode, :boolean, default: false

# o1 models only
#
# Constrains effort on reasoning for reasoning models. Currently supported
# values are `low`, `medium`, and `high`. Reducing reasoning effort can result in
# faster responses and fewer tokens used on reasoning in a response.
field :reasoning_effort, :string, default: "medium"

# Duration in seconds for the response to be received. When streaming a very
# lengthy response, a longer time limit may be required. However, when it
# goes on too long by itself, it tends to hallucinate more.
Expand Down Expand Up @@ -198,6 +224,8 @@ defmodule LangChain.ChatModels.ChatOpenAI do
:seed,
:n,
:stream,
:reasoning_mode,
:reasoning_effort,
:receive_timeout,
:json_response,
:json_schema,
Expand Down Expand Up @@ -275,7 +303,7 @@ defmodule LangChain.ChatModels.ChatOpenAI do
messages:
messages
|> Enum.reduce([], fn m, acc ->
case for_api(m) do
case for_api(openai, m) do
%{} = data ->
[data | acc]

Expand All @@ -287,22 +315,26 @@ defmodule LangChain.ChatModels.ChatOpenAI do
response_format: set_response_format(openai),
user: openai.user
}
|> Utils.conditionally_add_to_map(
:reasoning_effort,
if(openai.reasoning_mode, do: openai.reasoning_effort, else: nil)
)
|> Utils.conditionally_add_to_map(:max_tokens, openai.max_tokens)
|> Utils.conditionally_add_to_map(:seed, openai.seed)
|> Utils.conditionally_add_to_map(
:stream_options,
get_stream_options_for_api(openai.stream_options)
)
|> Utils.conditionally_add_to_map(:tools, get_tools_for_api(tools))
|> Utils.conditionally_add_to_map(:tools, get_tools_for_api(openai, tools))
|> Utils.conditionally_add_to_map(:tool_choice, get_tool_choice(openai))
end

defp get_tools_for_api(nil), do: []
defp get_tools_for_api(%_{} = _model, nil), do: []

defp get_tools_for_api(tools) do
defp get_tools_for_api(%_{} = model, tools) do
Enum.map(tools, fn
%Function{} = function ->
%{"type" => "function", "function" => for_api(function)}
%{"type" => "function", "function" => for_api(model, function)}
end)
end

Expand Down Expand Up @@ -341,48 +373,44 @@ defmodule LangChain.ChatModels.ChatOpenAI do
defp get_tool_choice(%ChatOpenAI{}), do: nil

@doc """
Convert a LangChain structure to the expected map of data for the OpenAI API.
Convert a LangChain Message-based structure to the expected map of data for
the OpenAI API. This happens within the context of the model configuration as
well. The additional context is needed to correctly convert a role to either
`:system` or `:developer`.
NOTE: The `ChatOpenAI` model's functions are reused in other modules. For this
reason, model is more generally defined as a struct.
"""
@spec for_api(Message.t() | ContentPart.t() | Function.t()) ::
@spec for_api(
struct(),
Message.t()
| PromptTemplate.t()
| ToolCall.t()
| ToolResult.t()
| ContentPart.t()
| Function.t()
) ::
%{String.t() => any()} | [%{String.t() => any()}]
def for_api(%Message{role: :assistant, tool_calls: tool_calls} = msg)
when is_list(tool_calls) do
%{
"role" => :assistant,
"content" => msg.content
}
|> Utils.conditionally_add_to_map("tool_calls", Enum.map(tool_calls, &for_api(&1)))
end

def for_api(%Message{role: :tool, tool_results: tool_results} = _msg)
when is_list(tool_results) do
# ToolResults turn into a list of tool messages for OpenAI
Enum.map(tool_results, fn result ->
%{
"role" => :tool,
"tool_call_id" => result.tool_call_id,
"content" => result.content
}
end)
end
def for_api(%_{} = model, %Message{content: content} = msg) when is_binary(content) do
role = get_message_role(model, msg.role)

def for_api(%Message{content: content} = msg) when is_binary(content) do
%{
"role" => msg.role,
"role" => role,
"content" => msg.content
}
|> Utils.conditionally_add_to_map("name", msg.name)
end

def for_api(%Message{role: :user, content: content} = msg) when is_list(content) do
def for_api(%_{} = model, %Message{role: :user, content: content} = msg)
when is_list(content) do
%{
"role" => msg.role,
"content" => Enum.map(content, &for_api(&1))
"content" => Enum.map(content, &for_api(model, &1))
}
|> Utils.conditionally_add_to_map("name", msg.name)
end

def for_api(%ToolResult{type: :function} = result) do
def for_api(%_{} = _model, %ToolResult{type: :function} = result) do
# a ToolResult becomes a stand-alone %Message{role: :tool} response.
%{
"role" => :tool,
Expand All @@ -391,15 +419,33 @@ defmodule LangChain.ChatModels.ChatOpenAI do
}
end

def for_api(%LangChain.PromptTemplate{} = _template) do
raise LangChain.LangChainError, "PromptTemplates must be converted to messages."
def for_api(%_{} = model, %Message{role: :assistant, tool_calls: tool_calls} = msg)
when is_list(tool_calls) do
%{
"role" => :assistant,
"content" => msg.content
}
|> Utils.conditionally_add_to_map("tool_calls", Enum.map(tool_calls, &for_api(model, &1)))
end

def for_api(%ContentPart{type: :text} = part) do
def for_api(%_{} = _model, %Message{role: :tool, tool_results: tool_results} = _msg)
when is_list(tool_results) do
# ToolResults turn into a list of tool messages for OpenAI
Enum.map(tool_results, fn result ->
%{
"role" => :tool,
"tool_call_id" => result.tool_call_id,
"content" => result.content
}
end)
end

def for_api(%_{} = _model, %ContentPart{type: :text} = part) do
%{"type" => "text", "text" => part.content}
end

def for_api(%ContentPart{type: image} = part) when image in [:image, :image_url] do
def for_api(%_{} = _model, %ContentPart{type: image} = part)
when image in [:image, :image_url] do
media_prefix =
case Keyword.get(part.options || [], :media, nil) do
nil ->
Expand Down Expand Up @@ -437,7 +483,7 @@ defmodule LangChain.ChatModels.ChatOpenAI do
end

# ToolCall support
def for_api(%ToolCall{type: :function} = fun) do
def for_api(%_{} = _model, %ToolCall{type: :function} = fun) do
%{
"id" => fun.call_id,
"type" => "function",
Expand All @@ -449,30 +495,41 @@ defmodule LangChain.ChatModels.ChatOpenAI do
end

# Function support
def for_api(%Function{} = fun) do
def for_api(%_{} = _model, %Function{} = fun) do
%{
"name" => fun.name,
"parameters" => get_parameters(fun)
}
|> Utils.conditionally_add_to_map("description", fun.description)
end

defp get_parameters(%Function{parameters: [], parameters_schema: nil} = _fun) do
def for_api(%_{} = _model, %PromptTemplate{} = _template) do
raise LangChain.LangChainError, "PromptTemplates must be converted to messages."
end

@doc false
def get_parameters(%Function{parameters: [], parameters_schema: nil} = _fun) do
%{
"type" => "object",
"properties" => %{}
}
end

defp get_parameters(%Function{parameters: [], parameters_schema: schema} = _fun)
when is_map(schema) do
def get_parameters(%Function{parameters: [], parameters_schema: schema} = _fun)
when is_map(schema) do
schema
end

defp get_parameters(%Function{parameters: params} = _fun) do
def get_parameters(%Function{parameters: params} = _fun) do
FunctionParam.to_parameters_schema(params)
end

# Convert a message role into either `:system` or :developer` based on the
# message role and the system config.
defp get_message_role(%ChatOpenAI{reasoning_mode: true}, :system), do: :developer
defp get_message_role(%ChatOpenAI{}, role), do: role
defp get_message_role(_model, role), do: role

@doc """
Calls the OpenAI API passing the ChatOpenAI struct with configuration, plus
either a simple message or the list of messages to act as the prompt.
Expand Down Expand Up @@ -889,12 +946,24 @@ defmodule LangChain.ChatModels.ChatOpenAI do
# MS Azure returns numeric error codes. Interpret them when possible to give a computer-friendly reason
#
# https://learn.microsoft.com/en-us/troubleshoot/azure/azure-kubernetes/create-upgrade-delete/429-too-many-requests-errors
def do_process_response(_model, %{"error" => %{"code" => code, "message" => reason}}) do
def do_process_response(_model, %{
"error" => %{"code" => code, "message" => reason} = error_data
}) do
type =
case code do
"429" ->
"rate_limit_exceeded"

"unsupported_value" ->
if String.contains?(reason, "does not support 'system' with this model") do
Logger.error(
"This model requires 'reasoning_mode' to be enabled. Reason: #{inspect(reason)}"
)

# return the API error type as the exception type information
error_data["type"]
end

_other ->
nil
end
Expand Down Expand Up @@ -996,6 +1065,8 @@ defmodule LangChain.ChatModels.ChatOpenAI do
:model,
:temperature,
:frequency_penalty,
:reasoning_mode,
:reasoning_effort,
:receive_timeout,
:seed,
:n,
Expand Down
2 changes: 1 addition & 1 deletion lib/chat_models/chat_vertex_ai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ defmodule LangChain.ChatModels.ChatVertexAI do
%{
# Google AI functions use an OpenAI compatible format.
# See: https://ai.google.dev/docs/function_calling#how_it_works
"functionDeclarations" => Enum.map(functions, &ChatOpenAI.for_api/1)
"functionDeclarations" => Enum.map(functions, &ChatOpenAI.for_api(vertex_ai, &1))
}
])
else
Expand Down
Loading

0 comments on commit 45708db

Please sign in to comment.