From 4f7f2f249529750f23a78c3bc240eeff49d2e522 Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Mon, 15 Jun 2026 20:47:26 -0700 Subject: [PATCH] Cost library for ext/network --- ext/network.go | 203 +++++++++++- ext/network_test.go | 610 +++++++++++++++++++++++++++++++++++++ interpreter/runtimecost.go | 41 ++- 3 files changed, 835 insertions(+), 19 deletions(-) diff --git a/ext/network.go b/ext/network.go index affe59e2d..bca065707 100644 --- a/ext/network.go +++ b/ext/network.go @@ -16,13 +16,16 @@ package ext import ( "fmt" + "math" "net/netip" "reflect" "github.com/google/cel-go/cel" + "github.com/google/cel-go/checker" "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/interpreter" ) const ( @@ -182,7 +185,11 @@ const ( var ( // Definitions for the Opaque Types - IPType = types.NewOpaqueType("net.IP") + + // IPType represents a network IP address. + IPType = types.NewOpaqueType("net.IP") + + // CIDRType represents a CIDR-format network range. CIDRType = types.NewOpaqueType("net.CIDR") ) @@ -196,13 +203,11 @@ func (*networkLib) LibraryName() string { func (*networkLib) CompileOptions() []cel.EnvOption { return []cel.EnvOption{ - // 1. Register Types cel.Types( IPType, CIDRType, ), - // 2. Register Functions cel.Function(cidrFunc, // K8s Parity: Following the pattern, this is "string_to_cidr" cel.Overload("string_to_cidr", []*cel.Type{cel.StringType}, CIDRType, @@ -288,11 +293,58 @@ func (*networkLib) CompileOptions() []cel.EnvOption { networkFormatValidator{funcName: ipFunc, argNum: 0, check: checkIP}, networkFormatValidator{funcName: cidrFunc, argNum: 0, check: checkCIDR}, ), + cel.CostEstimatorOptions( + checker.OverloadCostEstimate("string_to_cidr", estimateNetworkParseCost), + checker.OverloadCostEstimate("cidr_to_string", estimateNetworkNominalStringCost), + checker.OverloadCostEstimate("cidr_contains_cidr", estimateNetworkContainsCIDRCIDRCost), + checker.OverloadCostEstimate("cidr_contains_cidr_string", estimateNetworkContainsCIDRStringCost), + checker.OverloadCostEstimate("cidr_contains_ip_ip", estimateNetworkContainsIPIPCost), + checker.OverloadCostEstimate("cidr_contains_ip_string", estimateNetworkContainsIPStringCost), + checker.OverloadCostEstimate("ip_family", estimateNetworkNominalCost), + checker.OverloadCostEstimate("string_to_ip", estimateNetworkParseCost), + checker.OverloadCostEstimate("cidr_ip", estimateNetworkNominalOpaqueCost), + checker.OverloadCostEstimate("ip_to_string", estimateNetworkNominalStringCost), + checker.OverloadCostEstimate("ip_is_canonical", estimateIPIsCanonicalCost), + checker.OverloadCostEstimate("is_cidr", estimateNetworkParseBoolCost), + checker.OverloadCostEstimate("ip_is_global_unicast", estimateNetworkNominalCost), + checker.OverloadCostEstimate("is_ip", estimateNetworkParseBoolCost), + checker.OverloadCostEstimate("ip_is_link_local_multicast", estimateNetworkNominalCost), + checker.OverloadCostEstimate("ip_is_link_local_unicast", estimateNetworkNominalCost), + checker.OverloadCostEstimate("ip_is_loopback", estimateNetworkNominalCost), + checker.OverloadCostEstimate("cidr_is_mask", estimateNetworkNominalCost), + checker.OverloadCostEstimate("ip_is_unspecified", estimateNetworkNominalCost), + checker.OverloadCostEstimate("cidr_masked", estimateNetworkNominalOpaqueCost), + checker.OverloadCostEstimate("cidr_prefix_length", estimateNetworkNominalCost), + ), } } func (*networkLib) ProgramOptions() []cel.ProgramOption { - return []cel.ProgramOption{} + return []cel.ProgramOption{ + cel.CostTrackerOptions( + interpreter.OverloadCostTracker("string_to_cidr", trackNetworkParseCost), + interpreter.OverloadCostTracker("cidr_to_string", trackNetworkNominalCost), + interpreter.OverloadCostTracker("cidr_contains_cidr", trackNetworkContainsCIDRCIDRCost), + interpreter.OverloadCostTracker("cidr_contains_cidr_string", trackNetworkContainsCIDRStringCost), + interpreter.OverloadCostTracker("cidr_contains_ip_ip", trackNetworkContainsIPIPCost), + interpreter.OverloadCostTracker("cidr_contains_ip_string", trackNetworkContainsIPStringCost), + interpreter.OverloadCostTracker("ip_family", trackNetworkNominalCost), + interpreter.OverloadCostTracker("string_to_ip", trackNetworkParseCost), + interpreter.OverloadCostTracker("cidr_ip", trackNetworkNominalCost), + interpreter.OverloadCostTracker("ip_to_string", trackNetworkNominalCost), + interpreter.OverloadCostTracker("ip_is_canonical", trackIPIsCanonicalCost), + interpreter.OverloadCostTracker("is_cidr", trackNetworkParseCost), + interpreter.OverloadCostTracker("ip_is_global_unicast", trackNetworkNominalCost), + interpreter.OverloadCostTracker("is_ip", trackNetworkParseCost), + interpreter.OverloadCostTracker("ip_is_link_local_multicast", trackNetworkNominalCost), + interpreter.OverloadCostTracker("ip_is_link_local_unicast", trackNetworkNominalCost), + interpreter.OverloadCostTracker("ip_is_loopback", trackNetworkNominalCost), + interpreter.OverloadCostTracker("cidr_is_mask", trackNetworkNominalCost), + interpreter.OverloadCostTracker("ip_is_unspecified", trackNetworkNominalCost), + interpreter.OverloadCostTracker("cidr_masked", trackNetworkNominalCost), + interpreter.OverloadCostTracker("cidr_prefix_length", trackNetworkNominalCost), + ), + } } // networkAdapter adapts netip types while preserving existing adapters. @@ -478,8 +530,7 @@ func parseIPAddr(raw string) (netip.Addr, error) { return addr, nil } -// --- Opaque Type Wrappers --- - +// IP represents an IP address type. type IP struct { netip.Addr } @@ -527,6 +578,13 @@ func (i IP) Value() any { return i.Addr } +// Size returns the size of the IP address in bytes. +// /Used in the size estimation of the runtime cost. +func (i IP) Size() ref.Val { + return types.Int(int64(math.Ceil(float64(i.Addr.BitLen()) / 8))) +} + +// CIDR represents the CIDR network mask format. type CIDR struct { netip.Prefix } @@ -574,6 +632,12 @@ func (c CIDR) Value() any { return c.Prefix } +// Size returns the size of the CIDR prefix address in bytes. +// Used in the size estimation of the runtime cost. +func (c CIDR) Size() ref.Val { + return types.Int(int64(math.Ceil(float64(c.Prefix.Bits()) / 8))) +} + // --- Static Validators --- type argChecker func(e *cel.Env, call, arg ast.Expr) error @@ -617,3 +681,130 @@ func checkCIDR(e *cel.Env, call, arg ast.Expr) error { _, err := parseCIDR(pattern) return err } + +// Cost estimation functions for network extensions. + +func estimateNetworkParseCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + if len(args) < 1 { + return nil + } + sz := estimateSize(estimator, args[0]) + resultSize := rangedSizeEstimate(4, 16) + return callEstimate(sz.MultiplyByCostFactor(stringCostFactor), &resultSize) +} + +func estimateNetworkParseBoolCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + if len(args) < 1 { + return nil + } + sz := estimateSize(estimator, args[0]) + return callEstimate(sz.MultiplyByCostFactor(stringCostFactor), nil) +} + +func estimateIPIsCanonicalCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + if len(args) < 1 { + return nil + } + sz := estimateSize(estimator, args[0]) + return callEstimate(sz.MultiplyByCostFactor(2*stringCostFactor), nil) +} + +func estimateNetworkNominalCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + return callEstimate(callCostEstimate, nil) +} + +func estimateNetworkNominalOpaqueCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + resultSize := rangedSizeEstimate(4, 16) + return callEstimate(callCostEstimate, &resultSize) +} + +func estimateNetworkNominalStringCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + resultSize := rangedSizeEstimate(3, 45) + return callEstimate(callCostEstimate, &resultSize) +} + +func estimateNetworkContainsIPIPCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + sz := rangedSizeEstimate(4, 16) + ipCompCost := sz.Add(sz).MultiplyByCostFactor(stringCostFactor) + return callEstimate(ipCompCost, nil) +} + +func estimateNetworkContainsIPStringCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + if len(args) < 1 { + return nil + } + sz := rangedSizeEstimate(4, 16) + ipCompCost := sz.Add(sz).MultiplyByCostFactor(stringCostFactor) + argSz := estimateSize(estimator, args[0]) + ipCompCost = ipCompCost.Add(argSz.MultiplyByCostFactor(stringCostFactor)) + return callEstimate(ipCompCost, nil) +} + +func estimateNetworkContainsCIDRCIDRCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + sz := rangedSizeEstimate(4, 16) + ipCompCost := sz.Add(sz).MultiplyByCostFactor(stringCostFactor) + ipCompCost = ipCompCost.Add(sz.MultiplyByCostFactor(stringCostFactor)) + // K8s adds one for the extra IP traversal + ipCompCost = ipCompCost.Add(callCostEstimate) + return callEstimate(ipCompCost, nil) +} + +func estimateNetworkContainsCIDRStringCost(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + if len(args) < 1 { + return nil + } + sz := rangedSizeEstimate(4, 16) + ipCompCost := sz.Add(sz).MultiplyByCostFactor(stringCostFactor) + ipCompCost = ipCompCost.Add(sz.MultiplyByCostFactor(stringCostFactor)) + argSz := estimateSize(estimator, args[0]) + ipCompCost = ipCompCost.Add(argSz.MultiplyByCostFactor(stringCostFactor)) + // K8s adds one for the extra IP traversal + ipCompCost = ipCompCost.Add(callCostEstimate) + return callEstimate(ipCompCost, nil) +} + +// Runtime cost tracking functions for network extensions. + +func trackNetworkParseCost(args []ref.Val, result ref.Val) *uint64 { + cost := uint64(math.Ceil(float64(actualSize(args[0])) * stringCostFactor)) + return &cost +} + +func trackIPIsCanonicalCost(args []ref.Val, result ref.Val) *uint64 { + cost := uint64(math.Ceil(float64(actualSize(args[0])) * 2 * stringCostFactor)) + return &cost +} + +func trackNetworkNominalCost(args []ref.Val, result ref.Val) *uint64 { + return &callCost +} + +func trackNetworkContainsIPIPCost(args []ref.Val, result ref.Val) *uint64 { + cidrSize := actualSize(args[0]) + cost := uint64(math.Ceil(float64(cidrSize+cidrSize) * stringCostFactor)) + return &cost +} + +func trackNetworkContainsIPStringCost(args []ref.Val, result ref.Val) *uint64 { + cidrSize := actualSize(args[0]) + otherSize := actualSize(args[1]) + cost := uint64(math.Ceil(float64(cidrSize+cidrSize) * stringCostFactor)) + cost = safeAdd(cost, uint64(math.Ceil(float64(otherSize)*stringCostFactor))) + return &cost +} + +func trackNetworkContainsCIDRCIDRCost(args []ref.Val, result ref.Val) *uint64 { + cidrSize := actualSize(args[0]) + cost := uint64(math.Ceil(float64(cidrSize+cidrSize) * stringCostFactor)) + cost = safeAdd(cost, uint64(math.Ceil(float64(cidrSize)*stringCostFactor)), 1) + return &cost +} + +func trackNetworkContainsCIDRStringCost(args []ref.Val, result ref.Val) *uint64 { + cidrSize := actualSize(args[0]) + otherSize := actualSize(args[1]) + cost := uint64(math.Ceil(float64(cidrSize+cidrSize) * stringCostFactor)) + cost = safeAdd(cost, uint64(math.Ceil(float64(cidrSize)*stringCostFactor)), 1) + cost = safeAdd(cost, uint64(math.Ceil(float64(otherSize)*stringCostFactor))) + return &cost +} diff --git a/ext/network_test.go b/ext/network_test.go index a0e777ffd..592c07494 100644 --- a/ext/network_test.go +++ b/ext/network_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/google/cel-go/cel" + "github.com/google/cel-go/checker" "github.com/google/cel-go/common/types" ) @@ -278,6 +279,148 @@ func TestNetwork_Success(t *testing.T) { expr: "cidr('2001:db8::1/32').isMask()", out: false, }, + // IP success cases from K8s ip_test.go + { + name: "ipv4 isUnspecified false", + expr: "ip('127.0.0.1').isUnspecified()", + out: false, + }, + { + name: "ipv4 isLoopback false", + expr: "ip('1.2.3.4').isLoopback()", + out: false, + }, + { + name: "ipv4 isLinkLocalMulticast true", + expr: "ip('224.0.0.1').isLinkLocalMulticast()", + out: true, + }, + { + name: "ipv4 isLinkLocalMulticast false", + expr: "ip('224.0.1.1').isLinkLocalMulticast()", + out: false, + }, + { + name: "ipv4 isLinkLocalUnicast true", + expr: "ip('169.254.169.254').isLinkLocalUnicast()", + out: true, + }, + { + name: "ipv4 isLinkLocalUnicast false", + expr: "ip('192.168.0.1').isLinkLocalUnicast()", + out: false, + }, + { + name: "ipv4 isGlobalUnicast false", + expr: "ip('255.255.255.255').isGlobalUnicast()", + out: false, + }, + { + name: "ipv6 isUnspecified false", + expr: "ip('::1').isUnspecified()", + out: false, + }, + { + name: "ipv6 isLoopback false", + expr: "ip('2001:db8::abcd').isLoopback()", + out: false, + }, + { + name: "ipv6 isLinkLocalMulticast false", + expr: "ip('fd00::1').isLinkLocalMulticast()", + out: false, + }, + { + name: "ipv6 isLinkLocalUnicast true", + expr: "ip('fe80::1').isLinkLocalUnicast()", + out: true, + }, + { + name: "ipv6 isLinkLocalUnicast false", + expr: "ip('fd80::1').isLinkLocalUnicast()", + out: false, + }, + { + name: "ipv6 isGlobalUnicast true", + expr: "ip('2001:db8::abcd').isGlobalUnicast()", + out: true, + }, + { + name: "ipv6 isGlobalUnicast false", + expr: "ip('ff00::1').isGlobalUnicast()", + out: false, + }, + { + name: "type of IP is net.IP", + expr: "type(ip('192.168.0.1')) == net.IP", + out: true, + }, + // CIDR success cases from K8s cidr_test.go + { + name: "contains IP ipv6 (IP)", + expr: "cidr('2001:db8::/32').containsIP(ip('2001:db8::1'))", + out: true, + }, + { + name: "does not contain IP ipv6 (IP)", + expr: "cidr('2001:db8::/32').containsIP(ip('2001:dc8::1'))", + out: false, + }, + { + name: "contains IP ipv6 (string)", + expr: "cidr('2001:db8::/32').containsIP('2001:db8::1')", + out: true, + }, + { + name: "does not contain IP ipv6 (string)", + expr: "cidr('2001:db8::/32').containsIP('2001:dc8::1')", + out: false, + }, + { + name: "contains CIDR ipv6 (CIDR)", + expr: "cidr('2001:db8::/32').containsCIDR(cidr('2001:db8::/33'))", + out: true, + }, + { + name: "does not contain CIDR ipv6 (CIDR)", + expr: "cidr('2001:db8::/32').containsCIDR(cidr('2001:db8::/31'))", + out: false, + }, + { + name: "contains CIDR ipv6 (string)", + expr: "cidr('2001:db8::/32').containsCIDR('2001:db8::/33')", + out: true, + }, + { + name: "does not contain CIDR ipv6 (string)", + expr: "cidr('2001:db8::/32').containsCIDR('2001:db8::/31')", + out: false, + }, + { + name: "returns IP ipv6", + expr: "cidr('2001:db8::/32').ip() == ip('2001:db8::')", + out: true, + }, + { + name: "masks masked ipv6", + expr: "cidr('2001:db8::/32').masked() == cidr('2001:db8::/32')", + out: true, + }, + { + name: "masks unmasked ipv6", + expr: "cidr('2001:db8:1::/32').masked() == cidr('2001:db8::/32')", + out: true, + }, + { + name: "returns prefix length ipv6", + expr: "cidr('2001:db8::/32').prefixLength()", + out: int64(32), + }, + { + name: "type of CIDR is net.CIDR", + expr: "type(cidr('192.168.0.0/24')) == net.CIDR", + out: true, + }, } // Initialize the environment with the Network extension @@ -332,6 +475,16 @@ func TestNetwork_RuntimeErrors(t *testing.T) { expr: "cidr('10.0.0.0/8').containsCIDR('not-a-cidr')", errContains: "parse error", }, + { + name: "ip.isCanonical invalid ipv4", + expr: "ip.isCanonical('127.0.0.1.0')", + errContains: "parse error", + }, + { + name: "ip.isCanonical invalid ipv6", + expr: "ip.isCanonical('2001:db8:::68')", + errContains: "parse error", + }, } env, err := cel.NewEnv(Network()) @@ -530,6 +683,21 @@ func TestNetwork_CompileErrors(t *testing.T) { expr: "cidr('10.0.0.0/8')", errContains: "", }, + { + name: "passing cidr into isIP returns compile error", + expr: "isIP(cidr('192.168.0.0/24'))", + errContains: "found no matching overload for 'isIP'", + }, + { + name: "cidr parse invalid ipv4", + expr: "cidr('192.168.0.0/')", + errContains: "invalid cidr argument", + }, + { + name: "cidr parse invalid ipv6", + expr: "cidr('2001:db8::/')", + errContains: "invalid cidr argument", + }, } env, err := cel.NewEnv(Network()) @@ -565,3 +733,445 @@ func TestNetwork_CompileErrors(t *testing.T) { }) } } + +func TestNetworkCost(t *testing.T) { + tests := []struct { + name string + expr string + estimatedCost checker.CostEstimate + runtimeCost uint64 + }{ + { + name: "ip parse", + expr: "ip('192.168.0.1')", + estimatedCost: checker.FixedCostEstimate(2), + runtimeCost: 2, + }, + { + name: "isIP parse", + expr: "isIP('192.168.0.1')", + estimatedCost: checker.FixedCostEstimate(2), + runtimeCost: 2, + }, + { + name: "cidr parse", + expr: "cidr('192.168.0.0/16')", + estimatedCost: checker.FixedCostEstimate(2), + runtimeCost: 2, + }, + { + name: "isCIDR parse", + expr: "isCIDR('192.168.0.0/16')", + estimatedCost: checker.FixedCostEstimate(2), + runtimeCost: 2, + }, + { + name: "ip.isCanonical", + expr: "ip.isCanonical('192.168.0.1')", + estimatedCost: checker.FixedCostEstimate(3), + runtimeCost: 3, + }, + { + name: "cidr containsIP ip", + expr: "cidr('192.168.0.0/16').containsIP(ip('192.169.0.1'))", + estimatedCost: checker.CostEstimate{Min: 5, Max: 8}, + runtimeCost: 5, + }, + { + name: "cidr containsIP string", + expr: "cidr('192.168.0.0/16').containsIP('192.0.0.1')", + estimatedCost: checker.CostEstimate{Min: 4, Max: 7}, + runtimeCost: 4, + }, + { + name: "cidr containsCIDR cidr", + expr: "cidr('192.168.0.0/16').containsCIDR(cidr('192.0.0.0/30'))", + estimatedCost: checker.CostEstimate{Min: 7, Max: 11}, + runtimeCost: 7, + }, + { + name: "cidr containsCIDR string", + expr: "cidr('192.168.0.0/16').containsCIDR('192.0.0.0/30')", + estimatedCost: checker.CostEstimate{Min: 7, Max: 11}, + runtimeCost: 7, + }, + { + name: "ip family", + expr: "ip('192.168.0.1').family()", + estimatedCost: checker.FixedCostEstimate(3), + runtimeCost: 3, + }, + { + name: "ip unspecified", + expr: "ip('192.168.0.1').isUnspecified()", + estimatedCost: checker.FixedCostEstimate(3), + runtimeCost: 3, + }, + { + name: "ip isLoopback", + expr: "ip('192.168.0.1').isLoopback()", + estimatedCost: checker.FixedCostEstimate(3), + runtimeCost: 3, + }, + { + name: "ip isLinkLocalMulticast", + expr: "ip('192.168.0.1').isLinkLocalMulticast()", + estimatedCost: checker.FixedCostEstimate(3), + runtimeCost: 3, + }, + { + name: "ip isLinkLocalUnicast", + expr: "ip('192.168.0.1').isLinkLocalUnicast()", + estimatedCost: checker.FixedCostEstimate(3), + runtimeCost: 3, + }, + { + name: "ip isGlobalUnicast", + expr: "ip('192.168.0.1').isGlobalUnicast()", + estimatedCost: checker.FixedCostEstimate(3), + runtimeCost: 3, + }, + { + name: "ipv6 family", + expr: "ip('2001:db8:3333:4444:5555:6666:7777:8888').family()", + estimatedCost: checker.FixedCostEstimate(5), + runtimeCost: 5, + }, + { + name: "ipv6 unspecified", + expr: "ip('2001:db8:3333:4444:5555:6666:7777:8888').isUnspecified()", + estimatedCost: checker.FixedCostEstimate(5), + runtimeCost: 5, + }, + { + name: "ipv6 isLoopback", + expr: "ip('2001:db8:3333:4444:5555:6666:7777:8888').isLoopback()", + estimatedCost: checker.FixedCostEstimate(5), + runtimeCost: 5, + }, + { + name: "ipv6 isLinkLocalMulticast", + expr: "ip('2001:db8:3333:4444:5555:6666:7777:8888').isLinkLocalMulticast()", + estimatedCost: checker.FixedCostEstimate(5), + runtimeCost: 5, + }, + { + name: "ipv6 isLinkLocalUnicast", + expr: "ip('2001:db8:3333:4444:5555:6666:7777:8888').isLinkLocalUnicast()", + estimatedCost: checker.FixedCostEstimate(5), + runtimeCost: 5, + }, + { + name: "ipv6 isGlobalUnicast", + expr: "ip('2001:db8:3333:4444:5555:6666:7777:8888').isGlobalUnicast()", + estimatedCost: checker.FixedCostEstimate(5), + runtimeCost: 5, + }, + { + name: "cidr ip extraction", + expr: "cidr('2001:db8::/32').ip()", + estimatedCost: checker.FixedCostEstimate(3), + runtimeCost: 3, + }, + { + name: "cidr prefixLength", + expr: "cidr('2001:db8::/32').prefixLength()", + estimatedCost: checker.FixedCostEstimate(3), + runtimeCost: 3, + }, + { + name: "cidr masked", + expr: "cidr('2001:db8::/32').masked()", + estimatedCost: checker.FixedCostEstimate(3), + runtimeCost: 3, + }, + } + + for _, tst := range tests { + t.Run(tst.name, func(t *testing.T) { + testCost(t, tst.expr, tst.estimatedCost, tst.runtimeCost) + }) + } +} + +func testCost(t *testing.T, expr string, estimatedCost checker.CostEstimate, runtimeCost uint64) { + t.Helper() + env, err := cel.NewEnv(Network()) + if err != nil { + t.Fatalf("cel.NewEnv(Network()) failed: %v", err) + } + parsedAst, iss := env.Parse(expr) + if iss.Err() != nil { + t.Fatalf("env.Parse(%q) failed: %v", expr, iss.Err()) + } + checkedAst, iss := env.Check(parsedAst) + if iss.Err() != nil { + t.Fatalf("env.Check(%q) failed: %v", expr, iss.Err()) + } + + actualEst, err := env.EstimateCost(checkedAst, &noopCostEstimator{}) + if err != nil { + t.Fatalf("env.EstimateCost(%q) failed: %v", expr, err) + } + if actualEst.Min != estimatedCost.Min || actualEst.Max != estimatedCost.Max { + t.Errorf("expected estimated cost %v, got %v for expr %q", estimatedCost, actualEst, expr) + } + + program, err := env.Program(checkedAst, cel.CostTracking(&noopCostEstimator{})) + if err != nil { + t.Fatalf("env.Program(%q) failed: %v", expr, err) + } + _, evalDetails, err := program.Eval(cel.NoVars()) + if err != nil { + t.Fatalf("program.Eval(%q) failed: %v", expr, err) + } + if evalDetails == nil || evalDetails.ActualCost() == nil { + t.Fatalf("evalDetails or actualCost is nil for %q", expr) + } + if *evalDetails.ActualCost() != runtimeCost { + t.Errorf("expected runtime cost %d, got %d for expr %q", runtimeCost, *evalDetails.ActualCost(), expr) + } +} + +func TestIPCost(t *testing.T) { + ipv4 := "ip('192.168.0.1')" + ipv4BaseEstimatedCost := checker.FixedCostEstimate(2) + ipv4BaseRuntimeCost := uint64(2) + + ipv6 := "ip('2001:db8:3333:4444:5555:6666:7777:8888')" + ipv6BaseEstimatedCost := checker.FixedCostEstimate(4) + ipv6BaseRuntimeCost := uint64(4) + + testCases := []struct { + ops []string + expectEsimatedCost func(checker.CostEstimate) checker.CostEstimate + expectRuntimeCost func(uint64) uint64 + }{ + { + // For just parsing the IP, the cost is expected to be the base. + ops: []string{""}, + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { return c }, + expectRuntimeCost: func(c uint64) uint64 { return c }, + }, + { + ops: []string{".family()", ".isUnspecified()", ".isLoopback()", ".isLinkLocalMulticast()", ".isLinkLocalUnicast()", ".isGlobalUnicast()"}, + // For most other operations, the cost is expected to be the base + 1. + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 1, Max: c.Max + 1} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 1 }, + }, + { + ops: []string{" == ip('192.168.0.1')"}, + // For most other operations, the cost is expected to be the base + 1. + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return c.Add(ipv4BaseEstimatedCost).Add(checker.CostEstimate{Min: 1, Max: 2}) + }, + expectRuntimeCost: func(c uint64) uint64 { return c + ipv4BaseRuntimeCost + 1 }, + }, + } + + for _, tc := range testCases { + for _, op := range tc.ops { + t.Run(ipv4+op, func(t *testing.T) { + testCost(t, ipv4+op, tc.expectEsimatedCost(ipv4BaseEstimatedCost), tc.expectRuntimeCost(ipv4BaseRuntimeCost)) + }) + + t.Run(ipv6+op, func(t *testing.T) { + testCost(t, ipv6+op, tc.expectEsimatedCost(ipv6BaseEstimatedCost), tc.expectRuntimeCost(ipv6BaseRuntimeCost)) + }) + } + } +} + +func TestCIDRCost(t *testing.T) { + ipv4 := "cidr('192.168.0.0/16')" + ipv4BaseEstimatedCost := checker.CostEstimate{Min: 2, Max: 2} + ipv4BaseRuntimeCost := uint64(2) + + ipv6 := "cidr('2001:db8::/32')" + ipv6BaseEstimatedCost := checker.CostEstimate{Min: 2, Max: 2} + ipv6BaseRuntimeCost := uint64(2) + + type testCase struct { + ops []string + expectEsimatedCost func(checker.CostEstimate) checker.CostEstimate + expectRuntimeCost func(uint64) uint64 + } + + cases := []testCase{ + { + // For just parsing the IP, the cost is expected to be the base. + ops: []string{""}, + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { return c }, + expectRuntimeCost: func(c uint64) uint64 { return c }, + }, + { + ops: []string{".ip()", ".prefixLength()", ".masked()"}, + // For most other operations, the cost is expected to be the base + 1. + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 1, Max: c.Max + 1} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 1 }, + }, + { + ops: []string{" == cidr('2001:db8::/32')"}, + // For most other operations, the cost is expected to be the base + 1. + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return c.Add(ipv6BaseEstimatedCost).Add(checker.CostEstimate{Min: 1, Max: 2}) + }, + expectRuntimeCost: func(c uint64) uint64 { return c + ipv6BaseRuntimeCost + 1 }, + }, + } + + //nolint:gocritic + ipv4Cases := append(cases, []testCase{ + { + ops: []string{".containsCIDR(cidr('192.0.0.0/30'))"}, + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 5 }, + }, + { + ops: []string{".containsCIDR(cidr('192.168.0.0/16'))"}, + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 5 }, + }, + { + ops: []string{".containsCIDR('192.0.0.0/30')"}, + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 5 }, + }, + { + ops: []string{".containsCIDR('192.168.0.0/16')"}, + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 5 }, + }, + { + ops: []string{".containsIP(ip('192.0.0.1'))"}, + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 2, Max: c.Max + 5} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 2 }, + }, + { + ops: []string{".containsIP(ip('192.169.0.1'))"}, + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 3, Max: c.Max + 6} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 3 }, + }, + { + ops: []string{".containsIP(ip('192.169.169.250'))"}, + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 3, Max: c.Max + 6} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 3 }, + }, + { + ops: []string{".containsIP('192.0.0.1')"}, + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 2, Max: c.Max + 5} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 2 }, + }, + { + ops: []string{".containsIP('192.169.0.1')"}, + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 3, Max: c.Max + 6} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 3 }, + }, + }...) + + //nolint:gocritic + ipv6Cases := append(cases, []testCase{ + { + ops: []string{".containsCIDR(cidr('2001:db8::/126'))"}, + // For operations like checking if an IP is in a CIDR, the cost is expected to higher. + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 5 }, + }, + { + ops: []string{".containsCIDR(cidr('2001:db8::/32'))"}, + // For operations like checking if an IP is in a CIDR, the cost is expected to higher. + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 5 }, + }, + { + ops: []string{".containsCIDR('2001:db8::/126')"}, + // For operations like checking if an IP is in a CIDR, the cost is expected to higher. + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 5 }, + }, + { + ops: []string{".containsCIDR('2001:db8::/32')"}, + // For operations like checking if an IP is in a CIDR, the cost is expected to higher. + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 5 }, + }, + { + ops: []string{".containsIP(ip('2001:db8:3333:4444:5555:6666:7777:8888'))"}, + // For operations like checking if an IP is in a CIDR, the cost is expected to higher. + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 8} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 5 }, + }, + { + ops: []string{".containsIP(ip('2001:db8::1'))"}, + // For operations like checking if an IP is in a CIDR, the cost is expected to higher. + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 3, Max: c.Max + 6} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 3 }, + }, + { + ops: []string{".containsIP('2001:db8:3333:4444:5555:6666:7777:8888')"}, + // For operations like checking if an IP is in a CIDR, the cost is expected to higher. + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 8} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 5 }, + }, + { + ops: []string{".containsIP('2001:db8::1')"}, + // For operations like checking if an IP is in a CIDR, the cost is expected to higher. + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 3, Max: c.Max + 6} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 3 }, + }, + }...) + + for _, tc := range ipv4Cases { + for _, op := range tc.ops { + t.Run(ipv4+op, func(t *testing.T) { + testCost(t, ipv4+op, tc.expectEsimatedCost(ipv4BaseEstimatedCost), tc.expectRuntimeCost(ipv4BaseRuntimeCost)) + }) + } + } + + for _, tc := range ipv6Cases { + for _, op := range tc.ops { + t.Run(ipv6+op, func(t *testing.T) { + testCost(t, ipv6+op, tc.expectEsimatedCost(ipv6BaseEstimatedCost), tc.expectRuntimeCost(ipv6BaseRuntimeCost)) + }) + } + } +} diff --git a/interpreter/runtimecost.go b/interpreter/runtimecost.go index 6c44cd798..38a53fc36 100644 --- a/interpreter/runtimecost.go +++ b/interpreter/runtimecost.go @@ -276,7 +276,7 @@ func (c *CostTracker) costCall(call InterpretableCall, args []ref.Val, result re if tracker, found := c.overloadTrackers[call.OverloadID()]; found { callCost := tracker(args, result) if callCost != nil { - cost += *callCost + cost = safeAdd(cost, *callCost) return cost } } @@ -284,7 +284,7 @@ func (c *CostTracker) costCall(call InterpretableCall, args []ref.Val, result re if c.Estimator != nil { callCost := c.Estimator.CallCost(call.Function(), call.OverloadID(), args, result) if callCost != nil { - cost += *callCost + cost = safeAdd(cost, *callCost) return cost } } @@ -293,11 +293,11 @@ func (c *CostTracker) costCall(call InterpretableCall, args []ref.Val, result re switch call.OverloadID() { // O(n) functions case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString, overloads.ExtQuoteString, overloads.ExtFormatString: - cost += uint64(math.Ceil(float64(actualSize(args[0])) * common.StringTraversalCostFactor)) + cost = safeAdd(cost, uint64(math.Ceil(float64(actualSize(args[0]))*common.StringTraversalCostFactor))) case overloads.InList: // If a list is composed entirely of constant values this is O(1), but we don't account for that here. // We just assume all list containment checks are O(n). - cost += actualSize(args[1]) + cost = safeAdd(cost, actualSize(args[1])) // O(min(m, n)) functions case overloads.LessString, overloads.GreaterString, overloads.LessEqualsString, overloads.GreaterEqualsString, overloads.LessBytes, overloads.GreaterBytes, overloads.LessEqualsBytes, overloads.GreaterEqualsBytes, @@ -307,15 +307,12 @@ func (c *CostTracker) costCall(call InterpretableCall, args []ref.Val, result re // of 1. lhsSize := actualSize(args[0]) rhsSize := actualSize(args[1]) - minSize := lhsSize - if rhsSize < minSize { - minSize = rhsSize - } - cost += uint64(math.Ceil(float64(minSize) * common.StringTraversalCostFactor)) + minSize := min(rhsSize, lhsSize) + cost = safeAdd(cost, uint64(math.Ceil(float64(minSize)*common.StringTraversalCostFactor))) // O(m+n) functions case overloads.AddString, overloads.AddBytes: // In the worst case scenario, we would need to reallocate a new backing store and copy both operands over. - cost += uint64(math.Ceil(float64(actualSize(args[0])+actualSize(args[1])) * common.StringTraversalCostFactor)) + cost = safeAdd(cost, uint64(math.Ceil(float64(actualSize(args[0])+actualSize(args[1]))*common.StringTraversalCostFactor))) // O(nm) functions case overloads.MatchesString: // https://swtch.com/~rsc/regexp/regexp1.html applies to RE2 implementation supported by CEL @@ -328,11 +325,11 @@ func (c *CostTracker) costCall(call InterpretableCall, args []ref.Val, result re // For now, we're making a guess that each expression in a regex is typically at least 4 chars // in length. regexCost := uint64(math.Ceil(float64(actualSize(args[1])) * common.RegexStringLengthCostFactor)) - cost += strCost * regexCost + cost = safeAdd(cost, strCost*regexCost) case overloads.ContainsString: strCost := uint64(math.Ceil(float64(actualSize(args[0])) * common.StringTraversalCostFactor)) substrCost := uint64(math.Ceil(float64(actualSize(args[1])) * common.StringTraversalCostFactor)) - cost += strCost * substrCost + cost = safeAdd(cost, strCost*substrCost) default: // The following operations are assumed to have O(1) complexity. @@ -342,7 +339,7 @@ func (c *CostTracker) costCall(call InterpretableCall, args []ref.Val, result re // - Computing the size of strings, byte sequences, lists and maps. // - Logical operations and all operators on fixed width scalars (comparisons, equality) // - Any functions that don't have a declared cost either here or in provided ActualCostEstimator. - cost++ + cost = safeAdd(cost, 1) } return cost @@ -413,3 +410,21 @@ argloop: } return result, true } + +func safeAdd(x, y uint64, rest ...uint64) uint64 { + if y > 0 && x > math.MaxUint64-y { + return math.MaxUint64 + } + next := x + y + if len(rest) == 0 { + return next + } + return safeAdd(next, rest[0], rest[1:]...) +} + +func safeMul(x, y uint64) uint64 { + if y != 0 && x > math.MaxUint64/y { + return math.MaxUint64 + } + return x * y +}