Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Readdding image support #6902

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions lib/shared/src/chat/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@ export class ChatClient {

// We only want to send up the speaker and prompt text, regardless of whatever other fields
// might be on the messages objects (`file`, `displayText`, `contextFiles`, etc.).
const messagesToSend = augmentedMessages.map(({ speaker, text }) => ({
const messagesToSend = augmentedMessages.map<Message>(({ speaker, text, content }) => ({
text,
speaker,
content,
}))

const completionParams = {
const completionParams: CompletionParameters = {
...DEFAULT_CHAT_COMPLETION_PARAMETERS,
...params,
messages: messagesToSend,
Expand Down Expand Up @@ -107,8 +108,8 @@ export function sanitizeMessages(messages: Message[]): Message[] {
// the next one
const nextMessage = sanitizedMessages[index + 1]
if (
(nextMessage.speaker === 'assistant' && !nextMessage.text?.length) ||
(message.speaker === 'assistant' && !message.text?.length)
(nextMessage.speaker === 'assistant' && !nextMessage.text?.length && !nextMessage.content) ||
(message.speaker === 'assistant' && !message.text?.length && !message.content)
) {
return false
}
Expand Down
1 change: 1 addition & 0 deletions lib/shared/src/chat/transcript/messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ export interface SubMessage {

export interface ChatMessage extends Message {
contextFiles?: ContextItem[]
base64Image?: string

contextAlternatives?: RankedContext[]

Expand Down
18 changes: 10 additions & 8 deletions lib/shared/src/sourcegraph-api/completions/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,28 @@ interface DoneEvent {
type: 'done'
}

interface CompletionEvent extends CompletionResponse {
type: 'completion'
}
// interface CompletionEvent extends CompletionResponse {
// type: 'completion'
// }

interface ErrorEvent {
type: 'error'
error: string
}

export type Event = DoneEvent | CompletionEvent | ErrorEvent
export type Event = DoneEvent | ErrorEvent

export interface Message {
// Note: The unified API only supports one system message passed as the first message
speaker: 'human' | 'assistant' | 'system'
text?: PromptString
content?: string | MessagePart[]
base64Image?: string
}

export interface CompletionResponse {
completion: string
stopReason?: string
}
type MessagePart =
| { type: 'text'; text: string } // a normal text message
| { type: 'image_url'; image_url: { url: string } } // image message, per https://platform.openai.com/docs/guides/vision

export interface CompletionParameters {
fast?: boolean
Expand All @@ -45,6 +46,7 @@ export interface CompletionParameters {
type: 'content'
content: string
}
base64Image?: string
}

export interface SerializedCompletionParameters extends Omit<CompletionParameters, 'messages'> {
Expand Down
27 changes: 26 additions & 1 deletion vscode/src/chat/chat-view/ChatBuilder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ export class ChatBuilder {
if (this.messages.at(-1)?.speaker === 'human') {
throw new Error('Cannot add a user message after a user message')
}
this.messages.push({ ...message, speaker: 'human' })
this.messages.push({ ...message, speaker: 'human', base64Image: this.getAndResetImage() })
this.changeNotifications.next()
}

Expand Down Expand Up @@ -322,6 +322,31 @@ export class ChatBuilder {
}
return result
}

/**
* Store the base64-encoded image uploaded by user to a multi-modal model.
* Requires vision support in the model, added in the PR
* https://github.com/sourcegraph/sourcegraph/pull/546
*/
private image: string | undefined = undefined

/**
* Sets the base64-encoded image for the chat model.
* @param base64Image - The base64-encoded image data to set.
*/
public setImage(base64Image: string): void {
this.image = base64Image
}

/**
* Gets the base64-encoded image for the chat model and resets the internal image property to undefined.
* @returns The base64-encoded image, or undefined if no image has been set.
*/
public getAndResetImage(): string | undefined {
const image = this.image
this.image = undefined
return image
}
}

function messageToSerializedChatInteraction(
Expand Down
5 changes: 5 additions & 0 deletions vscode/src/chat/chat-view/ChatController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,11 @@ export class ChatController implements vscode.Disposable, vscode.WebviewViewProv
}
break
}

case 'chat/upload-file': {
this.chatBuilder.setImage(message.base64)
break
}
case 'log': {
const logger = message.level === 'debug' ? logDebug : logError
logger(message.filterLabel, message.message)
Expand Down
3 changes: 3 additions & 0 deletions vscode/src/chat/chat-view/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ export class DefaultPrompter {
`Ignored ${messagesIgnored} chat messages due to context limit`
)
}
for (const message of reverseTranscript) {
promptBuilder.tryAddImage(message.base64Image)
}
// Counter for context items categorized by source
const ignoredContext = { user: 0, corpus: 0, transcript: 0 }

Expand Down
2 changes: 2 additions & 0 deletions vscode/src/chat/protocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ export type WebviewMessage =
selectedFilters: NLSSearchDynamicFilter[]
}
| { command: 'action/confirmation'; id: string; response: boolean }
| { command: 'log'; level: 'debug' | 'error'; filterLabel: string; message: string }
| { command: 'chat/upload-file'; base64: string }

export interface SmartApplyResult {
taskId: FixupTaskID
Expand Down
8 changes: 6 additions & 2 deletions vscode/src/completions/nodeClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import {
type CompletionCallbacks,
type CompletionParameters,
type CompletionRequestParameters,
type CompletionResponse,
NetworkError,
RateLimitError,
SourcegraphCompletionsClient,
Expand All @@ -21,6 +20,7 @@ import {
getTraceparentHeaders,
globalAgentRef,
isError,
logDebug,
logError,
onAbort,
parseEvents,
Expand All @@ -38,6 +38,10 @@ export class SourcegraphNodeCompletionsClient extends SourcegraphCompletionsClie
signal?: AbortSignal
): Promise<void> {
const { apiVersion, interactionId } = requestParams
for (const message of params.messages) {
logDebug('apiVersion', JSON.stringify(apiVersion, null, 2))
logDebug('base64Image', JSON.stringify(message, null, 2))
}

const url = new URL(await this.completionsEndpoint())
if (apiVersion >= 1) {
Expand Down Expand Up @@ -326,7 +330,7 @@ export class SourcegraphNodeCompletionsClient extends SourcegraphCompletionsClie
getActiveTraceAndSpanId()?.traceId
)
}
const json = (await response.json()) as CompletionResponse
const json = await response.json()
if (typeof json?.completion === 'string') {
cb.onChange(json.completion)
cb.onComplete()
Expand Down
27 changes: 26 additions & 1 deletion vscode/src/prompt-builder/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export class PromptBuilder {
* A list of context items that are used to build context messages.
*/
public contextItems: ContextItem[] = []
public images: string[] = []

/**
* Convenience constructor because loading the tokenizer is async due to its large size.
Expand All @@ -47,10 +48,28 @@ export class PromptBuilder {
if (this.contextItems.length > 0) {
this.buildContextMessages()
}

this.buildImageMessages()
return this.prefixMessages.concat([...this.reverseMessages].reverse())
}

private buildImageMessages(): void {
for (const image of this.images) {
const imageMessage: Message = {
speaker: 'human',
content: [
{
type: 'image_url',
image_url: {
// TODO: Handle PNG/JPEG, don't hardcode to JPEG
url: `data:image/jpeg;base64,${image}`,
},
},
],
}
this.reverseMessages.push(...[ASSISTANT_MESSAGE, imageMessage])
}
}

private buildContextMessages(): void {
for (const item of this.contextItems) {
// Create context messages for each context item, where
Expand Down Expand Up @@ -108,6 +127,12 @@ export class PromptBuilder {
return undefined
}

public tryAddImage(base64Image: string | undefined): void {
if (base64Image) {
this.images.push(base64Image)
}
}

public async tryAddContext(
type: ContextTokenUsageType | 'history',
contextItems: ContextItem[]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import {
import type { UserAccountInfo } from '../../../../../Chat'
import { type ClientActionListener, useClientActionListener } from '../../../../../client/clientState'
import { promptModeToIntent } from '../../../../../prompts/PromptsTab'
import { getVSCodeAPI } from '../../../../../utils/VSCodeApi'
import { useTelemetryRecorder } from '../../../../../utils/telemetry'
import { useFeatureFlag } from '../../../../../utils/useFeatureFlags'
import { useLinkOpener } from '../../../../../utils/useLinkOpener'
Expand Down Expand Up @@ -99,6 +100,8 @@ export const HumanMessageEditor: FunctionComponent<{
}) => {
const telemetryRecorder = useTelemetryRecorder()

const [imageFile, setImageFile] = useState<File | undefined>(undefined)

const editorRef = useRef<PromptEditorRefAPI>(null)
useImperativeHandle(parentEditorRef, (): PromptEditorRefAPI | null => editorRef.current, [])

Expand Down Expand Up @@ -126,7 +129,7 @@ export const HumanMessageEditor: FunctionComponent<{
const experimentalPromptEditorEnabled = useFeatureFlag(FeatureFlag.CodyExperimentalPromptEditor)

const onSubmitClick = useCallback(
(intent?: ChatMessage['intent'], forceSubmit?: boolean): void => {
async (intent?: ChatMessage['intent'], forceSubmit?: boolean): Promise<void> => {
if (!forceSubmit && submitState === 'emptyEditorValue') {
return
}
Expand All @@ -142,6 +145,28 @@ export const HumanMessageEditor: FunctionComponent<{

const value = editorRef.current.getSerializedValue()
parentOnSubmit(intent)
if (imageFile) {
const readFileGetBase64String = (file: File): Promise<string> => {
return new Promise((resolve, reject) => {
const reader = new FileReader()
reader.onload = () => {
const base64 = reader.result
if (base64 && typeof base64 === 'string') {
resolve(base64.split(',')[1])
} else {
reject(new Error('Failed to read file'))
}
}
reader.onerror = () => reject(new Error('Failed to read file'))
reader.readAsDataURL(file)
})
}

const base64 = await readFileGetBase64String(imageFile)
getVSCodeAPI().postMessage({ command: 'chat/upload-file', base64 })
setImageFile(undefined)
}
parentOnSubmit(intent)

telemetryRecorder.recordEvent('cody.humanMessageEditor', 'submit', {
metadata: {
Expand All @@ -157,7 +182,15 @@ export const HumanMessageEditor: FunctionComponent<{
},
})
},
[submitState, parentOnSubmit, onStop, telemetryRecorder.recordEvent, isFirstMessage, isSent]
[
submitState,
parentOnSubmit,
onStop,
telemetryRecorder.recordEvent,
isFirstMessage,
isSent,
imageFile,
]
)

const onEditorEnterKey = useCallback(
Expand Down Expand Up @@ -423,6 +456,7 @@ export const HumanMessageEditor: FunctionComponent<{
)

const Editor = experimentalPromptEditorEnabled ? PromptEditorV2 : PromptEditor
const experimentalOneBoxEnabled = useFeatureFlag(FeatureFlag.CodyExperimentalOneBoxDebug)

return (
// biome-ignore lint/a11y/useKeyWithClickEvents: only relevant to click areas
Expand Down Expand Up @@ -470,6 +504,9 @@ export const HumanMessageEditor: FunctionComponent<{
hidden={!focused && isSent}
className={styles.toolbar}
intent={intent}
imageFile={imageFile}
setImageFile={setImageFile}
experimentalOneBoxEnabled={experimentalOneBoxEnabled}
/>
)}
</div>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { useActionSelect } from '../../../../../../prompts/PromptsTab'
import { useClientConfig } from '../../../../../../utils/useClientConfig'
import { AddContextButton } from './AddContextButton'
import { SubmitButton, type SubmitButtonState } from './SubmitButton'
import { UploadImageButton } from './UploadImageButton'

/**
* The toolbar for the human message editor.
Expand All @@ -35,6 +36,10 @@ export const Toolbar: FunctionComponent<{
intent?: ChatMessage['intent']

manuallySelectIntent: (intent: ChatMessage['intent']) => void
experimentalOneBoxEnabled?: boolean

imageFile?: File
setImageFile: (file: File | undefined) => void
}> = ({
userInfo,
isEditorFocused,
Expand All @@ -48,6 +53,9 @@ export const Toolbar: FunctionComponent<{
models,
intent,
manuallySelectIntent,
experimentalOneBoxEnabled,
imageFile,
setImageFile,
}) => {
/**
* If the user clicks in a gap or on the toolbar outside of any of its buttons, report back to
Expand Down Expand Up @@ -88,6 +96,14 @@ export const Toolbar: FunctionComponent<{
/>
)}
<PromptSelectFieldToolbarItem focusEditor={focusEditor} className="tw-ml-1 tw-mr-1" />
{
<UploadImageButton
className="tw-opacity-60"
imageFile={imageFile}
onClick={setImageFile}
/>
}
<PromptSelectFieldToolbarItem focusEditor={focusEditor} className="tw-ml-1 tw-mr-1" />
<ModelSelectFieldToolbarItem
models={models}
userInfo={userInfo}
Expand Down
Loading
Loading