Skip to content

Commit caacded

Browse files
committed
Add LateFunctionBinding declaration and fix constant folding
Adds the declaration for LateFunctionBindings, which can be used to indicate to the Runtime and the Optimizers that the function will be bound at runtime through the Activation. This lets the constant folding optimizer know that the function potentially has side effects and cannot be folded. Without this the optimization will fail with an error for late bound functions where all arguments are constants. The implementation for late bound functions will be added in a subsequent commit.
1 parent f5ea07b commit caacded

File tree

5 files changed

+256
-3
lines changed

5 files changed

+256
-3
lines changed

‎cel/decls.go‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,12 @@ func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
330330
return decls.FunctionBinding(binding)
331331
}
332332

333+
// LateFunctionBinding indicates that the function has a binding which is not known at compile time.
334+
// This is useful for functions which have side-effects or are not deterministically computable.
335+
func LateFunctionBinding() OverloadOpt {
336+
return decls.LateFunctionBinding()
337+
}
338+
333339
// OverloadIsNonStrict enables the function to be called with error and unknown argument values.
334340
//
335341
// Note: do not use this option unless absoluately necessary as it should be an uncommon feature.

‎cel/folding.go‎

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"fmt"
1919

2020
"github.com/google/cel-go/common/ast"
21+
"github.com/google/cel-go/common/decls"
2122
"github.com/google/cel-go/common/operators"
2223
"github.com/google/cel-go/common/overloads"
2324
"github.com/google/cel-go/common/types"
@@ -68,7 +69,8 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
6869
// Walk the list of foldable expression and continue to fold until there are no more folds left.
6970
// All of the fold candidates returned by the constantExprMatcher should succeed unless there's
7071
// a logic bug with the selection of expressions.
71-
foldableExprs := ast.MatchDescendants(root, constantExprMatcher)
72+
constantExprMatcherCapture := func(e ast.NavigableExpr) bool { return constantExprMatcher(ctx, a, e) }
73+
foldableExprs := ast.MatchDescendants(root, constantExprMatcherCapture)
7274
foldCount := 0
7375
for len(foldableExprs) != 0 && foldCount < opt.maxFoldIterations {
7476
for _, fold := range foldableExprs {
@@ -77,6 +79,10 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
7779
if fold.Kind() == ast.CallKind && maybePruneBranches(ctx, fold) {
7880
continue
7981
}
82+
// Late-bound function calls cannot be folded.
83+
if fold.Kind() == ast.CallKind && isLateBoundFunctionCall(ctx, a, fold) {
84+
continue
85+
}
8086
// Otherwise, assume all context is needed to evaluate the expression.
8187
err := tryFold(ctx, a, fold)
8288
if err != nil {
@@ -85,7 +91,7 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
8591
}
8692
}
8793
foldCount++
88-
foldableExprs = ast.MatchDescendants(root, constantExprMatcher)
94+
foldableExprs = ast.MatchDescendants(root, constantExprMatcherCapture)
8995
}
9096
// Once all of the constants have been folded, try to run through the remaining comprehensions
9197
// one last time. In this case, there's no guarantee they'll run, so we only update the
@@ -139,6 +145,32 @@ func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error {
139145
return nil
140146
}
141147

