Skip to content

Commit 5d7f09b

Browse files
committed
Add LateFunctionBinding declaration and fix constant folding
Adds the declaration for LateFunctionBindings, which can be used to indicate to the Runtime and the Optimizers that the function will be bound at runtime through the Activation. This lets the constant folding optimizer know that the function potentially has side effects and cannot be folded. Without this the optimization will fail with an error for late bound functions where all arguments are constants. The implementation for late bound functions will be added in a subsequent commit.
1 parent f5ea07b commit 5d7f09b

File tree

5 files changed

+250
-3
lines changed

5 files changed

+250
-3
lines changed

‎cel/decls.go‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,12 @@ func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
330330
return decls.FunctionBinding(binding)
331331
}
332332

333+
// LateFunctionBinding indicates that the function has a binding which is not known at compile time.
334+
// This is useful for functions which have side-effects or are not deterministically computable.
335+
func LateFunctionBinding() OverloadOpt {
336+
return decls.LateFunctionBinding()
337+
}
338+
333339
// OverloadIsNonStrict enables the function to be called with error and unknown argument values.
334340
//
335341
// Note: do not use this option unless absoluately necessary as it should be an uncommon feature.

‎cel/folding.go‎

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
6868
// Walk the list of foldable expression and continue to fold until there are no more folds left.
6969
// All of the fold candidates returned by the constantExprMatcher should succeed unless there's
7070
// a logic bug with the selection of expressions.
71-
foldableExprs := ast.MatchDescendants(root, constantExprMatcher)
71+
constantExprMatcherCapture := func(e ast.NavigableExpr) bool { return constantExprMatcher(ctx, a, e) }
72+
foldableExprs := ast.MatchDescendants(root, constantExprMatcherCapture)
7273
foldCount := 0
7374
for len(foldableExprs) != 0 && foldCount < opt.maxFoldIterations {
7475
for _, fold := range foldableExprs {
@@ -77,6 +78,10 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
7778
if fold.Kind() == ast.CallKind && maybePruneBranches(ctx, fold) {
7879
continue
7980
}
81+
// Late-bound function calls cannot be folded.
82+
if fold.Kind() == ast.CallKind && isLateBoundFunctionCall(ctx, a, fold) {
83+
continue
84+
}
8085
// Otherwise, assume all context is needed to evaluate the expression.
8186
err := tryFold(ctx, a, fold)
8287
if err != nil {
@@ -85,7 +90,7 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
8590
}
8691
}
8792
foldCount++
88-
foldableExprs = ast.MatchDescendants(root, constantExprMatcher)
93+
foldableExprs = ast.MatchDescendants(root, constantExprMatcherCapture)
8994
}
9095
// Once all of the constants have been folded, try to run through the remaining comprehensions
9196
// one last time. In this case, there's no guarantee they'll run, so we only update the
@@ -139,6 +144,27 @@ func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error {
139144
return nil
140145
}
141146

