Skip to content

Commit e5d292c

Browse files
committed
Add LateFunctionBinding declaration and fix constant folding
Adds the declaration for LateFunctionBindings, which can be used to indicate to the Runtime and the Optimizers that the function will be bound at runtime through the Activation. This lets the constant folding optimizer know that the function potentially has side effects and cannot be folded. Without this the optimization will fail with an error for late bound functions where all arguments are constants. The implementation for late bound functions will be added in a subsequent commit.
1 parent a6fbac9 commit e5d292c

File tree

5 files changed

+251
-3
lines changed

5 files changed

+251
-3
lines changed

‎cel/decls.go‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,12 @@ func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
330330
return decls.FunctionBinding(binding)
331331
}
332332

333+
// LateFunctionBinding indicates that the function has a binding which is not known at compile time.
334+
// This is useful for functions which have side-effects or are not deterministically computable.
335+
func LateFunctionBinding() OverloadOpt {
336+
return decls.LateFunctionBinding()
337+
}
338+
333339
// OverloadIsNonStrict enables the function to be called with error and unknown argument values.
334340
//
335341
// Note: do not use this option unless absoluately necessary as it should be an uncommon feature.

‎cel/folding.go‎

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
6868
// Walk the list of foldable expression and continue to fold until there are no more folds left.
6969
// All of the fold candidates returned by the constantExprMatcher should succeed unless there's
7070
// a logic bug with the selection of expressions.
71-
foldableExprs := ast.MatchDescendants(root, constantExprMatcher)
71+
constantExprMatcherCapture := func(e ast.NavigableExpr) bool { return constantExprMatcher(ctx, a, e) }
72+
foldableExprs := ast.MatchDescendants(root, constantExprMatcherCapture)
7273
foldCount := 0
7374
for len(foldableExprs) != 0 && foldCount < opt.maxFoldIterations {
7475
for _, fold := range foldableExprs {
@@ -77,6 +78,10 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
7778
if fold.Kind() == ast.CallKind && maybePruneBranches(ctx, fold) {
7879
continue
7980
}
81+
// Late-bound function calls cannot be folded.
82+
if fold.Kind() == ast.CallKind && isLateBoundFunctionCall(ctx, a, fold) {
83+
continue
84+
}
8085
// Otherwise, assume all context is needed to evaluate the expression.
8186
err := tryFold(ctx, a, fold)
8287
if err != nil {
@@ -85,7 +90,7 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
8590
}
8691
}
8792
foldCount++
88-
foldableExprs = ast.MatchDescendants(root, constantExprMatcher)
93+
foldableExprs = ast.MatchDescendants(root, constantExprMatcherCapture)
8994
}
9095
// Once all of the constants have been folded, try to run through the remaining comprehensions
9196
// one last time. In this case, there's no guarantee they'll run, so we only update the
@@ -139,6 +144,15 @@ func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error {
139144
return nil
140145
}
141146

