Skip to content

Commit a76a62b

Browse files
authored
feat (ai): add experimental prepareStep callback to generateText (#5985) (#5991)
## Background For many agentic use cases, it is helpful to be able to control the different settings at each step. ## Summary Add `experimental_prepareStep` option to `generateText` that allows modifying `model`, `toolChoice`, and `experimental_activeTools` for each step. ## Future Work Add to `streamText`. Add more inputs and outputs to `experimental_prepareStep`. ## Related Issues #4954 #3944 #5478
1 parent 52f30f8 commit a76a62b

File tree

5 files changed

+639
-12
lines changed

5 files changed

+639
-12
lines changed

‎.changeset/ten-students-yell.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
feat (ai): add experimental prepareStep callback to generateText

‎content/docs/03-ai-sdk-core/15-tools-and-tool-calling.mdx

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,46 @@ const result = await generateText({
142142
});
143143
```
144144

145+
### `experimental_prepareStep` callback
146+
147+
<Note type="warning">
148+
The `experimental_prepareStep` callback is experimental and may change in the
149+
future. It is only available in the `generateText` function.
150+
</Note>
151+
152+
The `experimental_prepareStep` callback is called before a step is started.
153+
154+
It is called with the following parameters:
155+
156+
- `model`: The model that was passed into `generateText`.
157+
- `maxSteps`: The maximum number of steps that was passed into `generateText`.
158+
- `stepNumber`: The number of the step that is being executed.
159+
- `steps`: The steps that have been executed so far.
160+
161+
You can use it to provide different settings for a step.
162+
163+
```tsx highlight="5-7"
164+
import { generateText } from 'ai';
165+
166+
const result = await generateText({
167+
// ...
168+
experimental_prepareStep: async ({ model, stepNumber, maxSteps, steps }) => {
169+
if (stepNumber === 0) {
170+
return {
171+
// use a different model for this step:
172+
model: modelForThisParticularStep,
173+
// force a tool choice for this step:
174+
toolChoice: { type: 'tool', toolName: 'tool1' },
175+
// limit the tools that are available for this step:
176+
experimental_activeTools: ['tool1'],
177+
};
178+
}
179+
180+
// when nothing is returned, the default settings are used
181+
},
182+
});
183+
```
184+
145185
## Response Messages
146186

147187
Adding the generated assistant and tool messages to your conversation history is a common task,

‎packages/ai/core/generate-text/__snapshots__/generate-text.test.ts.snap

Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,339 @@ exports[`options.maxSteps > 2 steps: initial, tool-result > result.steps should
333333
]
334334
`;
335335

336+
exports[`options.maxSteps > 2 steps: initial, tool-result with prepareStep > onStepFinish should be called for each step 1`] = `
337+
[
338+
{
339+
"files": [],
340+
"finishReason": "tool-calls",
341+
"isContinued": false,
342+
"providerMetadata": undefined,
343+
"reasoning": [],
344+
"reasoningText": undefined,
345+
"request": {},
346+
"response": {
347+
"body": undefined,
348+
"headers": undefined,
349+
"id": "test-id-1-from-model",
350+
"messages": [
351+
{
352+
"content": [
353+
{
354+
"args": {
355+
"value": "value",
356+
},
357+
"toolCallId": "call-1",
358+
"toolName": "tool1",
359+
"type": "tool-call",
360+
},
361+
],
362+
"id": "msg-0",
363+
"role": "assistant",
364+
},
365+
{
366+
"content": [
367+
{
368+
"result": "result1",
369+
"toolCallId": "call-1",
370+
"toolName": "tool1",
371+
"type": "tool-result",
372+
},
373+
],
374+
"id": "msg-1",
375+
"role": "tool",
376+
},
377+
],
378+
"modelId": "test-response-model-id",
379+
"timestamp": 1970-01-01T00:00:00.000Z,
380+
},
381+
"sources": [],
382+
"stepType": "initial",
383+
"text": "",
384+
"toolCalls": [
385+
{
386+
"args": {
387+
"value": "value",
388+
},
389+
"toolCallId": "call-1",
390+
"toolName": "tool1",
391+
"type": "tool-call",
392+
},
393+
],
394+
"toolResults": [
395+
{
396+
"args": {
397+
"value": "value",
398+
},
399+
"result": "result1",
400+
"toolCallId": "call-1",
401+
"toolName": "tool1",
402+
"type": "tool-result",
403+
},
404+
],
405+
"usage": {
406+
"completionTokens": 5,
407+
"promptTokens": 10,
408+
"totalTokens": 15,
409+
},
410+
"warnings": [],
411+
},
412+
{
413+
"files": [],
414+
"finishReason": "stop",
415+
"isContinued": false,
416+
"providerMetadata": undefined,
417+
"reasoning": [],
418+
"reasoningText": undefined,
419+
"request": {},
420+
"response": {
421+
"body": undefined,
422+
"headers": {
423+
"custom-response-header": "response-header-value",
424+
},
425+
"id": "test-id-2-from-model",
426+
"messages": [
427+
{
428+
"content": [
429+
{
430+
"args": {
431+
"value": "value",
432+
},
433+
"toolCallId": "call-1",
434+
"toolName": "tool1",
435+
"type": "tool-call",
436+
},
437+
],
438+
"id": "msg-0",
439+
"role": "assistant",
440+
},
441+
{
442+
"content": [
443+
{
444+
"result": "result1",
445+
"toolCallId": "call-1",
446+
"toolName": "tool1",
447+
"type": "tool-result",
448+
},
449+
],
450+
"id": "msg-1",
451+
"role": "tool",
452+
},
453+
{
454+
"content": [
455+
{
456+
"text": "Hello, world!",
457+
"type": "text",
458+
},
459+
],
460+
"id": "msg-2",
461+
"role": "assistant",
462+
},
463+
],
464+
"modelId": "test-response-model-id",
465+
"timestamp": 1970-01-01T00:00:10.000Z,
466+
},
467+
"sources": [],
468+
"stepType": "tool-result",
469+
"text": "Hello, world!",
470+
"toolCalls": [],
471+
"toolResults": [],
472+
"usage": {
473+
"completionTokens": 20,
474+
"promptTokens": 10,
475+
"totalTokens": 30,
476+
},
477+
"warnings": [],
478+
},
479+
]
480+
`;
481+
482+
exports[`options.maxSteps > 2 steps: initial, tool-result with prepareStep > result.response.messages should contain response messages from all steps 1`] = `
483+
[
484+
{
485+
"content": [
486+
{
487+
"args": {
488+
"value": "value",
489+
},
490+
"toolCallId": "call-1",
491+
"toolName": "tool1",
492+
"type": "tool-call",
493+
},
494+
],
495+
"id": "msg-0",
496+
"role": "assistant",
497+
},
498+
{
499+
"content": [
500+
{
501+
"result": "result1",
502+
"toolCallId": "call-1",
503+
"toolName": "tool1",
504+
"type": "tool-result",
505+
},
506+
],
507+
"id": "msg-1",
508+
"role": "tool",
509+
},
510+
{
511+
"content": [
512+
{
513+
"text": "Hello, world!",
514+
"type": "text",
515+
},
516+
],
517+
"id": "msg-2",
518+
"role": "assistant",
519+
},
520+
]
521+
`;
522+
523+
exports[`options.maxSteps > 2 steps: initial, tool-result with prepareStep > result.steps should contain all steps 1`] = `
524+
[
525+
{
526+
"files": [],
527+
"finishReason": "tool-calls",
528+
"isContinued": false,
529+
"providerMetadata": undefined,
530+
"reasoning": [],
531+
"reasoningText": undefined,
532+
"request": {},
533+
"response": {
534+
"body": undefined,
535+
"headers": undefined,
536+
"id": "test-id-1-from-model",
537+
"messages": [
538+
{
539+
"content": [
540+
{
541+
"args": {
542+
"value": "value",
543+
},
544+
"toolCallId": "call-1",
545+
"toolName": "tool1",
546+
"type": "tool-call",
547+
},
548+
],
549+
"id": "msg-0",
550+
"role": "assistant",
551+
},
552+
{
553+
"content": [
554+
{
555+
"result": "result1",
556+
"toolCallId": "call-1",
557+
"toolName": "tool1",
558+
"type": "tool-result",
559+
},
560+
],
561+
"id": "msg-1",
562+
"role": "tool",
563+
},
564+
],
565+
"modelId": "test-response-model-id",
566+
"timestamp": 1970-01-01T00:00:00.000Z,
567+
},
568+
"sources": [],
569+
"stepType": "initial",
570+
"text": "",
571+
"toolCalls": [
572+
{
573+
"args": {
574+
"value": "value",
575+
},
576+
"toolCallId": "call-1",
577+
"toolName": "tool1",
578+
"type": "tool-call",
579+
},
580+
],
581+
"toolResults": [
582+
{
583+
"args": {
584+
"value": "value",
585+
},
586+
"result": "result1",
587+
"toolCallId": "call-1",
588+
"toolName": "tool1",
589+
"type": "tool-result",
590+
},
591+
],
592+
"usage": {
593+
"completionTokens": 5,
594+
"promptTokens": 10,
595+
"totalTokens": 15,
596+
},
597+
"warnings": [],
598+
},
599+
{
600+
"files": [],
601+
"finishReason": "stop",
602+
"isContinued": false,
603+
"providerMetadata": undefined,
604+
"reasoning": [],
605+
"reasoningText": undefined,
606+
"request": {},
607+
"response": {
608+
"body": undefined,
609+
"headers": {
610+
"custom-response-header": "response-header-value",
611+
},
612+
"id": "test-id-2-from-model",
613+
"messages": [
614+
{
615+
"content": [
616+
{
617+
"args": {
618+
"value": "value",
619+
},
620+
"toolCallId": "call-1",
621+
"toolName": "tool1",
622+
"type": "tool-call",
623+
},
624+
],
625+
"id": "msg-0",
626+
"role": "assistant",
627+
},
628+
{
629+
"content": [
630+
{
631+
"result": "result1",
632+
"toolCallId": "call-1",
633+
"toolName": "tool1",
634+
"type": "tool-result",
635+
},
636+
],
637+
"id": "msg-1",
638+
"role": "tool",
639+
},
640+
{
641+
"content": [
642+
{
643+
"text": "Hello, world!",
644+
"type": "text",
645+
},
646+
],
647+
"id": "msg-2",
648+
"role": "assistant",
649+
},
650+
],
651+
"modelId": "test-response-model-id",
652+
"timestamp": 1970-01-01T00:00:10.000Z,
653+
},
654+
"sources": [],
655+
"stepType": "tool-result",
656+
"text": "Hello, world!",
657+
"toolCalls": [],
658+
"toolResults": [],
659+
"usage": {
660+
"completionTokens": 20,
661+
"promptTokens": 10,
662+
"totalTokens": 30,
663+
},
664+
"warnings": [],
665+
},
666+
]
667+
`;
668+
336669
exports[`options.maxSteps > 4 steps: initial, continue, continue, continue > onStepFinish should be called for each step 1`] = `
337670
[
338671
{

0 commit comments

Comments
 (0)