Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,37 @@ cc_test(
],
)

cc_library(
name = "type_spec_resolver",
srcs = ["type_spec_resolver.cc"],
hdrs = ["type_spec_resolver.h"],
deps = [
":ast",
":type",
"//internal:status_macros",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_protobuf//:protobuf",
],
)

cc_test(
name = "type_spec_resolver_test",
srcs = ["type_spec_resolver_test.cc"],
deps = [
":ast",
":type",
":type_spec_resolver",
"//internal:testing",
"//internal:testing_descriptor_pool",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:status_matchers",
"@com_google_protobuf//:protobuf",
],
)

cc_library(
name = "expr",
srcs = ["expr.cc"],
Expand Down
7 changes: 6 additions & 1 deletion common/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,16 @@ cc_library(
srcs = ["signature.cc"],
hdrs = ["signature.h"],
deps = [
"//common:ast",
"//common:type",
"//common:type_kind",
"//common:type_spec_resolver",
"//internal:status_macros",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@com_google_protobuf//:protobuf",
],
)

Expand All @@ -159,11 +161,14 @@ cc_test(
srcs = ["signature_test.cc"],
deps = [
":signature",
"//common:ast",
"//common:type",
"//common:type_kind",
"//internal:testing",
"//internal:testing_descriptor_pool",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:status_matchers",
"@com_google_absl//absl/status:statusor",
"@com_google_protobuf//:protobuf",
],
Expand Down
298 changes: 298 additions & 0 deletions common/internal/signature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,25 @@
#include "common/internal/signature.h"

#include <cstddef>
#include <cstring>
#include <memory>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "common/ast.h"
#include "common/type.h"
#include "common/type_kind.h"
#include "common/type_spec_resolver.h"
#include "internal/status_macros.h"
#include "google/protobuf/arena.h"
#include "google/protobuf/descriptor.h"

namespace cel::common_internal {

Expand Down Expand Up @@ -208,4 +216,294 @@ absl::StatusOr<std::string> MakeOverloadSignature(

return result;
}

namespace {

std::string Unescape(std::string_view str) {
size_t first_escape = str.find('\\');
if (first_escape == std::string_view::npos) {
return std::string(str);
}
std::string result;
result.reserve(str.size());
result.append(str.substr(0, first_escape));
bool escaped = false;
for (size_t i = first_escape; i < str.size(); ++i) {
char c = str[i];
if (escaped) {
result.push_back(c);
escaped = false;
} else if (c == '\\') {
escaped = true;
} else {
result.push_back(c);
}
}
if (escaped) {
result.push_back('\\');
}
return result;
}

class SignatureScanner {
public:
explicit SignatureScanner(std::string_view input,
std::string_view error_prefix = "Invalid signature")
: input_(input), error_prefix_(error_prefix) {}

absl::StatusOr<size_t> FindTopLevelChar(char target, bool find_last = false) {
size_t found_idx = std::string_view::npos;
int nesting = 0;
bool escaped = false;
for (size_t i = 0; i < input_.size(); ++i) {
char c = input_[i];
if (escaped) {
escaped = false;
continue;
}
if (c == '\\') {
escaped = true;
continue;
}

if (c == target && nesting == 0) {
if (find_last) {
found_idx = i;
} else {
return i;
}
}

if (c == '<') {
nesting++;
} else if (c == '>') {
nesting--;
if (nesting < 0) {
return absl::InvalidArgumentError(
absl::StrCat(error_prefix_, ": extra closing >"));
}
}
}
if (nesting != 0) {
return absl::InvalidArgumentError(
absl::StrCat(error_prefix_, ": mismatched brackets"));
}
return found_idx;
}

absl::StatusOr<std::vector<std::string_view>> SplitTopLevel(char delimiter) {
std::vector<std::string_view> result;
int nesting = 0;
bool escaped = false;
size_t start = 0;
for (size_t i = 0; i < input_.size(); ++i) {
char c = input_[i];
if (escaped) {
escaped = false;
continue;
}
if (c == '\\') {
escaped = true;
} else if (c == '<') {
nesting++;
} else if (c == '>') {
nesting--;
if (nesting < 0) {
return absl::InvalidArgumentError(
absl::StrCat(error_prefix_, ": extra closing >"));
}
} else if (c == delimiter && nesting == 0) {
result.push_back(input_.substr(start, i - start));
start = i + 1;
}
}
if (nesting != 0) {
return absl::InvalidArgumentError(
absl::StrCat(error_prefix_, ": mismatched brackets"));
}
if (start < input_.size()) {
result.push_back(input_.substr(start));
}
return result;
}

private:
std::string_view input_;
std::string_view error_prefix_;
};

absl::StatusOr<std::vector<std::string_view>> SplitTypeList(
std::string_view params) {
return SignatureScanner(params, "Invalid type signature").SplitTopLevel(',');
}

absl::StatusOr<TypeSpec> ParseTypeSignature(std::string_view signature) {
if (signature.empty()) {
return absl::InvalidArgumentError("Empty type signature");
}

if (signature[0] == '~') {
return TypeSpec(ParamTypeSpec(Unescape(signature.substr(1))));
}

CEL_ASSIGN_OR_RETURN(size_t less_idx,
SignatureScanner(signature, "Invalid type signature")
.FindTopLevelChar('<', /*find_last=*/false));

std::string name_str;
std::vector<TypeSpec> params;

if (less_idx != std::string_view::npos) {
// If the signature contains a '<', it must also contain a matching '>'.
if (signature.back() != '>') {
return absl::InvalidArgumentError(
"Invalid type signature: missing closing >");
}
name_str = Unescape(signature.substr(0, less_idx));
std::string_view params_str =
signature.substr(less_idx + 1, signature.size() - less_idx - 2);
CEL_ASSIGN_OR_RETURN(auto param_list, SplitTypeList(params_str));
for (std::string_view param_str : param_list) {
CEL_ASSIGN_OR_RETURN(auto param, ParseTypeSignature(param_str));
params.push_back(std::move(param));
}
} else {
name_str = Unescape(signature);
}

if (name_str == "null") return TypeSpec(NullTypeSpec());
if (name_str == "bool") return TypeSpec(PrimitiveType::kBool);
if (name_str == "int") return TypeSpec(PrimitiveType::kInt64);
if (name_str == "uint") return TypeSpec(PrimitiveType::kUint64);
if (name_str == "double") return TypeSpec(PrimitiveType::kDouble);
if (name_str == "string") return TypeSpec(PrimitiveType::kString);
if (name_str == "bytes") return TypeSpec(PrimitiveType::kBytes);
if (name_str == "any") return TypeSpec(WellKnownTypeSpec::kAny);
if (name_str == "timestamp") return TypeSpec(WellKnownTypeSpec::kTimestamp);
if (name_str == "duration") return TypeSpec(WellKnownTypeSpec::kDuration);
if (name_str == "dyn") return TypeSpec(DynTypeSpec());

if (name_str == "bool_wrapper")
return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool));
if (name_str == "int_wrapper")
return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64));
if (name_str == "uint_wrapper")
return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64));
if (name_str == "double_wrapper")
return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble));
if (name_str == "string_wrapper")
return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString));
if (name_str == "bytes_wrapper")
return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes));

