Skip to content

Commit

Permalink
updating OpenAI Assistant stream to allow responses with files
Browse files Browse the repository at this point in the history
  • Loading branch information
OvidijusParsiunas committed Apr 12, 2024
1 parent b120a8c commit fae09a3
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 3 deletions.
9 changes: 8 additions & 1 deletion component/src/services/openAI/openAIAssistantIO.ts
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,14 @@ export class OpenAIAssistantIO extends DirectServiceIO {
return {makingAnotherRequest: true};
}

// prettier-ignore
private async parseStreamResult(result: OpenAIAssistantInitReqResult) {
if (result.content && result.content.length > 0 && this.messages) {
const downloadCb = OpenAIAssistantUtils.getFilesAndText.bind(this,
this, {role: 'assistant', content: result.content}, result.content[0]);
this.messageStream?.endStreamAfterFileDownloaded(this.messages, downloadCb);
return {text: ''};
}
if (result.delta?.content) {
if (!this.streamedMessageId) {
this.streamedMessageId = result.id;
Expand All @@ -254,7 +261,7 @@ export class OpenAIAssistantIO extends DirectServiceIO {
this.messageStream?.newMessage();
}
if (result.delta.content.length > 1) {
const messages = await OpenAIAssistantUtils.processSteamMessages(this, result.delta.content);
const messages = await OpenAIAssistantUtils.processStreamMessages(this, result.delta.content);
return {text: messages[0].text, files: messages[1].files};
}
return {text: result.delta.content[0].text?.value};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ export class OpenAIAssistantUtils {
return Promise.all(parsedContent);
}

public static async processSteamMessages(io: DirectServiceIO, content: OpenAIAssistantContent[]) {
public static async processStreamMessages(io: DirectServiceIO, content: OpenAIAssistantContent[]) {
return OpenAIAssistantUtils.parseMessages(io, [{content, role: 'assistant'}]);
}

Expand Down
1 change: 1 addition & 0 deletions component/src/types/openAIResult.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export type OpenAIAssistantInitReqResult = OpenAIRunResult & {
};
// this is used exclusively for streams
file_ids?: string[];
content?: OpenAIAssistantContent[];
};

export interface OpenAINewAssistantResult {
Expand Down
16 changes: 15 additions & 1 deletion component/src/views/chat/messages/stream/messageStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ import {ErrorMessages} from '../../../../utils/errorMessages/errorMessages';
import {ElementUtils} from '../../../../utils/element/elementUtils';
import {MessageContentI} from '../../../../types/messagesInternal';
import {TextToSpeech} from '../textToSpeech/textToSpeech';
import {MessageFile} from '../../../../types/messageFile';
import {MessageElements, Messages} from '../messages';
import {Response} from '../../../../types/response';
import {HTMLMessages} from '../html/htmlMessages';
import {MessageUtils} from '../messageUtils';
import {MessagesBase} from '../messagesBase';
import {HTMLUtils} from '../html/htmlUtils';
import {MessageElements} from '../messages';

export class MessageStream {
static readonly MESSAGE_CLASS = 'streamed-message';
Expand All @@ -19,6 +20,7 @@ export class MessageStream {
private _activeMessageRole?: string;
private _message?: MessageContentI;
private readonly _messages: MessagesBase;
private _endStreamAfterOperation?: boolean;
private static readonly HTML_CONTENT_PLACEHOLDER = 'htmlplaceholder'; // used for extracting at end and for isStreaming

constructor(messages: MessagesBase) {
Expand Down Expand Up @@ -85,6 +87,7 @@ export class MessageStream {

public finaliseStreamedMessage() {
const {textElementsToText} = this._messages;
if (this._endStreamAfterOperation) return;
if (this._fileAdded && !this._elements) return;
if (!this._elements) throw Error(ErrorMessages.NO_VALID_STREAM_EVENTS_SENT);
if (!this._elements.bubbleElement?.classList.contains(MessageStream.MESSAGE_CLASS)) return;
Expand Down Expand Up @@ -117,4 +120,15 @@ export class MessageStream {
this._hasStreamEnded = false;
this._activeMessageRole = undefined;
}

// prettier-ignore
public async endStreamAfterFileDownloaded(
messages: Messages, downloadCb: () => Promise<{files?: MessageFile[]; text?: string}>) {
this._endStreamAfterOperation = true;
const {text, files} = await downloadCb();
if (text) this.updateBasedOnType(text, 'text', this._elements?.bubbleElement as HTMLElement, true);
this._endStreamAfterOperation = false;
this.finaliseStreamedMessage();
if (files) messages.addNewMessage({files}); // adding later to trigger event later
}
}

0 comments on commit fae09a3

Please sign in to comment.