Skip to content

Commit

Permalink
Helper methods for subsetting function declaration overloads (#1120)
Browse files Browse the repository at this point in the history
  • Loading branch information
TristonianJones authored Feb 3, 2025
1 parent 1bf2472 commit 2a85bb6
Show file tree
Hide file tree
Showing 3 changed files with 268 additions and 0 deletions.
32 changes: 32 additions & 0 deletions cel/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,38 @@ func Function(name string, opts ...FunctionOpt) EnvOption {
}
}

// OverloadSelector selects an overload associated with a given function when it returns true.
//
// Used in combination with the FunctionDecl.Subset method.
type OverloadSelector = decls.OverloadSelector

// IncludeOverloads defines an OverloadSelector which allow-lists a set of overloads by their ids.
func IncludeOverloads(overloadIDs ...string) OverloadSelector {
return decls.IncludeOverloads(overloadIDs...)
}

// ExcludeOverloads defines an OverloadSelector which deny-lists a set of overloads by their ids.
func ExcludeOverloads(overloadIDs ...string) OverloadSelector {
return decls.ExcludeOverloads(overloadIDs...)
}

// FunctionDecls provides one or more fully formed function declaration to be added to the environment.
func FunctionDecls(funcs ...*decls.FunctionDecl) EnvOption {
return func(e *Env) (*Env, error) {
var err error
for _, fn := range funcs {
if existing, found := e.functions[fn.Name()]; found {
fn, err = existing.Merge(fn)
if err != nil {
return nil, err
}
}
e.functions[fn.Name()] = fn
}
return e, nil
}
}

// FunctionOpt defines a functional option for configuring a function declaration.
type FunctionOpt = decls.FunctionOpt

Expand Down
182 changes: 182 additions & 0 deletions cel/decls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/stdlib"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
Expand Down Expand Up @@ -780,6 +781,187 @@ func TestExprDeclToDeclarationInvalid(t *testing.T) {
}
}

func TestFunctionDeclExcludeOverloads(t *testing.T) {
funcs := []*decls.FunctionDecl{}
for _, fn := range stdlib.Functions() {
if fn.Name() == operators.Add {
fn = fn.Subset(ExcludeOverloads(overloads.AddList, overloads.AddBytes, overloads.AddString))
}
funcs = append(funcs, fn)
}
env, err := NewCustomEnv(FunctionDecls(funcs...))
if err != nil {
t.Fatalf("NewCustomEnv() failed: %v", err)
}

successTests := []struct {
name string
expr string
want ref.Val
}{
{
name: "ints",
expr: "1 + 1",
want: types.Int(2),
},
{
name: "doubles",
expr: "1.5 + 1.5",
want: types.Double(3.0),
},
{
name: "uints",
expr: "1u + 2u",
want: types.Uint(3),
},
{
name: "timestamp plus duration",
expr: "timestamp('2001-01-01T00:00:00Z') + duration('1h') == timestamp('2001-01-01T01:00:00Z')",
want: types.True,
},
{
name: "durations",
expr: "duration('1h') + duration('1m') == duration('1h1m')",
want: types.True,
},
}
for _, tst := range successTests {
tc := tst
t.Run(tc.name, func(t *testing.T) {
ast, iss := env.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, _, err := prg.Eval(NoVars())
if err != nil {
t.Fatalf("prg.Eval() errored: %v", err)
}
if out.Equal(tc.want) != types.True {
t.Errorf("Eval() got %v, wanted %v", out, tc.want)
}
})
}
failureTests := []struct {
name string
expr string
}{
{
name: "strings",
expr: "'a' + 'b'",
},
{
name: "bytes",
expr: "b'123' + b'456'",
},
{
name: "lists",
expr: "[1] + [2, 3]",
},
}
for _, tst := range failureTests {
tc := tst
t.Run(tc.name, func(t *testing.T) {
_, iss := env.Compile(tc.expr)
if iss.Err() == nil {
t.Error("env.Compile() got ast, wanted error")
}
})
}
}

