Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion go/ai/background_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ func NewBackgroundModel(name string, opts *BackgroundModelOptions, startFn Start
simulateSystemPrompt(&opts.ModelOptions, nil),
augmentWithContext(&opts.ModelOptions, nil),
validateSupport(name, &opts.ModelOptions),
addAutomaticTelemetry(),
}
fn := core.ChainMiddleware(mws...)(backgroundModelToModelFn(startFn))

Expand Down
58 changes: 31 additions & 27 deletions go/ai/model_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,36 @@ type DownloadMediaOptions struct {
Filter func(part *Part) bool // Filter to apply to parts that are media URLs.
}

func CalculateInputOutputUsage(req *ModelRequest, resp *ModelResponse) {
if resp.Usage == nil {
resp.Usage = &GenerationUsage{}
}
if resp.Usage.InputCharacters == 0 {
resp.Usage.InputCharacters = countInputCharacters(req)
}
if resp.Usage.OutputCharacters == 0 {
resp.Usage.OutputCharacters = countOutputCharacters(resp)
}
if resp.Usage.InputImages == 0 {
resp.Usage.InputImages = countInputParts(req, func(part *Part) bool { return part.IsImage() })
}
if resp.Usage.OutputImages == 0 {
resp.Usage.OutputImages = countOutputParts(resp, func(part *Part) bool { return part.IsImage() })
}
if resp.Usage.InputVideos == 0 {
resp.Usage.InputVideos = countInputParts(req, func(part *Part) bool { return part.IsVideo() })
}
if resp.Usage.OutputVideos == 0 {
resp.Usage.OutputVideos = countOutputParts(resp, func(part *Part) bool { return part.IsVideo() })
}
if resp.Usage.InputAudioFiles == 0 {
resp.Usage.InputAudioFiles = countInputParts(req, func(part *Part) bool { return part.IsAudio() })
}
if resp.Usage.OutputAudioFiles == 0 {
resp.Usage.OutputAudioFiles = countOutputParts(resp, func(part *Part) bool { return part.IsAudio() })
}
}

// addAutomaticTelemetry creates middleware that automatically measures latency and calculates character and media counts.
func addAutomaticTelemetry() ModelMiddleware {
return func(fn ModelFunc) ModelFunc {
Expand All @@ -66,33 +96,7 @@ func addAutomaticTelemetry() ModelMiddleware {
resp.LatencyMs = latencyMs
}

if resp.Usage == nil {
resp.Usage = &GenerationUsage{}
}
if resp.Usage.InputCharacters == 0 {
resp.Usage.InputCharacters = countInputCharacters(req)
}
if resp.Usage.OutputCharacters == 0 {
resp.Usage.OutputCharacters = countOutputCharacters(resp)
}
if resp.Usage.InputImages == 0 {
resp.Usage.InputImages = countInputParts(req, func(part *Part) bool { return part.IsImage() })
}
if resp.Usage.OutputImages == 0 {
resp.Usage.OutputImages = countOutputParts(resp, func(part *Part) bool { return part.IsImage() })
}
if resp.Usage.InputVideos == 0 {
resp.Usage.InputVideos = countInputParts(req, func(part *Part) bool { return part.IsVideo() })
}
if resp.Usage.OutputVideos == 0 {
resp.Usage.OutputVideos = countOutputParts(resp, func(part *Part) bool { return part.IsVideo() })
}
if resp.Usage.InputAudioFiles == 0 {
resp.Usage.InputAudioFiles = countInputParts(req, func(part *Part) bool { return part.IsAudio() })
}
if resp.Usage.OutputAudioFiles == 0 {
resp.Usage.OutputAudioFiles = countOutputParts(resp, func(part *Part) bool { return part.IsAudio() })
}
CalculateInputOutputUsage(req, resp)

return resp, nil
}
Expand Down
42 changes: 40 additions & 2 deletions go/plugins/googlegenai/veo.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package googlegenai
import (
"context"
"fmt"
"time"

"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/core"
Expand Down Expand Up @@ -55,7 +56,15 @@ func newVeoModel(
return nil, fmt.Errorf("veo video generation failed: %w", err)
}

return fromVeoOperation(operation), nil
op := fromVeoOperation(operation)

if op.Metadata == nil {
op.Metadata = make(map[string]any)
}
op.Metadata["inputRequest"] = req
op.Metadata["startTime"] = time.Now()

return op, nil
}

checkFunc := func(ctx context.Context, op *ai.ModelOperation) (*ai.ModelOperation, error) {
Expand All @@ -64,7 +73,36 @@ func newVeoModel(
return nil, fmt.Errorf("veo operation status check failed: %w", err)
}

return fromVeoOperation(veoOp), nil
updatedOp := fromVeoOperation(veoOp)

// Restore metadata from the original operation
if op.Metadata != nil {
if updatedOp.Metadata == nil {
updatedOp.Metadata = make(map[string]any)
}
for k, v := range op.Metadata {
updatedOp.Metadata[k] = v
}
}

// Add telemetry metrics when operation completes
if updatedOp.Done && updatedOp.Output != nil {
if req, ok := updatedOp.Metadata["inputRequest"].(*ai.ModelRequest); ok {
ai.CalculateInputOutputUsage(req, updatedOp.Output)
} else {
ai.CalculateInputOutputUsage(nil, updatedOp.Output)
}

// Calculate latency if startTime is available
if startTime, ok := updatedOp.Metadata["startTime"].(time.Time); ok {
latencyMs := float64(time.Since(startTime).Nanoseconds()) / 1e6
if updatedOp.Output.LatencyMs == 0 {
updatedOp.Output.LatencyMs = latencyMs
}
}
}

return updatedOp, nil
}

return ai.NewBackgroundModel(name, &ai.BackgroundModelOptions{ModelOptions: info}, startFunc, checkFunc)
Expand Down
Loading