Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for viewer-based credentials for Databricks & Cortex #183

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Imports:
Suggests:
base64enc,
bslib,
connectcreds,
curl (>= 6.0.1),
gitcreds,
knitr,
Expand All @@ -39,7 +40,8 @@ VignetteBuilder:
knitr
Remotes:
r-lib/httr2,
jcheng5/shinychat
jcheng5/shinychat,
posit-dev/connectcreds
Config/Needs/website: tidyverse/tidytemplate, rmarkdown
Config/testthat/edition: 3
Config/testthat/parallel: true
Expand Down
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ export(chat_perplexity)
export(content_image_file)
export(content_image_plot)
export(content_image_url)
export(cortex_credentials)
export(create_tool_def)
export(interpolate)
export(interpolate_file)
Expand Down
76 changes: 51 additions & 25 deletions R/provider-cortex.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ NULL
#' previous messages. Nor does it support registering tools, and attempting to
#' do so will result in an error.
#'
#' `chat_cortex()` picks up the following ambient Snowflake credentials:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it's worth having a standard ## Auth header?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to do that.

#'
#' - A static OAuth token defined via the `SNOWFLAKE_TOKEN` environment
#' variable.
#' - Key-pair authentication credentials defined via the `SNOWFLAKE_USER` and
#' `SNOWFLAKE_PRIVATE_KEY` (which can be a PEM-encoded private key or a path
#' to one) environment variables.
#' - Posit Workbench-managed Snowflake credentials for the corresponding
#' `account`.
#'
#' @param account A Snowflake [account identifier](https://docs.snowflake.com/en/user-guide/admin-account-identifier),
#' e.g. `"testorg-test_account"`.
#' @param credentials A list of authentication headers to pass into
Expand All @@ -32,6 +42,7 @@ NULL
#' @param model_file Path to a semantic model file stored in a Snowflake Stage,
#' or `NULL` when using `model_spec` instead.
#' @inheritParams chat_openai
#' @inheritParams chat_databricks
#' @inherit chat_openai return
#' @family chatbots
#' @examplesIf elmer:::cortex_credentials_exist()
Expand All @@ -41,40 +52,53 @@ NULL
#' chat$chat("What questions can I ask?")
#' @export
chat_cortex <- function(account = Sys.getenv("SNOWFLAKE_ACCOUNT"),
credentials = cortex_credentials,
credentials = NULL,
model_spec = NULL,
model_file = NULL,
api_args = list(),
echo = c("none", "text", "all")) {
echo = c("none", "text", "all"),
session = NULL) {
check_string(account, allow_empty = FALSE)
check_string(model_spec, allow_empty = FALSE, allow_null = TRUE)
check_string(model_file, allow_empty = FALSE, allow_null = TRUE)
check_exclusive(model_spec, model_file)
echo <- check_echo(echo)
if (!is.null(session)) {
check_installed("connectcreds", "for viewer-based authentication")
if (!connectcreds::has_viewer_token(session, snowflake_url(account))) {
session <- NULL
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this just null session out with no warning/error?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has_viewer_token() emits a standard warning. I don't love this design, since I think it's non-obvious for the caller, but it does ensure that the warning is consistent across packages and we can limit how often it is shown in one place.

Does that approach make sense to you? Or should I add some sort of connectcreds::show_ignore_message() API for this and have has_viewer_token() keep quiet?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about something more like session <- connectcreds::update_session_token(...)? (and drop the if)

}
}

if (is_list(credentials)) {
static_credentials <- force(credentials)
credentials <- function(account) static_credentials
}
check_function(credentials)
check_function(credentials, allow_null = TRUE)

provider <- ProviderCortex(
account = account,
credentials = credentials,
model_spec = model_spec,
model_file = model_file,
extra_args = api_args
extra_args = api_args,
session = session
)

Chat$new(provider = provider, turns = NULL, echo = echo)
}

snowflake_url <- function(account) {
paste0("https://", account, ".snowflakecomputing.com")
}