func TestFunctionDeclIncludeOverloads(t *testing.T) {
funcs := []*decls.FunctionDecl{}
for _, fn := range stdlib.Functions() {
if fn.Name() == operators.Add {
fn = fn.Subset(IncludeOverloads(overloads.AddInt64, overloads.AddDouble))
}
funcs = append(funcs, fn)
}
env, err := NewCustomEnv(FunctionDecls(funcs...))
if err != nil {
t.Fatalf("NewCustomEnv() failed: %v", err)
}

successTests := []struct {
name string
expr string
want ref.Val
}{
{
name: "ints",
expr: "1 + 1",
want: types.Int(2),
},
{
name: "doubles",
expr: "1.5 + 1.5",
want: types.Double(3.0),
},
}
for _, tst := range successTests {
tc := tst
t.Run(tc.name, func(t *testing.T) {
ast, iss := env.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, _, err := prg.Eval(NoVars())
if err != nil {
t.Fatalf("prg.Eval() errored: %v", err)
}
if out.Equal(tc.want) != types.True {
t.Errorf("Eval() got %v, wanted %v", out, tc.want)
}
})
}
failureTests := []struct {
name string
expr string
}{
{
name: "strings",
expr: "'a' + 'b'",
},
{
name: "bytes",
expr: "b'123' + b'456'",
},
{
name: "lists",
expr: "[1] + [2, 3]",
},
{
name: "uints",
expr: "1u + 2u",
},
{
name: "timestamp plus duration",
expr: "timestamp('2001-01-01T00:00:00Z') + duration('1h') == timestamp('2001-01-01T01:00:00Z')",
},
{
name: "durations",
expr: "duration('1h') + duration('1m') == duration('1h1m')",
},
}
for _, tst := range failureTests {
tc := tst
t.Run(tc.name, func(t *testing.T) {
_, iss := env.Compile(tc.expr)
if iss.Err() == nil {
t.Error("env.Compile() got ast, wanted error")
}
})
}
}

func testParse(t testing.TB, env *Env, expr string, want any) {
t.Helper()
ast, iss := env.Parse(expr)
Expand Down
54 changes: 54 additions & 0 deletions common/decls/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,60 @@ func (f *FunctionDecl) Merge(other *FunctionDecl) (*FunctionDecl, error) {
return merged, nil
}

// OverloadSelector selects an overload associated with a given function when it returns true.
//
// Used in combination with the Subset method.
type OverloadSelector func(overload *OverloadDecl) bool

// IncludeOverloads defines an OverloadSelector which allow-lists a set of overloads by their ids.
func IncludeOverloads(overloadIDs ...string) OverloadSelector {
return func(overload *OverloadDecl) bool {
for _, oID := range overloadIDs {
if overload.id == oID {
return true
}
}
return false
}
}

// ExcludeOverloads defines an OverloadSelector which deny-lists a set of overloads by their ids.
func ExcludeOverloads(overloadIDs ...string) OverloadSelector {
return func(overload *OverloadDecl) bool {
for _, oID := range overloadIDs {
if overload.id == oID {
return false
}
}
return true
}
}

// Subset returns a new function declaration which contains only the overloads with the specified IDs.
func (f *FunctionDecl) Subset(selector OverloadSelector) *FunctionDecl {
if f == nil {
return nil
}
overloads := make(map[string]*OverloadDecl)
overloadOrdinals := make([]string, 0, len(f.overloadOrdinals))
for _, oID := range f.overloadOrdinals {
overload := f.overloads[oID]
if selector(overload) {
overloads[oID] = overload
overloadOrdinals = append(overloadOrdinals, oID)
}
}
subset := &FunctionDecl{
name: f.Name(),
overloads: overloads,
singleton: f.singleton,
disableTypeGuards: f.disableTypeGuards,
state: f.state,
overloadOrdinals: overloadOrdinals,
}
return subset
}

// AddOverload ensures that the new overload does not collide with an existing overload signature;
// however, if the function signatures are identical, the implementation may be rewritten as its
// difficult to compare functions by object identity.
Expand Down

0 comments on commit 2a85bb6

Please sign in to comment.