Skip to content

Commit 9b4d074

Browse files
authored
feat(streamObject): add enum support (#6028)
## Background We were missing enum on the streamObject method
1 parent 5e5b5d5 commit 9b4d074

File tree

7 files changed

+283
-27
lines changed

7 files changed

+283
-27
lines changed

‎.changeset/funny-cows-sin.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
feat(streamObject): add enum support

‎content/docs/03-ai-sdk-core/10-generating-structured-data.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ However, you need to manually provide schemas and then validate the generated da
1414
The AI SDK standardises structured object generation across model providers
1515
with the [`generateObject`](/docs/reference/ai-sdk-core/generate-object)
1616
and [`streamObject`](/docs/reference/ai-sdk-core/stream-object) functions.
17-
You can use both functions with different output strategies, e.g. `array`, `object`, or `no-schema`,
17+
You can use both functions with different output strategies, e.g. `array`, `object`, `enum`, or `no-schema`,
1818
and with different generation modes, e.g. `auto`, `tool`, or `json`.
1919
You can use [Zod schemas](/docs/reference/ai-sdk-core/zod-schema), [Valibot](/docs/reference/ai-sdk-core/valibot-schema), or [JSON schemas](/docs/reference/ai-sdk-core/json-schema) to specify the shape of the data that you want,
2020
and the AI model will generate data that conforms to that structure.
@@ -110,7 +110,7 @@ const result = streamObject({
110110

111111
## Output Strategy
112112

113-
You can use both functions with different output strategies, e.g. `array`, `object`, or `no-schema`.
113+
You can use both functions with different output strategies, e.g. `array`, `object`, `enum`, or `no-schema`.
114114

115115
### Object
116116

‎content/docs/07-reference/01-ai-sdk-core/04-stream-object.mdx

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,25 @@ for await (const partialObject of partialObjectStream) {
8080
}
8181
```
8282

83+
#### Example: generate an enum
84+
85+
When you want to generate a specific enum value, you can set the output strategy to `enum`
86+
and provide the list of possible values in the `enum` parameter.
87+
88+
```ts highlight="5-6"
89+
import { streamObject } from 'ai';
90+
91+
const { partialObjectStream } = streamObject({
92+
model: yourModel,
93+
output: 'enum',
94+
enum: ['action', 'comedy', 'drama', 'horror', 'sci-fi'],
95+
prompt:
96+
'Classify the genre of this movie plot: ' +
97+
'"A group of astronauts travel through a wormhole in search of a ' +
98+
'new habitable planet for humanity."',
99+
});
100+
```
101+
83102
To see `streamObject` in action, check out the [additional examples](#more-examples).
84103

85104
## Import
@@ -99,7 +118,7 @@ To see `streamObject` in action, check out the [additional examples](#more-examp
99118
},
100119
{
101120
name: 'output',
102-
type: "'object' | 'array' | 'no-schema' | undefined",
121+
type: "'object' | 'array' | 'enum' | 'no-schema' | undefined",
103122
description: "The type of output to generate. Defaults to 'object'.",
104123
},
105124
{
@@ -118,23 +137,23 @@ To see `streamObject` in action, check out the [additional examples](#more-examp
118137
It is sent to the model to generate the object and used to validate the output. \
119138
You can either pass in a Zod schema or a JSON schema (using the `jsonSchema` function). \
120139
In 'array' mode, the schema is used to describe an array element. \
121-
Not available with 'no-schema' output.",
140+
Not available with 'no-schema' or 'enum' output.",
122141
},
123142
{
124143
name: 'schemaName',
125144
type: 'string | undefined',
126145
description:
127146
"Optional name of the output that should be generated. \
128147
Used by some providers for additional LLM guidance, e.g. via tool or schema name. \
129-
Not available with 'no-schema' output.",
148+
Not available with 'no-schema' or 'enum' output.",
130149
},
131150
{
132151
name: 'schemaDescription',
133152
type: 'string | undefined',
134153
description:
135154
"Optional description of the output that should be generated. \
136155
Used by some providers for additional LLM guidance, e.g. via tool or schema name. \
137-
Not available with 'no-schema' output.",
156+
Not available with 'no-schema' or 'enum' output.",
138157
},
139158
{
140159
name: 'system',

‎packages/ai/core/generate-object/output-strategy.ts

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ const arrayOutputStrategy = <ELEMENT>(
290290

291291
const enumOutputStrategy = <ENUM extends string>(
292292
enumValues: Array<ENUM>,
293-
): OutputStrategy<ENUM, ENUM, never> => {
293+
): OutputStrategy<string, ENUM, never> => {
294294
return {
295295
type: 'enum',
296296

@@ -335,11 +335,41 @@ const enumOutputStrategy = <ENUM extends string>(
335335
};
336336
},
337337

338-
validatePartialResult() {
339-
// no streaming in enum mode
340-
throw new UnsupportedFunctionalityError({
341-
functionality: 'partial results in enum mode',
342-
});
338+
async validatePartialResult({ value, textDelta }) {
339+
if (!isJSONObject(value) || typeof value.result !== 'string') {
340+
return {
341+
success: false,
342+
error: new TypeValidationError({
343+
value,
344+
cause:
345+
'value must be an object that contains a string in the "result" property.',
346+
}),
347+
};
348+
}
349+
350+
const result = value.result as string;
351+
const possibleEnumValues = enumValues.filter(enumValue =>
352+
enumValue.startsWith(result),
353+
);
354+
355+
if (value.result.length === 0 || possibleEnumValues.length === 0) {
356+
return {
357+
success: false,
358+
error: new TypeValidationError({
359+
value,
360+
cause: 'value must be a string in the enum',
361+
}),
362+
};
363+
}
364+
365+
return {
366+
success: true,
367+
value: {
368+
partial:
369+
possibleEnumValues.length > 1 ? result : possibleEnumValues[0],
370+
textDelta,
371+
},
372+
};
343373
},
344374

345375
createElementStream() {

‎packages/ai/core/generate-object/stream-object.test-d.ts

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,26 @@
11
import { expectTypeOf } from 'vitest';
2-
import { generateObject } from './generate-object';
32
import { z } from 'zod';
43
import { JSONValue } from '@ai-sdk/provider';
54
import { streamObject } from './stream-object';
65
import { AsyncIterableStream } from '../util/async-iterable-stream';
76

87
describe('streamObject', () => {
8+
it('should support enum types', async () => {
9+
const result = await streamObject({
10+
output: 'enum',
11+
enum: ['a', 'b', 'c'] as const,
12+
model: undefined!,
13+
});
14+
15+
expectTypeOf<typeof result.object>().toEqualTypeOf<
16+
Promise<'a' | 'b' | 'c'>
17+
>;
18+
19+
for await (const text of result.partialObjectStream) {
20+
expectTypeOf(text).toEqualTypeOf<string>();
21+
}
22+
});
23+
924
it('should support schema types', async () => {
1025
const result = streamObject({
1126
schema: z.object({ number: z.number() }),

‎packages/ai/core/generate-object/stream-object.test.ts

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,6 +1139,168 @@ describe('streamObject', () => {
11391139
});
11401140
});
11411141

1142+
describe('output = "enum"', () => {
1143+
it('should stream an enum value', async () => {
1144+
const mockModel = new MockLanguageModelV2({
1145+
doStream: {
1146+
stream: convertArrayToReadableStream([
1147+
{ type: 'text', text: '{ ' },
1148+
{ type: 'text', text: '"result": ' },
1149+
{ type: 'text', text: `"su` },
1150+
{ type: 'text', text: `nny` },
1151+
{ type: 'text', text: `"` },
1152+
{ type: 'text', text: ' }' },
1153+
{
1154+
type: 'finish',
1155+
finishReason: 'stop',
1156+
usage: { inputTokens: 3, outputTokens: 10 },
1157+
},
1158+
]),
1159+
},
1160+
});
1161+
1162+
const result = streamObject({
1163+
model: mockModel,
1164+
output: 'enum',
1165+
enum: ['sunny', 'rainy', 'snowy'],
1166+
prompt: 'prompt',
1167+
});
1168+
1169+
expect(await convertAsyncIterableToArray(result.partialObjectStream))
1170+
.toMatchInlineSnapshot(`
1171+
[
1172+
"sunny",
1173+
]
1174+
`);
1175+
1176+
expect(mockModel.doStreamCalls[0].responseFormat).toMatchInlineSnapshot(`
1177+
{
1178+
"description": undefined,
1179+
"name": undefined,
1180+
"schema": {
1181+
"$schema": "http://json-schema.org/draft-07/schema#",
1182+
"additionalProperties": false,
1183+
"properties": {
1184+
"result": {
1185+
"enum": [
1186+
"sunny",
1187+
"rainy",
1188+
"snowy",
1189+
],
1190+
"type": "string",
1191+
},
1192+
},
1193+
"required": [
1194+
"result",
1195+
],
1196+
"type": "object",
1197+
},
1198+
"type": "json",
1199+
}
1200+
`);
1201+
});
1202+
1203+
it('should not stream incorrect values', async () => {
1204+
const mockModel = new MockLanguageModelV2({
1205+
doStream: {
1206+
stream: convertArrayToReadableStream([
1207+
{ type: 'text', text: '{ ' },
1208+
{ type: 'text', text: '"result": ' },
1209+
{ type: 'text', text: `"foo` },
1210+
{ type: 'text', text: `bar` },
1211+
{ type: 'text', text: `"` },
1212+
{ type: 'text', text: ' }' },
1213+
{
1214+
type: 'finish',
1215+
finishReason: 'stop',
1216+
usage: { inputTokens: 3, outputTokens: 10 },
1217+
},
1218+
]),
1219+
},
1220+
});
1221+
1222+
const result = streamObject({
1223+
model: mockModel,
1224+
output: 'enum',
1225+
enum: ['sunny', 'rainy', 'snowy'],
1226+
prompt: 'prompt',
1227+
});
1228+
1229+
expect(
1230+
await convertAsyncIterableToArray(result.partialObjectStream),
1231+
).toMatchInlineSnapshot(`[]`);
1232+
});
1233+
1234+
it('should handle ambiguous values', async () => {
1235+
const mockModel = new MockLanguageModelV2({
1236+
doStream: {
1237+
stream: convertArrayToReadableStream([
1238+
{ type: 'text', text: '{ ' },
1239+
{ type: 'text', text: '"result": ' },
1240+
{ type: 'text', text: `"foo` },
1241+
{ type: 'text', text: `bar` },
1242+
{ type: 'text', text: `"` },
1243+
{ type: 'text', text: ' }' },
1244+
{
1245+
type: 'finish',
1246+
finishReason: 'stop',
1247+
usage: { inputTokens: 3, outputTokens: 10 },
1248+
},
1249+
]),
1250+
},
1251+
});
1252+
1253+
const result = streamObject({
1254+
model: mockModel,
1255+
output: 'enum',
1256+
enum: ['foobar', 'foobar2'],
1257+
prompt: 'prompt',
1258+
});
1259+
1260+
expect(await convertAsyncIterableToArray(result.partialObjectStream))
1261+
.toMatchInlineSnapshot(`
1262+
[
1263+
"foo",
1264+
"foobar",
1265+
]
1266+
`);
1267+
});
1268+
1269+
it('should handle non-ambiguous values', async () => {
1270+
const mockModel = new MockLanguageModelV2({
1271+
doStream: {
1272+
stream: convertArrayToReadableStream([
1273+
{ type: 'text', text: '{ ' },
1274+
{ type: 'text', text: '"result": ' },
1275+
{ type: 'text', text: `"foo` },
1276+
{ type: 'text', text: `bar` },
1277+
{ type: 'text', text: `"` },
1278+
{ type: 'text', text: ' }' },
1279+
{
1280+
type: 'finish',
1281+
finishReason: 'stop',
1282+
usage: { inputTokens: 3, outputTokens: 10 },
1283+
},
1284+
]),
1285+
},
1286+
});
1287+
1288+
const result = streamObject({
1289+
model: mockModel,
1290+
output: 'enum',
1291+
enum: ['foobar', 'barfoo'],
1292+
prompt: 'prompt',
1293+
});
1294+
1295+
expect(await convertAsyncIterableToArray(result.partialObjectStream))
1296+
.toMatchInlineSnapshot(`
1297+
[
1298+
"foobar",
1299+
]
1300+
`);
1301+
});
1302+
});
1303+
11421304
describe('output = "no-schema"', () => {
11431305
it('should send object deltas', async () => {
11441306
const mockModel = new MockLanguageModelV2({

0 commit comments

Comments
 (0)