Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for viewer-based credentials for the Cortex chatbot #150

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@ Suggests:
rmarkdown,
shiny,
shinychat (>= 0.0.0.9000),
snowflakeauth (>= 0.0.0.9000),
testthat (>= 3.0.0),
withr
VignetteBuilder:
knitr
Remotes:
jcheng5/shinychat
jcheng5/shinychat,
atheriel/snowflakeauth
Config/Needs/website: tidyverse/tidytemplate, rmarkdown
Config/testthat/edition: 3
Config/testthat/parallel: true
Expand Down Expand Up @@ -79,4 +81,5 @@ Collate:
'utils-cat.R'
'utils-merge.R'
'utils.R'
'viewer-based-credentials.R'
'zzz.R'
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ export(chat_perplexity)
export(content_image_file)
export(content_image_plot)
export(content_image_url)
export(cortex_credentials)
export(create_tool_def)
export(interpolate)
export(interpolate_file)
Expand Down
204 changes: 72 additions & 132 deletions R/provider-cortex.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,49 +20,61 @@ NULL
#' previous messages. Nor does it support registering tools, and attempting to
#' do so will result in an error.
#'
#' @param account A Snowflake [account identifier](https://docs.snowflake.com/en/user-guide/admin-account-identifier),
#' e.g. `"testorg-test_account"`.
#' @param credentials A list of authentication headers to pass into
#' [`httr2::req_headers()`] or a function that returns them when passed
#' `account` as a parameter. The default [`cortex_credentials()`] function
#' picks up ambient Snowflake OAuth and key-pair authentication credentials
#' and handles refreshing them automatically.
#' By default we pick up on Snowflake connection parameters defined in the same
#' `connections.toml` file used by the [Python Connector for
#' Snowflake](https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect)
#' and the [Snowflake
#' CLI](https://docs.snowflake.com/en/developer-guide/snowflake-cli/connecting/configure-connections),
#' though connection parameters can be passed manually to
#' [snowflakeauth::snowflake_connection()], too. Keep in mind that Cortex
#' itself only supports OAuth and key-pair authentication.
#'
#' @param model_spec A semantic model specification, or `NULL` when
#' using `model_file` instead.
#' @param model_file Path to a semantic model file stored in a Snowflake Stage,
#' or `NULL` when using `model_spec` instead.
#' @param ... Further arguments passed to [snowflakeauth::snowflake_connection()].
#' @param session A Shiny session object, when using viewer-based credentials on
#' Posit Connect.
#' @inheritParams chat_openai
#' @inherit chat_openai return
#' @family chatbots
#' @examplesIf elmer:::cortex_credentials_exist()
#' @examplesIf FALSE
#' # Authenticate with Snowflake using an existing connections.toml file.
#' chat <- chat_cortex(
#' model_file = "@my_db.my_schema.my_stage/model.yaml"
#' )
#' chat$chat("What questions can I ask?")
#'
#' # Or pass connection parameters manually. For example, to use key-pair
#' # authentication:
#' chat <- chat_cortex(
#' model_file = "@my_db.my_schema.my_stage/model.yaml",
#' account = "myaccount",
#' user = "me",
#' private_key = "rsa_key.p8"
#' )
#' @export
chat_cortex <- function(account = Sys.getenv("SNOWFLAKE_ACCOUNT"),
credentials = cortex_credentials,
model_spec = NULL,
chat_cortex <- function(model_spec = NULL,
model_file = NULL,
api_args = list(),
echo = c("none", "text", "all")) {
check_string(account, allow_empty = FALSE)
echo = c("none", "text", "all"),
...,
session = NULL) {
check_installed("snowflakeauth", "for Snowflake authentication")
check_string(model_spec, allow_empty = FALSE, allow_null = TRUE)
check_string(model_file, allow_empty = FALSE, allow_null = TRUE)
check_exclusive(model_spec, model_file)
echo <- check_echo(echo)
check_shiny_session(session, allow_null = TRUE, call = call)

if (is_list(credentials)) {
static_credentials <- force(credentials)
credentials <- function(account) static_credentials
}
check_function(credentials)
connection <- snowflakeauth::snowflake_connection(..., .call = current_env())

provider <- ProviderCortex(
account = account,
credentials = credentials,
connection = connection,
model_spec = model_spec,
model_file = model_file,
session = session,
extra_args = api_args
)

Expand All @@ -73,27 +85,58 @@ ProviderCortex <- new_class(
"ProviderCortex",
parent = Provider,
package = "elmer",
constructor = function(account, credentials, model_spec = NULL,
model_file = NULL, extra_args = list()) {
base_url <- paste0("https://", account, ".snowflakecomputing.com")
constructor = function(connection,
model_spec = NULL,
model_file = NULL,
session = NULL,
extra_args = list()) {
extra_args <- compact(list2(
semantic_model = model_spec,
semantic_model_file = model_file,
!!!extra_args
))
if (!is.null(session)) {
# If viewer-based authentication is enabled, check whether we can actually
# get credentials. If we can, then make sure the authenticator is OAuth.
access_token <- connect_viewer_token(session, snowflake_url(connection))
if (!is.null(access_token)) {
connection$authenticator <- "oauth"
connection$user <- "placeholder"
connection$token <- access_token
} else {
session <- NULL
}
}
new_object(
Provider(base_url = base_url, extra_args = extra_args),
account = account,
credentials = credentials
Provider(base_url = snowflake_url(connection), extra_args = extra_args),
connection = connection,
session = session
)
},
properties = list(
account = prop_string(),
credentials = class_function,
connection = class_list,
session = class_list | NULL,
credentials = new_property(class_list, getter = function(self) {
if (!is.null(self@session)) {
# TODO: Right now we ask Connect for an up-to-date token before each
# Cortex request. Instead, we should request a new token only when the
# cached one has expired -- but right now there is no way to know when
# this occurs.
self@connection$token <- connect_viewer_token(
self@session,
snowflake_url(self@connection)
)
}
snowflakeauth::snowflake_credentials(self@connection)
}),
extra_args = class_list
)
)

snowflake_url <- function(connection) {
paste0("https://", connection$account, ".snowflakecomputing.com")
}

# See: https://docs.snowflake.com/en/developer-guide/snowflake-rest-api/reference/cortex-analyst
# https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-analyst/tutorials/tutorial-1#step-3-create-a-streamlit-app-to-talk-to-your-data-through-cortex-analyst
method(chat_request, ProviderCortex) <- function(provider,
Expand All @@ -112,7 +155,7 @@ method(chat_request, ProviderCortex) <- function(provider,
req <- request(provider@base_url)
req <- req_url_path_append(req, "/api/v2/cortex/analyst/message")
req <- httr2::req_headers(req,
!!!provider@credentials(provider@account), .redact = "Authorization"
!!!provider@credentials, .redact = "Authorization"
)
req <- req_retry(req, max_tries = 2)
req <- req_timeout(req, 60)
Expand Down Expand Up @@ -355,106 +398,3 @@ cortex_message_to_turn <- function(message, error_call = caller_env()) {
})
)
}

# Credential handling ----------------------------------------------------------

cortex_credentials_exist <- function(...) {
tryCatch(is_list(cortex_credentials(...)), error = function(e) FALSE)
}

#' @details
#' `cortex_credentials()` picks up the following ambient Snowflake credentials:
#'
#' - A static OAuth token defined via the `SNOWFLAKE_TOKEN` environment
#' variable.
#' - Key-pair authentication credentials defined via the `SNOWFLAKE_USER` and
#' `SNOWFLAKE_PRIVATE_KEY` (which can be a PEM-encoded private key or a path
#' to one) environment variables.
#' - Posit Workbench-managed Snowflake credentials for the corresponding
#' `account`.
#'
#' @inheritParams chat_cortex
#' @export
#' @rdname chat_cortex
cortex_credentials <- function(account = Sys.getenv("SNOWFLAKE_ACCOUNT")) {
token <- Sys.getenv("SNOWFLAKE_TOKEN")
if (nchar(token) != 0) {
return(
list(
Authorization = paste("Bearer", token),
# See: https://docs.snowflake.com/en/developer-guide/snowflake-rest-api/authentication#using-oauth
`X-Snowflake-Authorization-Token-Type` = "OAUTH"
)
)
}

# Support for Snowflake key-pair authentication.
# See: https://docs.snowflake.com/en/developer-guide/snowflake-rest-api/authentication#generate-a-jwt-token
user <- Sys.getenv("SNOWFLAKE_USER")
private_key <- Sys.getenv("SNOWFLAKE_PRIVATE_KEY")
if (nchar(user) != 0 && nchar(private_key) != 0) {
check_installed("jose", "for key-pair authentication")
key <- openssl::read_key(private_key)
# We can't use openssl::fingerprint() here because it uses a different
# algorithm.
fp <- openssl::base64_encode(
openssl::sha256(openssl::write_der(key$pubkey))
)
sub <- toupper(paste0(account, ".", user))
iss <- paste0(sub, ".SHA256:", fp)
# Note: Snowflake employs a malformed issuer claim, so we have to inject it
# manually after jose's validation phase.
claim <- httr2::jwt_claim("dummy", sub)
claim$iss <- iss
token <- httr2::jwt_encode_sig(claim, key)
return(
list(
Authorization = paste("Bearer", token),
`X-Snowflake-Authorization-Token-Type` = "KEYPAIR_JWT"
)
)
}

# Check for Workbench-managed credentials.
sf_home <- Sys.getenv("SNOWFLAKE_HOME")
if (grepl("posit-workbench", sf_home, fixed = TRUE)) {
token <- workbench_snowflake_token(account, sf_home)
if (!is.null(token)) {
return(list(
Authorization = paste("Bearer", token),
`X-Snowflake-Authorization-Token-Type` = "OAUTH"
))
}
}

if (is_testing()) {
testthat::skip("no Snowflake credentials available")
}

cli::cli_abort("No Snowflake credentials are available.")
}

# Reads Posit Workbench-managed Snowflake credentials from a
# $SNOWFLAKE_HOME/connections.toml file, as used by the Snowflake Connector for
# Python implementation. The file will look as follows:
#
# [workbench]
# account = "account-id"
# token = "token"
# authenticator = "oauth"
workbench_snowflake_token <- function(account, sf_home) {
cfg <- readLines(file.path(sf_home, "connections.toml"))
# We don't attempt a full parse of the TOML syntax, instead relying on the
# fact that this file will always contain only one section.
if (!any(grepl(account, cfg, fixed = TRUE))) {
# The configuration doesn't actually apply to this account.
return(NULL)
}
line <- grepl("token = ", cfg, fixed = TRUE)
token <- gsub("token = ", "", cfg[line])
if (nchar(token) == 0) {
return(NULL)
}
# Drop enclosing quotes.
gsub("\"", "", token)
}
74 changes: 74 additions & 0 deletions R/viewer-based-credentials.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Request an OAuth access token for the given resource from Posit Connect. The
# OAuth token will belong to the user owning the given Shiny session.
connect_viewer_token <- function(session, resource) {
if (!running_on_connect()) {
cli::cli_inform(c(
"!" = "Ignoring the {.arg sesssion} parameter.",
"i" = "Viewer-based credentials are only available when running on Connect."
))
## ), .frequency = "once", .frequency_id = "session param")
return(NULL)
}

# Older versions or certain configurations of Connect might not supply a user
# session token.
server_url <- Sys.getenv("CONNECT_SERVER")
token <- session$request$HTTP_POSIT_CONNECT_USER_SESSION_TOKEN
if (is.null(token) || nchar(server_url) == 0) {
cli::cli_abort(
"Viewer-based credentials are not supported by this version of Connect."
)
}

# See: https://docs.posit.co/connect/api/#post-/v1/oauth/integrations/credentials
req <- httr2::request(server_url)
req <- httr2::req_url_path_append(
req, "__api__/v1/oauth/integrations/credentials"
)
req <- httr2::req_headers(req,
Authorization = paste("Key", Sys.getenv("CONNECT_API_KEY")),
.redact = "Authorization"
)
req <- httr2::req_body_form(
req,
grant_type = "urn:ietf:params:oauth:grant-type:token-exchange",
subject_token_type = "urn:posit:connect:user-session-token",
subject_token = token,
resource = resource
)

# TODO: Do we need more precise error handling?
req <- httr2::req_error(
req, body = function(resp) httr2::resp_body_json(resp)$error
)

resp <- httr2::resp_body_json(httr2::req_perform(req))
resp$access_token
}

running_on_connect <- function() {
Sys.getenv("RSTUDIO_PRODUCT") == "CONNECT"
}

check_shiny_session <- function(x,
...,
allow_null = FALSE,
arg = caller_arg(x),
call = caller_env()) {
if (!missing(x)) {
if (inherits(x, "ShinySession")) {
return(invisible(NULL))
}
if (allow_null && is_null(x)) {
return(invisible(NULL))
}
}
stop_input_type(
x,
"a Shiny session object",
...,
allow_null = allow_null,
arg = arg,
call = call
)
}
Loading
Loading