Skip to content

Commit

Permalink
Minor update on cost order (#1119)
Browse files Browse the repository at this point in the history
* slight reordering in cost computation to preserve legacy compatibility
* Additional tests to attempt to catch ordering issues with cost estimation
  • Loading branch information
TristonianJones authored Jan 31, 2025
1 parent fb3fe56 commit 1bf2472
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 3 deletions.
6 changes: 3 additions & 3 deletions checker/cost.go
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,9 @@ func (c *coster) computeSize(e ast.Expr) *SizeEstimate {
if size, ok := c.computedSizes[e.ID()]; ok {
return &size
}
if size := computeExprSize(e); size != nil {
return size
}
// Ensure size estimates are computed first as users may choose to override the costs that
// CEL would otherwise ascribe to the type.
node := astNode{expr: e, path: c.getPath(e), t: c.getType(e)}
Expand All @@ -938,9 +941,6 @@ func (c *coster) computeSize(e ast.Expr) *SizeEstimate {
c.computedSizes[e.ID()] = *size
return size
}
if size := computeExprSize(e); size != nil {
return size
}
if size := computeTypeSize(c.getType(e)); size != nil {
return size
}
Expand Down
65 changes: 65 additions & 0 deletions checker/cost_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,31 @@ func TestCost(t *testing.T) {
expr: `self.val1 == 1.0`,
wanted: FixedCostEstimate(3),
},
{
name: "bytes list max",
expr: "[bytes('012345678901'), bytes('012345678901'), bytes('012345678901'), bytes('012345678901'), bytes('012345678901')].max()",
options: []CostOption{
OverloadCostEstimate("list_bytes_max",
func(estimator CostEstimator, target *AstNode, args []AstNode) *CallEstimate {
if target != nil {
// Charge 1 cost for comparing each element in the list
elCost := CostEstimate{Min: 1, Max: 1}
// If the list contains strings or bytes, add the cost of traversing all the strings/bytes as a way
// of estimating the additional comparison cost.
if elNode := listElementNode(*target); elNode != nil {
k := elNode.Type().Kind()
if k == types.StringKind || k == types.BytesKind {
sz := sizeEstimate(estimator, elNode)
elCost = elCost.Add(sz.MultiplyByCostFactor(common.StringTraversalCostFactor))
}
return &CallEstimate{CostEstimate: sizeEstimate(estimator, *target).MultiplyByCost(elCost)}
}
}
return nil
}),
},
wanted: CostEstimate{Min: 25, Max: 35},
},
}

for _, tst := range cases {
Expand Down Expand Up @@ -745,6 +770,14 @@ func TestCost(t *testing.T) {
if err != nil {
t.Fatalf("environment creation error: %v", err)
}
maxFunc, _ := decls.NewFunction("max",
decls.MemberOverload("list_bytes_max",
[]*types.Type{types.NewListType(types.BytesType)},
types.BytesType))
err = e.AddFunctions(maxFunc)
if err != nil {
t.Fatalf("environment creation error: %v", err)
}
err = e.AddIdents(tc.vars...)
if err != nil {
t.Fatalf("environment creation error: %s\n", err)
Expand Down Expand Up @@ -773,6 +806,9 @@ func (tc testCostEstimator) EstimateSize(element AstNode) *SizeEstimate {
if l, ok := tc.hints[strings.Join(element.Path(), ".")]; ok {
return &SizeEstimate{Min: 0, Max: l}
}
if element.Type() == types.BytesType {
return &SizeEstimate{Min: 0, Max: 12}
}
return nil
}

Expand All @@ -793,3 +829,32 @@ func estimateSize(estimator CostEstimator, node AstNode) SizeEstimate {
}
return SizeEstimate{Min: 0, Max: math.MaxUint64}
}

func listElementNode(list AstNode) AstNode {
if params := list.Type().Parameters(); len(params) > 0 {
lt := params[0]
nodePath := list.Path()
if nodePath != nil {
// Provide path if we have it so that a OpenAPIv3 maxLength validation can be looked up, if it exists
// for this node.
path := make([]string, len(nodePath)+1)
copy(path, nodePath)
path[len(nodePath)] = "@items"
return &astNode{path: path, t: lt, expr: nil}
} else {
// Provide just the type if no path is available so that worst case size can be looked up based on type.
return &astNode{t: lt, expr: nil}
}
}
return nil
}

func sizeEstimate(estimator CostEstimator, t AstNode) SizeEstimate {
if sz := t.ComputedSize(); sz != nil {
return *sz
}
if sz := estimator.EstimateSize(t); sz != nil {
return *sz
}
return SizeEstimate{Min: 0, Max: math.MaxUint64}
}

0 comments on commit 1bf2472

Please sign in to comment.