From 89f8662f2cb6a06cc533da5fb6a8c1dfb8c11bf8 Mon Sep 17 00:00:00 2001 From: Muhammad Askri Date: Tue, 12 May 2026 15:45:31 -0700 Subject: [PATCH] Add functions to parse type and function signatures into cel types. PiperOrigin-RevId: 914517275 --- common/BUILD | 31 ++++ common/internal/BUILD | 7 +- common/internal/signature.cc | 298 ++++++++++++++++++++++++++++++ common/internal/signature.h | 26 +++ common/internal/signature_test.cc | 298 ++++++++++++++++++++++++++++++ common/type_spec_resolver.cc | 185 +++++++++++++++++++ common/type_spec_resolver.h | 38 ++++ common/type_spec_resolver_test.cc | 235 +++++++++++++++++++++++ 8 files changed, 1117 insertions(+), 1 deletion(-) create mode 100644 common/type_spec_resolver.cc create mode 100644 common/type_spec_resolver.h create mode 100644 common/type_spec_resolver_test.cc diff --git a/common/BUILD b/common/BUILD index 0ead8b15a..42d7f67aa 100644 --- a/common/BUILD +++ b/common/BUILD @@ -46,6 +46,37 @@ cc_test( ], ) +cc_library( + name = "type_spec_resolver", + srcs = ["type_spec_resolver.cc"], + hdrs = ["type_spec_resolver.h"], + deps = [ + ":ast", + ":type", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "type_spec_resolver_test", + srcs = ["type_spec_resolver_test.cc"], + deps = [ + ":ast", + ":type", + ":type_spec_resolver", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "expr", srcs = ["expr.cc"], diff --git a/common/internal/BUILD b/common/internal/BUILD index 10084b685..ec39e689d 100644 --- a/common/internal/BUILD +++ b/common/internal/BUILD @@ -143,14 +143,16 @@ cc_library( srcs = ["signature.cc"], hdrs = ["signature.h"], deps = [ + "//common:ast", "//common:type", "//common:type_kind", + "//common:type_spec_resolver", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", ], ) @@ -159,11 +161,14 @@ cc_test( srcs = ["signature_test.cc"], deps = [ ":signature", + "//common:ast", "//common:type", + "//common:type_kind", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], diff --git a/common/internal/signature.cc b/common/internal/signature.cc index f63049878..e9a49910d 100644 --- a/common/internal/signature.cc +++ b/common/internal/signature.cc @@ -15,17 +15,25 @@ #include "common/internal/signature.h" #include +#include +#include #include #include +#include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" #include "common/type.h" #include "common/type_kind.h" +#include "common/type_spec_resolver.h" #include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" namespace cel::common_internal { @@ -208,4 +216,294 @@ absl::StatusOr MakeOverloadSignature( return result; } + +namespace { + +std::string Unescape(std::string_view str) { + size_t first_escape = str.find('\\'); + if (first_escape == std::string_view::npos) { + return std::string(str); + } + std::string result; + result.reserve(str.size()); + result.append(str.substr(0, first_escape)); + bool escaped = false; + for (size_t i = first_escape; i < str.size(); ++i) { + char c = str[i]; + if (escaped) { + result.push_back(c); + escaped = false; + } else if (c == '\\') { + escaped = true; + } else { + result.push_back(c); + } + } + if (escaped) { + result.push_back('\\'); + } + return result; +} + +class SignatureScanner { + public: + explicit SignatureScanner(std::string_view input, + std::string_view error_prefix = "Invalid signature") + : input_(input), error_prefix_(error_prefix) {} + + absl::StatusOr FindTopLevelChar(char target, bool find_last = false) { + size_t found_idx = std::string_view::npos; + int nesting = 0; + bool escaped = false; + for (size_t i = 0; i < input_.size(); ++i) { + char c = input_[i]; + if (escaped) { + escaped = false; + continue; + } + if (c == '\\') { + escaped = true; + continue; + } + + if (c == target && nesting == 0) { + if (find_last) { + found_idx = i; + } else { + return i; + } + } + + if (c == '<') { + nesting++; + } else if (c == '>') { + nesting--; + if (nesting < 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": extra closing >")); + } + } + } + if (nesting != 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + return found_idx; + } + + absl::StatusOr> SplitTopLevel(char delimiter) { + std::vector result; + int nesting = 0; + bool escaped = false; + size_t start = 0; + for (size_t i = 0; i < input_.size(); ++i) { + char c = input_[i]; + if (escaped) { + escaped = false; + continue; + } + if (c == '\\') { + escaped = true; + } else if (c == '<') { + nesting++; + } else if (c == '>') { + nesting--; + if (nesting < 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": extra closing >")); + } + } else if (c == delimiter && nesting == 0) { + result.push_back(input_.substr(start, i - start)); + start = i + 1; + } + } + if (nesting != 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + if (start < input_.size()) { + result.push_back(input_.substr(start)); + } + return result; + } + + private: + std::string_view input_; + std::string_view error_prefix_; +}; + +absl::StatusOr> SplitTypeList( + std::string_view params) { + return SignatureScanner(params, "Invalid type signature").SplitTopLevel(','); +} + +absl::StatusOr ParseTypeSignature(std::string_view signature) { + if (signature.empty()) { + return absl::InvalidArgumentError("Empty type signature"); + } + + if (signature[0] == '~') { + return TypeSpec(ParamTypeSpec(Unescape(signature.substr(1)))); + } + + CEL_ASSIGN_OR_RETURN(size_t less_idx, + SignatureScanner(signature, "Invalid type signature") + .FindTopLevelChar('<', /*find_last=*/false)); + + std::string name_str; + std::vector params; + + if (less_idx != std::string_view::npos) { + // If the signature contains a '<', it must also contain a matching '>'. + if (signature.back() != '>') { + return absl::InvalidArgumentError( + "Invalid type signature: missing closing >"); + } + name_str = Unescape(signature.substr(0, less_idx)); + std::string_view params_str = + signature.substr(less_idx + 1, signature.size() - less_idx - 2); + CEL_ASSIGN_OR_RETURN(auto param_list, SplitTypeList(params_str)); + for (std::string_view param_str : param_list) { + CEL_ASSIGN_OR_RETURN(auto param, ParseTypeSignature(param_str)); + params.push_back(std::move(param)); + } + } else { + name_str = Unescape(signature); + } + + if (name_str == "null") return TypeSpec(NullTypeSpec()); + if (name_str == "bool") return TypeSpec(PrimitiveType::kBool); + if (name_str == "int") return TypeSpec(PrimitiveType::kInt64); + if (name_str == "uint") return TypeSpec(PrimitiveType::kUint64); + if (name_str == "double") return TypeSpec(PrimitiveType::kDouble); + if (name_str == "string") return TypeSpec(PrimitiveType::kString); + if (name_str == "bytes") return TypeSpec(PrimitiveType::kBytes); + if (name_str == "any") return TypeSpec(WellKnownTypeSpec::kAny); + if (name_str == "timestamp") return TypeSpec(WellKnownTypeSpec::kTimestamp); + if (name_str == "duration") return TypeSpec(WellKnownTypeSpec::kDuration); + if (name_str == "dyn") return TypeSpec(DynTypeSpec()); + + if (name_str == "bool_wrapper") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)); + if (name_str == "int_wrapper") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)); + if (name_str == "uint_wrapper") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)); + if (name_str == "double_wrapper") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)); + if (name_str == "string_wrapper") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)); + if (name_str == "bytes_wrapper") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)); + + if (name_str == "list") { + if (params.size() > 1) { + return absl::InvalidArgumentError( + "Invalid type signature: list expects at most 1 parameter"); + } + auto elem = std::make_unique(DynTypeSpec()); + if (!params.empty()) { + *elem = std::move(params[0]); + } + return TypeSpec(ListTypeSpec(std::move(elem))); + } + + if (name_str == "map") { + if (!params.empty() && params.size() != 2) { + return absl::InvalidArgumentError( + "Invalid type signature: map expects 0 or 2 parameters"); + } + auto key = std::make_unique(DynTypeSpec()); + auto value = std::make_unique(DynTypeSpec()); + if (!params.empty()) { + *key = std::move(params[0]); + } + if (params.size() > 1) { + *value = std::move(params[1]); + } + return TypeSpec(MapTypeSpec(std::move(key), std::move(value))); + } + + if (name_str == "function") { + auto result_type = std::make_unique(DynTypeSpec()); + std::vector arg_types; + if (!params.empty()) { + result_type = std::make_unique(std::move(params[0])); + for (size_t i = 1; i < params.size(); ++i) { + arg_types.push_back(std::move(params[i])); + } + } + return TypeSpec( + FunctionTypeSpec(std::move(result_type), std::move(arg_types))); + } + + return TypeSpec(AbstractType(name_str, std::move(params))); +} + +} // namespace + +absl::StatusOr ParseFunctionSignature( + std::string_view signature) { + if (signature.empty()) { + return absl::InvalidArgumentError("Empty function signature"); + } + + CEL_ASSIGN_OR_RETURN(size_t paren_idx, + SignatureScanner(signature, "Invalid function signature") + .FindTopLevelChar('(', /*find_last=*/false)); + + if (paren_idx == std::string_view::npos || signature.back() != ')') { + return absl::InvalidArgumentError("Invalid function signature"); + } + + std::string_view prefix = signature.substr(0, paren_idx); + std::string_view args_str = + signature.substr(paren_idx + 1, signature.size() - paren_idx - 2); + + std::vector arg_types; + ParsedFunctionOverload out; + + CEL_ASSIGN_OR_RETURN(size_t dot_idx, + SignatureScanner(prefix, "Invalid function signature") + .FindTopLevelChar('.', /*find_last=*/true)); + + if (dot_idx != std::string_view::npos) { + out.is_member = true; + std::string_view receiver_str = prefix.substr(0, dot_idx); + std::string_view func_str = prefix.substr(dot_idx + 1); + + CEL_ASSIGN_OR_RETURN(auto receiver_param, ParseTypeSignature(receiver_str)); + arg_types.push_back(std::move(receiver_param)); + out.function_name = Unescape(func_str); + } else { + out.is_member = false; + out.function_name = Unescape(prefix); + } + + if (out.function_name.empty()) { + return absl::InvalidArgumentError( + "Invalid function signature: empty function name"); + } + + if (!args_str.empty()) { + CEL_ASSIGN_OR_RETURN(auto arg_list, SplitTypeList(args_str)); + for (std::string_view arg_str : arg_list) { + CEL_ASSIGN_OR_RETURN(auto arg_param, ParseTypeSignature(arg_str)); + arg_types.push_back(std::move(arg_param)); + } + } + + auto result_type = std::make_unique(DynTypeSpec()); + out.signature_type = + TypeSpec(FunctionTypeSpec(std::move(result_type), std::move(arg_types))); + + return out; +} + +absl::StatusOr ParseType(std::string_view signature, google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool) { + CEL_ASSIGN_OR_RETURN(auto type_spec, ParseTypeSignature(signature)); + return cel::ConvertTypeSpecToType(type_spec, arena, pool); +} + } // namespace cel::common_internal diff --git a/common/internal/signature.h b/common/internal/signature.h index 3f31d8fd1..5b91f0b5c 100644 --- a/common/internal/signature.h +++ b/common/internal/signature.h @@ -20,7 +20,10 @@ #include #include "absl/status/statusor.h" +#include "common/ast.h" #include "common/type.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" namespace cel::common_internal { @@ -56,6 +59,29 @@ absl::StatusOr MakeOverloadSignature( std::string_view function_name, const std::vector& args, bool is_member); +// Parses a string type signature directly into a `cel::Type`. +absl::StatusOr ParseType(std::string_view signature, google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool); + +// A parsed function overload signature with the function name, flag for member +// function, and the function signature type. +struct ParsedFunctionOverload { + std::string function_name; + bool is_member = false; + TypeSpec signature_type; + + bool operator==(const ParsedFunctionOverload& other) const { + return function_name == other.function_name && + is_member == other.is_member && + signature_type == other.signature_type; + } +}; + +// Parses a string function overload signature directly into a +// `cel::TypeSpec` configured as a `FunctionTypeSpec`. +absl::StatusOr ParseFunctionSignature( + std::string_view signature); + } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_SIGNATURE_H_ diff --git a/common/internal/signature_test.cc b/common/internal/signature_test.cc index 8e41c70fb..445b9443c 100644 --- a/common/internal/signature_test.cc +++ b/common/internal/signature_test.cc @@ -13,13 +13,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include "absl/base/no_destructor.h" #include "absl/status/status.h" +#include "absl/status/status_matchers.h" #include "absl/status/statusor.h" +#include "common/ast.h" #include "common/type.h" +#include "common/type_kind.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "google/protobuf/arena.h" @@ -77,6 +81,10 @@ std::vector GetTypeSignatureTestCases() { .type = ListType(GetTestArena(), TypeParamType("A")), .expected_signature = "list<~A>", }, + { + .type = ListType(GetTestArena(), TypeParamType("A GetTypeSignatureTestCases() { .type = TimestampType{}, .expected_signature = "timestamp", }, + { + .type = BoolWrapperType{}, + .expected_signature = "bool_wrapper", + }, { .type = IntWrapperType{}, .expected_signature = "int_wrapper", }, + { + .type = UintWrapperType{}, + .expected_signature = "uint_wrapper", + }, { .type = MessageType(GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes")), @@ -202,10 +218,18 @@ std::vector GetOverloadSignatureTestCases() { .args = {TimestampType{}}, .expected_signature = "hello(timestamp)", }, + { + .args = {BoolWrapperType{}}, + .expected_signature = "hello(bool_wrapper)", + }, { .args = {IntWrapperType{}}, .expected_signature = "hello(int_wrapper)", }, + { + .args = {UintWrapperType{}}, + .expected_signature = "hello(uint_wrapper)", + }, { .args = {MessageType( GetTestingDescriptorPool()->FindMessageTypeByName( @@ -231,6 +255,13 @@ std::vector GetOverloadSignatureTestCases() { .is_member = true, .expected_signature = "string.hello(bool,dyn)", }, + { + .function_name = "hello", + .args = {OpaqueType(GetTestArena(), "bar", + {TypeParamType("dummy.type")})}, + .is_member = true, + .expected_signature = R"(bar<~dummy\.type>.hello())", + }, { .function_name = R"(h.(e),l\o)", .args = {StringType{}, @@ -245,5 +276,272 @@ std::vector GetOverloadSignatureTestCases() { INSTANTIATE_TEST_SUITE_P(OverloadIdTest, OverloadSignatureTest, ValuesIn(GetOverloadSignatureTestCases())); +TEST(ParseSignatureTest, ProtoParsing) { + auto t1 = ParseType("int", GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t1, ::absl_testing::IsOk()); + EXPECT_TRUE(t1->IsInt()); + + auto t2 = ParseType("list<~A>", GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t2, ::absl_testing::IsOk()); + EXPECT_TRUE(t2->IsList()); + + auto t3 = ParseType(R"(~abc\)", GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t3, ::absl_testing::IsOk()); + EXPECT_TRUE(t3->IsTypeParam()); + EXPECT_EQ(t3->GetTypeParam().name(), R"(abc\)"); +} + +TEST(ParseSignatureTest, FunctionParsing) { + auto f1 = ParseFunctionSignature("hello(string)"); + ASSERT_THAT(f1, ::absl_testing::IsOk()); + EXPECT_TRUE(f1->signature_type.has_function()); + EXPECT_EQ(f1->signature_type.function().arg_types().size(), 1); + + auto f2 = ParseFunctionSignature("a.b.c()"); + ASSERT_THAT(f2, ::absl_testing::IsOk()); + EXPECT_TRUE(f2->is_member); + EXPECT_EQ(f2->function_name, "c"); +} + +TEST(ParseSignatureTest, ParsingErrors) { + EXPECT_THAT( + ParseType("list>", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("extra closing >"))); + EXPECT_THAT( + ParseType("list", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseType("list><", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("extra closing >"))); + + EXPECT_THAT(ParseFunctionSignature("hello(list>)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("extra closing >"))); + EXPECT_THAT(ParseFunctionSignature("hello(list)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseFunctionSignature("foo", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("list expects at most 1 parameter"))); + EXPECT_THAT( + ParseType("map", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("map expects 0 or 2 parameters"))); + EXPECT_THAT(ParseType("map", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("map expects 0 or 2 parameters"))); + + EXPECT_THAT(ParseFunctionSignature("()"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("empty function name"))); + EXPECT_THAT(ParseFunctionSignature("string.()"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("empty function name"))); + + EXPECT_THAT( + ParseType("listfoo", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("missing closing >"))); + + EXPECT_THAT(ParseFunctionSignature("hello>(string)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("extra closing >"))); + + EXPECT_THAT( + ParseType("list<", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); +} + +TEST(ParseSignatureTest, MessageTypeWithParamsError) { + EXPECT_THAT(ParseType("cel.expr.conformance.proto3.TestAllTypes", + GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +TEST(ParseSignatureTest, MissingClosingParenthesisError) { + EXPECT_THAT(ParseFunctionSignature("hello(string"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid function signature"))); + EXPECT_THAT(ParseFunctionSignature("hello)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid function signature"))); +} + +TEST(ParseSignatureTest, NestedDotsNonMember) { + auto f1 = ParseFunctionSignature( + "my_opaque()"); + ASSERT_THAT(f1, ::absl_testing::IsOk()); + EXPECT_FALSE(f1->is_member); + EXPECT_EQ(f1->function_name, + "my_opaque"); +} + +TEST(ParseSignatureTest, HighlyTrickyGenerics) { + auto t1 = ParseType("map>,map>>", + GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t1, ::absl_testing::IsOk()); + EXPECT_TRUE(t1->IsMap()); + + auto t2 = ParseType(R"(~abc\\)", GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t2, ::absl_testing::IsOk()); + EXPECT_TRUE(t2->IsTypeParam()); + EXPECT_EQ(t2->GetTypeParam().name(), R"(abc\)"); + + auto t3 = + ParseType(R"(~abc\\\\)", GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t3, ::absl_testing::IsOk()); + EXPECT_TRUE(t3->IsTypeParam()); + EXPECT_EQ(t3->GetTypeParam().name(), R"(abc\\)"); +} + +TEST(ParseSignatureTest, ComplexMemberScan) { + auto f1 = ParseFunctionSignature( + "bar>,map>.func(string)"); + ASSERT_THAT(f1, ::absl_testing::IsOk()); + EXPECT_TRUE(f1->is_member); + EXPECT_EQ(f1->function_name, "func"); + EXPECT_TRUE(f1->signature_type.has_function()); + EXPECT_EQ(f1->signature_type.function().arg_types().size(), 2); +} + +TEST(ParseSignatureTest, EmptyOrWhitespaceErrors) { + EXPECT_THAT(ParseType("", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty type signature"))); + EXPECT_THAT(ParseFunctionSignature(""), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty function signature"))); + EXPECT_THAT(ParseType("list>", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty type signature"))); +} + +void VerifyParsedMatchesType(const TypeSpec& parsed, const Type& original) { + switch (original.kind()) { + case TypeKind::kDyn: + EXPECT_TRUE(parsed.has_dyn()); + break; + case TypeKind::kNull: + EXPECT_TRUE(parsed.has_null()); + break; + case TypeKind::kBool: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kBool); + break; + case TypeKind::kInt: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kInt64); + break; + case TypeKind::kUint: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kUint64); + break; + case TypeKind::kDouble: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kDouble); + break; + case TypeKind::kString: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kString); + break; + case TypeKind::kBytes: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kBytes); + break; + case TypeKind::kAny: + EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kAny); + break; + case TypeKind::kTimestamp: + EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kTimestamp); + break; + case TypeKind::kDuration: + EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kDuration); + break; + case TypeKind::kList: + EXPECT_TRUE(parsed.has_list_type()); + if (!original.GetParameters().empty()) { + VerifyParsedMatchesType(parsed.list_type().elem_type(), + original.GetParameters()[0]); + } + break; + case TypeKind::kMap: + EXPECT_TRUE(parsed.has_map_type()); + if (!original.GetParameters().empty()) { + VerifyParsedMatchesType(parsed.map_type().key_type(), + original.GetParameters()[0]); + } + if (original.GetParameters().size() > 1) { + VerifyParsedMatchesType(parsed.map_type().value_type(), + original.GetParameters()[1]); + } + break; + case TypeKind::kBoolWrapper: + case TypeKind::kIntWrapper: + case TypeKind::kUintWrapper: + case TypeKind::kDoubleWrapper: + case TypeKind::kStringWrapper: + case TypeKind::kBytesWrapper: + EXPECT_TRUE(parsed.has_wrapper()); + break; + case TypeKind::kTypeParam: + EXPECT_TRUE(parsed.has_type_param()); + break; + default: + EXPECT_TRUE(parsed.has_abstract_type()); + break; + } +} + +void VerifyTypesEqual(const Type& lhs, const Type& rhs) { + EXPECT_EQ(lhs.kind(), rhs.kind()); + if (lhs.kind() != rhs.kind()) return; + + if (lhs.kind() == TypeKind::kOpaque || lhs.kind() == TypeKind::kTypeParam) { + EXPECT_EQ(lhs.name(), rhs.name()); + } + + const auto& lhs_params = lhs.GetParameters(); + const auto& rhs_params = rhs.GetParameters(); + EXPECT_EQ(lhs_params.size(), rhs_params.size()); + if (lhs_params.size() == rhs_params.size()) { + for (size_t i = 0; i < lhs_params.size(); ++i) { + VerifyTypesEqual(lhs_params[i], rhs_params[i]); + } + } +} + +TEST_P(TypeSignatureTest, ParseTypeCheck) { + const auto& param = GetParam(); + if (!param.expected_signature.empty() && param.expected_error.empty()) { + auto parsed = ParseType(param.expected_signature, GetTestArena(), + *GetTestingDescriptorPool()); + ASSERT_THAT(parsed, ::absl_testing::IsOk()); + VerifyTypesEqual(*parsed, param.type); + } +} + +TEST_P(OverloadSignatureTest, ExhaustiveFunctionParseCheck) { + const auto& param = GetParam(); + if (!param.expected_signature.empty()) { + auto parsed = ParseFunctionSignature(param.expected_signature); + ASSERT_THAT(parsed, ::absl_testing::IsOk()); + EXPECT_EQ(parsed->function_name, param.function_name); + EXPECT_EQ(parsed->is_member, param.is_member); + EXPECT_TRUE(parsed->signature_type.has_function()); + const auto& func = parsed->signature_type.function(); + for (size_t i = 0; i < param.args.size(); ++i) { + VerifyParsedMatchesType(func.arg_types()[i], param.args[i]); + } + } +} + } // namespace } // namespace cel::common_internal diff --git a/common/type_spec_resolver.cc b/common/type_spec_resolver.cc new file mode 100644 index 000000000..4df38ba89 --- /dev/null +++ b/common/type_spec_resolver.cc @@ -0,0 +1,185 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_spec_resolver.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/type.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +absl::StatusOr ConvertTypeSpecToType(const TypeSpec& type_spec, + google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool) { + if (type_spec.has_null()) return Type(NullType{}); + if (type_spec.has_dyn()) return Type(DynType{}); + + if (type_spec.has_primitive()) { + switch (type_spec.primitive()) { + case PrimitiveType::kBool: + return Type(BoolType{}); + case PrimitiveType::kInt64: + return Type(IntType{}); + case PrimitiveType::kUint64: + return Type(UintType{}); + case PrimitiveType::kDouble: + return Type(DoubleType{}); + case PrimitiveType::kString: + return Type(StringType{}); + case PrimitiveType::kBytes: + return Type(BytesType{}); + default: + return absl::InvalidArgumentError("Unsupported primitive type"); + } + } + + if (type_spec.has_well_known()) { + switch (type_spec.well_known()) { + case WellKnownTypeSpec::kAny: + return Type(AnyType{}); + case WellKnownTypeSpec::kTimestamp: + return Type(TimestampType{}); + case WellKnownTypeSpec::kDuration: + return Type(DurationType{}); + default: + return absl::InvalidArgumentError("Unsupported well-known type"); + } + } + + if (type_spec.has_wrapper()) { + switch (type_spec.wrapper()) { + case PrimitiveType::kBool: + return Type(BoolWrapperType{}); + case PrimitiveType::kInt64: + return Type(IntWrapperType{}); + case PrimitiveType::kUint64: + return Type(UintWrapperType{}); + case PrimitiveType::kDouble: + return Type(DoubleWrapperType{}); + case PrimitiveType::kString: + return Type(StringWrapperType{}); + case PrimitiveType::kBytes: + return Type(BytesWrapperType{}); + default: + return absl::InvalidArgumentError("Unsupported wrapper type"); + } + } + + if (type_spec.has_list_type()) { + CEL_ASSIGN_OR_RETURN( + auto elem_type, + ConvertTypeSpecToType(type_spec.list_type().elem_type(), arena, pool)); + return Type(ListType(arena, elem_type)); + } + + if (type_spec.has_map_type()) { + CEL_ASSIGN_OR_RETURN( + auto key_type, + ConvertTypeSpecToType(type_spec.map_type().key_type(), arena, pool)); + CEL_ASSIGN_OR_RETURN( + auto value_type, + ConvertTypeSpecToType(type_spec.map_type().value_type(), arena, pool)); + return Type(MapType(arena, key_type, value_type)); + } + + if (type_spec.has_function()) { + const auto& func_spec = type_spec.function(); + CEL_ASSIGN_OR_RETURN( + auto result_type, + ConvertTypeSpecToType(func_spec.result_type(), arena, pool)); + std::vector arg_types; + for (const auto& arg_spec : func_spec.arg_types()) { + CEL_ASSIGN_OR_RETURN(auto arg_type, + ConvertTypeSpecToType(arg_spec, arena, pool)); + arg_types.push_back(std::move(arg_type)); + } + return Type(FunctionType(arena, result_type, arg_types)); + } + + if (type_spec.has_type_param()) { + const std::string& name = type_spec.type_param().type(); + auto* allocated_name = google::protobuf::Arena::Create(arena, name); + return Type(TypeParamType(absl::string_view(*allocated_name))); + } + + if (type_spec.has_message_type()) { + const std::string& name = type_spec.message_type().type(); + const google::protobuf::Descriptor* descriptor = pool.FindMessageTypeByName(name); + if (descriptor == nullptr) { + return absl::InvalidArgumentError(absl::StrCat( + "Message type '", name, "' not found in descriptor pool")); + } + return Type::Message(descriptor); + } + + if (type_spec.has_abstract_type()) { + const std::string& name = type_spec.abstract_type().name(); + + // Check if it's a message type in the pool + const google::protobuf::Descriptor* descriptor = pool.FindMessageTypeByName(name); + if (descriptor != nullptr) { + if (!type_spec.abstract_type().parameter_types().empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Message type '", name, "' cannot have type parameters")); + } + return Type::Message(descriptor); + } + + // Check if it's an enum type in the pool + const google::protobuf::EnumDescriptor* enum_descriptor = + pool.FindEnumTypeByName(name); + if (enum_descriptor != nullptr) { + if (!type_spec.abstract_type().parameter_types().empty()) { + return absl::InvalidArgumentError( + absl::StrCat("Enum type '", name, "' cannot have type parameters")); + } + return Type::Enum(enum_descriptor); + } + + // Otherwise fallback to OpaqueType + std::vector params; + for (const auto& param_spec : type_spec.abstract_type().parameter_types()) { + CEL_ASSIGN_OR_RETURN(auto param, + ConvertTypeSpecToType(param_spec, arena, pool)); + params.push_back(std::move(param)); + } + auto* allocated_name = google::protobuf::Arena::Create(arena, name); + return Type(OpaqueType(arena, absl::string_view(*allocated_name), params)); + } + + if (type_spec.has_type()) { + CEL_ASSIGN_OR_RETURN(auto contained_type, + ConvertTypeSpecToType(type_spec.type(), arena, pool)); + return Type(TypeType(arena, contained_type)); + } + + if (type_spec.has_error()) { + return Type(ErrorType{}); + } + + return absl::InvalidArgumentError("Unknown TypeSpec kind"); +} + +} // namespace cel diff --git a/common/type_spec_resolver.h b/common/type_spec_resolver.h new file mode 100644 index 000000000..9de7f01f0 --- /dev/null +++ b/common/type_spec_resolver.h @@ -0,0 +1,38 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_SPEC_RESOLVER_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_SPEC_RESOLVER_H_ + +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/type.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Converts a native C++ intermediate AST node `cel::TypeSpec` into a +// `cel::Type`. +// +// This acts as a general translation layer between static analysis/checking +// type metadata schemas (`cel::TypeSpec`) and the runtime composed `cel::Type` +// definitions. +absl::StatusOr ConvertTypeSpecToType(const TypeSpec& type_spec, + google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_SPEC_RESOLVER_H_ diff --git a/common/type_spec_resolver_test.cc b/common/type_spec_resolver_test.cc new file mode 100644 index 000000000..93aa73cb8 --- /dev/null +++ b/common/type_spec_resolver_test.cc @@ -0,0 +1,235 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_spec_resolver.h" + +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "common/ast.h" +#include "common/type.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::internal::GetTestingDescriptorPool; +using ::testing::HasSubstr; + +google::protobuf::Arena* GetTestArena() { + static absl::NoDestructor arena; + return &*arena; +} + +TEST(TypeSpecResolverTest, NullTypeSpec) { + TypeSpec spec(NullTypeSpec{}); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsNull()); +} + +TEST(TypeSpecResolverTest, DynTypeSpec) { + TypeSpec spec(DynTypeSpec{}); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsDyn()); +} + +TEST(TypeSpecResolverTest, PrimitiveTypes) { + const std::vector primitives = { + PrimitiveType::kBool, PrimitiveType::kInt64, PrimitiveType::kUint64, + PrimitiveType::kDouble, PrimitiveType::kString, PrimitiveType::kBytes}; + for (auto p : primitives) { + TypeSpec spec(p); + auto t = ConvertTypeSpecToType(spec, GetTestArena(), + *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsBool() || t->IsInt() || t->IsUint() || t->IsDouble() || + t->IsString() || t->IsBytes()); + } +} + +TEST(TypeSpecResolverTest, WellKnownTypes) { + const std::vector well_knowns = { + WellKnownTypeSpec::kAny, WellKnownTypeSpec::kTimestamp, + WellKnownTypeSpec::kDuration}; + for (auto wk : well_knowns) { + TypeSpec spec(wk); + auto t = ConvertTypeSpecToType(spec, GetTestArena(), + *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsAny() || t->IsTimestamp() || t->IsDuration()); + } +} + +TEST(TypeSpecResolverTest, PrimitiveTypeWrappers) { + const std::vector primitives = { + PrimitiveType::kBool, PrimitiveType::kInt64, PrimitiveType::kUint64, + PrimitiveType::kDouble, PrimitiveType::kString, PrimitiveType::kBytes}; + for (auto p : primitives) { + TypeSpec spec{PrimitiveTypeWrapper(p)}; + auto t = ConvertTypeSpecToType(spec, GetTestArena(), + *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsWrapper()); + } +} + +TEST(TypeSpecResolverTest, ListTypeConversion) { + auto elem = std::make_unique(PrimitiveType::kInt64); + TypeSpec spec(ListTypeSpec(std::move(elem))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsList()); + EXPECT_TRUE(t->GetList().element().IsInt()); +} + +TEST(TypeSpecResolverTest, MapTypeConversion) { + auto key = std::make_unique(PrimitiveType::kString); + auto val = std::make_unique(PrimitiveType::kBytes); + TypeSpec spec(MapTypeSpec(std::move(key), std::move(val))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsMap()); + EXPECT_TRUE(t->GetMap().key().IsString()); + EXPECT_TRUE(t->GetMap().value().IsBytes()); +} + +TEST(TypeSpecResolverTest, FunctionTypeConversion) { + auto result = std::make_unique(PrimitiveType::kBool); + std::vector args; + args.push_back(TypeSpec(PrimitiveType::kString)); + TypeSpec spec(FunctionTypeSpec(std::move(result), std::move(args))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsFunction()); + EXPECT_EQ(t->GetFunction().args().size(), 1); + EXPECT_TRUE(t->GetFunction().result().IsBool()); +} + +TEST(TypeSpecResolverTest, TypeParamConversion) { + TypeSpec spec(ParamTypeSpec("T")); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsTypeParam()); + EXPECT_EQ(t->GetTypeParam().name(), "T"); +} + +TEST(TypeSpecResolverTest, MessageTypeConversion) { + TypeSpec spec( + AbstractType("cel.expr.conformance.proto3.TestAllTypes", /*params=*/{})); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsMessage()); + EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes"); +} + +TEST(TypeSpecResolverTest, MessageTypeWithParamsError) { + std::vector params; + params.push_back(TypeSpec(PrimitiveType::kInt64)); + TypeSpec spec(AbstractType("cel.expr.conformance.proto3.TestAllTypes", + std::move(params))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + EXPECT_THAT(t, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +TEST(TypeSpecResolverTest, UnresolvedAbstractTypeFallbackToOpaque) { + std::vector params; + params.push_back(TypeSpec(PrimitiveType::kInt64)); + TypeSpec spec(AbstractType("my.custom.OpaqueType", std::move(params))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsOpaque()); + EXPECT_EQ(t->name(), "my.custom.OpaqueType"); + EXPECT_EQ(t->GetParameters().size(), 1); +} + +TEST(TypeSpecResolverTest, TypeTypeConversion) { + auto nested = std::make_unique(PrimitiveType::kInt64); + TypeSpec spec(std::move(nested)); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsType()); + EXPECT_TRUE(t->GetType().GetType().IsInt()); +} + +TEST(TypeSpecResolverTest, ErrorTypeConversion) { + TypeSpec spec(ErrorTypeSpec::kValue); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsError()); +} + +TEST(TypeSpecResolverTest, MessageTypeSpecConversion) { + TypeSpec spec(MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsMessage()); + EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes"); +} + +TEST(TypeSpecResolverTest, MessageTypeSpecNotFoundError) { + TypeSpec spec(MessageTypeSpec("cel.expr.conformance.proto3.NonExistentType")); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + EXPECT_THAT(t, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("not found in descriptor pool"))); +} + +TEST(TypeSpecResolverTest, EnumTypeConversion) { + TypeSpec spec(AbstractType( + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum", /*params=*/{})); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsEnum()); + EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"); +} + +TEST(TypeSpecResolverTest, EnumTypeWithParamsError) { + std::vector params; + params.push_back(TypeSpec(PrimitiveType::kInt64)); + TypeSpec spec( + AbstractType("cel.expr.conformance.proto3.TestAllTypes.NestedEnum", + std::move(params))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + EXPECT_THAT(t, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +} // namespace +} // namespace cel