Skip to content

Commit

Permalink
Revert "Refactor external auth providers to re-generate headers on de…
Browse files Browse the repository at this point in the history
…mand (#6687)"

This reverts commit 03c93f9.
  • Loading branch information
umpox committed Jan 24, 2025
1 parent 4c8dd28 commit b001a7c
Show file tree
Hide file tree
Showing 20 changed files with 127 additions and 267 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ sealed class AuthenticationError {
"invalid-access-token" -> context.deserialize<InvalidAccessTokenError>(element, InvalidAccessTokenError::class.java)
"enterprise-user-logged-into-dotcom" -> context.deserialize<EnterpriseUserDotComError>(element, EnterpriseUserDotComError::class.java)
"auth-config-error" -> context.deserialize<AuthConfigError>(element, AuthConfigError::class.java)
"external-auth-provider-error" -> context.deserialize<ExternalAuthProviderError>(element, ExternalAuthProviderError::class.java)
else -> throw Exception("Unknown discriminator ${element}")
}
}
Expand Down Expand Up @@ -53,22 +52,13 @@ data class EnterpriseUserDotComError(
}

data class AuthConfigError(
val type: TypeEnum, // Oneof: auth-config-error
val title: String? = null,
val message: String,
val type: TypeEnum, // Oneof: auth-config-error
) : AuthenticationError() {

enum class TypeEnum {
@SerializedName("auth-config-error") `Auth-config-error`,
}
}