148+
func getFunctionOverloadDecls(function *decls.FunctionDecl, overloadIds []string) []*decls.OverloadDecl {
149+
var overloads []*decls.OverloadDecl
150+
if function == nil {
151+
return overloads
152+
}
153+
for _, o := range function.OverloadDecls() {
154+
for _, id := range overloadIds {
155+
if id == o.ID() {
156+
overloads = append(overloads, o)
157+
}
158+
}
159+
}
160+
return overloads
161+
}
162+
163+
func isLateBoundFunctionCall(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) bool {
164+
call := expr.AsCall()
165+
overloadDecls := getFunctionOverloadDecls(ctx.Functions()[call.FunctionName()], a.GetOverloadIDs(expr.ID()))
166+
for _, o := range overloadDecls {
167+
if o.HasLateBinding() {
168+
return true
169+
}
170+
}
171+
return false
172+
}
173+
142174
// maybePruneBranches inspects the non-strict call expression to determine whether
143175
// a branch can be removed. Evaluation will naturally prune logical and / or calls,
144176
// but conditional will not be pruned cleanly, so this is one small area where the
@@ -455,7 +487,7 @@ func adaptLiteral(ctx *OptimizerContext, val ref.Val) (ast.Expr, error) {
455487
// Only comprehensions which are not nested are included as possible constant folds, and only
456488
// if all variables referenced in the comprehension stack exist are only iteration or
457489
// accumulation variables.
458-
func constantExprMatcher(e ast.NavigableExpr) bool {
490+
func constantExprMatcher(ctx *OptimizerContext, a *ast.AST, e ast.NavigableExpr) bool {
459491
switch e.Kind() {
460492
case ast.CallKind:
461493
return constantCallMatcher(e)
@@ -477,6 +509,10 @@ func constantExprMatcher(e ast.NavigableExpr) bool {
477509
if e.Kind() == ast.IdentKind && !vars[e.AsIdent()] {
478510
constantExprs = false
479511
}
512+
// Late-bound function calls cannot be folded.
513+
if e.Kind() == ast.CallKind && isLateBoundFunctionCall(ctx, a, e) {
514+
constantExprs = false
515+
}
480516
})
481517
ast.PreOrderVisit(e, visitor)
482518
return constantExprs

‎cel/folding_test.go‎

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@ package cel
1717
import (
1818
"reflect"
1919
"sort"
20+
"strings"
2021
"testing"
2122

2223
"google.golang.org/protobuf/encoding/prototext"
2324
"google.golang.org/protobuf/proto"
2425

2526
"github.com/google/cel-go/common/ast"
27+
"github.com/google/cel-go/common/types/ref"
2628

2729
proto3pb "github.com/google/cel-go/test/proto3pb"
2830
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
@@ -313,6 +315,89 @@ func TestConstantFoldingOptimizer(t *testing.T) {
313315
}
314316
}
315317

318+
func TestConstantFoldingCallsWithSideEffects(t *testing.T) {
319+
tests := []struct {
320+
expr string
321+
folded string
322+
error string
323+
}{
324+
{
325+
expr: `noSideEffect(3)`,
326+
folded: `3`,
327+
},
328+
{
329+
expr: `withSideEffect(3)`,
330+
folded: `withSideEffect(3)`,
331+
},
332+
{
333+
expr: `[{}, {"a": 1}, {"b": 2}].exists(i, has(i.b) && withSideEffect(i.b) == 1)`,
334+
folded: `[{}, {"a": 1}, {"b": 2}].exists(i, has(i.b) && withSideEffect(i.b) == 1)`,
335+
},
336+
{
337+
expr: `[{}, {"a": 1}, {"b": 2}].exists(i, has(i.b) && noSideEffect(i.b) == 2)`,
338+
folded: `true`,
339+
},
340+
{
341+
expr: `noImpl(3)`,
342+
error: `constant-folding evaluation failed: no such overload: noImpl`,
343+
},
344+
}
345+
e, err := NewEnv(
346+
OptionalTypes(),
347+
EnableMacroCallTracking(),
348+
Function("noSideEffect",
349+
Overload("noSideEffect_int_int",
350+
[]*Type{IntType},
351+
IntType, FunctionBinding(func(args ...ref.Val) ref.Val {
352+
return args[0]
353+
}))),
354+
Function("withSideEffect",
355+
Overload("withSideEffect_int_int",
356+
[]*Type{IntType},
357+
IntType, LateFunctionBinding())),
358+
Function("noImpl",
359+
Overload("noImpl_int_int",
360+
[]*Type{IntType},
361+
IntType)),
362+
)
363+
if err != nil {
364+
t.Fatalf("NewEnv() failed: %v", err)
365+
}
366+
for _, tst := range tests {
367+
tc := tst
368+
t.Run(tc.expr, func(t *testing.T) {
369+
checked, iss := e.Compile(tc.expr)
370+
if iss.Err() != nil {
371+
t.Fatalf("Compile() failed: %v", iss.Err())
372+
}
373+
folder, err := NewConstantFoldingOptimizer()
374+
if err != nil {
375+
t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err)
376+
}
377+
opt := NewStaticOptimizer(folder)
378+
optimized, iss := opt.Optimize(e, checked)
379+
if tc.error != "" {
380+
if iss.Err() == nil {
381+
t.Errorf("got nil, wanted error containing %q", tc.error)
382+
} else if !strings.Contains(iss.Err().Error(), tc.error) {
383+
t.Errorf("got %q, wanted error containing %q", iss.Err().Error(), tc.error)
384+
}
385+
return
386+
}
387+
if iss.Err() != nil {
388+
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
389+
}
390+
folded, err := AstToString(optimized)
391+
if err != nil {
392+
t.Fatalf("AstToString() failed: %v", err)
393+
}
394+
if folded != tc.folded {
395+
t.Errorf("got %q, wanted %q", folded, tc.folded)
396+
}
397+
})
398+
}
399+
}
400+
316401
func TestConstantFoldingOptimizerMacroElimination(t *testing.T) {
317402
tests := []struct {
318403
expr string

‎common/decls/decls.go‎

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,9 @@ func (f *FunctionDecl) AddOverload(overload *OverloadDecl) error {
232232
}
233233
return fmt.Errorf("overload redefinition in function. %s: %s has multiple definitions", f.Name(), oID)
234234
}
235+
if overload.HasLateBinding() != o.HasLateBinding() {
236+
return fmt.Errorf("overload with late binding cannot be added to function %s: cannot mix late and non-late bindings", f.Name())
237+
}
235238
}
236239
f.overloadOrdinals = append(f.overloadOrdinals, overload.ID())
237240
f.overloads[overload.ID()] = overload
@@ -257,8 +260,10 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
257260
}
258261
overloads := []*functions.Overload{}
259262
nonStrict := false
263+
hasLateBinding := false
260264
for _, oID := range f.overloadOrdinals {
261265
o := f.overloads[oID]
266+
hasLateBinding = hasLateBinding || o.HasLateBinding()
262267
if o.hasBinding() {
263268
overload := &functions.Overload{
264269
Operator: o.ID(),
@@ -276,6 +281,9 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
276281
if len(overloads) != 0 {
277282
return nil, fmt.Errorf("singleton function incompatible with specialized overloads: %s", f.Name())
278283
}
284+
if hasLateBinding {
285+
return nil, fmt.Errorf("singleton function incompatible with late bindings: %s", f.Name())
286+
}
279287
overloads = []*functions.Overload{
280288
{
281289
Operator: f.Name(),
@@ -516,6 +524,9 @@ type OverloadDecl struct {
516524
argTypes []*types.Type
517525
resultType *types.Type
518526
isMemberFunction bool
527+
// hasLateBinding indicates that the function has a binding which is not known at compile time.
528+
// This is useful for functions which have side-effects or are not deterministically computable.
529+
hasLateBinding bool
519530
// nonStrict indicates that the function will accept error and unknown arguments as inputs.
520531
nonStrict bool
521532
// operandTrait indicates whether the member argument should have a specific type-trait.
@@ -571,6 +582,14 @@ func (o *OverloadDecl) IsNonStrict() bool {
571582
return o.nonStrict
572583
}
573584

585+
// HasLateBinding returns whether the overload has a binding which is not known at compile time.
586+
func (o *OverloadDecl) HasLateBinding() bool {
587+
if o == nil {
588+
return false
589+
}
590+
return o.hasLateBinding
591+
}
592+
574593
// OperandTrait returns the trait mask of the first operand to the overload call, e.g.
575594
// `traits.Indexer`
576595
func (o *OverloadDecl) OperandTrait() int {
@@ -739,6 +758,9 @@ func UnaryBinding(binding functions.UnaryOp) OverloadOpt {
739758
if len(o.ArgTypes()) != 1 {
740759
return nil, fmt.Errorf("unary function bound to non-unary overload: %s", o.ID())
741760
}
761+
if o.hasLateBinding {
762+
return nil, fmt.Errorf("overload already has a late binding: %s", o.ID())
763+
}
742764
o.unaryOp = binding
743765
return o, nil
744766
}
@@ -754,6 +776,9 @@ func BinaryBinding(binding functions.BinaryOp) OverloadOpt {
754776
if len(o.ArgTypes()) != 2 {
755777
return nil, fmt.Errorf("binary function bound to non-binary overload: %s", o.ID())
756778
}
779+
if o.hasLateBinding {
780+
return nil, fmt.Errorf("overload already has a late binding: %s", o.ID())
781+
}
757782
o.binaryOp = binding
758783
return o, nil
759784
}
@@ -766,11 +791,26 @@ func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
766791
if o.hasBinding() {
767792
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
768793
}
794+
if o.hasLateBinding {
795+
return nil, fmt.Errorf("overload already has a late binding: %s", o.ID())
796+
}
769797
o.functionOp = binding
770798
return o, nil
771799
}
772800
}
773801

802+
// LateFunctionBinding indicates that the function has a binding which is not known at compile time.
803+
// This is useful for functions which have side-effects or are not deterministically computable.
804+
func LateFunctionBinding() OverloadOpt {
805+
return func(o *OverloadDecl) (*OverloadDecl, error) {
806+
if o.hasBinding() {
807+
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
808+
}
809+
o.hasLateBinding = true
810+
return o, nil
811+
}
812+
}
813+
774814
// OverloadIsNonStrict enables the function to be called with error and unknown argument values.
775815
//
776816
// Note: do not use this option unless absoluately necessary as it should be an uncommon feature.

0 commit comments

Comments
 (0)