147+
func isLateBoundFunctionCall(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) bool {
148+
call := expr.AsCall()
149+
overloadIds := a.GetOverloadIDs(expr.ID())
150+
for _, f := range ctx.Functions() {
151+
if f.Name() != call.FunctionName() {
152+
continue
153+
}
154+
155+
for _, o := range f.OverloadDecls() {
156+
for _, id := range overloadIds {
157+
if id == o.ID() {
158+
if o.HasLateBinding() {
159+
return true
160+
}
161+
}
162+
}
163+
}
164+
}
165+
return false
166+
}
167+
142168
// maybePruneBranches inspects the non-strict call expression to determine whether
143169
// a branch can be removed. Evaluation will naturally prune logical and / or calls,
144170
// but conditional will not be pruned cleanly, so this is one small area where the
@@ -455,7 +481,7 @@ func adaptLiteral(ctx *OptimizerContext, val ref.Val) (ast.Expr, error) {
455481
// Only comprehensions which are not nested are included as possible constant folds, and only
456482
// if all variables referenced in the comprehension stack exist are only iteration or
457483
// accumulation variables.
458-
func constantExprMatcher(e ast.NavigableExpr) bool {
484+
func constantExprMatcher(ctx *OptimizerContext, a *ast.AST, e ast.NavigableExpr) bool {
459485
switch e.Kind() {
460486
case ast.CallKind:
461487
return constantCallMatcher(e)
@@ -477,6 +503,10 @@ func constantExprMatcher(e ast.NavigableExpr) bool {
477503
if e.Kind() == ast.IdentKind && !vars[e.AsIdent()] {
478504
constantExprs = false
479505
}
506+
// Late-bound function calls cannot be folded.
507+
if e.Kind() == ast.CallKind && isLateBoundFunctionCall(ctx, a, e) {
508+
constantExprs = false
509+
}
480510
})
481511
ast.PreOrderVisit(e, visitor)
482512
return constantExprs

‎cel/folding_test.go‎

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@ package cel
1717
import (
1818
"reflect"
1919
"sort"
20+
"strings"
2021
"testing"
2122

2223
"google.golang.org/protobuf/encoding/prototext"
2324
"google.golang.org/protobuf/proto"
2425

2526
"github.com/google/cel-go/common/ast"
27+
"github.com/google/cel-go/common/types/ref"
2628

2729
proto3pb "github.com/google/cel-go/test/proto3pb"
2830
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
@@ -313,6 +315,89 @@ func TestConstantFoldingOptimizer(t *testing.T) {
313315
}
314316
}
315317

318+
func TestConstantFoldingCallsWithSideEffects(t *testing.T) {
319+
tests := []struct {
320+
expr string
321+
folded string
322+
error string
323+
}{
324+
{
325+
expr: `noSideEffect(3)`,
326+
folded: `3`,
327+
},
328+
{
329+
expr: `withSideEffect(3)`,
330+
folded: `withSideEffect(3)`,
331+
},
332+
{
333+
expr: `[{}, {"a": 1}, {"b": 2}].exists(i, has(i.b) && withSideEffect(i.b) == 1)`,
334+
folded: `[{}, {"a": 1}, {"b": 2}].exists(i, has(i.b) && withSideEffect(i.b) == 1)`,
335+
},
336+
{
337+
expr: `[{}, {"a": 1}, {"b": 2}].exists(i, has(i.b) && noSideEffect(i.b) == 2)`,
338+
folded: `true`,
339+
},
340+
{
341+
expr: `noImpl(3)`,
342+
error: `constant-folding evaluation failed: no such overload: noImpl`,
343+
},
344+
}
345+
e, err := NewEnv(
346+
OptionalTypes(),
347+
EnableMacroCallTracking(),
348+
Function("noSideEffect",
349+
Overload("noSideEffect_int_int",
350+
[]*Type{IntType},
351+
IntType, FunctionBinding(func(args ...ref.Val) ref.Val {
352+
return args[0]
353+
}))),
354+
Function("withSideEffect",
355+
Overload("withSideEffect_int_int",
356+
[]*Type{IntType},
357+
IntType, LateFunctionBinding())),
358+
Function("noImpl",
359+
Overload("noImpl_int_int",
360+
[]*Type{IntType},
361+
IntType)),
362+
)
363+
if err != nil {
364+
t.Fatalf("NewEnv() failed: %v", err)
365+
}
366+
for _, tst := range tests {
367+
tc := tst
368+
t.Run(tc.expr, func(t *testing.T) {
369+
checked, iss := e.Compile(tc.expr)
370+
if iss.Err() != nil {
371+
t.Fatalf("Compile() failed: %v", iss.Err())
372+
}
373+
folder, err := NewConstantFoldingOptimizer()
374+
if err != nil {
375+
t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err)
376+
}
377+
opt := NewStaticOptimizer(folder)
378+
optimized, iss := opt.Optimize(e, checked)
379+
if tc.error != "" {
380+
if iss.Err() == nil {
381+
t.Errorf("got nil, wanted error containing %q", tc.error)
382+
} else if !strings.Contains(iss.Err().Error(), tc.error) {
383+
t.Errorf("got %q, wanted error containing %q", iss.Err().Error(), tc.error)
384+
}
385+
return
386+
}
387+
if iss.Err() != nil {
388+
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
389+
}
390+
folded, err := AstToString(optimized)
391+
if err != nil {
392+
t.Fatalf("AstToString() failed: %v", err)
393+
}
394+
if folded != tc.folded {
395+
t.Errorf("got %q, wanted %q", folded, tc.folded)
396+
}
397+
})
398+
}
399+
}
400+
316401
func TestConstantFoldingOptimizerMacroElimination(t *testing.T) {
317402
tests := []struct {
318403
expr string

‎common/decls/decls.go‎

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,9 @@ func (f *FunctionDecl) AddOverload(overload *OverloadDecl) error {
232232
}
233233
return fmt.Errorf("overload redefinition in function. %s: %s has multiple definitions", f.Name(), oID)
234234
}
235+
if overload.HasLateBinding() != o.HasLateBinding() {
236+
return fmt.Errorf("overload with late binding cannot be added to function %s: cannot mix late and non-late bindings", f.Name())
237+
}
235238
}
236239
f.overloadOrdinals = append(f.overloadOrdinals, overload.ID())
237240
f.overloads[overload.ID()] = overload
@@ -257,8 +260,10 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
257260
}
258261
overloads := []*functions.Overload{}
259262
nonStrict := false
263+
hasLateBinding := false
260264
for _, oID := range f.overloadOrdinals {
261265
o := f.overloads[oID]
266+
hasLateBinding = hasLateBinding || o.HasLateBinding()
262267
if o.hasBinding() {
263268
overload := &functions.Overload{
264269
Operator: o.ID(),
@@ -276,6 +281,9 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
276281
if len(overloads) != 0 {
277282
return nil, fmt.Errorf("singleton function incompatible with specialized overloads: %s", f.Name())
278283
}
284+
if hasLateBinding {
285+
return nil, fmt.Errorf("singleton function incompatible with late bindings: %s", f.Name())
286+
}
279287
overloads = []*functions.Overload{
280288
{
281289
Operator: f.Name(),
@@ -516,6 +524,9 @@ type OverloadDecl struct {
516524
argTypes []*types.Type
517525
resultType *types.Type
518526
isMemberFunction bool
527+
// hasLateBinding indicates that the function has a binding which is not known at compile time.
528+
// This is useful for functions which have side-effects or are not deterministically computable.
529+
hasLateBinding bool
519530
// nonStrict indicates that the function will accept error and unknown arguments as inputs.
520531
nonStrict bool
521532
// operandTrait indicates whether the member argument should have a specific type-trait.
@@ -571,6 +582,14 @@ func (o *OverloadDecl) IsNonStrict() bool {
571582
return o.nonStrict
572583
}
573584

585+
// HasLateBinding returns whether the overload has a binding which is not known at compile time.
586+
func (o *OverloadDecl) HasLateBinding() bool {
587+
if o == nil {
588+
return false
589+
}
590+
return o.hasLateBinding
591+
}
592+
574593
// OperandTrait returns the trait mask of the first operand to the overload call, e.g.
575594
// `traits.Indexer`
576595
func (o *OverloadDecl) OperandTrait() int {
@@ -739,6 +758,9 @@ func UnaryBinding(binding functions.UnaryOp) OverloadOpt {
739758
if len(o.ArgTypes()) != 1 {
740759
return nil, fmt.Errorf("unary function bound to non-unary overload: %s", o.ID())
741760
}
761+
if o.hasLateBinding {
762+
return nil, fmt.Errorf("overload already has a late binding: %s", o.ID())
763+
}
742764
o.unaryOp = binding
743765
return o, nil
744766
}
@@ -754,6 +776,9 @@ func BinaryBinding(binding functions.BinaryOp) OverloadOpt {
754776
if len(o.ArgTypes()) != 2 {
755777
return nil, fmt.Errorf("binary function bound to non-binary overload: %s", o.ID())
756778
}
779+
if o.hasLateBinding {
780+
return nil, fmt.Errorf("overload already has a late binding: %s", o.ID())
781+
}
757782
o.binaryOp = binding
758783
return o, nil
759784
}
@@ -766,11 +791,26 @@ func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
766791
if o.hasBinding() {
767792
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
768793
}
794+
if o.hasLateBinding {
795+
return nil, fmt.Errorf("overload already has a late binding: %s", o.ID())
796+
}
769797
o.functionOp = binding
770798
return o, nil
771799
}
772800
}
773801

802+
// LateFunctionBinding indicates that the function has a binding which is not known at compile time.
803+
// This is useful for functions which have side-effects or are not deterministically computable.
804+
func LateFunctionBinding() OverloadOpt {
805+
return func(o *OverloadDecl) (*OverloadDecl, error) {
806+
if o.hasBinding() {
807+
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
808+
}
809+
o.hasLateBinding = true
810+
return o, nil
811+
}
812+
}
813+
774814
// OverloadIsNonStrict enables the function to be called with error and unknown argument values.
775815
//
776816
// Note: do not use this option unless absoluately necessary as it should be an uncommon feature.

‎common/decls/decls_test.go‎

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,24 @@ func TestSingletonOverloadCollision(t *testing.T) {
473473
}
474474
}
475475

476+
func TestSingletonOverloadLateBindingCollision(t *testing.T) {
477+
fn, err := NewFunction("id",
478+
Overload("id_any", []*types.Type{types.AnyType}, types.AnyType,
479+
LateFunctionBinding(),
480+
),
481+
SingletonUnaryBinding(func(arg ref.Val) ref.Val {
482+
return arg
483+
}),
484+
)
485+
if err != nil {
486+
t.Fatalf("NewFunction() failed: %v", err)
487+
}
488+
_, err = fn.Bindings()
489+
if err == nil || !strings.Contains(err.Error(), "incompatible with late bindings") {
490+
t.Errorf("NewFunction() got %v, wanted incompatible with late bindings", err)
491+
}
492+
}
493+
476494
func TestSingletonUnaryBindingRedefinition(t *testing.T) {
477495
_, err := NewFunction("id",
478496
Overload("id_any", []*types.Type{types.AnyType}, types.AnyType),
@@ -592,6 +610,74 @@ func TestOverloadFunctionBindingRedefinition(t *testing.T) {
592610
}
593611
}
594612

613+
func TestOverloadFunctionLateBinding(t *testing.T) {
614+
function, err := NewFunction("id",
615+
Overload("id_bool", []*types.Type{types.BoolType}, types.AnyType, LateFunctionBinding(), LateFunctionBinding()),
616+
)
617+
if err != nil {
618+
t.Fatalf("NewFunction() failed: %v", err)
619+
}
620+
if len(function.OverloadDecls()) != 1 {
621+
t.Fatalf("NewFunction() got %v, wanted 1 overload", function.OverloadDecls())
622+
}
623+
if !function.OverloadDecls()[0].HasLateBinding() {
624+
t.Errorf("overload %v did not have a late binding", function.OverloadDecls()[0])
625+
}
626+
}
627+
628+
func TestOverloadFunctionMixLateAndNonLateBinding(t *testing.T) {
629+
_, err := NewFunction("id",
630+
Overload("id_bool", []*types.Type{types.BoolType}, types.AnyType, LateFunctionBinding()),
631+
Overload("id_int", []*types.Type{types.IntType}, types.AnyType),
632+
)
633+
if err == nil || !strings.Contains(err.Error(), "cannot mix late and non-late bindings") {
634+
t.Errorf("NewCustomEnv() got %v, wanted cannot mix late and non-late bindings", err)
635+
}
636+
}
637+
638+
func TestOverloadFunctionBindingWithLateBinding(t *testing.T) {
639+
_, err := NewFunction("id",
640+
Overload("id_bool", []*types.Type{types.BoolType}, types.AnyType, FunctionBinding(func(args ...ref.Val) ref.Val {
641+
return args[0]
642+
}), LateFunctionBinding()),
643+
)
644+
if err == nil || !strings.Contains(err.Error(), "already has a binding") {
645+
t.Errorf("NewCustomEnv() got %v, wanted already has a binding", err)
646+
}
647+
}
648+
649+
func TestOverloadFunctionLateBindingWithBinding(t *testing.T) {
650+
_, err := NewFunction("id",
651+
Overload("id_bool", []*types.Type{types.BoolType}, types.AnyType, LateFunctionBinding(),
652+
FunctionBinding(func(args ...ref.Val) ref.Val {
653+
return args[0]
654+
})),
655+
)
656+
if err == nil || !strings.Contains(err.Error(), "already has a late binding") {
657+
t.Errorf("NewCustomEnv() got %v, wanted already has a late binding", err)
658+
}
659+
660+
_, err = NewFunction("id",
661+
Overload("id_bool", []*types.Type{types.BoolType}, types.AnyType, LateFunctionBinding(),
662+
UnaryBinding(func(arg ref.Val) ref.Val {
663+
return arg
664+
})),
665+
)
666+
if err == nil || !strings.Contains(err.Error(), "already has a late binding") {
667+
t.Errorf("NewCustomEnv() got %v, wanted already has a late binding", err)
668+
}
669+
670+
_, err = NewFunction("id",
671+
Overload("id_bool", []*types.Type{types.BoolType, types.BoolType}, types.AnyType, LateFunctionBinding(),
672+
BinaryBinding(func(arg1 ref.Val, arg2 ref.Val) ref.Val {
673+
return arg1
674+
})),
675+
)
676+
if err == nil || !strings.Contains(err.Error(), "already has a late binding") {
677+
t.Errorf("NewCustomEnv() got %v, wanted already has a late binding", err)
678+
}
679+
}
680+
595681
func TestOverloadIsNonStrict(t *testing.T) {
596682
fn, err := NewFunction("getOrDefault",
597683
MemberOverload("get",

0 commit comments

Comments
 (0)