data class ExternalAuthProviderError(
val type: TypeEnum, // Oneof: external-auth-provider-error
val message: String,
) : AuthenticationError() {

enum class TypeEnum {
@SerializedName("external-auth-provider-error") `External-auth-provider-error`,
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ object Constants {
const val `enterprise-user-logged-into-dotcom` = "enterprise-user-logged-into-dotcom"
const val error = "error"
const val experimental = "experimental"
const val `external-auth-provider-error` = "external-auth-provider-error"
const val file = "file"
const val free = "free"
const val function = "function"
Expand Down
2 changes: 1 addition & 1 deletion agent/scripts/simple-external-auth-provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datetime import datetime

def generate_credentials():
current_epoch = int(time.time()) + 30
current_epoch = int(time.time()) + 100

credentials = {
"headers": {
Expand Down
19 changes: 2 additions & 17 deletions lib/shared/src/auth/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,22 +62,15 @@ export interface EnterpriseUserDotComError {
enterprise: string
}

export interface AuthConfigError {
export interface AuthConfigError extends AuthenticationErrorMessage {
type: 'auth-config-error'
message: string
}

export interface ExternalAuthProviderError {
type: 'external-auth-provider-error'
message: string
}

export type AuthenticationError =
| NetworkAuthError
| InvalidAccessTokenError
| EnterpriseUserDotComError
| AuthConfigError
| ExternalAuthProviderError

export interface AuthenticationErrorMessage {
title?: string
Expand Down Expand Up @@ -106,15 +99,7 @@ export function getAuthErrorMessage(error: AuthenticationError): AuthenticationE
'please contact your Sourcegraph admin.',
}
case 'auth-config-error':
return {
title: 'Auth Config Error',
message: error.message,
}
case 'external-auth-provider-error':
return {
title: 'External Auth Provider Error',
message: error.message,
}
return error
}
}

Expand Down
3 changes: 2 additions & 1 deletion lib/shared/src/configuration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ export interface AuthCredentials {

export interface HeaderCredential {
// We use function instead of property to prevent accidential top level serialization - we never want to store this data
getHeaders(): Promise<Record<string, string>>
getHeaders(): Record<string, string>
expiration: number | undefined
}

export interface TokenCredential {
Expand Down
13 changes: 6 additions & 7 deletions lib/shared/src/configuration/auth-resolver.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ describe('auth-resolver', () => {
expect(auth.serverEndpoint).toBe('https://my-server.com/')

const headerCredential = auth.credentials as HeaderCredential
expect(await headerCredential.getHeaders()).toStrictEqual({
expect(headerCredential.expiration).toBe(futureEpoch)
expect(headerCredential.getHeaders()).toStrictEqual({
Authorization: 'token X',
})

Expand Down Expand Up @@ -122,9 +123,8 @@ describe('auth-resolver', () => {

expect(auth.serverEndpoint).toBe('https://my-server.com/')

const headerCredential = auth.credentials as HeaderCredential
expect(headerCredential.getHeaders).toBeInstanceOf(Function)
expect(headerCredential.getHeaders()).rejects.toThrowError('Unexpected token')
expect(auth.credentials).toBe(undefined)
expect(auth.error.message).toContain('Failed to execute external auth command: Unexpected token')
})

test('resolve custom auth provider error handling - bad expiration', async () => {
Expand Down Expand Up @@ -158,9 +158,8 @@ describe('auth-resolver', () => {

expect(auth.serverEndpoint).toBe('https://my-server.com/')

const headerCredential = auth.credentials as HeaderCredential
expect(headerCredential.getHeaders).toBeInstanceOf(Function)
expect(headerCredential.getHeaders()).rejects.toThrowError(
expect(auth.credentials).toBe(undefined)
expect(auth.error.message).toContain(
'Credentials expiration cannot be set to a date in the past'
)
})
Expand Down
142 changes: 51 additions & 91 deletions lib/shared/src/configuration/auth-resolver.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
import { Subject } from 'observable-fns'
import type {
AuthCredentials,
ClientConfiguration,
ExternalAuthCommand,
ExternalAuthProvider,
} from '../configuration'
import { logError } from '../logger'
import { ExternalProviderAuthError } from '../sourcegraph-api/errors'
import type { ClientSecrets } from './resolver'

export const externalAuthRefresh = new Subject<void>()

export function normalizeServerEndpointURL(url: string): string {
return url.endsWith('/') ? url : `${url}/`
}
Expand Down Expand Up @@ -42,88 +37,20 @@ interface HeaderCredentialResult {
expiration?: number | undefined
}

let _headersCache: Promise<HeaderCredentialResult> | undefined = undefined

function hasExpired(expiration: number | undefined): boolean {
return expiration !== undefined && expiration * 1000 < Date.now()
}

async function getExternalProviderHeaders(
externalProvider: ExternalAuthProvider
): Promise<HeaderCredentialResult> {
const result = await executeCommand(externalProvider.executable).catch(error => {
throw new Error(`Failed to execute external auth command: ${error.message || error}`)
})

const credentials = JSON.parse(result) as HeaderCredentialResult

if (!credentials?.headers) {
throw new Error(`Output of the external auth command is invalid: ${result}`)
}

if (hasExpired(credentials.expiration)) {
throw new Error(
'Credentials expiration cannot be set to a date in the past: ' +
`${new Date(credentials.expiration! * 1000)} (${credentials.expiration})`
)
}

return credentials
}

async function createTokenCredentials(
clientSecrets: ClientSecrets,
serverEndpoint: string
): Promise<AuthCredentials> {
const token = await clientSecrets.getToken(serverEndpoint).catch(error => {
throw new Error(
`Failed to get access token for endpoint ${serverEndpoint}: ${error.message || error}`
)
})

return {
credentials: token
? { token, source: await clientSecrets.getTokenSource(serverEndpoint) }
: undefined,
serverEndpoint,
async function getExternalProviderAuthResult(
serverEndpoint: string,
authExternalProviders: readonly ExternalAuthProvider[]
): Promise<HeaderCredentialResult | undefined> {
const externalProvider = authExternalProviders.find(
provider => normalizeServerEndpointURL(provider.endpoint) === serverEndpoint
)

if (externalProvider) {
const result = await executeCommand(externalProvider.executable)
return JSON.parse(result)
}
}

function createHeaderCredentials(
externalProvider: ExternalAuthProvider,
serverEndpoint: string
): AuthCredentials {
// Needed in case of account switch so we reset the cache.
// We could also set it to undefined but there is no harm in pre-loading the cache.
_headersCache = getExternalProviderHeaders(externalProvider)

return {
credentials: {
async getHeaders() {
try {
while (true) {
let observed = _headersCache
if (!observed || hasExpired((await observed)?.expiration)) {
if (observed !== _headersCache) {
continue // cache already changed, retry
}
observed = _headersCache = getExternalProviderHeaders(externalProvider)
}
return (await observed).headers
}
} catch (error) {
_headersCache = undefined
externalAuthRefresh.next()

logError('resolveAuth', `External Auth Provider Error: ${error}`)
throw new ExternalProviderAuthError(
error instanceof Error ? error.message : String(error)
)
}
},
},
serverEndpoint,
}
return undefined
}

export async function resolveAuth(
Expand All @@ -142,13 +69,46 @@ export async function resolveAuth(
return { credentials: { token: overrideAuthToken }, serverEndpoint }
}

const externalProvider = authExternalProviders.find(
provider => normalizeServerEndpointURL(provider.endpoint) === serverEndpoint
)
const credentials = await getExternalProviderAuthResult(
serverEndpoint,
authExternalProviders
).catch(error => {
throw new Error(`Failed to execute external auth command: ${error.message || error}`)
})

if (credentials) {
if (credentials?.expiration) {
const expirationMs = credentials?.expiration * 1000
if (expirationMs < Date.now()) {
throw new Error(
'Credentials expiration cannot be set to a date in the past: ' +
`${new Date(expirationMs)} (${credentials.expiration})`
)
}
}
return {
credentials: {
expiration: credentials?.expiration,
getHeaders() {
return credentials.headers
},
},
serverEndpoint,
}
}

const token = await clientSecrets.getToken(serverEndpoint).catch(error => {
throw new Error(
`Failed to get access token for endpoint ${serverEndpoint}: ${error.message || error}`
)
})

return externalProvider
? createHeaderCredentials(externalProvider, serverEndpoint)
: createTokenCredentials(clientSecrets, serverEndpoint)
return {
credentials: token
? { token, source: await clientSecrets.getTokenSource(serverEndpoint) }
: undefined,
serverEndpoint,
}
} catch (error) {
return {
credentials: undefined,
Expand Down
15 changes: 12 additions & 3 deletions lib/shared/src/configuration/resolver.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import { Observable, map } from 'observable-fns'
import { Observable, Subject, map } from 'observable-fns'
import type { AuthCredentials, ClientConfiguration, TokenSource } from '../configuration'
import { logError } from '../logger'
import {
combineLatest,
distinctUntilChanged,
firstValueFrom,
fromLateSetSource,
promiseToObservable,
startWith,
} from '../misc/observable'
import { skipPendingOperation, switchMapReplayOperation } from '../misc/observableOperation'
import type { DefaultsAndUserPreferencesByEndpoint } from '../models/modelsService'
Expand Down Expand Up @@ -92,6 +94,11 @@ async function resolveConfiguration({

try {
const auth = await resolveAuth(serverEndpoint, clientConfiguration, clientSecrets)
const cred = auth.credentials
if (cred !== undefined && 'expiration' in cred && cred.expiration !== undefined) {
const expireInMs = cred.expiration * 1000 - Date.now()
setInterval(() => _refreshConfigRequests.next(), expireInMs)
}
return { configuration: clientConfiguration, clientState, auth, isReinstall }
} catch (error) {
// We don't want to throw here, because that would cause the observable to terminate and
Expand All @@ -108,14 +115,16 @@ async function resolveConfiguration({

const _resolvedConfig = fromLateSetSource<ResolvedConfiguration>()

const _refreshConfigRequests = new Subject<void>()

/**
* Set the observable that will be used to provide the global {@link resolvedConfig}. This should be
* set exactly once (except in tests).
*/
export function setResolvedConfigurationObservable(input: Observable<ConfigurationInput>): void {
_resolvedConfig.setSource(
input.pipe(
switchMapReplayOperation(input => promiseToObservable(resolveConfiguration(input))),
combineLatest(input, _refreshConfigRequests.pipe(startWith(undefined))).pipe(
switchMapReplayOperation(([input]) => promiseToObservable(resolveConfiguration(input))),
skipPendingOperation(),
map(value => {
if (isError(value)) {
Expand Down
13 changes: 2 additions & 11 deletions lib/shared/src/sourcegraph-api/completions/browserClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,9 @@ export class SourcegraphBrowserCompletionsClient extends SourcegraphCompletionsC
...requestParams.customHeaders,
} as HeadersInit)
addCodyClientIdentificationHeaders(headersInstance)
addAuthHeaders(config.auth, headersInstance, url)
headersInstance.set('Content-Type', 'application/json; charset=utf-8')

try {
await addAuthHeaders(config.auth, headersInstance, url)
} catch (error: any) {
cb.onError(error.message)
abort.abort()
console.error(error)
return
}

const parameters = new URLSearchParams(globalThis.location.search)
const trace = parameters.get('trace')
if (trace) {
Expand Down Expand Up @@ -140,13 +132,12 @@ export class SourcegraphBrowserCompletionsClient extends SourcegraphCompletionsC
...requestParams.customHeaders,
})
addCodyClientIdentificationHeaders(headersInstance)
addAuthHeaders(auth, headersInstance, url)

if (new URLSearchParams(globalThis.location.search).get('trace')) {
headersInstance.set('X-Sourcegraph-Should-Trace', 'true')
}
try {
await addAuthHeaders(auth, headersInstance, url)

const response = await fetch(url.toString(), {
method: 'POST',
headers: headersInstance,
Expand Down
9 changes: 0 additions & 9 deletions lib/shared/src/sourcegraph-api/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,3 @@ export function isNetworkLikeError(error: Error): boolean {
message.includes('SELF_SIGNED_CERT_IN_CHAIN')
)
}

export class ExternalProviderAuthError extends Error {
// Added to make TypeScript understand that ExternalProviderAuthError is not the same as Error.
public readonly isExternalProviderAuthError = true
}

export function isExternalProviderAuthError(error: unknown): error is ExternalProviderAuthError {
return error instanceof ExternalProviderAuthError
}
Loading

0 comments on commit b001a7c

Please sign in to comment.