From f358bb9bb668fd5fed9d41e37e4c3bb8914bcc56 Mon Sep 17 00:00:00 2001 From: Emily Date: Wed, 11 Dec 2024 18:32:22 +0100 Subject: [PATCH] Add streaming to AI --- dashboard/pages/analyst.vue | 75 ++++++-- dashboard/server/api/ai/[chat_id]/status.ts | 17 ++ dashboard/server/api/ai/send_message.post.ts | 42 ++++- dashboard/server/services/AiService.ts | 184 +++++++++++++++---- shared/schema/ai/AiChatSchema.ts | 4 + 5 files changed, 271 insertions(+), 51 deletions(-) create mode 100644 dashboard/server/api/ai/[chat_id]/status.ts diff --git a/dashboard/pages/analyst.vue b/dashboard/pages/analyst.vue index 4ec1a08..c1b5ef1 100644 --- a/dashboard/pages/analyst.vue +++ b/dashboard/pages/analyst.vue @@ -22,9 +22,38 @@ const loading = ref(false); const currentChatId = ref(""); const currentChatMessages = ref<{ role: string, content: string, charts?: any[] }[]>([]); +const currentChatMessageDelta = ref(''); const scroller = ref(null); + + +async function pollSendMessageStatus(chat_id: string, times: number, updateStatus: (status: string) => any) { + + if (times > 20) return; + + const res = await $fetch(`/api/ai/${chat_id}/status`, { + headers: useComputedHeaders({ + useSnapshotDates: false, + }).value + }); + if (!res) throw Error('Error during status request'); + + updateStatus(res.status); + + if (res.completed === false) { + setTimeout(() => pollSendMessageStatus(chat_id, times + 1, updateStatus), 200); + } else { + currentChatMessages.value.push({ + role: 'assistant', + content: currentChatMessageDelta.value.replace(/\[data:.*?\]/g,''), + }); + currentChatMessageDelta.value = ''; + + } + +} + async function sendMessage() { if (loading.value) return; @@ -43,21 +72,21 @@ async function sendMessage() { try { - const res = await $fetch(`/api/ai/send_message`, { - method: 'POST', - body: JSON.stringify(body), - headers: useComputedHeaders({ - useSnapshotDates: false, - custom: { 'Content-Type': 'application/json' } - }).value - }); + const res = await $fetch<{ chat_id: string }>(`/api/ai/send_message`, { method: 'POST', body: JSON.stringify(body), headers: useComputedHeaders({ useSnapshotDates: false, custom: { 'Content-Type': 'application/json' } }).value }); + currentChatId.value = res.chat_id; - currentChatMessages.value.push({ role: 'assistant', content: res.content || 'nocontent', charts: res.charts.map(e => JSON.parse(e)) }); + currentChatMessages.value.push({ role: 'assistant', content: '', charts: [] }); await reloadChatsRemaining(); await reloadChatsList(); - currentChatId.value = chatsList.value?.at(-1)?._id.toString() || ''; + await new Promise(e => setTimeout(e, 200)); + + await pollSendMessageStatus(res.chat_id, 0, status => { + if (!status) return; + if (status.length > 0) loading.value = false; + currentChatMessageDelta.value = status; + }); } catch (ex: any) { @@ -80,10 +109,6 @@ async function sendMessage() { setTimeout(() => scrollToBottom(), 1); - - loading.value = false; - - } async function openChat(chat_id?: string) { @@ -91,6 +116,7 @@ async function openChat(chat_id?: string) { if (!project.value) return; currentChatMessages.value = []; + currentChatMessageDelta.value = ''; if (!chat_id) { currentChatId.value = ''; @@ -139,6 +165,7 @@ async function deleteChat(chat_id: string) { if (currentChatId.value === chat_id) { currentChatId.value = ""; currentChatMessages.value = []; + currentChatMessageDelta.value = ''; } await $fetch(`/api/ai/${chat_id}/delete`, { headers: useComputedHeaders({ useSnapshotDates: false }).value @@ -184,6 +211,7 @@ const { visible: pricingDrawerVisible } = usePricingDrawer() {{ message.content }} +
@@ -195,9 +223,9 @@ const { visible: pricingDrawerVisible } = usePricingDrawer() breaks: true, }" />
-
+
@@ -209,6 +237,23 @@ const { visible: pricingDrawerVisible } = usePricingDrawer()
+ + + +
+
+ +
+
+ {{ currentChatMessageDelta.replace(/\[(data:(.*?))\]/g, 'Processing: $2\n') }} +
+
+ + + + +
diff --git a/dashboard/server/api/ai/[chat_id]/status.ts b/dashboard/server/api/ai/[chat_id]/status.ts new file mode 100644 index 0000000..5218f23 --- /dev/null +++ b/dashboard/server/api/ai/[chat_id]/status.ts @@ -0,0 +1,17 @@ + +import { AiChatModel } from "@schema/ai/AiChatSchema"; + +export default defineEventHandler(async event => { + const data = await getRequestData(event); + if (!data) return; + + const { project_id } = data; + + if (!event.context.params) return; + const chat_id = event.context.params['chat_id']; + + const chat = await AiChatModel.findOne({ _id: chat_id, project_id }, { status: 1, completed: 1 }); + if (!chat) return; + + return { status: chat.status, completed: chat.completed || false } +}); \ No newline at end of file diff --git a/dashboard/server/api/ai/send_message.post.ts b/dashboard/server/api/ai/send_message.post.ts index fa199e3..408d7f4 100644 --- a/dashboard/server/api/ai/send_message.post.ts +++ b/dashboard/server/api/ai/send_message.post.ts @@ -1,4 +1,4 @@ -import { sendMessageOnChat } from "~/server/services/AiService"; +import { sendMessageOnChat, updateChatStatus } from "~/server/services/AiService"; import { getAiChatRemainings } from "./chats_remaining"; @@ -15,7 +15,43 @@ export default defineEventHandler(async event => { const chatsRemaining = await getAiChatRemainings(pid); if (chatsRemaining <= 0) return setResponseStatus(event, 400, 'CHAT_LIMIT_REACHED'); - const response = await sendMessageOnChat(text, pid, chat_id); + const currentStatus: string[] = []; + + let responseSent = false; + + let targetChatId = ''; + + await sendMessageOnChat(text, pid, chat_id, { + onChatId: async chat_id => { + if (!responseSent) { + event.node.res.setHeader('Content-Type', 'application/json'); + event.node.res.end(JSON.stringify({ chat_id })); + targetChatId = chat_id; + responseSent = true; + } + }, + onDelta: async text => { + currentStatus.push(text); + await updateChatStatus(targetChatId, currentStatus.join(''), false); + }, + onFunctionName: async name => { + currentStatus.push('[data:FunctionName]'); + await updateChatStatus(targetChatId, currentStatus.join(''), false); + }, + onFunctionCall: async name => { + currentStatus.push('[data:FunctionCall]'); + await updateChatStatus(targetChatId, currentStatus.join(''), false); + }, + onFunctionResult: async (name, result) => { + currentStatus.push('[data:FunctionResult]'); + await updateChatStatus(targetChatId, currentStatus.join(''), false); + }, + onFinish: async calls => { + currentStatus.push('[data:FunctionFinish]'); + await updateChatStatus(targetChatId, currentStatus.join(''), false); + } + }); + + await updateChatStatus(targetChatId, currentStatus.join(''), true); - return response; }); \ No newline at end of file diff --git a/dashboard/server/services/AiService.ts b/dashboard/server/services/AiService.ts index 14dcef4..f054fe5 100644 --- a/dashboard/server/services/AiService.ts +++ b/dashboard/server/services/AiService.ts @@ -9,15 +9,11 @@ import { AiEventsInstance } from '../ai/functions/AI_Events'; import { AiVisitsInstance } from '../ai/functions/AI_Visits'; import { AiComposableChartInstance } from '../ai/functions/AI_ComposableChart'; -const { AI_ORG, AI_PROJECT, AI_KEY } = useRuntimeConfig(); +const { AI_KEY, AI_ORG, AI_PROJECT } = useRuntimeConfig(); const OPENAI_MODEL: OpenAI.Chat.ChatModel = 'gpt-4o-mini'; -const openai = new OpenAI({ - organization: AI_ORG, - project: AI_PROJECT, - apiKey: AI_KEY -}); +const openai = new OpenAI({ apiKey: AI_KEY, organization: AI_ORG, project: AI_PROJECT }); const tools: OpenAI.Chat.Completions.ChatCompletionTool[] = [ ...AiVisitsInstance.getTools(), @@ -57,6 +53,13 @@ async function setChatTitle(title: string, chat_id?: string) { await AiChatModel.updateOne({ _id: chat_id }, { title }); } +export async function updateChatStatus(chat_id: string, status: string, completed: boolean) { + await AiChatModel.updateOne({ _id: chat_id }, { + status, + completed + }); +} + export function getChartsInMessage(message: OpenAI.Chat.Completions.ChatCompletionMessageParam) { if (message.role != 'assistant') return []; @@ -65,13 +68,117 @@ export function getChartsInMessage(message: OpenAI.Chat.Completions.ChatCompleti return message.tool_calls.filter(e => e.function.name === 'createComposableChart').map(e => e.function.arguments); } -export async function sendMessageOnChat(text: string, pid: string, initial_chat_id?: string) { + + +type FunctionCall = { name: string, argsRaw: string[], collecting: boolean, result: any, tool_call_id: string } + +type DeltaCallback = (text: string) => any; +type FinishCallback = (functionsCount: number) => any; +type FunctionNameCallback = (name: string) => any; +type FunctionCallCallback = (name: string) => any; +type FunctionResultCallback = (name: string, result: any) => any; + +type ElaborateResponseCallbacks = { + onDelta?: DeltaCallback, + onFunctionName?: FunctionNameCallback, + onFunctionCall?: FunctionCallCallback, + onFunctionResult?: FunctionResultCallback, + onFinish?: FinishCallback, + onChatId?: (chat_id: string) => any +} + +async function elaborateResponse(messages: OpenAI.Chat.Completions.ChatCompletionMessageParam[], pid: string, chat_id: string, callbacks?: ElaborateResponseCallbacks) { + + const responseStream = await openai.beta.chat.completions.stream({ model: OPENAI_MODEL, messages, n: 1, tools }); + + const functionCalls: FunctionCall[] = []; + + let lastFinishReason: "length" | "tool_calls" | "function_call" | "stop" | "content_filter" | null = null; + + for await (const part of responseStream) { + + const delta = part.choices[0].delta; + const finishReason = part.choices[0].finish_reason; + + if (delta.content) await callbacks?.onDelta?.(delta.content); + + if (delta.tool_calls) { + const toolCall = delta.tool_calls[0]; + if (!toolCall.function) throw Error('Cannot get function from tool_calls'); + + const functionName = toolCall.function.name; + + const functionCall: FunctionCall = functionName ? + { name: functionName, argsRaw: [], collecting: true, result: null, tool_call_id: toolCall.id as string } : + functionCalls.at(-1) as FunctionCall; + + if (functionName) functionCalls.push(functionCall); + + if (functionName) await callbacks?.onFunctionName?.(functionName); + + if (toolCall.function.arguments) functionCall.argsRaw.push(toolCall.function.arguments); + + } + + if (finishReason === "tool_calls" && functionCalls.at(-1)?.collecting) { + const functionCall: FunctionCall = functionCalls.at(-1) as FunctionCall; + await callbacks?.onFunctionCall?.(functionCall.name); + const args = JSON.parse(functionCall.argsRaw.join('')); + const functionResult = await functions[functionCall.name]({ project_id: pid, ...args }); + functionCall.result = functionResult; + await callbacks?.onFunctionResult?.(functionCall.name, functionResult); + + addMessageToChat({ + role: 'assistant', + content: delta.content, + refusal: delta.refusal, + tool_calls: [ + { + id: functionCall.tool_call_id, + type: 'function', + function: { + name: functionCall.name, + arguments: functionCall.argsRaw.join('') + } + } + ], + parsed: null + }, chat_id); + + addMessageToChat({ + tool_call_id: functionCall.tool_call_id, + role: 'tool', + content: JSON.stringify(functionResult) + }, chat_id); + + + functionCall.collecting = false; + lastFinishReason = finishReason; + } + + } + + await callbacks?.onFinish?.(functionCalls.length); + + const toolResponseMesages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = functionCalls.map(e => { + return { tool_call_id: e.tool_call_id, role: "tool", content: JSON.stringify(e.result) } + }); + + if (lastFinishReason == 'tool_calls') return await elaborateResponse([...responseStream.messages, ...toolResponseMesages], pid, chat_id, callbacks); + + return responseStream; +} + + +export async function sendMessageOnChat(text: string, pid: string, initial_chat_id?: string, callbacks?: ElaborateResponseCallbacks) { const messages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = [] const chat_id = await createChatIfNotExist(pid, initial_chat_id); const chatMessages = await getMessagesFromChatId(chat_id); + await callbacks?.onChatId?.(chat_id); + if (chatMessages && chatMessages.length > 0) { messages.push(...chatMessages); } else { @@ -89,32 +196,43 @@ export async function sendMessageOnChat(text: string, pid: string, initial_chat_ messages.push(userMessage); await addMessageToChat(userMessage, chat_id); - let response = await openai.chat.completions.create({ model: OPENAI_MODEL, messages, n: 1, tools }); - - const chartsData: string[][] = []; - - while ((response.choices[0].message.tool_calls?.length || 0) > 0) { - await addMessageToChat(response.choices[0].message, chat_id); - messages.push(response.choices[0].message); - if (response.choices[0].message.tool_calls) { - - console.log('Tools to call', response.choices[0].message.tool_calls.length); - chartsData.push(getChartsInMessage(response.choices[0].message)); - - for (const toolCall of response.choices[0].message.tool_calls) { - const functionName = toolCall.function.name; - console.log('Calling tool function', functionName); - const functionToCall = functions[functionName]; - const functionArgs = JSON.parse(toolCall.function.arguments); - const functionResponse = await functionToCall({ project_id: pid, ...functionArgs }); - messages.push({ tool_call_id: toolCall.id, role: "tool", content: JSON.stringify(functionResponse) }); - await addMessageToChat({ tool_call_id: toolCall.id, role: "tool", content: JSON.stringify(functionResponse) }, chat_id); - } - } - response = await openai.chat.completions.create({ model: OPENAI_MODEL, messages, n: 1, tools }); + try { + const streamResponse = await elaborateResponse(messages, pid, chat_id, callbacks); + await addMessageToChat({ role: 'assistant', refusal: null, content: await streamResponse.finalContent() }, chat_id); + return { content: '', charts: [] }; + } catch (ex: any) { + console.error(ex); + return { content: ex.message, charts: [] }; } - await addMessageToChat(response.choices[0].message, chat_id); - await ProjectLimitModel.updateOne({ project_id: pid }, { $inc: { ai_messages: 1 } }) - return { content: response.choices[0].message.content, charts: chartsData.filter(e => e.length > 0).flat() }; + + // let response = await openai.chat.completions.create({ model: OPENAI_MODEL, messages, n: 1, tools }); + + // const chartsData: string[][] = []; + + // while ((response.choices[0].message.tool_calls?.length || 0) > 0) { + // await addMessageToChat(response.choices[0].message, chat_id); + // messages.push(response.choices[0].message); + // if (response.choices[0].message.tool_calls) { + + // console.log('Tools to call', response.choices[0].message.tool_calls.length); + // chartsData.push(getChartsInMessage(response.choices[0].message)); + + // for (const toolCall of response.choices[0].message.tool_calls) { + // const functionName = toolCall.function.name; + // console.log('Calling tool function', functionName); + // const functionToCall = functions[functionName]; + // const functionArgs = JSON.parse(toolCall.function.arguments); + // const functionResponse = await functionToCall({ project_id: pid, ...functionArgs }); + // messages.push({ tool_call_id: toolCall.id, role: "tool", content: JSON.stringify(functionResponse) }); + // await addMessageToChat({ tool_call_id: toolCall.id, role: "tool", content: JSON.stringify(functionResponse) }, chat_id); + // } + // } + // response = await openai.chat.completions.create({ model: OPENAI_MODEL, messages, n: 1, tools }); + // } + // await addMessageToChat(response.choices[0].message, chat_id); + // await ProjectLimitModel.updateOne({ project_id: pid }, { $inc: { ai_messages: 1 } }) + // return { content: response.choices[0].message.content, charts: chartsData.filter(e => e.length > 0).flat() }; + + } diff --git a/shared/schema/ai/AiChatSchema.ts b/shared/schema/ai/AiChatSchema.ts index 7a0336d..a66afe6 100644 --- a/shared/schema/ai/AiChatSchema.ts +++ b/shared/schema/ai/AiChatSchema.ts @@ -4,6 +4,8 @@ export type TAiChatSchema = { _id: Schema.Types.ObjectId, project_id: Schema.Types.ObjectId, messages: any[], + status: string, + completed: boolean, title: string, created_at: Date, updated_at: Date @@ -11,6 +13,8 @@ export type TAiChatSchema = { const AiChatSchema = new Schema({ project_id: { type: Schema.Types.ObjectId, index: 1 }, + status: { type: String }, + completed: { type: Boolean }, messages: [{ _id: false, type: Schema.Types.Mixed }], title: { type: String, required: true }, created_at: { type: Date, default: () => Date.now() },