ProviderCortex <- new_class(
"ProviderCortex",
parent = Provider,
constructor = function(account, credentials, model_spec = NULL,
model_file = NULL, extra_args = list()) {
base_url <- paste0("https://", account, ".snowflakecomputing.com")
model_file = NULL, extra_args = list(),
session = NULL) {
base_url <- snowflake_url(account)
extra_args <- compact(list2(
semantic_model = model_spec,
semantic_model_file = model_file,
Expand All @@ -88,8 +112,9 @@ ProviderCortex <- new_class(
},
properties = list(
account = prop_string(),
credentials = class_function,
extra_args = class_list
credentials = class_function | NULL,
extra_args = class_list,
session = class_list | NULL
)
)

Expand All @@ -110,9 +135,12 @@ method(chat_request, ProviderCortex) <- function(provider,

req <- request(provider@base_url)
req <- req_url_path_append(req, "/api/v2/cortex/analyst/message")
req <- httr2::req_headers(req,
!!!provider@credentials(provider@account), .redact = "Authorization"
creds <- cortex_credentials(
provider@account,
provider@credentials,
provider@session
)
req <- httr2::req_headers(req, !!!creds, .redact = "Authorization")
req <- req_retry(req, max_tries = 2)
req <- req_timeout(req, 60)

Expand Down Expand Up @@ -348,21 +376,19 @@ cortex_credentials_exist <- function(...) {
tryCatch(is_list(cortex_credentials(...)), error = function(e) FALSE)
}

#' @details
#' `cortex_credentials()` picks up the following ambient Snowflake credentials:
#'
#' - A static OAuth token defined via the `SNOWFLAKE_TOKEN` environment
#' variable.
#' - Key-pair authentication credentials defined via the `SNOWFLAKE_USER` and
#' `SNOWFLAKE_PRIVATE_KEY` (which can be a PEM-encoded private key or a path
#' to one) environment variables.
#' - Posit Workbench-managed Snowflake credentials for the corresponding
#' `account`.
#'
#' @inheritParams chat_cortex
#' @export
#' @rdname chat_cortex
cortex_credentials <- function(account = Sys.getenv("SNOWFLAKE_ACCOUNT")) {
cortex_credentials <- function(account = Sys.getenv("SNOWFLAKE_ACCOUNT"),
credentials = NULL,
session = NULL) {
# Session credentials take precedence over static credentials.
if (!is.null(session)) {
return(connectcreds::connect_viewer_token(session, snowflake_url(account)))
}

# User-supplied credentials.
if (!is.null(credentials)) {
return(credentials(account))
}

token <- Sys.getenv("SNOWFLAKE_TOKEN")
if (nchar(token) != 0) {
return(
Expand Down
30 changes: 25 additions & 5 deletions R/provider-databricks.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
#' - `databricks-meta-llama-3-1-405b-instruct`
#' @param token An authentication token for the Databricks workspace, or
#' `NULL` to use ambient credentials.
#' @param session A Shiny session object, when using viewer-based credentials on
#' Posit Connect.
#' @inheritParams chat_openai
#' @inherit chat_openai return
#' @export
Expand All @@ -44,28 +46,39 @@ chat_databricks <- function(workspace = databricks_workspace(),
model = NULL,
token = NULL,
api_args = list(),
echo = c("none", "text", "all")) {
echo = c("none", "text", "all"),
session = NULL) {
check_string(workspace, allow_empty = FALSE)
check_string(token, allow_empty = FALSE, allow_null = TRUE)
model <- set_default(model, "databricks-dbrx-instruct")
turns <- normalize_turns(turns, system_prompt)
echo <- check_echo(echo)
if (!is.null(session)) {
check_installed("connectcreds", "for viewer-based authentication")
if (!connectcreds::has_viewer_token(session, workspace)) {
session <- NULL
}
}
provider <- ProviderDatabricks(
base_url = workspace,
model = model,
extra_args = api_args,
token = token,
# Databricks APIs use bearer tokens, not API keys, but we need to pass an
# empty string here anyway to make S7::validate() happy.
api_key = ""
api_key = "",
session = session
)
Chat$new(provider = provider, turns = turns, echo = echo)
}

ProviderDatabricks <- new_class(
"ProviderDatabricks",
parent = ProviderOpenAI,
properties = list(token = prop_string(allow_null = TRUE))
properties = list(
token = prop_string(allow_null = TRUE),
session = class_list | NULL
)
)

method(chat_request, ProviderDatabricks) <- function(provider,
Expand All @@ -80,7 +93,7 @@ method(chat_request, ProviderDatabricks) <- function(provider,
# `/serving-endpoints/<model>/invocations`.
req <- req_url_path_append(req, "/serving-endpoints/chat/completions")
req <- req_auth_bearer_token(req,
databricks_token(provider@base_url, provider@token)
databricks_token(provider@base_url, provider@token, provider@session)
)
req <- req_retry(req, max_tries = 2)
req <- req_error(req, body = function(resp) {
Expand Down Expand Up @@ -165,9 +178,16 @@ databricks_workspace <- function() {

# Try various ways to get Databricks credentials. This implements a subset of
# the "Databricks client unified authentication" model.
databricks_token <- function(workspace = databricks_workspace(), token = NULL) {
databricks_token <- function(workspace = databricks_workspace(),
token = NULL,
session = NULL) {
host <- gsub("https://|/$", "", workspace)

# Session credentials take precedence over static credentials.
if (!is.null(session)) {
return(connectcreds::connect_viewer_token(session, workspace))
}

# An explicit bearer token takes precedence over everything else.
token <- token %||% Sys.getenv("DATABRICKS_TOKEN")
if (nchar(token)) {
Expand Down
16 changes: 8 additions & 8 deletions man/chat_cortex.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion man/chat_databricks.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 51 additions & 0 deletions tests/testthat/_snaps/provider-cortex.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,54 @@
[1] "@my_db.my_schema.my_stage/model.yaml"


# the session parameter is ignored when not on Connect

Code
. <- chat_cortex("testorg-test_account", model_file = "model.yaml", session = session)
Message
! Ignoring the `session` parameter.
i Viewer-based credentials are only available when running on Connect.

# missing viewer credentials generate errors on Connect

Code
. <- chat_cortex("testorg-test_account", model_file = "model.yaml", session = session)
Condition
Error in `connectcreds::has_viewer_token()`:
! Cannot fetch viewer-based credentials for the current Shiny session.
Caused by error in `connect_viewer_token()`:
! Viewer-based credentials are not supported by this version of Connect.

# token exchange requests to Connect look correct

Code
list(url = req$url, headers = req$headers, body = req$body$data)
Output
$url
[1] "localhost:3030/__api__/v1/oauth/integrations/credentials"

$headers
$headers$Authorization
[1] "Key key"

$headers$Accept
[1] "application/json"

attr(,"redact")
[1] "Authorization"

$body
$body$grant_type
[1] "urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange"

$body$subject_token
[1] "user-token"

$body$subject_token_type
[1] "urn%3Aposit%3Aconnect%3Auser-session-token"

$body$resource
[1] "https%3A%2F%2Ftestorg-test_account.snowflakecomputing.com"



Loading
Loading