if (name_str == "list") {
if (params.size() > 1) {
return absl::InvalidArgumentError(
"Invalid type signature: list expects at most 1 parameter");
}
auto elem = std::make_unique<TypeSpec>(DynTypeSpec());
if (!params.empty()) {
*elem = std::move(params[0]);
}
return TypeSpec(ListTypeSpec(std::move(elem)));
}

if (name_str == "map") {
if (!params.empty() && params.size() != 2) {
return absl::InvalidArgumentError(
"Invalid type signature: map expects 0 or 2 parameters");
}
auto key = std::make_unique<TypeSpec>(DynTypeSpec());
auto value = std::make_unique<TypeSpec>(DynTypeSpec());
if (!params.empty()) {
*key = std::move(params[0]);
}
if (params.size() > 1) {
*value = std::move(params[1]);
}
return TypeSpec(MapTypeSpec(std::move(key), std::move(value)));
}

if (name_str == "function") {
auto result_type = std::make_unique<TypeSpec>(DynTypeSpec());
std::vector<TypeSpec> arg_types;
if (!params.empty()) {
result_type = std::make_unique<TypeSpec>(std::move(params[0]));
for (size_t i = 1; i < params.size(); ++i) {
arg_types.push_back(std::move(params[i]));
}
}
return TypeSpec(
FunctionTypeSpec(std::move(result_type), std::move(arg_types)));
}

return TypeSpec(AbstractType(name_str, std::move(params)));
}

} // namespace

absl::StatusOr<ParsedFunctionOverload> ParseFunctionSignature(
std::string_view signature) {
if (signature.empty()) {
return absl::InvalidArgumentError("Empty function signature");
}

CEL_ASSIGN_OR_RETURN(size_t paren_idx,
SignatureScanner(signature, "Invalid function signature")
.FindTopLevelChar('(', /*find_last=*/false));

if (paren_idx == std::string_view::npos || signature.back() != ')') {
return absl::InvalidArgumentError("Invalid function signature");
}

std::string_view prefix = signature.substr(0, paren_idx);
std::string_view args_str =
signature.substr(paren_idx + 1, signature.size() - paren_idx - 2);

std::vector<TypeSpec> arg_types;
ParsedFunctionOverload out;

CEL_ASSIGN_OR_RETURN(size_t dot_idx,
SignatureScanner(prefix, "Invalid function signature")
.FindTopLevelChar('.', /*find_last=*/true));

if (dot_idx != std::string_view::npos) {
out.is_member = true;
std::string_view receiver_str = prefix.substr(0, dot_idx);
std::string_view func_str = prefix.substr(dot_idx + 1);

CEL_ASSIGN_OR_RETURN(auto receiver_param, ParseTypeSignature(receiver_str));
arg_types.push_back(std::move(receiver_param));
out.function_name = Unescape(func_str);
} else {
out.is_member = false;
out.function_name = Unescape(prefix);
}

if (out.function_name.empty()) {
return absl::InvalidArgumentError(
"Invalid function signature: empty function name");
}

if (!args_str.empty()) {
CEL_ASSIGN_OR_RETURN(auto arg_list, SplitTypeList(args_str));
for (std::string_view arg_str : arg_list) {
CEL_ASSIGN_OR_RETURN(auto arg_param, ParseTypeSignature(arg_str));
arg_types.push_back(std::move(arg_param));
}
}

auto result_type = std::make_unique<TypeSpec>(DynTypeSpec());
out.signature_type =
TypeSpec(FunctionTypeSpec(std::move(result_type), std::move(arg_types)));

return out;
}

absl::StatusOr<Type> ParseType(std::string_view signature, google::protobuf::Arena* arena,
const google::protobuf::DescriptorPool& pool) {
CEL_ASSIGN_OR_RETURN(auto type_spec, ParseTypeSignature(signature));
return cel::ConvertTypeSpecToType(type_spec, arena, pool);
}

} // namespace cel::common_internal
Loading