Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cel/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,12 @@ func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
return decls.FunctionBinding(binding)
}

// LateFunctionBinding indicates that the function has a binding which is not known at compile time.
// This is useful for functions which have side-effects or are not deterministically computable.
func LateFunctionBinding() OverloadOpt {
return decls.LateFunctionBinding()
}

// OverloadIsNonStrict enables the function to be called with error and unknown argument values.
//
// Note: do not use this option unless absoluately necessary as it should be an uncommon feature.
Expand Down
24 changes: 21 additions & 3 deletions cel/folding.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
// Walk the list of foldable expression and continue to fold until there are no more folds left.
// All of the fold candidates returned by the constantExprMatcher should succeed unless there's
// a logic bug with the selection of expressions.
foldableExprs := ast.MatchDescendants(root, constantExprMatcher)
constantExprMatcherCapture := func(e ast.NavigableExpr) bool { return constantExprMatcher(ctx, a, e) }
foldableExprs := ast.MatchDescendants(root, constantExprMatcherCapture)
foldCount := 0
for len(foldableExprs) != 0 && foldCount < opt.maxFoldIterations {
for _, fold := range foldableExprs {
Expand All @@ -77,6 +78,10 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
if fold.Kind() == ast.CallKind && maybePruneBranches(ctx, fold) {
continue
}
// Late-bound function calls cannot be folded.
if fold.Kind() == ast.CallKind && isLateBoundFunctionCall(ctx, a, fold) {
continue
}
// Otherwise, assume all context is needed to evaluate the expression.
err := tryFold(ctx, a, fold)
if err != nil {
Expand All @@ -85,7 +90,7 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
}
}
foldCount++
foldableExprs = ast.MatchDescendants(root, constantExprMatcher)
foldableExprs = ast.MatchDescendants(root, constantExprMatcherCapture)
}
// Once all of the constants have been folded, try to run through the remaining comprehensions
// one last time. In this case, there's no guarantee they'll run, so we only update the
Expand Down Expand Up @@ -139,6 +144,15 @@ func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error {
return nil
}

func isLateBoundFunctionCall(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) bool {
call := expr.AsCall()
function := ctx.Functions()[call.FunctionName()]
if function == nil {
return false
}
return function.HasLateBinding()
}

