Skip to content

Commit 667692d

Browse files
authored
Merge pull request #970 from dgageot/better-evals
Better evals
2 parents 0ec22dc + b33f1c8 commit 667692d

File tree

1 file changed

+62
-27
lines changed

1 file changed

+62
-27
lines changed

‎pkg/evaluation/evaluation.go‎

Lines changed: 62 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"os"
77
"path/filepath"
88

9+
"golang.org/x/sync/errgroup"
10+
911
"github.com/docker/cagent/pkg/chat"
1012
"github.com/docker/cagent/pkg/config"
1113
"github.com/docker/cagent/pkg/runtime"
@@ -20,8 +22,9 @@ type Score struct {
2022
}
2123

2224
type Result struct {
23-
Score Score
24-
EvalFile string
25+
FirstMessage string
26+
Score Score
27+
EvalFile string
2528
}
2629

2730
type Printer interface {
@@ -40,14 +43,51 @@ func Evaluate(ctx context.Context, out Printer, agentFilename, evalsDir string,
4043
}
4144

4245
_, err = runEvaluations(ctx, agents, evalsDir, func(result Result) {
46+
out.Printf("---\n")
47+
out.Printf("First message: %s\n", result.FirstMessage)
4348
out.Printf("Eval file: %s\n", result.EvalFile)
4449
out.Printf("Tool trajectory score: %f\n", result.Score.ToolTrajectoryScore)
4550
out.Printf("Rouge-1 score: %f\n", result.Score.Rouge1Score)
51+
out.Printf("\n")
4652
})
4753
return err
4854
}
4955

5056
func runEvaluations(ctx context.Context, t *team.Team, evalsDir string, onResult func(Result)) ([]Result, error) {
57+
evals, err := loadEvalSessions(ctx, evalsDir)
58+
if err != nil {
59+
return nil, err
60+
}
61+
62+
// Each eval gets a channel; results print in order as they complete.
63+
chans := make([]chan Result, len(evals))
64+
errs, ctx := errgroup.WithContext(ctx)
65+
errs.SetLimit(4)
66+
67+
for i := range evals {
68+
chans[i] = make(chan Result, 1)
69+
errs.Go(func() error {
70+
result, err := runSingleEvaluation(ctx, t, &evals[i])
71+
if err == nil {
72+
chans[i] <- result
73+
}
74+
return err
75+
})
76+
}
77+
78+
var results []Result
79+
for _, ch := range chans {
80+
if result, ok := <-ch; ok {
81+
results = append(results, result)
82+
onResult(result)
83+
}
84+
}
85+
86+
return results, errs.Wait()
87+
}
88+
89+
// loadEvalSessions reads all evaluation session files from the given directory.
90+
func loadEvalSessions(ctx context.Context, evalsDir string) ([]session.Session, error) {
5191
evalFiles, err := os.ReadDir(evalsDir)
5292
if err != nil {
5393
return nil, err
@@ -59,46 +99,41 @@ func runEvaluations(ctx context.Context, t *team.Team, evalsDir string, onResult
5999
return nil, ctx.Err()
60100
}
61101

62-
evalFile, err := os.ReadFile(filepath.Join(evalsDir, evalFile.Name()))
102+
data, err := os.ReadFile(filepath.Join(evalsDir, evalFile.Name()))
63103
if err != nil {
64104
return nil, err
65105
}
66106

67107
var sess session.Session
68-
if err := json.Unmarshal(evalFile, &sess); err != nil {
108+
if err := json.Unmarshal(data, &sess); err != nil {
69109
return nil, err
70110
}
71111

72112
evals = append(evals, sess)
73113
}
74114

75-
var results []Result
76-
for i := range evals {
77-
if ctx.Err() != nil {
78-
return nil, ctx.Err()
79-
}
80-
81-
rt, err := runtime.New(t)
82-
if err != nil {
83-
return nil, err
84-
}
85-
86-
actualMessages, err := runLoop(ctx, rt, &evals[i])
87-
if err != nil {
88-
return nil, err
89-
}
115+
return evals, nil
116+
}
90117

91-
score := score(evals[i].GetAllMessages(), actualMessages)
92-
result := Result{
93-
Score: score,
94-
EvalFile: evals[i].ID,
95-
}
96-
onResult(result)
118+
// runSingleEvaluation runs a single evaluation and returns the result.
119+
func runSingleEvaluation(ctx context.Context, t *team.Team, eval *session.Session) (Result, error) {
120+
rt, err := runtime.New(t)
121+
if err != nil {
122+
return Result{}, err
123+
}
97124

98-
results = append(results, result)
125+
actualMessages, err := runLoop(ctx, rt, eval)
126+
if err != nil {
127+
return Result{}, err
99128
}
100129

101-
return results, nil
130+
evalMessages := eval.GetAllMessages()
131+
132+
return Result{
133+
FirstMessage: evalMessages[0].Message.Content,
134+
Score: score(evalMessages, actualMessages),
135+
EvalFile: eval.ID,
136+
}, nil
102137
}
103138

104139
func runLoop(ctx context.Context, rt *runtime.LocalRuntime, eval *session.Session) ([]session.Message, error) {

0 commit comments

Comments
 (0)