Skip to content

Commit f04ffe4

Browse files
authored
feat (ui): add onData callback to Chat (#6984)
## Background Reacting to data parts as they become available is important for data parts that are not message related, e.g. events (which can be send as data parts). ## Summary Add `onData` callback to `Chat`. ## Verification Tested manually with `next-openai`` data parts example. ## Future Work * introduce ephemeral data parts
1 parent 4ee81c2 commit f04ffe4

File tree

5 files changed

+62
-12
lines changed

5 files changed

+62
-12
lines changed

‎.changeset/ninety-seahorses-fetch.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
feat (ui): add onData callback to Chat

‎examples/next-openai/app/use-chat-data-ui-parts/page.tsx

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,13 @@ export default function Chat() {
2121
transport: new DefaultChatTransport({
2222
api: '/api/use-chat-data-ui-parts',
2323
}),
24+
onData: dataPart => {
25+
console.log('dataPart', JSON.stringify(dataPart, null, 2));
26+
},
2427
});
2528

2629
return (
27-
<div className="flex flex-col w-full max-w-md py-24 mx-auto stretch">
30+
<div className="flex flex-col py-24 mx-auto w-full max-w-md stretch">
2831
{messages.map(message => (
2932
<div key={message.id} className="whitespace-pre-wrap">
3033
{message.role === 'user' ? 'User: ' : 'AI: '}{' '}
@@ -70,7 +73,7 @@ export default function Chat() {
7073
{status === 'submitted' && <div>Loading...</div>}
7174
<button
7275
type="button"
73-
className="px-4 py-2 mt-4 text-blue-500 border border-blue-500 rounded-md"
76+
className="px-4 py-2 mt-4 text-blue-500 rounded-md border border-blue-500"
7477
onClick={stop}
7578
>
7679
Stop
@@ -83,7 +86,7 @@ export default function Chat() {
8386
<div className="text-red-500">An error occurred.</div>
8487
<button
8588
type="button"
86-
className="px-4 py-2 mt-4 text-blue-500 border border-blue-500 rounded-md"
89+
className="px-4 py-2 mt-4 text-blue-500 rounded-md border border-blue-500"
8790
onClick={() => regenerate()}
8891
>
8992
Retry

‎packages/ai/src/ui/chat.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import {
2121
} from './should-resubmit-messages';
2222
import {
2323
isToolUIPart,
24+
type DataUIPart,
2425
type FileUIPart,
2526
type InferUIMessageData,
2627
type InferUIMessageMetadata,
@@ -138,6 +139,13 @@ export interface ChatInit<UI_MESSAGE extends UIMessage> {
138139
* @param message The message that was streamed.
139140
*/
140141
onFinish?: (options: { message: UI_MESSAGE }) => void;
142+
143+
/**
144+
* Optional callback function that is called when a data part is received.
145+
*
146+
* @param data The data part that was received.
147+
*/
148+
onData?: (dataPart: DataUIPart<InferUIMessageData<UI_MESSAGE>>) => void;
141149
}
142150

143151
export abstract class AbstractChat<UI_MESSAGE extends UIMessage> {
@@ -158,6 +166,7 @@ export abstract class AbstractChat<UI_MESSAGE extends UIMessage> {
158166
private onError?: ChatInit<UI_MESSAGE>['onError'];
159167
private onToolCall?: ChatInit<UI_MESSAGE>['onToolCall'];
160168
private onFinish?: ChatInit<UI_MESSAGE>['onFinish'];
169+
private onData?: ChatInit<UI_MESSAGE>['onData'];
161170

162171
private activeResponse: ActiveResponse<UI_MESSAGE> | undefined = undefined;
163172
private jobExecutor = new SerialJobExecutor();
@@ -173,6 +182,7 @@ export abstract class AbstractChat<UI_MESSAGE extends UIMessage> {
173182
onError,
174183
onToolCall,
175184
onFinish,
185+
onData,
176186
}: Omit<ChatInit<UI_MESSAGE>, 'messages'> & {
177187
state: ChatState<UI_MESSAGE>;
178188
}) {
@@ -186,6 +196,7 @@ export abstract class AbstractChat<UI_MESSAGE extends UIMessage> {
186196
this.onError = onError;
187197
this.onToolCall = onToolCall;
188198
this.onFinish = onFinish;
199+
this.onData = onData;
189200
}
190201

191202
/**
@@ -495,6 +506,7 @@ export abstract class AbstractChat<UI_MESSAGE extends UIMessage> {
495506
stream: processUIMessageStream({
496507
stream,
497508
onToolCall: this.onToolCall,
509+
onData: this.onData,
498510
messageMetadataSchema: this.messageMetadataSchema,
499511
dataPartSchemas: this.dataPartSchemas,
500512
runUpdateMessageJob,

‎packages/ai/src/ui/process-ui-message-stream.test.ts

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import {
66
processUIMessageStream,
77
StreamingUIMessageState,
88
} from './process-ui-message-stream';
9-
import { UIMessage } from './ui-messages';
9+
import { InferUIMessageData, UIMessage } from './ui-messages';
1010

1111
function createUIMessageStream(parts: UIMessageStreamPart[]) {
1212
return convertArrayToReadableStream(parts);
@@ -4042,7 +4042,11 @@ describe('processUIMessageStream', () => {
40424042
});
40434043

40444044
describe('data ui parts (single part)', () => {
4045+
let dataCalls: InferUIMessageData<UIMessage>[] = [];
4046+
40454047
beforeEach(async () => {
4048+
dataCalls = [];
4049+
40464050
const stream = createUIMessageStream([
40474051
{ type: 'start', messageId: 'msg-123' },
40484052
{ type: 'start-step' },
@@ -4066,6 +4070,9 @@ describe('processUIMessageStream', () => {
40664070
onError: error => {
40674071
throw error;
40684072
},
4073+
onData: data => {
4074+
dataCalls.push(data);
4075+
},
40694076
}),
40704077
});
40714078
});
@@ -4119,6 +4126,17 @@ describe('processUIMessageStream', () => {
41194126
}
41204127
`);
41214128
});
4129+
4130+
it('should call the onData callback with the correct arguments', async () => {
4131+
expect(dataCalls).toMatchInlineSnapshot(`
4132+
[
4133+
{
4134+
"data": "example-data-can-be-anything",
4135+
"type": "data-test",
4136+
},
4137+
]
4138+
`);
4139+
});
41224140
});
41234141

41244142
describe('data ui parts (single part with id and replacement update)', () => {

‎packages/ai/src/ui/process-ui-message-stream.ts

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import {
66
} from '@ai-sdk/provider-utils';
77
import {
88
InferUIMessageStreamPart,
9+
DataUIMessageStreamPart,
910
isDataUIMessageStreamPart,
1011
UIMessageStreamPart,
1112
} from '../ui-message-stream/ui-message-stream-parts';
@@ -14,6 +15,7 @@ import { mergeObjects } from '../util/merge-objects';
1415
import { parsePartialJson } from '../util/parse-partial-json';
1516
import { UIDataTypesToSchemas } from './chat';
1617
import {
18+
DataUIPart,
1719
getToolName,
1820
InferUIMessageData,
1921
InferUIMessageMetadata,
@@ -69,6 +71,7 @@ export function processUIMessageStream<UI_MESSAGE extends UIMessage>({
6971
dataPartSchemas,
7072
runUpdateMessageJob,
7173
onError,
74+
onData,
7275
}: {
7376
// input stream is not fully typed yet:
7477
stream: ReadableStream<UIMessageStreamPart>;
@@ -79,6 +82,7 @@ export function processUIMessageStream<UI_MESSAGE extends UIMessage>({
7982
onToolCall?: (options: {
8083
toolCall: ToolCall<string, unknown>;
8184
}) => void | Promise<unknown> | unknown;
85+
onData?: (dataPart: DataUIPart<InferUIMessageData<UI_MESSAGE>>) => void;
8286
runUpdateMessageJob: (
8387
job: (options: {
8488
state: StreamingUIMessageState<UI_MESSAGE>;
@@ -468,25 +472,33 @@ export function processUIMessageStream<UI_MESSAGE extends UIMessage>({
468472

469473
default: {
470474
if (isDataUIMessageStreamPart(part)) {
475+
// TODO validate against dataPartSchemas
476+
const dataPart = part as DataUIMessageStreamPart<
477+
InferUIMessageData<UI_MESSAGE>
478+
>;
479+
471480
// TODO improve type safety
472481
const existingPart: any =
473-
part.id != null
482+
dataPart.id != null
474483
? state.message.parts.find(
475484
(partArg: any) =>
476-
part.type === partArg.type && part.id === partArg.id,
485+
dataPart.type === partArg.type &&
486+
dataPart.id === partArg.id,
477487
)
478488
: undefined;
479489

480490
if (existingPart != null) {
481-
// TODO improve type safety
491+
// TODO validate merged data against dataPartSchemas
482492
existingPart.data =
483-
isObject(existingPart.data) && isObject(part.data)
484-
? mergeObjects(existingPart.data, part.data)
485-
: part.data;
493+
isObject(existingPart.data) && isObject(dataPart.data)
494+
? mergeObjects(existingPart.data, dataPart.data)
495+
: dataPart.data;
486496
} else {
487-
// TODO improve type safety
488-
state.message.parts.push(part as any);
497+
state.message.parts.push(dataPart);
489498
}
499+
500+
onData?.(dataPart);
501+
490502
write();
491503
}
492504
}

0 commit comments

Comments
 (0)