diff --git a/checker/optional.cc b/checker/optional.cc index d41e68aa1..f26bd3922 100644 --- a/checker/optional.cc +++ b/checker/optional.cc @@ -71,6 +71,13 @@ Type OptionalMapOfKV() { return *kInstance; } +Type ListOfOptionalV() { + static const absl::NoDestructor kInstance( + checker_internal::BuiltinsArena(), OptionalOfV()); + + return *kInstance; +} + class OptionalNames { public: static constexpr char kOptionalType[] = "optional_type"; @@ -85,6 +92,8 @@ class OptionalNames { static constexpr char kOptionalIndex[] = "_[?_]"; static constexpr char kOptionalFirst[] = "first"; static constexpr char kOptionalLast[] = "last"; + static constexpr char kOptionalUnwrap[] = "optional.unwrap"; + static constexpr char kOptionalUnwrapOpt[] = "unwrapOpt"; }; class OptionalOverloads { @@ -114,6 +123,9 @@ class OptionalOverloads { // Syntactic sugar for chained indexing. static constexpr char kOptionalListIndexInt[] = "optional_list_index_int"; static constexpr char kOptionalMapIndexValue[] = "optional_map_index_value"; + // Unwrapping + static constexpr char kOptionalUnwrapList[] = "optional_unwrap_list"; + static constexpr char kOptionalUnwrapOptList[] = "optional_unwrapOpt_list"; }; absl::Status RegisterOptionalDecls(TypeCheckerBuilder& builder, int version) { @@ -207,6 +219,17 @@ absl::Status RegisterOptionalDecls(TypeCheckerBuilder& builder, int version) { OptionalOfV(), OptionalMapOfKV(), TypeParamType("K")))); + CEL_ASSIGN_OR_RETURN( + auto unwrap, + MakeFunctionDecl( + OptionalNames::kOptionalUnwrap, + MakeOverloadDecl(OptionalOverloads::kOptionalUnwrapList, ListOfV(), ListOfOptionalV()))); + CEL_ASSIGN_OR_RETURN( + auto unwrap_opt, + MakeFunctionDecl( + OptionalNames::kOptionalUnwrapOpt, + MakeMemberOverloadDecl(OptionalOverloads::kOptionalUnwrapOptList, ListOfV(), ListOfOptionalV()))); + CEL_RETURN_IF_ERROR(builder.AddVariable( MakeVariableDecl(OptionalNames::kOptionalType, TypeOfOptionalOfV()))); @@ -220,6 +243,8 @@ absl::Status RegisterOptionalDecls(TypeCheckerBuilder& builder, int version) { CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(opt_index))); CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(select))); CEL_RETURN_IF_ERROR(builder.MergeFunction(std::move(index))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(unwrap))); + CEL_RETURN_IF_ERROR(builder.AddFunction(std::move(unwrap_opt))); if (version == 0 || version == 1) { return absl::OkStatus(); diff --git a/checker/standard_library.cc b/checker/standard_library.cc index 744a171ef..80ef08ba0 100644 --- a/checker/standard_library.cc +++ b/checker/standard_library.cc @@ -285,6 +285,8 @@ absl::Status AddTypeConversions(TypeCheckerBuilder& builder) { // Int FunctionDecl to_int; to_int.set_name(StandardFunctions::kInt); + CEL_RETURN_IF_ERROR(to_int.AddOverload( + MakeOverloadDecl(StandardOverloadIds::kBoolToInt, IntType(), BoolType()))); CEL_RETURN_IF_ERROR(to_int.AddOverload( MakeOverloadDecl(StandardOverloadIds::kIntToInt, IntType(), IntType()))); CEL_RETURN_IF_ERROR(to_int.AddOverload(MakeOverloadDecl( diff --git a/common/BUILD b/common/BUILD index 0426c0827..83a361461 100644 --- a/common/BUILD +++ b/common/BUILD @@ -1028,7 +1028,9 @@ cc_library( ], deps = [ ":kind", + ":type", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], diff --git a/common/function_descriptor.cc b/common/function_descriptor.cc index be32e8616..01149c3d1 100644 --- a/common/function_descriptor.cc +++ b/common/function_descriptor.cc @@ -18,13 +18,152 @@ #include #include "absl/base/macros.h" +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "common/kind.h" +#include "common/type.h" namespace cel { +namespace { + +// Recursive type matching with TypeParam binding support. +// Returns true if types match (i.e., they conflict and cannot coexist). +// +// Matching semantics: +// - dyn/any act as wildcards that match any type +// - TypeParam requires consistent binding across all arguments +// - Container types (list, map, opaque) are matched recursively +// +// The bindings map tracks TypeParam name -> bound Type mappings to ensure +// consistent binding. For example, (A, A) matches (int, int) but not +// (int, string) because A cannot bind to both int and string. +bool TypeMatches(const Type& a, const Type& b, + absl::flat_hash_map& bindings) { + // Wildcard types match anything (consistent with type checker's IsWildCardType) + // - dyn: dynamic type, defers type checking to runtime + // - any: legacy wildcard type + // - error: error type propagates through expressions + if (a.IsDyn() || b.IsDyn()) { + return true; + } + if (a.IsAny() || b.IsAny()) { + return true; + } + if (a.IsError() || b.IsError()) { + return true; + } + + // TypeParam handling - requires consistent binding + if (a.IsTypeParam()) { + std::string name(a.GetTypeParam().name()); + auto it = bindings.find(name); + if (it != bindings.end()) { + // Already bound - check consistency with the bound type + return TypeMatches(it->second, b, bindings); + } else { + // Not yet bound - bind to b and return match + bindings[name] = b; + return true; + } + } + if (b.IsTypeParam()) { + std::string name(b.GetTypeParam().name()); + auto it = bindings.find(name); + if (it != bindings.end()) { + // Already bound - check consistency with the bound type + return TypeMatches(a, it->second, bindings); + } else { + // Not yet bound - bind to a and return match + bindings[name] = a; + return true; + } + } + + // Different kinds don't match (e.g., int vs string) + TypeKind a_kind = a.kind(); + TypeKind b_kind = b.kind(); + if (a_kind != b_kind) { + return false; + } + + // Same kind - check type-specific details + switch (a_kind) { + case TypeKind::kList: { + // Recursively check element types + return TypeMatches(a.GetList().GetElement(), + b.GetList().GetElement(), bindings); + } + case TypeKind::kMap: { + // Recursively check key and value types + return TypeMatches(a.GetMap().GetKey(), b.GetMap().GetKey(), + bindings) && + TypeMatches(a.GetMap().GetValue(), b.GetMap().GetValue(), + bindings); + } + case TypeKind::kStruct: { + // Empty StructType acts as wildcard for any struct + StructType struct_a = a.GetStruct(); + StructType struct_b = b.GetStruct(); + if (!struct_a || !struct_b) { + // Empty struct matches any struct + return true; + } + // Both have names - must match exactly + return struct_a.name() == struct_b.name(); + } + case TypeKind::kOpaque: { + // Opaque types (including Optional) - must have same name and + // recursively matching parameters + OpaqueType opaque_a = a.GetOpaque(); + OpaqueType opaque_b = b.GetOpaque(); + // Empty OpaqueType acts as wildcard for any opaque + if (!opaque_a || !opaque_b) { + return true; + } + if (opaque_a.name() != opaque_b.name()) { + return false; + } + TypeParameters params_a = opaque_a.GetParameters(); + TypeParameters params_b = opaque_b.GetParameters(); + if (params_a.size() != params_b.size()) { + return false; + } + for (size_t i = 0; i < params_a.size(); i++) { + if (!TypeMatches(params_a[i], params_b[i], bindings)) { + return false; + } + } + return true; + } + default: + // Basic types with same kind always match + return true; + } +} + +// Converts a span of Kinds to a vector of Types. +// Uses the implicit Type(Kind) constructor for conversion. +std::vector KindsToTypes(absl::Span kinds) { + // Uses implicit Type(Kind) constructor for conversion. + std::vector types(kinds.begin(), kinds.end()); + return types; +} + +} // namespace + +const std::vector& FunctionDescriptor::kinds() const { + return impl_->kinds; +} + bool FunctionDescriptor::ShapeMatches(bool receiver_style, absl::Span types) const { + // Convert Kinds to Types and delegate to the Type-based version. + return ShapeMatches(receiver_style, KindsToTypes(types)); +} + +bool FunctionDescriptor::ShapeMatches(bool receiver_style, + absl::Span types) const { if (this->receiver_style() != receiver_style) { return false; } @@ -33,11 +172,13 @@ bool FunctionDescriptor::ShapeMatches(bool receiver_style, return false; } + // Use type-level matching with consistent TypeParam binding. + // The binding context is shared across all arguments to ensure + // TypeParam consistency (e.g., (A, A) requires both args to be same type). + absl::flat_hash_map bindings; + for (size_t i = 0; i < this->types().size(); i++) { - Kind this_type = this->types()[i]; - Kind other_type = types[i]; - if (this_type != Kind::kAny && other_type != Kind::kAny && - this_type != other_type) { + if (!TypeMatches(this->types()[i], types[i], bindings)) { return false; } } @@ -45,11 +186,19 @@ bool FunctionDescriptor::ShapeMatches(bool receiver_style, } bool FunctionDescriptor::operator==(const FunctionDescriptor& other) const { - return impl_.get() == other.impl_.get() || - (name() == other.name() && - receiver_style() == other.receiver_style() && - types().size() == other.types().size() && - std::equal(types().begin(), types().end(), other.types().begin())); + if (impl_.get() == other.impl_.get()) { + return true; + } + if (name() != other.name() || + receiver_style() != other.receiver_style() || + types().size() != other.types().size()) { + return false; + } + // Compare using Kind for backward compatibility. + // Full Type comparison can be added later if needed. + const auto& lhs_types = types(); + const auto& rhs_types = other.types(); + return std::equal(lhs_types.begin(), lhs_types.end(), rhs_types.begin()); } bool FunctionDescriptor::operator<(const FunctionDescriptor& other) const { @@ -73,7 +222,9 @@ bool FunctionDescriptor::operator<(const FunctionDescriptor& other) const { auto rhs_begin = other.types().begin(); auto rhs_end = other.types().end(); while (lhs_begin != lhs_end && rhs_begin != rhs_end) { - if (*lhs_begin < *rhs_begin) { + // Compare types lexicographically using DebugString as stable ordering + // This ensures consistent ordering for all Type variants + if (lhs_begin->DebugString() < rhs_begin->DebugString()) { return true; } if (!(*lhs_begin == *rhs_begin)) { diff --git a/common/function_descriptor.h b/common/function_descriptor.h index 75c61e13a..4970285a4 100644 --- a/common/function_descriptor.h +++ b/common/function_descriptor.h @@ -23,6 +23,7 @@ #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "common/kind.h" +#include "common/type.h" namespace cel { @@ -43,8 +44,12 @@ struct FunctionDescriptorOptions { // Coarsely describes a function for the purpose of runtime resolution of // overloads. +// +// Internally stores full Type information to enable type-level overload +// resolution. class FunctionDescriptor final { public: + // Constructors for legacy Kind-based types. FunctionDescriptor(absl::string_view name, bool receiver_style, std::vector types, bool is_strict) : impl_(std::make_shared( @@ -65,16 +70,48 @@ class FunctionDescriptor final { : impl_(std::make_shared(name, std::move(types), is_receiver_style, options)) {} + // Constructors for type-level descriptor. + // - overload_id for static analysis matching + // - types for runtime type-level matching + FunctionDescriptor(absl::string_view name, absl::string_view overload_id, + bool receiver_style, std::vector types, + bool is_strict) + : impl_(std::make_shared( + name, overload_id, std::move(types), receiver_style, + FunctionDescriptorOptions{is_strict, + /*is_contextual=*/false})) {} + + FunctionDescriptor(absl::string_view name, absl::string_view overload_id, + bool receiver_style, std::vector types, + bool is_strict, bool is_contextual) + : impl_(std::make_shared( + name, overload_id, std::move(types), receiver_style, + FunctionDescriptorOptions{is_strict, is_contextual})) {} + + FunctionDescriptor(absl::string_view name, absl::string_view overload_id, + bool is_receiver_style, std::vector types, + FunctionDescriptorOptions options = {}) + : impl_(std::make_shared( + name, overload_id, std::move(types), is_receiver_style, options)) {} + // Function name. const std::string& name() const { return impl_->name; } + // Overload ID for precise overload resolution (empty string if not set). + const std::string& overload_id() const { return impl_->overload_id; } + + // Whether this descriptor has an overload ID. + bool has_overload_id() const { return !impl_->overload_id.empty(); } + // Whether function is receiver style i.e. true means arg0.name(args[1:]...). bool receiver_style() const { return impl_->is_receiver_style; } - // The argument types the function accepts. - // - // TODO(uncreated-issue/17): make this kinds - const std::vector& types() const { return impl_->types; } + // The argument types the function accepts, with full Type information. + // This includes element types for containers (e.g., list vs list). + const std::vector& types() const { return impl_->types; } + + // The argument types the function accepts, as Kinds. + const std::vector& kinds() const; // if true (strict, default), error or unknown arguments are propagated // instead of calling the function. if false (non-strict), the function may @@ -96,6 +133,7 @@ class FunctionDescriptor final { return ShapeMatches(other.receiver_style(), other.types()); } bool ShapeMatches(bool receiver_style, absl::Span types) const; + bool ShapeMatches(bool receiver_style, absl::Span types) const; bool operator==(const FunctionDescriptor& other) const; @@ -103,15 +141,42 @@ class FunctionDescriptor final { private: struct Impl final { - Impl(absl::string_view name, std::vector types, + // Kind-based constructor (legacy API, no overload_id) + Impl(absl::string_view name, std::vector kinds, + bool is_receiver_style, FunctionDescriptorOptions options) + : name(name), + kinds(std::move(kinds)), + is_receiver_style(is_receiver_style), + options(options) { + // Convert kinds to types + types.reserve(this->kinds.size()); + for (const Kind& kind : this->kinds) { + types.push_back(Type(kind)); + } + } + + // Type-based constructor with overload_id (new API) + Impl(absl::string_view name, std::string_view overload_id, + std::vector types, bool is_receiver_style, FunctionDescriptorOptions options) : name(name), + overload_id(overload_id), types(std::move(types)), is_receiver_style(is_receiver_style), - options(options) {} + options(options) { + // Derive kinds from types for backward compatibility + kinds.reserve(this->types.size()); + for (const Type& type : this->types) { + kinds.push_back(TypeKindToKind(type.kind())); + } + } std::string name; - std::vector types; + std::string overload_id; + std::vector types; + // Backward compatibility field: Kind-level type representation. + // Derived from types during construction for legacy APIs. + std::vector kinds; bool is_receiver_style; FunctionDescriptorOptions options; }; diff --git a/common/standard_definitions.h b/common/standard_definitions.h index eea185f6b..6b803efeb 100644 --- a/common/standard_definitions.h +++ b/common/standard_definitions.h @@ -304,6 +304,7 @@ struct StandardOverloadIds { static constexpr absl::string_view kIntToUint = "int64_to_uint64"; static constexpr absl::string_view kStringToUint = "string_to_uint64"; // to_int + static constexpr absl::string_view kBoolToInt = "bool_to_int64"; static constexpr absl::string_view kUintToInt = "uint64_to_int64"; static constexpr absl::string_view kDoubleToInt = "double_to_int64"; static constexpr absl::string_view kIntToInt = "int64_to_int64"; diff --git a/common/type.cc b/common/type.cc index 9ea85954c..79be68413 100644 --- a/common/type.cc +++ b/common/type.cc @@ -82,6 +82,105 @@ Type Type::Enum(const google::protobuf::EnumDescriptor* absl_nonnull descriptor) return EnumType(descriptor); } +Type::Type(Kind kind) { + switch (kind) { + case Kind::kNull: + *this = NullType(); + break; + case Kind::kBool: + *this = BoolType(); + break; + case Kind::kInt: + *this = IntType(); + break; + case Kind::kUint: + *this = UintType(); + break; + case Kind::kDouble: + *this = DoubleType(); + break; + case Kind::kString: + *this = StringType(); + break; + case Kind::kBytes: + *this = BytesType(); + break; + case Kind::kStruct: + // Empty MessageType represents any struct (wildcard for matching) + // Note: We directly assign to variant_ to avoid StructType::ToTypeVariant() + // which converts empty StructType to DynType. + variant_ = MessageType(); + break; + case Kind::kDuration: + *this = DurationType(); + break; + case Kind::kTimestamp: + *this = TimestampType(); + break; + case Kind::kList: + // List without element type info defaults to list + *this = ListType(); + break; + case Kind::kMap: + // Map without key/value type info defaults to map + *this = MapType(); + break; + case Kind::kUnknown: + *this = UnknownType(); + break; + case Kind::kType: + *this = TypeType(); + break; + case Kind::kError: + *this = ErrorType(); + break; + case Kind::kAny: + *this = AnyType(); + break; + + case Kind::kDyn: + *this = DynType(); + break; + case Kind::kOpaque: + // Empty OpaqueType represents any opaque (wildcard) + *this = OpaqueType(); + break; + + case Kind::kBoolWrapper: + *this = BoolWrapperType(); + break; + case Kind::kIntWrapper: + *this = IntWrapperType(); + break; + case Kind::kUintWrapper: + *this = UintWrapperType(); + break; + case Kind::kDoubleWrapper: + *this = DoubleWrapperType(); + break; + case Kind::kStringWrapper: + *this = StringWrapperType(); + break; + case Kind::kBytesWrapper: + *this = BytesWrapperType(); + break; + + case Kind::kTypeParam: + *this = TypeParamType(); + break; + case Kind::kFunction: + *this = FunctionType(); + break; + case Kind::kEnum: + *this = EnumType(); + break; + + default: + *this = DynType(); + break; + } +} + namespace { static constexpr std::array kTypeToKindArray = { diff --git a/common/type.h b/common/type.h index c8851dd4e..e5cd78de1 100644 --- a/common/type.h +++ b/common/type.h @@ -35,6 +35,7 @@ #include "absl/types/span.h" #include "absl/types/variant.h" #include "absl/utility/utility.h" +#include "common/kind.h" #include "common/type_kind.h" #include "common/types/any_type.h" // IWYU pragma: export #include "common/types/bool_type.h" // IWYU pragma: export @@ -109,6 +110,11 @@ class Type final { Type& operator=(const Type&) = default; Type& operator=(Type&&) = default; + // Implicit conversion from Kind to Type. + // This allows std::vector to be implicitly converted to std::vector. + // NOLINTNEXTLINE(google-explicit-constructor) + Type(Kind kind); + template >>> diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index f7300cb58..3a32972f3 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -103,6 +103,7 @@ cc_library( "//base:data", "//common:allocator", "//common:ast", + "//common/ast:metadata", "//common:ast_traverse", "//common:ast_visitor", "//common:constant", diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index aa9a8858c..5b38ec493 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -53,6 +53,7 @@ #include "base/type_provider.h" #include "common/allocator.h" #include "common/ast.h" +#include "common/ast/metadata.h" #include "common/ast_traverse.h" #include "common/ast_visitor.h" #include "common/constant.h" @@ -520,6 +521,7 @@ class FlatExprVisitor : public cel::AstVisitor { bool enable_optional_types) : resolver_(resolver), type_provider_(type_provider), + reference_map_(reference_map), progress_status_(absl::OkStatus()), resolved_select_expr_(nullptr), options_(options), @@ -1637,25 +1639,149 @@ class FlatExprVisitor : public cel::AstVisitor { bool receiver_style = call_expr->has_target(); size_t num_args = call_expr->args().size() + (receiver_style ? 1 : 0); - // First, search for lazily defined function overloads. - // Lazy functions shadow eager functions with the same signature. - auto lazy_overloads = resolver_.FindLazyOverloads( - function, call_expr->has_target(), num_args, expr->id()); + // Try to extract overload IDs from reference_map + // Checked expressions will have overload_ids, parse-only won't + std::vector overload_ids; + auto ref_it = reference_map_.find(expr->id()); + if (ref_it != reference_map_.end()) { + overload_ids = ref_it->second.overload_id(); + } + + // Branch 1: Checked expression with overload IDs + // Direct lookup by overload ID - O(k) where k = number of overload_ids + if (!overload_ids.empty()) { + // Try lazy functions first (they shadow static functions) + std::vector lazy_overloads; + for (const auto& id : overload_ids) { + auto lazy = resolver_.FindLazyOverloadById(function, id); + if (lazy.has_value()) { + lazy_overloads.push_back(*lazy); + } + } + + if (!lazy_overloads.empty()) { + size_t checked_overloads_count = lazy_overloads.size(); + + // Merge arity-matched lazy overloads as eval-time fallback candidates. + // When dyn() casts cause runtime types to differ from checker-inferred + // types, the overload_id-resolved entries may not match. Arity-matched + // overloads ensure the correct overload can still be found at eval + // time. + auto lazy_fallback = resolver_.FindLazyOverloads( + function, receiver_style, num_args, expr->id()); + for (auto& ovl : lazy_fallback) { + bool duplicate = false; + for (const auto& existing : lazy_overloads) { + if (existing.descriptor == ovl.descriptor) { + duplicate = true; + break; + } + } + if (!duplicate) { + lazy_overloads.push_back(ovl); + } + } + + auto depth = RecursionEligible(); + if (depth.has_value()) { + auto args = + program_builder_.current()->ExtractRecursiveDependencies(); + SetRecursiveStep( + CreateDirectLazyFunctionStep( + expr->id(), *call_expr, std::move(args), + std::move(lazy_overloads), checked_overloads_count, + options_.enable_type_level_overload), + *depth + 1); + } else { + AddStep(CreateFunctionStep( + *call_expr, expr->id(), std::move(lazy_overloads), + checked_overloads_count, options_.enable_type_level_overload)); + } + return; + } + + // Try static functions + std::vector static_overloads; + for (const auto& id : overload_ids) { + auto static_func = resolver_.FindOverloadById(function, id); + if (static_func.has_value()) { + static_overloads.push_back(*static_func); + } + } + + if (!static_overloads.empty()) { + size_t checked_overloads_count = static_overloads.size(); + + // Merge arity-matched overloads as eval-time fallback candidates. + // When a dyn() cast causes the runtime argument types to differ from + // what the checker inferred, the overload_id-resolved overloads may + // not match. Arity-matched overloads ensure the correct overload can + // still be found at eval time via + // ArgumentKindsMatch/ArgumentTypesMatch. + auto fallback = resolver_.FindOverloads(function, receiver_style, + num_args, expr->id()); + for (auto& ovl : fallback) { + bool duplicate = false; + for (const auto& existing : static_overloads) { + if (&existing.implementation == &ovl.implementation) { + duplicate = true; + break; + } + } + if (!duplicate) { + static_overloads.push_back(ovl); + } + } + + auto recursion_depth = RecursionEligible(); + if (recursion_depth.has_value()) { + ABSL_DCHECK(program_builder_.current() != nullptr); + auto args = + program_builder_.current()->ExtractRecursiveDependencies(); + SetRecursiveStep( + CreateDirectFunctionStep(expr->id(), *call_expr, std::move(args), + std::move(static_overloads), + checked_overloads_count, + options_.enable_type_level_overload), + *recursion_depth + 1); + } else { + AddStep(CreateFunctionStep( + *call_expr, expr->id(), std::move(static_overloads), + checked_overloads_count, options_.enable_type_level_overload)); + } + return; + } + + // Fallback: overload IDs provided but no functions found with those IDs. + // This can happen when functions are registered without overload_id + // (legacy). Fall through to arity-based matching as a compatibility + // layer. + } + + // Branch 2: Parse-only expression - use arity-based matching + + // Try lazy functions first + auto lazy_overloads = resolver_.FindLazyOverloads(function, receiver_style, + num_args, expr->id()); if (!lazy_overloads.empty()) { if (auto depth = RecursionEligible(); depth.has_value()) { auto args = program_builder_.current()->ExtractRecursiveDependencies(); SetRecursiveStep(CreateDirectLazyFunctionStep( expr->id(), *call_expr, std::move(args), - std::move(lazy_overloads)), + std::move(lazy_overloads), + /*checked_overloads_count=*/0, + options_.enable_type_level_overload), *depth + 1); - return; + } else { + AddStep(CreateFunctionStep(*call_expr, expr->id(), + std::move(lazy_overloads), + /*checked_overloads_count=*/0, + options_.enable_type_level_overload)); } - AddStep(CreateFunctionStep(*call_expr, expr->id(), - std::move(lazy_overloads))); return; } - // Second, search for eagerly defined function overloads. + // Try static functions auto overloads = resolver_.FindOverloads(function, receiver_style, num_args, expr->id()); if (overloads.empty()) { @@ -1664,8 +1790,8 @@ class FlatExprVisitor : public cel::AstVisitor { // CelExpression creation or an inspectable warning for use within runtime // logging. auto status = issue_collector_.AddIssue(RuntimeIssue::CreateWarning( - absl::InvalidArgumentError( - "No overloads provided for FunctionStep creation"), + absl::InvalidArgumentError(absl::StrCat( + "No overloads provided for FunctionStep creation: ", function)), RuntimeIssue::ErrorCode::kNoMatchingOverload)); if (!status.ok()) { SetProgressStatusIfError(status); @@ -1681,11 +1807,15 @@ class FlatExprVisitor : public cel::AstVisitor { auto args = program_builder_.current()->ExtractRecursiveDependencies(); SetRecursiveStep( CreateDirectFunctionStep(expr->id(), *call_expr, std::move(args), - std::move(overloads)), + std::move(overloads), + /*checked_overloads_count=*/0, + options_.enable_type_level_overload), *recursion_depth + 1); return; } - AddStep(CreateFunctionStep(*call_expr, expr->id(), std::move(overloads))); + AddStep(CreateFunctionStep(*call_expr, expr->id(), std::move(overloads), + /*checked_overloads_count=*/0, + options_.enable_type_level_overload)); } // Add a step to the program, taking ownership. If successful, returns the @@ -1909,6 +2039,7 @@ class FlatExprVisitor : public cel::AstVisitor { const Resolver& resolver_; const cel::TypeProvider& type_provider_; + const absl::flat_hash_map& reference_map_; absl::Status progress_status_; absl::flat_hash_map call_handlers_; diff --git a/eval/compiler/resolver.cc b/eval/compiler/resolver.cc index cca72964a..20b943c14 100644 --- a/eval/compiler/resolver.cc +++ b/eval/compiler/resolver.cc @@ -152,6 +152,21 @@ std::vector Resolver::FindOverloads( return funcs; } +std::vector Resolver::FindOverloadsByTypes( + absl::string_view name, bool receiver_style, + const std::vector& types, int64_t expr_id) const { + std::vector funcs; + auto names = FullyQualifiedNames(name, expr_id); + for (auto it = names.begin(); it != names.end(); it++) { + funcs = function_registry_.FindStaticOverloadsByTypes(*it, receiver_style, + types); + if (!funcs.empty()) { + return funcs; + } + } + return funcs; +} + std::vector Resolver::FindOverloads( absl::string_view name, bool receiver_style, size_t arity, int64_t expr_id) const { @@ -173,6 +188,21 @@ std::vector Resolver::FindOverloads( return funcs; } +absl::optional Resolver::FindOverloadById( + absl::string_view name, absl::string_view overload_id) const { + // Try with namespace prefixes + auto prefixes = GetPrefixesFor(name); + for (const auto& prefix : prefixes) { + std::string qualified_name = absl::StrCat(prefix, name); + auto result = + function_registry_.FindStaticOverloadById(qualified_name, overload_id); + if (result.has_value()) { + return result; + } + } + return absl::nullopt; +} + std::vector Resolver::FindLazyOverloads( absl::string_view name, bool receiver_style, const std::vector& types, int64_t expr_id) const { @@ -189,6 +219,22 @@ std::vector Resolver::FindLazyOverloads( return funcs; } +std::vector +Resolver::FindLazyOverloadsByTypes(absl::string_view name, bool receiver_style, + const std::vector& types, + int64_t expr_id) const { + std::vector funcs; + auto names = FullyQualifiedNames(name, expr_id); + for (const auto& name : names) { + funcs = function_registry_.FindLazyOverloadsByTypes(name, receiver_style, + types); + if (!funcs.empty()) { + return funcs; + } + } + return funcs; +} + std::vector Resolver::FindLazyOverloads( absl::string_view name, bool receiver_style, size_t arity, int64_t expr_id) const { @@ -219,4 +265,20 @@ Resolver::FindType(absl::string_view name, int64_t expr_id) const { return std::nullopt; } +absl::optional +Resolver::FindLazyOverloadById(absl::string_view name, + absl::string_view overload_id) const { + // Try with namespace prefixes + auto prefixes = GetPrefixesFor(name); + for (const auto& prefix : prefixes) { + std::string qualified_name = absl::StrCat(prefix, name); + auto result = + function_registry_.FindLazyOverloadById(qualified_name, overload_id); + if (result.has_value()) { + return result; + } + } + return absl::nullopt; +} + } // namespace google::api::expr::runtime diff --git a/eval/compiler/resolver.h b/eval/compiler/resolver.h index de7b22f26..227e804f1 100644 --- a/eval/compiler/resolver.h +++ b/eval/compiler/resolver.h @@ -78,20 +78,40 @@ class Resolver { absl::string_view name, bool receiver_style, const std::vector& types, int64_t expr_id = -1) const; + std::vector FindLazyOverloadsByTypes( + absl::string_view name, bool receiver_style, + const std::vector& types, int64_t expr_id = -1) const; + std::vector FindLazyOverloads( absl::string_view name, bool receiver_style, size_t arity, int64_t expr_id = -1) const; + // Find a specific lazy overload by name and overload_id. + // Returns empty optional if not found. + // Considers namespace prefixes when searching. + absl::optional FindLazyOverloadById( + absl::string_view name, absl::string_view overload_id) const; + // FindOverloads returns the set, possibly empty, of eager function overloads // matching the given function signature. std::vector FindOverloads( absl::string_view name, bool receiver_style, const std::vector& types, int64_t expr_id = -1) const; + std::vector FindOverloadsByTypes( + absl::string_view name, bool receiver_style, + const std::vector& types, int64_t expr_id = -1) const; + std::vector FindOverloads( absl::string_view name, bool receiver_style, size_t arity, int64_t expr_id = -1) const; + // Find a specific static overload by name and overload_id. + // Returns empty optional if not found. + // Considers namespace prefixes when searching. + absl::optional FindOverloadById( + absl::string_view name, absl::string_view overload_id) const; + // FullyQualifiedNames returns the set of fully qualified names which may be // derived from the base_name within the specified expression container. std::vector FullyQualifiedNames(absl::string_view base_name, diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 44c7ded79..19b7701c8 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -282,6 +282,7 @@ cc_library( "//common:expr", "//common:function_descriptor", "//common:kind", + "//common:type", "//common:value", "//common:value_kind", "//eval/internal:errors", @@ -292,6 +293,7 @@ cc_library( "//runtime:function_provider", "//runtime:function_registry", "//runtime/internal:errors", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/eval/eval/function_step.cc b/eval/eval/function_step.cc index 12c5af8a7..e9febe561 100644 --- a/eval/eval/function_step.cc +++ b/eval/eval/function_step.cc @@ -1,6 +1,5 @@ #include "eval/eval/function_step.h" -#include #include #include #include @@ -8,6 +7,7 @@ #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -19,6 +19,7 @@ #include "common/expr.h" #include "common/function_descriptor.h" #include "common/kind.h" +#include "common/type.h" #include "common/value.h" #include "common/value_kind.h" #include "eval/eval/attribute_trail.h" @@ -67,7 +68,7 @@ bool ArgumentKindsMatch(const cel::FunctionDescriptor& descriptor, for (size_t i = 0; i < types_size; i++) { const auto& arg = arguments[i]; - cel::Kind param_kind = descriptor.types()[i]; + cel::Kind param_kind = descriptor.kinds()[i]; if (arg->kind() != param_kind && param_kind != cel::Kind::kAny) { return false; } @@ -76,6 +77,143 @@ bool ArgumentKindsMatch(const cel::FunctionDescriptor& descriptor, return true; } +// Forward declarations for recursive functions +bool ValuesMatch(cel::ValueIterator& iterator, const cel::Type& value_expected, + const absl::optional& key_expected, + absl::flat_hash_map& bindings, + const ExecutionFrameBase& frame); + +// Check if a single value matches expected type. +// Returns true if value's type is compatible with expected. +bool ValueMatches(const cel::Value& value, const cel::Type& expected, + absl::flat_hash_map& bindings, + const ExecutionFrameBase& frame) { + // Wildcard: expected accepts any type + if (expected.IsDyn() || expected.IsAny()) { + return true; + } + + // TypeParam handling + if (expected.IsTypeParam()) { + std::string name(expected.GetTypeParam().name()); + auto it = bindings.find(name); + if (it != bindings.end()) { + // Already bound: verify value matches the bound type + return ValueMatches(value, it->second, bindings, frame); + } + // Not yet bound: bind to value's runtime type + bindings[name] = value.GetRuntimeType(); + return true; + } + + // Get actual type from value + cel::Type actual = value.GetRuntimeType(); + + // Container handling - recursive verification + if (expected.IsList()) { + if (!actual.IsList()) { + return false; + } + const cel::ListValue& list = value.GetList(); + auto iter_result = list.NewIterator(); + if (!iter_result.ok()) { + return false; + } + cel::Type elem_expected = expected.GetList().GetElement(); + return ValuesMatch(**iter_result, elem_expected, absl::nullopt, bindings, + frame); + } + + if (expected.IsMap()) { + if (!actual.IsMap()) { + return false; + } + const cel::MapValue& map = value.GetMap(); + auto iter_result = map.NewIterator(); + if (!iter_result.ok()) { + return false; + } + cel::Type key_expected = expected.GetMap().GetKey(); + cel::Type val_expected = expected.GetMap().GetValue(); + return ValuesMatch(**iter_result, val_expected, key_expected, bindings, + frame); + } + + if (expected.IsOptional()) { + if (!actual.IsOptional()) { + return false; + } + const cel::OptionalValue& opt = value.GetOptional(); + if (!opt.HasValue()) { + return true; // Empty optional matches any optional + } + cel::OptionalType opt_type = expected.GetOptional(); + cel::TypeParameters params = opt_type.GetParameters(); + if (params.empty()) { + return true; + } + return ValueMatches(opt.Value(), params[0], bindings, frame); + } + + // Non-container types: direct type comparison + return expected == actual; +} + +// Check if all values from iterator match expected types. +// For lists: key_expected should be nullopt +// For maps: key_expected contains expected key type +bool ValuesMatch(cel::ValueIterator& iterator, const cel::Type& value_expected, + const absl::optional& key_expected, + absl::flat_hash_map& bindings, + const ExecutionFrameBase& frame) { + while (iterator.HasNext()) { + cel::Value key, val; + auto next_result = + iterator.Next2(frame.descriptor_pool(), frame.message_factory(), + frame.arena(), &key, &val); + if (!next_result.ok() || !*next_result) { + return false; // Error during iteration + } + + // Check key if expected + if (key_expected.has_value()) { + if (!ValueMatches(key, *key_expected, bindings, frame)) { + return false; + } + } + + // Check value + if (!ValueMatches(val, value_expected, bindings, frame)) { + return false; + } + } + return true; // All elements match (or empty) +} + +// Type-level argument matching with value-based verification. +// Directly uses Value.GetRuntimeType() for each argument and compares +// against descriptor.types() with full recursive verification. +// Supports TypeParam bindings and container type checking. +bool ArgumentTypesMatch(const cel::FunctionDescriptor& descriptor, + absl::Span arguments, + const ExecutionFrameBase& frame) { + const auto& expected_types = descriptor.types(); + + if (expected_types.size() != arguments.size()) { + return false; + } + + // Shared binding context for TypeParam consistency across all arguments + absl::flat_hash_map bindings; + + for (size_t i = 0; i < expected_types.size(); i++) { + if (!ValueMatches(arguments[i], expected_types[i], bindings, frame)) { + return false; + } + } + return true; +} + // Adjust new type names to legacy equivalent. int -> int64. // Temporary fix to migrate value types without breaking clients. // TODO(uncreated-issue/46): Update client tests that depend on this value. @@ -285,12 +423,37 @@ absl::Status AbstractFunctionStep::Evaluate(ExecutionFrame* frame) const { absl::StatusOr ResolveStatic( absl::Span input_args, - absl::Span overloads) { + absl::Span overloads, + const ExecutionFrameBase& frame, size_t checked_overloads_count, + bool type_level_overload) { + // Fast path: Checked expression with a single overload resolved by + // overload_id. Still verify runtime arguments because checked expressions can + // contain dyn values or be evaluated with activation values that differ from + // the checker environment. + if (checked_overloads_count == 1) { + bool matches = + type_level_overload + ? ArgumentTypesMatch(overloads[0].descriptor, input_args, frame) + : ArgumentKindsMatch(overloads[0].descriptor, input_args); + if (matches) { + return overloads[0]; + } + // Runtime mismatch -- fall through to full resolution which includes both + // checked overloads and arity-matched fallbacks. + } + + // Full resolution path over all candidates. for (const auto& overload : overloads) { - if (ArgumentKindsMatch(overload.descriptor, input_args)) { + bool matches = + type_level_overload + ? ArgumentTypesMatch(overload.descriptor, input_args, frame) + : ArgumentKindsMatch(overload.descriptor, input_args); + + if (matches) { return overload; } } + return std::nullopt; } @@ -298,70 +461,89 @@ absl::StatusOr ResolveLazy( absl::Span input_args, absl::string_view name, bool receiver_style, absl::Span providers, - const ExecutionFrameBase& frame) { - ResolveResult result = std::nullopt; - - std::vector arg_types(input_args.size()); + const ExecutionFrameBase& frame, size_t checked_overloads_count, + bool type_level_overload) { + const cel::ActivationInterface& activation = frame.activation(); - std::transform( - input_args.begin(), input_args.end(), arg_types.begin(), - [](const cel::Value& value) { return ValueKindToKind(value->kind()); }); + // Phase A: Collect overloads from providers. + std::vector overloads; + + if (checked_overloads_count > 0) { + // Checked path: use provider.descriptor (which carries overload_id) for + // precise lookup. + for (const auto& provider : providers) { + CEL_ASSIGN_OR_RETURN(auto overload, provider.provider.GetFunction( + provider.descriptor, activation)); + if (overload.has_value()) { + overloads.push_back(overload.value()); + } + } + } else { + // Parse-only: pre-filter by ArgumentKindsMatch, then collect overloads. + for (const auto& provider : providers) { + if (!ArgumentKindsMatch(provider.descriptor, input_args)) { + continue; + } - cel::FunctionDescriptor matcher{name, receiver_style, std::move(arg_types)}; + CEL_ASSIGN_OR_RETURN(auto overload, provider.provider.GetFunction( + provider.descriptor, activation)); - const cel::ActivationInterface& activation = frame.activation(); - for (auto provider : providers) { - // The LazyFunctionStep has so far only resolved by function shape, check - // that the runtime argument kinds agree with the specific descriptor for - // the provider candidates. - if (!ArgumentKindsMatch(provider.descriptor, input_args)) { - continue; + if (overload.has_value()) { + overloads.push_back(overload.value()); + } } + } - CEL_ASSIGN_OR_RETURN(auto overload, - provider.provider.GetFunction(matcher, activation)); - if (overload.has_value()) { - // More than one overload matches our arguments. - if (result.has_value()) { - return absl::Status(absl::StatusCode::kInternal, - "Cannot resolve overloads"); - } + // Phase B: Match arguments and select overload. + for (const auto& overload : overloads) { + bool matches = + type_level_overload + ? ArgumentTypesMatch(overload.descriptor, input_args, frame) + : ArgumentKindsMatch(overload.descriptor, input_args); - result.emplace(overload.value()); + if (matches) { + return overload; } } - return result; + return std::nullopt; } class EagerFunctionStep : public AbstractFunctionStep { public: EagerFunctionStep(std::vector overloads, const std::string& name, size_t num_args, - bool receiver_style, int64_t expr_id) + bool receiver_style, int64_t expr_id, + size_t checked_overloads_count, bool type_level_overload) : AbstractFunctionStep(name, num_args, receiver_style, expr_id), - overloads_(std::move(overloads)) {} + overloads_(std::move(overloads)), + checked_overloads_count_(checked_overloads_count), + type_level_overload_(type_level_overload) {} absl::StatusOr ResolveFunction( absl::Span input_args, const ExecutionFrame* frame) const override { - return ResolveStatic(input_args, overloads_); + return ResolveStatic(input_args, overloads_, *frame, + checked_overloads_count_, type_level_overload_); } private: std::vector overloads_; + size_t checked_overloads_count_; + bool type_level_overload_; }; class LazyFunctionStep : public AbstractFunctionStep { public: - // Constructs LazyFunctionStep that attempts to lookup function implementation - // at runtime. LazyFunctionStep(const std::string& name, size_t num_args, bool receiver_style, std::vector providers, - int64_t expr_id) + int64_t expr_id, size_t checked_overloads_count, + bool type_level_overload) : AbstractFunctionStep(name, num_args, receiver_style, expr_id), - providers_(std::move(providers)) {} + providers_(std::move(providers)), + checked_overloads_count_(checked_overloads_count), + type_level_overload_(type_level_overload) {} absl::StatusOr ResolveFunction( absl::Span input_args, @@ -369,46 +551,64 @@ class LazyFunctionStep : public AbstractFunctionStep { private: std::vector providers_; + size_t checked_overloads_count_; + bool type_level_overload_; }; absl::StatusOr LazyFunctionStep::ResolveFunction( absl::Span input_args, const ExecutionFrame* frame) const { - return ResolveLazy(input_args, name_, receiver_style_, providers_, *frame); + return ResolveLazy(input_args, name_, receiver_style_, providers_, *frame, + checked_overloads_count_, type_level_overload_); } class StaticResolver { public: - explicit StaticResolver(std::vector overloads) - : overloads_(std::move(overloads)) {} + using ResolveResult = absl::optional; + + StaticResolver(std::vector overloads, + size_t checked_overloads_count, bool type_level_overload) + : overloads_(std::move(overloads)), + checked_overloads_count_(checked_overloads_count), + type_level_overload_(type_level_overload) {} absl::StatusOr Resolve(ExecutionFrameBase& frame, absl::Span input) const { - return ResolveStatic(input, overloads_); + return ResolveStatic(input, overloads_, frame, checked_overloads_count_, + type_level_overload_); } private: std::vector overloads_; + size_t checked_overloads_count_; + bool type_level_overload_; }; class LazyResolver { public: - explicit LazyResolver( - std::vector providers, - std::string name, bool receiver_style) + using ResolveResult = absl::optional; + + LazyResolver(std::vector providers, + std::string name, bool receiver_style, + size_t checked_overloads_count, bool type_level_overload) : providers_(std::move(providers)), name_(std::move(name)), - receiver_style_(receiver_style) {} + receiver_style_(receiver_style), + checked_overloads_count_(checked_overloads_count), + type_level_overload_(type_level_overload) {} absl::StatusOr Resolve(ExecutionFrameBase& frame, absl::Span input) const { - return ResolveLazy(input, name_, receiver_style_, providers_, frame); + return ResolveLazy(input, name_, receiver_style_, providers_, frame, + checked_overloads_count_, type_level_overload_); } private: std::vector providers_; std::string name_; bool receiver_style_; + size_t checked_overloads_count_; + bool type_level_overload_; }; template @@ -491,39 +691,47 @@ class DirectFunctionStepImpl : public DirectExpressionStep { std::unique_ptr CreateDirectFunctionStep( int64_t expr_id, const cel::CallExpr& call, std::vector> deps, - std::vector overloads) { + std::vector overloads, + size_t checked_overloads_count, bool type_level_overload) { return std::make_unique>( expr_id, call.function(), std::move(deps), call.has_target(), - StaticResolver(std::move(overloads))); + StaticResolver(std::move(overloads), checked_overloads_count, + type_level_overload)); } std::unique_ptr CreateDirectLazyFunctionStep( int64_t expr_id, const cel::CallExpr& call, std::vector> deps, - std::vector providers) { + std::vector providers, + size_t checked_overloads_count, bool type_level_overload) { return std::make_unique>( expr_id, call.function(), std::move(deps), call.has_target(), - LazyResolver(std::move(providers), call.function(), call.has_target())); + LazyResolver(std::move(providers), call.function(), call.has_target(), + checked_overloads_count, type_level_overload)); } absl::StatusOr> CreateFunctionStep( const cel::CallExpr& call_expr, int64_t expr_id, - std::vector lazy_overloads) { + std::vector lazy_overloads, + size_t checked_overloads_count, bool type_level_overload) { bool receiver_style = call_expr.has_target(); size_t num_args = call_expr.args().size() + (receiver_style ? 1 : 0); const std::string& name = call_expr.function(); - return std::make_unique(name, num_args, receiver_style, - std::move(lazy_overloads), expr_id); + return std::make_unique( + name, num_args, receiver_style, std::move(lazy_overloads), expr_id, + checked_overloads_count, type_level_overload); } absl::StatusOr> CreateFunctionStep( const cel::CallExpr& call_expr, int64_t expr_id, - std::vector overloads) { + std::vector overloads, + size_t checked_overloads_count, bool type_level_overload) { bool receiver_style = call_expr.has_target(); size_t num_args = call_expr.args().size() + (receiver_style ? 1 : 0); const std::string& name = call_expr.function(); - return std::make_unique(std::move(overloads), name, - num_args, receiver_style, expr_id); + return std::make_unique( + std::move(overloads), name, num_args, receiver_style, expr_id, + checked_overloads_count, type_level_overload); } } // namespace google::api::expr::runtime diff --git a/eval/eval/function_step.h b/eval/eval/function_step.h index 9f664dc09..2bf3c80d8 100644 --- a/eval/eval/function_step.h +++ b/eval/eval/function_step.h @@ -1,6 +1,7 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_FUNCTION_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_FUNCTION_STEP_H_ +#include #include #include #include @@ -17,10 +18,18 @@ namespace google::api::expr::runtime { // Factory method for Call-based execution step where the function has been // statically resolved from a set of eagerly functions configured in the // CelFunctionRegistry. +// +// checked_overloads_count: number of overloads at the front of the vector that +// were resolved via overload_id from the checked expression's reference_map. +// 0 means parse-only (no overload_id info available). +// Remaining overloads (if any) are arity-matched fallback candidates for +// runtime type mismatch scenarios (e.g. dyn() casts). std::unique_ptr CreateDirectFunctionStep( int64_t expr_id, const cel::CallExpr& call, std::vector> deps, - std::vector overloads); + std::vector overloads, + size_t checked_overloads_count = 0, + bool type_level_overload = false); // Factory method for Call-based execution step where the function has been // statically resolved from a set of lazy functions configured in the @@ -28,20 +37,26 @@ std::unique_ptr CreateDirectFunctionStep( std::unique_ptr CreateDirectLazyFunctionStep( int64_t expr_id, const cel::CallExpr& call, std::vector> deps, - std::vector providers); + std::vector providers, + size_t checked_overloads_count = 0, + bool type_level_overload = false); // Factory method for Call-based execution step where the function will be // resolved at runtime (lazily) from an input Activation. absl::StatusOr> CreateFunctionStep( const cel::CallExpr& call, int64_t expr_id, - std::vector lazy_overloads); + std::vector lazy_overloads, + size_t checked_overloads_count = 0, + bool type_level_overload = false); // Factory method for Call-based execution step where the function has been // statically resolved from a set of eagerly functions configured in the // CelFunctionRegistry. absl::StatusOr> CreateFunctionStep( const cel::CallExpr& call, int64_t expr_id, - std::vector overloads); + std::vector overloads, + size_t checked_overloads_count = 0, + bool type_level_overload = false); } // namespace google::api::expr::runtime diff --git a/eval/public/cel_function.cc b/eval/public/cel_function.cc index 9b760d1ec..a95f030d2 100644 --- a/eval/public/cel_function.cc +++ b/eval/public/cel_function.cc @@ -24,7 +24,7 @@ bool CelFunction::MatchArguments(absl::Span arguments) const { } for (size_t i = 0; i < types_size; i++) { const auto& value = arguments[i]; - CelValue::Type arg_type = descriptor().types()[i]; + CelValue::Type arg_type = descriptor().kinds()[i]; if (value.type() != arg_type && arg_type != CelValue::Type::kAny) { return false; } @@ -41,7 +41,7 @@ bool CelFunction::MatchArguments(absl::Span arguments) const { } for (size_t i = 0; i < types_size; i++) { const auto& value = arguments[i]; - CelValue::Type arg_type = descriptor().types()[i]; + CelValue::Type arg_type = descriptor().kinds()[i]; if (value->kind() != arg_type && arg_type != CelValue::Type::kAny) { return false; } diff --git a/eval/public/cel_function_adapter_test.cc b/eval/public/cel_function_adapter_test.cc index 29d27e5af..ac7b97b7d 100644 --- a/eval/public/cel_function_adapter_test.cc +++ b/eval/public/cel_function_adapter_test.cc @@ -119,18 +119,18 @@ TEST(CelFunctionAdapterTest, TestTypeDeductionForCelValueBasicTypes) { EXPECT_EQ(descriptor.name(), "dummy_func"); int pos = 0; - ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kBool); - ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kInt64); - ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kUint64); - ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kDouble); - ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kString); - ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kBytes); - ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kMessage); - ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kDuration); - ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kTimestamp); - ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kList); - ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kMap); - ASSERT_EQ(descriptor.types()[pos++], CelValue::Type::kError); + ASSERT_EQ(descriptor.kinds()[pos++], CelValue::Type::kBool); + ASSERT_EQ(descriptor.kinds()[pos++], CelValue::Type::kInt64); + ASSERT_EQ(descriptor.kinds()[pos++], CelValue::Type::kUint64); + ASSERT_EQ(descriptor.kinds()[pos++], CelValue::Type::kDouble); + ASSERT_EQ(descriptor.kinds()[pos++], CelValue::Type::kString); + ASSERT_EQ(descriptor.kinds()[pos++], CelValue::Type::kBytes); + ASSERT_EQ(descriptor.kinds()[pos++], CelValue::Type::kMessage); + ASSERT_EQ(descriptor.kinds()[pos++], CelValue::Type::kDuration); + ASSERT_EQ(descriptor.kinds()[pos++], CelValue::Type::kTimestamp); + ASSERT_EQ(descriptor.kinds()[pos++], CelValue::Type::kList); + ASSERT_EQ(descriptor.kinds()[pos++], CelValue::Type::kMap); + ASSERT_EQ(descriptor.kinds()[pos++], CelValue::Type::kError); } TEST(CelFunctionAdapterTest, TestAdapterStatusOrMessage) { diff --git a/eval/public/extension_func_test.cc b/eval/public/extension_func_test.cc index 2e2497d7d..63bff1273 100644 --- a/eval/public/extension_func_test.cc +++ b/eval/public/extension_func_test.cc @@ -69,13 +69,13 @@ class ExtensionTest : public ::testing::Test { void TestStringStartsWith(const std::string& test_string, const std::string& prefix, bool result) { - TestStringInclusion("startsWith", {true, false}, test_string, prefix, + TestStringInclusion("startsWith", {true}, test_string, prefix, result); } void TestStringEndsWith(const std::string& test_string, const std::string& prefix, bool result) { - TestStringInclusion("endsWith", {true, false}, test_string, prefix, result); + TestStringInclusion("endsWith", {true}, test_string, prefix, result); } // Helper method to test timestamp() function diff --git a/extensions/BUILD b/extensions/BUILD index 05104a4a5..13fbc95a8 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -505,6 +505,7 @@ cc_library( deps = [ "//base:function_adapter", "//checker:type_checker_builder", + "//checker/internal:builtins_arena", "//common:decl", "//common:type", "//common:value", diff --git a/extensions/comprehensions_v2_functions.cc b/extensions/comprehensions_v2_functions.cc index bf23780c0..9fda37ce7 100644 --- a/extensions/comprehensions_v2_functions.cc +++ b/extensions/comprehensions_v2_functions.cc @@ -120,16 +120,22 @@ absl::StatusOr MapInsertMap( absl::Status RegisterComprehensionsV2Functions(FunctionRegistry& registry, const RuntimeOptions& options) { + // Overload IDs from comprehensions_v2.cc + static constexpr absl::string_view kMapInsertMapKeyValue = + "@mapInsert_map_key_value"; + static constexpr absl::string_view kMapInsertMapMap = "@mapInsert_map_map"; + CEL_RETURN_IF_ERROR(registry.Register( TernaryFunctionAdapter, MapValue, Value, Value>::CreateDescriptor("cel.@mapInsert", + kMapInsertMapKeyValue, /*receiver_style=*/false), TernaryFunctionAdapter, MapValue, Value, Value>::WrapFunction(&MapInsertKeyValue))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, MapValue, MapValue>:: - CreateDescriptor("cel.@mapInsert", + CreateDescriptor("cel.@mapInsert", kMapInsertMapMap, /*receiver_style=*/false), BinaryFunctionAdapter, MapValue, MapValue>::WrapFunction(&MapInsertMap))); diff --git a/extensions/encoders.cc b/extensions/encoders.cc index 66431b30b..a0eeabbdd 100644 --- a/extensions/encoders.cc +++ b/extensions/encoders.cc @@ -86,15 +86,21 @@ absl::Status RegisterEncodersDecls(TypeCheckerBuilder& builder) { absl::Status RegisterEncodersFunctions(FunctionRegistry& registry, const RuntimeOptions&) { + // Overload IDs from decls + static constexpr absl::string_view kBase64DecodeString = + "base64_decode_string"; + static constexpr absl::string_view kBase64EncodeBytes = "base64_encode_bytes"; + CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter, StringValue>::CreateDescriptor("base64.decode", + kBase64DecodeString, false), UnaryFunctionAdapter, StringValue>::WrapFunction( &Base64Decode))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter, BytesValue>::CreateDescriptor( - "base64.encode", false), + "base64.encode", kBase64EncodeBytes, false), UnaryFunctionAdapter, BytesValue>::WrapFunction( &Base64Encode))); return absl::OkStatus(); diff --git a/extensions/formatting.cc b/extensions/formatting.cc index 252fdc7bd..a781d9904 100644 --- a/extensions/formatting.cc +++ b/extensions/formatting.cc @@ -551,19 +551,20 @@ absl::Status RegisterStringFormattingFunctions( StringsExtensionFormatOptions format_options) { const int max_precision = std::clamp(format_options.max_precision, 0, kMaxPrecision); + static constexpr absl::string_view kStringFormat = "string_format"; CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, StringValue, ListValue>:: - CreateDescriptor("format", /*receiver_style=*/true), + CreateDescriptor("format", kStringFormat, /*receiver_style=*/true), BinaryFunctionAdapter, StringValue, ListValue>:: WrapFunction( [max_precision]( - const StringValue& format, const ListValue& args, - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, - google::protobuf::MessageFactory* absl_nonnull message_factory, - google::protobuf::Arena* absl_nonnull arena) { - return Format(format, args, max_precision, descriptor_pool, - message_factory, arena); - }))); + const StringValue& format, const ListValue& args, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + return Format(format, args, max_precision, descriptor_pool, + message_factory, arena); + }))); return absl::OkStatus(); } diff --git a/extensions/lists_functions.cc b/extensions/lists_functions.cc index bfe05d887..20362ae32 100644 --- a/extensions/lists_functions.cc +++ b/extensions/lists_functions.cc @@ -487,20 +487,25 @@ absl::StatusOr ListSort( } absl::Status RegisterListDistinctFunction(FunctionRegistry& registry) { + static constexpr absl::string_view kListDistinct = "list_distinct"; return UnaryFunctionAdapter, const ListValue&>:: - RegisterMemberOverload("distinct", &ListDistinct, registry); + RegisterMemberOverload("distinct", kListDistinct, &ListDistinct, registry); } absl::Status RegisterListFlattenFunction(FunctionRegistry& registry) { + static constexpr absl::string_view kListFlattenInt = "list_flatten_int"; + static constexpr absl::string_view kListFlatten = "list_flatten"; CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter, const ListValue&, int64_t>::RegisterMemberOverload("flatten", + kListFlattenInt, &ListFlatten, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter, const ListValue&>:: RegisterMemberOverload( "flatten", + kListFlatten, [](const ListValue& list, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, @@ -513,26 +518,34 @@ absl::Status RegisterListFlattenFunction(FunctionRegistry& registry) { } absl::Status RegisterListRangeFunction(FunctionRegistry& registry) { + static constexpr absl::string_view kListRange = "list_range"; return UnaryFunctionAdapter, int64_t>::RegisterGlobalOverload("lists.range", + kListRange, &ListRange, registry); } absl::Status RegisterListReverseFunction(FunctionRegistry& registry) { + static constexpr absl::string_view kListReverse = "list_reverse"; return UnaryFunctionAdapter, const ListValue&>:: - RegisterMemberOverload("reverse", &ListReverse, registry); + RegisterMemberOverload("reverse", kListReverse, &ListReverse, registry); } absl::Status RegisterListSliceFunction(FunctionRegistry& registry) { + static constexpr absl::string_view kListSlice = "list_slice"; return TernaryFunctionAdapter, const ListValue&, int64_t, int64_t>::RegisterMemberOverload("slice", + kListSlice, &ListSlice, registry); } absl::Status RegisterListSortFunction(FunctionRegistry& registry) { + // Note: checker declares multiple overloads (list_int_sort, list_double_sort, etc.) + // but runtime has a single generic implementation. This is hybrid mode where + // N decls map to 1 impl, so we don't use overload_id (runtime matching). CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter, const ListValue&>:: RegisterMemberOverload("sort", &ListSort, registry))); diff --git a/extensions/math_ext.cc b/extensions/math_ext.cc index a7773da19..8aec5de09 100644 --- a/extensions/math_ext.cc +++ b/extensions/math_ext.cc @@ -172,27 +172,31 @@ absl::StatusOr MaxList( } template -absl::Status RegisterCrossNumericMin(FunctionRegistry& registry) { +absl::Status RegisterCrossNumericMin(absl::string_view overload_id_tu, + absl::string_view overload_id_ut, + FunctionRegistry& registry) { CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( - kMathMin, Min, registry))); + kMathMin, overload_id_tu, Min, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( - kMathMin, Min, registry))); + kMathMin, overload_id_ut, Min, registry))); return absl::OkStatus(); } template -absl::Status RegisterCrossNumericMax(FunctionRegistry& registry) { +absl::Status RegisterCrossNumericMax(absl::string_view overload_id_tu, + absl::string_view overload_id_ut, + FunctionRegistry& registry) { CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( - kMathMax, Max, registry))); + kMathMax, overload_id_tu, Max, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( - kMathMax, Max, registry))); + kMathMax, overload_id_ut, Max, registry))); return absl::OkStatus(); } @@ -314,27 +318,101 @@ Value BitShiftRightUint(uint64_t lhs, int64_t rhs) { absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, const RuntimeOptions& options, int version) { + // Overload IDs matching math_ext_decls.cc + static constexpr absl::string_view kMathMinInt = "math_@min_int"; + static constexpr absl::string_view kMathMinDouble = "math_@min_double"; + static constexpr absl::string_view kMathMinUint = "math_@min_uint"; + static constexpr absl::string_view kMathMinIntInt = "math_@min_int_int"; + static constexpr absl::string_view kMathMinIntUint = "math_@min_int_uint"; + static constexpr absl::string_view kMathMinIntDouble = "math_@min_int_double"; + static constexpr absl::string_view kMathMinUintInt = "math_@min_uint_int"; + static constexpr absl::string_view kMathMinUintUint = "math_@min_uint_uint"; + static constexpr absl::string_view kMathMinUintDouble = + "math_@min_uint_double"; + static constexpr absl::string_view kMathMinDoubleInt = "math_@min_double_int"; + static constexpr absl::string_view kMathMinDoubleUint = + "math_@min_double_uint"; + static constexpr absl::string_view kMathMinDoubleDouble = + "math_@min_double_double"; + + static constexpr absl::string_view kMathMaxInt = "math_@max_int"; + static constexpr absl::string_view kMathMaxDouble = "math_@max_double"; + static constexpr absl::string_view kMathMaxUint = "math_@max_uint"; + static constexpr absl::string_view kMathMaxIntInt = "math_@max_int_int"; + static constexpr absl::string_view kMathMaxIntUint = "math_@max_int_uint"; + static constexpr absl::string_view kMathMaxIntDouble = "math_@max_int_double"; + static constexpr absl::string_view kMathMaxUintInt = "math_@max_uint_int"; + static constexpr absl::string_view kMathMaxUintUint = "math_@max_uint_uint"; + static constexpr absl::string_view kMathMaxUintDouble = + "math_@max_uint_double"; + static constexpr absl::string_view kMathMaxDoubleInt = "math_@max_double_int"; + static constexpr absl::string_view kMathMaxDoubleUint = + "math_@max_double_uint"; + static constexpr absl::string_view kMathMaxDoubleDouble = + "math_@max_double_double"; + + static constexpr absl::string_view kMathCeilDouble = "math_ceil_double"; + static constexpr absl::string_view kMathFloorDouble = "math_floor_double"; + static constexpr absl::string_view kMathRoundDouble = "math_round_double"; + static constexpr absl::string_view kMathTruncDouble = "math_trunc_double"; + static constexpr absl::string_view kMathSqrtInt = "math_sqrt_int"; + static constexpr absl::string_view kMathSqrtUint = "math_sqrt_uint"; + static constexpr absl::string_view kMathSqrtDouble = "math_sqrt_double"; + static constexpr absl::string_view kMathIsInfDouble = "math_isInf_double"; + static constexpr absl::string_view kMathIsNaNDouble = "math_isNaN_double"; + static constexpr absl::string_view kMathIsFiniteDouble = + "math_isFinite_double"; + static constexpr absl::string_view kMathAbsInt = "math_abs_int"; + static constexpr absl::string_view kMathAbsUint = "math_abs_uint"; + static constexpr absl::string_view kMathAbsDouble = "math_abs_double"; + static constexpr absl::string_view kMathSignInt = "math_sign_int"; + static constexpr absl::string_view kMathSignUint = "math_sign_uint"; + static constexpr absl::string_view kMathSignDouble = "math_sign_double"; + static constexpr absl::string_view kMathBitAndIntInt = "math_bitAnd_int_int"; + static constexpr absl::string_view kMathBitAndUintUint = + "math_bitAnd_uint_uint"; + static constexpr absl::string_view kMathBitOrIntInt = "math_bitOr_int_int"; + static constexpr absl::string_view kMathBitOrUintUint = + "math_bitOr_uint_uint"; + static constexpr absl::string_view kMathBitXorIntInt = "math_bitXor_int_int"; + static constexpr absl::string_view kMathBitXorUintUint = + "math_bitXor_uint_uint"; + static constexpr absl::string_view kMathBitNotIntInt = "math_bitNot_int_int"; + static constexpr absl::string_view kMathBitNotUintUint = + "math_bitNot_uint_uint"; + static constexpr absl::string_view kMathBitShiftLeftIntInt = + "math_bitShiftLeft_int_int"; + static constexpr absl::string_view kMathBitShiftLeftUintInt = + "math_bitShiftLeft_uint_int"; + static constexpr absl::string_view kMathBitShiftRightIntInt = + "math_bitShiftRight_int_int"; + static constexpr absl::string_view kMathBitShiftRightUintInt = + "math_bitShiftRight_uint_int"; + CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - kMathMin, Identity, registry))); + kMathMin, kMathMinInt, Identity, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - kMathMin, Identity, registry))); + kMathMin, kMathMinDouble, Identity, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - kMathMin, Identity, registry))); + kMathMin, kMathMinUint, Identity, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( - kMathMin, Min, registry))); + kMathMin, kMathMinIntInt, Min, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( - kMathMin, Min, registry))); + kMathMin, kMathMinDoubleDouble, Min, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( - kMathMin, Min, registry))); - CEL_RETURN_IF_ERROR((RegisterCrossNumericMin(registry))); - CEL_RETURN_IF_ERROR((RegisterCrossNumericMin(registry))); - CEL_RETURN_IF_ERROR((RegisterCrossNumericMin(registry))); + kMathMin, kMathMinUintUint, Min, registry))); + CEL_RETURN_IF_ERROR((RegisterCrossNumericMin( + kMathMinIntUint, kMathMinUintInt, registry))); + CEL_RETURN_IF_ERROR((RegisterCrossNumericMin( + kMathMinIntDouble, kMathMinDoubleInt, registry))); + CEL_RETURN_IF_ERROR((RegisterCrossNumericMin( + kMathMinDoubleUint, kMathMinUintDouble, registry))); CEL_RETURN_IF_ERROR(( UnaryFunctionAdapter, ListValue>::RegisterGlobalOverload(kMathMin, MinList, @@ -342,25 +420,28 @@ absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - kMathMax, Identity, registry))); + kMathMax, kMathMaxInt, Identity, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - kMathMax, Identity, registry))); + kMathMax, kMathMaxDouble, Identity, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - kMathMax, Identity, registry))); + kMathMax, kMathMaxUint, Identity, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( - kMathMax, Max, registry))); + kMathMax, kMathMaxIntInt, Max, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( - kMathMax, Max, registry))); + kMathMax, kMathMaxDoubleDouble, Max, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( - kMathMax, Max, registry))); - CEL_RETURN_IF_ERROR((RegisterCrossNumericMax(registry))); - CEL_RETURN_IF_ERROR((RegisterCrossNumericMax(registry))); - CEL_RETURN_IF_ERROR((RegisterCrossNumericMax(registry))); + kMathMax, kMathMaxUintUint, Max, registry))); + CEL_RETURN_IF_ERROR((RegisterCrossNumericMax( + kMathMaxIntUint, kMathMaxUintInt, registry))); + CEL_RETURN_IF_ERROR((RegisterCrossNumericMax( + kMathMaxIntDouble, kMathMaxDoubleInt, registry))); + CEL_RETURN_IF_ERROR((RegisterCrossNumericMax( + kMathMaxDoubleUint, kMathMaxUintDouble, registry))); CEL_RETURN_IF_ERROR(( UnaryFunctionAdapter, ListValue>::RegisterGlobalOverload(kMathMax, MaxList, @@ -371,86 +452,87 @@ absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.ceil", CeilDouble, registry))); + "math.ceil", kMathCeilDouble, CeilDouble, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.floor", FloorDouble, registry))); + "math.floor", kMathFloorDouble, FloorDouble, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.round", RoundDouble, registry))); + "math.round", kMathRoundDouble, RoundDouble, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.trunc", TruncDouble, registry))); + "math.trunc", kMathTruncDouble, TruncDouble, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.isInf", IsInfDouble, registry))); + "math.isInf", kMathIsInfDouble, IsInfDouble, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.isNaN", IsNaNDouble, registry))); + "math.isNaN", kMathIsNaNDouble, IsNaNDouble, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.isFinite", IsFiniteDouble, registry))); + "math.isFinite", kMathIsFiniteDouble, IsFiniteDouble, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.abs", AbsDouble, registry))); + "math.abs", kMathAbsDouble, AbsDouble, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.abs", AbsInt, registry))); + "math.abs", kMathAbsInt, AbsInt, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.abs", AbsUint, registry))); + "math.abs", kMathAbsUint, AbsUint, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.sign", SignDouble, registry))); + "math.sign", kMathSignDouble, SignDouble, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.sign", SignInt, registry))); + "math.sign", kMathSignInt, SignInt, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.sign", SignUint, registry))); + "math.sign", kMathSignUint, SignUint, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( - "math.bitAnd", BitAndInt, registry))); + "math.bitAnd", kMathBitAndIntInt, BitAndInt, registry))); CEL_RETURN_IF_ERROR( - (BinaryFunctionAdapter::RegisterGlobalOverload("math.bitAnd", - BitAndUint, - registry))); + (BinaryFunctionAdapter:: + RegisterGlobalOverload("math.bitAnd", kMathBitAndUintUint, + BitAndUint, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( - "math.bitOr", BitOrInt, registry))); + "math.bitOr", kMathBitOrIntInt, BitOrInt, registry))); CEL_RETURN_IF_ERROR( - (BinaryFunctionAdapter::RegisterGlobalOverload("math.bitOr", - BitOrUint, - registry))); + (BinaryFunctionAdapter:: + RegisterGlobalOverload("math.bitOr", kMathBitOrUintUint, BitOrUint, + registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( - "math.bitXor", BitXorInt, registry))); + "math.bitXor", kMathBitXorIntInt, BitXorInt, registry))); CEL_RETURN_IF_ERROR( - (BinaryFunctionAdapter::RegisterGlobalOverload("math.bitXor", - BitXorUint, - registry))); + (BinaryFunctionAdapter:: + RegisterGlobalOverload("math.bitXor", kMathBitXorUintUint, + BitXorUint, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.bitNot", BitNotInt, registry))); + "math.bitNot", kMathBitNotIntInt, BitNotInt, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.bitNot", BitNotUint, registry))); + "math.bitNot", kMathBitNotUintUint, BitNotUint, registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( - "math.bitShiftLeft", BitShiftLeftInt, registry))); + "math.bitShiftLeft", kMathBitShiftLeftIntInt, BitShiftLeftInt, + registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( - "math.bitShiftLeft", BitShiftLeftUint, registry))); + "math.bitShiftLeft", kMathBitShiftLeftUintInt, BitShiftLeftUint, + registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( - "math.bitShiftRight", BitShiftRightInt, registry))); + "math.bitShiftRight", kMathBitShiftRightIntInt, BitShiftRightInt, + registry))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterGlobalOverload( - "math.bitShiftRight", BitShiftRightUint, registry))); + "math.bitShiftRight", kMathBitShiftRightUintInt, BitShiftRightUint, + registry))); if (version == 1) { return absl::OkStatus(); @@ -458,13 +540,13 @@ absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.sqrt", SqrtDouble, registry))); + "math.sqrt", kMathSqrtDouble, SqrtDouble, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.sqrt", SqrtInt, registry))); + "math.sqrt", kMathSqrtInt, SqrtInt, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.sqrt", SqrtUint, registry))); + "math.sqrt", kMathSqrtUint, SqrtUint, registry))); return absl::OkStatus(); } diff --git a/extensions/regex_ext.cc b/extensions/regex_ext.cc index 9c06d90c2..0559b6eb5 100644 --- a/extensions/regex_ext.cc +++ b/extensions/regex_ext.cc @@ -233,23 +233,32 @@ Value ReplaceN(int regex_max_program_size, const StringValue& target, absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry, bool disable_extract, int regex_max_program_size) { + // Overload IDs from decls + static constexpr absl::string_view kRegexExtractStringString = "regex_extract_string_string"; + static constexpr absl::string_view kRegexExtractAllStringString = "regex_extractAll_string_string"; + static constexpr absl::string_view kRegexReplaceStringStringString = "regex_replace_string_string_string"; + static constexpr absl::string_view kRegexReplaceStringStringStringInt = "regex_replace_string_string_string_int"; + if (!disable_extract) { CEL_RETURN_IF_ERROR(( BinaryFunctionAdapter, StringValue, StringValue>:: RegisterGlobalOverload( "regex.extract", + kRegexExtractStringString, absl::bind_front(&Extract, regex_max_program_size), registry))); } CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter, StringValue, StringValue>:: RegisterGlobalOverload( "regex.extractAll", + kRegexExtractAllStringString, absl::bind_front(&ExtractAll, regex_max_program_size), registry))); CEL_RETURN_IF_ERROR( (TernaryFunctionAdapter< absl::StatusOr, StringValue, StringValue, StringValue>::RegisterGlobalOverload("regex.replace", + kRegexReplaceStringStringString, absl::bind_front( &ReplaceAll, regex_max_program_size), @@ -259,6 +268,7 @@ absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry, StringValue, StringValue, int64_t>:: RegisterGlobalOverload( "regex.replace", + kRegexReplaceStringStringStringInt, absl::bind_front(&ReplaceN, regex_max_program_size), registry))); return absl::OkStatus(); } diff --git a/extensions/regex_functions.cc b/extensions/regex_functions.cc index 005987ae4..221c58620 100644 --- a/extensions/regex_functions.cc +++ b/extensions/regex_functions.cc @@ -151,11 +151,17 @@ absl::StatusOr CaptureStringN( absl::Status RegisterRegexFunctions(FunctionRegistry& registry, int max_regex_program_size) { + // Overload IDs from decls + static constexpr absl::string_view kReExtractStringStringString = "re_extract_string_string_string"; + static constexpr absl::string_view kReCaptureStringString = "re_capture_string_string"; + static constexpr absl::string_view kReCaptureNStringString = "re_captureN_string_string"; + // Register Regex Extract Function CEL_RETURN_IF_ERROR( (TernaryFunctionAdapter< absl::StatusOr, StringValue, StringValue, StringValue>::RegisterGlobalOverload(kRegexExtract, + kReExtractStringStringString, absl::bind_front( &ExtractString, max_regex_program_size), @@ -166,6 +172,7 @@ absl::Status RegisterRegexFunctions(FunctionRegistry& registry, (BinaryFunctionAdapter, StringValue, StringValue>:: RegisterGlobalOverload( kRegexCapture, + kReCaptureStringString, absl::bind_front(&CaptureString, max_regex_program_size), registry))); @@ -174,6 +181,7 @@ absl::Status RegisterRegexFunctions(FunctionRegistry& registry, (BinaryFunctionAdapter, StringValue, StringValue>:: RegisterGlobalOverload( kRegexCaptureN, + kReCaptureNStringString, absl::bind_front(&CaptureStringN, max_regex_program_size), registry))); return absl::OkStatus(); diff --git a/extensions/sets_functions.cc b/extensions/sets_functions.cc index ebe163550..29fc8410a 100644 --- a/extensions/sets_functions.cc +++ b/extensions/sets_functions.cc @@ -97,30 +97,36 @@ absl::StatusOr SetsEquivalent( } absl::Status RegisterSetsContainsFunction(FunctionRegistry& registry) { + static constexpr absl::string_view kListSetsContainsList = "list_sets_contains_list"; return registry.Register( BinaryFunctionAdapter< absl::StatusOr, const ListValue&, const ListValue&>::CreateDescriptor("sets.contains", + kListSetsContainsList, /*receiver_style=*/false), BinaryFunctionAdapter, const ListValue&, const ListValue&>::WrapFunction(SetsContains)); } absl::Status RegisterSetsIntersectsFunction(FunctionRegistry& registry) { + static constexpr absl::string_view kListSetsIntersectsList = "list_sets_intersects_list"; return registry.Register( BinaryFunctionAdapter< absl::StatusOr, const ListValue&, const ListValue&>::CreateDescriptor("sets.intersects", + kListSetsIntersectsList, /*receiver_style=*/false), BinaryFunctionAdapter, const ListValue&, const ListValue&>::WrapFunction(SetsIntersects)); } absl::Status RegisterSetsEquivalentFunction(FunctionRegistry& registry) { + static constexpr absl::string_view kListSetsEquivalentList = "list_sets_equivalent_list"; return registry.Register( BinaryFunctionAdapter< absl::StatusOr, const ListValue&, const ListValue&>::CreateDescriptor("sets.equivalent", + kListSetsEquivalentList, /*receiver_style=*/false), BinaryFunctionAdapter, const ListValue&, const ListValue&>::WrapFunction(SetsEquivalent)); diff --git a/extensions/strings.cc b/extensions/strings.cc index 54fda20d6..9a2a8b5ea 100644 --- a/extensions/strings.cc +++ b/extensions/strings.cc @@ -309,76 +309,106 @@ absl::Status RegisterStringsFunctions( FunctionRegistry& registry, const RuntimeOptions& options, const StringsExtensionOptions& extension_options) { const int version = extension_options.version; + // Overload IDs from decls + static constexpr absl::string_view kListJoin = "list_join"; + static constexpr absl::string_view kListJoinString = "list_join_string"; + static constexpr absl::string_view kStringSplitString = "string_split_string"; + static constexpr absl::string_view kStringSplitStringInt = + "string_split_string_int"; + static constexpr absl::string_view kStringLowerAscii = "string_lower_ascii"; + static constexpr absl::string_view kStringReplaceStringString = + "string_replace_string_string"; + static constexpr absl::string_view kStringReplaceStringStringInt = + "string_replace_string_string_int"; + static constexpr absl::string_view kStringCharAtInt = "string_char_at_int"; + static constexpr absl::string_view kStringIndexOfString = + "string_index_of_string"; + static constexpr absl::string_view kStringIndexOfStringInt = + "string_index_of_string_int"; + static constexpr absl::string_view kStringLastIndexOfString = + "string_last_index_of_string"; + static constexpr absl::string_view kStringLastIndexOfStringInt = + "string_last_index_of_string_int"; + static constexpr absl::string_view kStringSubstringInt = + "string_substring_int"; + static constexpr absl::string_view kStringSubstringIntInt = + "string_substring_int_int"; + static constexpr absl::string_view kStringUpperAscii = "string_upper_ascii"; + static constexpr absl::string_view kStringsQuote = "strings_quote"; + static constexpr absl::string_view kStringReverse = "string_reverse"; + static constexpr absl::string_view kStringTrim = "string_trim"; CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, StringValue, StringValue>:: - CreateDescriptor("split", /*receiver_style=*/true), + CreateDescriptor("split", kStringSplitString, + /*receiver_style=*/true), BinaryFunctionAdapter, StringValue, StringValue>::WrapFunction(Split2))); CEL_RETURN_IF_ERROR(registry.Register( TernaryFunctionAdapter< absl::StatusOr, StringValue, StringValue, - int64_t>::CreateDescriptor("split", /*receiver_style=*/true), + int64_t>::CreateDescriptor("split", kStringSplitStringInt, + /*receiver_style=*/true), TernaryFunctionAdapter, StringValue, StringValue, int64_t>::WrapFunction(Split3))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter, StringValue>:: - CreateDescriptor("lowerAscii", /*receiver_style=*/true), + CreateDescriptor("lowerAscii", kStringLowerAscii, + /*receiver_style=*/true), UnaryFunctionAdapter, StringValue>::WrapFunction( LowerAscii))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter, StringValue>:: - CreateDescriptor("upperAscii", /*receiver_style=*/true), + CreateDescriptor("upperAscii", kStringUpperAscii, + /*receiver_style=*/true), UnaryFunctionAdapter, StringValue>::WrapFunction( UpperAscii))); CEL_RETURN_IF_ERROR(registry.Register( TernaryFunctionAdapter< absl::StatusOr, StringValue, StringValue, - StringValue>::CreateDescriptor("replace", /*receiver_style=*/true), + StringValue>::CreateDescriptor("replace", kStringReplaceStringString, + /*receiver_style=*/true), TernaryFunctionAdapter, StringValue, StringValue, StringValue>::WrapFunction(Replace1))); CEL_RETURN_IF_ERROR(registry.Register( QuaternaryFunctionAdapter< absl::StatusOr, StringValue, StringValue, StringValue, - int64_t>::CreateDescriptor("replace", /*receiver_style=*/true), + int64_t>::CreateDescriptor("replace", kStringReplaceStringStringInt, + /*receiver_style=*/true), QuaternaryFunctionAdapter, StringValue, StringValue, StringValue, int64_t>::WrapFunction(Replace2))); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterMemberOverload("charAt", &CharAt, + int64_t>::RegisterMemberOverload("charAt", + kStringCharAtInt, + &CharAt, registry))); CEL_RETURN_IF_ERROR( - (BinaryFunctionAdapter::RegisterMemberOverload("indexOf", - &IndexOf2, - registry))); + (BinaryFunctionAdapter:: + RegisterMemberOverload("indexOf", kStringIndexOfString, &IndexOf2, + registry))); CEL_RETURN_IF_ERROR( - (TernaryFunctionAdapter::RegisterMemberOverload("indexOf", - &IndexOf3, - registry))); + (TernaryFunctionAdapter:: + RegisterMemberOverload("indexOf", kStringIndexOfStringInt, &IndexOf3, + registry))); CEL_RETURN_IF_ERROR( - (BinaryFunctionAdapter::RegisterMemberOverload("lastIndexOf", - &LastIndexOf2, - registry))); + (BinaryFunctionAdapter:: + RegisterMemberOverload("lastIndexOf", kStringLastIndexOfString, + &LastIndexOf2, registry))); CEL_RETURN_IF_ERROR( - (TernaryFunctionAdapter::RegisterMemberOverload("lastIndexOf", - &LastIndexOf3, - registry))); + (TernaryFunctionAdapter:: + RegisterMemberOverload("lastIndexOf", kStringLastIndexOfStringInt, + &LastIndexOf3, registry))); CEL_RETURN_IF_ERROR( - (BinaryFunctionAdapter::RegisterMemberOverload("substring", - &Substring2, - registry))); + (BinaryFunctionAdapter:: + RegisterMemberOverload("substring", kStringSubstringInt, &Substring2, + registry))); CEL_RETURN_IF_ERROR( - (TernaryFunctionAdapter::RegisterMemberOverload("substring", - &Substring3, - registry))); + (TernaryFunctionAdapter:: + RegisterMemberOverload("substring", kStringSubstringIntInt, + &Substring3, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterMemberOverload( - "trim", &Trim, registry))); + "trim", kStringTrim, &Trim, registry))); if (version == 0) { return absl::OkStatus(); } @@ -387,19 +417,19 @@ absl::Status RegisterStringsFunctions( registry, options, {extension_options.max_precision})); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - "strings.quote", &Quote, registry))); + "strings.quote", kStringsQuote, &Quote, registry))); if (version == 1) { return absl::OkStatus(); } CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter, ListValue>::CreateDescriptor( - "join", /*receiver_style=*/true), + "join", kListJoin, /*receiver_style=*/true), UnaryFunctionAdapter, ListValue>::WrapFunction( Join1))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, ListValue, StringValue>:: - CreateDescriptor("join", /*receiver_style=*/true), + CreateDescriptor("join", kListJoinString, /*receiver_style=*/true), BinaryFunctionAdapter, ListValue, StringValue>::WrapFunction(Join2))); if (version == 2) { @@ -408,7 +438,7 @@ absl::Status RegisterStringsFunctions( CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterMemberOverload( - "reverse", &Reverse, registry))); + "reverse", kStringReverse, &Reverse, registry))); return absl::OkStatus(); } diff --git a/runtime/activation.cc b/runtime/activation.cc index e999f7a02..3d9efaa80 100644 --- a/runtime/activation.cc +++ b/runtime/activation.cc @@ -91,15 +91,37 @@ std::vector Activation::FindFunctionOverloads( std::vector result; auto iter = functions_.find(name); if (iter != functions_.end()) { - const std::vector& overloads = iter->second; - result.reserve(overloads.size()); - for (const auto& overload : overloads) { + const FunctionEntry& entry = iter->second; + result.reserve(entry.overloads.size()); + for (const auto& overload : entry.overloads) { result.push_back({*overload.descriptor, *overload.implementation}); } } return result; } +absl::optional Activation::FindFunctionOverloadById( + absl::string_view name, absl::string_view overload_id) const { + if (overload_id.empty()) { + return absl::nullopt; + } + + auto functions_it = functions_.find(name); + if (functions_it == functions_.end()) { + return absl::nullopt; + } + + const FunctionEntry& entry = functions_it->second; + auto it = entry.overloads_by_id.find(overload_id); + if (it == entry.overloads_by_id.end()) { + return absl::nullopt; + } + + const OverloadEntry* overload_entry = it->second; + return FunctionOverloadReference{*overload_entry->descriptor, + *overload_entry->implementation}; +} + bool Activation::InsertOrAssignValue(absl::string_view name, Value value) { return values_ .insert_or_assign(name, ValueEntry{std::move(value), absl::nullopt}) @@ -115,14 +137,25 @@ bool Activation::InsertOrAssignValueProvider(absl::string_view name, bool Activation::InsertFunction(const cel::FunctionDescriptor& descriptor, std::unique_ptr impl) { - auto& overloads = functions_[descriptor.name()]; - for (auto& overload : overloads) { + auto& entry = functions_[descriptor.name()]; + + // Check for duplicate shape + for (auto& overload : entry.overloads) { if (overload.descriptor->ShapeMatches(descriptor)) { return false; } } - overloads.push_back( + + // Add to overloads vector + entry.overloads.push_back( {std::make_unique(descriptor), std::move(impl)}); + + // Add to overload ID index if overload_id is present + if (descriptor.has_overload_id()) { + OverloadEntry* overload_entry = &entry.overloads.back(); + entry.overloads_by_id[descriptor.overload_id()] = overload_entry; + } + return true; } diff --git a/runtime/activation.h b/runtime/activation.h index 8c4fb4073..3b899a05b 100644 --- a/runtime/activation.h +++ b/runtime/activation.h @@ -15,6 +15,7 @@ #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_ACTIVATION_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_ACTIVATION_H_ +#include #include #include #include @@ -75,6 +76,9 @@ class Activation final : public ActivationInterface { std::vector FindFunctionOverloads( absl::string_view name) const override; + absl::optional FindFunctionOverloadById( + absl::string_view name, absl::string_view overload_id) const override; + absl::Span GetUnknownAttributes() const override { return unknown_patterns_; @@ -126,11 +130,17 @@ class Activation final : public ActivationInterface { absl::optional provider; }; - struct FunctionEntry { + struct OverloadEntry { std::unique_ptr descriptor; std::unique_ptr implementation; }; + struct FunctionEntry { + std::deque overloads; + // O(1) lookup by overload ID within this function name. + absl::flat_hash_map overloads_by_id; + }; + friend class runtime_internal::ActivationAttributeMatcherAccess; void SetAttributeMatcher(const runtime_internal::AttributeMatcher* matcher) { @@ -176,7 +186,7 @@ class Activation final : public ActivationInterface { std::unique_ptr owned_attribute_matcher_; - absl::flat_hash_map> functions_; + absl::flat_hash_map functions_; }; } // namespace cel diff --git a/runtime/activation_interface.h b/runtime/activation_interface.h index c589468de..c1acb7bf5 100644 --- a/runtime/activation_interface.h +++ b/runtime/activation_interface.h @@ -72,6 +72,23 @@ class ActivationInterface { virtual std::vector FindFunctionOverloads( absl::string_view name) const = 0; + // Find a single context function overload by name and overload ID. + // Returns nullopt if not found or if overload_id is empty. + virtual absl::optional FindFunctionOverloadById( + absl::string_view name, absl::string_view overload_id) const { + if (overload_id.empty()) { + return absl::nullopt; + } + + auto overloads = FindFunctionOverloads(name); + for (const auto& overload : overloads) { + if (overload.descriptor.overload_id() == overload_id) { + return overload; + } + } + return absl::nullopt; + } + // Return a list of unknown attribute patterns. // // If an attribute (select path) encountered during evaluation matches any of diff --git a/runtime/activation_test.cc b/runtime/activation_test.cc index 4303116a3..92f9bd87c 100644 --- a/runtime/activation_test.cc +++ b/runtime/activation_test.cc @@ -287,11 +287,11 @@ TEST_F(ActivationTest, InsertFunctionOk) { UnorderedElementsAre( Truly([](const FunctionOverloadReference& ref) { return ref.descriptor.name() == "Fn" && - ref.descriptor.types() == std::vector{Kind::kUint}; + ref.descriptor.kinds() == std::vector{Kind::kUint}; }), Truly([](const FunctionOverloadReference& ref) { return ref.descriptor.name() == "Fn" && - ref.descriptor.types() == std::vector{Kind::kInt}; + ref.descriptor.kinds() == std::vector{Kind::kInt}; }))) << "expected overloads Fn(int), Fn(uint)"; } @@ -309,7 +309,7 @@ TEST_F(ActivationTest, InsertFunctionFails) { EXPECT_THAT(activation.FindFunctionOverloads("Fn"), ElementsAre(Truly([](const FunctionOverloadReference& ref) { return ref.descriptor.name() == "Fn" && - ref.descriptor.types() == std::vector{Kind::kAny}; + ref.descriptor.kinds() == std::vector{Kind::kAny}; }))) << "expected overload Fn(any)"; } diff --git a/runtime/function_adapter.h b/runtime/function_adapter.h index 62932a027..8645cb5ae 100644 --- a/runtime/function_adapter.h +++ b/runtime/function_adapter.h @@ -230,6 +230,14 @@ class NullaryFunctionAdapter return FunctionDescriptor(name, receiver_style, {}, options); } + // Descriptor with overload_id + static FunctionDescriptor CreateDescriptor( + absl::string_view name, absl::string_view overload_id, + bool receiver_style, FunctionDescriptorOptions options = {}) { + return FunctionDescriptor(name, overload_id, receiver_style, {}, + options); + } + private: class UnaryFunctionImpl : public Function { public: @@ -329,6 +337,15 @@ class UnaryFunctionAdapter : public RegisterHelper> { {runtime_internal::AdaptedKind()}, options); } + // Descriptor with overload_id + static FunctionDescriptor CreateDescriptor( + absl::string_view name, absl::string_view overload_id, + bool receiver_style, FunctionDescriptorOptions options = {}) { + return FunctionDescriptor(name, overload_id, receiver_style, + {runtime_internal::AdaptedKind()}, + options); + } + private: class UnaryFunctionImpl : public Function { public: @@ -480,6 +497,16 @@ class BinaryFunctionAdapter options); } + // Descriptor with overload_id + static FunctionDescriptor CreateDescriptor( + absl::string_view name, absl::string_view overload_id, + bool receiver_style, FunctionDescriptorOptions options = {}) { + return FunctionDescriptor(name, overload_id, receiver_style, + {runtime_internal::AdaptedKind(), + runtime_internal::AdaptedKind()}, + options); + } + private: class BinaryFunctionImpl : public Function { public: @@ -571,6 +598,17 @@ class TernaryFunctionAdapter options); } + // Descriptor with overload_id + static FunctionDescriptor CreateDescriptor( + absl::string_view name, absl::string_view overload_id, + bool receiver_style, FunctionDescriptorOptions options = {}) { + return FunctionDescriptor( + name, overload_id, receiver_style, + {runtime_internal::AdaptedKind(), runtime_internal::AdaptedKind(), + runtime_internal::AdaptedKind()}, + options); + } + private: class TernaryFunctionImpl : public Function { public: @@ -667,6 +705,18 @@ class QuaternaryFunctionAdapter options); } + // Descriptor with overload_id + static FunctionDescriptor CreateDescriptor( + absl::string_view name, absl::string_view overload_id, + bool receiver_style, FunctionDescriptorOptions options = {}) { + return FunctionDescriptor( + name, overload_id, receiver_style, + {runtime_internal::AdaptedKind(), runtime_internal::AdaptedKind(), + runtime_internal::AdaptedKind(), + runtime_internal::AdaptedKind()}, + options); + } + private: class QuaternaryFunctionImpl : public Function { public: diff --git a/runtime/function_adapter_test.cc b/runtime/function_adapter_test.cc index 910020fdf..a6513a468 100644 --- a/runtime/function_adapter_test.cc +++ b/runtime/function_adapter_test.cc @@ -282,7 +282,7 @@ TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorInt) { EXPECT_EQ(desc.name(), "Increment"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); - EXPECT_THAT(desc.types(), ElementsAre(Kind::kInt64)); + EXPECT_THAT(desc.kinds(), ElementsAre(Kind::kInt64)); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorDouble) { @@ -293,7 +293,7 @@ TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorDouble) { EXPECT_EQ(desc.name(), "Mult2"); EXPECT_TRUE(desc.is_strict()); EXPECT_TRUE(desc.receiver_style()); - EXPECT_THAT(desc.types(), ElementsAre(Kind::kDouble)); + EXPECT_THAT(desc.kinds(), ElementsAre(Kind::kDouble)); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorUint) { @@ -304,7 +304,7 @@ TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorUint) { EXPECT_EQ(desc.name(), "Increment"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); - EXPECT_THAT(desc.types(), ElementsAre(Kind::kUint64)); + EXPECT_THAT(desc.kinds(), ElementsAre(Kind::kUint64)); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorBool) { @@ -315,7 +315,7 @@ TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorBool) { EXPECT_EQ(desc.name(), "Not"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); - EXPECT_THAT(desc.types(), ElementsAre(Kind::kBool)); + EXPECT_THAT(desc.kinds(), ElementsAre(Kind::kBool)); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorTimestamp) { @@ -326,7 +326,7 @@ TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorTimestamp) { EXPECT_EQ(desc.name(), "AddMinute"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); - EXPECT_THAT(desc.types(), ElementsAre(Kind::kTimestamp)); + EXPECT_THAT(desc.kinds(), ElementsAre(Kind::kTimestamp)); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorDuration) { @@ -338,7 +338,7 @@ TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorDuration) { EXPECT_EQ(desc.name(), "AddFiveSeconds"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); - EXPECT_THAT(desc.types(), ElementsAre(Kind::kDuration)); + EXPECT_THAT(desc.kinds(), ElementsAre(Kind::kDuration)); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorString) { @@ -349,7 +349,7 @@ TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorString) { EXPECT_EQ(desc.name(), "Prepend"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); - EXPECT_THAT(desc.types(), ElementsAre(Kind::kString)); + EXPECT_THAT(desc.kinds(), ElementsAre(Kind::kString)); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorBytes) { @@ -360,7 +360,7 @@ TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorBytes) { EXPECT_EQ(desc.name(), "Prepend"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); - EXPECT_THAT(desc.types(), ElementsAre(Kind::kBytes)); + EXPECT_THAT(desc.kinds(), ElementsAre(Kind::kBytes)); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorAny) { @@ -371,7 +371,7 @@ TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorAny) { EXPECT_EQ(desc.name(), "Increment"); EXPECT_TRUE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); - EXPECT_THAT(desc.types(), ElementsAre(Kind::kAny)); + EXPECT_THAT(desc.kinds(), ElementsAre(Kind::kAny)); } TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorNonStrict) { @@ -383,7 +383,7 @@ TEST_F(FunctionAdapterTest, UnaryFunctionAdapterCreateDescriptorNonStrict) { EXPECT_EQ(desc.name(), "Increment"); EXPECT_FALSE(desc.is_strict()); EXPECT_FALSE(desc.receiver_style()); - EXPECT_THAT(desc.types(), ElementsAre(Kind::kAny)); + EXPECT_THAT(desc.kinds(), ElementsAre(Kind::kAny)); } TEST_F(FunctionAdapterTest, BinaryFunctionAdapterWrapFunctionInt) { diff --git a/runtime/function_registry.cc b/runtime/function_registry.cc index 59f267255..c68cd7bc2 100644 --- a/runtime/function_registry.cc +++ b/runtime/function_registry.cc @@ -24,11 +24,13 @@ #include "absl/container/node_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "common/function_descriptor.h" #include "common/kind.h" +#include "common/type.h" #include "runtime/activation_interface.h" #include "runtime/function.h" #include "runtime/function_overload_reference.h" @@ -47,6 +49,13 @@ class ActivationFunctionProviderImpl absl::StatusOr> GetFunction( const cel::FunctionDescriptor& descriptor, const cel::ActivationInterface& activation) const override { + // Branch 1: If descriptor has overload_id, use O(1) precise lookup + if (descriptor.has_overload_id()) { + return activation.FindFunctionOverloadById(descriptor.name(), + descriptor.overload_id()); + } + + // Branch 2: No overload_id, fallback to signature matching (legacy logic) std::vector overloads = activation.FindFunctionOverloads(descriptor.name()); @@ -82,16 +91,34 @@ absl::Status FunctionRegistry::Register( if (DescriptorRegistered(descriptor)) { return absl::Status( absl::StatusCode::kAlreadyExists, - "CelFunction with specified parameters already registered"); + absl::StrCat("CelFunction with specified parameters already registered: ", + descriptor.name())); } if (!ValidateNonStrictOverload(descriptor)) { return absl::Status(absl::StatusCode::kAlreadyExists, - "Only one overload is allowed for non-strict function"); + absl::StrCat("Only one overload is allowed for non-strict function: ", + descriptor.name())); + } + + // Check for overload ID conflicts within this function name + if (descriptor.has_overload_id() && + IsOverloadIdConflict(descriptor.name(), descriptor.overload_id())) { + return absl::Status( + absl::StatusCode::kAlreadyExists, + absl::StrCat("Overload ID already registered: ", + descriptor.overload_id())); } auto& overloads = functions_[descriptor.name()]; overloads.static_overloads.push_back( StaticFunctionEntry(descriptor, std::move(implementation))); + + // Add to overload ID index if overload_id is present + if (descriptor.has_overload_id()) { + StaticFunctionEntry* entry = &overloads.static_overloads.back(); + overloads.static_overloads_by_id[descriptor.overload_id()] = entry; + } + return absl::OkStatus(); } @@ -100,24 +127,54 @@ absl::Status FunctionRegistry::RegisterLazyFunction( if (DescriptorRegistered(descriptor)) { return absl::Status( absl::StatusCode::kAlreadyExists, - "CelFunction with specified parameters already registered"); + absl::StrCat("CelFunction with specified parameters already registered: ", + descriptor.name())); } if (!ValidateNonStrictOverload(descriptor)) { return absl::Status(absl::StatusCode::kAlreadyExists, - "Only one overload is allowed for non-strict function"); + absl::StrCat("Only one overload is allowed for non-strict function: ", + descriptor.name())); + } + + // Check for overload ID conflicts within this function name + if (descriptor.has_overload_id() && + IsOverloadIdConflict(descriptor.name(), descriptor.overload_id())) { + return absl::Status( + absl::StatusCode::kAlreadyExists, + absl::StrCat("Overload ID already registered: ", + descriptor.overload_id())); } + auto& overloads = functions_[descriptor.name()]; overloads.lazy_overloads.push_back( LazyFunctionEntry(descriptor, CreateActivationFunctionProvider())); + // Add to overload ID index if overload_id is present + if (descriptor.has_overload_id()) { + LazyFunctionEntry* entry = &overloads.lazy_overloads.back(); + overloads.lazy_overloads_by_id[descriptor.overload_id()] = entry; + } + return absl::OkStatus(); } std::vector FunctionRegistry::FindStaticOverloads(absl::string_view name, bool receiver_style, - absl::Span types) const { + absl::Span kinds) const { + std::vector types; + types.reserve(kinds.size()); + for (const auto& kind : kinds) { + types.push_back(cel::Type(kind)); + } + return FindStaticOverloadsByTypes(name, receiver_style, types); +} + +std::vector +FunctionRegistry::FindStaticOverloadsByTypes(absl::string_view name, + bool receiver_style, + absl::Span types) const { std::vector matched_funcs; auto overloads = functions_.find(name); @@ -134,6 +191,26 @@ FunctionRegistry::FindStaticOverloads(absl::string_view name, return matched_funcs; } +absl::optional +FunctionRegistry::FindStaticOverloadById(absl::string_view name, + absl::string_view overload_id) const { + auto functions_it = functions_.find(name); + if (functions_it == functions_.end()) { + return absl::nullopt; + } + + const RegistryEntry& entry = functions_it->second; + auto it = entry.static_overloads_by_id.find(overload_id); + if (it == entry.static_overloads_by_id.end()) { + return absl::nullopt; + } + + const StaticFunctionEntry* func_entry = it->second; + return cel::FunctionOverloadReference{*func_entry->descriptor, + *func_entry->implementation}; +} + + std::vector FunctionRegistry::FindStaticOverloadsByArity(absl::string_view name, bool receiver_style, @@ -157,7 +234,19 @@ FunctionRegistry::FindStaticOverloadsByArity(absl::string_view name, std::vector FunctionRegistry::FindLazyOverloads( absl::string_view name, bool receiver_style, - absl::Span types) const { + absl::Span kinds) const { + std::vector types; + types.reserve(kinds.size()); + for (const auto& kind : kinds) { + types.push_back(cel::Type(kind)); + } + return FindLazyOverloadsByTypes(name, receiver_style, types); +} + +std::vector FunctionRegistry::FindLazyOverloadsByTypes( + absl::string_view name, + bool receiver_style, + absl::Span types) const { std::vector matched_funcs; auto overloads = functions_.find(name); @@ -174,6 +263,24 @@ std::vector FunctionRegistry::FindLazyOverloads( return matched_funcs; } +absl::optional +FunctionRegistry::FindLazyOverloadById(absl::string_view name, + absl::string_view overload_id) const { + auto functions_it = functions_.find(name); + if (functions_it == functions_.end()) { + return absl::nullopt; + } + + const RegistryEntry& entry = functions_it->second; + auto it = entry.lazy_overloads_by_id.find(overload_id); + if (it == entry.lazy_overloads_by_id.end()) { + return absl::nullopt; + } + + const LazyFunctionEntry* func_entry = it->second; + return LazyOverload{*func_entry->descriptor, *func_entry->function_provider}; +} + std::vector FunctionRegistry::FindLazyOverloadsByArity(absl::string_view name, bool receiver_style, @@ -260,4 +367,15 @@ bool FunctionRegistry::ValidateNonStrictOverload( entry.lazy_overloads[0].descriptor->is_strict()); } +bool FunctionRegistry::IsOverloadIdConflict( + absl::string_view name, absl::string_view overload_id) const { + auto it = functions_.find(name); + if (it == functions_.end()) { + return false; + } + const RegistryEntry& entry = it->second; + return entry.static_overloads_by_id.contains(overload_id) || + entry.lazy_overloads_by_id.contains(overload_id); +} + } // namespace cel diff --git a/runtime/function_registry.h b/runtime/function_registry.h index 6a227978d..04b193d9c 100644 --- a/runtime/function_registry.h +++ b/runtime/function_registry.h @@ -28,6 +28,7 @@ #include "absl/types/span.h" #include "common/function_descriptor.h" #include "common/kind.h" +#include "common/type.h" #include "runtime/function.h" #include "runtime/function_overload_reference.h" #include "runtime/function_provider.h" @@ -84,6 +85,15 @@ class FunctionRegistry { absl::string_view name, bool receiver_style, absl::Span types) const; + std::vector FindStaticOverloadsByTypes( + absl::string_view name, bool receiver_style, + absl::Span types) const; + + // Find a static function by overload ID (O(1) lookup). + // Returns nullopt if not found or if the overload ID doesn't match the name. + absl::optional FindStaticOverloadById( + absl::string_view name, absl::string_view overload_id) const; + std::vector FindStaticOverloadsByArity( absl::string_view name, bool receiver_style, size_t arity) const; @@ -102,6 +112,15 @@ class FunctionRegistry { absl::string_view name, bool receiver_style, absl::Span types) const; + std::vector FindLazyOverloadsByTypes( + absl::string_view name, bool receiver_style, + absl::Span types) const; + + // Find a lazy function by overload ID (O(1) lookup). + // Returns nullopt if not found or if the overload ID doesn't match the name. + absl::optional FindLazyOverloadById( + absl::string_view name, absl::string_view overload_id) const; + std::vector FindLazyOverloadsByArity(absl::string_view name, bool receiver_style, size_t arity) const; @@ -138,8 +157,16 @@ class FunctionRegistry { }; struct RegistryEntry { - std::vector static_overloads; - std::vector lazy_overloads; + // Use deque instead of vector to guarantee pointer stability. + // This allows safe storage of pointers to entries in the by-ID indexes. + std::deque static_overloads; + std::deque lazy_overloads; + + // O(1) lookup by overload ID within this function name. + // Points to entries in static_overloads/lazy_overloads. + absl::flat_hash_map + static_overloads_by_id; + absl::flat_hash_map lazy_overloads_by_id; }; // Returns whether the descriptor is registered either as a lazy function or @@ -151,6 +178,11 @@ class FunctionRegistry { bool ValidateNonStrictOverload( const cel::FunctionDescriptor& descriptor) const; + // Check if an overload ID conflicts with already registered static or lazy + // functions under the given function name. + bool IsOverloadIdConflict(absl::string_view name, + absl::string_view overload_id) const; + // indexed by function name (not type checker overload id). absl::flat_hash_map functions_; }; diff --git a/runtime/function_registry_test.cc b/runtime/function_registry_test.cc index 53916777a..ecec16de7 100644 --- a/runtime/function_registry_test.cc +++ b/runtime/function_registry_test.cc @@ -24,7 +24,9 @@ #include "absl/types/span.h" #include "common/function_descriptor.h" #include "common/kind.h" +#include "common/type.h" #include "common/value.h" +#include "google/protobuf/arena.h" #include "internal/testing.h" #include "runtime/activation.h" #include "runtime/function.h" @@ -152,7 +154,7 @@ TEST(FunctionRegistryTest, DefaultLazyProviderReturnsImpl) { ASSERT_TRUE(func.has_value()); EXPECT_EQ(func->descriptor.name(), "LazyFunction"); - EXPECT_EQ(func->descriptor.types(), std::vector{cel::Kind::kInt64}); + EXPECT_EQ(func->descriptor.kinds(), std::vector{cel::Kind::kInt64}); } TEST(FunctionRegistryTest, DefaultLazyProviderAmbiguousOverload) { @@ -297,6 +299,206 @@ INSTANTIATE_TEST_SUITE_P(NonStrictRegistrationFailTest, NonStrictRegistrationFailTest, testing::Combine(testing::Bool(), testing::Bool())); +// Test type-level overload resolution: distinguish list vs list +TEST(FunctionRegistryTest, TypeLevelOverloadResolution_ListTypes) { + google::protobuf::Arena arena; + FunctionRegistry registry; + + // Register fn(list) + ListType list_int_type(&arena, IntType{}); + cel::FunctionDescriptor desc_list_int("fn", "fn_int", false, + std::vector{list_int_type}, true); + + class ListIntFunction : public cel::Function { + public: + absl::StatusOr Invoke(absl::Span args, + const InvokeContext& context) const override { + return IntValue(1); // Return 1 for list + } + }; + + ASSERT_OK( + registry.Register(desc_list_int, std::make_unique())); + + // Register fn(list) + ListType list_string_type(&arena, StringType{}); + cel::FunctionDescriptor desc_list_string( + "fn", "fn_string", false, std::vector{list_string_type}, true); + + class ListStringFunction : public cel::Function { + public: + absl::StatusOr Invoke(absl::Span args, + const InvokeContext& context) const override { + return IntValue(2); // Return 2 for list + } + }; + + ASSERT_OK(registry.Register(desc_list_string, + std::make_unique())); + + // Verify both overloads are registered using Kind-based lookup + auto overloads = registry.FindStaticOverloads("fn", false, {Kind::kList}); + EXPECT_THAT(overloads, SizeIs(2)); + + // Verify type-level lookup can distinguish them + auto list_int_overloads = + registry.FindStaticOverloadsByTypes("fn", false, {list_int_type}); + EXPECT_THAT(list_int_overloads, SizeIs(1)); + EXPECT_EQ(list_int_overloads[0].descriptor.types()[0].DebugString(), + "list"); + + auto list_string_overloads = + registry.FindStaticOverloadsByTypes("fn", false, {list_string_type}); + EXPECT_THAT(list_string_overloads, SizeIs(1)); + EXPECT_EQ(list_string_overloads[0].descriptor.types()[0].DebugString(), + "list"); +} + +// Test type-level overload resolution: map vs map +TEST(FunctionRegistryTest, TypeLevelOverloadResolution_MapTypes) { + google::protobuf::Arena arena; + FunctionRegistry registry; + + // Register fn(map) + MapType map_string_int_type(&arena, StringType{}, IntType{}); + cel::FunctionDescriptor desc_map_int( + "fn", "fn_map_int", false, std::vector{map_string_int_type}, true); + + class MapIntFunction : public cel::Function { + public: + absl::StatusOr Invoke(absl::Span args, + const InvokeContext& context) const override { + return IntValue(1); + } + }; + + ASSERT_OK( + registry.Register(desc_map_int, std::make_unique())); + + // Register fn(map) + MapType map_string_string_type(&arena, StringType{}, StringType{}); + cel::FunctionDescriptor desc_map_string( + "fn", "fn_map_string", false, std::vector{map_string_string_type}, + true); + + class MapStringFunction : public cel::Function { + public: + absl::StatusOr Invoke(absl::Span args, + const InvokeContext& context) const override { + return IntValue(2); + } + }; + + ASSERT_OK(registry.Register(desc_map_string, + std::make_unique())); + + // Verify both are registered + auto overloads = registry.FindStaticOverloads("fn", false, {Kind::kMap}); + EXPECT_THAT(overloads, SizeIs(2)); + + // Verify type-level lookup + auto map_int_overloads = + registry.FindStaticOverloadsByTypes("fn", false, {map_string_int_type}); + EXPECT_THAT(map_int_overloads, SizeIs(1)); + EXPECT_EQ(map_int_overloads[0].descriptor.types()[0].DebugString(), + "map"); + + auto map_string_overloads = registry.FindStaticOverloadsByTypes( + "fn", false, {map_string_string_type}); + EXPECT_THAT(map_string_overloads, SizeIs(1)); + EXPECT_EQ(map_string_overloads[0].descriptor.types()[0].DebugString(), + "map"); +} + +// Test wildcard types: dyn, any, error +TEST(FunctionRegistryTest, TypeLevelOverloadResolution_WildcardTypes) { + google::protobuf::Arena arena; + FunctionRegistry registry; + + // Register fn(dyn) - should match anything + cel::FunctionDescriptor desc_dyn("fn", "fn_dyn", false, + std::vector{DynType{}}, true); + + class DynFunction : public cel::Function { + public: + absl::StatusOr Invoke(absl::Span args, + const InvokeContext& context) const override { + return IntValue(1); + } + }; + + ASSERT_OK(registry.Register(desc_dyn, std::make_unique())); + + // Verify it matches various types + auto int_overloads = + registry.FindStaticOverloadsByTypes("fn", false, {IntType{}}); + EXPECT_THAT(int_overloads, SizeIs(1)); + + auto string_overloads = + registry.FindStaticOverloadsByTypes("fn", false, {StringType{}}); + EXPECT_THAT(string_overloads, SizeIs(1)); + + ListType list_type(&arena, IntType{}); + auto list_overloads = + registry.FindStaticOverloadsByTypes("fn", false, {list_type}); + EXPECT_THAT(list_overloads, SizeIs(1)); +} + +// Test that list acts as wildcard for any list +TEST(FunctionRegistryTest, TypeLevelOverloadResolution_ListDynWildcard) { + google::protobuf::Arena arena; + FunctionRegistry registry; + + // Register fn(list) + ListType list_dyn_type(&arena, DynType{}); + cel::FunctionDescriptor desc("fn", "fn_list_dyn", false, + std::vector{list_dyn_type}, true); + + class ListDynFunction : public cel::Function { + public: + absl::StatusOr Invoke(absl::Span args, + const InvokeContext& context) const override { + return IntValue(1); + } + }; + + ASSERT_OK(registry.Register(desc, std::make_unique())); + + // Should match list, list, etc. + ListType list_int(&arena, IntType{}); + auto list_int_overloads = + registry.FindStaticOverloadsByTypes("fn", false, {list_int}); + EXPECT_THAT(list_int_overloads, SizeIs(1)); + + ListType list_string(&arena, StringType{}); + auto list_string_overloads = + registry.FindStaticOverloadsByTypes("fn", false, {list_string}); + EXPECT_THAT(list_string_overloads, SizeIs(1)); +} + +// Test empty struct type acts as wildcard +TEST(FunctionRegistryTest, TypeLevelOverloadResolution_StructWildcard) { + FunctionRegistry registry; + + // Register fn(struct) - empty MessageType acts as wildcard + cel::FunctionDescriptor desc("fn", "fn_struct", false, + std::vector{MessageType{}}, true); + + class StructFunction : public cel::Function { + public: + absl::StatusOr Invoke(absl::Span args, + const InvokeContext& context) const override { + return IntValue(1); + } + }; + + ASSERT_OK(registry.Register(desc, std::make_unique())); + + // Should match any struct type + auto overloads = registry.FindStaticOverloads("fn", false, {Kind::kStruct}); + EXPECT_THAT(overloads, SizeIs(1)); +} + } // namespace } // namespace cel diff --git a/runtime/optional_types.cc b/runtime/optional_types.cc index 6678a05ed..2143364eb 100644 --- a/runtime/optional_types.cc +++ b/runtime/optional_types.cc @@ -81,6 +81,41 @@ absl::StatusOr OptionalHasValue(const OpaqueValue& opaque_value) { runtime_internal::CreateNoMatchingOverloadError("hasValue")}; } +absl::StatusOr OptionalOr( + const OpaqueValue& opaque_value1, const OpaqueValue& opaque_value2, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + if (auto optional_value1 = opaque_value1.AsOptional(); optional_value1) { + if (optional_value1->HasValue()) { + return optional_value1->Value(); + } + if (auto optional_value2 = opaque_value2.AsOptional(); optional_value2) { + if (optional_value2->HasValue()) { + return optional_value2->Value(); + } + return OptionalValue::None(); + } + return OptionalValue::None(); + } + return ErrorValue{runtime_internal::CreateNoMatchingOverloadError("or")}; +} + +absl::StatusOr OptionalOrValue( + const OpaqueValue& opaque_value, const Value& value, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + if (auto optional_value = opaque_value.AsOptional(); optional_value) { + if (optional_value->HasValue()) { + return optional_value->Value(); + } else { + return value; + } + } + return ErrorValue{runtime_internal::CreateNoMatchingOverloadError("orValue")}; +} + absl::StatusOr SelectOptionalFieldStruct( const StructValue& struct_value, const StringValue& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -230,9 +265,9 @@ absl::StatusOr OptionalOptIndexOptionalValue( } absl::StatusOr ListFirst(const cel::ListValue& list, - const google::protobuf::DescriptorPool* descriptor_pool, - google::protobuf::MessageFactory* message_factory, - google::protobuf::Arena* arena) { + const google::protobuf::DescriptorPool* descriptor_pool, + google::protobuf::MessageFactory* message_factory, + google::protobuf::Arena* arena) { CEL_ASSIGN_OR_RETURN(size_t size, list.Size()); if (size == 0) { return Value(OptionalValue::None()); @@ -243,9 +278,9 @@ absl::StatusOr ListFirst(const cel::ListValue& list, } absl::StatusOr ListLast(const cel::ListValue& list, - const google::protobuf::DescriptorPool* descriptor_pool, - google::protobuf::MessageFactory* message_factory, - google::protobuf::Arena* arena) { + const google::protobuf::DescriptorPool* descriptor_pool, + google::protobuf::MessageFactory* message_factory, + google::protobuf::Arena* arena) { CEL_ASSIGN_OR_RETURN(size_t size, list.Size()); if (size == 0) { return Value(OptionalValue::None()); @@ -297,26 +332,53 @@ absl::Status RegisterOptionalTypeFunctions(FunctionRegistry& registry, return absl::FailedPreconditionError( "optional_type requires RuntimeOptions.enable_heterogeneous_equality"); } + + // Optional overload IDs from checker/optional.cc + static constexpr absl::string_view kOptionalOf = "optional_of"; + static constexpr absl::string_view kOptionalOfNonZeroValue = + "optional_ofNonZeroValue"; + static constexpr absl::string_view kOptionalNone = "optional_none"; + static constexpr absl::string_view kOptionalValue = "optional_value"; + static constexpr absl::string_view kOptionalHasValue = "optional_hasValue"; + static constexpr absl::string_view kOptionalOr = "optional_or_optional"; + static constexpr absl::string_view kOptionalOrValue = + "optional_orValue_value"; + static constexpr absl::string_view kMapOptionalIndexValue = + "map_optindex_optional_value"; + static constexpr absl::string_view kListOptionalIndexInt = + "list_optindex_optional_int"; + static constexpr absl::string_view kListFirst = "list_first"; + static constexpr absl::string_view kListLast = "list_last"; + static constexpr absl::string_view kOptionalUnwrapList = + "optional_unwrap_list"; + static constexpr absl::string_view kOptionalUnwrapOptList = + "optional_unwrapOpt_list"; + CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter::CreateDescriptor("optional.of", + kOptionalOf, false), UnaryFunctionAdapter::WrapFunction(&OptionalOf))); - CEL_RETURN_IF_ERROR( - registry.Register(UnaryFunctionAdapter::CreateDescriptor( - "optional.ofNonZeroValue", false), - UnaryFunctionAdapter::WrapFunction( - &OptionalOfNonZeroValue))); CEL_RETURN_IF_ERROR(registry.Register( - NullaryFunctionAdapter::CreateDescriptor("optional.none", false), + UnaryFunctionAdapter::CreateDescriptor( + "optional.ofNonZeroValue", kOptionalOfNonZeroValue, false), + UnaryFunctionAdapter::WrapFunction( + &OptionalOfNonZeroValue))); + CEL_RETURN_IF_ERROR(registry.Register( + NullaryFunctionAdapter::CreateDescriptor("optional.none", + kOptionalNone, false), NullaryFunctionAdapter::WrapFunction(&OptionalNone))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter, - OpaqueValue>::CreateDescriptor("value", true), + OpaqueValue>::CreateDescriptor("value", + kOptionalValue, true), UnaryFunctionAdapter, OpaqueValue>::WrapFunction( &OptionalGetValue))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter, - OpaqueValue>::CreateDescriptor("hasValue", true), + OpaqueValue>::CreateDescriptor("hasValue", + kOptionalHasValue, + true), UnaryFunctionAdapter, OpaqueValue>::WrapFunction( &OptionalHasValue))); CEL_RETURN_IF_ERROR(registry.Register( @@ -336,12 +398,16 @@ absl::Status RegisterOptionalTypeFunctions(FunctionRegistry& registry, StringValue>::WrapFunction(&SelectOptionalField))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, MapValue, - Value>::CreateDescriptor("_[?_]", false), + Value>::CreateDescriptor("_[?_]", + kMapOptionalIndexValue, + false), BinaryFunctionAdapter, MapValue, Value>::WrapFunction(&MapOptIndexOptionalValue))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, ListValue, - int64_t>::CreateDescriptor("_[?_]", false), + int64_t>::CreateDescriptor("_[?_]", + kListOptionalIndexInt, + false), BinaryFunctionAdapter, ListValue, int64_t>::WrapFunction(&ListOptIndexOptionalInt))); CEL_RETURN_IF_ERROR(registry.Register( @@ -351,24 +417,36 @@ absl::Status RegisterOptionalTypeFunctions(FunctionRegistry& registry, WrapFunction(&OptionalOptIndexOptionalValue))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter, ListValue>::CreateDescriptor( - "optional.unwrap", false), + "optional.unwrap", kOptionalUnwrapList, false), UnaryFunctionAdapter, ListValue>::WrapFunction( &ListUnwrapOpt))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter, ListValue>::CreateDescriptor( - "unwrapOpt", true), + "unwrapOpt", kOptionalUnwrapOptList, true), UnaryFunctionAdapter, ListValue>::WrapFunction( &ListUnwrapOpt))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter, ListValue>::CreateDescriptor( - "first", true), + "first", kListFirst, true), UnaryFunctionAdapter, ListValue>::WrapFunction( &ListFirst))); CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter, ListValue>::CreateDescriptor( - "last", true), + "last", kListLast, true), UnaryFunctionAdapter, ListValue>::WrapFunction( &ListLast))); + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter, OpaqueValue, + OpaqueValue>::CreateDescriptor("or", kOptionalOr, + true), + BinaryFunctionAdapter, OpaqueValue, + OpaqueValue>::WrapFunction(&OptionalOr))); + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter, OpaqueValue, + Value>::CreateDescriptor("orValue", + kOptionalOrValue, true), + BinaryFunctionAdapter, OpaqueValue, + Value>::WrapFunction(&OptionalOrValue))); return absl::OkStatus(); } diff --git a/runtime/optional_types_test.cc b/runtime/optional_types_test.cc index 455e51988..d5ad0d45d 100644 --- a/runtime/optional_types_test.cc +++ b/runtime/optional_types_test.cc @@ -65,33 +65,33 @@ using ::testing::TestWithParam; MATCHER_P(MatchesOptionalReceiver1, name, "") { const FunctionDescriptor& descriptor = arg.descriptor; - std::vector types{Kind::kOpaque}; + std::vector kinds{Kind::kOpaque}; return descriptor.name() == name && descriptor.receiver_style() == true && - descriptor.types() == types; + descriptor.kinds() == kinds; } MATCHER_P2(MatchesOptionalReceiver2, name, kind, "") { const FunctionDescriptor& descriptor = arg.descriptor; - std::vector types{Kind::kOpaque, kind}; + std::vector kinds{Kind::kOpaque, kind}; return descriptor.name() == name && descriptor.receiver_style() == true && - descriptor.types() == types; + descriptor.kinds() == kinds; } MATCHER_P2(MatchesOptionalSelect, kind1, kind2, "") { const FunctionDescriptor& descriptor = arg.descriptor; - std::vector types{kind1, kind2}; + std::vector kinds{kind1, kind2}; return descriptor.name() == "_?._" && descriptor.receiver_style() == false && - descriptor.types() == types; + descriptor.kinds() == kinds; } MATCHER_P2(MatchesOptionalIndex, kind1, kind2, "") { const FunctionDescriptor& descriptor = arg.descriptor; - std::vector types{kind1, kind2}; + std::vector kinds{kind1, kind2}; return descriptor.name() == "_[?_]" && descriptor.receiver_style() == false && - descriptor.types() == types; + descriptor.kinds() == kinds; } TEST(EnableOptionalTypes, HeterogeneousEqualityRequired) { diff --git a/runtime/register_function_helper.h b/runtime/register_function_helper.h index 8cc133abc..2f21a4fbd 100644 --- a/runtime/register_function_helper.h +++ b/runtime/register_function_helper.h @@ -60,6 +60,17 @@ class RegisterHelper { AdapterT::WrapFunction(std::forward(fn))); } + // Generic registration with overload_id + template + static absl::Status Register(absl::string_view name, absl::string_view overload_id, + bool receiver_style, + FunctionT&& fn, FunctionRegistry& registry, + FunctionDescriptorOptions options = {}) { + return registry.Register( + AdapterT::CreateDescriptor(name, overload_id, receiver_style, options), + AdapterT::WrapFunction(std::forward(fn))); + } + // Registers a global overload (.e.g. size() ) template static absl::Status RegisterGlobalOverload(absl::string_view name, @@ -69,6 +80,17 @@ class RegisterHelper { registry); } + // Registers a global overload with overload_id + template + static absl::Status RegisterGlobalOverload(absl::string_view name, + absl::string_view overload_id, + FunctionT&& fn, + FunctionRegistry& registry) { + return registry.Register( + AdapterT::CreateDescriptor(name, overload_id, /*receiver_style=*/false), + AdapterT::WrapFunction(std::forward(fn))); + } + // Registers a member overload (.e.g. .size()) template static absl::Status RegisterMemberOverload(absl::string_view name, @@ -78,6 +100,17 @@ class RegisterHelper { registry); } + // Registers a member overload with overload_id + template + static absl::Status RegisterMemberOverload(absl::string_view name, + absl::string_view overload_id, + FunctionT&& fn, + FunctionRegistry& registry) { + return registry.Register( + AdapterT::CreateDescriptor(name, overload_id, /*receiver_style=*/true), + AdapterT::WrapFunction(std::forward(fn))); + } + // Registers a non-strict overload. // // Non-strict functions may receive errors or unknown values as arguments, @@ -92,6 +125,19 @@ class RegisterHelper { return Register(name, /*receiver_style=*/false, std::forward(fn), registry, /*strict=*/false); } + + // Registers a non-strict overload with overload_id + template + static absl::Status RegisterNonStrictOverload(absl::string_view name, + absl::string_view overload_id, + FunctionT&& fn, + FunctionRegistry& registry) { + return registry.Register( + AdapterT::CreateDescriptor(name, overload_id, /*receiver_style=*/false, + FunctionDescriptorOptions{ + /*is_strict=*/false}), + AdapterT::WrapFunction(std::forward(fn))); + } }; } // namespace cel diff --git a/runtime/runtime_options.h b/runtime/runtime_options.h index 7a61208a0..f40947184 100644 --- a/runtime/runtime_options.h +++ b/runtime/runtime_options.h @@ -188,6 +188,15 @@ struct RuntimeOptions { // // If disabled, will use the legacy behavior of rounding to 6 decimal places. bool enable_precision_preserving_double_format = true; + + // Enable type-level function overload resolution. + // + // When true, function overload resolution uses ArgumentTypesMatch with + // Value.GetRuntimeType() for type-level verification (including container + // element types and TypeParam bindings). + // + // When false (default), uses ArgumentKindsMatch for Kind-level matching only. + bool enable_type_level_overload = false; }; // LINT.ThenChange(//depot/google3/eval/public/cel_options.h) diff --git a/runtime/standard/BUILD b/runtime/standard/BUILD index 02a23ef1b..349e2ea47 100644 --- a/runtime/standard/BUILD +++ b/runtime/standard/BUILD @@ -35,6 +35,7 @@ cc_library( "//base:builtins", "//base:function_adapter", "//common:value", + "//common:standard_definitions", "//internal:number", "//internal:status_macros", "//runtime:function_registry", @@ -71,6 +72,7 @@ cc_library( "//base:builtins", "//base:function_adapter", "//common:value", + "//common:standard_definitions", "//internal:number", "//internal:status_macros", "//runtime:function_registry", @@ -111,6 +113,7 @@ cc_library( "//base:function_adapter", "//common:value", "//common:value_kind", + "//common:standard_definitions", "//internal:number", "//internal:status_macros", "//runtime:function_registry", @@ -158,6 +161,7 @@ cc_library( "//base:builtins", "//base:function_adapter", "//common:value", + "//common:standard_definitions", "//internal:status_macros", "//runtime:function_registry", "//runtime:register_function_helper", @@ -204,6 +208,7 @@ cc_library( "//base:builtins", "//base:function_adapter", "//common:value", + "//common:standard_definitions", "//internal:status_macros", "//runtime:function_registry", "//runtime:runtime_options", @@ -236,6 +241,7 @@ cc_library( "//base:builtins", "//base:function_adapter", "//common:value", + "//common:standard_definitions", "//internal:overflow", "//internal:status_macros", "//internal:time", @@ -274,6 +280,7 @@ cc_library( "//base:builtins", "//base:function_adapter", "//common:value", + "//common:standard_definitions", "//internal:overflow", "//internal:status_macros", "//runtime:function_registry", @@ -305,6 +312,7 @@ cc_library( "//base:builtins", "//base:function_adapter", "//common:value", + "//common:standard_definitions", "//internal:overflow", "//internal:status_macros", "//runtime:function_registry", @@ -338,6 +346,7 @@ cc_library( "//base:builtins", "//base:function_adapter", "//common:value", + "//common:standard_definitions", "//internal:status_macros", "//runtime:function_registry", "//runtime:runtime_options", @@ -371,6 +380,7 @@ cc_library( "//base:builtins", "//base:function_adapter", "//common:value", + "//common:standard_definitions", "//internal:re2_options", "//internal:status_macros", "//runtime:function_registry", diff --git a/runtime/standard/arithmetic_functions.cc b/runtime/standard/arithmetic_functions.cc index a851ceb39..587f3a66c 100644 --- a/runtime/standard/arithmetic_functions.cc +++ b/runtime/standard/arithmetic_functions.cc @@ -21,6 +21,7 @@ #include "absl/strings/string_view.h" #include "base/builtins.h" #include "base/function_adapter.h" +#include "common/standard_definitions.h" #include "common/value.h" #include "internal/overflow.h" #include "internal/status_macros.h" @@ -170,22 +171,30 @@ Value Modulo(uint64_t v0, uint64_t v1) { // Helper method // Registers all arithmetic functions for template parameter type. template -absl::Status RegisterArithmeticFunctionsForType(FunctionRegistry& registry) { +absl::Status RegisterArithmeticFunctionsForType(FunctionRegistry& registry, + absl::string_view add_overload_id, + absl::string_view subtract_overload_id, + absl::string_view multiply_overload_id, + absl::string_view divide_overload_id) { using FunctionAdapter = cel::BinaryFunctionAdapter; - CEL_RETURN_IF_ERROR(registry.Register( - FunctionAdapter::CreateDescriptor(cel::builtin::kAdd, false), - FunctionAdapter::WrapFunction(&Add))); + CEL_RETURN_IF_ERROR( + registry.Register(FunctionAdapter::CreateDescriptor( + cel::builtin::kAdd, add_overload_id, false), + FunctionAdapter::WrapFunction(&Add))); CEL_RETURN_IF_ERROR(registry.Register( - FunctionAdapter::CreateDescriptor(cel::builtin::kSubtract, false), + FunctionAdapter::CreateDescriptor(cel::builtin::kSubtract, + subtract_overload_id, false), FunctionAdapter::WrapFunction(&Sub))); CEL_RETURN_IF_ERROR(registry.Register( - FunctionAdapter::CreateDescriptor(cel::builtin::kMultiply, false), + FunctionAdapter::CreateDescriptor(cel::builtin::kMultiply, + multiply_overload_id, false), FunctionAdapter::WrapFunction(&Mul))); return registry.Register( - FunctionAdapter::CreateDescriptor(cel::builtin::kDivide, false), + FunctionAdapter::CreateDescriptor(cel::builtin::kDivide, + divide_overload_id, false), FunctionAdapter::WrapFunction(&Div)); } @@ -193,39 +202,49 @@ absl::Status RegisterArithmeticFunctionsForType(FunctionRegistry& registry) { absl::Status RegisterArithmeticFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { - CEL_RETURN_IF_ERROR(RegisterArithmeticFunctionsForType(registry)); - CEL_RETURN_IF_ERROR(RegisterArithmeticFunctionsForType(registry)); - CEL_RETURN_IF_ERROR(RegisterArithmeticFunctionsForType(registry)); + using cel::StandardOverloadIds; + + CEL_RETURN_IF_ERROR(RegisterArithmeticFunctionsForType(registry, + StandardOverloadIds::kAddInt, StandardOverloadIds::kSubtractInt, + StandardOverloadIds::kMultiplyInt, StandardOverloadIds::kDivideInt)); + + CEL_RETURN_IF_ERROR(RegisterArithmeticFunctionsForType(registry, + StandardOverloadIds::kAddUint, StandardOverloadIds::kSubtractUint, + StandardOverloadIds::kMultiplyUint, StandardOverloadIds::kDivideUint)); + + CEL_RETURN_IF_ERROR(RegisterArithmeticFunctionsForType(registry, + StandardOverloadIds::kAddDouble, StandardOverloadIds::kSubtractDouble, + StandardOverloadIds::kMultiplyDouble, StandardOverloadIds::kDivideDouble)); // Modulo CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter::CreateDescriptor( - cel::builtin::kModulo, false), + cel::builtin::kModulo, StandardOverloadIds::kModuloInt, false), BinaryFunctionAdapter::WrapFunction( &Modulo))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter::CreateDescriptor( - cel::builtin::kModulo, false), + cel::builtin::kModulo, StandardOverloadIds::kModuloUint, false), BinaryFunctionAdapter::WrapFunction( &Modulo))); // Negation group - CEL_RETURN_IF_ERROR( - registry.Register(UnaryFunctionAdapter::CreateDescriptor( - cel::builtin::kNeg, false), - UnaryFunctionAdapter::WrapFunction( - [](int64_t value) -> Value { - auto inv = cel::internal::CheckedNegation(value); - if (!inv.ok()) { - return ErrorValue(inv.status()); - } - return IntValue(*inv); - }))); + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter::CreateDescriptor( + cel::builtin::kNeg, StandardOverloadIds::kNegateInt, false), + UnaryFunctionAdapter::WrapFunction( + [](int64_t value) -> Value { + auto inv = cel::internal::CheckedNegation(value); + if (!inv.ok()) { + return ErrorValue(inv.status()); + } + return IntValue(*inv); + }))); return registry.Register( UnaryFunctionAdapter::CreateDescriptor(cel::builtin::kNeg, - false), + StandardOverloadIds::kNegateDouble, false), UnaryFunctionAdapter::WrapFunction( [](double value) -> double { return -value; })); } diff --git a/runtime/standard/arithmetic_functions_test.cc b/runtime/standard/arithmetic_functions_test.cc index f1da74fd2..af5e770f2 100644 --- a/runtime/standard/arithmetic_functions_test.cc +++ b/runtime/standard/arithmetic_functions_test.cc @@ -27,16 +27,16 @@ using ::testing::UnorderedElementsAre; MATCHER_P2(MatchesOperatorDescriptor, name, expected_kind, "") { const FunctionDescriptor& descriptor = arg.descriptor; - std::vector types{expected_kind, expected_kind}; + std::vector kinds{expected_kind, expected_kind}; return descriptor.name() == name && descriptor.receiver_style() == false && - descriptor.types() == types; + descriptor.kinds() == kinds; } MATCHER_P(MatchesNegationDescriptor, expected_kind, "") { const FunctionDescriptor& descriptor = arg.descriptor; - std::vector types{expected_kind}; + std::vector kinds{expected_kind}; return descriptor.name() == builtin::kNeg && - descriptor.receiver_style() == false && descriptor.types() == types; + descriptor.receiver_style() == false && descriptor.kinds() == kinds; } TEST(RegisterArithmeticFunctions, Registered) { diff --git a/runtime/standard/comparison_functions.cc b/runtime/standard/comparison_functions.cc index bddd1efe9..23103945f 100644 --- a/runtime/standard/comparison_functions.cc +++ b/runtime/standard/comparison_functions.cc @@ -20,6 +20,7 @@ #include "absl/time/time.h" #include "base/builtins.h" #include "base/function_adapter.h" +#include "common/standard_definitions.h" #include "common/value.h" #include "internal/number.h" #include "internal/status_macros.h" @@ -159,22 +160,31 @@ bool CrossNumericGreaterOrEqualTo(T t, U u) { template absl::Status RegisterComparisonFunctionsForType( - cel::FunctionRegistry& registry) { + cel::FunctionRegistry& registry, + absl::string_view less_overload_id, + absl::string_view less_or_equal_overload_id, + absl::string_view greater_overload_id, + absl::string_view greater_or_equal_overload_id) { using FunctionAdapter = BinaryFunctionAdapter; - CEL_RETURN_IF_ERROR(registry.Register( - FunctionAdapter::CreateDescriptor(cel::builtin::kLess, false), - FunctionAdapter::WrapFunction(LessThan))); + CEL_RETURN_IF_ERROR( + registry.Register(FunctionAdapter::CreateDescriptor(cel::builtin::kLess, + less_overload_id, + false), + FunctionAdapter::WrapFunction(LessThan))); CEL_RETURN_IF_ERROR(registry.Register( - FunctionAdapter::CreateDescriptor(cel::builtin::kLessOrEqual, false), + FunctionAdapter::CreateDescriptor(cel::builtin::kLessOrEqual, + less_or_equal_overload_id, false), FunctionAdapter::WrapFunction(LessThanOrEqual))); CEL_RETURN_IF_ERROR(registry.Register( - FunctionAdapter::CreateDescriptor(cel::builtin::kGreater, false), + FunctionAdapter::CreateDescriptor(cel::builtin::kGreater, + greater_overload_id, false), FunctionAdapter::WrapFunction(GreaterThan))); CEL_RETURN_IF_ERROR(registry.Register( - FunctionAdapter::CreateDescriptor(cel::builtin::kGreaterOrEqual, false), + FunctionAdapter::CreateDescriptor(cel::builtin::kGreaterOrEqual, + greater_or_equal_overload_id, false), FunctionAdapter::WrapFunction(GreaterThanOrEqual))); return absl::OkStatus(); @@ -182,45 +192,81 @@ absl::Status RegisterComparisonFunctionsForType( absl::Status RegisterHomogenousComparisonFunctions( cel::FunctionRegistry& registry) { - CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); - - CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); - - CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); - - CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); - - CEL_RETURN_IF_ERROR( - RegisterComparisonFunctionsForType(registry)); - - CEL_RETURN_IF_ERROR( - RegisterComparisonFunctionsForType(registry)); - - CEL_RETURN_IF_ERROR( - RegisterComparisonFunctionsForType(registry)); - - CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); + using cel::StandardOverloadIds; + + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType( + registry, StandardOverloadIds::kLessBool, + StandardOverloadIds::kLessEqualsBool, StandardOverloadIds::kGreaterBool, + StandardOverloadIds::kGreaterEqualsBool)); + + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType( + registry, StandardOverloadIds::kLessInt, + StandardOverloadIds::kLessEqualsInt, StandardOverloadIds::kGreaterInt, + StandardOverloadIds::kGreaterEqualsInt)); + + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType( + registry, StandardOverloadIds::kLessUint, + StandardOverloadIds::kLessEqualsUint, StandardOverloadIds::kGreaterUint, + StandardOverloadIds::kGreaterEqualsUint)); + + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType( + registry, StandardOverloadIds::kLessDouble, + StandardOverloadIds::kLessEqualsDouble, + StandardOverloadIds::kGreaterDouble, + StandardOverloadIds::kGreaterEqualsDouble)); + + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType( + registry, StandardOverloadIds::kLessString, + StandardOverloadIds::kLessEqualsString, + StandardOverloadIds::kGreaterString, + StandardOverloadIds::kGreaterEqualsString)); + + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType( + registry, StandardOverloadIds::kLessBytes, + StandardOverloadIds::kLessEqualsBytes, + StandardOverloadIds::kGreaterBytes, + StandardOverloadIds::kGreaterEqualsBytes)); + + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType( + registry, StandardOverloadIds::kLessDuration, + StandardOverloadIds::kLessEqualsDuration, + StandardOverloadIds::kGreaterDuration, + StandardOverloadIds::kGreaterEqualsDuration)); + + CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType( + registry, StandardOverloadIds::kLessTimestamp, + StandardOverloadIds::kLessEqualsTimestamp, + StandardOverloadIds::kGreaterTimestamp, + StandardOverloadIds::kGreaterEqualsTimestamp)); return absl::OkStatus(); } template -absl::Status RegisterCrossNumericComparisons(cel::FunctionRegistry& registry) { +absl::Status RegisterCrossNumericComparisons(cel::FunctionRegistry& registry, + absl::string_view less_overload_id, + absl::string_view greater_overload_id, + absl::string_view greater_or_equal_overload_id, + absl::string_view less_or_equal_overload_id) { using FunctionAdapter = BinaryFunctionAdapter; CEL_RETURN_IF_ERROR(registry.Register( FunctionAdapter::CreateDescriptor(cel::builtin::kLess, + less_overload_id, /*receiver_style=*/false), - FunctionAdapter::WrapFunction(&CrossNumericLessThan))); + FunctionAdapter::WrapFunction(&CrossNumericLessThan))); CEL_RETURN_IF_ERROR(registry.Register( FunctionAdapter::CreateDescriptor(cel::builtin::kGreater, + greater_overload_id, /*receiver_style=*/false), FunctionAdapter::WrapFunction(&CrossNumericGreaterThan))); CEL_RETURN_IF_ERROR(registry.Register( FunctionAdapter::CreateDescriptor(cel::builtin::kGreaterOrEqual, + greater_or_equal_overload_id, /*receiver_style=*/false), FunctionAdapter::WrapFunction(&CrossNumericGreaterOrEqualTo))); CEL_RETURN_IF_ERROR(registry.Register( FunctionAdapter::CreateDescriptor(cel::builtin::kLessOrEqual, + less_or_equal_overload_id, /*receiver_style=*/false), FunctionAdapter::WrapFunction(&CrossNumericLessOrEqualTo))); return absl::OkStatus(); @@ -228,32 +274,105 @@ absl::Status RegisterCrossNumericComparisons(cel::FunctionRegistry& registry) { absl::Status RegisterHeterogeneousComparisonFunctions( cel::FunctionRegistry& registry) { + using cel::StandardOverloadIds; + + CEL_RETURN_IF_ERROR( + (RegisterCrossNumericComparisons(registry, + StandardOverloadIds::kLessDoubleInt, + StandardOverloadIds::kGreaterDoubleInt, + StandardOverloadIds::kGreaterEqualsDoubleInt, + StandardOverloadIds::kLessEqualsDoubleInt))); + + CEL_RETURN_IF_ERROR( + (RegisterCrossNumericComparisons(registry, + StandardOverloadIds::kLessDoubleUint, + StandardOverloadIds::kGreaterDoubleUint, + StandardOverloadIds::kGreaterEqualsDoubleUint, + StandardOverloadIds::kLessEqualsDoubleUint))); + + CEL_RETURN_IF_ERROR( + (RegisterCrossNumericComparisons(registry, + StandardOverloadIds::kLessUintDouble, + StandardOverloadIds::kGreaterUintDouble, + StandardOverloadIds::kGreaterEqualsUintDouble, + StandardOverloadIds::kLessEqualsUintDouble))); + + CEL_RETURN_IF_ERROR( + (RegisterCrossNumericComparisons(registry, + StandardOverloadIds::kLessUintInt, + StandardOverloadIds::kGreaterUintInt, + StandardOverloadIds::kGreaterEqualsUintInt, + StandardOverloadIds::kLessEqualsUintInt))); + + CEL_RETURN_IF_ERROR( + (RegisterCrossNumericComparisons(registry, + StandardOverloadIds::kLessIntDouble, + StandardOverloadIds::kGreaterIntDouble, + StandardOverloadIds::kGreaterEqualsIntDouble, + StandardOverloadIds::kLessEqualsIntDouble))); + CEL_RETURN_IF_ERROR( - (RegisterCrossNumericComparisons(registry))); + (RegisterCrossNumericComparisons(registry, + StandardOverloadIds::kLessIntUint, + StandardOverloadIds::kGreaterIntUint, + StandardOverloadIds::kGreaterEqualsIntUint, + StandardOverloadIds::kLessEqualsIntUint))); + CEL_RETURN_IF_ERROR( - (RegisterCrossNumericComparisons(registry))); + RegisterComparisonFunctionsForType(registry, + StandardOverloadIds::kLessBool, + StandardOverloadIds::kLessEqualsBool, + StandardOverloadIds::kGreaterBool, + StandardOverloadIds::kGreaterEqualsBool)); CEL_RETURN_IF_ERROR( - (RegisterCrossNumericComparisons(registry))); + RegisterComparisonFunctionsForType(registry, + StandardOverloadIds::kLessInt, + StandardOverloadIds::kLessEqualsInt, + StandardOverloadIds::kGreaterInt, + StandardOverloadIds::kGreaterEqualsInt)); + CEL_RETURN_IF_ERROR( - (RegisterCrossNumericComparisons(registry))); + RegisterComparisonFunctionsForType(registry, + StandardOverloadIds::kLessUint, + StandardOverloadIds::kLessEqualsUint, + StandardOverloadIds::kGreaterUint, + StandardOverloadIds::kGreaterEqualsUint)); CEL_RETURN_IF_ERROR( - (RegisterCrossNumericComparisons(registry))); + RegisterComparisonFunctionsForType(registry, + StandardOverloadIds::kLessDouble, + StandardOverloadIds::kLessEqualsDouble, + StandardOverloadIds::kGreaterDouble, + StandardOverloadIds::kGreaterEqualsDouble)); + CEL_RETURN_IF_ERROR( - (RegisterCrossNumericComparisons(registry))); + RegisterComparisonFunctionsForType(registry, + StandardOverloadIds::kLessString, + StandardOverloadIds::kLessEqualsString, + StandardOverloadIds::kGreaterString, + StandardOverloadIds::kGreaterEqualsString)); - CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); - CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); - CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); - CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); CEL_RETURN_IF_ERROR( - RegisterComparisonFunctionsForType(registry)); + RegisterComparisonFunctionsForType(registry, + StandardOverloadIds::kLessBytes, + StandardOverloadIds::kLessEqualsBytes, + StandardOverloadIds::kGreaterBytes, + StandardOverloadIds::kGreaterEqualsBytes)); + CEL_RETURN_IF_ERROR( - RegisterComparisonFunctionsForType(registry)); + RegisterComparisonFunctionsForType(registry, + StandardOverloadIds::kLessDuration, + StandardOverloadIds::kLessEqualsDuration, + StandardOverloadIds::kGreaterDuration, + StandardOverloadIds::kGreaterEqualsDuration)); + CEL_RETURN_IF_ERROR( - RegisterComparisonFunctionsForType(registry)); - CEL_RETURN_IF_ERROR(RegisterComparisonFunctionsForType(registry)); + RegisterComparisonFunctionsForType(registry, + StandardOverloadIds::kLessTimestamp, + StandardOverloadIds::kLessEqualsTimestamp, + StandardOverloadIds::kGreaterTimestamp, + StandardOverloadIds::kGreaterEqualsTimestamp)); return absl::OkStatus(); } diff --git a/runtime/standard/container_functions.cc b/runtime/standard/container_functions.cc index c81dc7596..d7685d6dd 100644 --- a/runtime/standard/container_functions.cc +++ b/runtime/standard/container_functions.cc @@ -23,6 +23,7 @@ #include "absl/status/statusor.h" #include "base/builtins.h" #include "base/function_adapter.h" +#include "common/standard_definitions.h" #include "common/value.h" #include "common/values/list_value_builder.h" #include "internal/status_macros.h" @@ -99,31 +100,53 @@ absl::StatusOr AppendList(ListValue value1, const Value& value2) { absl::Status RegisterContainerFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { + using cel::StandardOverloadIds; + // receiver style = true/false // Support both the global and receiver style size() for lists and maps. - for (bool receiver_style : {true, false}) { - CEL_RETURN_IF_ERROR(registry.Register( - cel::UnaryFunctionAdapter, const ListValue&>:: - CreateDescriptor(cel::builtin::kSize, receiver_style), - UnaryFunctionAdapter, - const ListValue&>::WrapFunction(ListSizeImpl))); - - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter, const MapValue&>:: - CreateDescriptor(cel::builtin::kSize, receiver_style), - UnaryFunctionAdapter, - const MapValue&>::WrapFunction(MapSizeImpl))); - } + CEL_RETURN_IF_ERROR(registry.Register( + cel::UnaryFunctionAdapter, const ListValue&>:: + CreateDescriptor(cel::builtin::kSize, + StandardOverloadIds::kSizeList, + /*receiver_style=*/false), + UnaryFunctionAdapter, + const ListValue&>::WrapFunction(ListSizeImpl))); + + CEL_RETURN_IF_ERROR(registry.Register( + cel::UnaryFunctionAdapter, const ListValue&>:: + CreateDescriptor(cel::builtin::kSize, + StandardOverloadIds::kSizeListMember, + /*receiver_style=*/true), + UnaryFunctionAdapter, + const ListValue&>::WrapFunction(ListSizeImpl))); + + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter, const MapValue&>:: + CreateDescriptor(cel::builtin::kSize, + StandardOverloadIds::kSizeMap, + /*receiver_style=*/false), + UnaryFunctionAdapter, + const MapValue&>::WrapFunction(MapSizeImpl))); + + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter, const MapValue&>:: + CreateDescriptor(cel::builtin::kSize, + StandardOverloadIds::kSizeMapMember, + /*receiver_style=*/true), + UnaryFunctionAdapter, + const MapValue&>::WrapFunction(MapSizeImpl))); if (options.enable_list_concat) { CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter< - absl::StatusOr, const ListValue&, - const ListValue&>::CreateDescriptor(cel::builtin::kAdd, false), + BinaryFunctionAdapter, const ListValue&, + const ListValue&>:: + CreateDescriptor(cel::builtin::kAdd, + StandardOverloadIds::kAddList, false), BinaryFunctionAdapter, const ListValue&, const ListValue&>::WrapFunction(ConcatList))); } + // Internal runtime function, no overload_id return registry.Register( BinaryFunctionAdapter< absl::StatusOr, ListValue, diff --git a/runtime/standard/container_functions_test.cc b/runtime/standard/container_functions_test.cc index 3e4838bc2..ec239f2c7 100644 --- a/runtime/standard/container_functions_test.cc +++ b/runtime/standard/container_functions_test.cc @@ -28,9 +28,9 @@ using ::testing::UnorderedElementsAre; MATCHER_P3(MatchesDescriptor, name, receiver, expected_kinds, "") { const FunctionDescriptor& descriptor = arg.descriptor; - const std::vector& types = expected_kinds; + const std::vector& kinds = expected_kinds; return descriptor.name() == name && descriptor.receiver_style() == receiver && - descriptor.types() == types; + descriptor.kinds() == kinds; } TEST(RegisterContainerFunctions, RegistersSizeFunctions) { diff --git a/runtime/standard/container_membership_functions.cc b/runtime/standard/container_membership_functions.cc index cc0638429..70591b457 100644 --- a/runtime/standard/container_membership_functions.cc +++ b/runtime/standard/container_membership_functions.cc @@ -24,6 +24,7 @@ #include "absl/strings/string_view.h" #include "base/builtins.h" #include "base/function_adapter.h" +#include "common/standard_definitions.h" #include "common/value.h" #include "internal/number.h" #include "internal/status_macros.h" @@ -132,8 +133,10 @@ absl::Status RegisterListMembershipFunctions(FunctionRegistry& registry, CEL_RETURN_IF_ERROR( (RegisterHelper, const Value&, const ListValue&>>:: - RegisterGlobalOverload(op, &HeterogeneousEqualityIn, registry))); + RegisterGlobalOverload(op, cel::StandardOverloadIds::kInList, + &HeterogeneousEqualityIn, registry))); } else { + // Generic function: no overload_id, fallback to runtime matching CEL_RETURN_IF_ERROR( (RegisterHelper, bool, const ListValue&>>:: diff --git a/runtime/standard/container_membership_functions_test.cc b/runtime/standard/container_membership_functions_test.cc index 9c90136d9..98cf8e7fa 100644 --- a/runtime/standard/container_membership_functions_test.cc +++ b/runtime/standard/container_membership_functions_test.cc @@ -32,9 +32,9 @@ using ::testing::UnorderedElementsAre; MATCHER_P3(MatchesDescriptor, name, receiver, expected_kinds, "") { const FunctionDescriptor& descriptor = *arg; - const std::vector& types = expected_kinds; + const std::vector& kinds = expected_kinds; return descriptor.name() == name && descriptor.receiver_style() == receiver && - descriptor.types() == types; + descriptor.kinds() == kinds; } static constexpr std::array kInOperators = { diff --git a/runtime/standard/equality_functions.cc b/runtime/standard/equality_functions.cc index 6546db16c..4862fff30 100644 --- a/runtime/standard/equality_functions.cc +++ b/runtime/standard/equality_functions.cc @@ -29,6 +29,7 @@ #include "absl/types/optional.h" #include "base/builtins.h" #include "base/function_adapter.h" +#include "common/standard_definitions.h" #include "common/value.h" #include "common/value_kind.h" #include "internal/number.h" @@ -522,10 +523,12 @@ absl::Status RegisterHeterogeneousEqualityFunctions( using Adapter = cel::RegisterHelper< BinaryFunctionAdapter, const Value&, const Value&>>; CEL_RETURN_IF_ERROR( - Adapter::RegisterGlobalOverload(kEqual, &EqualOverloadImpl, registry)); + Adapter::RegisterGlobalOverload(kEqual, cel::StandardOverloadIds::kEquals, + &EqualOverloadImpl, registry)); CEL_RETURN_IF_ERROR(Adapter::RegisterGlobalOverload( - kInequal, &InequalOverloadImpl, registry)); + kInequal, cel::StandardOverloadIds::kNotEquals, + &InequalOverloadImpl, registry)); return absl::OkStatus(); } diff --git a/runtime/standard/equality_functions_test.cc b/runtime/standard/equality_functions_test.cc index 605c66d82..d591704cb 100644 --- a/runtime/standard/equality_functions_test.cc +++ b/runtime/standard/equality_functions_test.cc @@ -33,9 +33,9 @@ using ::testing::UnorderedElementsAre; MATCHER_P3(MatchesDescriptor, name, receiver, expected_kinds, "") { const FunctionDescriptor& descriptor = *arg; - const std::vector& types = expected_kinds; + const std::vector& kinds = expected_kinds; return descriptor.name() == name && descriptor.receiver_style() == receiver && - descriptor.types() == types; + descriptor.kinds() == kinds; } TEST(RegisterEqualityFunctionsHomogeneous, RegistersEqualOperators) { diff --git a/runtime/standard/logical_functions.cc b/runtime/standard/logical_functions.cc index cd3dd3cb5..2d53e02c1 100644 --- a/runtime/standard/logical_functions.cc +++ b/runtime/standard/logical_functions.cc @@ -18,6 +18,7 @@ #include "absl/strings/string_view.h" #include "base/builtins.h" #include "base/function_adapter.h" +#include "common/standard_definitions.h" #include "common/value.h" #include "internal/status_macros.h" #include "runtime/function_registry.h" @@ -47,18 +48,24 @@ Value NotStrictlyFalseImpl(const Value& value) { absl::Status RegisterLogicalFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { + using cel::StandardOverloadIds; + // logical NOT CEL_RETURN_IF_ERROR( (RegisterHelper>::RegisterGlobalOverload( - builtin::kNot, [](bool value) -> bool { return !value; }, registry))); + builtin::kNot, StandardOverloadIds::kNot, + [](bool value) -> bool { return !value; }, registry))); // Strictness using StrictnessHelper = RegisterHelper>; CEL_RETURN_IF_ERROR(StrictnessHelper::RegisterNonStrictOverload( - builtin::kNotStrictlyFalse, &NotStrictlyFalseImpl, registry)); + builtin::kNotStrictlyFalse, StandardOverloadIds::kNotStrictlyFalse, + &NotStrictlyFalseImpl, registry)); CEL_RETURN_IF_ERROR(StrictnessHelper::RegisterNonStrictOverload( - builtin::kNotStrictlyFalseDeprecated, &NotStrictlyFalseImpl, registry)); + builtin::kNotStrictlyFalseDeprecated, + StandardOverloadIds::kNotStrictlyFalseDeprecated, &NotStrictlyFalseImpl, + registry)); return absl::OkStatus(); } diff --git a/runtime/standard/regex_functions.cc b/runtime/standard/regex_functions.cc index 6833f7804..f4fa971a0 100644 --- a/runtime/standard/regex_functions.cc +++ b/runtime/standard/regex_functions.cc @@ -17,6 +17,7 @@ #include "absl/strings/string_view.h" #include "base/builtins.h" #include "base/function_adapter.h" +#include "common/standard_definitions.h" #include "common/value.h" #include "internal/re2_options.h" #include "internal/status_macros.h" @@ -29,6 +30,8 @@ namespace {} // namespace absl::Status RegisterRegexFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { + using cel::StandardOverloadIds; + if (options.enable_regex) { auto regex_matches = [max_size = options.regex_max_program_size]( const StringValue& target, @@ -40,14 +43,20 @@ absl::Status RegisterRegexFunctions(FunctionRegistry& registry, }; // bind str.matches(re) and matches(str, re) - for (bool receiver_style : {true, false}) { - using MatchFnAdapter = - BinaryFunctionAdapter; - CEL_RETURN_IF_ERROR( - registry.Register(MatchFnAdapter::CreateDescriptor( - cel::builtin::kRegexMatch, receiver_style), - MatchFnAdapter::WrapFunction(regex_matches))); - } + using MatchFnAdapter = + BinaryFunctionAdapter; + + CEL_RETURN_IF_ERROR(registry.Register( + MatchFnAdapter::CreateDescriptor(cel::builtin::kRegexMatch, + StandardOverloadIds::kMatches, + /*receiver_style=*/false), + MatchFnAdapter::WrapFunction(regex_matches))); + + CEL_RETURN_IF_ERROR(registry.Register( + MatchFnAdapter::CreateDescriptor(cel::builtin::kRegexMatch, + StandardOverloadIds::kMatchesMember, + /*receiver_style=*/true), + MatchFnAdapter::WrapFunction(regex_matches))); } // if options.enable_regex return absl::OkStatus(); diff --git a/runtime/standard/regex_functions_test.cc b/runtime/standard/regex_functions_test.cc index 59bbe9abf..53806480e 100644 --- a/runtime/standard/regex_functions_test.cc +++ b/runtime/standard/regex_functions_test.cc @@ -38,10 +38,10 @@ MATCHER_P2(MatchesDescriptor, name, call_style, "") { break; } const FunctionDescriptor& descriptor = *arg; - std::vector types{Kind::kString, Kind::kString}; + std::vector kinds{Kind::kString, Kind::kString}; return descriptor.name() == name && descriptor.receiver_style() == receiver_style && - descriptor.types() == types; + descriptor.kinds() == kinds; } TEST(RegisterRegexFunctions, Registered) { diff --git a/runtime/standard/string_functions.cc b/runtime/standard/string_functions.cc index 2bcfe185c..60a6525c7 100644 --- a/runtime/standard/string_functions.cc +++ b/runtime/standard/string_functions.cc @@ -19,10 +19,10 @@ #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "base/builtins.h" #include "base/function_adapter.h" +#include "common/standard_definitions.h" #include "common/value.h" #include "internal/status_macros.h" #include "runtime/function_registry.h" @@ -63,6 +63,8 @@ bool StringStartsWith(const StringValue& value, const StringValue& prefix) { } absl::Status RegisterSizeFunctions(FunctionRegistry& registry) { + using cel::StandardOverloadIds; + // String size auto size_func = [](const StringValue& value) -> int64_t { return value.Size(); @@ -71,10 +73,12 @@ absl::Status RegisterSizeFunctions(FunctionRegistry& registry) { // Support global and receiver style size() operations on strings. using StrSizeFnAdapter = UnaryFunctionAdapter; CEL_RETURN_IF_ERROR(StrSizeFnAdapter::RegisterGlobalOverload( - cel::builtin::kSize, size_func, registry)); + cel::builtin::kSize, StandardOverloadIds::kSizeString, size_func, + registry)); CEL_RETURN_IF_ERROR(StrSizeFnAdapter::RegisterMemberOverload( - cel::builtin::kSize, size_func, registry)); + cel::builtin::kSize, StandardOverloadIds::kSizeStringMember, size_func, + registry)); // Bytes size auto bytes_size_func = [](const BytesValue& value) -> int64_t { @@ -84,50 +88,74 @@ absl::Status RegisterSizeFunctions(FunctionRegistry& registry) { // Support global and receiver style size() operations on bytes. using BytesSizeFnAdapter = UnaryFunctionAdapter; CEL_RETURN_IF_ERROR(BytesSizeFnAdapter::RegisterGlobalOverload( - cel::builtin::kSize, bytes_size_func, registry)); + cel::builtin::kSize, StandardOverloadIds::kSizeBytes, bytes_size_func, + registry)); - return BytesSizeFnAdapter::RegisterMemberOverload(cel::builtin::kSize, - bytes_size_func, registry); + return BytesSizeFnAdapter::RegisterMemberOverload( + cel::builtin::kSize, StandardOverloadIds::kSizeBytesMember, + bytes_size_func, registry); } absl::Status RegisterConcatFunctions(FunctionRegistry& registry) { + using cel::StandardOverloadIds; + using StrCatFnAdapter = BinaryFunctionAdapter, const StringValue&, const StringValue&>; CEL_RETURN_IF_ERROR(StrCatFnAdapter::RegisterGlobalOverload( - cel::builtin::kAdd, &ConcatString, registry)); + cel::builtin::kAdd, StandardOverloadIds::kAddString, &ConcatString, + registry)); using BytesCatFnAdapter = BinaryFunctionAdapter, const BytesValue&, const BytesValue&>; - return BytesCatFnAdapter::RegisterGlobalOverload(cel::builtin::kAdd, - &ConcatBytes, registry); + return BytesCatFnAdapter::RegisterGlobalOverload( + cel::builtin::kAdd, StandardOverloadIds::kAddBytes, &ConcatBytes, + registry); } } // namespace absl::Status RegisterStringFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { + using cel::StandardOverloadIds; + // Basic substring tests (contains, startsWith, endsWith) - for (bool receiver_style : {true, false}) { - auto status = - BinaryFunctionAdapter:: - Register(cel::builtin::kStringContains, receiver_style, - StringContains, registry); - CEL_RETURN_IF_ERROR(status); - - status = - BinaryFunctionAdapter:: - Register(cel::builtin::kStringEndsWith, receiver_style, - StringEndsWith, registry); - CEL_RETURN_IF_ERROR(status); - - status = - BinaryFunctionAdapter:: - Register(cel::builtin::kStringStartsWith, receiver_style, - StringStartsWith, registry); - CEL_RETURN_IF_ERROR(status); - } + auto status = + BinaryFunctionAdapter:: + Register(cel::builtin::kStringContains, StandardOverloadIds::kContainsString, true, + StringContains, registry); + CEL_RETURN_IF_ERROR(status); + + status = + BinaryFunctionAdapter:: + Register(cel::builtin::kStringEndsWith, StandardOverloadIds::kEndsWithString, + true, StringEndsWith, registry); + CEL_RETURN_IF_ERROR(status); + + status = + BinaryFunctionAdapter:: + Register(cel::builtin::kStringStartsWith, StandardOverloadIds::kStartsWithString, + true, StringStartsWith, registry); + CEL_RETURN_IF_ERROR(status); + + // Global overloads for string contains, startsWith, endsWith + // This is used for backward compatibility with stored expressions + status = + BinaryFunctionAdapter:: + Register(cel::builtin::kStringContains, false, + StringContains, registry); + CEL_RETURN_IF_ERROR(status); + status = + BinaryFunctionAdapter:: + Register(cel::builtin::kStringEndsWith, false, + StringEndsWith, registry); + CEL_RETURN_IF_ERROR(status); + status = + BinaryFunctionAdapter:: + Register(cel::builtin::kStringStartsWith, false, + StringStartsWith, registry); + CEL_RETURN_IF_ERROR(status); // string concatenation if enabled if (options.enable_string_concat) { diff --git a/runtime/standard/string_functions_test.cc b/runtime/standard/string_functions_test.cc index d520b3577..8f8b14275 100644 --- a/runtime/standard/string_functions_test.cc +++ b/runtime/standard/string_functions_test.cc @@ -38,10 +38,10 @@ MATCHER_P3(MatchesDescriptor, name, call_style, expected_kinds, "") { break; } const FunctionDescriptor& descriptor = *arg; - const std::vector& types = expected_kinds; + const std::vector& kinds = expected_kinds; return descriptor.name() == name && descriptor.receiver_style() == receiver_style && - descriptor.types() == types; + descriptor.kinds() == kinds; } TEST(RegisterStringFunctions, FunctionsRegistered) { diff --git a/runtime/standard/time_functions.cc b/runtime/standard/time_functions.cc index a0ec5377c..0afc8878d 100644 --- a/runtime/standard/time_functions.cc +++ b/runtime/standard/time_functions.cc @@ -27,6 +27,7 @@ #include "absl/time/time.h" #include "base/builtins.h" #include "base/function_adapter.h" +#include "common/standard_definitions.h" #include "common/value.h" #include "internal/overflow.h" #include "internal/status_macros.h" @@ -162,9 +163,12 @@ Value GetMilliseconds(absl::Time timestamp, absl::string_view tz) { absl::Status RegisterTimestampFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { + using cel::StandardOverloadIds; + CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: - CreateDescriptor(builtin::kFullYear, true), + CreateDescriptor(builtin::kFullYear, + StandardOverloadIds::kTimestampToYearWithTz, true), BinaryFunctionAdapter:: WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { return GetFullYear(ts, tz.ToString()); @@ -172,27 +176,30 @@ absl::Status RegisterTimestampFunctions(FunctionRegistry& registry, CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter::CreateDescriptor( - builtin::kFullYear, true), + builtin::kFullYear, StandardOverloadIds::kTimestampToYear, true), UnaryFunctionAdapter::WrapFunction( [](absl::Time ts) -> Value { return GetFullYear(ts, ""); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: - CreateDescriptor(builtin::kMonth, true), + CreateDescriptor(builtin::kMonth, + StandardOverloadIds::kTimestampToMonthWithTz, true), BinaryFunctionAdapter:: WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { return GetMonth(ts, tz.ToString()); }))); CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor(builtin::kMonth, - true), + UnaryFunctionAdapter::CreateDescriptor( + builtin::kMonth, StandardOverloadIds::kTimestampToMonth, true), UnaryFunctionAdapter::WrapFunction( [](absl::Time ts) -> Value { return GetMonth(ts, ""); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: - CreateDescriptor(builtin::kDayOfYear, true), + CreateDescriptor(builtin::kDayOfYear, + StandardOverloadIds::kTimestampToDayOfYearWithTz, + true), BinaryFunctionAdapter:: WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { return GetDayOfYear(ts, tz.ToString()); @@ -200,13 +207,16 @@ absl::Status RegisterTimestampFunctions(FunctionRegistry& registry, CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter::CreateDescriptor( - builtin::kDayOfYear, true), + builtin::kDayOfYear, StandardOverloadIds::kTimestampToDayOfYear, + true), UnaryFunctionAdapter::WrapFunction( [](absl::Time ts) -> Value { return GetDayOfYear(ts, ""); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: - CreateDescriptor(builtin::kDayOfMonth, true), + CreateDescriptor(builtin::kDayOfMonth, + StandardOverloadIds::kTimestampToDayOfMonthWithTz, + true), BinaryFunctionAdapter:: WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { return GetDayOfMonth(ts, tz.ToString()); @@ -214,27 +224,31 @@ absl::Status RegisterTimestampFunctions(FunctionRegistry& registry, CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter::CreateDescriptor( - builtin::kDayOfMonth, true), + builtin::kDayOfMonth, StandardOverloadIds::kTimestampToDayOfMonth, + true), UnaryFunctionAdapter::WrapFunction( [](absl::Time ts) -> Value { return GetDayOfMonth(ts, ""); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: - CreateDescriptor(builtin::kDate, true), + CreateDescriptor(builtin::kDate, + StandardOverloadIds::kTimestampToDateWithTz, true), BinaryFunctionAdapter:: WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { return GetDate(ts, tz.ToString()); }))); CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor(builtin::kDate, - true), + UnaryFunctionAdapter::CreateDescriptor( + builtin::kDate, StandardOverloadIds::kTimestampToDate, true), UnaryFunctionAdapter::WrapFunction( [](absl::Time ts) -> Value { return GetDate(ts, ""); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: - CreateDescriptor(builtin::kDayOfWeek, true), + CreateDescriptor(builtin::kDayOfWeek, + StandardOverloadIds::kTimestampToDayOfWeekWithTz, + true), BinaryFunctionAdapter:: WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { return GetDayOfWeek(ts, tz.ToString()); @@ -242,27 +256,31 @@ absl::Status RegisterTimestampFunctions(FunctionRegistry& registry, CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter::CreateDescriptor( - builtin::kDayOfWeek, true), + builtin::kDayOfWeek, StandardOverloadIds::kTimestampToDayOfWeek, + true), UnaryFunctionAdapter::WrapFunction( [](absl::Time ts) -> Value { return GetDayOfWeek(ts, ""); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: - CreateDescriptor(builtin::kHours, true), + CreateDescriptor(builtin::kHours, + StandardOverloadIds::kTimestampToHoursWithTz, true), BinaryFunctionAdapter:: WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { return GetHours(ts, tz.ToString()); }))); CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter::CreateDescriptor(builtin::kHours, - true), + UnaryFunctionAdapter::CreateDescriptor( + builtin::kHours, StandardOverloadIds::kTimestampToHours, true), UnaryFunctionAdapter::WrapFunction( [](absl::Time ts) -> Value { return GetHours(ts, ""); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: - CreateDescriptor(builtin::kMinutes, true), + CreateDescriptor(builtin::kMinutes, + StandardOverloadIds::kTimestampToMinutesWithTz, + true), BinaryFunctionAdapter:: WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { return GetMinutes(ts, tz.ToString()); @@ -270,13 +288,15 @@ absl::Status RegisterTimestampFunctions(FunctionRegistry& registry, CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter::CreateDescriptor( - builtin::kMinutes, true), + builtin::kMinutes, StandardOverloadIds::kTimestampToMinutes, true), UnaryFunctionAdapter::WrapFunction( [](absl::Time ts) -> Value { return GetMinutes(ts, ""); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: - CreateDescriptor(builtin::kSeconds, true), + CreateDescriptor(builtin::kSeconds, + StandardOverloadIds::kTimestampToSecondsWithTz, + true), BinaryFunctionAdapter:: WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { return GetSeconds(ts, tz.ToString()); @@ -284,13 +304,15 @@ absl::Status RegisterTimestampFunctions(FunctionRegistry& registry, CEL_RETURN_IF_ERROR(registry.Register( UnaryFunctionAdapter::CreateDescriptor( - builtin::kSeconds, true), + builtin::kSeconds, StandardOverloadIds::kTimestampToSeconds, true), UnaryFunctionAdapter::WrapFunction( [](absl::Time ts) -> Value { return GetSeconds(ts, ""); }))); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: - CreateDescriptor(builtin::kMilliseconds, true), + CreateDescriptor( + builtin::kMilliseconds, + StandardOverloadIds::kTimestampToMillisecondsWithTz, true), BinaryFunctionAdapter:: WrapFunction([](absl::Time ts, const StringValue& tz) -> Value { return GetMilliseconds(ts, tz.ToString()); @@ -298,17 +320,20 @@ absl::Status RegisterTimestampFunctions(FunctionRegistry& registry, return registry.Register( UnaryFunctionAdapter::CreateDescriptor( - builtin::kMilliseconds, true), + builtin::kMilliseconds, + StandardOverloadIds::kTimestampToMilliseconds, true), UnaryFunctionAdapter::WrapFunction( [](absl::Time ts) -> Value { return GetMilliseconds(ts, ""); })); } absl::Status RegisterCheckedTimeArithmeticFunctions( FunctionRegistry& registry) { + using cel::StandardOverloadIds; + CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor(builtin::kAdd, - false), + BinaryFunctionAdapter:: + CreateDescriptor(builtin::kAdd, + StandardOverloadIds::kAddTimestampDuration, false), BinaryFunctionAdapter, absl::Time, absl::Duration>:: WrapFunction( [](absl::Time t1, absl::Duration d2) -> absl::StatusOr { @@ -320,8 +345,9 @@ absl::Status RegisterCheckedTimeArithmeticFunctions( }))); CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter, absl::Duration, - absl::Time>::CreateDescriptor(builtin::kAdd, false), + BinaryFunctionAdapter, absl::Duration, absl::Time>:: + CreateDescriptor(builtin::kAdd, + StandardOverloadIds::kAddDurationTimestamp, false), BinaryFunctionAdapter, absl::Duration, absl::Time>:: WrapFunction( [](absl::Duration d2, absl::Time t1) -> absl::StatusOr { @@ -334,8 +360,9 @@ absl::Status RegisterCheckedTimeArithmeticFunctions( CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, absl::Duration, - absl::Duration>::CreateDescriptor(builtin::kAdd, - false), + absl::Duration>:: + CreateDescriptor(builtin::kAdd, + StandardOverloadIds::kAddDurationDuration, false), BinaryFunctionAdapter< absl::StatusOr, absl::Duration, absl::Duration>::WrapFunction([](absl::Duration d1, absl::Duration d2) @@ -349,7 +376,9 @@ absl::Status RegisterCheckedTimeArithmeticFunctions( CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, absl::Time, absl::Duration>:: - CreateDescriptor(builtin::kSubtract, false), + CreateDescriptor(builtin::kSubtract, + StandardOverloadIds::kSubtractTimestampDuration, + false), BinaryFunctionAdapter, absl::Time, absl::Duration>:: WrapFunction( [](absl::Time t1, absl::Duration d2) -> absl::StatusOr { @@ -361,9 +390,10 @@ absl::Status RegisterCheckedTimeArithmeticFunctions( }))); CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter, absl::Time, - absl::Time>::CreateDescriptor(builtin::kSubtract, - false), + BinaryFunctionAdapter, absl::Time, absl::Time>:: + CreateDescriptor(builtin::kSubtract, + StandardOverloadIds::kSubtractTimestampTimestamp, + false), BinaryFunctionAdapter, absl::Time, absl::Time>:: WrapFunction( [](absl::Time t1, absl::Time t2) -> absl::StatusOr { @@ -377,7 +407,10 @@ absl::Status RegisterCheckedTimeArithmeticFunctions( CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter< absl::StatusOr, absl::Duration, - absl::Duration>::CreateDescriptor(builtin::kSubtract, false), + absl::Duration>::CreateDescriptor(builtin::kSubtract, + StandardOverloadIds:: + kSubtractDurationDuration, + false), BinaryFunctionAdapter< absl::StatusOr, absl::Duration, absl::Duration>::WrapFunction([](absl::Duration d1, absl::Duration d2) @@ -394,27 +427,30 @@ absl::Status RegisterCheckedTimeArithmeticFunctions( absl::Status RegisterUncheckedTimeArithmeticFunctions( FunctionRegistry& registry) { + using cel::StandardOverloadIds; + CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor(builtin::kAdd, - false), + BinaryFunctionAdapter:: + CreateDescriptor(builtin::kAdd, + StandardOverloadIds::kAddTimestampDuration, false), BinaryFunctionAdapter::WrapFunction( [](absl::Time t1, absl::Duration d2) -> Value { return UnsafeTimestampValue(t1 + d2); }))); CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor(builtin::kAdd, false), + BinaryFunctionAdapter:: + CreateDescriptor(builtin::kAdd, + StandardOverloadIds::kAddDurationTimestamp, false), BinaryFunctionAdapter::WrapFunction( [](absl::Duration d2, absl::Time t1) -> Value { return UnsafeTimestampValue(t1 + d2); }))); CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter::CreateDescriptor(builtin::kAdd, - false), + BinaryFunctionAdapter:: + CreateDescriptor(builtin::kAdd, + StandardOverloadIds::kAddDurationDuration, false), BinaryFunctionAdapter:: WrapFunction([](absl::Duration d1, absl::Duration d2) -> Value { return UnsafeDurationValue(d1 + d2); @@ -422,7 +458,9 @@ absl::Status RegisterUncheckedTimeArithmeticFunctions( CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: - CreateDescriptor(builtin::kSubtract, false), + CreateDescriptor(builtin::kSubtract, + StandardOverloadIds::kSubtractTimestampDuration, + false), BinaryFunctionAdapter::WrapFunction( @@ -432,7 +470,8 @@ absl::Status RegisterUncheckedTimeArithmeticFunctions( CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter::CreateDescriptor( - builtin::kSubtract, false), + builtin::kSubtract, StandardOverloadIds::kSubtractTimestampTimestamp, + false), BinaryFunctionAdapter::WrapFunction( [](absl::Time t1, absl::Time t2) -> Value { @@ -441,7 +480,9 @@ absl::Status RegisterUncheckedTimeArithmeticFunctions( CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter:: - CreateDescriptor(builtin::kSubtract, false), + CreateDescriptor(builtin::kSubtract, + StandardOverloadIds::kSubtractDurationDuration, + false), BinaryFunctionAdapter:: WrapFunction([](absl::Duration d1, absl::Duration d2) -> Value { return UnsafeDurationValue(d1 - d2); @@ -451,28 +492,35 @@ absl::Status RegisterUncheckedTimeArithmeticFunctions( } absl::Status RegisterDurationFunctions(FunctionRegistry& registry) { + using cel::StandardOverloadIds; + // duration breakdown accessor functions using DurationAccessorFunction = UnaryFunctionAdapter; CEL_RETURN_IF_ERROR(registry.Register( - DurationAccessorFunction::CreateDescriptor(builtin::kHours, true), + DurationAccessorFunction::CreateDescriptor( + builtin::kHours, StandardOverloadIds::kDurationToHours, true), DurationAccessorFunction::WrapFunction( [](absl::Duration d) -> int64_t { return absl::ToInt64Hours(d); }))); CEL_RETURN_IF_ERROR(registry.Register( - DurationAccessorFunction::CreateDescriptor(builtin::kMinutes, true), + DurationAccessorFunction::CreateDescriptor( + builtin::kMinutes, StandardOverloadIds::kDurationToMinutes, true), DurationAccessorFunction::WrapFunction([](absl::Duration d) -> int64_t { return absl::ToInt64Minutes(d); }))); CEL_RETURN_IF_ERROR(registry.Register( - DurationAccessorFunction::CreateDescriptor(builtin::kSeconds, true), + DurationAccessorFunction::CreateDescriptor( + builtin::kSeconds, StandardOverloadIds::kDurationToSeconds, true), DurationAccessorFunction::WrapFunction([](absl::Duration d) -> int64_t { return absl::ToInt64Seconds(d); }))); return registry.Register( - DurationAccessorFunction::CreateDescriptor(builtin::kMilliseconds, true), + DurationAccessorFunction::CreateDescriptor( + builtin::kMilliseconds, StandardOverloadIds::kDurationToMilliseconds, + true), DurationAccessorFunction::WrapFunction([](absl::Duration d) -> int64_t { constexpr int64_t millis_per_second = 1000L; return absl::ToInt64Milliseconds(d) % millis_per_second; diff --git a/runtime/standard/time_functions_test.cc b/runtime/standard/time_functions_test.cc index f578a1023..67b6aec91 100644 --- a/runtime/standard/time_functions_test.cc +++ b/runtime/standard/time_functions_test.cc @@ -28,25 +28,25 @@ using ::testing::UnorderedElementsAre; MATCHER_P3(MatchesOperatorDescriptor, name, expected_kind1, expected_kind2, "") { const FunctionDescriptor& descriptor = *arg; - std::vector types{expected_kind1, expected_kind2}; + std::vector kinds{expected_kind1, expected_kind2}; return descriptor.name() == name && descriptor.receiver_style() == false && - descriptor.types() == types; + descriptor.kinds() == kinds; } MATCHER_P2(MatchesTimeAccessor, name, kind, "") { const FunctionDescriptor& descriptor = *arg; - std::vector types{kind}; + std::vector kinds{kind}; return descriptor.name() == name && descriptor.receiver_style() == true && - descriptor.types() == types; + descriptor.kinds() == kinds; } MATCHER_P2(MatchesTimezoneTimeAccessor, name, kind, "") { const FunctionDescriptor& descriptor = *arg; - std::vector types{kind, Kind::kString}; + std::vector kinds{kind, Kind::kString}; return descriptor.name() == name && descriptor.receiver_style() == true && - descriptor.types() == types; + descriptor.kinds() == kinds; } TEST(RegisterTimeFunctions, MathOperatorsRegistered) { diff --git a/runtime/standard/type_conversion_functions.cc b/runtime/standard/type_conversion_functions.cc index 76e95751b..2665cee72 100644 --- a/runtime/standard/type_conversion_functions.cc +++ b/runtime/standard/type_conversion_functions.cc @@ -28,6 +28,7 @@ #include "absl/time/time.h" #include "base/builtins.h" #include "base/function_adapter.h" +#include "common/standard_definitions.h" #include "common/value.h" #include "internal/overflow.h" #include "internal/status_macros.h" @@ -80,15 +81,18 @@ Value LegacyFormatDouble(double v, const Function::InvokeContext& context) { absl::Status RegisterBoolConversionFunctions(FunctionRegistry& registry, const RuntimeOptions&) { + using cel::StandardOverloadIds; + // bool -> bool absl::Status status = UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kBool, [](bool v) { return v; }, registry); + cel::builtin::kBool, StandardOverloadIds::kBoolToBool, + [](bool v) { return v; }, registry); CEL_RETURN_IF_ERROR(status); // string -> bool return UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kBool, + cel::builtin::kBool, StandardOverloadIds::kStringToBool, [](const StringValue& v) -> Value { if ((v == "true") || (v == "True") || (v == "TRUE") || (v == "t") || (v == "1")) { @@ -106,16 +110,19 @@ absl::Status RegisterBoolConversionFunctions(FunctionRegistry& registry, absl::Status RegisterIntConversionFunctions(FunctionRegistry& registry, const RuntimeOptions&) { + using cel::StandardOverloadIds; + // bool -> int absl::Status status = UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kInt, [](bool v) { return static_cast(v); }, + cel::builtin::kInt, StandardOverloadIds::kBoolToInt, + [](bool v) { return static_cast(v); }, registry); CEL_RETURN_IF_ERROR(status); // double -> int status = UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kInt, + cel::builtin::kInt, StandardOverloadIds::kDoubleToInt, [](double v) -> Value { auto conv = cel::internal::CheckedDoubleToInt64(v); if (!conv.ok()) { @@ -128,13 +135,14 @@ absl::Status RegisterIntConversionFunctions(FunctionRegistry& registry, // int -> int status = UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kInt, [](int64_t v) { return v; }, registry); + cel::builtin::kInt, StandardOverloadIds::kIntToInt, + [](int64_t v) { return v; }, registry); CEL_RETURN_IF_ERROR(status); // string -> int status = UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kInt, + cel::builtin::kInt, StandardOverloadIds::kStringToInt, [](const StringValue& s) -> Value { int64_t result; if (!absl::SimpleAtoi(s.ToString(), &result)) { @@ -148,13 +156,13 @@ absl::Status RegisterIntConversionFunctions(FunctionRegistry& registry, // time -> int status = UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kInt, [](absl::Time t) { return absl::ToUnixSeconds(t); }, - registry); + cel::builtin::kInt, StandardOverloadIds::kTimestampToInt, + [](absl::Time t) { return absl::ToUnixSeconds(t); }, registry); CEL_RETURN_IF_ERROR(status); // uint -> int return UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kInt, + cel::builtin::kInt, StandardOverloadIds::kUintToInt, [](uint64_t v) -> Value { auto conv = cel::internal::CheckedUint64ToInt64(v); if (!conv.ok()) { @@ -167,6 +175,8 @@ absl::Status RegisterIntConversionFunctions(FunctionRegistry& registry, absl::Status RegisterStringConversionFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { + using cel::StandardOverloadIds; + // May be optionally disabled to reduce potential allocs. if (!options.enable_string_conversion) { return absl::OkStatus(); @@ -174,8 +184,7 @@ absl::Status RegisterStringConversionFunctions(FunctionRegistry& registry, absl::Status status = UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kString, - + cel::builtin::kString, StandardOverloadIds::kBytesToString, [](const BytesValue& value) -> Value { auto valid = value.NativeValue([](const auto& value) -> bool { return internal::Utf8IsValid(value); @@ -191,7 +200,7 @@ absl::Status RegisterStringConversionFunctions(FunctionRegistry& registry, // bool -> string status = UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kString, + cel::builtin::kString, StandardOverloadIds::kBoolToString, [](bool value) -> StringValue { return StringValue(value ? "true" : "false"); }, @@ -200,7 +209,7 @@ absl::Status RegisterStringConversionFunctions(FunctionRegistry& registry, // double -> string status = UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kString, + cel::builtin::kString, StandardOverloadIds::kDoubleToString, (options.enable_precision_preserving_double_format ? &FormatDouble : &LegacyFormatDouble), registry); @@ -208,7 +217,7 @@ absl::Status RegisterStringConversionFunctions(FunctionRegistry& registry, // int -> string status = UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kString, + cel::builtin::kString, StandardOverloadIds::kIntToString, [](int64_t value) -> StringValue { return StringValue(absl::StrCat(value)); }, @@ -218,13 +227,13 @@ absl::Status RegisterStringConversionFunctions(FunctionRegistry& registry, // string -> string status = UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kString, + cel::builtin::kString, StandardOverloadIds::kStringToString, [](StringValue value) -> StringValue { return value; }, registry); CEL_RETURN_IF_ERROR(status); // uint -> string status = UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kString, + cel::builtin::kString, StandardOverloadIds::kUintToString, [](uint64_t value) -> StringValue { return StringValue(absl::StrCat(value)); }, @@ -233,7 +242,7 @@ absl::Status RegisterStringConversionFunctions(FunctionRegistry& registry, // duration -> string status = UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kString, + cel::builtin::kString, StandardOverloadIds::kDurationToString, [](absl::Duration value) -> Value { auto encode = EncodeDurationToJson(value); if (!encode.ok()) { @@ -246,7 +255,7 @@ absl::Status RegisterStringConversionFunctions(FunctionRegistry& registry, // timestamp -> string return UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kString, + cel::builtin::kString, StandardOverloadIds::kTimestampToString, [](absl::Time value) -> Value { auto encode = EncodeTimestampToJson(value); if (!encode.ok()) { @@ -259,10 +268,12 @@ absl::Status RegisterStringConversionFunctions(FunctionRegistry& registry, absl::Status RegisterUintConversionFunctions(FunctionRegistry& registry, const RuntimeOptions&) { + using cel::StandardOverloadIds; + // double -> uint absl::Status status = UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kUint, + cel::builtin::kUint, StandardOverloadIds::kDoubleToUint, [](double v) -> Value { auto conv = cel::internal::CheckedDoubleToUint64(v); if (!conv.ok()) { @@ -275,7 +286,7 @@ absl::Status RegisterUintConversionFunctions(FunctionRegistry& registry, // int -> uint status = UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kUint, + cel::builtin::kUint, StandardOverloadIds::kIntToUint, [](int64_t v) -> Value { auto conv = cel::internal::CheckedInt64ToUint64(v); if (!conv.ok()) { @@ -289,7 +300,7 @@ absl::Status RegisterUintConversionFunctions(FunctionRegistry& registry, // string -> uint status = UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kUint, + cel::builtin::kUint, StandardOverloadIds::kStringToUint, [](const StringValue& s) -> Value { uint64_t result; if (!absl::SimpleAtoi(s.ToString(), &result)) { @@ -303,45 +314,50 @@ absl::Status RegisterUintConversionFunctions(FunctionRegistry& registry, // uint -> uint return UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kUint, [](uint64_t v) { return v; }, registry); + cel::builtin::kUint, StandardOverloadIds::kUintToUint, + [](uint64_t v) { return v; }, registry); } absl::Status RegisterBytesConversionFunctions(FunctionRegistry& registry, const RuntimeOptions&) { + using cel::StandardOverloadIds; + // bytes -> bytes absl::Status status = UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kBytes, - + cel::builtin::kBytes, StandardOverloadIds::kBytesToBytes, [](BytesValue value) -> BytesValue { return value; }, registry); CEL_RETURN_IF_ERROR(status); // string -> bytes return UnaryFunctionAdapter, const StringValue&>:: RegisterGlobalOverload( - cel::builtin::kBytes, + cel::builtin::kBytes, StandardOverloadIds::kStringToBytes, [](const StringValue& value) { return BytesValue(value.ToString()); }, registry); } absl::Status RegisterDoubleConversionFunctions(FunctionRegistry& registry, const RuntimeOptions&) { + using cel::StandardOverloadIds; + // double -> double absl::Status status = UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kDouble, [](double v) { return v; }, registry); + cel::builtin::kDouble, StandardOverloadIds::kDoubleToDouble, + [](double v) { return v; }, registry); CEL_RETURN_IF_ERROR(status); // int -> double status = UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kDouble, [](int64_t v) { return static_cast(v); }, - registry); + cel::builtin::kDouble, StandardOverloadIds::kIntToDouble, + [](int64_t v) { return static_cast(v); }, registry); CEL_RETURN_IF_ERROR(status); // string -> double status = UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kDouble, + cel::builtin::kDouble, StandardOverloadIds::kStringToDouble, [](const StringValue& s) -> Value { double result; if (absl::SimpleAtod(s.ToString(), &result)) { @@ -356,8 +372,8 @@ absl::Status RegisterDoubleConversionFunctions(FunctionRegistry& registry, // uint -> double return UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kDouble, [](uint64_t v) { return static_cast(v); }, - registry); + cel::builtin::kDouble, StandardOverloadIds::kUintToDouble, + [](uint64_t v) { return static_cast(v); }, registry); } Value CreateDurationFromString(const StringValue& dur_str) { @@ -376,10 +392,13 @@ Value CreateDurationFromString(const StringValue& dur_str) { absl::Status RegisterTimeConversionFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { + using cel::StandardOverloadIds; + // duration() conversion from string. CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kDuration, CreateDurationFromString, registry))); + cel::builtin::kDuration, StandardOverloadIds::kStringToDuration, + CreateDurationFromString, registry))); bool enable_timestamp_duration_overflow_errors = options.enable_timestamp_duration_overflow_errors; @@ -387,7 +406,7 @@ absl::Status RegisterTimeConversionFunctions(FunctionRegistry& registry, // timestamp conversion from int. CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kTimestamp, + cel::builtin::kTimestamp, StandardOverloadIds::kIntToTimestamp, [=](int64_t epoch_seconds) -> Value { absl::Time ts = absl::FromUnixSeconds(epoch_seconds); if (enable_timestamp_duration_overflow_errors) { @@ -402,21 +421,21 @@ absl::Status RegisterTimeConversionFunctions(FunctionRegistry& registry, // timestamp -> timestamp CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kTimestamp, + cel::builtin::kTimestamp, StandardOverloadIds::kTimestampToTimestamp, [](absl::Time value) -> Value { return TimestampValue(value); }, registry))); // duration -> duration CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kDuration, + cel::builtin::kDuration, StandardOverloadIds::kDurationToDuration, [](absl::Duration value) -> Value { return DurationValue(value); }, registry))); // timestamp() conversion from string. return UnaryFunctionAdapter:: RegisterGlobalOverload( - cel::builtin::kTimestamp, + cel::builtin::kTimestamp, StandardOverloadIds::kStringToTimestamp, [=](const StringValue& time_str) -> Value { absl::Time ts; if (!absl::ParseTime(absl::RFC3339_full, time_str.ToString(), &ts, @@ -438,6 +457,8 @@ absl::Status RegisterTimeConversionFunctions(FunctionRegistry& registry, absl::Status RegisterTypeConversionFunctions(FunctionRegistry& registry, const RuntimeOptions& options) { + using cel::StandardOverloadIds; + CEL_RETURN_IF_ERROR(RegisterBoolConversionFunctions(registry, options)); CEL_RETURN_IF_ERROR(RegisterBytesConversionFunctions(registry, options)); @@ -456,13 +477,13 @@ absl::Status RegisterTypeConversionFunctions(FunctionRegistry& registry, // TODO(issues/102): strip dyn() function references at type-check time. absl::Status status = UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kDyn, [](const Value& value) -> Value { return value; }, - registry); + cel::builtin::kDyn, StandardOverloadIds::kToDyn, + [](const Value& value) -> Value { return value; }, registry); CEL_RETURN_IF_ERROR(status); // type(dyn) -> type return UnaryFunctionAdapter::RegisterGlobalOverload( - cel::builtin::kType, + cel::builtin::kType, StandardOverloadIds::kToType, [](const Value& value) { return TypeValue(value.GetRuntimeType()); }, registry); } diff --git a/runtime/standard/type_conversion_functions_test.cc b/runtime/standard/type_conversion_functions_test.cc index ece8d454f..a7db149ee 100644 --- a/runtime/standard/type_conversion_functions_test.cc +++ b/runtime/standard/type_conversion_functions_test.cc @@ -28,9 +28,9 @@ using ::testing::UnorderedElementsAre; MATCHER_P3(MatchesUnaryDescriptor, name, receiver, expected_kind, "") { const FunctionDescriptor& descriptor = arg.descriptor; - std::vector types{expected_kind}; + std::vector kinds{expected_kind}; return descriptor.name() == name && descriptor.receiver_style() == receiver && - descriptor.types() == types; + descriptor.kinds() == kinds; } TEST(RegisterTypeConversionFunctions, RegisterBoolConversionFunctions) {