147+
func isLateBoundFunctionCall(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) bool {
148+
call := expr.AsCall()
149+
function := ctx.Functions()[call.FunctionName()]
150+
if function == nil {
151+
return false
152+
}
153+
return function.HasLateBinding()
154+
}
155+
142156
// maybePruneBranches inspects the non-strict call expression to determine whether
143157
// a branch can be removed. Evaluation will naturally prune logical and / or calls,
144158
// but conditional will not be pruned cleanly, so this is one small area where the
@@ -455,7 +469,7 @@ func adaptLiteral(ctx *OptimizerContext, val ref.Val) (ast.Expr, error) {
455469
// Only comprehensions which are not nested are included as possible constant folds, and only
456470
// if all variables referenced in the comprehension stack exist are only iteration or
457471
// accumulation variables.
458-
func constantExprMatcher(e ast.NavigableExpr) bool {
472+
func constantExprMatcher(ctx *OptimizerContext, a *ast.AST, e ast.NavigableExpr) bool {
459473
switch e.Kind() {
460474
case ast.CallKind:
461475
return constantCallMatcher(e)
@@ -477,6 +491,10 @@ func constantExprMatcher(e ast.NavigableExpr) bool {
477491
if e.Kind() == ast.IdentKind && !vars[e.AsIdent()] {
478492
constantExprs = false
479493
}
494+
// Late-bound function calls cannot be folded.
495+
if e.Kind() == ast.CallKind && isLateBoundFunctionCall(ctx, a, e) {
496+
constantExprs = false
497+
}
480498
})
481499
ast.PreOrderVisit(e, visitor)
482500
return constantExprs

‎cel/folding_test.go‎

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@ package cel
1717
import (
1818
"reflect"
1919
"sort"
20+
"strings"
2021
"testing"
2122

2223
"google.golang.org/protobuf/encoding/prototext"
2324
"google.golang.org/protobuf/proto"
2425

2526
"github.com/google/cel-go/common/ast"
27+
"github.com/google/cel-go/common/types/ref"
2628

2729
proto3pb "github.com/google/cel-go/test/proto3pb"
2830
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
@@ -313,6 +315,89 @@ func TestConstantFoldingOptimizer(t *testing.T) {
313315
}
314316
}
315317

318+
func TestConstantFoldingCallsWithSideEffects(t *testing.T) {
319+
tests := []struct {
320+
expr string
321+
folded string
322+
error string
323+
}{
324+
{
325+
expr: `noSideEffect(3)`,
326+
folded: `3`,
327+
},
328+
{
329+
expr: `withSideEffect(3)`,
330+
folded: `withSideEffect(3)`,
331+
},
332+
{
333+
expr: `[{}, {"a": 1}, {"b": 2}].exists(i, has(i.b) && withSideEffect(i.b) == 1)`,
334+
folded: `[{}, {"a": 1}, {"b": 2}].exists(i, has(i.b) && withSideEffect(i.b) == 1)`,
335+
},
336+
{
337+
expr: `[{}, {"a": 1}, {"b": 2}].exists(i, has(i.b) && noSideEffect(i.b) == 2)`,
338+
folded: `true`,
339+
},
340+
{
341+
expr: `noImpl(3)`,
342+
error: `constant-folding evaluation failed: no such overload: noImpl`,
343+
},
344+
}
345+
e, err := NewEnv(
346+
OptionalTypes(),
347+
EnableMacroCallTracking(),
348+
Function("noSideEffect",
349+
Overload("noSideEffect_int_int",
350+
[]*Type{IntType},
351+
IntType, FunctionBinding(func(args ...ref.Val) ref.Val {
352+
return args[0]
353+
}))),
354+
Function("withSideEffect",
355+
Overload("withSideEffect_int_int",
356+
[]*Type{IntType},
357+
IntType, LateFunctionBinding())),
358+
Function("noImpl",
359+
Overload("noImpl_int_int",
360+
[]*Type{IntType},
361+
IntType)),
362+
)
363+
if err != nil {
364+
t.Fatalf("NewEnv() failed: %v", err)
365+
}
366+
for _, tst := range tests {
367+
tc := tst
368+
t.Run(tc.expr, func(t *testing.T) {
369+
checked, iss := e.Compile(tc.expr)
370+
if iss.Err() != nil {
371+
t.Fatalf("Compile() failed: %v", iss.Err())
372+
}
373+
folder, err := NewConstantFoldingOptimizer()
374+
if err != nil {
375+
t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err)
376+
}
377+
opt := NewStaticOptimizer(folder)
378+
optimized, iss := opt.Optimize(e, checked)
379+
if tc.error != "" {
380+
if iss.Err() == nil {
381+
t.Errorf("got nil, wanted error containing %q", tc.error)
382+
} else if !strings.Contains(iss.Err().Error(), tc.error) {
383+
t.Errorf("got %q, wanted error containing %q", iss.Err().Error(), tc.error)
384+
}
385+
return
386+
}
387+
if iss.Err() != nil {
388+
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
389+
}
390+
folded, err := AstToString(optimized)
391+
if err != nil {
392+
t.Fatalf("AstToString() failed: %v", err)
393+
}
394+
if folded != tc.folded {
395+
t.Errorf("got %q, wanted %q", folded, tc.folded)
396+
}
397+
})
398+
}
399+
}
400+
316401
func TestConstantFoldingOptimizerMacroElimination(t *testing.T) {
317402
tests := []struct {
318403
expr string

‎common/decls/decls.go‎

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,9 @@ func (f *FunctionDecl) AddOverload(overload *OverloadDecl) error {
232232
}
233233
return fmt.Errorf("overload redefinition in function. %s: %s has multiple definitions", f.Name(), oID)
234234
}
235+
if overload.HasLateBinding() != o.HasLateBinding() {
236+
return fmt.Errorf("overload with late binding cannot be added to function %s: cannot mix late and non-late bindings", f.Name())
237+
}
235238
}
236239
f.overloadOrdinals = append(f.overloadOrdinals, overload.ID())
237240
f.overloads[overload.ID()] = overload
@@ -250,15 +253,30 @@ func (f *FunctionDecl) OverloadDecls() []*OverloadDecl {
250253
return overloads
251254
}
252255

256+
// Returns true if the function has late bindings. A function cannot mix late bindings with other bindings.
257+
func (f *FunctionDecl) HasLateBinding() bool {
258+
if f == nil {
259+
return false
260+
}
261+
for _, oID := range f.overloadOrdinals {
262+
if f.overloads[oID].HasLateBinding() {
263+
return true
264+
}
265+
}
266+
return false
267+
}
268+
253269
// Bindings produces a set of function bindings, if any are defined.
254270
func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
255271
if f == nil {
256272
return []*functions.Overload{}, nil
257273
}
258274
overloads := []*functions.Overload{}
259275
nonStrict := false
276+
hasLateBinding := false
260277
for _, oID := range f.overloadOrdinals {
261278
o := f.overloads[oID]
279+
hasLateBinding = hasLateBinding || o.HasLateBinding()
262280
if o.hasBinding() {
263281
overload := &functions.Overload{
264282
Operator: o.ID(),
@@ -276,6 +294,9 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
276294
if len(overloads) != 0 {
277295
return nil, fmt.Errorf("singleton function incompatible with specialized overloads: %s", f.Name())
278296
}
297+
if hasLateBinding {
298+
return nil, fmt.Errorf("singleton function incompatible with late bindings: %s", f.Name())
299+
}
279300
overloads = []*functions.Overload{
280301
{
281302
Operator: f.Name(),
@@ -516,6 +537,9 @@ type OverloadDecl struct {
516537
argTypes []*types.Type
517538
resultType *types.Type
518539
isMemberFunction bool
540+
// hasLateBinding indicates that the function has a binding which is not known at compile time.
541+
// This is useful for functions which have side-effects or are not deterministically computable.
542+
hasLateBinding bool
519543
// nonStrict indicates that the function will accept error and unknown arguments as inputs.
520544
nonStrict bool
521545
// operandTrait indicates whether the member argument should have a specific type-trait.
@@ -571,6 +595,14 @@ func (o *OverloadDecl) IsNonStrict() bool {
571595
return o.nonStrict
572596
}
573597

598+
// HasLateBinding returns whether the overload has a binding which is not known at compile time.
599+
func (o *OverloadDecl) HasLateBinding() bool {
600+
if o == nil {
601+
return false
602+
}
603+
return o.hasLateBinding
604+
}
605+
574606
// OperandTrait returns the trait mask of the first operand to the overload call, e.g.
575607
// `traits.Indexer`
576608
func (o *OverloadDecl) OperandTrait() int {
@@ -739,6 +771,9 @@ func UnaryBinding(binding functions.UnaryOp) OverloadOpt {
739771
if len(o.ArgTypes()) != 1 {
740772
return nil, fmt.Errorf("unary function bound to non-unary overload: %s", o.ID())
741773
}
774+
if o.hasLateBinding {
775+
return nil, fmt.Errorf("overload already has a late binding: %s", o.ID())
776+
}
742777
o.unaryOp = binding
743778
return o, nil
744779
}
@@ -754,6 +789,9 @@ func BinaryBinding(binding functions.BinaryOp) OverloadOpt {
754789
if len(o.ArgTypes()) != 2 {
755790
return nil, fmt.Errorf("binary function bound to non-binary overload: %s", o.ID())
756791
}
792+
if o.hasLateBinding {
793+
return nil, fmt.Errorf("overload already has a late binding: %s", o.ID())
794+
}
757795
o.binaryOp = binding
758796
return o, nil
759797
}
@@ -766,11 +804,26 @@ func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
766804
if o.hasBinding() {
767805
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
768806
}
807+
if o.hasLateBinding {
808+
return nil, fmt.Errorf("overload already has a late binding: %s", o.ID())
809+
}
769810
o.functionOp = binding
770811
return o, nil
771812
}
772813
}
773814

815+
// LateFunctionBinding indicates that the function has a binding which is not known at compile time.
816+
// This is useful for functions which have side-effects or are not deterministically computable.
817+
func LateFunctionBinding() OverloadOpt {
818+
return func(o *OverloadDecl) (*OverloadDecl, error) {
819+
if o.hasBinding() {
820+
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
821+
}
822+
o.hasLateBinding = true
823+
return o, nil
824+
}
825+
}
826+
774827
// OverloadIsNonStrict enables the function to be called with error and unknown argument values.
775828
//
776829
// Note: do not use this option unless absoluately necessary as it should be an uncommon feature.

0 commit comments

Comments
 (0)