diff --git a/src/cloudflare/internal/ai-api.ts b/src/cloudflare/internal/ai-api.ts index 7077e6765e8..cd7bd96059e 100644 --- a/src/cloudflare/internal/ai-api.ts +++ b/src/cloudflare/internal/ai-api.ts @@ -13,6 +13,7 @@ interface AiError { message: string; name: string; description: string; + errors?: Array<{ code: number; message: string }>; } export type SessionOptions = { @@ -96,10 +97,12 @@ export class Ai { }, }; - const res = await this.fetcher.fetch( - 'https://workers-binding.ai/run?version=3', - fetchOptions - ); + let endpointUrl = 'https://workers-binding.ai/run?version=3'; + if (options.gateway?.id) { + endpointUrl = 'https://workers-binding.ai/ai-gateway/run?version=3'; + } + + const res = await this.fetcher.fetch(endpointUrl, fetchOptions); this.lastRequestId = res.headers.get('cf-ai-req-id'); this.aiGatewayLogId = res.headers.get('cf-aig-log-id'); @@ -138,11 +141,23 @@ export class Ai { try { const parsedContent = JSON.parse(content) as AiError; - this.lastRequestInternalStatusCode = parsedContent.internalCode; - return new InferenceUpstreamError( - `${parsedContent.internalCode}: ${parsedContent.description}`, - parsedContent.name - ); + if (parsedContent.internalCode) { + this.lastRequestInternalStatusCode = parsedContent.internalCode; + return new InferenceUpstreamError( + `${parsedContent.internalCode}: ${parsedContent.description}`, + parsedContent.name + ); + } else if ( + parsedContent.errors && + parsedContent.errors.length > 0 && + parsedContent.errors[0] + ) { + return new InferenceUpstreamError( + `${parsedContent.errors[0].code}: ${parsedContent.errors[0].message}` + ); + } else { + return new InferenceUpstreamError(content); + } } catch { return new InferenceUpstreamError(content); } diff --git a/src/cloudflare/internal/test/ai/ai-api-test.js b/src/cloudflare/internal/test/ai/ai-api-test.js index 49956a4521e..80f791814cd 100644 --- a/src/cloudflare/internal/test/ai/ai-api-test.js +++ b/src/cloudflare/internal/test/ai/ai-api-test.js @@ -91,7 +91,11 @@ export const tests = { // Test raw input const resp = await env.ai.run('rawInputs', { prompt: 'test' }); - assert.deepStrictEqual(resp, { inputs: { prompt: 'test' }, options: {} }); + assert.deepStrictEqual(resp, { + inputs: { prompt: 'test' }, + options: {}, + requestUrl: 'https://workers-binding.ai/run?version=3', + }); } { @@ -105,6 +109,7 @@ export const tests = { assert.deepStrictEqual(resp, { inputs: { prompt: 'test' }, options: { gateway: { id: 'my-gateway', skipCache: true } }, + requestUrl: 'https://workers-binding.ai/ai-gateway/run?version=3', }); } @@ -126,6 +131,7 @@ export const tests = { example: 123, gateway: { id: 'my-gateway', metadata: { employee: 1233 } }, }, + requestUrl: 'https://workers-binding.ai/ai-gateway/run?version=3', }); } }, diff --git a/src/cloudflare/internal/test/ai/ai-mock.js b/src/cloudflare/internal/test/ai/ai-mock.js index da07f77e752..c85c6f23017 100644 --- a/src/cloudflare/internal/test/ai/ai-mock.js +++ b/src/cloudflare/internal/test/ai/ai-mock.js @@ -22,9 +22,15 @@ export default { } if (modelName === 'rawInputs') { - return Response.json(data, { - headers: respHeaders, - }); + return Response.json( + { + ...data, + requestUrl: request.url, + }, + { + headers: respHeaders, + } + ); } if (modelName === 'inputErrorModel') {