Skip to content

Commit 75ff49b

Browse files
authored
fix(go/plugins/compat_oai): prevent message duplication when using media parts (#3773)
1 parent 72af9c3 commit 75ff49b

File tree

2 files changed

+66
-3
lines changed

2 files changed

+66
-3
lines changed

‎go/plugins/compat_oai/generate.go‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ func (g *ModelGenerator) WithMessages(messages []*ai.Message) *ModelGenerator {
7979
case ai.RoleSystem:
8080
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
8181
case ai.RoleModel:
82-
8382
am := openai.ChatCompletionAssistantMessageParam{}
8483
am.Content.OfString = param.NewOpt(content)
8584
toolCalls, err := convertToolCalls(msg.Content)
@@ -113,10 +112,11 @@ func (g *ModelGenerator) WithMessages(messages []*ai.Message) *ModelGenerator {
113112
oaiMessages = append(oaiMessages, tm)
114113
}
115114
case ai.RoleUser:
116-
oaiMessages = append(oaiMessages, openai.UserMessage(content))
117-
118115
parts := []openai.ChatCompletionContentPartUnionParam{}
119116
for _, p := range msg.Content {
117+
if p.IsText() {
118+
parts = append(parts, openai.TextContentPart(p.Text))
119+
}
120120
if p.IsMedia() {
121121
part := openai.ImageContentPart(
122122
openai.ChatCompletionContentPartImageImageURLParam{

‎go/plugins/compat_oai/openai/openai_live_test.go‎

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ package openai_test
1616

1717
import (
1818
"context"
19+
"encoding/base64"
20+
"io"
1921
"math"
22+
"net/http"
2023
"os"
2124
"strings"
2225
"testing"
@@ -279,4 +282,64 @@ func TestPlugin(t *testing.T) {
279282
t.Fatalf("expecting 2 user messages, got: %d", userMsgCount)
280283
}
281284
})
285+
t.Run("image", func(t *testing.T) {
286+
image, err := fetchImgAsBase64()
287+
if err != nil {
288+
t.Fatalf("failed to fetch image: %v", err)
289+
}
290+
resp, err := genkit.Generate(ctx, g,
291+
ai.WithModelName("openai/gpt-4.1-nano"),
292+
ai.WithMessages(
293+
ai.NewUserMessage(
294+
ai.NewMediaPart("image/jpeg", "data:image/jpeg;base64,"+image),
295+
ai.NewTextPart("What's in the image?."),
296+
),
297+
),
298+
)
299+
if err != nil {
300+
t.Fatalf("failed to generate: %v", err)
301+
}
302+
if !strings.Contains(resp.Text(), "cat") {
303+
t.Fatalf("image detection failed, want: cat, got: %s", resp.Text())
304+
}
305+
mediaMessages := 0
306+
textMessages := 0
307+
for _, m := range resp.Request.Messages {
308+
for _, p := range m.Content {
309+
if p.IsText() {
310+
textMessages += 1
311+
}
312+
if p.IsMedia() {
313+
mediaMessages += 1
314+
}
315+
}
316+
}
317+
if mediaMessages > 1 {
318+
t.Fatalf("unwanted media message, want: 1, got: %d", mediaMessages)
319+
}
320+
if textMessages > 1 {
321+
t.Fatalf("unwanted text message, want: 1, got %d", textMessages)
322+
}
323+
})
324+
}
325+
326+
func fetchImgAsBase64() (string, error) {
327+
// CC0 license image
328+
imgURL := "https://pd.w.org/2025/07/896686fbbcd9990c9.84605288-2048x1365.jpg"
329+
resp, err := http.Get(imgURL)
330+
if err != nil {
331+
return "", err
332+
}
333+
defer resp.Body.Close()
334+
if resp.StatusCode != http.StatusOK {
335+
return "", err
336+
}
337+
338+
imageBytes, err := io.ReadAll(resp.Body)
339+
if err != nil {
340+
return "", err
341+
}
342+
343+
base64string := base64.StdEncoding.EncodeToString(imageBytes)
344+
return base64string, nil
282345
}

0 commit comments

Comments
 (0)