// maybePruneBranches inspects the non-strict call expression to determine whether
// a branch can be removed. Evaluation will naturally prune logical and / or calls,
// but conditional will not be pruned cleanly, so this is one small area where the
Expand Down Expand Up @@ -455,7 +469,7 @@ func adaptLiteral(ctx *OptimizerContext, val ref.Val) (ast.Expr, error) {
// Only comprehensions which are not nested are included as possible constant folds, and only
// if all variables referenced in the comprehension stack exist are only iteration or
// accumulation variables.
func constantExprMatcher(e ast.NavigableExpr) bool {
func constantExprMatcher(ctx *OptimizerContext, a *ast.AST, e ast.NavigableExpr) bool {
switch e.Kind() {
case ast.CallKind:
return constantCallMatcher(e)
Expand All @@ -477,6 +491,10 @@ func constantExprMatcher(e ast.NavigableExpr) bool {
if e.Kind() == ast.IdentKind && !vars[e.AsIdent()] {
constantExprs = false
}
// Late-bound function calls cannot be folded.
if e.Kind() == ast.CallKind && isLateBoundFunctionCall(ctx, a, e) {
constantExprs = false
}
})
ast.PreOrderVisit(e, visitor)
return constantExprs
Expand Down
85 changes: 85 additions & 0 deletions cel/folding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ package cel
import (
"reflect"
"sort"
"strings"
"testing"

"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/proto"

"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types/ref"

proto3pb "github.com/google/cel-go/test/proto3pb"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
Expand Down Expand Up @@ -313,6 +315,89 @@ func TestConstantFoldingOptimizer(t *testing.T) {
}
}

func TestConstantFoldingCallsWithSideEffects(t *testing.T) {
tests := []struct {
expr string
folded string
error string
}{
{
expr: `noSideEffect(3)`,
folded: `3`,
},
{
expr: `withSideEffect(3)`,
folded: `withSideEffect(3)`,
},
{
expr: `[{}, {"a": 1}, {"b": 2}].exists(i, has(i.b) && withSideEffect(i.b) == 1)`,
folded: `[{}, {"a": 1}, {"b": 2}].exists(i, has(i.b) && withSideEffect(i.b) == 1)`,
},
{
expr: `[{}, {"a": 1}, {"b": 2}].exists(i, has(i.b) && noSideEffect(i.b) == 2)`,
folded: `true`,
},
{
expr: `noImpl(3)`,
error: `constant-folding evaluation failed: no such overload: noImpl`,
},
}
e, err := NewEnv(
OptionalTypes(),
EnableMacroCallTracking(),
Function("noSideEffect",
Overload("noSideEffect_int_int",
[]*Type{IntType},
IntType, FunctionBinding(func(args ...ref.Val) ref.Val {
return args[0]
}))),
Function("withSideEffect",
Overload("withSideEffect_int_int",
[]*Type{IntType},
IntType, LateFunctionBinding())),
Function("noImpl",
Overload("noImpl_int_int",
[]*Type{IntType},
IntType)),
)
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
for _, tst := range tests {
tc := tst
t.Run(tc.expr, func(t *testing.T) {
checked, iss := e.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("Compile() failed: %v", iss.Err())
}
folder, err := NewConstantFoldingOptimizer()
if err != nil {
t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err)
}
opt := NewStaticOptimizer(folder)
optimized, iss := opt.Optimize(e, checked)
if tc.error != "" {
if iss.Err() == nil {
t.Errorf("got nil, wanted error containing %q", tc.error)
} else if !strings.Contains(iss.Err().Error(), tc.error) {
t.Errorf("got %q, wanted error containing %q", iss.Err().Error(), tc.error)
}
return
}
if iss.Err() != nil {
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
}
folded, err := AstToString(optimized)
if err != nil {
t.Fatalf("AstToString() failed: %v", err)
}
if folded != tc.folded {
t.Errorf("got %q, wanted %q", folded, tc.folded)
}
})
}
}

func TestConstantFoldingOptimizerMacroElimination(t *testing.T) {
tests := []struct {
expr string
Expand Down
53 changes: 53 additions & 0 deletions common/decls/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ func (f *FunctionDecl) AddOverload(overload *OverloadDecl) error {
}
return fmt.Errorf("overload redefinition in function. %s: %s has multiple definitions", f.Name(), oID)
}
if overload.HasLateBinding() != o.HasLateBinding() {
return fmt.Errorf("overload with late binding cannot be added to function %s: cannot mix late and non-late bindings", f.Name())
}
}
f.overloadOrdinals = append(f.overloadOrdinals, overload.ID())
f.overloads[overload.ID()] = overload
Expand All @@ -250,15 +253,30 @@ func (f *FunctionDecl) OverloadDecls() []*OverloadDecl {
return overloads
}

// Returns true if the function has late bindings. A function cannot mix late bindings with other bindings.
func (f *FunctionDecl) HasLateBinding() bool {
if f == nil {
return false
}
for _, oID := range f.overloadOrdinals {
if f.overloads[oID].HasLateBinding() {
return true
}
}
return false
}

// Bindings produces a set of function bindings, if any are defined.
func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
if f == nil {
return []*functions.Overload{}, nil
}
overloads := []*functions.Overload{}
nonStrict := false
hasLateBinding := false
for _, oID := range f.overloadOrdinals {
o := f.overloads[oID]
hasLateBinding = hasLateBinding || o.HasLateBinding()
if o.hasBinding() {
overload := &functions.Overload{
Operator: o.ID(),
Expand All @@ -276,6 +294,9 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
if len(overloads) != 0 {
return nil, fmt.Errorf("singleton function incompatible with specialized overloads: %s", f.Name())
}
if hasLateBinding {
return nil, fmt.Errorf("singleton function incompatible with late bindings: %s", f.Name())
}
overloads = []*functions.Overload{
{
Operator: f.Name(),
Expand Down Expand Up @@ -516,6 +537,9 @@ type OverloadDecl struct {
argTypes []*types.Type
resultType *types.Type
isMemberFunction bool
// hasLateBinding indicates that the function has a binding which is not known at compile time.
// This is useful for functions which have side-effects or are not deterministically computable.
hasLateBinding bool
// nonStrict indicates that the function will accept error and unknown arguments as inputs.
nonStrict bool
// operandTrait indicates whether the member argument should have a specific type-trait.
Expand Down Expand Up @@ -571,6 +595,14 @@ func (o *OverloadDecl) IsNonStrict() bool {
return o.nonStrict
}

// HasLateBinding returns whether the overload has a binding which is not known at compile time.
func (o *OverloadDecl) HasLateBinding() bool {
if o == nil {
return false
}
return o.hasLateBinding
}

// OperandTrait returns the trait mask of the first operand to the overload call, e.g.
// `traits.Indexer`
func (o *OverloadDecl) OperandTrait() int {
Expand Down Expand Up @@ -739,6 +771,9 @@ func UnaryBinding(binding functions.UnaryOp) OverloadOpt {
if len(o.ArgTypes()) != 1 {
return nil, fmt.Errorf("unary function bound to non-unary overload: %s", o.ID())
}
if o.hasLateBinding {
return nil, fmt.Errorf("overload already has a late binding: %s", o.ID())
}
o.unaryOp = binding
return o, nil
}
Expand All @@ -754,6 +789,9 @@ func BinaryBinding(binding functions.BinaryOp) OverloadOpt {
if len(o.ArgTypes()) != 2 {
return nil, fmt.Errorf("binary function bound to non-binary overload: %s", o.ID())
}
if o.hasLateBinding {
return nil, fmt.Errorf("overload already has a late binding: %s", o.ID())
}
o.binaryOp = binding
return o, nil
}
Expand All @@ -766,11 +804,26 @@ func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
if o.hasBinding() {
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
}
if o.hasLateBinding {
return nil, fmt.Errorf("overload already has a late binding: %s", o.ID())
}
o.functionOp = binding
return o, nil
}
}

// LateFunctionBinding indicates that the function has a binding which is not known at compile time.
// This is useful for functions which have side-effects or are not deterministically computable.
func LateFunctionBinding() OverloadOpt {
return func(o *OverloadDecl) (*OverloadDecl, error) {
if o.hasBinding() {
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
}
o.hasLateBinding = true
return o, nil
}
}

// OverloadIsNonStrict enables the function to be called with error and unknown argument values.
//
// Note: do not use this option unless absoluately necessary as it should be an uncommon feature.
Expand Down
Loading