diff --git a/.agents/docs-and-formatting.md b/.agents/docs-and-formatting.md index 4a7ebdaa35..acfa09fede 100644 --- a/.agents/docs-and-formatting.md +++ b/.agents/docs-and-formatting.md @@ -37,6 +37,10 @@ Load this file when changing documentation, public APIs, protocol specs, benchma - Python code, including `compiler/`, `benchmarks/`, `integration_tests/`, and `python/`: `python -m ruff format ` and `python -m ruff check --fix ` +- JavaScript/TypeScript under `javascript/`: use the package's ESLint-owned formatting path + (`npm run lint -- --fix` when fixing style, `npm run lint -- --quiet` when checking). Do not run + Prettier on JavaScript or TypeScript files unless that package has an explicit Prettier config or + script; otherwise it creates unrelated formatting churn. - Repo-wide format and lint sweep: `bash ci/format.sh --all` When code changes touch `compiler/` or `benchmarks/`, format those changed source files with the diff --git a/.gitignore b/.gitignore index daaab45bd3..3f202ac650 100644 --- a/.gitignore +++ b/.gitignore @@ -128,6 +128,8 @@ csharp/artifacts/ tasks/ benchmarks/python/proto/ benchmarks/java/dependency-reduced-pom.xml +benchmarks/go/go.test +benchmarks/rust/Cargo.lock docs/superpowers test.md @@ -136,4 +138,4 @@ benchmarks/dart/profile_output integration_tests/idl_tests/dart/.dart_tool/ **/pubspec.lock -**/tmp/* \ No newline at end of file +**/tmp/* diff --git a/AGENTS.md b/AGENTS.md index 15d9e2e69a..33ae3be7e5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -39,6 +39,17 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - Keep task boundaries strict. Review tasks do not edit code, analysis-only tasks do not silently turn into implementation, and active-branch fixes must land in the active branch/workspace. - For non-trivial multi-step tasks, write the plan and progress into the canonical durable task file (use a matched skill/workflow file if it provides one, otherwise use a file under `tasks/`) and read that file after compaction before continuing. +## Design Integrity Gates + +- Record all core design and decisions in the owning docs when they belong there, especially under `docs/guide/**` or `docs/specification/**`. +- Do not allow implementation drift from the design document. +- Do not compromise design decisions to make implementation easier. +- Do not leave workaround code behind. +- All code must have a clean owner model; the wrong owner model or abstraction is unacceptable. +- Do not leave ugly or temporary code behind. +- Do not leave legacy, dead, useless, or stale code, tests, or docs behind. +- Do not leave avoidable technical debt behind. + ## Repo-Wide Hard Rules - Do not preserve legacy, dead, or useless code, tests, or docs unless the user explicitly requests it. diff --git a/benchmarks/javascript/benchmark.js b/benchmarks/javascript/benchmark.js index 3ff8078c46..b2723f66f4 100644 --- a/benchmarks/javascript/benchmark.js +++ b/benchmarks/javascript/benchmark.js @@ -29,7 +29,7 @@ const core = require(path.join(JS_ROOT, "packages", "core", "dist", "index.js")) const protobuf = require(path.join(JS_ROOT, "node_modules", "protobufjs")); const Fory = core.default; -const { Type } = core; +const { BoolArray, Type } = core; const DEFAULT_DURATION_SECONDS = 3; const SERIALIZER_ORDER = ["fory", "protobuf", "json"]; @@ -664,6 +664,26 @@ function normalizeForyValue(datasetKey, value) { } } +function normalizeForyRoundTripValue(datasetKey, value) { + switch (datasetKey) { + case "sample": + return { + ...value, + boolean_array: value.boolean_array instanceof BoolArray + ? Array.from(value.boolean_array) + : value.boolean_array, + }; + case "samplelist": + return { + sample_list: value.sample_list.map((item) => + normalizeForyRoundTripValue("sample", item) + ), + }; + default: + return value; + } +} + function normalizeProtobufValue(datasetKey, value) { switch (datasetKey) { case "sample": @@ -687,7 +707,10 @@ function ensureSerializationWorks(dataset) { const foryValue = normalizeForyValue(dataset.key, value); const foryBytes = dataset.forySerializer.serialize(foryValue); const foryRoundTrip = dataset.forySerializer.deserialize(foryBytes); - assert.deepStrictEqual(foryRoundTrip, foryValue); + assert.deepStrictEqual( + normalizeForyRoundTripValue(dataset.key, foryRoundTrip), + foryValue + ); const protoPayload = dataset.toProto(value); const protoBytes = dataset.protoType.encode(dataset.protoType.create(protoPayload)).finish(); diff --git a/benchmarks/swift/Sources/SwiftBenchmark/BenchmarkRunner.swift b/benchmarks/swift/Sources/SwiftBenchmark/BenchmarkRunner.swift index 5f6d09a6ca..b1530925cb 100644 --- a/benchmarks/swift/Sources/SwiftBenchmark/BenchmarkRunner.swift +++ b/benchmarks/swift/Sources/SwiftBenchmark/BenchmarkRunner.swift @@ -68,7 +68,7 @@ final class BenchmarkSuite { init(config: BenchmarkConfig) { self.config = config - self.fory = Fory(xlang: false, trackRef: false, compatible: true) + self.fory = Fory(xlang: false, ref: false, compatible: true) registerTypes() } diff --git a/cpp/fory/serialization/struct_compatible_test.cc b/cpp/fory/serialization/struct_compatible_test.cc index 5777447cf5..b688ba69f8 100644 --- a/cpp/fory/serialization/struct_compatible_test.cc +++ b/cpp/fory/serialization/struct_compatible_test.cc @@ -33,7 +33,9 @@ #include "fory/serialization/fory.h" #include "gtest/gtest.h" +#include #include +#include #include #include @@ -217,6 +219,44 @@ struct ProductV2 { FORY_STRUCT(ProductV2, name, price, tags, attributes); }; +struct CompatibleListField { + std::vector values; + + FORY_STRUCT(CompatibleListField, + (values, fory::F(1).list(fory::T::int32().fixed()))); +}; + +struct CompatibleArrayField { + std::vector values; + + FORY_STRUCT(CompatibleArrayField, + (values, fory::F(1).array(fory::T::int32()))); +}; + +struct CompatibleNullableListField { + std::vector> values; + + FORY_STRUCT(CompatibleNullableListField, + (values, fory::F(1).list(fory::T::fixed()))); +}; + +struct CompatibleNestedListField { + std::map> values; + + FORY_STRUCT(CompatibleNestedListField, + (values, + fory::F(1).map(fory::T::string(), + fory::T::list(fory::T::int32().fixed())))); +}; + +struct CompatibleNestedArrayField { + std::map> values; + + FORY_STRUCT(CompatibleNestedArrayField, + (values, fory::F(1).map(fory::T::string(), + fory::T::array(fory::T::int32())))); +}; + // ============================================================================ // TESTS // ============================================================================ @@ -417,6 +457,85 @@ TEST(SchemaEvolutionTest, CollectionFieldEvolution) { EXPECT_TRUE(prod_v2.attributes.empty()); // Default empty map } +TEST(SchemaEvolutionTest, ImmediateListFieldCanReadIntoArrayCarrier) { + auto writer = Fory::builder().compatible(true).xlang(true).build(); + auto reader = Fory::builder().compatible(true).xlang(true).build(); + + constexpr uint32_t TYPE_ID = 1005; + ASSERT_TRUE(writer.register_struct(TYPE_ID).ok()); + ASSERT_TRUE(reader.register_struct(TYPE_ID).ok()); + + auto bytes = writer.serialize(CompatibleListField{{1, -2, 3}}); + ASSERT_TRUE(bytes.ok()) << bytes.error().to_string(); + auto decoded = reader.deserialize(bytes.value().data(), + bytes.value().size()); + + ASSERT_TRUE(decoded.ok()) << decoded.error().to_string(); + EXPECT_EQ(decoded.value().values, (std::vector{1, -2, 3})); +} + +TEST(SchemaEvolutionTest, ImmediateArrayFieldCanReadIntoListCarrier) { + auto writer = Fory::builder().compatible(true).xlang(true).build(); + auto reader = Fory::builder().compatible(true).xlang(true).build(); + + constexpr uint32_t TYPE_ID = 1006; + ASSERT_TRUE(writer.register_struct(TYPE_ID).ok()); + ASSERT_TRUE(reader.register_struct(TYPE_ID).ok()); + + auto bytes = writer.serialize(CompatibleArrayField{{4, 5, 6}}); + ASSERT_TRUE(bytes.ok()) << bytes.error().to_string(); + auto decoded = reader.deserialize(bytes.value().data(), + bytes.value().size()); + + ASSERT_TRUE(decoded.ok()) << decoded.error().to_string(); + EXPECT_EQ(decoded.value().values, (std::vector{4, 5, 6})); +} + +TEST(SchemaEvolutionTest, NullableListElementsCannotReadIntoArrayCarrier) { + auto writer = Fory::builder().compatible(true).xlang(true).build(); + auto reader = Fory::builder().compatible(true).xlang(true).build(); + + constexpr uint32_t TYPE_ID = 1007; + ASSERT_TRUE( + writer.register_struct(TYPE_ID).ok()); + ASSERT_TRUE(reader.register_struct(TYPE_ID).ok()); + + auto bytes = writer.serialize(CompatibleNullableListField{{1, 2}}); + ASSERT_TRUE(bytes.ok()) << bytes.error().to_string(); + std::vector payload = std::move(bytes).value(); + auto decoded = + reader.deserialize(payload.data(), payload.size()); + + ASSERT_TRUE(decoded.ok()) << decoded.error().to_string(); + EXPECT_EQ(decoded.value().values, (std::vector{1, 2})); + + bytes = writer.serialize(CompatibleNullableListField{{1, std::nullopt}}); + ASSERT_TRUE(bytes.ok()) << bytes.error().to_string(); + payload = std::move(bytes).value(); + decoded = + reader.deserialize(payload.data(), payload.size()); + + ASSERT_FALSE(decoded.ok()); +} + +TEST(SchemaEvolutionTest, NestedListArraySchemaPairsAreNotMatched) { + auto writer = Fory::builder().compatible(true).xlang(true).build(); + auto reader = Fory::builder().compatible(true).xlang(true).build(); + + constexpr uint32_t TYPE_ID = 1008; + ASSERT_TRUE(writer.register_struct(TYPE_ID).ok()); + ASSERT_TRUE(reader.register_struct(TYPE_ID).ok()); + + auto bytes = writer.serialize( + CompatibleNestedListField{{{"items", std::vector{7, 8}}}}); + ASSERT_TRUE(bytes.ok()) << bytes.error().to_string(); + auto decoded = reader.deserialize( + bytes.value().data(), bytes.value().size()); + + ASSERT_TRUE(decoded.ok()) << decoded.error().to_string(); + EXPECT_TRUE(decoded.value().values.empty()); +} + TEST(SchemaEvolutionTest, RoundtripWithSameVersion) { // Sanity check: V2 -> V2 should work perfectly auto fory_compat = Fory::builder().compatible(true).xlang(true).build(); diff --git a/cpp/fory/serialization/struct_serializer.h b/cpp/fory/serialization/struct_serializer.h index 7b437acd64..679227e0db 100644 --- a/cpp/fory/serialization/struct_serializer.h +++ b/cpp/fory/serialization/struct_serializer.h @@ -97,6 +97,30 @@ inline constexpr bool is_primitive_type_id(TypeId type_id) { type_id == TypeId::TAGGED_UINT64; } +/// Type trait to check if a type is a raw primitive (not a wrapper like +/// optional, shared_ptr, etc.) +template struct is_raw_primitive : std::false_type {}; +template <> struct is_raw_primitive : std::true_type {}; +template <> struct is_raw_primitive : std::true_type {}; +template <> struct is_raw_primitive : std::true_type {}; +template <> struct is_raw_primitive : std::true_type {}; +template <> struct is_raw_primitive : std::true_type {}; +template <> struct is_raw_primitive : std::true_type {}; +template <> struct is_raw_primitive : std::true_type {}; +template <> struct is_raw_primitive : std::true_type {}; +template <> struct is_raw_primitive : std::true_type {}; +template <> struct is_raw_primitive : std::true_type {}; +template <> struct is_raw_primitive : std::true_type {}; +template <> struct is_raw_primitive : std::true_type {}; +template <> struct is_raw_primitive : std::true_type {}; +template +inline constexpr bool is_raw_primitive_v = is_raw_primitive::value; + +template +FORY_ALWAYS_INLINE TargetType read_primitive_by_type_id(ReadContext &ctx, + uint32_t type_id, + Error &error); + /// write a primitive value to buffer at given offset WITHOUT updating /// writer_index. Returns the number of bytes written. Caller must ensure buffer /// has sufficient capacity. @@ -863,12 +887,12 @@ Container read_configured_list_data(ReadContext &ctx) { using Elem = element_type_t; uint32_t length = ctx.read_var_uint32(ctx.error()); Container result; - if constexpr (has_reserve_v) { - result.reserve(length); - } if (FORY_PREDICT_FALSE(ctx.has_error()) || length == 0) { return result; } + if constexpr (has_reserve_v) { + result.reserve(length); + } uint8_t bitmap = ctx.read_uint8(ctx.error()); if (FORY_PREDICT_FALSE(ctx.has_error())) { return result; @@ -899,6 +923,73 @@ Container read_configured_list_data(ReadContext &ctx) { return result; } +template +FORY_NOINLINE Container read_configured_list_data_as_array_field( + ReadContext &ctx, uint32_t remote_element_type_id) { + using Elem = element_type_t; + uint32_t length = ctx.read_var_uint32(ctx.error()); + Container result; + if (FORY_PREDICT_FALSE(ctx.has_error()) || length == 0) { + return result; + } + if (FORY_PREDICT_FALSE(length > ctx.config().max_collection_size)) { + ctx.set_error( + Error::invalid_data("Collection length exceeds max_collection_size")); + return result; + } + uint8_t bitmap = ctx.read_uint8(ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return result; + } + const bool is_decl_type = (bitmap & COLL_DECL_ELEMENT_TYPE) != 0; + const bool is_same_type = (bitmap & COLL_IS_SAME_TYPE) != 0; + const bool track_ref = (bitmap & COLL_TRACKING_REF) != 0; + const bool has_null = (bitmap & COLL_HAS_NULL) != 0; + if (FORY_PREDICT_FALSE(track_ref || has_null)) { + ctx.set_error(Error::invalid_data( + "compatible list to array field requires non-null elements")); + return result; + } + if (FORY_PREDICT_FALSE(!is_same_type)) { + ctx.set_error(Error::invalid_data( + "compatible list to array field requires same-type elements")); + return result; + } + if (FORY_PREDICT_FALSE(!is_decl_type)) { + ctx.set_error(Error::invalid_data( + "compatible list to array field requires declared elements")); + return result; + } + if constexpr (has_reserve_v) { + result.reserve(length); + } + for (uint32_t i = 0; i < length; ++i) { + if constexpr (is_raw_primitive_v) { + auto elem = read_primitive_by_type_id(ctx, remote_element_type_id, + ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return result; + } + collection_insert(result, elem); + } else { + ctx.set_error(Error::type_error( + "compatible list to array field requires primitive elements")); + return result; + } + } + return result; +} + +template +FORY_NOINLINE Container +read_configured_array_data_as_list_field(ReadContext &ctx, RefMode ref_mode) { + if (ref_mode == RefMode::None) { + return Serializer::read_data(ctx); + } + return Serializer::read(ctx, ref_mode, false); +} + template void write_configured_map_data(const MapType &map, WriteContext &ctx) { @@ -2728,26 +2819,6 @@ void write_struct_fields_impl(const T &obj, WriteContext &ctx, (write_field_at_sorted_position(obj, ctx, has_generics), ...); } } - -/// Type trait to check if a type is a raw primitive (not a wrapper like -/// optional, shared_ptr, etc.) -template struct is_raw_primitive : std::false_type {}; -template <> struct is_raw_primitive : std::true_type {}; -template <> struct is_raw_primitive : std::true_type {}; -template <> struct is_raw_primitive : std::true_type {}; -template <> struct is_raw_primitive : std::true_type {}; -template <> struct is_raw_primitive : std::true_type {}; -template <> struct is_raw_primitive : std::true_type {}; -template <> struct is_raw_primitive : std::true_type {}; -template <> struct is_raw_primitive : std::true_type {}; -template <> struct is_raw_primitive : std::true_type {}; -template <> struct is_raw_primitive : std::true_type {}; -template <> struct is_raw_primitive : std::true_type {}; -template <> struct is_raw_primitive : std::true_type {}; -template <> struct is_raw_primitive : std::true_type {}; -template -inline constexpr bool is_raw_primitive_v = is_raw_primitive::value; - /// Read a primitive value based on remote type_id (for compatible mode). /// Returns the value as a uint64_t (or int64_t for signed types). /// The caller must convert to the correct local type. @@ -2802,6 +2873,8 @@ FORY_ALWAYS_INLINE TargetType read_primitive_by_type_id(ReadContext &ctx, return static_cast(ctx.read_tagged_uint64(error)); case TypeId::FLOAT16: return static_cast(ctx.read_f16(error).to_float()); + case TypeId::BFLOAT16: + return static_cast(ctx.read_bf16(error).to_float()); case TypeId::FLOAT32: return static_cast(ctx.read_float(error)); case TypeId::FLOAT64: @@ -3205,6 +3278,45 @@ void read_single_field_by_index_compatible(T &obj, ReadContext &ctx, } } + if constexpr (is_vector_v) { + constexpr FieldNodeKind configured_kind = + configured_node_kind(); + constexpr bool configured_as_array = + configured_kind == FieldNodeKind::Array; + constexpr bool configured_as_list = configured_kind == FieldNodeKind::List; + if constexpr (configured_as_array) { + if (remote_type_id == static_cast(TypeId::LIST)) { + if (FORY_PREDICT_FALSE(remote_field_type.generics.size() != 1)) { + ctx.set_error(Error::invalid_data( + "compatible list to array field requires one element schema")); + return; + } + const auto &remote_element_type = remote_field_type.generics[0]; + constexpr int8_t child = configured_node_child(); + FieldType result = read_configured_list_data_as_array_field< + FieldType, T, Index, 0, child>(ctx, remote_element_type.type_id); + if constexpr (is_fory_field_v) { + (obj.*field_ptr).value = std::move(result); + } else { + obj.*field_ptr = std::move(result); + } + return; + } + } else if constexpr (configured_as_list) { + uint32_t element_type_id = 0; + if (primitive_array_element_type_id(remote_type_id, element_type_id)) { + FieldType result = read_configured_array_data_as_list_field( + ctx, remote_ref_mode); + if constexpr (is_fory_field_v) { + (obj.*field_ptr).value = std::move(result); + } else { + obj.*field_ptr = std::move(result); + } + return; + } + } + } + if constexpr (configured_node_has_override()) { FieldType result = read_configured_value( ctx, remote_ref_mode, read_type); diff --git a/cpp/fory/serialization/type_resolver.cc b/cpp/fory/serialization/type_resolver.cc index b5ffb5917e..039e7e8264 100644 --- a/cpp/fory/serialization/type_resolver.cc +++ b/cpp/fory/serialization/type_resolver.cc @@ -1104,8 +1104,9 @@ void TypeMeta::assign_field_ids(const TypeMeta *local_type, if (local_uses_field_ids && remote_field.field_id >= 0) { auto id_it = local_field_id_map.find(remote_field.field_id); if (id_it != local_field_id_map.end() && - field_types_compatible(local_fields[id_it->second].field_type, - remote_field.field_type)) { + field_types_compatible_top_level( + local_fields[id_it->second].field_type, + remote_field.field_type)) { local_index = id_it->second; } } @@ -1132,8 +1133,8 @@ void TypeMeta::assign_field_ids(const TypeMeta *local_type, if (it != local_field_index_map.end()) { size_t idx = it->second; const FieldInfo &local_field = local_fields[idx]; - if (field_types_compatible(local_field.field_type, - remote_field.field_type)) { + if (field_types_compatible_top_level(local_field.field_type, + remote_field.field_type)) { local_index = idx; } } @@ -1146,8 +1147,8 @@ void TypeMeta::assign_field_ids(const TypeMeta *local_type, if (used[i]) { continue; } - if (field_types_compatible(local_fields[i].field_type, - remote_field.field_type)) { + if (field_types_compatible_top_level(local_fields[i].field_type, + remote_field.field_type)) { local_index = i; break; } diff --git a/cpp/fory/serialization/type_resolver.h b/cpp/fory/serialization/type_resolver.h index cf42126324..48f73445cf 100644 --- a/cpp/fory/serialization/type_resolver.h +++ b/cpp/fory/serialization/type_resolver.h @@ -246,6 +246,77 @@ inline bool field_types_compatible(const FieldType &local, (local.generics.empty() || remote.generics.empty()); } +inline bool primitive_array_element_type_id(uint32_t array_type_id, + uint32_t &element_type_id) { + switch (static_cast(array_type_id)) { + case TypeId::BOOL_ARRAY: + element_type_id = static_cast(TypeId::BOOL); + return true; + case TypeId::INT8_ARRAY: + element_type_id = static_cast(TypeId::INT8); + return true; + case TypeId::INT16_ARRAY: + element_type_id = static_cast(TypeId::INT16); + return true; + case TypeId::INT32_ARRAY: + element_type_id = static_cast(TypeId::VARINT32); + return true; + case TypeId::INT64_ARRAY: + element_type_id = static_cast(TypeId::VARINT64); + return true; + case TypeId::FLOAT16_ARRAY: + element_type_id = static_cast(TypeId::FLOAT16); + return true; + case TypeId::FLOAT32_ARRAY: + element_type_id = static_cast(TypeId::FLOAT32); + return true; + case TypeId::FLOAT64_ARRAY: + element_type_id = static_cast(TypeId::FLOAT64); + return true; + case TypeId::UINT8_ARRAY: + element_type_id = static_cast(TypeId::UINT8); + return true; + case TypeId::UINT16_ARRAY: + element_type_id = static_cast(TypeId::UINT16); + return true; + case TypeId::UINT32_ARRAY: + element_type_id = static_cast(TypeId::UINT32); + return true; + case TypeId::UINT64_ARRAY: + element_type_id = static_cast(TypeId::UINT64); + return true; + case TypeId::BFLOAT16_ARRAY: + element_type_id = static_cast(TypeId::BFLOAT16); + return true; + default: + return false; + } +} + +inline bool field_types_compatible_top_level(const FieldType &local, + const FieldType &remote) { + if (field_types_compatible(local, remote)) { + return true; + } + + uint32_t array_element_type_id = 0; + if (local.type_id == static_cast(TypeId::LIST) && + remote.generics.size() == 0 && + primitive_array_element_type_id(remote.type_id, array_element_type_id) && + local.generics.size() == 1) { + return compatible_fingerprint_type_id(local.generics[0].type_id) == + compatible_fingerprint_type_id(array_element_type_id); + } + if (remote.type_id == static_cast(TypeId::LIST) && + local.generics.size() == 0 && + primitive_array_element_type_id(local.type_id, array_element_type_id) && + remote.generics.size() == 1) { + return compatible_fingerprint_type_id(remote.generics[0].type_id) == + compatible_fingerprint_type_id(array_element_type_id); + } + return false; +} + // ============================================================================ // FieldInfo - Field metadata (name, type, id) // ============================================================================ diff --git a/cpp/fory/serialization/xlang_test_main.cc b/cpp/fory/serialization/xlang_test_main.cc index 99933ce096..93d78f1760 100644 --- a/cpp/fory/serialization/xlang_test_main.cc +++ b/cpp/fory/serialization/xlang_test_main.cc @@ -308,6 +308,18 @@ struct ReducedPrecisionFloatStruct { (bfloat16_array, fory::F().list(fory::T::bfloat16()))); }; +struct CompatibleInt32ListField { + std::vector values; + FORY_STRUCT(CompatibleInt32ListField, + (values, fory::F(1).list(fory::T::int32().fixed()))); +}; + +struct CompatibleInt32ArrayField { + std::vector values; + FORY_STRUCT(CompatibleInt32ArrayField, + (values, fory::F(1).array(fory::T::int32()))); +}; + enum class TestEnum : int32_t { VALUE_A = 0, VALUE_B = 1, VALUE_C = 2 }; FORY_ENUM(TestEnum, VALUE_A, VALUE_B, VALUE_C); @@ -951,6 +963,10 @@ void run_test_schema_evolution_compatible_reverse(const std::string &data_file); void run_test_reduced_precision_float_struct(const std::string &data_file); void run_test_reduced_precision_float_struct_compatible_skip( const std::string &data_file); +void run_test_list_array_compatible_list_to_array(const std::string &data_file); +void run_test_list_array_compatible_array_to_list(const std::string &data_file); +void run_test_list_array_compatible_nullable_list_to_array_error( + const std::string &data_file); void run_test_one_enum_field_schema(const std::string &data_file); void run_test_one_enum_field_compatible(const std::string &data_file); void run_test_two_enum_field_compatible(const std::string &data_file); @@ -1060,6 +1076,13 @@ int main(int argc, char **argv) { } else if (case_name == "test_reduced_precision_float_struct_compatible_skip") { run_test_reduced_precision_float_struct_compatible_skip(data_file); + } else if (case_name == "test_list_array_compatible_list_to_array") { + run_test_list_array_compatible_list_to_array(data_file); + } else if (case_name == "test_list_array_compatible_array_to_list") { + run_test_list_array_compatible_array_to_list(data_file); + } else if (case_name == + "test_list_array_compatible_nullable_list_to_array_error") { + run_test_list_array_compatible_nullable_list_to_array_error(data_file); } else if (case_name == "test_one_enum_field_schema") { run_test_one_enum_field_schema(data_file); } else if (case_name == "test_one_enum_field_compatible") { @@ -2285,6 +2308,51 @@ void run_test_reduced_precision_float_struct_compatible_skip( write_file(data_file, out); } +void run_test_list_array_compatible_list_to_array( + const std::string &data_file) { + auto bytes = read_file(data_file); + auto fory = build_fory(true, true); + ensure_ok(fory.register_struct(901), + "register CompatibleInt32ArrayField"); + + Buffer buffer = make_buffer(bytes); + auto value = read_next(fory, buffer); + + std::vector out; + append_serialized(fory, value, out); + write_file(data_file, out); +} + +void run_test_list_array_compatible_array_to_list( + const std::string &data_file) { + auto bytes = read_file(data_file); + auto fory = build_fory(true, true); + ensure_ok(fory.register_struct(901), + "register CompatibleInt32ListField"); + + Buffer buffer = make_buffer(bytes); + auto value = read_next(fory, buffer); + + std::vector out; + append_serialized(fory, value, out); + write_file(data_file, out); +} + +void run_test_list_array_compatible_nullable_list_to_array_error( + const std::string &data_file) { + auto bytes = read_file(data_file); + auto fory = build_fory(true, true); + ensure_ok(fory.register_struct(901), + "register CompatibleInt32ArrayField nullable-list error"); + + Buffer buffer = make_buffer(bytes); + auto result = fory.deserialize(buffer); + if (result.ok()) { + fail("Expected nullable list payload to fail compatible array read"); + } + write_file(data_file, bytes); +} + // ============================================================================ // Schema Evolution Tests - Enum Fields // ============================================================================ diff --git a/csharp/src/Fory.Generator/ForyObjectGenerator.cs b/csharp/src/Fory.Generator/ForyObjectGenerator.cs index fb30ed410e..242e277a14 100644 --- a/csharp/src/Fory.Generator/ForyObjectGenerator.cs +++ b/csharp/src/Fory.Generator/ForyObjectGenerator.cs @@ -177,6 +177,8 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) } } + EmitCompatibleFieldCodecMethods(sb, model); + sb.AppendLine(" private static bool __ForyCanReadCompatiblePrimitive(global::Apache.Fory.TypeId typeId)"); sb.AppendLine(" {"); sb.AppendLine(" return typeId switch"); @@ -689,6 +691,194 @@ private static void EmitFieldCodecMethods(StringBuilder sb, MemberModel member) sb.AppendLine(); } + private static void EmitCompatibleFieldCodecMethods(StringBuilder sb, TypeModel model) + { + bool hasCompatibleField = false; + foreach (MemberModel member in model.SortedMembers) + { + if (member.FieldCodec is not null && + TryBuildCompatibleListArrayReadCodec(member.FieldCodec, out _)) + { + hasCompatibleField = true; + break; + } + } + + if (!hasCompatibleField) + { + return; + } + + sb.AppendLine(" private static class __ForyCompatibleFieldReaders"); + sb.AppendLine(" {"); + foreach (MemberModel member in model.SortedMembers) + { + if (member.FieldCodec is not null && + TryBuildCompatibleListArrayReadCodec(member.FieldCodec, out FieldCodecModel? alternateCodec)) + { + EmitCompatibleFieldCodecMethod(sb, member, member.FieldCodec, alternateCodec); + } + } + + sb.AppendLine(" }"); + sb.AppendLine(); + } + + private static void EmitCompatibleFieldCodecMethod( + StringBuilder sb, + MemberModel member, + FieldCodecModel codec, + FieldCodecModel alternateCodec) + { + string memberId = Sanitize(member.Name); + sb.AppendLine(" [global::System.Runtime.CompilerServices.MethodImpl(global::System.Runtime.CompilerServices.MethodImplOptions.NoInlining)]"); + sb.AppendLine( + $" internal static {member.TypeName} Read{memberId}ListArrayBridge(global::Apache.Fory.ReadContext context, global::Apache.Fory.TypeMetaFieldType remoteFieldType, global::Apache.Fory.RefMode refMode)"); + sb.AppendLine(" {"); + sb.AppendLine(" if (remoteFieldType.TypeId == " + alternateCodec.TypeId + ")"); + sb.AppendLine(" {"); + if (codec.Kind == FieldCodecKind.PackedArray) + { + sb.AppendLine(" if (remoteFieldType.Generics.Count != 1)"); + sb.AppendLine(" {"); + sb.AppendLine(" throw new global::Apache.Fory.InvalidDataException(\"compatible list to array field requires one element schema\");"); + sb.AppendLine(" }"); + } + + sb.AppendLine(" if (refMode == global::Apache.Fory.RefMode.NullOnly)"); + sb.AppendLine(" {"); + sb.AppendLine(" sbyte refFlag = context.Reader.ReadInt8();"); + sb.AppendLine(" if (refFlag == (sbyte)global::Apache.Fory.RefFlag.Null)"); + sb.AppendLine(" {"); + sb.AppendLine($" return ({member.TypeName})default!;"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" if (refFlag != (sbyte)global::Apache.Fory.RefFlag.NotNullValue)"); + sb.AppendLine(" {"); + sb.AppendLine(" throw new global::Apache.Fory.InvalidDataException($\"invalid nullOnly ref flag {refFlag}\");"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + int id = 0; + string compatibleResultVar = $"__{memberId}CompatibleValue"; + if (codec.Kind == FieldCodecKind.PackedArray && alternateCodec.Kind == FieldCodecKind.List) + { + EmitReadCompatibleListArrayPayload(sb, codec, compatibleResultVar, 4, ref id); + } + else + { + EmitReadPayload(sb, alternateCodec, compatibleResultVar, 4, ref id); + } + + sb.AppendLine($" return {compatibleResultVar};"); + sb.AppendLine(" }"); + sb.AppendLine(" throw new global::Apache.Fory.InvalidDataException($\"unsupported compatible field schema pair: local " + codec.TypeId + ", remote {remoteFieldType.TypeId}\");"); + sb.AppendLine(" }"); + } + + private static bool TryBuildCompatibleListArrayReadCodec(FieldCodecModel codec, out FieldCodecModel compatibleCodec) + { + if (codec.Kind == FieldCodecKind.PackedArray) + { + uint elementTypeId = PackedArrayElementTypeId(codec.TypeId); + compatibleCodec = new FieldCodecModel( + FieldCodecKind.List, + 22, + codec.TypeName, + codec.Nullable, + codec.NullableValueType, + codec.CarrierKind, + ImmutableArray.Create(new FieldCodecModel( + FieldCodecKind.Scalar, + elementTypeId, + PackedArrayElementTypeName(codec.TypeId), + false, + false, + CarrierKind.Value, + ImmutableArray.Empty))); + return true; + } + + if (codec.Kind == FieldCodecKind.List && + codec.Generics.Length == 1 && + TryResolveArrayTypeIdForElement(codec.Generics[0].TypeId) is uint arrayTypeId) + { + compatibleCodec = new FieldCodecModel( + FieldCodecKind.PackedArray, + arrayTypeId, + codec.TypeName, + codec.Nullable, + codec.NullableValueType, + codec.CarrierKind, + ImmutableArray.Empty); + return true; + } + + compatibleCodec = codec; + return false; + } + + private static void EmitReadCompatibleListArrayPayload( + StringBuilder sb, + FieldCodecModel codec, + string targetVar, + int indentLevel, + ref int id) + { + string indent = new(' ', indentLevel * 4); + string lengthVar = $"__foryLength{id++}"; + string headerVar = $"__foryHeader{id++}"; + string declaredVar = $"__foryDeclared{id++}"; + string sameTypeVar = $"__forySameType{id++}"; + sb.AppendLine($"{indent}int {lengthVar} = checked((int)context.Reader.ReadVarUInt32());"); + sb.AppendLine($"{indent}if ({lengthVar} != 0)"); + sb.AppendLine($"{indent}{{"); + string innerIndent = indent + " "; + sb.AppendLine($"{innerIndent}byte {headerVar} = context.Reader.ReadUInt8();"); + sb.AppendLine($"{innerIndent}if (({headerVar} & 0b0000_0011) != 0)"); + sb.AppendLine($"{innerIndent}{{"); + sb.AppendLine($"{innerIndent} throw new global::Apache.Fory.InvalidDataException(\"compatible list to array field requires non-null elements\");"); + sb.AppendLine($"{innerIndent}}}"); + sb.AppendLine($"{innerIndent}bool {declaredVar} = ({headerVar} & 0b0000_0100) != 0;"); + sb.AppendLine($"{innerIndent}bool {sameTypeVar} = ({headerVar} & 0b0000_1000) != 0;"); + sb.AppendLine($"{innerIndent}if (!{sameTypeVar})"); + sb.AppendLine($"{innerIndent}{{"); + sb.AppendLine($"{innerIndent} throw new global::Apache.Fory.InvalidDataException(\"compatible list to array field requires same-type elements\");"); + sb.AppendLine($"{innerIndent}}}"); + sb.AppendLine($"{innerIndent}if (!{declaredVar})"); + sb.AppendLine($"{innerIndent}{{"); + sb.AppendLine($"{innerIndent} uint __foryWireTypeId = context.Reader.ReadUInt8();"); + sb.AppendLine($"{innerIndent} if (__foryWireTypeId != remoteFieldType.Generics[0].TypeId)"); + sb.AppendLine($"{innerIndent} {{"); + sb.AppendLine($"{innerIndent} throw new global::Apache.Fory.TypeMismatchException(remoteFieldType.Generics[0].TypeId, __foryWireTypeId);"); + sb.AppendLine($"{innerIndent} }}"); + sb.AppendLine($"{innerIndent}}}"); + sb.AppendLine($"{indent}}}"); + string indexVar = $"__foryIndex{id++}"; + string elementTypeName = codec.CarrierKind == CarrierKind.Array ? ElementTypeName(codec.TypeName) : PackedArrayElementTypeName(codec.TypeId); + if (codec.CarrierKind == CarrierKind.Array) + { + sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = new {ElementTypeName(codec.TypeName)}[{lengthVar}];"); + } + else + { + sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = new({lengthVar});"); + } + + sb.AppendLine($"{indent}for (int {indexVar} = 0; {indexVar} < {lengthVar}; {indexVar}++)"); + sb.AppendLine($"{indent}{{"); + sb.AppendLine($"{innerIndent}object __foryItem = __ForyReadCompatiblePrimitivePayload((global::Apache.Fory.TypeId)remoteFieldType.Generics[0].TypeId, context);"); + if (codec.CarrierKind == CarrierKind.Array) + { + sb.AppendLine($"{innerIndent}{targetVar}[{indexVar}] = ({elementTypeName})__foryItem;"); + } + else + { + sb.AppendLine($"{innerIndent}{targetVar}.Add(({elementTypeName})__foryItem);"); + } + + sb.AppendLine($"{indent}}}"); + } + private static void EmitWritePayload( StringBuilder sb, FieldCodecModel codec, @@ -1248,6 +1438,28 @@ private static string ElementTypeName(string arrayTypeName) : "object"; } + private static string PackedArrayElementTypeName(uint typeId) + { + return typeId switch + { + 41 => "byte", + 43 => "bool", + 44 => "sbyte", + 45 => "short", + 46 => "int", + 47 => "long", + 48 => "byte", + 49 => "ushort", + 50 => "uint", + 51 => "ulong", + 53 => "global::System.Half", + 54 => "global::Apache.Fory.BFloat16", + 55 => "float", + 56 => "double", + _ => throw new InvalidOperationException($"unsupported packed array type id {typeId}"), + }; + } + private static int PackedArrayElementWidth(uint typeId) { return typeId switch @@ -1420,8 +1632,26 @@ private static void EmitReadMemberAssignment( if (member.FieldCodec is not null) { - sb.AppendLine( - $"{indent}{assignmentTarget} = __ForyRead{Sanitize(member.Name)}Field(context, {refModeExpr});"); + if (variableSuffix == "Compat" && + TryBuildCompatibleListArrayReadCodec(member.FieldCodec, out _)) + { + sb.AppendLine($"{indent}if (remoteField.FieldType.TypeId == {member.FieldCodec.TypeId})"); + sb.AppendLine($"{indent}{{"); + sb.AppendLine( + $"{indent} {assignmentTarget} = __ForyRead{Sanitize(member.Name)}Field(context, {refModeExpr});"); + sb.AppendLine($"{indent}}}"); + sb.AppendLine($"{indent}else"); + sb.AppendLine($"{indent}{{"); + sb.AppendLine( + $"{indent} {assignmentTarget} = __ForyCompatibleFieldReaders.Read{Sanitize(member.Name)}ListArrayBridge(context, remoteField.FieldType, {refModeExpr});"); + sb.AppendLine($"{indent}}}"); + } + else + { + sb.AppendLine( + $"{indent}{assignmentTarget} = __ForyRead{Sanitize(member.Name)}Field(context, {refModeExpr});"); + } + return; } @@ -2753,12 +2983,12 @@ private static bool TryResolveSchemaTypeId(string fullName, out uint typeId, out 1 => 43, 2 => 44, 3 => 45, - 5 => 46, - 7 => 47, + 4 or 5 => 46, + 6 or 7 or 8 => 47, 9 => 48, 10 => 49, - 12 => 50, - 14 => 51, + 11 or 12 => 50, + 13 or 14 or 15 => 51, 17 => 53, 18 => 54, 19 => 55, diff --git a/csharp/src/Fory/TypeMeta.cs b/csharp/src/Fory/TypeMeta.cs index 9073b5f1f4..4312a5af86 100644 --- a/csharp/src/Fory/TypeMeta.cs +++ b/csharp/src/Fory/TypeMeta.cs @@ -754,7 +754,7 @@ public static void AssignFieldIds( if (localIndex >= 0 && localMatch is not null && - IsCompatibleFieldType(remoteField.FieldType, localMatch.FieldType)) + IsCompatibleFieldType(remoteField.FieldType, localMatch.FieldType, topLevel: true)) { remoteField.AssignedFieldId = localIndex; } @@ -765,8 +765,13 @@ localMatch is not null && } } - private static bool IsCompatibleFieldType(TypeMetaFieldType remote, TypeMetaFieldType local) + private static bool IsCompatibleFieldType(TypeMetaFieldType remote, TypeMetaFieldType local, bool topLevel) { + if (topLevel && IsCompatibleListArrayFieldPair(remote, local)) + { + return true; + } + if (NormalizeTypeIdForMatch(remote.TypeId) != NormalizeTypeIdForMatch(local.TypeId)) { return false; @@ -779,7 +784,7 @@ private static bool IsCompatibleFieldType(TypeMetaFieldType remote, TypeMetaFiel for (int i = 0; i < remote.Generics.Count; i++) { - if (!IsCompatibleFieldType(remote.Generics[i], local.Generics[i])) + if (!IsCompatibleFieldType(remote.Generics[i], local.Generics[i], topLevel: false)) { return false; } @@ -788,6 +793,66 @@ private static bool IsCompatibleFieldType(TypeMetaFieldType remote, TypeMetaFiel return true; } + private static bool IsCompatibleListArrayFieldPair(TypeMetaFieldType remote, TypeMetaFieldType local) + { + uint? localArrayElementTypeId = TryPackedArrayElementTypeId(local.TypeId); + uint? remoteArrayElementTypeId = TryPackedArrayElementTypeId(remote.TypeId); + bool remoteListLocalArray = remote.TypeId == (uint)global::Apache.Fory.TypeId.List && + localArrayElementTypeId.HasValue && + remote.Generics.Count == 1 && + CompatibleScalarTypeId(localArrayElementTypeId.Value) == + CompatibleScalarTypeId(remote.Generics[0].TypeId); + if (remoteListLocalArray) + { + return true; + } + + return local.TypeId == (uint)global::Apache.Fory.TypeId.List && + remoteArrayElementTypeId.HasValue && + local.Generics.Count == 1 && + CompatibleScalarTypeId(remoteArrayElementTypeId.Value) == + CompatibleScalarTypeId(local.Generics[0].TypeId); + } + + private static uint? TryPackedArrayElementTypeId(uint typeId) + { + return typeId switch + { + (uint)global::Apache.Fory.TypeId.BoolArray => (uint)global::Apache.Fory.TypeId.Bool, + (uint)global::Apache.Fory.TypeId.Int8Array => (uint)global::Apache.Fory.TypeId.Int8, + (uint)global::Apache.Fory.TypeId.Int16Array => (uint)global::Apache.Fory.TypeId.Int16, + (uint)global::Apache.Fory.TypeId.Int32Array => (uint)global::Apache.Fory.TypeId.VarInt32, + (uint)global::Apache.Fory.TypeId.Int64Array => (uint)global::Apache.Fory.TypeId.VarInt64, + (uint)global::Apache.Fory.TypeId.Float16Array => (uint)global::Apache.Fory.TypeId.Float16, + (uint)global::Apache.Fory.TypeId.Float32Array => (uint)global::Apache.Fory.TypeId.Float32, + (uint)global::Apache.Fory.TypeId.Float64Array => (uint)global::Apache.Fory.TypeId.Float64, + (uint)global::Apache.Fory.TypeId.UInt8Array => (uint)global::Apache.Fory.TypeId.UInt8, + (uint)global::Apache.Fory.TypeId.UInt16Array => (uint)global::Apache.Fory.TypeId.UInt16, + (uint)global::Apache.Fory.TypeId.UInt32Array => (uint)global::Apache.Fory.TypeId.UInt32, + (uint)global::Apache.Fory.TypeId.UInt64Array => (uint)global::Apache.Fory.TypeId.UInt64, + (uint)global::Apache.Fory.TypeId.BFloat16Array => (uint)global::Apache.Fory.TypeId.BFloat16, + _ => null, + }; + } + + private static uint CompatibleScalarTypeId(uint typeId) + { + return typeId switch + { + (uint)global::Apache.Fory.TypeId.Int32 or + (uint)global::Apache.Fory.TypeId.VarInt32 => (uint)global::Apache.Fory.TypeId.VarInt32, + (uint)global::Apache.Fory.TypeId.Int64 or + (uint)global::Apache.Fory.TypeId.VarInt64 or + (uint)global::Apache.Fory.TypeId.TaggedInt64 => (uint)global::Apache.Fory.TypeId.VarInt64, + (uint)global::Apache.Fory.TypeId.UInt32 or + (uint)global::Apache.Fory.TypeId.VarUInt32 => (uint)global::Apache.Fory.TypeId.VarUInt32, + (uint)global::Apache.Fory.TypeId.UInt64 or + (uint)global::Apache.Fory.TypeId.VarUInt64 or + (uint)global::Apache.Fory.TypeId.TaggedUInt64 => (uint)global::Apache.Fory.TypeId.VarUInt64, + _ => NormalizeTypeIdForMatch(typeId), + }; + } + private static uint NormalizeTypeIdForMatch(uint typeId) { return typeId switch diff --git a/csharp/tests/Fory.Tests/ForyRuntimeTests.cs b/csharp/tests/Fory.Tests/ForyRuntimeTests.cs index d56c3e4ef0..7e1b0bc107 100644 --- a/csharp/tests/Fory.Tests/ForyRuntimeTests.cs +++ b/csharp/tests/Fory.Tests/ForyRuntimeTests.cs @@ -115,6 +115,62 @@ public sealed class ExplicitArraySchema public int[] Values { get; set; } = []; } +[ForyObject] +public sealed class CompatibleListSchema +{ + [ForyField(Type = typeof(S.List))] + public List Values { get; set; } = []; +} + +[ForyObject] +public sealed class CompatibleNullableListSchema +{ + [ForyField(Type = typeof(S.List))] + public List Values { get; set; } = []; +} + +[ForyObject] +public sealed class CompatibleArraySchema +{ + [ForyField(Type = typeof(S.Array))] + public int[] Values { get; set; } = []; +} + +[ForyObject] +public sealed class CompatibleUInt32ListSchema +{ + [ForyField(Type = typeof(S.List))] + public List Values { get; set; } = []; +} + +[ForyObject] +public sealed class CompatibleUInt32ArraySchema +{ + [ForyField(Type = typeof(S.Array))] + public uint[] Values { get; set; } = []; +} + +[ForyObject] +public sealed class CompatibleUInt32ArrayListCarrierSchema +{ + [ForyField(Type = typeof(S.Array))] + public List Values { get; set; } = []; +} + +[ForyObject] +public sealed class CompatibleNestedListSchema +{ + [ForyField(Type = typeof(S.Map>))] + public Dictionary> Values { get; set; } = []; +} + +[ForyObject] +public sealed class CompatibleNestedArraySchema +{ + [ForyField(Type = typeof(S.Map>))] + public Dictionary Values { get; set; } = []; +} + [ForyObject] public sealed class SemanticScalarSchema { @@ -954,6 +1010,130 @@ public void ExplicitArrayMarkerRoundTripsDenseArrayPayload() Assert.Equal(value.Values, decoded.Values); } + [Fact] + public void CompatibleReadAllowsImmediateListFieldIntoArrayCarrier() + { + ForyRuntime writer = ForyRuntime.Builder().Compatible(true).Build(); + writer.Register(306); + ForyRuntime reader = ForyRuntime.Builder().Compatible(true).Build(); + reader.Register(306); + + CompatibleArraySchema decoded = reader.Deserialize( + writer.Serialize(new CompatibleListSchema { Values = [1, -2, int.MaxValue] })); + + Assert.Equal([1, -2, int.MaxValue], decoded.Values); + } + + [Fact] + public void CompatibleReadAllowsImmediateArrayFieldIntoListCarrier() + { + ForyRuntime writer = ForyRuntime.Builder().Compatible(true).Build(); + writer.Register(307); + ForyRuntime reader = ForyRuntime.Builder().Compatible(true).Build(); + reader.Register(307); + + CompatibleListSchema decoded = reader.Deserialize( + writer.Serialize(new CompatibleArraySchema { Values = [3, 4, 5] })); + + Assert.Equal([3, 4, 5], decoded.Values); + } + + [Fact] + public void CompatibleReadSupportsUInt32ListArrayFieldPairs() + { + ForyRuntime listWriter = ForyRuntime.Builder().Compatible(true).Build(); + listWriter.Register(309); + ForyRuntime arrayReader = ForyRuntime.Builder().Compatible(true).Build(); + arrayReader.Register(309); + + CompatibleUInt32ArraySchema decodedArray = arrayReader.Deserialize( + listWriter.Serialize(new CompatibleUInt32ListSchema { Values = [1u, uint.MaxValue] })); + + Assert.Equal([1u, uint.MaxValue], decodedArray.Values); + + ForyRuntime arrayListCarrierReader = ForyRuntime.Builder().Compatible(true).Build(); + arrayListCarrierReader.Register(309); + CompatibleUInt32ArrayListCarrierSchema decodedArrayListCarrier = + arrayListCarrierReader.Deserialize( + listWriter.Serialize(new CompatibleUInt32ListSchema { Values = [7u, 8u] })); + + Assert.Equal([7u, 8u], decodedArrayListCarrier.Values); + + ForyRuntime arrayWriter = ForyRuntime.Builder().Compatible(true).Build(); + arrayWriter.Register(310); + ForyRuntime listReader = ForyRuntime.Builder().Compatible(true).Build(); + listReader.Register(310); + + CompatibleUInt32ListSchema decodedList = listReader.Deserialize( + arrayWriter.Serialize(new CompatibleUInt32ArraySchema { Values = [9u, uint.MaxValue] })); + + Assert.Equal([9u, uint.MaxValue], decodedList.Values); + } + + [Fact] + public void CompatibleReadRejectsNullableListElementsIntoArrayCarrier() + { + ForyRuntime writer = ForyRuntime.Builder().Compatible(true).Build(); + writer.Register(308); + ForyRuntime reader = ForyRuntime.Builder().Compatible(true).Build(); + reader.Register(308); + + byte[] nonNullPayload = writer.Serialize(new CompatibleNullableListSchema { Values = [1, 2] }); + CompatibleArraySchema decoded = reader.Deserialize(nonNullPayload); + Assert.Equal([1, 2], decoded.Values); + + byte[] payload = writer.Serialize(new CompatibleNullableListSchema { Values = [1, null] }); + InvalidDataException exception = + Assert.Throws(() => reader.Deserialize(payload)); + Assert.Contains("compatible list to array field requires non-null elements", exception.Message); + } + + [Fact] + public void CompatibleReadDoesNotMatchNestedListArraySchemaPairs() + { + List localFields = + [ + new TypeMetaFieldInfo( + null, + "values", + new TypeMetaFieldType( + (uint)TypeId.Map, + false, + false, + [ + new TypeMetaFieldType((uint)TypeId.String, false), + new TypeMetaFieldType((uint)TypeId.Int32Array, false), + ])), + ]; + TypeMeta remoteTypeMeta = new( + (uint)TypeId.CompatibleStruct, + 0, + MetaString.Empty('_', '_'), + new MetaString("remote", MetaStringEncoding.Utf8, '_', '_', "remote"u8.ToArray()), + false, + [ + new TypeMetaFieldInfo( + null, + "values", + new TypeMetaFieldType( + (uint)TypeId.Map, + false, + false, + [ + new TypeMetaFieldType((uint)TypeId.String, false), + new TypeMetaFieldType( + (uint)TypeId.List, + false, + false, + [new TypeMetaFieldType((uint)TypeId.VarInt32, false)]), + ])), + ]); + + TypeMeta.AssignFieldIds(remoteTypeMeta, localFields); + + Assert.Equal(-1, remoteTypeMeta.Fields[0].AssignedFieldId); + } + [Fact] public void NestedSchemaAnnotationControlsMapAndListPayload() { diff --git a/csharp/tests/Fory.XlangPeer/Program.cs b/csharp/tests/Fory.XlangPeer/Program.cs index 04f9691882..2fad7dfbf3 100644 --- a/csharp/tests/Fory.XlangPeer/Program.cs +++ b/csharp/tests/Fory.XlangPeer/Program.cs @@ -220,6 +220,9 @@ private static byte[] ExecuteCase(string caseName, byte[] input) "test_schema_evolution_compatible_reverse" => CaseSchemaEvolutionCompatibleReverse(input), "test_reduced_precision_float_struct" => CaseReducedPrecisionFloatStruct(input), "test_reduced_precision_float_struct_compatible_skip" => CaseReducedPrecisionFloatStructCompatibleSkip(input), + "test_list_array_compatible_list_to_array" => CaseListArrayCompatibleListToArray(input), + "test_list_array_compatible_array_to_list" => CaseListArrayCompatibleArrayToList(input), + "test_list_array_compatible_nullable_list_to_array_error" => CaseListArrayCompatibleNullableListToArrayError(input), "test_one_enum_field_schema" => CaseOneEnumFieldSchema(input), "test_one_enum_field_compatible" => CaseOneEnumFieldCompatible(input), "test_two_enum_field_compatible" => CaseTwoEnumFieldCompatible(input), @@ -819,6 +822,37 @@ private static byte[] CaseReducedPrecisionFloatStructCompatibleSkip(byte[] input return RoundTripSingle(input, fory); } + private static byte[] CaseListArrayCompatibleListToArray(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(901); + return RoundTripSingle(input, fory); + } + + private static byte[] CaseListArrayCompatibleArrayToList(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(901); + return RoundTripSingle(input, fory); + } + + private static byte[] CaseListArrayCompatibleNullableListToArrayError(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(901); + + ReadOnlySequence sequence = new(input); + try + { + _ = fory.Deserialize(ref sequence); + } + catch (Apache.Fory.InvalidDataException) + { + return input; + } + throw new InvalidOperationException("Expected nullable list payload to fail compatible array read"); + } + private static byte[] CaseOneEnumFieldSchema(byte[] input) { ForyRuntime fory = BuildFory(compatible: false); @@ -1279,6 +1313,27 @@ public sealed class ReducedPrecisionFloatStruct public List BFloat16Array { get; set; } = []; } +[ForyObject] +public sealed class CompatibleInt32ListField +{ + [ForyField(1, Type = typeof(S.List>))] + public List Values { get; set; } = []; +} + +[ForyObject] +public sealed class CompatibleNullableInt32ListField +{ + [ForyField(1, Type = typeof(S.List>))] + public List Values { get; set; } = []; +} + +[ForyObject] +public sealed class CompatibleInt32ArrayField +{ + [ForyField(1, Type = typeof(S.Array))] + public int[] Values { get; set; } = []; +} + [ForyObject] public enum TestEnum { diff --git a/dart/packages/fory-test/lib/entity/xlang_test_models.dart b/dart/packages/fory-test/lib/entity/xlang_test_models.dart index 98073b6412..e25b9292b9 100644 --- a/dart/packages/fory-test/lib/entity/xlang_test_models.dart +++ b/dart/packages/fory-test/lib/entity/xlang_test_models.dart @@ -19,6 +19,8 @@ library; +import 'dart:typed_data'; + import 'package:fory/fory.dart'; import 'xlang_test_manual.dart' as manual; @@ -336,6 +338,33 @@ class ReducedPrecisionFloatStruct { List bfloat16Array = []; } +@ForyStruct() +class CompatibleInt32ListField { + CompatibleInt32ListField(); + + @ListField(id: 1, element: Int32Type(encoding: Encoding.fixed)) + List values = []; +} + +@ForyStruct() +class CompatibleNullableInt32ListField { + CompatibleNullableInt32ListField(); + + @ListField( + id: 1, + element: Int32Type(nullable: true, encoding: Encoding.fixed), + ) + List values = []; +} + +@ForyStruct() +class CompatibleInt32ArrayField { + CompatibleInt32ArrayField(); + + @ArrayField(id: 1, element: Int32Type()) + Int32List values = Int32List(0); +} + @ForyStruct() class OneEnumFieldStruct { OneEnumFieldStruct(); diff --git a/dart/packages/fory-test/test/cross_lang_test/xlang_test_main.dart b/dart/packages/fory-test/test/cross_lang_test/xlang_test_main.dart index f10848f352..3e74db8107 100644 --- a/dart/packages/fory-test/test/cross_lang_test/xlang_test_main.dart +++ b/dart/packages/fory-test/test/cross_lang_test/xlang_test_main.dart @@ -401,6 +401,27 @@ void _runCircularRoundTrip({required bool compatible, required int id}) { _writeFile(fory.serialize(value, trackRef: true)); } +void _runListArrayCompatibleRoundTrip(Type type) { + final fory = _newFory(compatible: true); + registerXlangType(fory, type, id: 901); + final value = fory.deserializeFrom(Buffer.wrap(_readFile())); + _writeFile(fory.serialize(value)); +} + +void _runListArrayCompatibleNullableListToArrayError() { + final fory = _newFory(compatible: true); + registerXlangType(fory, CompatibleInt32ArrayField, id: 901); + final bytes = _readFile(); + try { + fory.deserialize(bytes); + } catch (_) { + _writeFile(bytes); + return; + } + throw StateError( + 'Expected nullable list payload to fail compatible array read.'); +} + void _runCase(String caseName) { switch (caseName) { case 'test_buffer': @@ -639,6 +660,15 @@ void _runCase(String caseName) { registerXlangType(fory, EmptyStruct, id: 213); _roundTripFory(fory); return; + case 'test_list_array_compatible_list_to_array': + _runListArrayCompatibleRoundTrip(CompatibleInt32ArrayField); + return; + case 'test_list_array_compatible_array_to_list': + _runListArrayCompatibleRoundTrip(CompatibleInt32ListField); + return; + case 'test_list_array_compatible_nullable_list_to_array_error': + _runListArrayCompatibleNullableListToArrayError(); + return; default: throw UnsupportedError('Unknown Dart xlang case: $caseName'); } diff --git a/dart/packages/fory/lib/src/serializer/collection_flags.dart b/dart/packages/fory/lib/src/serializer/collection_flags.dart new file mode 100644 index 0000000000..7a1e2276c6 --- /dev/null +++ b/dart/packages/fory/lib/src/serializer/collection_flags.dart @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +abstract final class CollectionFlags { + static const int trackingRef = 0x01; + static const int hasNull = 0x02; + static const int isDeclaredElementType = 0x04; + static const int isSameType = 0x08; +} diff --git a/dart/packages/fory/lib/src/serializer/collection_serializers.dart b/dart/packages/fory/lib/src/serializer/collection_serializers.dart index ac6a4f8f1a..e6d23674d7 100644 --- a/dart/packages/fory/lib/src/serializer/collection_serializers.dart +++ b/dart/packages/fory/lib/src/serializer/collection_serializers.dart @@ -17,23 +17,26 @@ * under the License. */ +import 'dart:typed_data'; + import 'package:fory/src/context/read_context.dart'; import 'package:fory/src/context/ref_writer.dart'; import 'package:fory/src/context/write_context.dart'; +import 'package:fory/src/meta/field_info.dart'; import 'package:fory/src/meta/field_type.dart'; import 'package:fory/src/meta/type_ids.dart'; import 'package:fory/src/resolver/type_resolver.dart'; +import 'package:fory/src/serializer/collection_flags.dart'; import 'package:fory/src/serializer/primitive_serializers.dart'; import 'package:fory/src/serializer/scalar_serializers.dart'; import 'package:fory/src/serializer/serializer.dart'; +import 'package:fory/src/serializer/serialization_field_info.dart'; import 'package:fory/src/serializer/serializer_support.dart'; - -abstract final class CollectionFlags { - static const int trackingRef = 0x01; - static const int hasNull = 0x02; - static const int isDeclaredElementType = 0x04; - static const int isSameType = 0x08; -} +import 'package:fory/src/types/bfloat16.dart'; +import 'package:fory/src/types/bool_list.dart'; +import 'package:fory/src/types/float16.dart'; +import 'package:fory/src/types/int64.dart'; +import 'package:fory/src/types/uint64.dart'; @pragma('vm:prefer-inline') void _writeDirectTypeInfoValue( @@ -404,6 +407,230 @@ final class SetSerializer extends Serializer { const ListSerializer listSerializer = ListSerializer(); const SetSerializer setSerializer = SetSerializer(); +@pragma('vm:never-inline') +Object? readCompatibleMatchedCollectionArrayField( + ReadContext context, + SerializationFieldInfo localField, + FieldInfo remoteField, +) { + final localType = localField.fieldType; + final remoteType = remoteField.fieldType; + if (isCompatibleArrayType(localType.typeId) && + remoteType.typeId == TypeIds.list) { + final elementType = + remoteType.arguments.isEmpty ? null : remoteType.arguments.single; + if (elementType == null || + _arrayElementTypeId(localType.typeId) != + _compatibleArrayElementTypeId(elementType.typeId)) { + throw StateError( + 'Compatible list-to-array field ${localField.name} is unsupported.'); + } + return _readCompatibleListAsArrayField( + context, + elementType, + localType.typeId, + localField.name, + ); + } + if (localType.typeId == TypeIds.list && + isCompatibleArrayType(remoteType.typeId)) { + final localElementType = + localType.arguments.isEmpty ? null : localType.arguments.single; + if (localElementType == null || + _arrayElementTypeId(remoteType.typeId) != + _compatibleArrayElementTypeId(localElementType.typeId)) { + throw StateError( + 'Compatible array-to-list field ${localField.name} is unsupported.'); + } + final raw = readCompatibleField(context, remoteField); + return _arrayToListValue(raw); + } + return readFieldValue(context, localField); +} + +bool isCompatibleArrayType(int typeId) => _arrayElementTypeId(typeId) != null; + +bool isCompatibleCollectionArrayFieldPair( + FieldInfo localField, + FieldInfo remoteField, +) { + return isCompatibleCollectionArrayTypePair( + localField.fieldType, + remoteField.fieldType, + ); +} + +bool isCompatibleCollectionArrayTypePair( + FieldType localType, + FieldType remoteType, +) { + if (isCompatibleArrayType(localType.typeId) && + remoteType.typeId == TypeIds.list) { + return _listElementMatchesArray(remoteType, localType.typeId); + } + if (localType.typeId == TypeIds.list && + isCompatibleArrayType(remoteType.typeId)) { + return _listElementMatchesArray(localType, remoteType.typeId); + } + return false; +} + +bool isCompatibleCollectionArrayRootTypePair( + FieldType localType, + FieldType remoteType, +) { + final localTypeId = localType.typeId; + final remoteTypeId = remoteType.typeId; + return (localTypeId == TypeIds.list && isCompatibleArrayType(remoteTypeId)) || + (isCompatibleArrayType(localTypeId) && remoteTypeId == TypeIds.list); +} + +bool _listElementMatchesArray(FieldType listType, int arrayTypeId) { + final elementType = + listType.arguments.isEmpty ? null : listType.arguments.single; + return elementType != null && + _arrayElementTypeId(arrayTypeId) == + _compatibleArrayElementTypeId(elementType.typeId); +} + +Object _readCompatibleListAsArrayField( + ReadContext context, + FieldType elementType, + int arrayTypeId, + String fieldName, +) { + final size = context.buffer.readVarUint32(); + if (size > context.config.maxCollectionSize) { + throw StateError( + 'Collection size $size exceeds ${context.config.maxCollectionSize}.', + ); + } + if (size == 0) { + return _newArrayValue(arrayTypeId, 0); + } + final header = context.buffer.readUint8(); + final trackRef = (header & CollectionFlags.trackingRef) != 0; + final hasNull = (header & CollectionFlags.hasNull) != 0; + final usesDeclaredType = + (header & CollectionFlags.isDeclaredElementType) != 0; + final sameType = (header & CollectionFlags.isSameType) != 0; + if (hasNull || trackRef) { + throw StateError( + 'Compatible list-to-array field $fieldName cannot read nullable or ref-tracked elements.', + ); + } + if (!sameType || !usesDeclaredType) { + throw StateError( + 'Compatible list-to-array field $fieldName requires declared same-type elements.', + ); + } + final elementResolved = context.typeResolver.resolveFieldType(elementType); + final result = _newArrayValue(arrayTypeId, size); + for (var index = 0; index < size; index += 1) { + _setArrayValue( + result, + arrayTypeId, + index, + context.readResolvedValue(elementResolved, elementType), + ); + } + return result; +} + +int? _arrayElementTypeId(int typeId) { + return switch (typeId) { + TypeIds.boolArray => TypeIds.boolType, + TypeIds.int8Array => TypeIds.int8, + TypeIds.int16Array => TypeIds.int16, + TypeIds.int32Array => TypeIds.int32, + TypeIds.int64Array => TypeIds.int64, + TypeIds.uint8Array => TypeIds.uint8, + TypeIds.uint16Array => TypeIds.uint16, + TypeIds.uint32Array => TypeIds.uint32, + TypeIds.uint64Array => TypeIds.uint64, + TypeIds.float16Array => TypeIds.float16, + TypeIds.bfloat16Array => TypeIds.bfloat16, + TypeIds.float32Array => TypeIds.float32, + TypeIds.float64Array => TypeIds.float64, + _ => null, + }; +} + +int _compatibleArrayElementTypeId(int typeId) { + return switch (typeId) { + TypeIds.varInt32 => TypeIds.int32, + TypeIds.varInt64 || TypeIds.taggedInt64 => TypeIds.int64, + TypeIds.varUint32 => TypeIds.uint32, + TypeIds.varUint64 || TypeIds.taggedUint64 => TypeIds.uint64, + _ => typeId, + }; +} + +Object _newArrayValue(int arrayTypeId, int length) { + return switch (arrayTypeId) { + TypeIds.boolArray => BoolList(length), + TypeIds.int8Array => Int8List(length), + TypeIds.int16Array => Int16List(length), + TypeIds.int32Array => Int32List(length), + TypeIds.int64Array => Int64List(length), + TypeIds.uint8Array => Uint8List(length), + TypeIds.uint16Array => Uint16List(length), + TypeIds.uint32Array => Uint32List(length), + TypeIds.uint64Array => Uint64List(length), + TypeIds.float16Array => Float16List(length), + TypeIds.bfloat16Array => Bfloat16List(length), + TypeIds.float32Array => Float32List(length), + TypeIds.float64Array => Float64List(length), + _ => + throw StateError('Unsupported compatible array field type $arrayTypeId.'), + }; +} + +void _setArrayValue(Object target, int arrayTypeId, int index, Object? value) { + switch (arrayTypeId) { + case TypeIds.boolArray: + (target as BoolList)[index] = value as bool; + case TypeIds.int8Array: + (target as Int8List)[index] = value as int; + case TypeIds.int16Array: + (target as Int16List)[index] = value as int; + case TypeIds.int32Array: + (target as Int32List)[index] = value as int; + case TypeIds.int64Array: + (target as Int64List)[index] = + value is int ? Int64(value) : value as Int64; + case TypeIds.uint8Array: + (target as Uint8List)[index] = value as int; + case TypeIds.uint16Array: + (target as Uint16List)[index] = value as int; + case TypeIds.uint32Array: + (target as Uint32List)[index] = value as int; + case TypeIds.uint64Array: + (target as Uint64List)[index] = + value is int ? Uint64(value) : value as Uint64; + case TypeIds.float16Array: + (target as Float16List)[index] = value as Float16; + case TypeIds.bfloat16Array: + (target as Bfloat16List)[index] = value as Bfloat16; + case TypeIds.float32Array: + (target as Float32List)[index] = (value as num).toDouble(); + case TypeIds.float64Array: + (target as Float64List)[index] = (value as num).toDouble(); + default: + throw StateError('Unsupported compatible array field type $arrayTypeId.'); + } +} + +Object _arrayToListValue(Object? raw) { + if (raw is BoolList) { + return raw.toList(); + } + if (raw is Iterable) { + return raw.toList(); + } + throw StateError('Expected compatible array payload.'); +} + @pragma('vm:prefer-inline') List readTypedListPayload( ReadContext context, diff --git a/dart/packages/fory/lib/src/serializer/struct_serializer.dart b/dart/packages/fory/lib/src/serializer/struct_serializer.dart index d7167b57dd..f810749cbd 100644 --- a/dart/packages/fory/lib/src/serializer/struct_serializer.dart +++ b/dart/packages/fory/lib/src/serializer/struct_serializer.dart @@ -20,8 +20,10 @@ import 'package:fory/src/context/read_context.dart'; import 'package:fory/src/context/write_context.dart'; import 'package:fory/src/meta/field_info.dart'; +import 'package:fory/src/meta/field_type.dart'; import 'package:fory/src/meta/type_def.dart'; import 'package:fory/src/resolver/type_resolver.dart'; +import 'package:fory/src/serializer/collection_serializers.dart'; import 'package:fory/src/serializer/compatible_struct_metadata.dart'; import 'package:fory/src/serializer/serializer.dart'; import 'package:fory/src/serializer/serialization_field_info.dart'; @@ -161,6 +163,14 @@ final class StructSerializer extends Serializer { required bool hasCurrentPreservedRef, }) { final layout = _compatibleReadLayoutForResolved(resolved); + if (layout is _CompatibleCollectionArrayReadLayout) { + return _readCompatibleCollectionArray( + context, + resolved, + layout, + hasCurrentPreservedRef: hasCurrentPreservedRef, + ); + } final compatibleFactory = _compatibleFactory; final compatibleReadersBySlot = _compatibleReadersBySlot; if (compatibleFactory != null && compatibleReadersBySlot != null) { @@ -220,34 +230,145 @@ final class StructSerializer extends Serializer { return value; } + Object _readCompatibleCollectionArray( + ReadContext context, + TypeInfo resolved, + _CompatibleCollectionArrayReadLayout layout, { + required bool hasCurrentPreservedRef, + }) { + final compatibleFactory = _compatibleFactory; + final compatibleReadersBySlot = _compatibleReadersBySlot; + if (compatibleFactory != null && compatibleReadersBySlot != null) { + final int? sentinelId; + final needsSentinel = resolved.supportsRef && !hasCurrentPreservedRef; + if (needsSentinel) { + sentinelId = context.refReader.preserveRefId(-1); + } else { + sentinelId = null; + } + final value = compatibleFactory(); + context.reference(value); + for (var index = 0; index < layout.fields.length; index += 1) { + final localField = layout.fields[index]; + if (localField == null) { + readCompatibleField(context, layout.remoteFields[index]); + continue; + } + compatibleReadersBySlot[localField.slot]( + context, + value, + layout.topLevelListArrayPairs[index] + ? readCompatibleMatchedCollectionArrayField( + context, + localField, + layout.remoteFields[index], + ) + : readFieldValue(context, localField), + ); + } + if (needsSentinel && + context.refReader.hasPreservedRefId && + context.refReader.lastPreservedRefId == sentinelId) { + context.refReader.reference(null); + } + _rememberRemoteMetadata(context, resolved, value); + return value; + } + final compatibleValues = List.filled(_localFields.length, null); + final presentSlots = List.filled(_localFields.length, false); + for (var index = 0; index < layout.fields.length; index += 1) { + final localField = layout.fields[index]; + if (localField == null) { + readCompatibleField(context, layout.remoteFields[index]); + continue; + } + final slot = localField.slot; + compatibleValues[slot] = layout.topLevelListArrayPairs[index] + ? readCompatibleMatchedCollectionArrayField( + context, + localField, + layout.remoteFields[index], + ) + : readFieldValue(context, localField); + presentSlots[slot] = true; + } + final previousCompatibleFields = context.structReadSlots; + context.structReadSlots = StructReadSlots(compatibleValues, presentSlots); + final value = context.readSerializerPayload( + _payloadSerializer, + resolved, + hasCurrentPreservedRef: hasCurrentPreservedRef, + ) as Object; + context.structReadSlots = previousCompatibleFields; + _rememberRemoteMetadata(context, resolved, value); + return value; + } + + @pragma('vm:prefer-inline') _CompatibleReadLayout _compatibleReadLayoutForResolved( TypeInfo resolved, ) { final remoteTypeDef = resolved.remoteTypeDef; if (remoteTypeDef == null) { - return _CompatibleReadLayout(_typeDef.fields, _localFields); + return _CompatibleReadLayout( + _typeDef.fields, + _localFields, + ); } final cached = _compatibleReadLayouts[remoteTypeDef]; if (cached != null) { return cached; } + return _buildCompatibleReadLayout(remoteTypeDef); + } + + _CompatibleReadLayout _buildCompatibleReadLayout(TypeDef remoteTypeDef) { final fields = []; + List? topLevelListArrayPairs; + var hasTopLevelListArrayPairs = false; for (final remoteField in remoteTypeDef.fields) { final localField = _localFieldsByIdentifier[remoteField.identifier]; if (localField == null) { fields.add(null); + topLevelListArrayPairs?.add(false); continue; } - final mergedField = _typeResolver.serializationFieldInfo( - mergeCompatibleReadField(localField.field, remoteField), - slot: localField.slot, - ); + final topLevelListArrayPair = + _topLevelListArrayPair(localField.field, remoteField); + if (_hasUnsupportedListArrayMismatch( + localField.field.fieldType, + remoteField.fieldType, + topLevel: true, + )) { + throw StateError( + 'Compatible field ${localField.name} has unsupported list/array schema mismatch.', + ); + } + if (topLevelListArrayPair) { + topLevelListArrayPairs ??= + List.filled(fields.length, false, growable: true); + hasTopLevelListArrayPairs = true; + } + final mergedField = topLevelListArrayPair + ? localField + : _typeResolver.serializationFieldInfo( + mergeCompatibleReadField(localField.field, remoteField), + slot: localField.slot, + ); fields.add(mergedField); + topLevelListArrayPairs?.add(topLevelListArrayPair); } - final layout = _CompatibleReadLayout( - remoteTypeDef.fields, - List.unmodifiable(fields), - ); + final frozenFields = List.unmodifiable(fields); + final layout = hasTopLevelListArrayPairs + ? _CompatibleCollectionArrayReadLayout( + remoteTypeDef.fields, + frozenFields, + List.unmodifiable(topLevelListArrayPairs!), + ) + : _CompatibleReadLayout( + remoteTypeDef.fields, + frozenFields, + ); _compatibleReadLayouts[remoteTypeDef] = layout; return layout; } @@ -264,9 +385,48 @@ final class StructSerializer extends Serializer { } } +bool _topLevelListArrayPair(FieldInfo localField, FieldInfo remoteField) { + return isCompatibleCollectionArrayFieldPair(localField, remoteField); +} + +bool _hasUnsupportedListArrayMismatch( + FieldType localType, + FieldType remoteType, { + required bool topLevel, +}) { + if (isCompatibleCollectionArrayRootTypePair(localType, remoteType)) { + return !(topLevel && + isCompatibleCollectionArrayTypePair(localType, remoteType)); + } + if (localType.typeId != remoteType.typeId || + localType.arguments.length != remoteType.arguments.length) { + return false; + } + for (var index = 0; index < localType.arguments.length; index += 1) { + if (_hasUnsupportedListArrayMismatch( + localType.arguments[index], + remoteType.arguments[index], + topLevel: false, + )) { + return true; + } + } + return false; +} + final class _CompatibleReadLayout { final List remoteFields; final List fields; const _CompatibleReadLayout(this.remoteFields, this.fields); } + +final class _CompatibleCollectionArrayReadLayout extends _CompatibleReadLayout { + final List topLevelListArrayPairs; + + const _CompatibleCollectionArrayReadLayout( + super.remoteFields, + super.fields, + this.topLevelListArrayPairs, + ); +} diff --git a/dart/packages/fory/test/scalar_and_typed_array_serializer_test.dart b/dart/packages/fory/test/scalar_and_typed_array_serializer_test.dart index 5d61240aba..f5374a187b 100644 --- a/dart/packages/fory/test/scalar_and_typed_array_serializer_test.dart +++ b/dart/packages/fory/test/scalar_and_typed_array_serializer_test.dart @@ -72,6 +72,54 @@ class ExplicitArrayEnvelope { List normalList = []; } +@ForyStruct() +class CompatibleListEnvelope { + CompatibleListEnvelope(); + + @ListField(element: Int32Type(encoding: Encoding.fixed)) + List values = []; +} + +@ForyStruct() +class CompatibleArrayEnvelope { + CompatibleArrayEnvelope(); + + @ArrayField(element: Int32Type()) + Int32List values = Int32List(0); +} + +@ForyStruct() +class CompatibleNullableListEnvelope { + CompatibleNullableListEnvelope(); + + @ListField(element: Int32Type(nullable: true, encoding: Encoding.fixed)) + List values = []; +} + +@ForyStruct() +class CompatibleStringListEnvelope { + CompatibleStringListEnvelope(); + + @ListField(element: StringType()) + List values = []; +} + +@ForyStruct() +class CompatibleNestedArrayListEnvelope { + CompatibleNestedArrayListEnvelope(); + + @ListField(element: ArrayType(element: Int32Type())) + List values = []; +} + +@ForyStruct() +class CompatibleNestedListEnvelope { + CompatibleNestedListEnvelope(); + + @ListField(element: ListType(element: Int32Type(encoding: Encoding.fixed))) + List> values = >[]; +} + void _registerScalarTypes(Fory fory) { ScalarAndTypedArraySerializerTestFory.register( fory, @@ -498,6 +546,144 @@ void main() { expect(fieldsByName['normal_list'], equals(TypeIds.list)); }); + test('adapts immediate compatible list and dense array fields', () { + final writer = Fory(); + final reader = Fory(); + ScalarAndTypedArraySerializerTestFory.register( + writer, + CompatibleListEnvelope, + namespace: 'test', + typeName: 'CompatibleListArrayEnvelope', + ); + ScalarAndTypedArraySerializerTestFory.register( + reader, + CompatibleArrayEnvelope, + namespace: 'test', + typeName: 'CompatibleListArrayEnvelope', + ); + + final bytes = writer.serialize( + CompatibleListEnvelope()..values = [1, 2, 3], + ); + final decoded = reader.deserialize(bytes); + + _expectInt32ListEquals( + decoded.values, Int32List.fromList([1, 2, 3])); + }); + + test('adapts immediate compatible dense array and list fields', () { + final writer = Fory(); + final reader = Fory(); + ScalarAndTypedArraySerializerTestFory.register( + writer, + CompatibleArrayEnvelope, + namespace: 'test', + typeName: 'CompatibleListArrayEnvelope', + ); + ScalarAndTypedArraySerializerTestFory.register( + reader, + CompatibleListEnvelope, + namespace: 'test', + typeName: 'CompatibleListArrayEnvelope', + ); + + final bytes = writer.serialize( + CompatibleArrayEnvelope()..values = Int32List.fromList([1, 2, 3]), + ); + final decoded = reader.deserialize(bytes); + + expect(decoded.values, orderedEquals([1, 2, 3])); + }); + + test( + 'rejects compatible list payload with nullable elements for dense array fields', + () { + final writer = Fory(); + final reader = Fory(); + ScalarAndTypedArraySerializerTestFory.register( + writer, + CompatibleNullableListEnvelope, + namespace: 'test', + typeName: 'CompatibleNullableListArrayEnvelope', + ); + ScalarAndTypedArraySerializerTestFory.register( + reader, + CompatibleArrayEnvelope, + namespace: 'test', + typeName: 'CompatibleNullableListArrayEnvelope', + ); + + final nonNullBytes = writer.serialize( + CompatibleNullableListEnvelope()..values = [1, 2, 3], + ); + final decoded = reader.deserialize(nonNullBytes); + expect(decoded.values, orderedEquals([1, 2, 3])); + + final nullableBytes = writer.serialize( + CompatibleNullableListEnvelope()..values = [1, null, 3], + ); + + expect( + () => reader.deserialize(nullableBytes), + throwsStateError, + ); + }); + + test('rejects incompatible compatible list and dense array element fields', + () { + final writer = Fory(); + final reader = Fory(); + ScalarAndTypedArraySerializerTestFory.register( + writer, + CompatibleStringListEnvelope, + namespace: 'test', + typeName: 'CompatibleMismatchedListArrayEnvelope', + ); + ScalarAndTypedArraySerializerTestFory.register( + reader, + CompatibleArrayEnvelope, + namespace: 'test', + typeName: 'CompatibleMismatchedListArrayEnvelope', + ); + + final bytes = writer.serialize( + CompatibleStringListEnvelope()..values = ['1', '2'], + ); + expect( + () => reader.deserialize(bytes), + throwsStateError, + ); + }); + + test('rejects nested compatible list and dense array field positions', () { + final writer = Fory(); + final reader = Fory(); + ScalarAndTypedArraySerializerTestFory.register( + writer, + CompatibleNestedArrayListEnvelope, + namespace: 'test', + typeName: 'CompatibleNestedListArrayEnvelope', + ); + ScalarAndTypedArraySerializerTestFory.register( + reader, + CompatibleNestedListEnvelope, + namespace: 'test', + typeName: 'CompatibleNestedListArrayEnvelope', + ); + + final bytes = writer.serialize( + CompatibleNestedArrayListEnvelope() + ..values = [ + Int32List.fromList([1, 2]) + ], + ); + + expect( + () => reader.deserialize(bytes), + throwsStateError, + ); + }); + test('enforces maxBinarySize on write and read', () { final oversized = Uint8List.fromList([1, 2, 3, 4]); diff --git a/docs/guide/xlang/serialization.md b/docs/guide/xlang/serialization.md index 794293d329..bee4783697 100644 --- a/docs/guide/xlang/serialization.md +++ b/docs/guide/xlang/serialization.md @@ -30,7 +30,9 @@ Reduced-precision floating-point values are also part of the built-in xlang type - `float16` and `array` - `bfloat16` and `array` -Use the language-specific carrier types documented in the type mapping reference. Python uses `pyfory.Float16` and `pyfory.BFloat16` as annotation markers only; scalar values are native Python `float`, and dense reduced-precision arrays use `pyfory.Float16Array` and `pyfory.BFloat16Array`. Go uses the `float16` and `bfloat16` packages for scalar, slice, and array carriers; JavaScript uses `number` / `number[]` for `float16` and `BFloat16` / `BFloat16Array` for `bfloat16`; Java uses `@ArrayType` on supported reduced-precision carriers for `array` / `array` schema, while general object arrays stay on the `list` path; C++, Rust, and C# provide their own dedicated scalar and array carriers. +Use the language-specific carrier types documented in the type mapping reference. Python uses `pyfory.Float16` and `pyfory.BFloat16` as annotation markers only; scalar values are native Python `float`, and dense reduced-precision arrays use `pyfory.Float16Array` and `pyfory.BFloat16Array`. Go uses the `float16` and `bfloat16` packages for scalar, slice, and array carriers; JavaScript uses `number` for scalar `float16`, `BFloat16` / `number` for scalar `bfloat16`, and dense array carriers `BoolArray`, `Float16Array`, and `BFloat16Array` for the corresponding `array` schemas. Java uses `@ArrayType` on supported reduced-precision carriers for `array` / `array` schema, while general object arrays stay on the `list` path; C++, Rust, and C# provide their own dedicated scalar and array carriers. + +When `compatible=true`, a direct struct/class field can evolve between `list` and `array` for dense bool/numeric `T`. Integer list element encodings in the same signedness and width domain match the corresponding dense array element domain. This applies only to the immediate matched field schema. It does not apply to nested collection, map, array, union, or generic positions. If a peer `list` payload declares nullable or ref-tracked elements, reading it into a local `array` field raises a compatible-read error. ### Java diff --git a/docs/specification/xlang_serialization_spec.md b/docs/specification/xlang_serialization_spec.md index b1743918a9..851791edbc 100644 --- a/docs/specification/xlang_serialization_spec.md +++ b/docs/specification/xlang_serialization_spec.md @@ -182,6 +182,24 @@ canonical dynamic tags for `array`: `ARRAY (42)` is reserved for a future generic or shaped-array descriptor and is not emitted for dense primitive arrays. +In schema-compatible mode only, a matched struct/class field may read between +direct top-level `list` and direct top-level `array` schemas when `T` +belongs to the valid dense array element domains above. Integer list element +encodings in the same signedness and width domain match the corresponding dense +array element domain. This is a read adaptation, not a schema-kind merge: +writers keep emitting their local canonical `list` or `array` payload, and +TypeDef/ClassDef encodings, fingerprints, dynamic root serialization, +schema-consistent mode, and unknown-field skipping continue to treat `list` +and `array` as distinct kinds. + +The adaptation is limited to the immediate schema of the matched compatible +field. It does not apply when `list` or `array` appears inside another +field type, including collection elements, map keys or values, array elements, +union alternatives, or other generic/container positions. When a peer `list` +payload declares nullable or ref-tracked elements, a local matched `array` +field must raise a compatible-read error. Null list elements must not be coerced +to dense-array default values. + Users can also provide meta hints for fields of a type, or the type whole. Here is an example in java which use annotation to provide such information. @@ -1219,7 +1237,9 @@ can be taken as an example. #### primitive array Primitive array are taken as a binary buffer, serialization will just write the length of array size as an unsigned int, -then copy the whole buffer into the stream. +then copy the whole buffer into the stream. Multi-byte element arrays are always encoded in little-endian element order; +runtimes whose native typed-array storage uses another byte order must swap or write elements explicitly instead of +copying native storage bytes unchanged. Such serialization won't compress the array. If users want to compress primitive array, users need to register custom serializers for such types or mark it as list type. diff --git a/docs/specification/xlang_type_mapping.md b/docs/specification/xlang_type_mapping.md index acd8743ca1..a05a045bb3 100644 --- a/docs/specification/xlang_type_mapping.md +++ b/docs/specification/xlang_type_mapping.md @@ -93,7 +93,7 @@ FDL spells them as an encoding modifier plus a semantic integer type. | date | 39 | LocalDate | datetime.date | Date | fory::serialization::Date | fory.Date | chrono::NaiveDate | | decimal | 40 | BigDecimal | Decimal | Decimal | / | fory.Decimal | fory::Decimal | | binary | 41 | byte[] | bytes | / | `uint8_t[n]/vector` | `[n]uint8/[]T` | `Vec` | -| `array` (bool_array) | 43 | bool[] | BoolArray / ndarray(np.bool\_) | Type.boolArray() | `bool[n]` | `[n]bool/[]T` | `Vec` | +| `array` (bool_array) | 43 | bool[] | BoolArray / ndarray(np.bool\_) | BoolArray / Type.boolArray() | `bool[n]` | `[n]bool/[]T` | `Vec` | | `array` (int8_array) | 44 | `@Int8Type byte[]` | Int8Array / ndarray(int8) | Type.int8Array() | `int8_t[n]/vector` | `[n]int8/[]T` | `Vec` | | `array` (int16_array) | 45 | short[] | Int16Array / ndarray(int16) | Type.int16Array() | `int16_t[n]/vector` | `[n]int16/[]T` | `Vec` | | `array` (int32_array) | 46 | int[] | Int32Array / ndarray(int32) | Type.int32Array() | `int32_t[n]/vector` | `[n]int32/[]T` | `Vec` | @@ -103,8 +103,8 @@ FDL spells them as an encoding modifier plus a semantic integer type. | `array` (uint32_array) | 50 | `@UInt32Type int[]` | UInt32Array / ndarray(uint32) | Type.uint32Array() | `uint32_t[n]/vector` | `[n]uint32/[]T` | `Vec` | | `array` (uint64_array) | 51 | `@UInt64Type long[]` | UInt64Array / ndarray(uint64) | Type.uint64Array() | `uint64_t[n]/vector` | `[n]uint64/[]T` | `Vec` | | `array` (float8_array) | 52 | / | / | / | / | / | / | -| `array` (float16_array) | 53 | `Float16Array` / `@Float16Type short[]` | Float16Array / ndarray(float16) | Type.float16Array() | `fory::float16_t[n]/std::vector` | `[N]float16.Float16` / `[]float16.Float16` | `Vec` / `[Float16; N]` | -| `array` (bfloat16_array) | 54 | `BFloat16Array` / `@BFloat16Type short[]` | BFloat16Array / ndarray(bfloat16) | Type.bfloat16Array() | `fory::bfloat16_t[n]/std::vector` | `[N]bfloat16.BFloat16` / `[]bfloat16.BFloat16` | `Vec` / `[BFloat16; N]` | +| `array` (float16_array) | 53 | `Float16Array` / `@Float16Type short[]` | Float16Array / ndarray(float16) | Float16Array / Type.float16Array() | `fory::float16_t[n]/std::vector` | `[N]float16.Float16` / `[]float16.Float16` | `Vec` / `[Float16; N]` | +| `array` (bfloat16_array) | 54 | `BFloat16Array` / `@BFloat16Type short[]` | BFloat16Array / ndarray(bfloat16) | BFloat16Array / Type.bfloat16Array() | `fory::bfloat16_t[n]/std::vector` | `[N]bfloat16.BFloat16` / `[]bfloat16.BFloat16` | `Vec` / `[BFloat16; N]` | | `array` (float32_array) | 55 | float[] | Float32Array / ndarray(float32) | Type.float32Array() | `float[n]/vector` | `[n]float32/[]T` | `Vec` | | `array` (float64_array) | 56 | double[] | Float64Array / ndarray(float64) | Type.float64Array() | `double[n]/vector` | `[n]float64/[]T` | `Vec` | @@ -112,6 +112,7 @@ Notes: - Python `pyfory.Float16` and `pyfory.BFloat16` are reserved annotation markers; scalar values deserialize as native Python `float`. - Python `BoolArray`, `Int8Array`, `Int16Array`, `Int32Array`, `Int64Array`, `UInt8Array`, `UInt16Array`, `UInt32Array`, `UInt64Array`, `Float16Array`, `BFloat16Array`, `Float32Array`, and `Float64Array` are public dense-array wrappers with list-like sequence behavior. +- JavaScript `BoolArray`, fallback `Float16Array`, and `BFloat16Array` are public dense-array wrappers backed by `Uint8Array` or `Uint16Array`. A JavaScript runtime with native `Float16Array` may return that native carrier for `array`. - Java plain `byte[]` maps to `binary`. Numeric byte arrays use type-use annotations: `@Int8Type byte[]` for `array` and `@UInt8Type byte[]` for `array`. - Dart uses `BoolList` for `array`, typed-data lists for integer/float32/float64 arrays, and @@ -123,6 +124,14 @@ Notes: of the current xlang type-mapping surface. - Current xlang uses `*_ARRAY` for one-dimensional primitive arrays and nested `list` for multi-dimensional arrays. +- `list` and `array` remain distinct schema kinds. In schema-compatible struct/class field + matching only, a direct top-level `list` field may be read as a direct top-level `array` + field, and a direct top-level `array` field may be read as a direct top-level `list` field, + when `T` is one of the dense bool/numeric array domains. Integer list element encodings in the + same signedness and width domain match the corresponding dense array element domain. The rule does + not apply inside nested collection, map, array, union, or generic positions. A peer `list` + payload that declares nullable or ref-tracked elements must raise a compatible-read error when the + local matched field is `array`. ## Type info diff --git a/go/fory/field_info.go b/go/fory/field_info.go index cc34d8b84b..9b45a4e68d 100644 --- a/go/fory/field_info.go +++ b/go/fory/field_info.go @@ -1003,10 +1003,6 @@ func elementTypesCompatible(actual, expected reflect.Type) bool { if expected.Kind() == reflect.Ptr && actual.Kind() != reflect.Ptr { return elementTypesCompatible(actual, expected.Elem()) } - if (actual.Kind() == reflect.Array && expected.Kind() == reflect.Slice) || - (actual.Kind() == reflect.Slice && expected.Kind() == reflect.Array) { - return elementTypesCompatible(actual.Elem(), expected.Elem()) - } return false } diff --git a/go/fory/fory_compatible_test.go b/go/fory/fory_compatible_test.go index 7ffff6ef13..16ccca9cd2 100644 --- a/go/fory/fory_compatible_test.go +++ b/go/fory/fory_compatible_test.go @@ -131,6 +131,30 @@ type ByteFamilyInt8ArrayDataClass struct { Payload []int8 `fory:"type=array(element=int8)"` } +type Int32ListPayloadDataClass struct { + Payload []int32 `fory:"type=list(element=int32(nullable=false,encoding=fixed))"` +} + +type Int32VarintListPayloadDataClass struct { + Payload []int32 `fory:"type=list(element=int32(nullable=false))"` +} + +type NullableInt32ListPayloadDataClass struct { + Payload []*int32 `fory:"type=list(element=int32(nullable=true,encoding=fixed))"` +} + +type Int32ArrayPayloadDataClass struct { + Payload [3]int32 `fory:"type=array(element=int32)"` +} + +type NestedInt32ListPayloadDataClass struct { + Payload [][]int32 `fory:"type=list(element=list(element=int32(nullable=false,encoding=fixed)))"` +} + +type NestedInt32ArrayPayloadDataClass struct { + Payload [][2]int32 `fory:"type=list(element=array(element=int32))"` +} + func TestMetaShareEnabled(t *testing.T) { fory := NewForyWithOptions(WithXlang(true), WithCompatible(true)) @@ -531,6 +555,89 @@ func TestCompatibleSerializationScenarios(t *testing.T) { assert.Nil(t, out.Payload) }, }, + { + name: "Int32ListToArray", + tag: "Int32Sequence", + writeType: Int32ListPayloadDataClass{}, + readType: Int32ArrayPayloadDataClass{}, + input: Int32ListPayloadDataClass{ + Payload: []int32{1, 2, 3}, + }, + assertFunc: func(t *testing.T, input any, output any) { + out := output.(Int32ArrayPayloadDataClass) + assert.Equal(t, [3]int32{1, 2, 3}, out.Payload) + }, + }, + { + name: "Int32VarintListToArray", + tag: "Int32Sequence", + writeType: Int32VarintListPayloadDataClass{}, + readType: Int32ArrayPayloadDataClass{}, + input: Int32VarintListPayloadDataClass{ + Payload: []int32{-1, 2, 3}, + }, + assertFunc: func(t *testing.T, input any, output any) { + out := output.(Int32ArrayPayloadDataClass) + assert.Equal(t, [3]int32{-1, 2, 3}, out.Payload) + }, + }, + { + name: "EmptyInt32ListDoesNotMatchFixedArrayLength", + tag: "Int32Sequence", + writeType: Int32ListPayloadDataClass{}, + readType: Int32ArrayPayloadDataClass{}, + input: Int32ListPayloadDataClass{Payload: []int32{}}, + unmarshalErrContains: "array length 3 does not match serialized list length 0", + }, + { + name: "Int32ArrayToList", + tag: "Int32Sequence", + writeType: Int32ArrayPayloadDataClass{}, + readType: Int32ListPayloadDataClass{}, + input: Int32ArrayPayloadDataClass{ + Payload: [3]int32{1, 2, 3}, + }, + assertFunc: func(t *testing.T, input any, output any) { + out := output.(Int32ListPayloadDataClass) + assert.Equal(t, []int32{1, 2, 3}, out.Payload) + }, + }, + { + name: "NullableInt32ListWithoutNullsMatchesArray", + tag: "Int32Sequence", + writeType: NullableInt32ListPayloadDataClass{}, + readType: Int32ArrayPayloadDataClass{}, + input: NullableInt32ListPayloadDataClass{ + Payload: []*int32{ptr(int32(1)), ptr(int32(2)), ptr(int32(3))}, + }, + assertFunc: func(t *testing.T, input any, output any) { + out := output.(Int32ArrayPayloadDataClass) + assert.Equal(t, [3]int32{1, 2, 3}, out.Payload) + }, + }, + { + name: "NullableInt32ListPayloadDoesNotMatchArray", + tag: "Int32Sequence", + writeType: NullableInt32ListPayloadDataClass{}, + readType: Int32ArrayPayloadDataClass{}, + input: NullableInt32ListPayloadDataClass{ + Payload: []*int32{ptr(int32(1)), nil, ptr(int32(3))}, + }, + unmarshalErrContains: "compatible list to array field requires non-null elements", + }, + { + name: "NestedListArrayMismatch", + tag: "NestedInt32Sequence", + writeType: NestedInt32ListPayloadDataClass{}, + readType: NestedInt32ArrayPayloadDataClass{}, + input: NestedInt32ListPayloadDataClass{ + Payload: [][]int32{{1, 2}, {3, 4}}, + }, + assertFunc: func(t *testing.T, input any, output any) { + out := output.(NestedInt32ArrayPayloadDataClass) + assert.Nil(t, out.Payload) + }, + }, } for _, tc := range cases { @@ -541,14 +648,15 @@ func TestCompatibleSerializationScenarios(t *testing.T) { } type compatibilityCase struct { - name string - tag string - writeType any - readType any - input any - assertFunc func(t *testing.T, input any, output any) - writerSetup func(*Fory) error - readerSetup func(*Fory) error + name string + tag string + writeType any + readType any + input any + assertFunc func(t *testing.T, input any, output any) + writerSetup func(*Fory) error + readerSetup func(*Fory) error + unmarshalErrContains string } func runCompatibilityCase(t *testing.T, tc compatibilityCase) { @@ -587,6 +695,12 @@ func runCompatibilityCase(t *testing.T, tc compatibilityCase) { assert.NotPanics(t, func() { unmarshalErr = reader.Unmarshal(data, target.Interface()) }) + if tc.unmarshalErrContains != "" { + if assert.Error(t, unmarshalErr) { + assert.Contains(t, unmarshalErr.Error(), tc.unmarshalErrContains) + } + return + } assert.NoError(t, unmarshalErr) tc.assertFunc(t, tc.input, target.Elem().Interface()) diff --git a/go/fory/slice.go b/go/fory/slice.go index 9231aa75f1..8edc005e39 100644 --- a/go/fory/slice.go +++ b/go/fory/slice.go @@ -205,6 +205,25 @@ func (s *sliceSerializer) writeDataWithGenerics(ctx *WriteContext, value reflect // Serialize elements with ref tracking or nulls handling declaredGenericDispatch := hasGenerics && serializerNeedsGenericDispatch(s.elemSerializer) + if !trackRefs && !hasNull { + if declaredGenericDispatch { + for i := 0; i < length; i++ { + s.elemSerializer.Write(ctx, RefModeNone, false, true, value.Index(i)) + if ctx.HasError() { + return + } + } + } else { + for i := 0; i < length; i++ { + s.elemSerializer.WriteData(ctx, value.Index(i)) + if ctx.HasError() { + return + } + } + } + return + } + for i := 0; i < length; i++ { elem := value.Index(i) @@ -325,6 +344,25 @@ func (s *sliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { elemRefMode = RefModeTracking } + if !trackRefs && !hasNull { + if declaredGenericDispatch { + for i := 0; i < length; i++ { + s.elemSerializer.Read(ctx, RefModeNone, false, true, value.Index(i)) + if ctx.HasError() { + return + } + } + } else { + for i := 0; i < length; i++ { + s.elemSerializer.ReadData(ctx, value.Index(i)) + if ctx.HasError() { + return + } + } + } + return + } + // Slow path: general deserialization with ref tracking or nulls for i := 0; i < length; i++ { elem := value.Index(i) diff --git a/go/fory/slice_primitive_list.go b/go/fory/slice_primitive_list.go index d5ad0454f8..aa965d1e45 100644 --- a/go/fory/slice_primitive_list.go +++ b/go/fory/slice_primitive_list.go @@ -27,6 +27,11 @@ type primitiveListSerializer struct { elemTypeID TypeId } +type compatiblePrimitiveListToArraySerializer struct { + arrayType reflect.Type + listReader primitiveListSerializer +} + func newPrimitiveListSerializer(type_ reflect.Type, elemTypeID TypeId) (Serializer, bool) { if type_.Kind() != reflect.Slice { return nil, false @@ -177,6 +182,84 @@ func (s primitiveListSerializer) ReadData(ctx *ReadContext, value reflect.Value) s.readValues(buf, err, value, length, hasNull) } +func (s compatiblePrimitiveListToArraySerializer) WriteData(ctx *WriteContext, value reflect.Value) { + ctx.SetError(SerializationErrorf("compatible list-to-array field serializer is read-only")) +} + +func (s compatiblePrimitiveListToArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) { + ctx.SetError(SerializationErrorf("compatible list-to-array field serializer is read-only")) +} + +//go:noinline +func (s compatiblePrimitiveListToArraySerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { + done, typeID := readSliceRefAndType(ctx, refMode, readType, value) + if done || ctx.HasError() { + return + } + if readType && typeID != uint32(LIST) { + ctx.SetError(DeserializationErrorf("array-compatible list type mismatch: expected LIST (%d), got %d", LIST, typeID)) + return + } + s.ReadData(ctx, value) +} + +func (s compatiblePrimitiveListToArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { + buf := ctx.Buffer() + err := ctx.Err() + length := ctx.ReadCollectionLength() + if ctx.HasError() { + return + } + if value.Kind() != reflect.Slice && length != value.Len() { + ctx.SetError(DeserializationErrorf("array length %d does not match serialized list length %d", value.Len(), length)) + return + } + if length == 0 { + if value.Kind() == reflect.Slice { + value.Set(reflect.MakeSlice(value.Type(), 0, 0)) + } else if value.Len() != 0 { + ctx.SetError(DeserializationErrorf("array-compatible list length %d does not match array length %d", length, value.Len())) + } + return + } + if value.Kind() == reflect.Array && length != value.Len() { + ctx.SetError(DeserializationErrorf("array-compatible list length %d does not match array length %d", length, value.Len())) + return + } + collectFlag := buf.ReadInt8(err) + if (collectFlag & CollectionIsSameType) != 0 { + if (collectFlag & CollectionIsDeclElementType) == 0 { + ctx.TypeResolver().ReadTypeInfo(buf, err) + } + } + if (collectFlag & CollectionTrackingRef) != 0 { + ctx.SetError(DeserializationErrorf("array-compatible list does not support reference-tracked elements")) + return + } + if (collectFlag & CollectionHasNull) != 0 { + ctx.SetError(DeserializationErrorf("compatible list to array field requires non-null elements")) + return + } + if (collectFlag & (CollectionIsSameType | CollectionIsDeclElementType)) != (CollectionIsSameType | CollectionIsDeclElementType) { + ctx.SetError(DeserializationErrorf("array-compatible list requires declared same-type elements")) + return + } + if value.Kind() == reflect.Slice { + temp := reflect.New(value.Type()).Elem() + s.listReader.readValues(buf, err, temp, length, false) + if ctx.HasError() { + return + } + value.Set(temp) + return + } + s.listReader.readArrayValues(buf, err, value, length) +} + +func (s compatiblePrimitiveListToArraySerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { + s.Read(ctx, refMode, false, false, value) +} + func (s primitiveListSerializer) readValues(buf *ByteBuffer, err *Error, value reflect.Value, length int, hasNull bool) { switch s.type_.Elem().Kind() { case reflect.Bool: @@ -208,6 +291,104 @@ func (s primitiveListSerializer) readValues(buf *ByteBuffer, err *Error, value r } } +func (s primitiveListSerializer) readArrayValues(buf *ByteBuffer, err *Error, value reflect.Value, length int) { + switch s.type_.Elem().Kind() { + case reflect.Bool: + raw := buf.ReadBinary(length, err) + for i := 0; i < length; i++ { + value.Index(i).SetBool(raw[i] != 0) + } + case reflect.Int8: + raw := buf.ReadBinary(length, err) + for i := 0; i < length; i++ { + value.Index(i).SetInt(int64(int8(raw[i]))) + } + case reflect.Uint8: + raw := buf.ReadBinary(length, err) + for i := 0; i < length; i++ { + value.Index(i).SetUint(uint64(raw[i])) + } + case reflect.Int16: + for i := 0; i < length; i++ { + value.Index(i).SetInt(int64(buf.ReadInt16(err))) + } + case reflect.Uint16: + for i := 0; i < length; i++ { + value.Index(i).SetUint(uint64(uint16(buf.ReadInt16(err)))) + } + case reflect.Int32: + for i := 0; i < length; i++ { + if s.elemTypeID == INT32 { + value.Index(i).SetInt(int64(buf.ReadInt32(err))) + } else { + value.Index(i).SetInt(int64(buf.ReadVarint32(err))) + } + } + case reflect.Uint32: + for i := 0; i < length; i++ { + if s.elemTypeID == UINT32 { + value.Index(i).SetUint(uint64(uint32(buf.ReadInt32(err)))) + } else { + value.Index(i).SetUint(uint64(buf.ReadVarUint32(err))) + } + } + case reflect.Int64: + for i := 0; i < length; i++ { + switch s.elemTypeID { + case INT64: + value.Index(i).SetInt(buf.ReadInt64(err)) + case TAGGED_INT64: + value.Index(i).SetInt(buf.ReadTaggedInt64(err)) + default: + value.Index(i).SetInt(buf.ReadVarint64(err)) + } + } + case reflect.Uint64: + for i := 0; i < length; i++ { + switch s.elemTypeID { + case UINT64: + value.Index(i).SetUint(uint64(buf.ReadInt64(err))) + case TAGGED_UINT64: + value.Index(i).SetUint(buf.ReadTaggedUint64(err)) + default: + value.Index(i).SetUint(buf.ReadVarUint64(err)) + } + } + case reflect.Int: + for i := 0; i < length; i++ { + if s.elemTypeID == INT32 { + value.Index(i).SetInt(int64(buf.ReadInt32(err))) + } else if s.elemTypeID == INT64 { + value.Index(i).SetInt(buf.ReadInt64(err)) + } else if reflect.TypeOf(int(0)).Size() == 8 { + value.Index(i).SetInt(buf.ReadVarint64(err)) + } else { + value.Index(i).SetInt(int64(buf.ReadVarint32(err))) + } + } + case reflect.Uint: + for i := 0; i < length; i++ { + if s.elemTypeID == UINT32 { + value.Index(i).SetUint(uint64(uint32(buf.ReadInt32(err)))) + } else if s.elemTypeID == UINT64 { + value.Index(i).SetUint(uint64(buf.ReadInt64(err))) + } else if reflect.TypeOf(uint(0)).Size() == 8 { + value.Index(i).SetUint(buf.ReadVarUint64(err)) + } else { + value.Index(i).SetUint(uint64(buf.ReadVarUint32(err))) + } + } + case reflect.Float32: + for i := 0; i < length; i++ { + value.Index(i).SetFloat(float64(buf.ReadFloat32(err))) + } + case reflect.Float64: + for i := 0; i < length; i++ { + value.Index(i).SetFloat(buf.ReadFloat64(err)) + } + } +} + func writeBoolListPayload(buf *ByteBuffer, value []bool) { if len(value) > 0 { buf.WriteBinary(unsafe.Slice((*byte)(unsafe.Pointer(&value[0])), len(value))) diff --git a/go/fory/struct_init.go b/go/fory/struct_init.go index e31cef8929..b64ae7d6fe 100644 --- a/go/fory/struct_init.go +++ b/go/fory/struct_init.go @@ -432,6 +432,7 @@ func (s *structSerializer) initFieldsFromTypeDef(typeResolver *TypeResolver) err if exists { shouldRead := false + usesCompatibleCollectionArrayReader := false isPolymorphicField := def.typeSpec.TypeId() == UNKNOWN defTypeId := def.typeSpec.TypeId() internalDefTypeId := defTypeId @@ -521,13 +522,37 @@ func (s *structSerializer) initFieldsFromTypeDef(typeResolver *TypeResolver) err } else if defTypeId == SET && isSetReflectType(localType) { shouldRead = true fieldType = localType + } else if defTypeId == LIST && localFieldSpec != nil && listFieldCanReadLocalArray( + def.typeSpec, + def.nullable, + def.trackRef, + localFieldSpec.Type, + localNullableByIndex[fieldIndex], + localTrackRefByIndex[fieldIndex], + localType, + ) { + shouldRead = true + fieldType = localType + } else if defTypeId == LIST && localFieldSpec != nil && compatibleListFieldCanReadLocalArray(def.typeSpec, localFieldSpec.Type, localType) { + shouldRead = true + usesCompatibleCollectionArrayReader = true + fieldType = localType + sliceType := reflect.SliceOf(localType.Elem()) + if listReader, ok := newPrimitiveListSerializer(sliceType, def.typeSpec.Element.TypeID); ok { + fieldSerializer = compatiblePrimitiveListToArraySerializer{ + arrayType: localType, + listReader: listReader.(primitiveListSerializer), + } + } + } else if defTypeId == LIST && localType.Kind() == reflect.Array { + shouldRead = false } else if !typeLookupFailed && typesCompatible(localType, remoteType) { shouldRead = true fieldType = localType } if shouldRead { - if localType != nil { + if localType != nil && !usesCompatibleCollectionArrayReader { localSerializer, localErr := serializerForTypeSpec(typeResolver, localType, def.typeSpec) if localErr == nil && localSerializer != nil { fieldSerializer = localSerializer @@ -775,6 +800,72 @@ func fieldSpecEqualForDiff(remoteSpec *TypeSpec, remoteNullable bool, remoteTrac return remote.EqualForDiff(local) } +func listFieldCanReadLocalArray(remoteSpec *TypeSpec, remoteNullable bool, remoteTrackRef bool, localSpec *TypeSpec, localNullable bool, localTrackRef bool, localType reflect.Type) bool { + if localType == nil || localType.Kind() != reflect.Array || remoteSpec == nil || localSpec == nil { + return false + } + if remoteSpec.TypeID != LIST || localSpec.TypeID != LIST { + return false + } + return fieldSpecEqualForDiff(remoteSpec, remoteNullable, remoteTrackRef, localSpec, localNullable, localTrackRef) +} + +func compatibleListFieldCanReadLocalArray(remoteSpec *TypeSpec, localSpec *TypeSpec, localType reflect.Type) bool { + if remoteSpec == nil || localSpec == nil || localType == nil { + return false + } + if localType.Kind() != reflect.Array && localType.Kind() != reflect.Slice { + return false + } + if remoteSpec.TypeID != LIST || remoteSpec.Element == nil { + return false + } + if !isPrimitiveArrayType(localSpec.TypeID) { + return false + } + remoteSpec.normalizeChildren() + localSpec.normalizeChildren() + if _, ok := primitiveArrayElementTypeID(localSpec.TypeID); !ok { + return false + } + sliceType := reflect.SliceOf(localType.Elem()) + _, ok := newPrimitiveListSerializer(sliceType, remoteSpec.Element.TypeID) + return ok +} + +func primitiveArrayElementTypeID(arrayTypeID TypeId) (TypeId, bool) { + switch arrayTypeID { + case BOOL_ARRAY: + return BOOL, true + case INT8_ARRAY: + return INT8, true + case INT16_ARRAY: + return INT16, true + case INT32_ARRAY: + return INT32, true + case INT64_ARRAY: + return INT64, true + case UINT8_ARRAY: + return UINT8, true + case UINT16_ARRAY: + return UINT16, true + case UINT32_ARRAY: + return UINT32, true + case UINT64_ARRAY: + return UINT64, true + case FLOAT32_ARRAY: + return FLOAT32, true + case FLOAT64_ARRAY: + return FLOAT64, true + case FLOAT16_ARRAY: + return FLOAT16, true + case BFLOAT16_ARRAY: + return BFLOAT16, true + default: + return UNKNOWN, false + } +} + func typeIdEqualForDiff(remoteTypeId TypeId, localTypeId TypeId) bool { if remoteTypeId == localTypeId { return true diff --git a/go/fory/tests/xlang/xlang_test_main.go b/go/fory/tests/xlang/xlang_test_main.go index 9708cb6b35..c6342a9e6b 100644 --- a/go/fory/tests/xlang/xlang_test_main.go +++ b/go/fory/tests/xlang/xlang_test_main.go @@ -319,6 +319,18 @@ type ReducedPrecisionFloatStruct struct { Bfloat16Array []bfloat16.BFloat16 } +type CompatibleInt32ListField struct { + Values []int32 `fory:"id=1,type=list(element=int32(nullable=false,encoding=fixed))"` +} + +type CompatibleNullableInt32ListField struct { + Values []*int32 `fory:"id=1,type=list(element=int32(nullable=true,encoding=fixed))"` +} + +type CompatibleInt32ArrayField struct { + Values []int32 `fory:"id=1,type=array(element=int32)"` +} + type StructWithList struct { Items []string } @@ -1598,6 +1610,59 @@ func testReducedPrecisionFloatStructCompatibleSkip() { writeFile(dataFile, serialized) } +func testListArrayCompatibleListToArray() { + dataFile := getDataFile() + data := readFile(dataFile) + + f := fory.New(fory.WithXlang(true), fory.WithCompatible(true)) + f.RegisterStruct(CompatibleInt32ArrayField{}, 901) + + var result CompatibleInt32ArrayField + if err := f.Deserialize(data, &result); err != nil { + panic(fmt.Sprintf("Failed to deserialize compatible list as array: %v", err)) + } + serialized, err := f.Serialize(&result) + if err != nil { + panic(fmt.Sprintf("Failed to serialize compatible array field: %v", err)) + } + + writeFile(dataFile, serialized) +} + +func testListArrayCompatibleArrayToList() { + dataFile := getDataFile() + data := readFile(dataFile) + + f := fory.New(fory.WithXlang(true), fory.WithCompatible(true)) + f.RegisterStruct(CompatibleInt32ListField{}, 901) + + var result CompatibleInt32ListField + if err := f.Deserialize(data, &result); err != nil { + panic(fmt.Sprintf("Failed to deserialize compatible array as list: %v", err)) + } + serialized, err := f.Serialize(&result) + if err != nil { + panic(fmt.Sprintf("Failed to serialize compatible list field: %v", err)) + } + + writeFile(dataFile, serialized) +} + +func testListArrayCompatibleNullableListToArrayError() { + dataFile := getDataFile() + data := readFile(dataFile) + + f := fory.New(fory.WithXlang(true), fory.WithCompatible(true)) + f.RegisterStruct(CompatibleInt32ArrayField{}, 901) + + var result CompatibleInt32ArrayField + if err := f.Deserialize(data, &result); err == nil { + panic("Expected nullable list payload to fail compatible array read") + } + + writeFile(dataFile, data) +} + // Enum field tests func testOneEnumFieldSchemaConsistent() { dataFile := getDataFile() @@ -2887,6 +2952,12 @@ func main() { testReducedPrecisionFloatStruct() case "test_reduced_precision_float_struct_compatible_skip": testReducedPrecisionFloatStructCompatibleSkip() + case "test_list_array_compatible_list_to_array": + testListArrayCompatibleListToArray() + case "test_list_array_compatible_array_to_list": + testListArrayCompatibleArrayToList() + case "test_list_array_compatible_nullable_list_to_array_error": + testListArrayCompatibleNullableListToArrayError() case "test_one_enum_field_schema": testOneEnumFieldSchemaConsistent() case "test_one_enum_field_compatible": diff --git a/integration_tests/idl_tests/javascript/roundtrip.ts b/integration_tests/idl_tests/javascript/roundtrip.ts index 4b77e8fc41..e832c17b69 100644 --- a/integration_tests/idl_tests/javascript/roundtrip.ts +++ b/integration_tests/idl_tests/javascript/roundtrip.ts @@ -29,7 +29,14 @@ import * as assert from "assert/strict"; import * as fs from "fs"; -import Fory, { BFloat16, Decimal, Type } from "@apache-fory/core"; +import Fory, { + BFloat16, + BFloat16Array, + BoolArray, + Decimal, + ForyFloat16Array, + Type, +} from "@apache-fory/core"; import { AnyHelper } from "@apache-fory/core/dist/lib/gen/any"; import { ConfigFlags, @@ -183,6 +190,15 @@ function normalizeAcyclic(value: unknown): unknown { if (value instanceof BFloat16) { return value.toFloat32(); } + if ( + value instanceof BoolArray || + value instanceof ForyFloat16Array || + value instanceof BFloat16Array + ) { + return Array.from(value as Iterable, (item) => + normalizeAcyclic(item), + ); + } if (value instanceof Date) { return { __dateMs: value.getTime() }; } diff --git a/integration_tests/idl_tests/javascript/test/roundtrip.test.ts b/integration_tests/idl_tests/javascript/test/roundtrip.test.ts index 2d3a3b2764..e4fde53575 100644 --- a/integration_tests/idl_tests/javascript/test/roundtrip.test.ts +++ b/integration_tests/idl_tests/javascript/test/roundtrip.test.ts @@ -17,7 +17,14 @@ * under the License. */ -import Fory, { BFloat16, Decimal, Type } from "@apache-fory/core"; +import Fory, { + BFloat16, + BFloat16Array, + BoolArray, + Decimal, + ForyFloat16Array, + Type, +} from "@apache-fory/core"; import { AddressBook, Animal, @@ -128,6 +135,13 @@ function normalize(value: unknown): unknown { if (value instanceof BFloat16) { return value.toFloat32(); } + if ( + value instanceof BoolArray || + value instanceof ForyFloat16Array || + value instanceof BFloat16Array + ) { + return Array.from(value as Iterable, (item) => normalize(item)); + } if (value instanceof Date) { return { __dateMs: value.getTime() }; } diff --git a/java/fory-core/src/main/java/org/apache/fory/builder/BaseObjectCodecBuilder.java b/java/fory-core/src/main/java/org/apache/fory/builder/BaseObjectCodecBuilder.java index af035f359e..8353520ea4 100644 --- a/java/fory-core/src/main/java/org/apache/fory/builder/BaseObjectCodecBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/builder/BaseObjectCodecBuilder.java @@ -2163,6 +2163,10 @@ private Expression deserializeForNotNullNoRef( protected Expression deserializeField( Expression buffer, Descriptor descriptor, Function callback) { + if (hasCompatibleCollectionArrayRead(descriptor)) { + Expression value = deserializeCompatibleListArrayField(descriptor); + return new ListExpression(value, callback.apply(value)); + } TypeRef typeRef = descriptor.getTypeRef(); boolean nullable = descriptor.isNullable(); // descriptor.isTrackingRef() already includes the needWriteRef check @@ -2213,6 +2217,38 @@ protected Expression deserializeField( } } + protected Expression deserializeCompatibleListArrayField(Descriptor descriptor) { + TypeExtMeta extMeta = descriptor.getTypeRef().getTypeExtMeta(); + boolean nullable = extMeta == null ? descriptor.isNullable() : extMeta.nullable(); + boolean trackingRef = extMeta == null ? descriptor.isTrackingRef() : extMeta.trackingRef(); + Class targetType = + descriptor.getField() == null ? descriptor.getRawType() : descriptor.getField().getType(); + return new StaticInvoke( + MetaSharedSerializer.class, + "readCompatibleCollectionArrayField", + OBJECT_TYPE, + readContextRef(), + Literal.ofBoolean(trackingRef), + Literal.ofBoolean(nullable), + Literal.ofInt( + MetaSharedSerializer.compatibleCollectionArrayReadMode(typeResolver, descriptor)), + Literal.ofInt( + MetaSharedSerializer.compatibleCollectionArrayTypeId(typeResolver, descriptor)), + Literal.ofInt( + MetaSharedSerializer.compatibleCollectionElementTypeId(typeResolver, descriptor)), + Literal.ofClass(targetType)); + } + + protected boolean hasCompatibleCollectionArrayRead(Descriptor descriptor) { + return MetaSharedSerializer.hasCompatibleCollectionArrayRead(typeResolver, descriptor); + } + + protected TypeRef compatibleReadTargetTypeRef(Descriptor descriptor) { + return descriptor.getField() == null + ? descriptor.getTypeRef() + : TypeRef.of(descriptor.getField().getGenericType()); + } + private Expression deserializeForNotNullForField( Expression buffer, Descriptor descriptor, Expression serializer) { TypeRef typeRef = descriptor.getTypeRef(); diff --git a/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java b/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java index 0603f97133..05e4dda534 100644 --- a/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java @@ -673,6 +673,10 @@ protected Expression deserializeGroup( for (Descriptor d : group) { ExpressionVisitor.ExprHolder exprHolder = ExpressionVisitor.ExprHolder.of("bean", bean); walkPath.add(d.getDeclaringClass() + d.getName()); + TypeRef castTypeRef = + hasCompatibleCollectionArrayRead(d) + ? compatibleReadTargetTypeRef(d) + : d.getTypeRef(); Expression action = deserializeField( buffer, @@ -680,8 +684,7 @@ protected Expression deserializeGroup( // `bean` will be replaced by `Reference` to cut-off expr // dependency. expr -> - setFieldValue( - exprHolder.get("bean"), d, tryInlineCast(expr, d.getTypeRef()))); + setFieldValue(exprHolder.get("bean"), d, tryInlineCast(expr, castTypeRef))); walkPath.removeLast(); groupExpressions.add(action); } @@ -701,8 +704,14 @@ protected Expression deserializeGroupForRecord( // use Reference to cut-off expr dependency. for (Descriptor d : group) { boolean nullable = d.isNullable(); - Expression v = deserializeForNullableField(buffer, d, expr -> expr, nullable); - Expression action = setFieldValue(bean, d, tryInlineCast(v, d.getTypeRef())); + boolean compatibleCollectionArrayRead = hasCompatibleCollectionArrayRead(d); + Expression v = + compatibleCollectionArrayRead + ? deserializeCompatibleListArrayField(d) + : deserializeForNullableField(buffer, d, expr -> expr, nullable); + TypeRef castTypeRef = + compatibleCollectionArrayRead ? compatibleReadTargetTypeRef(d) : d.getTypeRef(); + Expression action = setFieldValue(bean, d, tryInlineCast(v, castTypeRef)); groupExpressions.add(action); } return groupExpressions; diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/FieldInfo.java b/java/fory-core/src/main/java/org/apache/fory/meta/FieldInfo.java index 980321b324..70494f2e79 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/FieldInfo.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/FieldInfo.java @@ -29,6 +29,7 @@ import org.apache.fory.serializer.converter.FieldConverters; import org.apache.fory.type.Descriptor; import org.apache.fory.type.DescriptorBuilder; +import org.apache.fory.type.TypeAnnotationUtils; import org.apache.fory.type.Types; import org.apache.fory.util.StringUtils; @@ -103,6 +104,48 @@ Descriptor toDescriptor(TypeResolver resolver, Descriptor descriptor) { boolean remoteTrackingRef = fieldType.trackingRef(); if (descriptor != null) { + Class rawType = typeRef.getRawType(); + DescriptorBuilder builder = + new DescriptorBuilder(descriptor) + .typeName(typeName) + .trackingRef(remoteTrackingRef) + .nullable(remoteNullable) + .typeRef(typeRef) + .type(rawType); + if (isTopLevelListArrayCompatibleReadPair(resolver, descriptor)) { + if (listElementTypeId(fieldType) != Types.UNKNOWN) { + TypeRef peerListTypeRef = fieldType.toTypeToken(resolver, null); + return builder + .typeName(fieldType.getTypeName(resolver, peerListTypeRef)) + .typeRef(peerListTypeRef) + .type(peerListTypeRef.getRawType()) + .build(); + } + return builder.build(); + } + FieldTypes.FieldType localFieldType = + descriptor.getField() == null + ? null + : FieldTypes.buildFieldType(resolver, descriptor.getField()); + int peerArrayTypeId = arrayTypeId(fieldType); + if (peerArrayTypeId != Types.UNKNOWN + && peerArrayTypeId == arrayTypeId(localFieldType) + && TypeAnnotationUtils.isArrayType(descriptor)) { + return new DescriptorBuilder(descriptor) + .trackingRef(remoteTrackingRef) + .nullable(remoteNullable) + .build(); + } + if (localFieldType != null && hasListArrayShapeMismatch(fieldType, localFieldType)) { + throw new IllegalArgumentException( + StringUtils.format( + "Unsupported nested list/array compatible field mismatch for field " + + "{}.{}: peer={}, local={}", + definedClass, + fieldName, + fieldType, + localFieldType)); + } if (remoteNullable == descriptor.isNullable() && remoteTrackingRef == descriptor.isTrackingRef() && typeRef.equals(descriptor.getTypeRef())) { @@ -110,7 +153,6 @@ Descriptor toDescriptor(TypeResolver resolver, Descriptor descriptor) { return descriptor; } } - Class rawType = typeRef.getRawType(); if (FieldTypes.useFieldType(rawType, descriptor)) { return new DescriptorBuilder(descriptor) .typeName(typeName) @@ -119,13 +161,6 @@ Descriptor toDescriptor(TypeResolver resolver, Descriptor descriptor) { .typeRef(typeRef) .build(); } - DescriptorBuilder builder = - new DescriptorBuilder(descriptor) - .typeName(typeName) - .trackingRef(remoteTrackingRef) - .nullable(remoteNullable) - .typeRef(typeRef) - .type(rawType); // Local field exists - check if we need to update nullable/trackingRef boolean typeMismatch = !typeRef.equals(declared); if (typeMismatch) { @@ -161,6 +196,129 @@ Descriptor toDescriptor(TypeResolver resolver, Descriptor descriptor) { remoteNullable); } + private boolean isTopLevelListArrayCompatibleReadPair( + TypeResolver resolver, Descriptor localDescriptor) { + Field localField = localDescriptor.getField(); + if (localField == null || !resolver.isCrossLanguage()) { + return false; + } + FieldTypes.FieldType localFieldType = FieldTypes.buildFieldType(resolver, localField); + int peerListElementTypeId = listElementTypeId(fieldType); + if (peerListElementTypeId != Types.UNKNOWN) { + int localArrayTypeId = arrayTypeId(localFieldType); + return localArrayTypeId != Types.UNKNOWN + && localArrayTypeId == denseArrayTypeId(peerListElementTypeId); + } + int peerArrayTypeId = arrayTypeId(fieldType); + if (peerArrayTypeId != Types.UNKNOWN) { + int localListElementTypeId = listElementTypeId(localFieldType); + return localListElementTypeId != Types.UNKNOWN + && peerArrayTypeId == denseArrayTypeId(localListElementTypeId); + } + return false; + } + + private static boolean hasListArrayShapeMismatch( + FieldTypes.FieldType peerFieldType, FieldTypes.FieldType localFieldType) { + boolean peerList = isListField(peerFieldType); + boolean localList = isListField(localFieldType); + boolean peerArray = arrayTypeId(peerFieldType) != Types.UNKNOWN; + boolean localArray = arrayTypeId(localFieldType) != Types.UNKNOWN; + if ((peerList && localArray) || (peerArray && localList)) { + return true; + } + if (peerFieldType.getTypeId() != localFieldType.getTypeId()) { + return false; + } + if (peerFieldType instanceof FieldTypes.CollectionFieldType + && localFieldType instanceof FieldTypes.CollectionFieldType) { + return hasListArrayShapeMismatch( + ((FieldTypes.CollectionFieldType) peerFieldType).getElementType(), + ((FieldTypes.CollectionFieldType) localFieldType).getElementType()); + } + if (peerFieldType instanceof FieldTypes.MapFieldType + && localFieldType instanceof FieldTypes.MapFieldType) { + FieldTypes.MapFieldType peerMap = (FieldTypes.MapFieldType) peerFieldType; + FieldTypes.MapFieldType localMap = (FieldTypes.MapFieldType) localFieldType; + return hasListArrayShapeMismatch(peerMap.getKeyType(), localMap.getKeyType()) + || hasListArrayShapeMismatch(peerMap.getValueType(), localMap.getValueType()); + } + if (peerFieldType instanceof FieldTypes.ArrayFieldType + && localFieldType instanceof FieldTypes.ArrayFieldType) { + return hasListArrayShapeMismatch( + ((FieldTypes.ArrayFieldType) peerFieldType).getComponentType(), + ((FieldTypes.ArrayFieldType) localFieldType).getComponentType()); + } + return false; + } + + private static boolean isListField(FieldTypes.FieldType fieldType) { + return fieldType instanceof FieldTypes.CollectionFieldType + && fieldType.getTypeId() == Types.LIST; + } + + private static int listElementTypeId(FieldTypes.FieldType fieldType) { + if (!(fieldType instanceof FieldTypes.CollectionFieldType) + || fieldType.getTypeId() != Types.LIST) { + return Types.UNKNOWN; + } + FieldTypes.FieldType elementType = + ((FieldTypes.CollectionFieldType) fieldType).getElementType(); + if (elementType instanceof FieldTypes.RegisteredFieldType) { + return ((FieldTypes.RegisteredFieldType) elementType).getTypeId(); + } + return Types.UNKNOWN; + } + + private static int arrayTypeId(FieldTypes.FieldType fieldType) { + if (fieldType instanceof FieldTypes.RegisteredFieldType) { + int typeId = ((FieldTypes.RegisteredFieldType) fieldType).getTypeId(); + if (Types.isPrimitiveArray(typeId)) { + return typeId; + } + } + return Types.UNKNOWN; + } + + private static int denseArrayTypeId(int elementTypeId) { + switch (elementTypeId) { + case Types.BOOL: + return Types.BOOL_ARRAY; + case Types.INT8: + return Types.INT8_ARRAY; + case Types.UINT8: + return Types.UINT8_ARRAY; + case Types.INT16: + return Types.INT16_ARRAY; + case Types.UINT16: + return Types.UINT16_ARRAY; + case Types.INT32: + case Types.VARINT32: + return Types.INT32_ARRAY; + case Types.UINT32: + case Types.VAR_UINT32: + return Types.UINT32_ARRAY; + case Types.INT64: + case Types.VARINT64: + case Types.TAGGED_INT64: + return Types.INT64_ARRAY; + case Types.UINT64: + case Types.VAR_UINT64: + case Types.TAGGED_UINT64: + return Types.UINT64_ARRAY; + case Types.FLOAT16: + return Types.FLOAT16_ARRAY; + case Types.BFLOAT16: + return Types.BFLOAT16_ARRAY; + case Types.FLOAT32: + return Types.FLOAT32_ARRAY; + case Types.FLOAT64: + return Types.FLOAT64_ARRAY; + default: + return Types.UNKNOWN; + } + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java index 39e70a5431..4f4d07eb34 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java @@ -76,6 +76,7 @@ import org.apache.fory.memory.Platform; import org.apache.fory.meta.EncodedMetaString; import org.apache.fory.meta.TypeDef; +import org.apache.fory.meta.TypeExtMeta; import org.apache.fory.reflect.ReflectionUtils; import org.apache.fory.serializer.ArraySerializers; import org.apache.fory.serializer.BigIntegerSerializer; @@ -1292,6 +1293,10 @@ public boolean usesPrimitiveFieldOrdering(Descriptor descriptor) { } private byte getInternalTypeId(Descriptor descriptor) { + TypeExtMeta extMeta = descriptor.getTypeRef().getTypeExtMeta(); + if (extMeta != null && extMeta.typeId() != Types.UNKNOWN) { + return (byte) extMeta.typeId(); + } if (TypeAnnotationUtils.isBoxedListArrayType(descriptor.getField())) { return (byte) TypeAnnotationUtils.getBoxedListArrayTypeId(descriptor.getField()); } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java new file mode 100644 index 0000000000..3828d4b47c --- /dev/null +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java @@ -0,0 +1,682 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.serializer; + +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.List; +import org.apache.fory.Fory; +import org.apache.fory.collection.BFloat16List; +import org.apache.fory.collection.BoolList; +import org.apache.fory.collection.Float16List; +import org.apache.fory.collection.Float32List; +import org.apache.fory.collection.Float64List; +import org.apache.fory.collection.Int16List; +import org.apache.fory.collection.Int32List; +import org.apache.fory.collection.Int64List; +import org.apache.fory.collection.Int8List; +import org.apache.fory.collection.UInt16List; +import org.apache.fory.collection.UInt32List; +import org.apache.fory.collection.UInt64List; +import org.apache.fory.collection.UInt8List; +import org.apache.fory.config.Config; +import org.apache.fory.context.ReadContext; +import org.apache.fory.context.RefReader; +import org.apache.fory.exception.DeserializationException; +import org.apache.fory.memory.MemoryBuffer; +import org.apache.fory.memory.Platform; +import org.apache.fory.meta.FieldTypes; +import org.apache.fory.meta.TypeExtMeta; +import org.apache.fory.reflect.TypeRef; +import org.apache.fory.resolver.RefMode; +import org.apache.fory.resolver.TypeResolver; +import org.apache.fory.serializer.collection.CollectionFlags; +import org.apache.fory.type.BFloat16; +import org.apache.fory.type.BFloat16Array; +import org.apache.fory.type.Descriptor; +import org.apache.fory.type.Float16; +import org.apache.fory.type.Float16Array; +import org.apache.fory.type.TypeUtils; +import org.apache.fory.type.Types; + +final class CompatibleCollectionArrayReader { + static final int READ_LIST_TO_ARRAY = 1; + static final int READ_ARRAY_TO_LIST = 2; + + static final class ReadAction { + final int mode; + final int arrayTypeId; + final int elementTypeId; + final Class targetType; + + private ReadAction(int mode, int arrayTypeId, int elementTypeId, Class targetType) { + this.mode = mode; + this.arrayTypeId = arrayTypeId; + this.elementTypeId = elementTypeId; + this.targetType = targetType; + } + } + + private CompatibleCollectionArrayReader() {} + + static ReadAction readAction(TypeResolver resolver, Descriptor descriptor) { + Field field = descriptor.getField(); + if (field == null || !resolver.isCrossLanguage()) { + return null; + } + FieldTypes.FieldType localFieldType = FieldTypes.buildFieldType(resolver, field); + int peerListElementTypeId = listElementTypeId(descriptor.getTypeRef()); + if (peerListElementTypeId != Types.UNKNOWN) { + int localArrayTypeId = arrayTypeId(localFieldType); + if (localArrayTypeId != Types.UNKNOWN + && localArrayTypeId == denseArrayTypeId(peerListElementTypeId)) { + return new ReadAction( + READ_LIST_TO_ARRAY, localArrayTypeId, peerListElementTypeId, field.getType()); + } + return null; + } + int peerArrayTypeId = arrayTypeId(descriptor.getTypeRef()); + if (peerArrayTypeId != Types.UNKNOWN) { + int localListElementTypeId = listElementTypeId(localFieldType); + if (localListElementTypeId != Types.UNKNOWN + && peerArrayTypeId == denseArrayTypeId(localListElementTypeId)) { + return new ReadAction( + READ_ARRAY_TO_LIST, peerArrayTypeId, localListElementTypeId, field.getType()); + } + } + return null; + } + + static Object read(ReadContext readContext, RefMode refMode, ReadAction action) { + return read( + readContext, + refMode, + action.mode, + action.arrayTypeId, + action.elementTypeId, + action.targetType); + } + + static Object read( + ReadContext readContext, + RefMode refMode, + int readMode, + int arrayTypeId, + int elementTypeId, + Class targetType) { + switch (refMode) { + case NONE: + return readNotNull(readContext, readMode, arrayTypeId, elementTypeId, targetType); + case NULL_ONLY: + if (readContext.getBuffer().readByte() == Fory.NULL_FLAG) { + return null; + } + return readNotNull(readContext, readMode, arrayTypeId, elementTypeId, targetType); + case TRACKING: + return readTracking(readContext, readMode, arrayTypeId, elementTypeId, targetType); + default: + throw new IllegalStateException("Unknown refMode: " + refMode); + } + } + + private static Object readTracking( + ReadContext readContext, + int readMode, + int arrayTypeId, + int elementTypeId, + Class targetType) { + RefReader refReader = readContext.getRefReader(); + int nextReadRefId = refReader.tryPreserveRefId(readContext.getBuffer()); + if (nextReadRefId >= Fory.NOT_NULL_VALUE_FLAG) { + Object value = readNotNull(readContext, readMode, arrayTypeId, elementTypeId, targetType); + refReader.setReadRef(nextReadRefId, value); + return value; + } + return refReader.getReadRef(); + } + + private static Object readNotNull( + ReadContext readContext, + int readMode, + int arrayTypeId, + int elementTypeId, + Class targetType) { + if (readMode == READ_LIST_TO_ARRAY) { + Object array = readListPayloadAsPrimitiveArray(readContext, arrayTypeId, elementTypeId); + if (array == null) { + return null; + } + return materializeTarget(array, arrayTypeId, targetType); + } + if (readMode == READ_ARRAY_TO_LIST) { + Object array = readDenseArrayPayload(readContext, arrayTypeId); + return materializeTarget(array, arrayTypeId, targetType); + } + throw new IllegalStateException("Unexpected compatible read mode " + readMode); + } + + private static int listElementTypeId(FieldTypes.FieldType fieldType) { + if (!(fieldType instanceof FieldTypes.CollectionFieldType) + || fieldType.getTypeId() != Types.LIST) { + return Types.UNKNOWN; + } + FieldTypes.FieldType elementType = + ((FieldTypes.CollectionFieldType) fieldType).getElementType(); + if (elementType instanceof FieldTypes.RegisteredFieldType) { + return ((FieldTypes.RegisteredFieldType) elementType).getTypeId(); + } + return Types.UNKNOWN; + } + + private static int listElementTypeId(TypeRef typeRef) { + TypeExtMeta extMeta = typeRef.getTypeExtMeta(); + if (extMeta == null || extMeta.typeId() != Types.LIST) { + return Types.UNKNOWN; + } + TypeExtMeta elementExtMeta = TypeUtils.getElementType(typeRef).getTypeExtMeta(); + return elementExtMeta == null ? Types.UNKNOWN : elementExtMeta.typeId(); + } + + private static int arrayTypeId(FieldTypes.FieldType fieldType) { + if (fieldType instanceof FieldTypes.RegisteredFieldType) { + int typeId = ((FieldTypes.RegisteredFieldType) fieldType).getTypeId(); + if (Types.isPrimitiveArray(typeId)) { + return typeId; + } + } + return Types.UNKNOWN; + } + + private static int arrayTypeId(TypeRef typeRef) { + TypeExtMeta extMeta = typeRef.getTypeExtMeta(); + if (extMeta != null && Types.isPrimitiveArray(extMeta.typeId())) { + return extMeta.typeId(); + } + return Types.UNKNOWN; + } + + private static int denseArrayTypeId(int elementTypeId) { + switch (elementTypeId) { + case Types.BOOL: + return Types.BOOL_ARRAY; + case Types.INT8: + return Types.INT8_ARRAY; + case Types.UINT8: + return Types.UINT8_ARRAY; + case Types.INT16: + return Types.INT16_ARRAY; + case Types.UINT16: + return Types.UINT16_ARRAY; + case Types.INT32: + case Types.VARINT32: + return Types.INT32_ARRAY; + case Types.UINT32: + case Types.VAR_UINT32: + return Types.UINT32_ARRAY; + case Types.INT64: + case Types.VARINT64: + case Types.TAGGED_INT64: + return Types.INT64_ARRAY; + case Types.UINT64: + case Types.VAR_UINT64: + case Types.TAGGED_UINT64: + return Types.UINT64_ARRAY; + case Types.FLOAT16: + return Types.FLOAT16_ARRAY; + case Types.BFLOAT16: + return Types.BFLOAT16_ARRAY; + case Types.FLOAT32: + return Types.FLOAT32_ARRAY; + case Types.FLOAT64: + return Types.FLOAT64_ARRAY; + default: + return Types.UNKNOWN; + } + } + + private static Object readListPayloadAsPrimitiveArray( + ReadContext readContext, int arrayTypeId, int elementTypeId) { + MemoryBuffer buffer = readContext.getBuffer(); + int numElements = buffer.readVarUInt32Small7(); + validateElementCount(readContext.getConfig(), numElements); + validateElementStorageSize(readContext.getConfig(), numElements, elementSize(arrayTypeId)); + if (numElements > 0) { + int flags = buffer.readByte(); + boolean hasNull = (flags & CollectionFlags.HAS_NULL) == CollectionFlags.HAS_NULL; + boolean trackingRef = (flags & CollectionFlags.TRACKING_REF) == CollectionFlags.TRACKING_REF; + boolean sameType = (flags & CollectionFlags.IS_SAME_TYPE) == CollectionFlags.IS_SAME_TYPE; + boolean declared = + (flags & CollectionFlags.IS_DECL_ELEMENT_TYPE) == CollectionFlags.IS_DECL_ELEMENT_TYPE; + if (hasNull || trackingRef) { + throw new DeserializationException( + "Cannot read nullable or ref-tracked peer list payload into local array field"); + } + if (!sameType || !declared) { + throw new DeserializationException( + "Cannot read peer list payload into local array field"); + } + } + return readListPrimitiveElements(buffer, numElements, arrayTypeId, elementTypeId); + } + + private static Object readDenseArrayPayload(ReadContext readContext, int arrayTypeId) { + MemoryBuffer buffer = readContext.getBuffer(); + int byteSize = buffer.readVarUInt32Small7(); + int elemSize = elementSize(arrayTypeId); + validateBinarySize(readContext.getConfig(), buffer, byteSize, elemSize); + return readPrimitiveElements(buffer, byteSize / elemSize, arrayTypeId); + } + + private static Object readPrimitiveElements( + MemoryBuffer buffer, int numElements, int arrayTypeId) { + switch (arrayTypeId) { + case Types.BOOL_ARRAY: + { + boolean[] values = new boolean[numElements]; + buffer.readToUnsafe(values, Platform.BOOLEAN_ARRAY_OFFSET, numElements); + return values; + } + case Types.INT8_ARRAY: + case Types.UINT8_ARRAY: + { + byte[] values = new byte[numElements]; + buffer.readToUnsafe(values, Platform.BYTE_ARRAY_OFFSET, numElements); + return values; + } + case Types.INT16_ARRAY: + case Types.UINT16_ARRAY: + case Types.FLOAT16_ARRAY: + case Types.BFLOAT16_ARRAY: + { + short[] values = new short[numElements]; + if (Platform.IS_LITTLE_ENDIAN) { + buffer.readToUnsafe(values, Platform.SHORT_ARRAY_OFFSET, numElements * 2); + } else { + for (int i = 0; i < numElements; i++) { + values[i] = buffer.readInt16(); + } + } + return values; + } + case Types.INT32_ARRAY: + case Types.UINT32_ARRAY: + { + int[] values = new int[numElements]; + if (Platform.IS_LITTLE_ENDIAN) { + buffer.readToUnsafe(values, Platform.INT_ARRAY_OFFSET, numElements * 4); + } else { + for (int i = 0; i < numElements; i++) { + values[i] = buffer.readInt32(); + } + } + return values; + } + case Types.INT64_ARRAY: + case Types.UINT64_ARRAY: + { + long[] values = new long[numElements]; + if (Platform.IS_LITTLE_ENDIAN) { + buffer.readToUnsafe(values, Platform.LONG_ARRAY_OFFSET, numElements * 8); + } else { + for (int i = 0; i < numElements; i++) { + values[i] = buffer.readInt64(); + } + } + return values; + } + case Types.FLOAT32_ARRAY: + { + float[] values = new float[numElements]; + if (Platform.IS_LITTLE_ENDIAN) { + buffer.readToUnsafe(values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4); + } else { + for (int i = 0; i < numElements; i++) { + values[i] = buffer.readFloat32(); + } + } + return values; + } + case Types.FLOAT64_ARRAY: + { + double[] values = new double[numElements]; + if (Platform.IS_LITTLE_ENDIAN) { + buffer.readToUnsafe(values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8); + } else { + for (int i = 0; i < numElements; i++) { + values[i] = buffer.readFloat64(); + } + } + return values; + } + default: + throw new IllegalArgumentException("Unsupported dense array type id " + arrayTypeId); + } + } + + private static Object readListPrimitiveElements( + MemoryBuffer buffer, int numElements, int arrayTypeId, int elementTypeId) { + switch (elementTypeId) { + case Types.BOOL: + { + boolean[] values = new boolean[numElements]; + for (int i = 0; i < numElements; i++) { + values[i] = buffer.readBoolean(); + } + return values; + } + case Types.INT8: + case Types.UINT8: + { + byte[] values = new byte[numElements]; + buffer.readBytes(values); + return values; + } + case Types.INT16: + case Types.UINT16: + case Types.FLOAT16: + case Types.BFLOAT16: + { + short[] values = new short[numElements]; + for (int i = 0; i < numElements; i++) { + values[i] = buffer.readInt16(); + } + return values; + } + case Types.INT32: + case Types.UINT32: + { + int[] values = new int[numElements]; + for (int i = 0; i < numElements; i++) { + values[i] = buffer.readInt32(); + } + return values; + } + case Types.VARINT32: + { + int[] values = new int[numElements]; + for (int i = 0; i < numElements; i++) { + values[i] = buffer.readVarInt32(); + } + return values; + } + case Types.VAR_UINT32: + { + int[] values = new int[numElements]; + for (int i = 0; i < numElements; i++) { + values[i] = buffer.readVarUInt32(); + } + return values; + } + case Types.INT64: + case Types.UINT64: + { + long[] values = new long[numElements]; + for (int i = 0; i < numElements; i++) { + values[i] = buffer.readInt64(); + } + return values; + } + case Types.VARINT64: + { + long[] values = new long[numElements]; + for (int i = 0; i < numElements; i++) { + values[i] = buffer.readVarInt64(); + } + return values; + } + case Types.TAGGED_INT64: + { + long[] values = new long[numElements]; + for (int i = 0; i < numElements; i++) { + values[i] = buffer.readTaggedInt64(); + } + return values; + } + case Types.VAR_UINT64: + { + long[] values = new long[numElements]; + for (int i = 0; i < numElements; i++) { + values[i] = buffer.readVarUInt64(); + } + return values; + } + case Types.TAGGED_UINT64: + { + long[] values = new long[numElements]; + for (int i = 0; i < numElements; i++) { + values[i] = buffer.readTaggedUInt64(); + } + return values; + } + case Types.FLOAT32: + { + float[] values = new float[numElements]; + for (int i = 0; i < numElements; i++) { + values[i] = buffer.readFloat32(); + } + return values; + } + case Types.FLOAT64: + { + double[] values = new double[numElements]; + for (int i = 0; i < numElements; i++) { + values[i] = buffer.readFloat64(); + } + return values; + } + default: + throw new DeserializationException( + "Unsupported peer list element type id " + + elementTypeId + + " for local array type id " + + arrayTypeId); + } + } + + private static Object materializeTarget(Object array, int arrayTypeId, Class targetType) { + if (targetType.isArray()) { + return array; + } + if (targetType == Float16Array.class) { + return Float16Array.wrapBits((short[]) array); + } + if (targetType == BFloat16Array.class) { + return BFloat16Array.wrapBits((short[]) array); + } + Object primitiveList = materializePrimitiveList(array, arrayTypeId, targetType); + if (primitiveList != null) { + return primitiveList; + } + if (List.class.isAssignableFrom(targetType)) { + return materializeBoxedList(array, arrayTypeId); + } + throw new DeserializationException("Unsupported compatible list/array target " + targetType); + } + + private static Object materializePrimitiveList( + Object array, int arrayTypeId, Class targetType) { + switch (arrayTypeId) { + case Types.BOOL_ARRAY: + return targetType == BoolList.class ? new BoolList((boolean[]) array) : null; + case Types.INT8_ARRAY: + return targetType == Int8List.class ? new Int8List((byte[]) array) : null; + case Types.UINT8_ARRAY: + return targetType == UInt8List.class ? new UInt8List((byte[]) array) : null; + case Types.INT16_ARRAY: + return targetType == Int16List.class ? new Int16List((short[]) array) : null; + case Types.UINT16_ARRAY: + return targetType == UInt16List.class ? new UInt16List((short[]) array) : null; + case Types.INT32_ARRAY: + return targetType == Int32List.class ? new Int32List((int[]) array) : null; + case Types.UINT32_ARRAY: + return targetType == UInt32List.class ? new UInt32List((int[]) array) : null; + case Types.INT64_ARRAY: + return targetType == Int64List.class ? new Int64List((long[]) array) : null; + case Types.UINT64_ARRAY: + return targetType == UInt64List.class ? new UInt64List((long[]) array) : null; + case Types.FLOAT16_ARRAY: + return targetType == Float16List.class ? new Float16List((short[]) array) : null; + case Types.BFLOAT16_ARRAY: + return targetType == BFloat16List.class ? new BFloat16List((short[]) array) : null; + case Types.FLOAT32_ARRAY: + return targetType == Float32List.class ? new Float32List((float[]) array) : null; + case Types.FLOAT64_ARRAY: + return targetType == Float64List.class ? new Float64List((double[]) array) : null; + default: + throw new IllegalArgumentException("Unsupported dense array type id " + arrayTypeId); + } + } + + private static List materializeBoxedList(Object array, int arrayTypeId) { + int size = java.lang.reflect.Array.getLength(array); + ArrayList list = new ArrayList<>(size); + switch (arrayTypeId) { + case Types.BOOL_ARRAY: + for (boolean value : (boolean[]) array) { + list.add(value); + } + break; + case Types.INT8_ARRAY: + for (byte value : (byte[]) array) { + list.add(value); + } + break; + case Types.UINT8_ARRAY: + for (byte value : (byte[]) array) { + list.add(value & 0xFF); + } + break; + case Types.INT16_ARRAY: + for (short value : (short[]) array) { + list.add(value); + } + break; + case Types.UINT16_ARRAY: + for (short value : (short[]) array) { + list.add(value & 0xFFFF); + } + break; + case Types.INT32_ARRAY: + for (int value : (int[]) array) { + list.add(value); + } + break; + case Types.UINT32_ARRAY: + for (int value : (int[]) array) { + list.add(Integer.toUnsignedLong(value)); + } + break; + case Types.INT64_ARRAY: + case Types.UINT64_ARRAY: + for (long value : (long[]) array) { + list.add(value); + } + break; + case Types.FLOAT16_ARRAY: + for (short value : (short[]) array) { + list.add(Float16.fromBits(value)); + } + break; + case Types.BFLOAT16_ARRAY: + for (short value : (short[]) array) { + list.add(BFloat16.fromBits(value)); + } + break; + case Types.FLOAT32_ARRAY: + for (float value : (float[]) array) { + list.add(value); + } + break; + case Types.FLOAT64_ARRAY: + for (double value : (double[]) array) { + list.add(value); + } + break; + default: + throw new IllegalArgumentException("Unsupported dense array type id " + arrayTypeId); + } + return list; + } + + private static int elementSize(int arrayTypeId) { + switch (arrayTypeId) { + case Types.BOOL_ARRAY: + case Types.INT8_ARRAY: + case Types.UINT8_ARRAY: + return 1; + case Types.INT16_ARRAY: + case Types.UINT16_ARRAY: + case Types.FLOAT16_ARRAY: + case Types.BFLOAT16_ARRAY: + return 2; + case Types.INT32_ARRAY: + case Types.UINT32_ARRAY: + case Types.FLOAT32_ARRAY: + return 4; + case Types.INT64_ARRAY: + case Types.UINT64_ARRAY: + case Types.FLOAT64_ARRAY: + return 8; + default: + throw new IllegalArgumentException("Unsupported dense array type id " + arrayTypeId); + } + } + + private static void validateElementCount(Config config, int numElements) { + if (numElements < 0) { + throw new DeserializationException("Collection size must be non-negative: " + numElements); + } + if (numElements > config.maxCollectionSize()) { + throw new DeserializationException( + "Collection size " + + numElements + + " exceeds max collection size " + + config.maxCollectionSize()); + } + } + + private static void validateElementStorageSize(Config config, int numElements, int elemSize) { + if (numElements > config.maxBinarySize() / elemSize) { + throw new DeserializationException( + "Binary payload size " + + ((long) numElements * elemSize) + + " exceeds max binary size " + + config.maxBinarySize()); + } + } + + private static void validateBinarySize( + Config config, MemoryBuffer buffer, int byteSize, int elemSize) { + if (byteSize < 0) { + throw new DeserializationException("Binary payload size must be non-negative: " + byteSize); + } + if (byteSize > config.maxBinarySize()) { + throw new DeserializationException( + "Binary payload size " + byteSize + " exceeds max binary size " + config.maxBinarySize()); + } + if (byteSize % elemSize != 0) { + throw new DeserializationException( + "Binary payload size " + byteSize + " is not aligned to element size " + elemSize); + } + if (byteSize > buffer.remaining()) { + buffer.checkReadableBytes(byteSize); + } + } +} diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/MetaSharedSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/MetaSharedSerializer.java index 404cceb818..c56b8782c8 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/MetaSharedSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/MetaSharedSerializer.java @@ -34,6 +34,7 @@ import org.apache.fory.meta.TypeDef; import org.apache.fory.reflect.FieldAccessor; import org.apache.fory.resolver.ClassResolver; +import org.apache.fory.resolver.RefMode; import org.apache.fory.resolver.TypeResolver; import org.apache.fory.serializer.FieldGroups.SerializationFieldInfo; import org.apache.fory.type.Descriptor; @@ -69,6 +70,10 @@ public class MetaSharedSerializer extends AbstractObjectSerializer { private final SerializationFieldInfo[] buildInFields; private final SerializationFieldInfo[] containerFields; private final SerializationFieldInfo[] otherFields; + private final CompatibleCollectionArrayReader.ReadAction[] buildInCompatibleReadActions; + private final CompatibleCollectionArrayReader.ReadAction[] containerCompatibleReadActions; + private final CompatibleCollectionArrayReader.ReadAction[] otherCompatibleReadActions; + private final boolean hasCompatibleCollectionArrayRead; private final RecordInfo recordInfo; private Serializer serializer; private final boolean hasDefaultValues; @@ -106,6 +111,16 @@ public MetaSharedSerializer(TypeResolver typeResolver, Class type, TypeDef ty buildInFields = fieldGroups.buildInFields; containerFields = fieldGroups.containerFields; otherFields = fieldGroups.userTypeFields; + buildInCompatibleReadActions = + buildCompatibleCollectionArrayReadActions(typeResolver, buildInFields); + containerCompatibleReadActions = + buildCompatibleCollectionArrayReadActions(typeResolver, containerFields); + otherCompatibleReadActions = + buildCompatibleCollectionArrayReadActions(typeResolver, otherFields); + hasCompatibleCollectionArrayRead = + buildInCompatibleReadActions != null + || containerCompatibleReadActions != null + || otherCompatibleReadActions != null; if (isRecord) { List fieldNames = descriptorGrouper.getSortedDescriptors().stream() @@ -140,6 +155,80 @@ public MetaSharedSerializer(TypeResolver typeResolver, Class type, TypeDef ty this.defaultValueFields = defaultValueFields; } + /** Used by generated meta-shared serializers for top-level list/array compatible field reads. */ + public static Object readCompatibleCollectionArrayField( + ReadContext readContext, + boolean trackingRef, + boolean nullable, + int readMode, + int arrayTypeId, + int elementTypeId, + Class targetType) { + return CompatibleCollectionArrayReader.read( + readContext, + RefMode.of(trackingRef, nullable), + readMode, + arrayTypeId, + elementTypeId, + targetType); + } + + /** Used by generated meta-shared serializers to cache a top-level list/array read action. */ + public static int compatibleCollectionArrayReadMode( + TypeResolver resolver, Descriptor descriptor) { + return requireCompatibleCollectionArrayReadAction(resolver, descriptor).mode; + } + + /** Used by generated meta-shared serializers to cache the dense array carrier type. */ + public static int compatibleCollectionArrayTypeId(TypeResolver resolver, Descriptor descriptor) { + return requireCompatibleCollectionArrayReadAction(resolver, descriptor).arrayTypeId; + } + + /** Used by generated meta-shared serializers to cache the peer or local element type. */ + public static int compatibleCollectionElementTypeId( + TypeResolver resolver, Descriptor descriptor) { + return requireCompatibleCollectionArrayReadAction(resolver, descriptor).elementTypeId; + } + + /** Returns whether a descriptor has a top-level list/array compatible read action. */ + public static boolean hasCompatibleCollectionArrayRead( + TypeResolver resolver, Descriptor descriptor) { + return CompatibleCollectionArrayReader.readAction(resolver, descriptor) != null; + } + + private static CompatibleCollectionArrayReader.ReadAction + requireCompatibleCollectionArrayReadAction(TypeResolver resolver, Descriptor descriptor) { + CompatibleCollectionArrayReader.ReadAction action = + CompatibleCollectionArrayReader.readAction(resolver, descriptor); + if (action == null) { + throw new IllegalArgumentException( + "Descriptor has no top-level list/array compatible read action: " + descriptor); + } + return action; + } + + private static CompatibleCollectionArrayReader.ReadAction[] + buildCompatibleCollectionArrayReadActions( + TypeResolver resolver, SerializationFieldInfo[] fields) { + CompatibleCollectionArrayReader.ReadAction[] actions = null; + for (int i = 0; i < fields.length; i++) { + CompatibleCollectionArrayReader.ReadAction action = + CompatibleCollectionArrayReader.readAction(resolver, fields[i].descriptor); + if (action != null) { + if (actions == null) { + actions = new CompatibleCollectionArrayReader.ReadAction[fields.length]; + } + actions[i] = action; + } + } + return actions; + } + + private static CompatibleCollectionArrayReader.ReadAction compatibleCollectionArrayReadAction( + CompatibleCollectionArrayReader.ReadAction[] actions, int index) { + return actions == null ? null : actions[index]; + } + @Override public void write(WriteContext writeContext, T value) { MemoryBuffer buffer = writeContext.getBuffer(); @@ -164,11 +253,14 @@ private T newInstance() { @Override public T read(ReadContext readContext) { - MemoryBuffer buffer = readContext.getBuffer(); if (isRecord) { Object[] fieldValues = new Object[buildInFields.length + otherFields.length + containerFields.length]; - readFields(readContext, fieldValues); + if (hasCompatibleCollectionArrayRead) { + readFieldsWithCompatibleCollectionArray(readContext, fieldValues); + } else { + readFields(readContext, fieldValues); + } fieldValues = RecordUtils.remapping(recordInfo, fieldValues); T t = objectCreator.newInstanceWithArguments(fieldValues); Arrays.fill(recordInfo.getRecordComponents(), null); @@ -178,6 +270,16 @@ public T read(ReadContext readContext) { if (readContext.hasPreservedRefId()) { readContext.reference(targetObject); } + if (hasCompatibleCollectionArrayRead) { + readFieldsWithCompatibleCollectionArray(readContext, targetObject); + } else { + readFields(readContext, targetObject); + } + return targetObject; + } + + private void readFields(ReadContext readContext, T targetObject) { + MemoryBuffer buffer = readContext.getBuffer(); RefReader refReader = readContext.getRefReader(); // read order: primitive,boxed,final,other,collection,map for (SerializationFieldInfo fieldInfo : this.buildInFields) { @@ -222,16 +324,6 @@ public T read(ReadContext readContext) { fieldAccessor.putObject(targetObject, fieldValue); } } - return targetObject; - } - - private void compatibleRead( - ReadContext readContext, SerializationFieldInfo fieldInfo, Object obj) { - MemoryBuffer buffer = readContext.getBuffer(); - Object fieldValue = - AbstractObjectSerializer.readBuildInFieldValue( - readContext, typeResolver, readContext.getRefReader(), fieldInfo, buffer); - fieldInfo.fieldConverter.set(obj, fieldValue); } private void readFields(ReadContext readContext, Object[] fields) { @@ -277,6 +369,138 @@ private void readFields(ReadContext readContext, Object[] fields) { } } + private void compatibleRead( + ReadContext readContext, SerializationFieldInfo fieldInfo, Object obj) { + MemoryBuffer buffer = readContext.getBuffer(); + Object fieldValue = + AbstractObjectSerializer.readBuildInFieldValue( + readContext, typeResolver, readContext.getRefReader(), fieldInfo, buffer); + fieldInfo.fieldConverter.set(obj, fieldValue); + } + + private void readFieldsWithCompatibleCollectionArray(ReadContext readContext, T targetObject) { + MemoryBuffer buffer = readContext.getBuffer(); + RefReader refReader = readContext.getRefReader(); + for (int i = 0; i < buildInFields.length; i++) { + SerializationFieldInfo fieldInfo = buildInFields[i]; + CompatibleCollectionArrayReader.ReadAction action = + compatibleCollectionArrayReadAction(buildInCompatibleReadActions, i); + if (Utils.DEBUG_OUTPUT_VERBOSE) { + printFieldDebugInfo(fieldInfo, buffer); + } + FieldAccessor fieldAccessor = fieldInfo.fieldAccessor; + if (fieldAccessor != null) { + if (action != null) { + fieldAccessor.putObject( + targetObject, + CompatibleCollectionArrayReader.read(readContext, fieldInfo.refMode, action)); + } else { + AbstractObjectSerializer.readBuildInFieldValue( + readContext, typeResolver, refReader, fieldInfo, buffer, targetObject); + } + } else { + if (fieldInfo.fieldConverter == null) { + // Skip the field value from buffer since it doesn't exist in current class + FieldSkipper.skipField(readContext, typeResolver, refReader, fieldInfo, buffer); + } else { + compatibleRead(readContext, fieldInfo, targetObject); + } + } + } + Generics generics = readContext.getGenerics(); + for (int i = 0; i < containerFields.length; i++) { + SerializationFieldInfo fieldInfo = containerFields[i]; + CompatibleCollectionArrayReader.ReadAction action = + compatibleCollectionArrayReadAction(containerCompatibleReadActions, i); + if (Utils.DEBUG_OUTPUT_VERBOSE) { + printFieldDebugInfo(fieldInfo, buffer); + } + Object fieldValue = + action == null + ? AbstractObjectSerializer.readContainerFieldValue( + readContext, typeResolver, refReader, generics, fieldInfo, buffer) + : CompatibleCollectionArrayReader.read(readContext, fieldInfo.refMode, action); + FieldAccessor fieldAccessor = fieldInfo.fieldAccessor; + if (fieldAccessor != null) { + fieldAccessor.putObject(targetObject, fieldValue); + } + } + for (int i = 0; i < otherFields.length; i++) { + SerializationFieldInfo fieldInfo = otherFields[i]; + CompatibleCollectionArrayReader.ReadAction action = + compatibleCollectionArrayReadAction(otherCompatibleReadActions, i); + if (Utils.DEBUG_OUTPUT_VERBOSE) { + printFieldDebugInfo(fieldInfo, buffer); + } + Object fieldValue = + action == null + ? AbstractObjectSerializer.readField( + readContext, typeResolver, refReader, fieldInfo, buffer) + : CompatibleCollectionArrayReader.read(readContext, fieldInfo.refMode, action); + FieldAccessor fieldAccessor = fieldInfo.fieldAccessor; + if (fieldAccessor != null) { + fieldAccessor.putObject(targetObject, fieldValue); + } + } + } + + private void readFieldsWithCompatibleCollectionArray(ReadContext readContext, Object[] fields) { + MemoryBuffer buffer = readContext.getBuffer(); + int counter = 0; + RefReader refReader = readContext.getRefReader(); + for (int i = 0; i < buildInFields.length; i++) { + SerializationFieldInfo fieldInfo = buildInFields[i]; + CompatibleCollectionArrayReader.ReadAction action = + compatibleCollectionArrayReadAction(buildInCompatibleReadActions, i); + if (Utils.DEBUG_OUTPUT_ENABLED) { + printFieldDebugInfo(fieldInfo, buffer); + } + if (fieldInfo.fieldAccessor != null) { + fields[counter++] = + action == null + ? AbstractObjectSerializer.readBuildInFieldValue( + readContext, typeResolver, refReader, fieldInfo, buffer) + : CompatibleCollectionArrayReader.read(readContext, fieldInfo.refMode, action); + } else { + // Skip the field value from buffer since it doesn't exist in current class. + // For records, fieldConverter can't be used since records are immutable and + // constructed all at once. We just read to advance buffer position. + FieldSkipper.skipField(readContext, typeResolver, refReader, fieldInfo, buffer); + // remapping will handle those extra fields from peers. + fields[counter++] = null; + } + } + Generics generics = readContext.getGenerics(); + for (int i = 0; i < containerFields.length; i++) { + SerializationFieldInfo fieldInfo = containerFields[i]; + CompatibleCollectionArrayReader.ReadAction action = + compatibleCollectionArrayReadAction(containerCompatibleReadActions, i); + if (Utils.DEBUG_OUTPUT_ENABLED) { + printFieldDebugInfo(fieldInfo, buffer); + } + Object fieldValue = + action == null + ? AbstractObjectSerializer.readContainerFieldValue( + readContext, typeResolver, refReader, generics, fieldInfo, buffer) + : CompatibleCollectionArrayReader.read(readContext, fieldInfo.refMode, action); + fields[counter++] = fieldValue; + } + for (int i = 0; i < otherFields.length; i++) { + SerializationFieldInfo fieldInfo = otherFields[i]; + CompatibleCollectionArrayReader.ReadAction action = + compatibleCollectionArrayReadAction(otherCompatibleReadActions, i); + if (Utils.DEBUG_OUTPUT_ENABLED) { + printFieldDebugInfo(fieldInfo, buffer); + } + Object fieldValue = + action == null + ? AbstractObjectSerializer.readField( + readContext, typeResolver, refReader, fieldInfo, buffer) + : CompatibleCollectionArrayReader.read(readContext, fieldInfo.refMode, action); + fields[counter++] = fieldValue; + } + } + private void printFieldDebugInfo(SerializationFieldInfo fieldInfo, MemoryBuffer buffer) { LOG.info( "[Java] read field {} of type {}, reader index {}", diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/CPPXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/CPPXlangTest.java index d9697e1c48..050fd5d4a3 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/CPPXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/CPPXlangTest.java @@ -451,4 +451,10 @@ public void testUnsignedSchemaCompatible(boolean enableCodegen) throws java.io.I public void testManualSchemaKindStruct(boolean enableCodegen) throws java.io.IOException { super.testManualSchemaKindStruct(enableCodegen); } + + @Override + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testListArrayCompatibleRead(boolean enableCodegen) throws java.io.IOException { + super.testListArrayCompatibleRead(enableCodegen); + } } diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/CSharpXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/CSharpXlangTest.java index 668dc36e86..c584eaab76 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/CSharpXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/CSharpXlangTest.java @@ -379,4 +379,10 @@ public void testNestedAnnotatedContainerCompatible(boolean enableCodegen) public void testManualSchemaKindStruct(boolean enableCodegen) throws java.io.IOException { super.testManualSchemaKindStruct(enableCodegen); } + + @Override + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testListArrayCompatibleRead(boolean enableCodegen) throws java.io.IOException { + super.testListArrayCompatibleRead(enableCodegen); + } } diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/DartXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/DartXlangTest.java index 6e73fb7f87..8c8ac42ec2 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/DartXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/DartXlangTest.java @@ -354,4 +354,10 @@ public void testNestedAnnotatedContainerCompatible(boolean enableCodegen) public void testManualSchemaKindStruct(boolean enableCodegen) throws java.io.IOException { super.testManualSchemaKindStruct(enableCodegen); } + + @Override + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testListArrayCompatibleRead(boolean enableCodegen) throws java.io.IOException { + super.testListArrayCompatibleRead(enableCodegen); + } } diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/GoXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/GoXlangTest.java index 51153c3d20..98f96d9267 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/GoXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/GoXlangTest.java @@ -524,4 +524,10 @@ public void testNestedAnnotatedContainerCompatible(boolean enableCodegen) public void testManualSchemaKindStruct(boolean enableCodegen) throws java.io.IOException { super.testManualSchemaKindStruct(enableCodegen); } + + @Override + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testListArrayCompatibleRead(boolean enableCodegen) throws java.io.IOException { + super.testListArrayCompatibleRead(enableCodegen); + } } diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/JavaScriptXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/JavaScriptXlangTest.java index 7d167368a8..576bb8da0c 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/JavaScriptXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/JavaScriptXlangTest.java @@ -550,4 +550,10 @@ public void testUnsignedSchemaCompatible(boolean enableCodegen) throws java.io.I public void testManualSchemaKindStruct(boolean enableCodegen) throws java.io.IOException { super.testManualSchemaKindStruct(enableCodegen); } + + @Override + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testListArrayCompatibleRead(boolean enableCodegen) throws java.io.IOException { + super.testListArrayCompatibleRead(enableCodegen); + } } diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/MetaSharedXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/MetaSharedXlangTest.java index fa5dc9aaea..1fd546c5b8 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/MetaSharedXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/MetaSharedXlangTest.java @@ -19,9 +19,22 @@ package org.apache.fory.xlang; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertThrows; +import static org.testng.Assert.assertTrue; + +import java.util.Arrays; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; import lombok.Data; import org.apache.fory.Fory; import org.apache.fory.ForyTestBase; +import org.apache.fory.annotation.ArrayType; +import org.apache.fory.annotation.Int32Type; +import org.apache.fory.collection.Int32List; +import org.apache.fory.config.Int32Encoding; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.test.bean.BeanB; import org.apache.fory.xlang.PyCrossLanguageTest.Bar; import org.apache.fory.xlang.PyCrossLanguageTest.Foo; @@ -58,4 +71,160 @@ public void testMDArrayField() { s.arr = new int[][] {{1, 2}, {3, 4}}; serDeCheck(fory, s); } + + @Data + static class DirectListField { + @Int32Type(encoding = Int32Encoding.FIXED) + Int32List values; + } + + @Data + static class DirectNullableListField { + List values; + } + + @Data + static class DirectArrayField { + int[] values; + } + + @Data + static class DirectAnnotatedArrayField { + @ArrayType List values; + } + + @Data + static class NestedListField { + List> values; + } + + @Data + static class NestedArrayElementField { + List values; + } + + @Data + static class NestedSetListField { + Set> values; + } + + @Data + static class NestedSetArrayElementField { + Set values; + } + + @Test + public void testTopLevelListArrayCompatibleRead() { + Fory listFory = compatibleFory(DirectListField.class); + DirectListField listStruct = new DirectListField(); + listStruct.values = new Int32List(new int[] {1, 2, 3}); + byte[] listBytes = listFory.serialize(listStruct); + + Fory arrayFory = compatibleFory(DirectArrayField.class); + DirectArrayField arrayStruct = (DirectArrayField) arrayFory.deserialize(listBytes); + assertTrue(Arrays.equals(arrayStruct.values, new int[] {1, 2, 3})); + + DirectListField emptyListStruct = new DirectListField(); + emptyListStruct.values = new Int32List(); + DirectArrayField emptyArrayStruct = + (DirectArrayField) arrayFory.deserialize(listFory.serialize(emptyListStruct)); + assertEquals(emptyArrayStruct.values.length, 0); + + DirectArrayField peerArrayStruct = new DirectArrayField(); + peerArrayStruct.values = new int[] {4, 5, 6}; + byte[] arrayBytes = arrayFory.serialize(peerArrayStruct); + DirectListField readListStruct = (DirectListField) listFory.deserialize(arrayBytes); + assertEquals(readListStruct.values, Arrays.asList(4, 5, 6)); + + DirectArrayField emptyPeerArrayStruct = new DirectArrayField(); + emptyPeerArrayStruct.values = new int[0]; + DirectListField emptyReadListStruct = + (DirectListField) listFory.deserialize(arrayFory.serialize(emptyPeerArrayStruct)); + assertEquals(emptyReadListStruct.values, java.util.Collections.emptyList()); + } + + @Test + public void testTopLevelListAnnotatedArrayCompatibleRead() { + Fory listFory = compatibleFory(DirectListField.class); + DirectListField listStruct = new DirectListField(); + listStruct.values = new Int32List(new int[] {7, 8}); + + Fory annotatedArrayFory = compatibleFory(DirectAnnotatedArrayField.class); + DirectAnnotatedArrayField annotatedArrayStruct = + (DirectAnnotatedArrayField) annotatedArrayFory.deserialize(listFory.serialize(listStruct)); + assertEquals(annotatedArrayStruct.values, Arrays.asList(7, 8)); + } + + @Test + public void testTopLevelListArrayCompatibleReadWithoutCodegen() { + Fory listFory = compatibleFory(DirectListField.class, false); + DirectListField listStruct = new DirectListField(); + listStruct.values = new Int32List(new int[] {1, 2, 3}); + + Fory arrayFory = compatibleFory(DirectArrayField.class, false); + DirectArrayField arrayStruct = + (DirectArrayField) arrayFory.deserialize(listFory.serialize(listStruct)); + assertTrue(Arrays.equals(arrayStruct.values, new int[] {1, 2, 3})); + } + + @Test + public void testNullableListPayloadRejectedForArrayCompatibleRead() { + for (boolean codegen : new boolean[] {false, true}) { + Fory listFory = compatibleFory(DirectNullableListField.class, codegen); + DirectNullableListField listStruct = new DirectNullableListField(); + listStruct.values = Arrays.asList(1, 2, 3); + byte[] listBytes = listFory.serialize(listStruct); + + Fory arrayFory = compatibleFory(DirectArrayField.class, codegen); + DirectArrayField arrayStruct = (DirectArrayField) arrayFory.deserialize(listBytes); + assertTrue(Arrays.equals(arrayStruct.values, new int[] {1, 2, 3})); + + listStruct.values = Arrays.asList(1, null, 3); + byte[] nullablePayload = listFory.serialize(listStruct); + assertThrows(DeserializationException.class, () -> arrayFory.deserialize(nullablePayload)); + } + } + + @Test + public void testNestedListArrayCompatibleReadUnsupported() { + Fory nestedListFory = compatibleFory(NestedListField.class); + NestedListField nestedListStruct = new NestedListField(); + nestedListStruct.values = Arrays.asList(Arrays.asList(1, 2)); + byte[] nestedListBytes = nestedListFory.serialize(nestedListStruct); + + Fory nestedArrayFory = compatibleFory(NestedArrayElementField.class); + assertThrows( + DeserializationException.class, () -> nestedArrayFory.deserialize(nestedListBytes)); + + NestedArrayElementField nestedArrayStruct = new NestedArrayElementField(); + nestedArrayStruct.values = Arrays.asList(new int[] {1, 2}); + byte[] nestedArrayBytes = nestedArrayFory.serialize(nestedArrayStruct); + assertThrows( + DeserializationException.class, () -> nestedListFory.deserialize(nestedArrayBytes)); + + Fory nestedSetListFory = compatibleFory(NestedSetListField.class, false); + NestedSetListField nestedSetListStruct = new NestedSetListField(); + nestedSetListStruct.values = new LinkedHashSet<>(Arrays.asList(Arrays.asList(1, 2))); + byte[] nestedSetListBytes = nestedSetListFory.serialize(nestedSetListStruct); + + Fory nestedSetArrayFory = compatibleFory(NestedSetArrayElementField.class, false); + assertThrows( + DeserializationException.class, () -> nestedSetArrayFory.deserialize(nestedSetListBytes)); + + NestedSetArrayElementField nestedSetArrayStruct = new NestedSetArrayElementField(); + nestedSetArrayStruct.values = new LinkedHashSet<>(Arrays.asList(new int[] {1, 2})); + byte[] nestedSetArrayBytes = nestedSetArrayFory.serialize(nestedSetArrayStruct); + assertThrows( + DeserializationException.class, () -> nestedSetListFory.deserialize(nestedSetArrayBytes)); + } + + private static Fory compatibleFory(Class type) { + return compatibleFory(type, true); + } + + private static Fory compatibleFory(Class type, boolean codegen) { + Fory fory = Fory.builder().withXlang(true).withCompatible(true).withCodegen(codegen).build(); + fory.register(type, "example.list_array_compatible"); + return fory; + } } diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/PythonXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/PythonXlangTest.java index 138f62ba0e..b1bca830b2 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/PythonXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/PythonXlangTest.java @@ -392,4 +392,14 @@ public void testUnsignedSchemaCompatible(boolean enableCodegen) throws java.io.I public void testManualSchemaKindStruct(boolean enableCodegen) throws java.io.IOException { super.testManualSchemaKindStruct(enableCodegen); } + + @Override + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testListArrayCompatibleRead(boolean enableCodegen) throws java.io.IOException { + super.testListArrayCompatibleRead(enableCodegen); + assertCompatibleListToArrayPeerCarrier( + "test_list_array_compatible_list_to_ndarray", enableCodegen); + assertCompatibleListToArrayPeerCarrier( + "test_list_array_compatible_list_to_pyarray", enableCodegen); + } } diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/RustXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/RustXlangTest.java index c38ed9d17e..fd6ceb4dcb 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/RustXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/RustXlangTest.java @@ -81,7 +81,7 @@ protected CommandContext buildCommandContext(String caseName, Path dataFile) { env.put("RUSTFLAGS", "-Awarnings"); env.put("RUST_BACKTRACE", "1"); env.put("ENABLE_FORY_DEBUG_OUTPUT", "1"); - env.put("FORY_PANIC_ON_ERROR", "1"); + env.put("FORY_PANIC_ON_ERROR", caseName.endsWith("_error") ? "0" : "1"); return new CommandContext(command, env, new File("../../rust")); } @@ -347,4 +347,10 @@ public void testUnsignedSchemaCompatible(boolean enableCodegen) throws java.io.I public void testManualSchemaKindStruct(boolean enableCodegen) throws java.io.IOException { super.testManualSchemaKindStruct(enableCodegen); } + + @Override + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testListArrayCompatibleRead(boolean enableCodegen) throws java.io.IOException { + super.testListArrayCompatibleRead(enableCodegen); + } } diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/SwiftXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/SwiftXlangTest.java index 5a98e9d619..f9562025d3 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/SwiftXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/SwiftXlangTest.java @@ -415,4 +415,10 @@ public void testUnsignedSchemaCompatible(boolean enableCodegen) throws java.io.I public void testManualSchemaKindStruct(boolean enableCodegen) throws java.io.IOException { super.testManualSchemaKindStruct(enableCodegen); } + + @Override + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testListArrayCompatibleRead(boolean enableCodegen) throws java.io.IOException { + super.testListArrayCompatibleRead(enableCodegen); + } } diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/XlangTestBase.java b/java/fory-core/src/test/java/org/apache/fory/xlang/XlangTestBase.java index 120b1d452e..d72094cd2a 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/XlangTestBase.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/XlangTestBase.java @@ -50,10 +50,12 @@ import org.apache.fory.annotation.UInt8Type; import org.apache.fory.collection.BFloat16List; import org.apache.fory.collection.Float16List; +import org.apache.fory.collection.Int32List; import org.apache.fory.config.Int32Encoding; import org.apache.fory.config.Int64Encoding; import org.apache.fory.context.ReadContext; import org.apache.fory.context.WriteContext; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.memory.MemoryUtils; import org.apache.fory.meta.FieldInfo; @@ -1713,6 +1715,147 @@ protected static void assertReducedPrecisionFloatStruct(ReducedPrecisionFloatStr Assert.assertEquals(value.bfloat16Array.getShort(2), (short) 0xBF80); } + @Data + static class XlangCompatibleInt32ListField { + @ForyField(id = 1) + @Int32Type(encoding = Int32Encoding.FIXED) + Int32List values; + } + + @Data + static class XlangCompatibleNullableInt32ListField { + @ForyField(id = 1) + List values; + } + + @Data + static class XlangCompatibleInt32ArrayField { + @ForyField(id = 1) + int[] values; + } + + protected static XlangCompatibleInt32ListField newCompatibleInt32ListField(int... values) { + XlangCompatibleInt32ListField value = new XlangCompatibleInt32ListField(); + value.values = new Int32List(values); + return value; + } + + protected static XlangCompatibleNullableInt32ListField newCompatibleNullableInt32ListField( + Integer... values) { + XlangCompatibleNullableInt32ListField value = new XlangCompatibleNullableInt32ListField(); + value.values = Arrays.asList(values); + return value; + } + + protected static XlangCompatibleInt32ArrayField newCompatibleInt32ArrayField(int... values) { + XlangCompatibleInt32ArrayField value = new XlangCompatibleInt32ArrayField(); + value.values = values; + return value; + } + + protected static Fory compatibleListArrayFory(Class type, boolean enableCodegen) { + Fory fory = + Fory.builder().withXlang(true).withCompatible(true).withCodegen(enableCodegen).build(); + fory.register(type, 901); + return fory; + } + + private static void assertIntArrayEquals(int[] actual, int... expected) { + Assert.assertNotNull(actual); + Assert.assertTrue( + Arrays.equals(actual, expected), + "Expected " + Arrays.toString(expected) + ", got " + Arrays.toString(actual)); + } + + protected void assertCompatibleListToArrayPeerCarrier(String caseName, boolean enableCodegen) + throws java.io.IOException { + Fory listFory = compatibleListArrayFory(XlangCompatibleInt32ListField.class, enableCodegen); + Fory arrayFory = compatibleListArrayFory(XlangCompatibleInt32ArrayField.class, enableCodegen); + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(256); + listFory.serialize(buffer, newCompatibleInt32ListField(7, 8, 9)); + ExecutionContext ctx = prepareExecution(caseName, buffer.getBytes(0, buffer.writerIndex())); + runPeer(ctx); + XlangCompatibleInt32ArrayField arrayResult = + (XlangCompatibleInt32ArrayField) arrayFory.deserialize(readBuffer(ctx.dataFile())); + assertIntArrayEquals(arrayResult.values, 7, 8, 9); + } + + protected void testListArrayCompatibleRead(boolean enableCodegen) throws java.io.IOException { + Fory listFory = compatibleListArrayFory(XlangCompatibleInt32ListField.class, enableCodegen); + Fory nullableListFory = + compatibleListArrayFory(XlangCompatibleNullableInt32ListField.class, enableCodegen); + Fory arrayFory = compatibleListArrayFory(XlangCompatibleInt32ArrayField.class, enableCodegen); + + XlangCompatibleInt32ListField listValue = newCompatibleInt32ListField(1, -2, 3); + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(256); + listFory.serialize(buffer, listValue); + ExecutionContext ctx = + prepareExecution( + "test_list_array_compatible_list_to_array", buffer.getBytes(0, buffer.writerIndex())); + runPeer(ctx); + XlangCompatibleInt32ArrayField arrayResult = + (XlangCompatibleInt32ArrayField) arrayFory.deserialize(readBuffer(ctx.dataFile())); + assertIntArrayEquals(arrayResult.values, 1, -2, 3); + + XlangCompatibleInt32ListField emptyListValue = newCompatibleInt32ListField(); + buffer = MemoryBuffer.newHeapBuffer(128); + listFory.serialize(buffer, emptyListValue); + ctx = + prepareExecution( + "test_list_array_compatible_list_to_array", buffer.getBytes(0, buffer.writerIndex())); + runPeer(ctx); + XlangCompatibleInt32ArrayField emptyArrayResult = + (XlangCompatibleInt32ArrayField) arrayFory.deserialize(readBuffer(ctx.dataFile())); + assertIntArrayEquals(emptyArrayResult.values); + + XlangCompatibleInt32ArrayField arrayValue = newCompatibleInt32ArrayField(4, 5, 6); + buffer = MemoryBuffer.newHeapBuffer(256); + arrayFory.serialize(buffer, arrayValue); + ctx = + prepareExecution( + "test_list_array_compatible_array_to_list", buffer.getBytes(0, buffer.writerIndex())); + runPeer(ctx); + XlangCompatibleInt32ListField listResult = + (XlangCompatibleInt32ListField) listFory.deserialize(readBuffer(ctx.dataFile())); + Assert.assertEquals(listResult.values, Arrays.asList(4, 5, 6)); + + XlangCompatibleInt32ArrayField emptyArrayValue = newCompatibleInt32ArrayField(); + buffer = MemoryBuffer.newHeapBuffer(128); + arrayFory.serialize(buffer, emptyArrayValue); + ctx = + prepareExecution( + "test_list_array_compatible_array_to_list", buffer.getBytes(0, buffer.writerIndex())); + runPeer(ctx); + XlangCompatibleInt32ListField emptyListResult = + (XlangCompatibleInt32ListField) listFory.deserialize(readBuffer(ctx.dataFile())); + Assert.assertEquals(emptyListResult.values, Collections.emptyList()); + + XlangCompatibleNullableInt32ListField nullableListWithoutNulls = + newCompatibleNullableInt32ListField(1, 2, 3); + buffer = MemoryBuffer.newHeapBuffer(256); + nullableListFory.serialize(buffer, nullableListWithoutNulls); + ctx = + prepareExecution( + "test_list_array_compatible_list_to_array", buffer.getBytes(0, buffer.writerIndex())); + runPeer(ctx); + XlangCompatibleInt32ArrayField nullableListArrayResult = + (XlangCompatibleInt32ArrayField) arrayFory.deserialize(readBuffer(ctx.dataFile())); + assertIntArrayEquals(nullableListArrayResult.values, 1, 2, 3); + + XlangCompatibleNullableInt32ListField nullableListValue = + newCompatibleNullableInt32ListField(1, null, 3); + buffer = MemoryBuffer.newHeapBuffer(256); + nullableListFory.serialize(buffer, nullableListValue); + byte[] nullablePayload = buffer.getBytes(0, buffer.writerIndex()); + Assert.expectThrows( + DeserializationException.class, + () -> arrayFory.deserialize(MemoryUtils.wrap(nullablePayload))); + ctx = + prepareExecution( + "test_list_array_compatible_nullable_list_to_array_error", nullablePayload); + runPeer(ctx); + } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") public void testOneStringFieldSchemaConsistent(boolean enableCodegen) throws java.io.IOException { String caseName = "test_one_string_field_schema"; diff --git a/javascript/packages/core/index.ts b/javascript/packages/core/index.ts index 7a70db13f4..4af7a47ac4 100644 --- a/javascript/packages/core/index.ts +++ b/javascript/packages/core/index.ts @@ -27,6 +27,8 @@ import Fory from "./lib/fory"; import { BinaryReader } from "./lib/reader"; import { BinaryWriter } from "./lib/writer"; import { BFloat16, BFloat16Array } from "./lib/types/bfloat16"; +import { BoolArray } from "./lib/types/boolArray"; +import { Float16Array, ForyFloat16Array } from "./lib/types/float16"; import { ReadContext, WriteContext } from "./lib/context"; import { Decimal } from "./lib/types/decimal"; @@ -38,9 +40,12 @@ export { BinaryWriter, Dynamic, BinaryReader, + BoolArray, BFloat16, BFloat16Array, Decimal, + Float16Array, + ForyFloat16Array, ReadContext, WriteContext, }; diff --git a/javascript/packages/core/lib/context.ts b/javascript/packages/core/lib/context.ts index 53a9cc1820..a811f507e6 100644 --- a/javascript/packages/core/lib/context.ts +++ b/javascript/packages/core/lib/context.ts @@ -27,6 +27,7 @@ import { import { InnerFieldInfo, TypeMeta } from "./meta/TypeMeta"; import { Type, TypeInfo } from "./typeInfo"; import { Config, RefFlags, Serializer, TypeId } from "./type"; +import { markCompatibleCollectionArrayRead } from "./gen/collection"; type TypeResolverLike = { config: Config; @@ -39,6 +40,163 @@ type TypeResolverLike = { regenerateReadSerializer(typeInfo: TypeInfo): Serializer; }; +function remoteListElementType(fieldInfo: InnerFieldInfo): InnerFieldInfo | undefined { + if (fieldInfo.typeId !== TypeId.LIST) { + return undefined; + } + return fieldInfo.options?.inner; +} + +function denseArrayElementTypeId(typeId: number): number | undefined { + switch (typeId) { + case TypeId.BOOL_ARRAY: + return TypeId.BOOL; + case TypeId.INT8_ARRAY: + return TypeId.INT8; + case TypeId.INT16_ARRAY: + return TypeId.INT16; + case TypeId.INT32_ARRAY: + return TypeId.INT32; + case TypeId.INT64_ARRAY: + return TypeId.INT64; + case TypeId.UINT8_ARRAY: + return TypeId.UINT8; + case TypeId.UINT16_ARRAY: + return TypeId.UINT16; + case TypeId.UINT32_ARRAY: + return TypeId.UINT32; + case TypeId.UINT64_ARRAY: + return TypeId.UINT64; + case TypeId.FLOAT16_ARRAY: + return TypeId.FLOAT16; + case TypeId.BFLOAT16_ARRAY: + return TypeId.BFLOAT16; + case TypeId.FLOAT32_ARRAY: + return TypeId.FLOAT32; + case TypeId.FLOAT64_ARRAY: + return TypeId.FLOAT64; + default: + return undefined; + } +} + +function compatibleArrayElementTypeId(typeId: number): number { + switch (typeId) { + case TypeId.VARINT32: + return TypeId.INT32; + case TypeId.VARINT64: + case TypeId.TAGGED_INT64: + return TypeId.INT64; + case TypeId.VAR_UINT32: + return TypeId.UINT32; + case TypeId.VAR_UINT64: + case TypeId.TAGGED_UINT64: + return TypeId.UINT64; + default: + return typeId; + } +} + +function typeInfoForElementTypeId(typeId: number): TypeInfo { + switch (typeId) { + case TypeId.BOOL: + return Type.bool(); + case TypeId.INT8: + return Type.int8(); + case TypeId.INT16: + return Type.int16(); + case TypeId.INT32: + return Type.int32({ encoding: "fixed" }); + case TypeId.VARINT32: + return Type.int32(); + case TypeId.INT64: + return Type.int64({ encoding: "fixed" }); + case TypeId.VARINT64: + return Type.int64(); + case TypeId.TAGGED_INT64: + return Type.int64({ encoding: "tagged" }); + case TypeId.UINT8: + return Type.uint8(); + case TypeId.UINT16: + return Type.uint16(); + case TypeId.UINT32: + return Type.uint32({ encoding: "fixed" }); + case TypeId.VAR_UINT32: + return Type.uint32(); + case TypeId.UINT64: + return Type.uint64({ encoding: "fixed" }); + case TypeId.VAR_UINT64: + return Type.uint64(); + case TypeId.TAGGED_UINT64: + return Type.uint64({ encoding: "tagged" }); + case TypeId.FLOAT16: + return Type.float16(); + case TypeId.BFLOAT16: + return Type.bfloat16(); + case TypeId.FLOAT32: + return Type.float32(); + case TypeId.FLOAT64: + return Type.float64(); + default: + return Type.any(); + } +} + +function typeInfoForDenseArrayElementTypeId(typeId: number): TypeInfo { + switch (typeId) { + case TypeId.BOOL: + return Type.boolArray(); + case TypeId.INT8: + return Type.int8Array(); + case TypeId.INT16: + return Type.int16Array(); + case TypeId.INT32: + return Type.int32Array(); + case TypeId.INT64: + return Type.int64Array(); + case TypeId.UINT8: + return Type.uint8Array(); + case TypeId.UINT16: + return Type.uint16Array(); + case TypeId.UINT32: + return Type.uint32Array(); + case TypeId.UINT64: + return Type.uint64Array(); + case TypeId.FLOAT16: + return Type.float16Array(); + case TypeId.BFLOAT16: + return Type.bfloat16Array(); + case TypeId.FLOAT32: + return Type.float32Array(); + case TypeId.FLOAT64: + return Type.float64Array(); + default: + return Type.any(); + } +} + +function compatibleListToArrayTypeInfo( + remoteElement: InnerFieldInfo, + targetElementTypeId: number, +): TypeInfo { + const elementTypeInfo = typeInfoForElementTypeId(remoteElement.typeId) + .setNullable(remoteElement.nullable === true) + .setTrackingRef(remoteElement.trackingRef === true); + const typeInfo = Type.list(elementTypeInfo); + return markCompatibleCollectionArrayRead(typeInfo, { + target: "array", + elementTypeId: targetElementTypeId, + }); +} + +function compatibleArrayToListTypeInfo(elementTypeId: number): TypeInfo { + const typeInfo = typeInfoForDenseArrayElementTypeId(elementTypeId); + return markCompatibleCollectionArrayRead(typeInfo, { + target: "list", + elementTypeId, + }); +} + class MetaStringBytes { dynamicWriteStringId = -1; @@ -500,17 +658,29 @@ export class ReadContext { private fieldInfoToTypeInfo( fieldInfo: InnerFieldInfo, fallbackTypeInfo?: TypeInfo, + topLevel = true, ): TypeInfo { + if (topLevel && fallbackTypeInfo) { + const compatible = this.compatibleFieldTypeInfo(fieldInfo, fallbackTypeInfo); + if (compatible) { + return compatible; + } + } + if (this.hasUnsupportedListArrayMismatch(fieldInfo, fallbackTypeInfo, topLevel)) { + throw new Error("unsupported compatible list/array schema mismatch"); + } switch (fieldInfo.typeId) { case TypeId.MAP: return Type.map( this.fieldInfoToTypeInfo( fieldInfo.options!.key!, fallbackTypeInfo?.options?.key, + false, ), this.fieldInfoToTypeInfo( fieldInfo.options!.value!, fallbackTypeInfo?.options?.value, + false, ), ); case TypeId.LIST: @@ -518,6 +688,7 @@ export class ReadContext { this.fieldInfoToTypeInfo( fieldInfo.options!.inner!, fallbackTypeInfo?.options?.inner, + false, ), ); case TypeId.SET: @@ -525,6 +696,7 @@ export class ReadContext { this.fieldInfoToTypeInfo( fieldInfo.options!.key!, fallbackTypeInfo?.options?.key, + false, ), ); default: { @@ -561,6 +733,82 @@ export class ReadContext { } } + private compatibleFieldTypeInfo( + remote: InnerFieldInfo, + local: TypeInfo, + ): TypeInfo | undefined { + const remoteElement = remoteListElementType(remote); + const localElement = denseArrayElementTypeId(local.typeId); + if (remoteElement !== undefined && localElement !== undefined) { + if (compatibleArrayElementTypeId(remoteElement.typeId) !== localElement) { + return undefined; + } + return compatibleListToArrayTypeInfo(remoteElement, localElement); + } + const remoteArrayElement = denseArrayElementTypeId(remote.typeId); + if ( + remoteArrayElement !== undefined + && local.typeId === TypeId.LIST + && local.options?.inner + && compatibleArrayElementTypeId(local.options.inner.typeId) === remoteArrayElement + ) { + return compatibleArrayToListTypeInfo(remoteArrayElement); + } + return undefined; + } + + private hasUnsupportedListArrayMismatch( + remote: InnerFieldInfo, + local: TypeInfo | undefined, + topLevel: boolean, + ): boolean { + if (!local) { + return false; + } + if (this.isListArrayRootPair(remote, local)) { + return !(topLevel && this.compatibleFieldTypeInfo(remote, local)); + } + if (remote.typeId !== local.typeId) { + return false; + } + switch (remote.typeId) { + case TypeId.MAP: + return ( + this.hasUnsupportedListArrayMismatch( + remote.options!.key!, + local.options?.key, + false, + ) + || this.hasUnsupportedListArrayMismatch( + remote.options!.value!, + local.options?.value, + false, + ) + ); + case TypeId.LIST: + return this.hasUnsupportedListArrayMismatch( + remote.options!.inner!, + local.options?.inner, + false, + ); + case TypeId.SET: + return this.hasUnsupportedListArrayMismatch( + remote.options!.key!, + local.options?.key, + false, + ); + default: + return false; + } + } + + private isListArrayRootPair(remote: InnerFieldInfo, local: TypeInfo): boolean { + return ( + (remote.typeId === TypeId.LIST && denseArrayElementTypeId(local.typeId) !== undefined) + || (denseArrayElementTypeId(remote.typeId) !== undefined && local.typeId === TypeId.LIST) + ); + } + genSerializerByTypeMetaRuntime(typeMeta: TypeMeta, original?: Serializer) { const typeId = typeMeta.getTypeId(); if (!TypeId.structType(typeId)) { diff --git a/javascript/packages/core/lib/gen/collection.ts b/javascript/packages/core/lib/gen/collection.ts index 76b1ac73b9..b97876230b 100644 --- a/javascript/packages/core/lib/gen/collection.ts +++ b/javascript/packages/core/lib/gen/collection.ts @@ -17,14 +17,38 @@ * under the License. */ -import { TypeInfo } from "../typeInfo"; +import type { TypeInfo } from "../typeInfo"; import { CodecBuilder } from "./builder"; import { BaseSerializerGenerator, SerializerGenerator } from "./serializer"; import { CodegenRegistry } from "./router"; import { TypeId, RefFlags, Serializer } from "../type"; import { Scope } from "./scope"; import { AnyHelper } from "./any"; -import { ReadContext, WriteContext } from "../context"; +import type { ReadContext, WriteContext } from "../context"; + +export type CompatibleCollectionArrayReadAction = { + target: "array" | "list"; + elementTypeId: number; +}; + +const compatibleCollectionArrayReadActions = new WeakMap< + TypeInfo, + CompatibleCollectionArrayReadAction +>(); + +export function markCompatibleCollectionArrayRead( + typeInfo: TypeInfo, + action: CompatibleCollectionArrayReadAction, +): TypeInfo { + compatibleCollectionArrayReadActions.set(typeInfo, action); + return typeInfo; +} + +export function getCompatibleCollectionArrayReadAction( + typeInfo: TypeInfo, +): CompatibleCollectionArrayReadAction | undefined { + return compatibleCollectionArrayReadActions.get(typeInfo); +} export const CollectionFlags = { /** Whether track elements ref. */ @@ -40,6 +64,56 @@ export const CollectionFlags = { SAME_TYPE: 0b1000, }; +function compatibleArrayCollectionExpr(elementTypeId: number, len: string): string { + switch (elementTypeId) { + case TypeId.BOOL: + return `new external.BoolArray(${len})`; + case TypeId.INT8: + return `new Int8Array(${len})`; + case TypeId.INT16: + return `new Int16Array(${len})`; + case TypeId.INT32: + return `new Int32Array(${len})`; + case TypeId.INT64: + return `new BigInt64Array(${len})`; + case TypeId.UINT8: + return `new Uint8Array(${len})`; + case TypeId.UINT16: + return `new Uint16Array(${len})`; + case TypeId.UINT32: + return `new Uint32Array(${len})`; + case TypeId.UINT64: + return `new BigUint64Array(${len})`; + case TypeId.FLOAT16: + return `external.createFloat16Array(${len})`; + case TypeId.BFLOAT16: + return `new external.BFloat16Array(${len})`; + case TypeId.FLOAT32: + return `new Float32Array(${len})`; + case TypeId.FLOAT64: + return `new Float64Array(${len})`; + default: + return `new Array(${len})`; + } +} + +function compatibleArrayPutAccessor( + elementTypeId: number, + result: string, + item: string, + index: string, +): string { + switch (elementTypeId) { + case TypeId.BOOL: + case TypeId.BFLOAT16: + return `${result}.setValue(${index}, ${item})`; + case TypeId.FLOAT16: + return `typeof ${result}.setValue === "function" ? ${result}.setValue(${index}, ${item}) : ${result}[${index}] = ${item}`; + default: + return `${result}[${index}] = ${item}`; + } +} + class CollectionAnySerializer { constructor( private writeContext: WriteContext, @@ -370,6 +444,24 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera const elemSerializer = this.scope.uniqueName("elemSerializer"); const anyHelper = this.builder.getExternal(AnyHelper.name); const readContextName = this.builder.getReadContextName(); + const compatibleReadAction = getCompatibleCollectionArrayReadAction(this.typeInfo); + const compatibleListToArray = compatibleReadAction?.target === "array"; + const newCollection = compatibleListToArray + ? compatibleArrayCollectionExpr(compatibleReadAction!.elementTypeId, len) + : this.newCollection(len); + const putAccessor = (item: string, index: string) => compatibleListToArray + ? compatibleArrayPutAccessor(compatibleReadAction!.elementTypeId, result, item, index) + : this.putAccessor(result, item, index); + const rejectCompatiblePayload = compatibleListToArray + ? ` + if (${flags} & (${CollectionFlags.HAS_NULL} | ${CollectionFlags.TRACKING_REF})) { + throw new Error("compatible list-to-array field cannot read nullable or ref-tracked elements"); + } + if ((${flags} & (${CollectionFlags.SAME_TYPE} | ${CollectionFlags.DECL_ELEMENT_TYPE})) !== (${CollectionFlags.SAME_TYPE} | ${CollectionFlags.DECL_ELEMENT_TYPE})) { + throw new Error("compatible list-to-array field requires declared same-type elements"); + } + ` + : ""; // Skip depth tracking for leaf element types (primitives, string, enum, time, typed arrays). const innerIsLeaf = TypeId.isLeafTypeId(this.innerGenerator.getTypeId()!); const readInnerElement = ( @@ -383,10 +475,11 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera return ` const ${len} = ${this.builder.reader.readVarUint32Small7()}; ${this.builder.getReadContextName()}.checkCollectionSize(${len}); - const ${result} = ${this.newCollection(len)}; + const ${result} = ${newCollection}; ${this.maybeReference(result, refState)} if (${len} > 0) { const ${flags} = ${this.builder.reader.readUint8()}; + ${rejectCompatiblePayload} let ${elemSerializer} = null; if (!(${flags} & ${CollectionFlags.DECL_ELEMENT_TYPE})) { ${elemSerializer} = ${anyHelper}.detectSerializer(${readContextName}); @@ -399,31 +492,31 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera case ${RefFlags.RefValueFlag}: if (${elemSerializer}) { ${innerIsLeaf ? "" : `${readContextName}.incReadDepth();`} - ${this.putAccessor(result, `${elemSerializer}.read(${refFlag} === ${RefFlags.RefValueFlag})`, idx)} + ${putAccessor(`${elemSerializer}.read(${refFlag} === ${RefFlags.RefValueFlag})`, idx)} ${innerIsLeaf ? "" : `${readContextName}.decReadDepth();`} } else { - ${readInnerElement((x: any) => `${this.putAccessor(result, x, idx)}`, `${refFlag} === ${RefFlags.RefValueFlag}`)} + ${readInnerElement((x: any) => `${putAccessor(x, idx)}`, `${refFlag} === ${RefFlags.RefValueFlag}`)} } break; case ${RefFlags.RefFlag}: - ${this.putAccessor(result, this.builder.referenceResolver.getReadRef(this.builder.reader.readVarUInt32()), idx)} + ${putAccessor(this.builder.referenceResolver.getReadRef(this.builder.reader.readVarUInt32()), idx)} break; case ${RefFlags.NullFlag}: - ${this.putAccessor(result, "null", idx)} + ${putAccessor("null", idx)} break; } } } else if (${flags} & ${CollectionFlags.HAS_NULL}) { for (let ${idx} = 0; ${idx} < ${len}; ${idx}++) { if (${this.builder.reader.readInt8()} == ${RefFlags.NullFlag}) { - ${this.putAccessor(result, "null", idx)} + ${putAccessor("null", idx)} } else { if (${elemSerializer}) { ${innerIsLeaf ? "" : `${readContextName}.incReadDepth();`} - ${this.putAccessor(result, `${elemSerializer}.read(false)`, idx)} + ${putAccessor(`${elemSerializer}.read(false)`, idx)} ${innerIsLeaf ? "" : `${readContextName}.decReadDepth();`} } else { - ${readInnerElement((x: any) => `${this.putAccessor(result, x, idx)}`, "false")} + ${readInnerElement((x: any) => `${putAccessor(x, idx)}`, "false")} } } } @@ -431,10 +524,10 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera for (let ${idx} = 0; ${idx} < ${len}; ${idx}++) { if (${elemSerializer}) { ${innerIsLeaf ? "" : `${readContextName}.incReadDepth();`} - ${this.putAccessor(result, `${elemSerializer}.read(false)`, idx)} + ${putAccessor(`${elemSerializer}.read(false)`, idx)} ${innerIsLeaf ? "" : `${readContextName}.decReadDepth();`} } else { - ${readInnerElement((x: any) => `${this.putAccessor(result, x, idx)}`, "false")} + ${readInnerElement((x: any) => `${putAccessor(x, idx)}`, "false")} } } } diff --git a/javascript/packages/core/lib/gen/struct.ts b/javascript/packages/core/lib/gen/struct.ts index 9dec622de1..048603b48e 100644 --- a/javascript/packages/core/lib/gen/struct.ts +++ b/javascript/packages/core/lib/gen/struct.ts @@ -24,6 +24,7 @@ import { TypeInfo } from "../typeInfo"; import { CodegenRegistry } from "./router"; import { BaseSerializerGenerator, SerializerGenerator } from "./serializer"; import { TypeMeta } from "../meta/TypeMeta"; +import { getCompatibleCollectionArrayReadAction } from "./collection"; /** * Returns true when a field's read cannot recurse and needs no depth tracking. @@ -46,6 +47,18 @@ function isDepthFreeField(typeInfo: TypeInfo): boolean { return false; } +function compatibleReadTargetExpr(typeInfo: TypeInfo, expr: string): string { + const action = getCompatibleCollectionArrayReadAction(typeInfo); + switch (action?.target) { + case "array": + return expr; + case "list": + return `Array.from(${expr})`; + default: + return expr; + } +} + const sortProps = (typeInfo: TypeInfo, typeResolver: CodecBuilder["resolver"]) => { const names = TypeMeta.fromTypeInfo(typeInfo, typeResolver).getFieldInfo(); const props = typeInfo.options!.props; @@ -109,24 +122,25 @@ class StructSerializerGenerator extends BaseSerializerGenerator { readField(fieldTypeInfo: TypeInfo, assignStmt: (expr: string) => string, embedGenerator: SerializerGenerator) { const { nullable = false, dynamic, trackingRef } = fieldTypeInfo; const refMode = toRefMode(trackingRef, nullable); + const assignCompatible = (expr: string) => assignStmt(compatibleReadTargetExpr(fieldTypeInfo, expr)); let stmt = ""; // polymorphic type if (this.builder.resolver.isMonomorphic(fieldTypeInfo, dynamic)) { if (refMode == RefMode.TRACKING || refMode === RefMode.NULL_ONLY) { stmt = ` - ${embedGenerator.readRefWithoutTypeInfo(assignStmt)} + ${embedGenerator.readRefWithoutTypeInfo(assignCompatible)} `; } else if (isDepthFreeField(fieldTypeInfo)) { // Leaf types and collections of leaf types cannot recurse — skip depth tracking. - stmt = embedGenerator.read(assignStmt, "false"); + stmt = embedGenerator.read(assignCompatible, "false"); } else { - stmt = embedGenerator.readWithDepth(assignStmt, "false"); + stmt = embedGenerator.readWithDepth(assignCompatible, "false"); } } else { if (refMode == RefMode.TRACKING || refMode === RefMode.NULL_ONLY) { - stmt = `${embedGenerator.readRef(assignStmt)}`; + stmt = `${embedGenerator.readRef(assignCompatible)}`; } else { - stmt = embedGenerator.readNoRef(assignStmt, "false"); + stmt = embedGenerator.readNoRef(assignCompatible, "false"); } } return stmt; diff --git a/javascript/packages/core/lib/gen/typedArray.ts b/javascript/packages/core/lib/gen/typedArray.ts index 7506e8e97d..bb6d716db7 100644 --- a/javascript/packages/core/lib/gen/typedArray.ts +++ b/javascript/packages/core/lib/gen/typedArray.ts @@ -23,8 +23,48 @@ import { BaseSerializerGenerator } from "./serializer"; import { CodegenRegistry } from "./router"; import { Scope } from "./scope"; import { TypeId } from "../type"; +import { BFloat16Array } from "../types/bfloat16"; +import { BoolArray } from "../types/boolArray"; +import { + createFloat16Array, + createFloat16ArrayFromRaw, + getFloat16Raw, +} from "../types/float16"; -function build(inner: TypeInfo, creator: string, size: number) { +const endianProbe = new Uint16Array([0x00ff]); +const IS_LITTLE_ENDIAN = new Uint8Array(endianProbe.buffer)[0] === 0xff; + +type TypedArrayReadMethod = + | "readFloat32" + | "readFloat64" + | "readInt8" + | "readInt16" + | "readInt32" + | "readInt64" + | "readUint8" + | "readUint16" + | "readUint32" + | "readUint64"; + +type TypedArrayWriteMethod = + | "writeFloat32" + | "writeFloat64" + | "writeInt8" + | "writeInt16" + | "writeInt32" + | "writeInt64" + | "writeUint8" + | "writeUint16" + | "writeUint32" + | "writeUint64"; + +function build( + inner: TypeInfo, + creator: string, + size: number, + readMethod: TypedArrayReadMethod, + writeMethod: TypedArrayWriteMethod, +) { return class TypedArraySerializerGenerator extends BaseSerializerGenerator { typeInfo: TypeInfo; @@ -34,6 +74,16 @@ function build(inner: TypeInfo, creator: string, size: number) { } write(accessor: string): string { + if (size !== 1 && !IS_LITTLE_ENDIAN) { + const idx = this.scope.uniqueName("idx"); + return ` + ${this.builder.writer.writeVarUInt32(`${accessor}.length * ${size}`)} + ${this.builder.writer.reserve(`${accessor}.length * ${size}`)}; + for (let ${idx} = 0; ${idx} < ${accessor}.length; ${idx}++) { + ${this.builder.writer[writeMethod](`${accessor}[${idx}]`)} + } + `; + } return ` ${this.builder.writer.writeVarUInt32(`${accessor}.byteLength`)} ${this.builder.writer.arrayBuffer(`${accessor}.buffer`, `${accessor}.byteOffset`, `${accessor}.byteLength`)} @@ -45,6 +95,24 @@ function build(inner: TypeInfo, creator: string, size: number) { const len = this.scope.uniqueName("len"); const copied = this.scope.uniqueName("copied"); + if (size !== 1 && !IS_LITTLE_ENDIAN) { + const rawLen = this.scope.uniqueName("rawLen"); + const idx = this.scope.uniqueName("idx"); + return ` + const ${rawLen} = ${this.builder.reader.readVarUInt32()}; + ${this.builder.getReadContextName()}.checkBinarySize(${rawLen}); + if ((${rawLen} % ${size}) !== 0) { + throw new Error("dense array byte length is not divisible by element size"); + } + const ${len} = ${rawLen} / ${size}; + const ${result} = new ${creator}(${len}); + ${this.maybeReference(result, refState)} + for (let ${idx} = 0; ${idx} < ${len}; ${idx}++) { + ${result}[${idx}] = ${this.builder.reader[readMethod]()}; + } + ${accessor(result)} + `; + } return ` const ${len} = ${this.builder.reader.readVarUInt32()}; ${this.builder.getReadContextName()}.checkBinarySize(${len}); @@ -71,11 +139,18 @@ class BoolArraySerializerGenerator extends BaseSerializerGenerator { write(accessor: string): string { const item = this.scope.uniqueName("item"); + const raw = this.scope.uniqueName("raw"); return ` - ${this.builder.writer.writeVarUInt32(`${accessor}.length`)} - ${this.builder.writer.reserve(`${accessor}.length`)}; - for (const ${item} of ${accessor}) { - ${this.builder.writer.writeUint8(`${item} ? 1 : 0`)} + if (${accessor} instanceof external.BoolArray) { + const ${raw} = ${accessor}.raw; + ${this.builder.writer.writeVarUInt32(`${raw}.byteLength`)} + ${this.builder.writer.buffer(raw)} + } else { + ${this.builder.writer.writeVarUInt32(`${accessor}.length`)} + ${this.builder.writer.reserve(`${accessor}.length`)}; + for (const ${item} of ${accessor}) { + ${this.builder.writer.writeUint8(`${item} ? 1 : 0`)} + } } `; } @@ -84,14 +159,45 @@ class BoolArraySerializerGenerator extends BaseSerializerGenerator { const result = this.scope.uniqueName("result"); const len = this.scope.uniqueName("len"); const idx = this.scope.uniqueName("idx"); + const raw = this.scope.uniqueName("raw"); + const bits = this.scope.uniqueName("bits"); + const copied = this.scope.uniqueName("copied"); + const readByte = this.builder.reader.readUint8(); return ` const ${len} = ${this.builder.reader.readVarUInt32()}; ${this.builder.getReadContextName()}.checkCollectionSize(${len}); - const ${result} = new Array(${len}); - ${this.maybeReference(result, refState)} - for (let ${idx} = 0; ${idx} < ${len}; ${idx}++) { - ${result}[${idx}] = ${this.builder.reader.readUint8()} === 1; + let ${result}; + if (${len} <= 4) { + let ${bits}; + switch (${len}) { + case 4: + ${bits} = ${this.builder.reader.readUint32()}; + break; + case 3: + ${bits} = ${readByte} | (${readByte} << 8) | (${readByte} << 16); + break; + case 2: + ${bits} = ${this.builder.reader.readUint16()}; + break; + case 1: + ${bits} = ${readByte}; + break; + default: + ${bits} = 0; + break; + } + ${result} = external.BoolArray.fromPackedBytes(${bits}, ${len}); + } else if (${len} <= 32) { + ${result} = new external.BoolArray(${len}); + const ${raw} = ${result}.raw; + for (let ${idx} = 0; ${idx} < ${len}; ${idx}++) { + ${raw}[${idx}] = ${readByte}; + } + } else { + const ${copied} = ${this.builder.reader.buffer(len)} + ${result} = external.BoolArray.fromRaw(${copied}); } + ${this.maybeReference(result, refState)} ${accessor(result)} `; } @@ -111,11 +217,22 @@ class Float16ArraySerializerGenerator extends BaseSerializerGenerator { write(accessor: string): string { const item = this.scope.uniqueName("item"); + const raw = this.scope.uniqueName("raw"); + const idx = this.scope.uniqueName("idx"); return ` - ${this.builder.writer.writeVarUInt32(`${accessor}.length * 2`)} - ${this.builder.writer.reserve(`${accessor}.length * 2`)}; - for (const ${item} of ${accessor}) { - ${this.builder.writer.writeFloat16(item)} + const ${raw} = external.getFloat16Raw(${accessor}); + if (${raw} !== null) { + ${this.builder.writer.writeVarUInt32(`${raw}.byteLength`)} + ${this.builder.writer.reserve(`${raw}.byteLength`)}; + for (let ${idx} = 0; ${idx} < ${raw}.length; ${idx}++) { + ${this.builder.writer.writeUint16(`${raw}[${idx}]`)} + } + } else { + ${this.builder.writer.writeVarUInt32(`${accessor}.length * 2`)} + ${this.builder.writer.reserve(`${accessor}.length * 2`)}; + for (const ${item} of ${accessor}) { + ${this.builder.writer.writeFloat16(item)} + } } `; } @@ -125,15 +242,20 @@ class Float16ArraySerializerGenerator extends BaseSerializerGenerator { const rawLen = this.scope.uniqueName("rawLen"); const len = this.scope.uniqueName("len"); const idx = this.scope.uniqueName("idx"); + const raw = this.scope.uniqueName("raw"); return ` const ${rawLen} = ${this.builder.reader.readVarUInt32()}; ${this.builder.getReadContextName()}.checkBinarySize(${rawLen}); + if ((${rawLen} % 2) !== 0) { + throw new Error("float16 array byte length is not divisible by element size"); + } const ${len} = ${rawLen} / 2; - const ${result} = new Array(${len}); - ${this.maybeReference(result, refState)} + const ${raw} = new Uint16Array(${len}); for (let ${idx} = 0; ${idx} < ${len}; ${idx}++) { - ${result}[${idx}] = ${this.builder.reader.readFloat16()}; + ${raw}[${idx}] = ${this.builder.reader.readUint16()}; } + const ${result} = external.createFloat16ArrayFromRaw(${raw}); + ${this.maybeReference(result, refState)} ${accessor(result)} `; } @@ -153,11 +275,22 @@ class BFloat16ArraySerializerGenerator extends BaseSerializerGenerator { write(accessor: string): string { const item = this.scope.uniqueName("item"); + const raw = this.scope.uniqueName("raw"); + const idx = this.scope.uniqueName("idx"); return ` - ${this.builder.writer.writeVarUInt32(`${accessor}.length * 2`)} - ${this.builder.writer.reserve(`${accessor}.length * 2`)}; - for (const ${item} of ${accessor}) { - ${this.builder.writer.writeBfloat16(item)} + if (${accessor} instanceof external.BFloat16Array) { + const ${raw} = ${accessor}.raw; + ${this.builder.writer.writeVarUInt32(`${raw}.byteLength`)} + ${this.builder.writer.reserve(`${raw}.byteLength`)}; + for (let ${idx} = 0; ${idx} < ${raw}.length; ${idx}++) { + ${this.builder.writer.writeUint16(`${raw}[${idx}]`)} + } + } else { + ${this.builder.writer.writeVarUInt32(`${accessor}.length * 2`)} + ${this.builder.writer.reserve(`${accessor}.length * 2`)}; + for (const ${item} of ${accessor}) { + ${this.builder.writer.writeBfloat16(item)} + } } `; } @@ -167,15 +300,20 @@ class BFloat16ArraySerializerGenerator extends BaseSerializerGenerator { const rawLen = this.scope.uniqueName("rawLen"); const len = this.scope.uniqueName("len"); const idx = this.scope.uniqueName("idx"); + const raw = this.scope.uniqueName("raw"); return ` const ${rawLen} = ${this.builder.reader.readVarUInt32()}; ${this.builder.getReadContextName()}.checkBinarySize(${rawLen}); + if ((${rawLen} % 2) !== 0) { + throw new Error("bfloat16 array byte length is not divisible by element size"); + } const ${len} = ${rawLen} / 2; - const ${result} = new Array(${len}); - ${this.maybeReference(result, refState)} + const ${raw} = new Uint16Array(${len}); for (let ${idx} = 0; ${idx} < ${len}; ${idx}++) { - ${result}[${idx}] = ${this.builder.reader.readBfloat16()}; + ${raw}[${idx}] = ${this.builder.reader.readUint16()}; } + const ${result} = external.BFloat16Array.fromRaw(${raw}); + ${this.maybeReference(result, refState)} ${accessor(result)} `; } @@ -186,16 +324,21 @@ class BFloat16ArraySerializerGenerator extends BaseSerializerGenerator { } CodegenRegistry.register(TypeId.BOOL_ARRAY, BoolArraySerializerGenerator); -CodegenRegistry.register(TypeId.BINARY, build(Type.uint8(), `Uint8Array`, 1)); -CodegenRegistry.register(TypeId.INT8_ARRAY, build(Type.int8(), `Int8Array`, 1)); -CodegenRegistry.register(TypeId.INT16_ARRAY, build(Type.int16(), `Int16Array`, 2)); -CodegenRegistry.register(TypeId.INT32_ARRAY, build(Type.int32(), `Int32Array`, 4)); -CodegenRegistry.register(TypeId.INT64_ARRAY, build(Type.int64(), `BigInt64Array`, 8)); -CodegenRegistry.register(TypeId.UINT8_ARRAY, build(Type.uint8(), `Uint8Array`, 1)); -CodegenRegistry.register(TypeId.UINT16_ARRAY, build(Type.uint16(), `Uint16Array`, 2)); -CodegenRegistry.register(TypeId.UINT32_ARRAY, build(Type.uint32(), `Uint32Array`, 4)); -CodegenRegistry.register(TypeId.UINT64_ARRAY, build(Type.uint64(), `BigUint64Array`, 8)); +CodegenRegistry.register(TypeId.BINARY, build(Type.uint8(), `Uint8Array`, 1, "readUint8", "writeUint8")); +CodegenRegistry.register(TypeId.INT8_ARRAY, build(Type.int8(), `Int8Array`, 1, "readInt8", "writeInt8")); +CodegenRegistry.register(TypeId.INT16_ARRAY, build(Type.int16(), `Int16Array`, 2, "readInt16", "writeInt16")); +CodegenRegistry.register(TypeId.INT32_ARRAY, build(Type.int32(), `Int32Array`, 4, "readInt32", "writeInt32")); +CodegenRegistry.register(TypeId.INT64_ARRAY, build(Type.int64(), `BigInt64Array`, 8, "readInt64", "writeInt64")); +CodegenRegistry.register(TypeId.UINT8_ARRAY, build(Type.uint8(), `Uint8Array`, 1, "readUint8", "writeUint8")); +CodegenRegistry.register(TypeId.UINT16_ARRAY, build(Type.uint16(), `Uint16Array`, 2, "readUint16", "writeUint16")); +CodegenRegistry.register(TypeId.UINT32_ARRAY, build(Type.uint32(), `Uint32Array`, 4, "readUint32", "writeUint32")); +CodegenRegistry.register(TypeId.UINT64_ARRAY, build(Type.uint64(), `BigUint64Array`, 8, "readUint64", "writeUint64")); CodegenRegistry.register(TypeId.FLOAT16_ARRAY, Float16ArraySerializerGenerator); CodegenRegistry.register(TypeId.BFLOAT16_ARRAY, BFloat16ArraySerializerGenerator); -CodegenRegistry.register(TypeId.FLOAT32_ARRAY, build(Type.float32(), `Float32Array`, 4)); -CodegenRegistry.register(TypeId.FLOAT64_ARRAY, build(Type.float64(), `Float64Array`, 8)); +CodegenRegistry.register(TypeId.FLOAT32_ARRAY, build(Type.float32(), `Float32Array`, 4, "readFloat32", "writeFloat32")); +CodegenRegistry.register(TypeId.FLOAT64_ARRAY, build(Type.float64(), `Float64Array`, 8, "readFloat64", "writeFloat64")); +CodegenRegistry.registerExternal(BFloat16Array); +CodegenRegistry.registerExternal(BoolArray); +CodegenRegistry.registerExternal(createFloat16Array); +CodegenRegistry.registerExternal(createFloat16ArrayFromRaw); +CodegenRegistry.registerExternal(getFloat16Raw); diff --git a/javascript/packages/core/lib/reader/index.ts b/javascript/packages/core/lib/reader/index.ts index daef64375d..fe4263b751 100644 --- a/javascript/packages/core/lib/reader/index.ts +++ b/javascript/packages/core/lib/reader/index.ts @@ -22,6 +22,7 @@ import { isNodeEnv } from "../util"; import { PlatformBuffer, alloc, fromUint8Array } from "../platformBuffer"; import { readLatin1String } from "./string"; import { BFloat16 } from "../types/bfloat16"; +import { fromFloat16Bits } from "../types/float16"; export class BinaryReader { private sliceStringEnable; @@ -549,34 +550,7 @@ export class BinaryReader { } readFloat16() { - const asUint16 = this.readUint16(); - const sign = asUint16 >> 15; - const exponent = (asUint16 >> 10) & 0x1f; - const mantissa = asUint16 & 0x3ff; - - // IEEE 754-2008 - if (exponent === 0) { - if (mantissa === 0) { - // +-0 - return sign === 0 ? 0 : -0; - } else { - // Denormalized number - return (sign === 0 ? 1 : -1) * mantissa * 2 ** (1 - 15 - 10); - } - } else if (exponent === 31) { - if (mantissa === 0) { - // Infinity - return sign === 0 ? Infinity : -Infinity; - } else { - // NaN - return NaN; - } - } else { - // Normalized number - return ( - (sign === 0 ? 1 : -1) * (1 + mantissa * 2 ** -10) * 2 ** (exponent - 15) - ); - } + return fromFloat16Bits(this.readUint16()); } readBfloat16(): BFloat16 { diff --git a/javascript/packages/core/lib/typeInfo.ts b/javascript/packages/core/lib/typeInfo.ts index 16c936bed6..9d09b85271 100644 --- a/javascript/packages/core/lib/typeInfo.ts +++ b/javascript/packages/core/lib/typeInfo.ts @@ -19,6 +19,8 @@ import { ForyTypeInfoSymbol, TypeId } from "./type"; import { BFloat16, BFloat16Array } from "./types/bfloat16"; +import { BoolArray } from "./types/boolArray"; +import { Float16Array } from "./types/float16"; import { Decimal } from "./types/decimal"; @@ -554,7 +556,7 @@ export type HintInput = T extends { : T extends { type: typeof TypeId.BOOL_ARRAY; } - ? boolean[] + ? BoolArray | boolean[] : T extends { type: typeof TypeId.INT8_ARRAY; } @@ -590,7 +592,7 @@ export type HintInput = T extends { : T extends { type: typeof TypeId.FLOAT16_ARRAY; } - ? number[] + ? Float16Array | number[] : T extends { type: typeof TypeId.BFLOAT16_ARRAY; } @@ -684,7 +686,7 @@ export type HintResult = T extends never ? any : T extends { : T extends { type: typeof TypeId.BOOL_ARRAY; } - ? boolean[] + ? BoolArray : T extends { type: typeof TypeId.INT8_ARRAY; } @@ -720,11 +722,11 @@ export type HintResult = T extends never ? any : T extends { : T extends { type: typeof TypeId.FLOAT16_ARRAY; } - ? number[] + ? Float16Array : T extends { type: typeof TypeId.BFLOAT16_ARRAY; } - ? BFloat16[] + ? BFloat16Array : T extends { type: typeof TypeId.FLOAT32_ARRAY; } diff --git a/javascript/packages/core/lib/typeResolver.ts b/javascript/packages/core/lib/typeResolver.ts index 3bb93acc88..573fd9396c 100644 --- a/javascript/packages/core/lib/typeResolver.ts +++ b/javascript/packages/core/lib/typeResolver.ts @@ -22,6 +22,9 @@ import { Gen } from "./gen"; import { Dynamic, Type, TypeInfo } from "./typeInfo"; import { ReadContext, WriteContext } from "./context"; import { Decimal } from "./types/decimal"; +import { BFloat16Array } from "./types/bfloat16"; +import { BoolArray } from "./types/boolArray"; +import { isFloat16Array } from "./types/float16"; const uninitSerialize = { _initialized: false, @@ -108,6 +111,9 @@ export default class TypeResolver { private int16ArraySerializer: null | Serializer = null; private int32ArraySerializer: null | Serializer = null; private int64ArraySerializer: null | Serializer = null; + private boolArraySerializer: null | Serializer = null; + private float16ArraySerializer: null | Serializer = null; + private bfloat16ArraySerializer: null | Serializer = null; private float32ArraySerializer: null | Serializer = null; private float64ArraySerializer: null | Serializer = null; @@ -231,6 +237,9 @@ export default class TypeResolver { this.int16ArraySerializer = this.getSerializerById(TypeId.INT16_ARRAY); this.int32ArraySerializer = this.getSerializerById(TypeId.INT32_ARRAY); this.int64ArraySerializer = this.getSerializerById(TypeId.INT64_ARRAY); + this.boolArraySerializer = this.getSerializerById(TypeId.BOOL_ARRAY); + this.float16ArraySerializer = this.getSerializerById(TypeId.FLOAT16_ARRAY); + this.bfloat16ArraySerializer = this.getSerializerById(TypeId.BFLOAT16_ARRAY); this.float32ArraySerializer = this.getSerializerById(TypeId.FLOAT32_ARRAY); this.float64ArraySerializer = this.getSerializerById(TypeId.FLOAT64_ARRAY); } @@ -339,6 +348,18 @@ export default class TypeResolver { return this.decimalSerializer; } + if (v instanceof BoolArray) { + return this.boolArraySerializer; + } + + if (isFloat16Array(v)) { + return this.float16ArraySerializer; + } + + if (v instanceof BFloat16Array) { + return this.bfloat16ArraySerializer; + } + if (v instanceof Uint8Array) { return this.uint8ArraySerializer; } diff --git a/javascript/packages/core/lib/types/bfloat16.ts b/javascript/packages/core/lib/types/bfloat16.ts index 0beb46d448..5806de1587 100644 --- a/javascript/packages/core/lib/types/bfloat16.ts +++ b/javascript/packages/core/lib/types/bfloat16.ts @@ -57,24 +57,35 @@ export class BFloat16 { } } +export function toBFloat16Bits(value: BFloat16 | number): number { + return value instanceof BFloat16 ? value.toBits() : BFloat16.fromFloat32(value).toBits(); +} + +export function fromBFloat16Bits(bits: number): number { + float32View[0] = 0; + uint32View[0] = (bits & 0xffff) << 16; + return float32View[0]; +} + export class BFloat16Array { - private readonly _data: Uint16Array; + static readonly BYTES_PER_ELEMENT = 2; + readonly BYTES_PER_ELEMENT = 2; + private _data: Uint16Array; constructor(length: number); - constructor(source: Uint16Array | BFloat16[] | number[]); - constructor(lengthOrSource: number | Uint16Array | BFloat16[] | number[]) { + constructor(source: BFloat16Array | Uint16Array | ArrayLike); + constructor(lengthOrSource: number | BFloat16Array | Uint16Array | ArrayLike) { if (typeof lengthOrSource === "number") { this._data = new Uint16Array(lengthOrSource); + } else if (lengthOrSource instanceof BFloat16Array) { + this._data = new Uint16Array(lengthOrSource.raw); } else if (lengthOrSource instanceof Uint16Array) { this._data = new Uint16Array(lengthOrSource.length); this._data.set(lengthOrSource); } else { - const arr = lengthOrSource as (BFloat16 | number)[]; - this._data = new Uint16Array(arr.length); - for (let i = 0; i < arr.length; i++) { - const v = arr[i]; - this._data[i] - = v instanceof BFloat16 ? v.toBits() : BFloat16.fromFloat32(v).toBits(); + this._data = new Uint16Array(lengthOrSource.length); + for (let i = 0; i < lengthOrSource.length; i++) { + this._data[i] = toBFloat16Bits(lengthOrSource[i]); } } } @@ -83,13 +94,47 @@ export class BFloat16Array { return this._data.length; } - get(index: number): BFloat16 { - return BFloat16.fromBits(this._data[index]); + get byteLength(): number { + return this._data.byteLength; + } + + get byteOffset(): number { + return this._data.byteOffset; + } + + get buffer(): ArrayBufferLike { + return this._data.buffer; + } + + get(index: number): number { + return fromBFloat16Bits(this._data[index]); + } + + at(index: number): number | undefined { + const normalized = index < 0 ? this.length + index : index; + if (normalized < 0 || normalized >= this.length) { + return undefined; + } + return this.get(normalized); + } + + set(values: BFloat16Array | ArrayLike, offset = 0): void { + if (values instanceof BFloat16Array) { + this._data.set(values.raw, offset); + return; + } + for (let i = 0; i < values.length; i++) { + this._data[offset + i] = toBFloat16Bits(values[i]); + } + } + + setValue(index: number, value: BFloat16 | number): void { + this._data[index] = toBFloat16Bits(value); } - set(index: number, value: BFloat16 | number): void { - this._data[index] - = value instanceof BFloat16 ? value.toBits() : BFloat16.fromFloat32(value).toBits(); + fill(value: BFloat16 | number, start?: number, end?: number): this { + this._data.fill(toBFloat16Bits(value), start, end); + return this; } get raw(): Uint16Array { @@ -97,21 +142,33 @@ export class BFloat16Array { } static fromRaw(data: Uint16Array): BFloat16Array { - const arr = new BFloat16Array(data.length); - arr._data.set(data); - return arr; + const array = Object.create(BFloat16Array.prototype) as BFloat16Array; + array._data = data; + return array; + } + + slice(start?: number, end?: number): BFloat16Array { + return BFloat16Array.fromRaw(this._data.slice(start, end)); + } + + subarray(begin?: number, end?: number): BFloat16Array { + return BFloat16Array.fromRaw(this._data.subarray(begin, end)); + } + + toArray(): number[] { + return Array.from(this); } - [Symbol.iterator](): IterableIterator { + [Symbol.iterator](): IterableIterator { let i = 0; const data = this._data; const len = data.length; return { - next(): IteratorResult { + next(): IteratorResult { if (i < len) { - return { value: BFloat16.fromBits(data[i++]), done: false }; + return { value: fromBFloat16Bits(data[i++]), done: false }; } - return { value: undefined as unknown as BFloat16, done: true }; + return { value: undefined as unknown as number, done: true }; }, [Symbol.iterator]() { return this; diff --git a/javascript/packages/core/lib/types/boolArray.ts b/javascript/packages/core/lib/types/boolArray.ts new file mode 100644 index 0000000000..e7419fecd5 --- /dev/null +++ b/javascript/packages/core/lib/types/boolArray.ts @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +export class BoolArray { + static readonly BYTES_PER_ELEMENT = 1; + readonly BYTES_PER_ELEMENT = 1; + private _data: Uint8Array | null; + private _length: number; + private _packedBits: number; + + constructor(length: number); + constructor(source: BoolArray | Uint8Array | ArrayLike); + constructor(lengthOrSource: number | BoolArray | Uint8Array | ArrayLike) { + this._packedBits = 0; + if (typeof lengthOrSource === "number") { + this._data = new Uint8Array(lengthOrSource); + this._length = lengthOrSource; + } else if (lengthOrSource instanceof BoolArray) { + this._data = new Uint8Array(lengthOrSource.raw); + this._length = this._data.length; + } else if (lengthOrSource instanceof Uint8Array) { + this._data = new Uint8Array(lengthOrSource); + this._length = this._data.length; + } else { + this._data = new Uint8Array(lengthOrSource.length); + this._length = lengthOrSource.length; + for (let i = 0; i < lengthOrSource.length; i++) { + this._data[i] = lengthOrSource[i] ? 1 : 0; + } + } + } + + get length(): number { + return this._length; + } + + get byteLength(): number { + return this._length; + } + + get byteOffset(): number { + return this._data === null ? 0 : this._data.byteOffset; + } + + get buffer(): ArrayBufferLike { + return this.raw.buffer; + } + + get raw(): Uint8Array { + let data = this._data; + if (data !== null) { + return data; + } + data = new Uint8Array(this._length); + let bits = this._packedBits; + for (let i = 0; i < this._length; i++) { + data[i] = bits & 0xFF; + bits >>>= 8; + } + this._data = data; + this._packedBits = 0; + return data; + } + + get(index: number): boolean { + const data = this._data; + if (data !== null) { + return data[index] !== 0; + } + return ((this._packedBits >>> (index * 8)) & 0xFF) !== 0; + } + + at(index: number): boolean | undefined { + const normalized = index < 0 ? this.length + index : index; + if (normalized < 0 || normalized >= this.length) { + return undefined; + } + return this.get(normalized); + } + + set(values: BoolArray | ArrayLike, offset = 0): void { + const data = this.raw; + if (values instanceof BoolArray) { + data.set(values.raw, offset); + return; + } + for (let i = 0; i < values.length; i++) { + data[offset + i] = values[i] ? 1 : 0; + } + } + + setValue(index: number, value: boolean): void { + this.raw[index] = value ? 1 : 0; + } + + fill(value: boolean, start?: number, end?: number): this { + this.raw.fill(value ? 1 : 0, start, end); + return this; + } + + slice(start?: number, end?: number): BoolArray { + return BoolArray.fromRaw(this.raw.slice(start, end)); + } + + subarray(begin?: number, end?: number): BoolArray { + return BoolArray.fromRaw(this.raw.subarray(begin, end)); + } + + toArray(): boolean[] { + return Array.from(this); + } + + static fromRaw(data: Uint8Array): BoolArray { + const array = Object.create(BoolArray.prototype) as BoolArray; + array._data = data; + array._length = data.length; + array._packedBits = 0; + return array; + } + + static fromPackedBytes(bits: number, length: number): BoolArray { + const array = Object.create(BoolArray.prototype) as BoolArray; + array._data = null; + array._length = length; + array._packedBits = bits; + return array; + } + + [Symbol.iterator](): IterableIterator { + let i = 0; + const data = this._data; + const len = this._length; + const bits = this._packedBits; + return { + next(): IteratorResult { + if (i < len) { + const value = data === null + ? ((bits >>> ((i++) * 8)) & 0xFF) !== 0 + : data[i++] !== 0; + return { value, done: false }; + } + return { value: undefined as unknown as boolean, done: true }; + }, + [Symbol.iterator]() { + return this; + }, + }; + } +} diff --git a/javascript/packages/core/lib/types/float16.ts b/javascript/packages/core/lib/types/float16.ts new file mode 100644 index 0000000000..e346236d5c --- /dev/null +++ b/javascript/packages/core/lib/types/float16.ts @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +const float32View = new Float32Array(1); +const int32View = new Int32Array(float32View.buffer); + +export function toFloat16Bits(value: number) { + float32View[0] = value; + const floatValue = int32View[0]; + const sign = (floatValue >>> 16) & 0x8000; + const exponent = ((floatValue >>> 23) & 0xff) - 127; + const significand = floatValue & 0x7fffff; + + if (exponent === 128) { + return sign | 0x7c00 | (significand !== 0 ? 0x0200 : 0); + } + + if (exponent > 15) { + return sign | 0x7c00; + } + + if (exponent < -14) { + return sign | ((significand | 0x800000) >> (13 - 14 - exponent)); + } + + return sign | ((exponent + 15) << 10) | (significand >> 13); +} + +export function fromFloat16Bits(bits: number): number { + const sign = bits >> 15; + const exponent = (bits >> 10) & 0x1f; + const mantissa = bits & 0x3ff; + + if (exponent === 0) { + if (mantissa === 0) { + return sign === 0 ? 0 : -0; + } + return (sign === 0 ? 1 : -1) * mantissa * 2 ** (1 - 15 - 10); + } + + if (exponent === 31) { + return mantissa === 0 ? (sign === 0 ? Infinity : -Infinity) : NaN; + } + + return (sign === 0 ? 1 : -1) * (1 + mantissa * 2 ** -10) * 2 ** (exponent - 15); +} + +type Float16Source = ForyFloat16Array | Uint16Array | ArrayLike; + +export class ForyFloat16Array { + static readonly BYTES_PER_ELEMENT = 2; + readonly BYTES_PER_ELEMENT = 2; + private _data: Uint16Array; + + constructor(length: number); + constructor(source: Float16Source); + constructor(lengthOrSource: number | Float16Source) { + if (typeof lengthOrSource === "number") { + this._data = new Uint16Array(lengthOrSource); + } else if (lengthOrSource instanceof ForyFloat16Array) { + this._data = new Uint16Array(lengthOrSource.raw); + } else if (lengthOrSource instanceof Uint16Array) { + this._data = new Uint16Array(lengthOrSource); + } else { + this._data = new Uint16Array(lengthOrSource.length); + for (let i = 0; i < lengthOrSource.length; i++) { + this._data[i] = toFloat16Bits(lengthOrSource[i]); + } + } + } + + get length(): number { + return this._data.length; + } + + get byteLength(): number { + return this._data.byteLength; + } + + get byteOffset(): number { + return this._data.byteOffset; + } + + get buffer(): ArrayBufferLike { + return this._data.buffer; + } + + get raw(): Uint16Array { + return this._data; + } + + get(index: number): number { + return fromFloat16Bits(this._data[index]); + } + + at(index: number): number | undefined { + const normalized = index < 0 ? this.length + index : index; + if (normalized < 0 || normalized >= this.length) { + return undefined; + } + return this.get(normalized); + } + + set(values: ForyFloat16Array | ArrayLike, offset = 0): void { + if (values instanceof ForyFloat16Array) { + this._data.set(values.raw, offset); + return; + } + for (let i = 0; i < values.length; i++) { + this._data[offset + i] = toFloat16Bits(values[i]); + } + } + + setValue(index: number, value: number): void { + this._data[index] = toFloat16Bits(value); + } + + fill(value: number, start?: number, end?: number): this { + this._data.fill(toFloat16Bits(value), start, end); + return this; + } + + slice(start?: number, end?: number): ForyFloat16Array { + return ForyFloat16Array.fromRaw(this._data.slice(start, end)); + } + + subarray(begin?: number, end?: number): ForyFloat16Array { + return ForyFloat16Array.fromRaw(this._data.subarray(begin, end)); + } + + toArray(): number[] { + return Array.from(this); + } + + static fromRaw(data: Uint16Array): ForyFloat16Array { + const array = Object.create(ForyFloat16Array.prototype) as ForyFloat16Array; + array._data = data; + return array; + } + + [Symbol.iterator](): IterableIterator { + let i = 0; + const data = this._data; + const len = data.length; + return { + next(): IteratorResult { + if (i < len) { + return { value: fromFloat16Bits(data[i++]), done: false }; + } + return { value: undefined as unknown as number, done: true }; + }, + [Symbol.iterator]() { + return this; + }, + }; + } +} + +type NativeFloat16ArrayConstructor = { + new(length: number): ArrayLike; + new(source: ArrayLike): ArrayLike; + readonly BYTES_PER_ELEMENT: number; +}; + +const NativeFloat16Array = (globalThis as unknown as { Float16Array?: NativeFloat16ArrayConstructor }).Float16Array; + +export type Float16Array = ForyFloat16Array | ArrayLike; + +export const Float16Array = ( + NativeFloat16Array ?? ForyFloat16Array +) as NativeFloat16ArrayConstructor | typeof ForyFloat16Array; + +export function isFloat16Array(value: unknown): boolean { + return value instanceof ForyFloat16Array + || (NativeFloat16Array !== undefined && value instanceof NativeFloat16Array); +} + +export function getFloat16Raw(value: unknown): Uint16Array | null { + return value instanceof ForyFloat16Array ? value.raw : null; +} + +export function createFloat16ArrayFromRaw(raw: Uint16Array): Float16Array { + if (NativeFloat16Array !== undefined) { + const result = new NativeFloat16Array(raw.length) as { [index: number]: number }; + for (let i = 0; i < raw.length; i++) { + result[i] = fromFloat16Bits(raw[i]); + } + return result as Float16Array; + } + return ForyFloat16Array.fromRaw(raw); +} + +export function createFloat16Array(source: ArrayLike): Float16Array { + return new Float16Array(source) as Float16Array; +} diff --git a/javascript/packages/core/lib/writer/index.ts b/javascript/packages/core/lib/writer/index.ts index f51c35b8a8..21fd51154d 100644 --- a/javascript/packages/core/lib/writer/index.ts +++ b/javascript/packages/core/lib/writer/index.ts @@ -20,8 +20,8 @@ import { HalfMaxInt32, HalfMinInt32, Hps, LATIN1, UTF16, UTF8 } from "../type"; import { PlatformBuffer, alloc, strByteLength } from "../platformBuffer"; import { OwnershipError } from "../error"; -import { toFloat16, toBFloat16 } from "./number"; -import { BFloat16 } from "../types/bfloat16"; +import { toFloat16Bits } from "../types/float16"; +import { BFloat16, toBFloat16Bits } from "../types/bfloat16"; const MAX_POOL_SIZE = 1024 * 1024 * 3; // 3MB @@ -466,13 +466,11 @@ export class BinaryWriter { } writeFloat16(value: number) { - this.writeUint16(toFloat16(value)); + this.writeUint16(toFloat16Bits(value)); } writeBfloat16(value: BFloat16 | number) { - const bits - = value instanceof BFloat16 ? value.toBits() : toBFloat16(value); - this.writeUint16(bits); + this.writeUint16(toBFloat16Bits(value)); } writeGetCursor() { diff --git a/javascript/packages/core/lib/writer/number.ts b/javascript/packages/core/lib/writer/number.ts deleted file mode 100644 index 1c38ce39ee..0000000000 --- a/javascript/packages/core/lib/writer/number.ts +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -const float32View = new Float32Array(1); -const int32View = new Int32Array(float32View.buffer); - -export function toFloat16(value: number) { - float32View[0] = value; - const floatValue = int32View[0]; - const sign = (floatValue >>> 16) & 0x8000; // sign only - const exponent = ((floatValue >>> 23) & 0xff) - 127; // extract exponent from floatValue - const significand = floatValue & 0x7fffff; // extract significand from floatValue - - if (exponent === 128) { // floatValue is NaN or Infinity - return sign | 0x7c00 | (significand !== 0 ? 0x0200 : 0); - } - - if (exponent > 15) { - return sign | 0x7c00; // return Infinity - } - - if (exponent < -14) { - // subnormal - // shift amount = 13 - 14 - exponent = -1 - exponent - return sign | ((significand | 0x800000) >> (13 - 14 - exponent)); - } - - return sign | ((exponent + 15) << 10) | (significand >> 13); -} - -const float32ViewBf = new Float32Array(1); -const uint32ViewBf = new Uint32Array(float32ViewBf.buffer); - -/** - * Convert float32 to bfloat16 bits (round-to-nearest, ties-to-even). - * BFloat16 layout: 1 sign, 8 exponent, 7 mantissa. - */ -export function toBFloat16(value: number): number { - float32ViewBf[0] = value; - const bits = uint32ViewBf[0]; - const exponent = (bits >> 23) & 0xff; - if (exponent === 255) { - return (bits >> 16) & 0xffff; - } - const remainder = bits & 0x1ffff; - let u = (bits + 0x8000) >> 16; - if (remainder === 0x8000 && (u & 1) !== 0) { - u--; - } - return u & 0xffff; -} diff --git a/javascript/test/array.test.ts b/javascript/test/array.test.ts index 9e4ef64369..ef6d84f709 100644 --- a/javascript/test/array.test.ts +++ b/javascript/test/array.test.ts @@ -17,7 +17,13 @@ * under the License. */ -import Fory, { Type, BFloat16Array } from '../packages/core/index'; +import Fory, { + Type, + BFloat16Array, + BoolArray, + Float16Array, + ForyFloat16Array, +} from '../packages/core/index'; import { TypeId } from '../packages/core/lib/type'; import { describe, expect, test } from '@jest/globals'; import * as beautify from 'js-beautify'; @@ -79,13 +85,12 @@ describe('array', () => { const result = fory.deserialize( input ); - expect(result).toEqual({ - a: [true, false], - a2: new Int16Array([1, 2, 3]), - a3: new Int32Array([3, 5, 76]), - a4: new BigInt64Array([634n, 564n, 76n]), - a6: new Float64Array([234243.555, 55654.679]), - }) + expect(result.a).toBeInstanceOf(BoolArray); + expect(Array.from(result.a)).toEqual([true, false]); + expect(result.a2).toEqual(new Int16Array([1, 2, 3])); + expect(result.a3).toEqual(new Int32Array([3, 5, 76])); + expect(result.a4).toEqual(new BigInt64Array([634n, 564n, 76n])); + expect(result.a6).toEqual(new Float64Array([234243.555, 55654.679])); }); @@ -122,9 +127,10 @@ describe('array', () => { const result = fory.deserialize( input ); - expect(result.a6[0]).toBeCloseTo(1.5, 1) - expect(result.a6[1]).toBeCloseTo(2.5, 1) - expect(result.a6[2]).toBeCloseTo(-4.5, 1) + expect(result.a6).toBeInstanceOf(Float16Array as any); + expect(Array.from(result.a6 as Iterable)[0]).toBeCloseTo(1.5, 1) + expect(Array.from(result.a6 as Iterable)[1]).toBeCloseTo(2.5, 1) + expect(Array.from(result.a6 as Iterable)[2]).toBeCloseTo(-4.5, 1) }); test('should bfloat16Array work', () => { @@ -139,10 +145,11 @@ describe('array', () => { a7: [1.5, 2.5, -4.5], }, serialize); const result = fory.deserialize(input); + expect(result.a7).toBeInstanceOf(BFloat16Array); expect(result.a7).toHaveLength(3); - expect(result.a7[0].toFloat32()).toBeCloseTo(1.5, 2); - expect(result.a7[1].toFloat32()).toBeCloseTo(2.5, 2); - expect(result.a7[2].toFloat32()).toBeCloseTo(-4.5, 2); + expect(result.a7.get(0)).toBeCloseTo(1.5, 2); + expect(result.a7.get(1)).toBeCloseTo(2.5, 2); + expect(result.a7.get(2)).toBeCloseTo(-4.5, 2); }); test('should bfloat16Array accept BFloat16Array', () => { @@ -156,9 +163,68 @@ describe('array', () => { const arr = new BFloat16Array([1.25, -2.5, 0]); const input = fory.serialize({ a7: arr }, serialize); const result = fory.deserialize(input); + expect(result.a7).toBeInstanceOf(BFloat16Array); expect(result.a7).toHaveLength(3); - expect(result.a7[0].toFloat32()).toBeCloseTo(1.25, 2); - expect(result.a7[1].toFloat32()).toBeCloseTo(-2.5, 2); - expect(result.a7[2].toFloat32()).toBe(0); + expect(result.a7.get(0)).toBeCloseTo(1.25, 2); + expect(result.a7.get(1)).toBeCloseTo(-2.5, 2); + expect(result.a7.get(2)).toBe(0); + }); + + test('should expose bool and reduced-precision array carriers', () => { + const bools = new BoolArray([true, false, true]); + bools.setValue(1, true); + expect(Array.from(bools)).toEqual([true, true, true]); + expect(bools.raw).toEqual(new Uint8Array([1, 1, 1])); + + const f16 = new ForyFloat16Array([1.5, -2]); + expect(f16.get(0)).toBeCloseTo(1.5, 1); + expect(f16.get(1)).toBeCloseTo(-2, 1); + + const bf16 = new BFloat16Array([1.5, -2]); + expect(Array.from(bf16)).toEqual([1.5, -2]); + }); + + test('should write dense array protocol bytes in little-endian order', () => { + const fory = new Fory({ compatible: false, ref: true }); + + const uint16Type = Type.struct({ typeName: 'example.uint16array' }, { + values: Type.uint16Array(), + }); + const uint16Serializer = fory.register(uint16Type).serializer; + const uint16Bytes = fory.serialize({ + values: new Uint16Array([0x1234, 0xabcd]), + }, uint16Serializer); + expect(containsBytes(uint16Bytes, [0x34, 0x12, 0xcd, 0xab])).toBe(true); + + const float16Type = Type.struct({ typeName: 'example.float16array' }, { + values: Type.float16Array(), + }); + const float16Serializer = fory.register(float16Type).serializer; + const float16Bytes = fory.serialize({ + values: new ForyFloat16Array([1, -2]), + }, float16Serializer); + expect(containsBytes(float16Bytes, [0x00, 0x3c, 0x00, 0xc0])).toBe(true); + + const bfloat16Type = Type.struct({ typeName: 'example.bfloat16array' }, { + values: Type.bfloat16Array(), + }); + const bfloat16Serializer = fory.register(bfloat16Type).serializer; + const bfloat16Bytes = fory.serialize({ + values: new BFloat16Array([1, -2]), + }, bfloat16Serializer); + expect(containsBytes(bfloat16Bytes, [0x80, 0x3f, 0x00, 0xc0])).toBe(true); }); }); + +function containsBytes(bytes: Uint8Array, needle: number[]) { + outer: + for (let i = 0; i <= bytes.length - needle.length; i++) { + for (let j = 0; j < needle.length; j++) { + if (bytes[i + j] !== needle[j]) { + continue outer; + } + } + return true; + } + return false; +} diff --git a/javascript/test/crossLanguage.test.ts b/javascript/test/crossLanguage.test.ts index 86a7257798..e94b96ea3f 100644 --- a/javascript/test/crossLanguage.test.ts +++ b/javascript/test/crossLanguage.test.ts @@ -1394,6 +1394,43 @@ describe("bool", () => { const serializedData = fory.serialize(deserializedStruct); writeToFile(serializedData as Buffer); }); + + test("test_list_array_compatible_list_to_array", () => { + const fory = new Fory({ compatible: true }); + const serializer = fory.register( + Type.struct(901, { + values: Type.int32Array().setId(1), + }), + ); + + const value = serializer.deserialize(content); + writeToFile(serializer.serialize(value) as Buffer); + }); + + test("test_list_array_compatible_array_to_list", () => { + const fory = new Fory({ compatible: true }); + const serializer = fory.register( + Type.struct(901, { + values: Type.list(Type.int32({ encoding: "fixed" })).setId(1), + }), + ); + + const value = serializer.deserialize(content); + writeToFile(serializer.serialize(value) as Buffer); + }); + + test("test_list_array_compatible_nullable_list_to_array_error", () => { + const fory = new Fory({ compatible: true }); + const serializer = fory.register( + Type.struct(901, { + values: Type.int32Array().setId(1), + }), + ); + + expect(() => serializer.deserialize(content)).toThrow(); + writeToFile(content); + }); + test("test_one_enum_field_schema", () => { const fory = new Fory({ compatible: false, diff --git a/javascript/test/sizeLimit.test.ts b/javascript/test/sizeLimit.test.ts index b347514d92..73a61caaf0 100644 --- a/javascript/test/sizeLimit.test.ts +++ b/javascript/test/sizeLimit.test.ts @@ -17,7 +17,7 @@ * under the License. */ -import Fory, { Type } from '../packages/core/index'; +import Fory, { Type, BoolArray } from '../packages/core/index'; import { describe, expect, test } from '@jest/globals'; describe('size-limit guardrails', () => { @@ -274,7 +274,8 @@ describe('size-limit guardrails', () => { })); const data = { flags: [true, false, true] }; const result = deserialize(serialize(data)); - expect(result!.flags).toEqual([true, false, true]); + expect(result!.flags).toBeInstanceOf(BoolArray); + expect(Array.from(result!.flags)).toEqual([true, false, true]); }); test('should throw when bool array exceeds maxCollectionSize', () => { diff --git a/javascript/test/typemeta.test.ts b/javascript/test/typemeta.test.ts index 473ec38418..c83fff543c 100644 --- a/javascript/test/typemeta.test.ts +++ b/javascript/test/typemeta.test.ts @@ -17,7 +17,12 @@ * under the License. */ -import Fory, { Type } from "../packages/core/index"; +import Fory, { + BFloat16Array, + BoolArray, + Float16Array, + Type, +} from "../packages/core/index"; import { ReadContext } from "../packages/core/lib/context"; import { TypeMeta } from "../packages/core/lib/meta/TypeMeta"; import { BinaryReader } from "../packages/core/lib/reader"; @@ -195,6 +200,139 @@ describe("typemeta", () => { }); }); + test("adapts only immediate compatible list and dense array field pairs", () => { + const writerFory = new Fory({ compatible: true }); + const readerFory = new Fory({ compatible: true }); + + const writerType = Type.struct(7211, { + values: Type.list(Type.int32({ encoding: "fixed" })).setId(1), + }); + const readerType = Type.struct(7211, { + values: Type.int32Array().setId(1), + }); + + const bytes = writerFory.register(writerType).serialize({ + values: [1, 2, 3], + }); + const result = readerFory.register(readerType).deserialize(bytes); + + expect(result.values).toBeInstanceOf(Int32Array); + expect(Array.from(result.values)).toEqual([1, 2, 3]); + }); + + test("adapts compatible list fields to reduced-precision dense array carriers", () => { + const writerFory = new Fory({ compatible: true }); + const readerFory = new Fory({ compatible: true }); + + const writerType = Type.struct(7214, { + bools: Type.list(Type.bool()).setId(1), + float16s: Type.list(Type.float16()).setId(2), + bfloat16s: Type.list(Type.bfloat16()).setId(3), + }); + const readerType = Type.struct(7214, { + bools: Type.boolArray().setId(1), + float16s: Type.float16Array().setId(2), + bfloat16s: Type.bfloat16Array().setId(3), + }); + + const bytes = writerFory.register(writerType).serialize({ + bools: [true, false], + float16s: [1.5, -2], + bfloat16s: [1.5, -2], + }); + const result = readerFory.register(readerType).deserialize(bytes); + + expect(result.bools).toBeInstanceOf(BoolArray); + expect(Array.from(result.bools)).toEqual([true, false]); + expect(result.float16s).toBeInstanceOf(Float16Array as any); + expect(Array.from(result.float16s as Iterable)[0]).toBeCloseTo(1.5, 1); + expect(Array.from(result.float16s as Iterable)[1]).toBeCloseTo(-2, 1); + expect(result.bfloat16s).toBeInstanceOf(BFloat16Array); + expect(Array.from(result.bfloat16s as Iterable)).toEqual([1.5, -2]); + }); + + test("adapts compatible dense array field to immediate list field", () => { + const writerFory = new Fory({ compatible: true }); + const readerFory = new Fory({ compatible: true }); + + const writerType = Type.struct(7213, { + values: Type.int32Array().setId(1), + }); + const readerType = Type.struct(7213, { + values: Type.list(Type.int32({ encoding: "fixed" })).setId(1), + }); + + const bytes = writerFory.register(writerType).serialize({ + values: new Int32Array([1, 2, 3]), + }); + const result = readerFory.register(readerType).deserialize(bytes); + + expect(Array.isArray(result.values)).toBe(true); + expect(result).toEqual({ values: [1, 2, 3] }); + }); + + test("rejects compatible list to dense array when payload has nullable elements", () => { + const writerFory = new Fory({ compatible: true }); + const readerFory = new Fory({ compatible: true }); + + const writerType = Type.struct(7212, { + values: Type.list( + Type.int32({ encoding: "fixed" }).setNullable(true), + ).setId(1), + }); + const readerType = Type.struct(7212, { + values: Type.int32Array().setId(1), + }); + + const serializer = writerFory.register(writerType); + const nonNullBytes = serializer.serialize({ + values: [1, 2, 3], + }); + const result = readerFory.register(readerType).deserialize(nonNullBytes); + expect(Array.from(result.values as Int32Array)).toEqual([1, 2, 3]); + + const nullableBytes = serializer.serialize({ + values: [1, null, 3], + }); + expect(() => readerFory.register(readerType).deserialize(nullableBytes)).toThrow(); + }); + + test("rejects incompatible immediate list and dense array element fields", () => { + const writerFory = new Fory({ compatible: true }); + const readerFory = new Fory({ compatible: true }); + + const writerType = Type.struct(7215, { + values: Type.list(Type.string()).setId(1), + }); + const readerType = Type.struct(7215, { + values: Type.int32Array().setId(1), + }); + + const bytes = writerFory.register(writerType).serialize({ + values: ["1", "2"], + }); + + expect(() => readerFory.register(readerType).deserialize(bytes)).toThrow(/list\/array/); + }); + + test("rejects nested compatible list and dense array positions", () => { + const writerFory = new Fory({ compatible: true }); + const readerFory = new Fory({ compatible: true }); + + const writerType = Type.struct(7216, { + values: Type.list(Type.int32Array()).setId(1), + }); + const readerType = Type.struct(7216, { + values: Type.list(Type.list(Type.int32({ encoding: "fixed" }))).setId(1), + }); + + const bytes = writerFory.register(writerType).serialize({ + values: [new Int32Array([1, 2])], + }); + + expect(() => readerFory.register(readerType).deserialize(bytes)).toThrow(/list\/array/); + }); + test("keeps compatible named schema evolution working when field count differs", () => { const writerFory = new Fory({ compatible: true }); const readerFory = new Fory({ compatible: true }); diff --git a/python/pyfory/meta/typedef.py b/python/pyfory/meta/typedef.py index e13d11e18d..97551acc4b 100644 --- a/python/pyfory/meta/typedef.py +++ b/python/pyfory/meta/typedef.py @@ -503,7 +503,7 @@ def create_serializer(self, resolver, type_): if declared_root_type: declared_root_type, *extra = declared_root_type elem_type = extra[0] if extra else None - elif type_ and len(type_) >= 2: + elif type_ and not isinstance(type_, ArrayMeta) and len(type_) >= 2: elem_type = type_[1] elem_serializer = self.element_type.create_serializer(resolver, elem_type) elem_override = getattr(self.element_type, "tracking_ref_override", None) @@ -613,42 +613,109 @@ def __repr__(self): ) ) +_ARRAY_ELEMENT_TYPE_IDS = { + TypeId.BOOL_ARRAY: TypeId.BOOL, + TypeId.INT8_ARRAY: TypeId.INT8, + TypeId.INT16_ARRAY: TypeId.INT16, + TypeId.INT32_ARRAY: TypeId.INT32, + TypeId.INT64_ARRAY: TypeId.INT64, + TypeId.UINT8_ARRAY: TypeId.UINT8, + TypeId.UINT16_ARRAY: TypeId.UINT16, + TypeId.UINT32_ARRAY: TypeId.UINT32, + TypeId.UINT64_ARRAY: TypeId.UINT64, + TypeId.FLOAT16_ARRAY: TypeId.FLOAT16, + TypeId.BFLOAT16_ARRAY: TypeId.BFLOAT16, + TypeId.FLOAT32_ARRAY: TypeId.FLOAT32, + TypeId.FLOAT64_ARRAY: TypeId.FLOAT64, +} + + +def _list_array_element_type_matches(list_field_type: FieldType, array_field_type: FieldType) -> bool: + array_element_type_id = _ARRAY_ELEMENT_TYPE_IDS.get(array_field_type.type_id) + if array_element_type_id is None: + return False + return list_field_type.type_id == TypeId.LIST and _list_element_type_matches_array_element( + list_field_type.element_type.type_id, array_element_type_id + ) + + +def _list_element_type_matches_array_element(list_element_type_id: TypeId, array_element_type_id: TypeId) -> bool: + if list_element_type_id == array_element_type_id: + return True + list_domain = _INT_TYPE_DOMAINS.get(list_element_type_id) + return list_domain is not None and list_domain == _INT_TYPE_DOMAINS.get(array_element_type_id) + + +def _is_root_list_array_pair(remote_field_type: FieldType, local_field_type: FieldType) -> bool: + if local_field_type is None: + return False + if remote_field_type.type_id == TypeId.LIST and local_field_type.type_id in _ARRAY_TYPE_IDS: + return _list_array_element_type_matches(remote_field_type, local_field_type) + if local_field_type.type_id == TypeId.LIST and remote_field_type.type_id in _ARRAY_TYPE_IDS: + return _list_array_element_type_matches(local_field_type, remote_field_type) + return False + + +def _is_root_list_array_shape_pair(remote_field_type: FieldType, local_field_type: FieldType) -> bool: + if local_field_type is None: + return False + return ( + remote_field_type.type_id == TypeId.LIST + and local_field_type.type_id in _ARRAY_TYPE_IDS + or local_field_type.type_id == TypeId.LIST + and remote_field_type.type_id in _ARRAY_TYPE_IDS + ) + -def _payload_shape_matches(remote_field_type: FieldType, local_field_type: FieldType) -> bool: +def _remote_list_to_local_array_allowed(remote_field_type: FieldType, local_field_type: FieldType) -> bool: + return ( + remote_field_type.type_id == TypeId.LIST + and local_field_type.type_id in _ARRAY_TYPE_IDS + and _list_array_element_type_matches(remote_field_type, local_field_type) + ) + + +def _payload_shape_matches(remote_field_type: FieldType, local_field_type: FieldType, top_level: bool = True) -> bool: if local_field_type is None: return False remote_type_id = remote_field_type.type_id local_type_id = local_field_type.type_id if _is_bytes_uint8_array_pair(remote_type_id, local_type_id): return True + if top_level and _is_root_list_array_pair(remote_field_type, local_field_type): + return True if remote_type_id != local_type_id: return False if remote_type_id in (TypeId.LIST, TypeId.SET): - return _payload_shape_matches(remote_field_type.element_type, local_field_type.element_type) + return _payload_shape_matches(remote_field_type.element_type, local_field_type.element_type, False) if remote_type_id == TypeId.MAP: - return _payload_shape_matches(remote_field_type.key_type, local_field_type.key_type) and _payload_shape_matches( + return _payload_shape_matches(remote_field_type.key_type, local_field_type.key_type, False) and _payload_shape_matches( remote_field_type.value_type, local_field_type.value_type, + False, ) return True -def _payload_shape_needs_local_carrier(remote_field_type: FieldType, local_field_type: FieldType) -> bool: +def _payload_shape_needs_local_carrier(remote_field_type: FieldType, local_field_type: FieldType, top_level: bool = True) -> bool: remote_type_id = remote_field_type.type_id local_type_id = local_field_type.type_id if _is_bytes_uint8_array_pair(remote_type_id, local_type_id): return True + if top_level and _is_root_list_array_pair(remote_field_type, local_field_type): + return True if remote_type_id != local_type_id: return False if remote_type_id in _ARRAY_TYPE_IDS: return True if remote_type_id in (TypeId.LIST, TypeId.SET): - return _payload_shape_needs_local_carrier(remote_field_type.element_type, local_field_type.element_type) + return _payload_shape_needs_local_carrier(remote_field_type.element_type, local_field_type.element_type, False) if remote_type_id == TypeId.MAP: return _payload_shape_needs_local_carrier( remote_field_type.key_type, local_field_type.key_type, - ) or _payload_shape_needs_local_carrier(remote_field_type.value_type, local_field_type.value_type) + False, + ) or _payload_shape_needs_local_carrier(remote_field_type.value_type, local_field_type.value_type, False) return False @@ -668,6 +735,52 @@ def _create_compatible_field_serializer( local_field_type: typing.Optional[FieldType], local_declared_type, ): + if _is_root_list_array_shape_pair(remote_field_type, local_field_type) and not _is_root_list_array_pair( + remote_field_type, + local_field_type, + ): + from pyfory.error import TypeNotCompatibleError + + raise TypeNotCompatibleError(f"Field {field_name!r} has unsupported list/array schema mismatch") + + if _is_root_list_array_pair(remote_field_type, local_field_type): + from pyfory.serializer import ( + CompatibleArrayToListFieldSerializer, + CompatibleListToArrayFieldSerializer, + ForyArrayFieldSerializer, + fory_array_wrapper_type, + ) + + if remote_field_type.type_id == TypeId.LIST: + if not _remote_list_to_local_array_allowed(remote_field_type, local_field_type): + return remote_field_type.create_serializer(resolver, local_declared_type) + target_serializer = _create_local_typehint_serializer(resolver, field_name, type_hint) + if target_serializer is None: + wrapper_type = fory_array_wrapper_type(local_field_type.type_id) + target_serializer = ForyArrayFieldSerializer( + resolver, + wrapper_type, + local_field_type.type_id, + field_name, + ) + elem_serializer = remote_field_type.element_type.create_serializer(resolver, None) + return CompatibleListToArrayFieldSerializer( + resolver, + target_serializer, + elem_serializer, + field_name, + ) + + remote_wrapper_type = fory_array_wrapper_type(remote_field_type.type_id) + remote_serializer = ForyArrayFieldSerializer( + resolver, + remote_wrapper_type, + remote_field_type.type_id, + field_name, + ) + elem_serializer = local_field_type.element_type.create_serializer(resolver, None) + return CompatibleArrayToListFieldSerializer(resolver, remote_serializer, elem_serializer) + if _payload_shape_matches(remote_field_type, local_field_type) and _payload_shape_needs_local_carrier(remote_field_type, local_field_type): serializer = _create_local_typehint_serializer(resolver, field_name, type_hint) if serializer is not None: @@ -701,7 +814,7 @@ def _is_bytes_uint8_array_pair(remote_type_id: int, local_type_id: int) -> bool: ) -def _field_type_assignment(remote_field_type: FieldType, local_field_type: FieldType) -> typing.Tuple[bool, bool]: +def _field_type_assignment(remote_field_type: FieldType, local_field_type: FieldType, top_level: bool = True) -> typing.Tuple[bool, bool]: if local_field_type is None: return False, False needs_validation = _requires_nullable_validation(remote_field_type, local_field_type) @@ -711,12 +824,15 @@ def _field_type_assignment(remote_field_type: FieldType, local_field_type: Field return True, needs_validation if remote_type_id == TypeId.UNKNOWN: return True, True + if top_level and _is_root_list_array_pair(remote_field_type, local_field_type): + return True, True if remote_type_id in (TypeId.LIST, TypeId.SET): if local_type_id != remote_type_id: return False, False child_assignable, child_needs_validation = _field_type_assignment( remote_field_type.element_type, local_field_type.element_type, + False, ) return child_assignable, needs_validation or child_needs_validation if remote_type_id == TypeId.MAP: @@ -725,10 +841,12 @@ def _field_type_assignment(remote_field_type: FieldType, local_field_type: Field key_assignable, key_needs_validation = _field_type_assignment( remote_field_type.key_type, local_field_type.key_type, + False, ) value_assignable, value_needs_validation = _field_type_assignment( remote_field_type.value_type, local_field_type.value_type, + False, ) return ( key_assignable and value_assignable, diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py index 509f7a1a4c..2126d16c32 100644 --- a/python/pyfory/serializer.py +++ b/python/pyfory/serializer.py @@ -732,6 +732,87 @@ def read(self, buffer): return self.wrapper_serializer.read(buffer) +class CompatibleArrayToListFieldSerializer(Serializer): + def __init__(self, type_resolver, remote_array_serializer, elem_serializer): + super().__init__(type_resolver, list) + self.remote_array_serializer = remote_array_serializer + self.elem_serializer = elem_serializer + self.need_to_write_ref = False + + def write(self, buffer, value): + raise TypeError("compatible array-to-list field serializer is read-only") + + def read(self, read_context): + return list(self.remote_array_serializer.read(read_context)) + + +class CompatibleListToArrayFieldSerializer(Serializer): + def __init__(self, type_resolver, target_serializer, elem_serializer, field_name=None): + super().__init__(type_resolver, target_serializer.type_) + self.target_serializer = target_serializer + self.elem_serializer = elem_serializer + self.field_name = field_name or "" + self.need_to_write_ref = False + + def write(self, buffer, value): + raise TypeError("compatible list-to-array field serializer is read-only") + + def _empty_target(self): + if isinstance(self.target_serializer, ForyArrayFieldSerializer): + return self.target_serializer.wrapper_type() + if isinstance(self.target_serializer, PyArraySerializer): + return array.array(self.target_serializer.typecode) + if np is not None and isinstance(self.target_serializer, Numpy1DArraySerializer): + return np.empty(0, dtype=self.target_serializer.dtype) + raise TypeError(f"Field {self.field_name!r} has unsupported array target serializer {type(self.target_serializer)!r}") + + def _new_target(self, length): + if isinstance(self.target_serializer, ForyArrayFieldSerializer): + return self.target_serializer.wrapper_type() + if isinstance(self.target_serializer, PyArraySerializer): + return array.array(self.target_serializer.typecode) + if np is not None and isinstance(self.target_serializer, Numpy1DArraySerializer): + return np.empty(length, dtype=self.target_serializer.dtype) + raise TypeError(f"Field {self.field_name!r} has unsupported array target serializer {type(self.target_serializer)!r}") + + def read(self, read_context): + from pyfory.collection import ( + COLL_HAS_NULL, + COLL_IS_DECL_ELEMENT_TYPE, + COLL_IS_SAME_TYPE, + COLL_TRACKING_REF, + ) + from pyfory.error import TypeNotCompatibleError + + length = read_context.read_var_uint32() + if length > read_context.max_collection_size: + raise ValueError(f"Collection size {length} exceeds the configured limit of {read_context.max_collection_size}") + if length == 0: + return self._empty_target() + collect_flag = read_context.read_int8() + if (collect_flag & (COLL_HAS_NULL | COLL_TRACKING_REF)) != 0: + raise TypeNotCompatibleError( + f"Field {self.field_name!r} cannot read nullable or ref-tracked list elements as array", + ) + if (collect_flag & (COLL_IS_SAME_TYPE | COLL_IS_DECL_ELEMENT_TYPE)) != (COLL_IS_SAME_TYPE | COLL_IS_DECL_ELEMENT_TYPE): + raise TypeNotCompatibleError( + f"Field {self.field_name!r} requires declared same-type list elements for array compatible read", + ) + + target = self._new_target(length) + append = None if np is not None and isinstance(self.target_serializer, Numpy1DArraySerializer) else target.append + for index in range(length): + item = read_context.read_no_ref(serializer=self.elem_serializer) + try: + if append is None: + target[index] = item + else: + append(item) + except (TypeError, ValueError, OverflowError) as exc: + raise type(exc)(f"{self.field_name}[{index}] invalid for {type(target).__name__}: {exc}") from exc + return target + + class DynamicPyArraySerializer(Serializer): """Serializer for dynamic Python arrays that handles any typecode.""" diff --git a/python/pyfory/tests/test_typedef_encoding.py b/python/pyfory/tests/test_typedef_encoding.py index 102e75ed20..ed93d3fead 100644 --- a/python/pyfory/tests/test_typedef_encoding.py +++ b/python/pyfory/tests/test_typedef_encoding.py @@ -21,7 +21,7 @@ import array from dataclasses import dataclass, make_dataclass -from typing import List, Dict +from typing import List, Dict, Optional import pytest @@ -43,8 +43,10 @@ prepend_header, ) from pyfory.meta.typedef_decoder import decode_typedef +from pyfory.serializer import PyArraySerializer from pyfory.types import TypeId from pyfory import Fory +from pyfory.error import TypeNotCompatibleError try: import numpy as np @@ -102,6 +104,51 @@ class UInt8ArrayPayload: payload: pyfory.Array[pyfory.UInt8] +@dataclass +class Int32ListPayload: + payload: List[pyfory.FixedInt32] + + +@dataclass +class Int32VarintListPayload: + payload: List[pyfory.Int32] + + +@dataclass +class NullableInt32ListPayload: + payload: List[Optional[pyfory.FixedInt32]] + + +@dataclass +class StringListPayload: + payload: List[str] + + +@dataclass +class Int32ArrayPayload: + payload: pyfory.Array[pyfory.Int32] + + +@dataclass +class Int32NDArrayPayload: + payload: pyfory.NDArray[pyfory.Int32] + + +@dataclass +class Int32PyArrayPayload: + payload: pyfory.PyArray[pyfory.Int32] + + +@dataclass +class NestedInt32ListPayload: + payload: List[List[pyfory.FixedInt32]] + + +@dataclass +class NestedInt32ArrayPayload: + payload: List[pyfory.Array[pyfory.Int32]] + + def test_collection_field_type(): """Test collection field type creation and serialization.""" element_type = FieldType(TypeId.INT32, True, True, False) @@ -367,6 +414,174 @@ def test_compatible_uint8_array_assigns_to_bytes(): assert decoded.payload == b"\x01\x02\xff" +def _register_int32_payload(fory, cls): + fory.register(cls, namespace="example", typename="Int32Sequence") + + +def _pyarray_int32_value(values): + for typecode, (_itemsize, _ftype, type_id) in PyArraySerializer.typecode_dict.items(): + if type_id == TypeId.INT32_ARRAY: + return array.array(typecode, values) + raise AssertionError("No array.array typecode maps to INT32_ARRAY") + + +def test_compatible_int32_list_assigns_to_array(): + writer = Fory(xlang=True, compatible=True) + reader = Fory(xlang=True, compatible=True) + _register_int32_payload(writer, Int32ListPayload) + _register_int32_payload(reader, Int32ArrayPayload) + + decoded = reader.deserialize(writer.serialize(Int32ListPayload(payload=[1, 2, 3]))) + + assert isinstance(decoded, Int32ArrayPayload) + assert isinstance(decoded.payload, pyfory.Int32Array) + assert list(decoded.payload) == [1, 2, 3] + + +@pytest.mark.skipif(np is None, reason="Requires numpy") +def test_compatible_int32_list_assigns_to_ndarray(): + writer = Fory(xlang=True, compatible=True) + reader = Fory(xlang=True, compatible=True) + _register_int32_payload(writer, Int32ListPayload) + _register_int32_payload(reader, Int32NDArrayPayload) + + decoded = reader.deserialize(writer.serialize(Int32ListPayload(payload=[1, 2, 3]))) + + assert isinstance(decoded, Int32NDArrayPayload) + assert isinstance(decoded.payload, np.ndarray) + assert decoded.payload.dtype == np.dtype(np.int32) + np.testing.assert_array_equal(decoded.payload, np.array([1, 2, 3], dtype=np.int32)) + + +@pytest.mark.skipif(np is None, reason="Requires numpy") +def test_compatible_empty_int32_list_assigns_to_ndarray(): + writer = Fory(xlang=True, compatible=True) + reader = Fory(xlang=True, compatible=True) + _register_int32_payload(writer, Int32ListPayload) + _register_int32_payload(reader, Int32NDArrayPayload) + + decoded = reader.deserialize(writer.serialize(Int32ListPayload(payload=[]))) + + assert isinstance(decoded, Int32NDArrayPayload) + assert isinstance(decoded.payload, np.ndarray) + assert decoded.payload.dtype == np.dtype(np.int32) + assert decoded.payload.size == 0 + + +def test_compatible_int32_list_assigns_to_pyarray(): + writer = Fory(xlang=True, compatible=True) + reader = Fory(xlang=True, compatible=True) + _register_int32_payload(writer, Int32ListPayload) + _register_int32_payload(reader, Int32PyArrayPayload) + + decoded = reader.deserialize(writer.serialize(Int32ListPayload(payload=[1, 2, 3]))) + + assert isinstance(decoded, Int32PyArrayPayload) + assert isinstance(decoded.payload, array.array) + assert PyArraySerializer.typecode_dict[decoded.payload.typecode][2] == TypeId.INT32_ARRAY + assert decoded.payload.tolist() == [1, 2, 3] + + +def test_compatible_empty_int32_list_assigns_to_pyarray(): + writer = Fory(xlang=True, compatible=True) + reader = Fory(xlang=True, compatible=True) + _register_int32_payload(writer, Int32ListPayload) + _register_int32_payload(reader, Int32PyArrayPayload) + + decoded = reader.deserialize(writer.serialize(Int32ListPayload(payload=[]))) + + assert isinstance(decoded, Int32PyArrayPayload) + assert isinstance(decoded.payload, array.array) + assert PyArraySerializer.typecode_dict[decoded.payload.typecode][2] == TypeId.INT32_ARRAY + assert decoded.payload.tolist() == [] + + +def test_compatible_varint_int32_list_assigns_to_array(): + writer = Fory(xlang=True, compatible=True) + reader = Fory(xlang=True, compatible=True) + _register_int32_payload(writer, Int32VarintListPayload) + _register_int32_payload(reader, Int32ArrayPayload) + + decoded = reader.deserialize(writer.serialize(Int32VarintListPayload(payload=[-1, 2, 3]))) + + assert isinstance(decoded, Int32ArrayPayload) + assert isinstance(decoded.payload, pyfory.Int32Array) + assert list(decoded.payload) == [-1, 2, 3] + + +def test_compatible_int32_array_assigns_to_list(): + writer = Fory(xlang=True, compatible=True) + reader = Fory(xlang=True, compatible=True) + _register_int32_payload(writer, Int32ArrayPayload) + _register_int32_payload(reader, Int32ListPayload) + + decoded = reader.deserialize(writer.serialize(Int32ArrayPayload(payload=pyfory.Int32Array([1, 2, 3])))) + + assert isinstance(decoded, Int32ListPayload) + assert decoded.payload == [1, 2, 3] + + +@pytest.mark.skipif(np is None, reason="Requires numpy") +def test_compatible_int32_ndarray_assigns_to_list(): + writer = Fory(xlang=True, compatible=True) + reader = Fory(xlang=True, compatible=True) + _register_int32_payload(writer, Int32NDArrayPayload) + _register_int32_payload(reader, Int32ListPayload) + + decoded = reader.deserialize(writer.serialize(Int32NDArrayPayload(payload=np.array([1, 2, 3], dtype=np.int32)))) + + assert isinstance(decoded, Int32ListPayload) + assert decoded.payload == [1, 2, 3] + + +def test_compatible_int32_pyarray_assigns_to_list(): + writer = Fory(xlang=True, compatible=True) + reader = Fory(xlang=True, compatible=True) + _register_int32_payload(writer, Int32PyArrayPayload) + _register_int32_payload(reader, Int32ListPayload) + + decoded = reader.deserialize(writer.serialize(Int32PyArrayPayload(payload=_pyarray_int32_value([1, 2, 3])))) + + assert isinstance(decoded, Int32ListPayload) + assert decoded.payload == [1, 2, 3] + + +def test_compatible_nullable_int32_list_payload_rejects_array_read(): + writer = Fory(xlang=True, compatible=True) + reader = Fory(xlang=True, compatible=True) + _register_int32_payload(writer, NullableInt32ListPayload) + _register_int32_payload(reader, Int32ArrayPayload) + + decoded = reader.deserialize(writer.serialize(NullableInt32ListPayload(payload=[1, 2, 3]))) + assert isinstance(decoded, Int32ArrayPayload) + assert list(decoded.payload) == [1, 2, 3] + + with pytest.raises(TypeNotCompatibleError): + reader.deserialize(writer.serialize(NullableInt32ListPayload(payload=[1, None, 3]))) + + +def test_compatible_incompatible_list_array_elements_reject(): + writer = Fory(xlang=True, compatible=True) + reader = Fory(xlang=True, compatible=True) + _register_int32_payload(writer, StringListPayload) + _register_int32_payload(reader, Int32ArrayPayload) + + with pytest.raises(TypeNotCompatibleError): + reader.deserialize(writer.serialize(StringListPayload(payload=["1", "2"]))) + + +def test_compatible_nested_list_array_mismatch_not_assigned(): + writer = Fory(xlang=True, compatible=True) + reader = Fory(xlang=True, compatible=True) + _register_int32_payload(writer, NestedInt32ListPayload) + _register_int32_payload(reader, NestedInt32ArrayPayload) + + decoded = reader.deserialize(writer.serialize(NestedInt32ListPayload(payload=[[1, 2], [3]]))) + + assert isinstance(decoded, NestedInt32ArrayPayload) + assert decoded.payload == [] + + if __name__ == "__main__": test_collection_field_type() test_map_field_type() diff --git a/python/pyfory/tests/xlang_test_main.py b/python/pyfory/tests/xlang_test_main.py index 53bd64e50d..d9c189e3ec 100644 --- a/python/pyfory/tests/xlang_test_main.py +++ b/python/pyfory/tests/xlang_test_main.py @@ -28,13 +28,20 @@ import math import os import decimal +import array from dataclasses import dataclass from typing import Any, Dict, List, Optional, Set import pyfory from pyfory import Ref +from pyfory.error import TypeNotCompatibleError from pyfory.meta.meta_compressor import NoOpMetaCompressor +try: + import numpy as np +except ImportError: + np = None + def debug_print(*params): """Print params if debug is needed.""" @@ -71,6 +78,12 @@ def decimal_values() -> List[decimal.Decimal]: ] +def empty_int32_ndarray(): + if np is None: + raise RuntimeError("numpy is required for pyfory.NDArray fields") + return np.empty(0, dtype=np.int32) + + # ============================================================================ # Test Data Classes - Must match XlangTestBase.java definitions # ============================================================================ @@ -196,6 +209,31 @@ class ReducedPrecisionFloatStruct: bfloat16_array: List[pyfory.BFloat16] = None +@dataclass +class CompatibleInt32ListField: + values: List[pyfory.FixedInt32] = pyfory.field(1, default_factory=list) + + +@dataclass +class CompatibleNullableInt32ListField: + values: List[Optional[pyfory.FixedInt32]] = pyfory.field(1, default_factory=list) + + +@dataclass +class CompatibleInt32ArrayField: + values: pyfory.Array[pyfory.Int32] = pyfory.field(1, default_factory=pyfory.Int32Array) + + +@dataclass +class CompatibleInt32NDArrayField: + values: pyfory.NDArray[pyfory.Int32] = pyfory.field(1, default_factory=empty_int32_ndarray) + + +@dataclass +class CompatibleInt32PyArrayField: + values: pyfory.PyArray[pyfory.Int32] = pyfory.field(1, default_factory=lambda: array.array("i")) + + class TestEnum(enum.Enum): VALUE_A = 0 VALUE_B = 1 @@ -881,6 +919,53 @@ def test_reduced_precision_float_struct_compatible_skip(): f.write(new_bytes) +def _round_trip_compatible_list_array_field(local_type): + data_file = get_data_file() + with open(data_file, "rb") as f: + data_bytes = f.read() + + fory = pyfory.Fory(xlang=True, compatible=True) + fory.register_type(local_type, type_id=901) + obj = fory.deserialize(data_bytes) + new_bytes = fory.serialize(obj) + with open(data_file, "wb") as f: + f.write(new_bytes) + + +def test_list_array_compatible_list_to_array(): + _round_trip_compatible_list_array_field(CompatibleInt32ArrayField) + + +def test_list_array_compatible_list_to_ndarray(): + _round_trip_compatible_list_array_field(CompatibleInt32NDArrayField) + + +def test_list_array_compatible_list_to_pyarray(): + _round_trip_compatible_list_array_field(CompatibleInt32PyArrayField) + + +def test_list_array_compatible_array_to_list(): + _round_trip_compatible_list_array_field(CompatibleInt32ListField) + + +def test_list_array_compatible_nullable_list_to_array_error(): + data_file = get_data_file() + with open(data_file, "rb") as f: + data_bytes = f.read() + + fory = pyfory.Fory(xlang=True, compatible=True) + fory.register_type(CompatibleInt32ArrayField, type_id=901) + try: + fory.deserialize(data_bytes) + except TypeNotCompatibleError: + pass + else: + raise AssertionError("Expected nullable list payload to fail compatible array read") + new_bytes = data_bytes + with open(data_file, "wb") as f: + f.write(new_bytes) + + def test_one_enum_field_schema(): """Test one enum field struct with schema consistent mode.""" data_file = get_data_file() diff --git a/rust/fory-core/src/serializer/codec.rs b/rust/fory-core/src/serializer/codec.rs index bcc343694d..5529525c46 100644 --- a/rust/fory-core/src/serializer/codec.rs +++ b/rust/fory-core/src/serializer/codec.rs @@ -21,6 +21,10 @@ //! Fory-owned building blocks that allow generated code to apply field-local and //! nested collection configuration without creating wrapper value types. +use super::collection::{ + read_primitive_array_vec_compatible_mismatch, read_vec_compatible_mismatch, + CompatibleListArrayElement, +}; use crate::context::{ReadContext, WriteContext}; use crate::error::Error; use crate::meta::FieldType; @@ -34,10 +38,10 @@ use std::marker::PhantomData; use std::rc::Rc; use std::sync::Arc; -const TRACKING_REF: u8 = 0b1; -const HAS_NULL: u8 = 0b10; -const DECL_ELEMENT_TYPE: u8 = 0b100; -const IS_SAME_TYPE: u8 = 0b1000; +pub(super) const TRACKING_REF: u8 = 0b1; +pub(super) const HAS_NULL: u8 = 0b10; +pub(super) const DECL_ELEMENT_TYPE: u8 = 0b100; +pub(super) const IS_SAME_TYPE: u8 = 0b1000; const TRACKING_KEY_REF: u8 = 0b1; const KEY_NULL: u8 = 0b10; @@ -169,7 +173,7 @@ where } #[inline(always)] -fn same_numeric_family(local: u32, remote: u32) -> bool { +pub(super) fn same_numeric_family(local: u32, remote: u32) -> bool { matches!( (local, remote), (type_id::INT32, type_id::INT32 | type_id::VARINT32) @@ -204,7 +208,7 @@ fn same_numeric_family(local: u32, remote: u32) -> bool { } #[inline(always)] -fn collection_type_with_fallback_generics(type_id: u32) -> bool { +pub(super) fn collection_type_with_fallback_generics(type_id: u32) -> bool { type_id == type_id::LIST || type_id == type_id::SET || type_id == type_id::MAP } @@ -223,7 +227,7 @@ pub fn field_types_compatible(local: &FieldType, remote: &FieldType) -> bool { } #[inline(always)] -fn generic_field_type<'a>( +pub(super) fn generic_field_type<'a>( field_type: &'a FieldType, index: usize, owner: &str, @@ -275,6 +279,7 @@ pub trait Codec: 'static { fn read_field(context: &mut ReadContext) -> Result; + #[inline(always)] fn read_compatible( context: &mut ReadContext, local_field_type: &FieldType, @@ -1279,6 +1284,24 @@ where Self::read_data(context) } + #[inline(always)] + fn read_compatible( + context: &mut ReadContext, + local_field_type: &FieldType, + remote_field_type: &FieldType, + ) -> Result>, Error> { + if local_field_type.compatible_fingerprint() == remote_field_type.compatible_fingerprint() { + return Self::read_field_with_type(context, remote_field_type).map(Some); + } + if local_field_type.type_id == remote_field_type.type_id + && collection_type_with_fallback_generics(local_field_type.type_id) + && (local_field_type.generics.is_empty() || remote_field_type.generics.is_empty()) + { + return Self::read_field_with_type(context, remote_field_type).map(Some); + } + read_vec_compatible_mismatch::(context, local_field_type, remote_field_type) + } + fn write_data(value: &Vec, context: &mut WriteContext) -> Result<(), Error> { let len = value.len(); context.writer.write_var_u32(len as u32); @@ -1397,10 +1420,12 @@ where "Type inconsistent, target collection element type is not polymorphic", )); } + let owned_element_type; let element_type = if is_declared { - generic_field_type(remote_field_type, 0, "list")?.clone() + generic_field_type(remote_field_type, 0, "list")? } else { - C::read_type_info_as_field_type(context)? + owned_element_type = C::read_type_info_as_field_type(context)?; + &owned_element_type }; let mut vec = Vec::with_capacity(len as usize); if has_null { @@ -1409,12 +1434,12 @@ where if flag == RefFlag::Null as i8 { vec.push(C::default_value()); } else { - vec.push(C::read_data_with_type(context, &element_type)?); + vec.push(C::read_data_with_type(context, element_type)?); } } } else { for _ in 0..len { - vec.push(C::read_data_with_type(context, &element_type)?); + vec.push(C::read_data_with_type(context, element_type)?); } } Ok(vec) @@ -1511,7 +1536,7 @@ pub struct PrimitiveArrayVecCodec Codec> for PrimitiveArrayVecCodec where - T: Serializer + ForyDefault, + T: CompatibleListArrayElement, { #[inline(always)] fn field_type(_: &TypeResolver) -> Result { @@ -1547,6 +1572,21 @@ where Self::read_data(context) } + fn read_compatible( + context: &mut ReadContext, + local_field_type: &FieldType, + remote_field_type: &FieldType, + ) -> Result>, Error> { + if local_field_type.compatible_fingerprint() == remote_field_type.compatible_fingerprint() { + return Self::read_field_with_type(context, remote_field_type).map(Some); + } + read_primitive_array_vec_compatible_mismatch::( + context, + local_field_type, + remote_field_type, + ) + } + #[inline(always)] fn write_data(value: &Vec, context: &mut WriteContext) -> Result<(), Error> { primitive_list::fory_write_data(value, context) diff --git a/rust/fory-core/src/serializer/collection.rs b/rust/fory-core/src/serializer/collection.rs index 1e9fa00119..421a224d6a 100644 --- a/rust/fory-core/src/serializer/collection.rs +++ b/rust/fory-core/src/serializer/collection.rs @@ -15,13 +15,16 @@ // specific language governing permissions and limitations // under the License. +use super::codec::{field_ref_mode, generic_field_type, same_numeric_family, Codec}; use crate::context::ReadContext; use crate::context::WriteContext; use crate::ensure; use crate::error::Error; +use crate::meta::FieldType; use crate::resolver::{RefFlag, RefMode}; use crate::serializer::{ForyDefault, Serializer}; -use crate::type_id::{need_to_write_type_for_field, PRIMITIVE_ARRAY_TYPES}; +use crate::type_id::{self, need_to_write_type_for_field, TypeId, PRIMITIVE_ARRAY_TYPES}; +use crate::types::{bfloat16::bfloat16, float16::float16}; const TRACKING_REF: u8 = 0b1; @@ -443,3 +446,398 @@ where .collect::>() } } + +fn primitive_array_element_type_id(type_id: u32) -> Option { + match type_id { + type_id::BOOL_ARRAY => Some(type_id::BOOL), + type_id::INT8_ARRAY => Some(type_id::INT8), + type_id::INT16_ARRAY => Some(type_id::INT16), + type_id::INT32_ARRAY => Some(type_id::INT32), + type_id::INT64_ARRAY => Some(type_id::INT64), + type_id::UINT8_ARRAY => Some(type_id::UINT8), + type_id::UINT16_ARRAY => Some(type_id::UINT16), + type_id::UINT32_ARRAY => Some(type_id::UINT32), + type_id::UINT64_ARRAY => Some(type_id::UINT64), + type_id::FLOAT16_ARRAY => Some(type_id::FLOAT16), + type_id::BFLOAT16_ARRAY => Some(type_id::BFLOAT16), + type_id::FLOAT32_ARRAY => Some(type_id::FLOAT32), + type_id::FLOAT64_ARRAY => Some(type_id::FLOAT64), + _ => None, + } +} + +fn primitive_array_element_size(type_id: u32) -> Option { + match type_id { + type_id::BOOL_ARRAY | type_id::INT8_ARRAY | type_id::UINT8_ARRAY => Some(1), + type_id::INT16_ARRAY + | type_id::UINT16_ARRAY + | type_id::FLOAT16_ARRAY + | type_id::BFLOAT16_ARRAY => Some(2), + type_id::INT32_ARRAY | type_id::UINT32_ARRAY | type_id::FLOAT32_ARRAY => Some(4), + type_id::INT64_ARRAY | type_id::UINT64_ARRAY | type_id::FLOAT64_ARRAY => Some(8), + _ => None, + } +} + +fn primitive_array_type_matches_rust_type(type_id: u32) -> bool { + let rust_type = std::any::TypeId::of::(); + match type_id { + type_id::BOOL_ARRAY => rust_type == std::any::TypeId::of::(), + type_id::INT8_ARRAY => rust_type == std::any::TypeId::of::(), + type_id::INT16_ARRAY => rust_type == std::any::TypeId::of::(), + type_id::INT32_ARRAY => rust_type == std::any::TypeId::of::(), + type_id::INT64_ARRAY => rust_type == std::any::TypeId::of::(), + type_id::UINT8_ARRAY => rust_type == std::any::TypeId::of::(), + type_id::UINT16_ARRAY => rust_type == std::any::TypeId::of::(), + type_id::UINT32_ARRAY => rust_type == std::any::TypeId::of::(), + type_id::UINT64_ARRAY => rust_type == std::any::TypeId::of::(), + type_id::FLOAT16_ARRAY => rust_type == std::any::TypeId::of::(), + type_id::BFLOAT16_ARRAY => rust_type == std::any::TypeId::of::(), + type_id::FLOAT32_ARRAY => rust_type == std::any::TypeId::of::(), + type_id::FLOAT64_ARRAY => rust_type == std::any::TypeId::of::(), + _ => false, + } +} + +fn read_primitive_array_data_bulk( + context: &mut ReadContext, + type_id: u32, + size_bytes: usize, + len: usize, +) -> Option, Error>> { + if !primitive_array_type_matches_rust_type::(type_id) { + return None; + } + #[cfg(target_endian = "little")] + { + let mut vec: Vec = Vec::with_capacity(len); + let src = match context.reader.read_bytes(size_bytes) { + Ok(src) => src, + Err(error) => return Some(Err(error)), + }; + unsafe { + std::ptr::copy_nonoverlapping(src.as_ptr(), vec.as_mut_ptr() as *mut u8, size_bytes); + vec.set_len(len); + } + Some(Ok(vec)) + } + #[cfg(target_endian = "big")] + { + let _ = (context, size_bytes, len); + None + } +} + +fn list_element_type_matches_array(list: &FieldType, array: &FieldType) -> bool { + primitive_array_element_type_id(array.type_id).is_some_and(|element_type_id| { + list.type_id == type_id::LIST + && list.generics.len() == 1 + && primitive_array_element_type_matches(element_type_id, list.generics[0].type_id) + }) +} + +fn primitive_array_element_type_matches( + array_element_type_id: u32, + list_element_type_id: u32, +) -> bool { + array_element_type_id == list_element_type_id + || same_numeric_family(array_element_type_id, list_element_type_id) +} + +fn read_primitive_array_data_with_codec( + context: &mut ReadContext, + remote_field_type: &FieldType, +) -> Result, Error> +where + T: 'static, + C: Codec, +{ + let size_bytes = context.reader.read_var_u32()? as usize; + let elem_size = primitive_array_element_size(remote_field_type.type_id) + .ok_or_else(|| Error::type_error("array-compatible field is not a primitive array"))?; + if size_bytes % elem_size != 0 { + return Err(Error::invalid_data("Invalid data length")); + } + let max = context.max_binary_size() as usize; + if size_bytes > max { + return Err(Error::size_limit_exceeded(format!( + "Binary size {} exceeds limit {}", + size_bytes, max + ))); + } + let remaining = context.reader.slice_after_cursor().len(); + if size_bytes > remaining { + let cursor = context.reader.get_cursor(); + return Err(Error::buffer_out_of_bound( + cursor, + size_bytes, + cursor + remaining, + )); + } + let len = size_bytes / elem_size; + let element_type_id = primitive_array_element_type_id(remote_field_type.type_id) + .ok_or_else(|| Error::type_error("array-compatible field is not a primitive array"))?; + if let Some(result) = + read_primitive_array_data_bulk::(context, remote_field_type.type_id, size_bytes, len) + { + return result; + } + let element_type = FieldType::new(element_type_id, false, Vec::new()); + let mut vec = Vec::with_capacity(len); + for _ in 0..len { + vec.push(C::read_data_with_type(context, &element_type)?); + } + Ok(vec) +} + +pub(super) trait CompatibleListArrayElement: Serializer + ForyDefault { + fn read_list_array_element( + context: &mut ReadContext, + remote_type_id: u32, + ) -> Result; +} + +macro_rules! compatible_exact_element { + ($ty:ty, $type_id:expr, $reader:ident) => { + impl CompatibleListArrayElement for $ty { + #[inline(always)] + fn read_list_array_element( + context: &mut ReadContext, + remote_type_id: u32, + ) -> Result { + if remote_type_id == $type_id { + context.reader.$reader() + } else { + Err(Error::type_mismatch( + <$ty as Serializer>::fory_static_type_id() as u32, + remote_type_id, + )) + } + } + } + }; +} + +macro_rules! compatible_integer_element { + ($ty:ty, $fixed_type:expr, $var_type:expr, $fixed_reader:ident, $var_reader:ident) => { + impl CompatibleListArrayElement for $ty { + #[inline(always)] + fn read_list_array_element( + context: &mut ReadContext, + remote_type_id: u32, + ) -> Result { + match remote_type_id { + x if x == $fixed_type => context.reader.$fixed_reader(), + x if x == $var_type => context.reader.$var_reader(), + _ => Err(Error::type_mismatch( + <$ty as Serializer>::fory_static_type_id() as u32, + remote_type_id, + )), + } + } + } + }; +} + +macro_rules! compatible_tagged_integer_element { + ( + $ty:ty, + $fixed_type:expr, + $var_type:expr, + $tagged_type:expr, + $fixed_reader:ident, + $var_reader:ident, + $tagged_reader:ident + ) => { + impl CompatibleListArrayElement for $ty { + #[inline(always)] + fn read_list_array_element( + context: &mut ReadContext, + remote_type_id: u32, + ) -> Result { + match remote_type_id { + x if x == $fixed_type => context.reader.$fixed_reader(), + x if x == $var_type => context.reader.$var_reader(), + x if x == $tagged_type => context.reader.$tagged_reader(), + _ => Err(Error::type_mismatch( + <$ty as Serializer>::fory_static_type_id() as u32, + remote_type_id, + )), + } + } + } + }; +} + +impl CompatibleListArrayElement for bool { + #[inline(always)] + fn read_list_array_element( + context: &mut ReadContext, + remote_type_id: u32, + ) -> Result { + if remote_type_id == type_id::BOOL { + Ok(context.reader.read_u8()? == 1) + } else { + Err(Error::type_mismatch(type_id::BOOL, remote_type_id)) + } + } +} + +compatible_exact_element!(i8, type_id::INT8, read_i8); +compatible_exact_element!(i16, type_id::INT16, read_i16); +compatible_integer_element!( + i32, + type_id::INT32, + type_id::VARINT32, + read_i32, + read_var_i32 +); +compatible_tagged_integer_element!( + i64, + type_id::INT64, + type_id::VARINT64, + type_id::TAGGED_INT64, + read_i64, + read_var_i64, + read_tagged_i64 +); +compatible_exact_element!(u8, type_id::UINT8, read_u8); +compatible_exact_element!(u16, type_id::UINT16, read_u16); +compatible_integer_element!( + u32, + type_id::UINT32, + type_id::VAR_UINT32, + read_u32, + read_var_u32 +); +compatible_tagged_integer_element!( + u64, + type_id::UINT64, + type_id::VAR_UINT64, + type_id::TAGGED_UINT64, + read_u64, + read_var_u64, + read_tagged_u64 +); +compatible_exact_element!(float16, type_id::FLOAT16, read_f16); +compatible_exact_element!(bfloat16, type_id::BFLOAT16, read_bf16); +compatible_exact_element!(f32, type_id::FLOAT32, read_f32); +compatible_exact_element!(f64, type_id::FLOAT64, read_f64); +compatible_exact_element!(i128, type_id::INT128, read_i128); +compatible_exact_element!(u128, TypeId::U128 as u32, read_u128); +compatible_exact_element!(isize, type_id::ISIZE, read_isize); +compatible_exact_element!(usize, type_id::USIZE, read_usize); + +fn read_non_nullable_list_data_with_type( + context: &mut ReadContext, + remote_field_type: &FieldType, +) -> Result, Error> +where + T: CompatibleListArrayElement, +{ + let element_type = generic_field_type(remote_field_type, 0, "list")?; + let len = context.reader.read_var_u32()?; + if len == 0 { + return Ok(Vec::new()); + } + let max = context.max_collection_size(); + if len > max { + return Err(Error::size_limit_exceeded(format!( + "Collection size {} exceeds limit {}", + len, max + ))); + } + let header = context.reader.read_u8()?; + if (header & HAS_NULL) != 0 { + return Err(Error::type_error( + "compatible list to array field requires non-null elements", + )); + } + if (header & TRACKING_REF) != 0 { + return Err(Error::type_error( + "array-compatible list declares reference-tracked elements", + )); + } + if (header & IS_SAME_TYPE) == 0 { + return Err(Error::type_error( + "array-compatible list must declare same-type elements", + )); + } + if (header & DECL_ELEMENT_TYPE) == 0 { + return Err(Error::type_error( + "array-compatible list must declare element type", + )); + } + let mut vec = Vec::with_capacity(len as usize); + for _ in 0..len { + vec.push(T::read_list_array_element(context, element_type.type_id)?); + } + Ok(vec) +} + +#[cold] +#[inline(never)] +pub(super) fn read_vec_compatible_mismatch( + context: &mut ReadContext, + local_field_type: &FieldType, + remote_field_type: &FieldType, +) -> Result>, Error> +where + T: 'static, + C: Codec, +{ + if local_field_type.type_id == type_id::LIST + && list_element_type_matches_array(local_field_type, remote_field_type) + { + return read_array_data_as_vec_bridge::(context, remote_field_type).map(Some); + } + Ok(None) +} + +fn read_array_data_as_vec_bridge( + context: &mut ReadContext, + remote_field_type: &FieldType, +) -> Result, Error> +where + T: 'static, + C: Codec, +{ + if field_ref_mode(remote_field_type) != RefMode::None { + let ref_flag = context.reader.read_i8()?; + if ref_flag == RefFlag::Null as i8 { + return Ok(Vec::new()); + } + } + if crate::serializer::util::field_need_read_type_info(remote_field_type.type_id) { + let remote = context.reader.read_u8()? as u32; + if remote != remote_field_type.type_id { + return Err(Error::type_mismatch(remote_field_type.type_id, remote)); + } + } + read_primitive_array_data_with_codec::(context, remote_field_type) +} + +#[cold] +#[inline(never)] +pub(super) fn read_primitive_array_vec_compatible_mismatch( + context: &mut ReadContext, + local_field_type: &FieldType, + remote_field_type: &FieldType, +) -> Result>, Error> +where + T: CompatibleListArrayElement, +{ + if remote_field_type.type_id == type_id::LIST + && !remote_field_type.generics.is_empty() + && list_element_type_matches_array(remote_field_type, local_field_type) + { + if field_ref_mode(remote_field_type) != RefMode::None { + let ref_flag = context.reader.read_i8()?; + if ref_flag == RefFlag::Null as i8 { + return Ok(Some(Vec::new())); + } + } + return read_non_nullable_list_data_with_type::(context, remote_field_type).map(Some); + } + Ok(None) +} diff --git a/rust/fory-derive/src/object/field_codec.rs b/rust/fory-derive/src/object/field_codec.rs index 2e209581ec..8241c142d2 100644 --- a/rust/fory-derive/src/object/field_codec.rs +++ b/rust/fory-derive/src/object/field_codec.rs @@ -379,12 +379,7 @@ fn field_dispatch_for( ) -> syn::Result { if meta.array { let codec_ty = codec_type_for(ty, meta, nullable, track_ref)?; - return Ok(FieldDispatch::Serializer { - field_type: quote! { - <#codec_ty as fory_core::serializer::codec::Codec<#ty>>::field_type(type_resolver)? - }, - has_generics: true, - }); + return Ok(FieldDispatch::Codec { codec_ty }); } if meta.encoding.is_none() && meta.list.is_none() diff --git a/rust/tests/tests/compatible/test_struct.rs b/rust/tests/tests/compatible/test_struct.rs index cd18c947b5..1b5e94dec4 100644 --- a/rust/tests/tests/compatible/test_struct.rs +++ b/rust/tests/tests/compatible/test_struct.rs @@ -69,6 +69,99 @@ fn simple() { assert_eq!(animal.last, obj.last); } +#[test] +fn compatible_list_array_field_pairs() { + #[derive(ForyStruct, Debug)] + struct ListPayload { + payload: Vec, + } + + #[derive(ForyStruct, Debug)] + struct NullableListPayload { + #[fory(list(element(nullable = true)))] + payload: Vec>, + } + + #[derive(ForyStruct, Debug)] + struct ArrayPayload { + #[fory(array)] + payload: Vec, + } + + #[derive(ForyStruct, Debug)] + struct NestedListPayload { + payload: Vec>, + } + + #[derive(ForyStruct, Debug)] + struct NestedArrayPayload { + #[fory(list(element(array)))] + payload: Vec>, + } + + let mut writer = Fory::builder().compatible(true).build(); + let mut reader = Fory::builder().compatible(true).build(); + writer.register::(991).unwrap(); + reader.register::(991).unwrap(); + let bytes = writer + .serialize(&ListPayload { + payload: vec![1, 2, 3], + }) + .unwrap(); + let decoded: ArrayPayload = reader.deserialize(&bytes).unwrap(); + assert_eq!(decoded.payload, vec![1, 2, 3]); + + let mut writer = Fory::builder().compatible(true).build(); + let mut reader = Fory::builder().compatible(true).build(); + writer.register::(992).unwrap(); + reader.register::(992).unwrap(); + let bytes = writer + .serialize(&ArrayPayload { + payload: vec![1, 2, 3], + }) + .unwrap(); + let decoded: ListPayload = reader.deserialize(&bytes).unwrap(); + assert_eq!(decoded.payload, vec![1, 2, 3]); + + let mut writer = Fory::builder().compatible(true).build(); + let mut reader = Fory::builder().compatible(true).build(); + writer.register::(993).unwrap(); + reader.register::(993).unwrap(); + let bytes = writer + .serialize(&NullableListPayload { + payload: vec![Some(1), Some(2), Some(3)], + }) + .unwrap(); + let decoded: ArrayPayload = reader.deserialize(&bytes).unwrap(); + assert_eq!(decoded.payload, vec![1, 2, 3]); + + let bytes = writer + .serialize(&NullableListPayload { + payload: vec![Some(1), None, Some(3)], + }) + .unwrap(); + let err = reader + .deserialize::(&bytes) + .expect_err("expected nullable list payload to fail compatible array read"); + assert!( + err.to_string() + .contains("compatible list to array field requires non-null elements"), + "{err}" + ); + + let mut writer = Fory::builder().compatible(true).build(); + let mut reader = Fory::builder().compatible(true).build(); + writer.register::(994).unwrap(); + reader.register::(994).unwrap(); + let bytes = writer + .serialize(&NestedListPayload { + payload: vec![vec![1, 2], vec![3]], + }) + .unwrap(); + let decoded: NestedArrayPayload = reader.deserialize(&bytes).unwrap(); + assert_eq!(decoded.payload, Vec::>::default()); +} + #[test] fn skip_option() { #[derive(ForyStruct, Debug)] diff --git a/rust/tests/tests/test_cross_language.rs b/rust/tests/tests/test_cross_language.rs index 1008f9229a..8660ad493e 100644 --- a/rust/tests/tests/test_cross_language.rs +++ b/rust/tests/tests/test_cross_language.rs @@ -902,6 +902,18 @@ struct ReducedPrecisionFloatStruct { bfloat16_array: Vec, } +#[derive(ForyStruct, Debug, PartialEq)] +struct CompatibleInt32ListField { + #[fory(id = 1, list(element(encoding = fixed)))] + values: Vec, +} + +#[derive(ForyStruct, Debug, PartialEq)] +struct CompatibleInt32ArrayField { + #[fory(id = 1, array)] + values: Vec, +} + #[allow(non_camel_case_types)] #[derive(ForyEnum, Debug, PartialEq, Default, Clone)] enum TestEnum { @@ -1408,6 +1420,50 @@ fn test_reduced_precision_float_struct_compatible_skip() { fs::write(&data_file_path, new_bytes).unwrap(); } +#[test] +#[ignore] +fn test_list_array_compatible_list_to_array() { + let data_file_path = get_data_file(); + let bytes = fs::read(&data_file_path).unwrap(); + + let mut fory = Fory::builder().compatible(true).xlang(true).build(); + fory.register::(901).unwrap(); + let value: CompatibleInt32ArrayField = fory.deserialize(&bytes).unwrap(); + let new_bytes = fory.serialize(&value).unwrap(); + fs::write(&data_file_path, new_bytes).unwrap(); +} + +#[test] +#[ignore] +fn test_list_array_compatible_array_to_list() { + let data_file_path = get_data_file(); + let bytes = fs::read(&data_file_path).unwrap(); + + let mut fory = Fory::builder().compatible(true).xlang(true).build(); + fory.register::(901).unwrap(); + let value: CompatibleInt32ListField = fory.deserialize(&bytes).unwrap(); + let new_bytes = fory.serialize(&value).unwrap(); + fs::write(&data_file_path, new_bytes).unwrap(); +} + +#[test] +#[ignore] +fn test_list_array_compatible_nullable_list_to_array_error() { + let data_file_path = get_data_file(); + let bytes = fs::read(&data_file_path).unwrap(); + + let mut fory = Fory::builder().compatible(true).xlang(true).build(); + fory.register::(901).unwrap(); + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + fory.deserialize::(&bytes) + })); + assert!( + result.is_err() || result.is_ok_and(|value| value.is_err()), + "Expected nullable list payload to fail compatible array read" + ); + fs::write(&data_file_path, bytes).unwrap(); +} + // ============================================================================ // Schema Evolution Tests - Enum Fields // ============================================================================ diff --git a/swift/Sources/Fory/FieldCodecs.swift b/swift/Sources/Fory/FieldCodecs.swift index eb34c1cafe..9d48cd6e74 100644 --- a/swift/Sources/Fory/FieldCodecs.swift +++ b/swift/Sources/Fory/FieldCodecs.swift @@ -32,6 +32,11 @@ public protocol FieldCodec { static func writeStaticTypeInfo(_ context: WriteContext) throws static func readTypeInfo(_ context: ReadContext) throws -> TypeInfo? static func withTypeInfo(_ typeInfo: TypeInfo?, _ context: ReadContext, _ body: () throws -> R) rethrows -> R + static func readCompatibleField( + _ context: ReadContext, + remoteFieldType: TypeMeta.FieldType, + refMode: RefMode + ) throws -> Value } public extension FieldCodec { @@ -62,6 +67,18 @@ public extension FieldCodec { return try body() } + static func readCompatibleField( + _ context: ReadContext, + remoteFieldType: TypeMeta.FieldType, + refMode: RefMode + ) throws -> Value { + try read( + context, + refMode: refMode, + readTypeInfo: TypeId.needsTypeInfoForField(TypeId(rawValue: remoteFieldType.typeID) ?? .unknown) + ) + } + static func write( _ value: Value, _ context: WriteContext, @@ -163,6 +180,21 @@ public extension FieldCodec { } } +private enum FieldCodecDefault { + static func readCompatibleField( + codec _: Codec.Type, + _ context: ReadContext, + remoteFieldType: TypeMeta.FieldType, + refMode: RefMode + ) throws -> Codec.Value { + try Codec.read( + context, + refMode: refMode, + readTypeInfo: TypeId.needsTypeInfoForField(TypeId(rawValue: remoteFieldType.typeID) ?? .unknown) + ) + } +} + public enum SerializerCodec: FieldCodec { public typealias Value = T @@ -622,6 +654,22 @@ public enum ListFieldCodec: FieldCodec { public static func readPayload(_ context: ReadContext) throws -> Value { return try readCollectionPayload(context, elementCodec: ElementCodec.self) } + + public static func readCompatibleField( + _ context: ReadContext, + remoteFieldType: TypeMeta.FieldType, + refMode: RefMode + ) throws -> Value { + if isCompatiblePackedArrayTypeID(remoteFieldType.typeID, elementCodec: ElementCodec.self) { + return try readCompatiblePackedArrayField(context, refMode: refMode, elementCodec: ElementCodec.self) + } + return try FieldCodecDefault.readCompatibleField( + codec: Self.self, + context, + remoteFieldType: remoteFieldType, + refMode: refMode + ) + } } public enum ArrayFieldCodec: FieldCodec { @@ -654,6 +702,30 @@ public enum ArrayFieldCodec: FieldCodec { throw ForyError.invalidData("unsupported array field element codec \(ElementCodec.self)") } + public static func readCompatibleField( + _ context: ReadContext, + remoteFieldType: TypeMeta.FieldType, + refMode: RefMode + ) throws -> Value { + if remoteFieldType.typeID == TypeId.list.rawValue, + let element = remoteFieldType.generics.first, + let localArrayTypeID = packedArrayTypeID(for: ElementCodec.self), + TypeId.listElementTypeID(element.typeID, matchesDenseArrayTypeID: localArrayTypeID.rawValue) { + return try readListPayloadAsArray( + context, + refMode: refMode, + elementCodec: ElementCodec.self, + remoteElementTypeID: element.typeID + ) + } + return try FieldCodecDefault.readCompatibleField( + codec: Self.self, + context, + remoteFieldType: remoteFieldType, + refMode: refMode + ) + } + public static func write( _ value: Value, _ context: WriteContext, @@ -1045,6 +1117,12 @@ private func uncheckedPackedArrayCast(_ array: [From], to _: To.Type) return unsafeBitCast(array, to: [To].self) } +@inline(__always) +private func uncheckedScalarCast(_ value: From, to _: To.Type) -> To { + assert(From.self == To.self) + return unsafeBitCast(value, to: To.self) +} + private func packedArrayTypeID(for _: ElementCodec.Type) -> TypeId? { if ElementCodec.isNullableType { return nil @@ -1091,6 +1169,13 @@ private func packedArrayTypeID(for _: ElementCodec.Typ return nil } +private func isCompatiblePackedArrayTypeID( + _ typeID: UInt32, + elementCodec _: ElementCodec.Type +) -> Bool { + TypeId.listElementTypeID(ElementCodec.typeId.rawValue, matchesDenseArrayTypeID: typeID) +} + private func writePackedArrayPayload( _ value: [ElementCodec.Value], _ context: WriteContext, @@ -1245,6 +1330,170 @@ private func readUIntArrayPayload(_ context: ReadContext) throws -> [UInt] { return values } +@inline(never) +private func readCompatiblePackedArrayField( + _ context: ReadContext, + refMode: RefMode, + elementCodec _: ElementCodec.Type +) throws -> [ElementCodec.Value] { + switch refMode { + case .none: + return try readCompatiblePackedArrayPayload(context, elementCodec: ElementCodec.self) + case .nullOnly, .tracking: + let rawFlag = try context.buffer.readInt8() + guard rawFlag != RefFlag.null.rawValue else { + return [] + } + if rawFlag == RefFlag.ref.rawValue { + let refID = try context.buffer.readVarUInt32() + return try context.refReader.readRef(refID, as: [ElementCodec.Value].self) + } + let reservedRefID = (rawFlag == RefFlag.refValue.rawValue && context.trackRef) + ? context.refReader.reserveRefID() + : nil + let value = try readCompatiblePackedArrayPayload(context, elementCodec: ElementCodec.self) + if let reservedRefID { + context.refReader.storeRef(value, at: reservedRefID) + } + return value + } +} + +private func readCompatiblePackedArrayPayload( + _ context: ReadContext, + elementCodec _: ElementCodec.Type +) throws -> [ElementCodec.Value] { + if ElementCodec.self == BoolCodec.self { + return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Bool], to: ElementCodec.Value.self) + } + if ElementCodec.self == Int8Codec.self { + return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int8], to: ElementCodec.Value.self) + } + if ElementCodec.self == Int16Codec.self { + return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int16], to: ElementCodec.Value.self) + } + if ElementCodec.self == Int32FixedCodec.self || ElementCodec.self == Int32VarintCodec.self { + return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int32], to: ElementCodec.Value.self) + } + if ElementCodec.self == Int64FixedCodec.self || ElementCodec.self == Int64VarintCodec.self || ElementCodec.self == Int64TaggedCodec.self { + return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int64], to: ElementCodec.Value.self) + } + if ElementCodec.self == IntFixedCodec.self || ElementCodec.self == IntVarintCodec.self || ElementCodec.self == IntTaggedCodec.self { + return uncheckedPackedArrayCast(try readIntArrayPayload(context), to: ElementCodec.Value.self) + } + if ElementCodec.self == UInt8Codec.self { + return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt8], to: ElementCodec.Value.self) + } + if ElementCodec.self == UInt16Codec.self { + return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt16], to: ElementCodec.Value.self) + } + if ElementCodec.self == UInt32FixedCodec.self || ElementCodec.self == UInt32VarintCodec.self { + return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt32], to: ElementCodec.Value.self) + } + if ElementCodec.self == UInt64FixedCodec.self || ElementCodec.self == UInt64VarintCodec.self || ElementCodec.self == UInt64TaggedCodec.self { + return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt64], to: ElementCodec.Value.self) + } + if ElementCodec.self == UIntFixedCodec.self || ElementCodec.self == UIntVarintCodec.self || ElementCodec.self == UIntTaggedCodec.self { + return uncheckedPackedArrayCast(try readUIntArrayPayload(context), to: ElementCodec.Value.self) + } + if ElementCodec.self == Float16Codec.self { + return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Float16], to: ElementCodec.Value.self) + } + if ElementCodec.self == BFloat16Codec.self { + return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [BFloat16], to: ElementCodec.Value.self) + } + if ElementCodec.self == FloatCodec.self { + return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Float], to: ElementCodec.Value.self) + } + if ElementCodec.self == DoubleCodec.self { + return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Double], to: ElementCodec.Value.self) + } + throw ForyError.invalidData("unsupported compatible array-to-list field element codec \(ElementCodec.self)") +} + +private func readCompatibleElementPayload( + _ context: ReadContext, + elementCodec _: ElementCodec.Type, + remoteElementTypeID: UInt32? +) throws -> ElementCodec.Value { + guard let remoteElementTypeID, + remoteElementTypeID != ElementCodec.typeId.rawValue, + let remoteTypeID = TypeId(rawValue: remoteElementTypeID) + else { + return try ElementCodec.readPayload(context) + } + + if ElementCodec.self == Int32FixedCodec.self || ElementCodec.self == Int32VarintCodec.self { + switch remoteTypeID { + case .int32: + return uncheckedScalarCast(try context.buffer.readInt32() as Int32, to: ElementCodec.Value.self) + case .varint32: + return uncheckedScalarCast(try context.buffer.readVarInt32() as Int32, to: ElementCodec.Value.self) + default: + break + } + } + if ElementCodec.self == Int64FixedCodec.self || ElementCodec.self == Int64VarintCodec.self || ElementCodec.self == Int64TaggedCodec.self { + switch remoteTypeID { + case .int64: + return uncheckedScalarCast(try context.buffer.readInt64() as Int64, to: ElementCodec.Value.self) + case .varint64: + return uncheckedScalarCast(try context.buffer.readVarInt64() as Int64, to: ElementCodec.Value.self) + case .taggedInt64: + return uncheckedScalarCast(try context.buffer.readTaggedInt64() as Int64, to: ElementCodec.Value.self) + default: + break + } + } + if ElementCodec.self == IntFixedCodec.self || ElementCodec.self == IntVarintCodec.self || ElementCodec.self == IntTaggedCodec.self { + switch remoteTypeID { + case .int64: + return uncheckedScalarCast(Int(try context.buffer.readInt64()), to: ElementCodec.Value.self) + case .varint64: + return uncheckedScalarCast(Int(try context.buffer.readVarInt64()), to: ElementCodec.Value.self) + case .taggedInt64: + return uncheckedScalarCast(Int(try context.buffer.readTaggedInt64()), to: ElementCodec.Value.self) + default: + break + } + } + if ElementCodec.self == UInt32FixedCodec.self || ElementCodec.self == UInt32VarintCodec.self { + switch remoteTypeID { + case .uint32: + return uncheckedScalarCast(try context.buffer.readUInt32() as UInt32, to: ElementCodec.Value.self) + case .varUInt32: + return uncheckedScalarCast(try context.buffer.readVarUInt32() as UInt32, to: ElementCodec.Value.self) + default: + break + } + } + if ElementCodec.self == UInt64FixedCodec.self || ElementCodec.self == UInt64VarintCodec.self || ElementCodec.self == UInt64TaggedCodec.self { + switch remoteTypeID { + case .uint64: + return uncheckedScalarCast(try context.buffer.readUInt64() as UInt64, to: ElementCodec.Value.self) + case .varUInt64: + return uncheckedScalarCast(try context.buffer.readVarUInt64() as UInt64, to: ElementCodec.Value.self) + case .taggedUInt64: + return uncheckedScalarCast(try context.buffer.readTaggedUInt64() as UInt64, to: ElementCodec.Value.self) + default: + break + } + } + if ElementCodec.self == UIntFixedCodec.self || ElementCodec.self == UIntVarintCodec.self || ElementCodec.self == UIntTaggedCodec.self { + switch remoteTypeID { + case .uint64: + return uncheckedScalarCast(UInt(try context.buffer.readUInt64()), to: ElementCodec.Value.self) + case .varUInt64: + return uncheckedScalarCast(UInt(try context.buffer.readVarUInt64()), to: ElementCodec.Value.self) + case .taggedUInt64: + return uncheckedScalarCast(UInt(try context.buffer.readTaggedUInt64()), to: ElementCodec.Value.self) + default: + break + } + } + throw ForyError.typeMismatch(expected: ElementCodec.typeId.rawValue, actual: remoteElementTypeID) +} + private func readPackedArrayElementCount( _ context: ReadContext, width: Int, @@ -1377,3 +1626,90 @@ private func readCollectionPayload( return result } } + +@inline(never) +private func readListPayloadAsArray( + _ context: ReadContext, + refMode: RefMode, + elementCodec _: ElementCodec.Type, + remoteElementTypeID: UInt32 +) throws -> [ElementCodec.Value] { + switch refMode { + case .none: + return try readListPayloadAsArrayPayload( + context, + elementCodec: ElementCodec.self, + remoteElementTypeID: remoteElementTypeID + ) + case .nullOnly, .tracking: + let rawFlag = try context.buffer.readInt8() + guard rawFlag != RefFlag.null.rawValue else { + return [] + } + if rawFlag == RefFlag.ref.rawValue { + let refID = try context.buffer.readVarUInt32() + return try context.refReader.readRef(refID, as: [ElementCodec.Value].self) + } + let reservedRefID = (rawFlag == RefFlag.refValue.rawValue && context.trackRef) + ? context.refReader.reserveRefID() + : nil + let value = try readListPayloadAsArrayPayload( + context, + elementCodec: ElementCodec.self, + remoteElementTypeID: remoteElementTypeID + ) + if let reservedRefID { + context.refReader.storeRef(value, at: reservedRefID) + } + return value + } +} + +private func readListPayloadAsArrayPayload( + _ context: ReadContext, + elementCodec _: ElementCodec.Type, + remoteElementTypeID: UInt32 +) throws -> [ElementCodec.Value] { + let buffer = context.buffer + let length = Int(try buffer.readVarUInt32()) + try context.ensureCollectionLength(length, label: "array") + if length == 0 { + return [] + } + + let header = try buffer.readUInt8() + let trackRef = (header & CollectionHeader.trackingRef) != 0 + let hasNull = (header & CollectionHeader.hasNull) != 0 + if hasNull { + throw ForyError.invalidData("compatible list-to-array field cannot read nullable elements") + } + let declared = (header & CollectionHeader.declaredElementType) != 0 + let sameType = (header & CollectionHeader.sameType) != 0 + + var result: [ElementCodec.Value] = [] + result.reserveCapacity(length) + + if !sameType { + throw ForyError.invalidData("compatible list-to-array field requires same-type elements") + } + + if trackRef { + throw ForyError.invalidData("compatible list-to-array field cannot read ref-tracked elements") + } + let elementTypeInfo: TypeInfo? + if declared { + elementTypeInfo = nil + } else { + throw ForyError.invalidData("compatible list-to-array field requires declared elements") + } + return try ElementCodec.withTypeInfo(elementTypeInfo, context) { + for _ in 0.. Bool { + guard + let listElementType = TypeId(rawValue: listElementTypeID), + let arrayElementType = TypeId(rawValue: arrayTypeID)?.denseArrayElementTypeID + else { + return false + } + if listElementType == arrayElementType { + return true + } + guard let listDomain = listElementType.compatibleIntegerEncodingDomain else { + return false + } + return listDomain == arrayElementType.compatibleIntegerEncodingDomain + } + + private var compatibleIntegerEncodingDomain: UInt8? { + switch self { + case .int32, .varint32: + return 1 + case .int64, .varint64, .taggedInt64: + return 2 + case .uint32, .varUInt32: + return 3 + case .uint64, .varUInt64, .taggedUInt64: + return 4 + default: + return nil + } + } } diff --git a/swift/Sources/Fory/TypeMeta.swift b/swift/Sources/Fory/TypeMeta.swift index 16cc64ecfe..4c123b2bb2 100644 --- a/swift/Sources/Fory/TypeMeta.swift +++ b/swift/Sources/Fory/TypeMeta.swift @@ -676,8 +676,12 @@ public final class TypeMeta: Equatable, @unchecked Sendable { private static func isCompatibleFieldType( _ remoteType: FieldType, - _ localType: FieldType + _ localType: FieldType, + topLevel: Bool = true ) -> Bool { + if topLevel, isCompatibleTopLevelListArrayFieldType(remoteType, localType) { + return true + } if normalizeCompatibleTypeIDForComparison(remoteType.typeID) != normalizeCompatibleTypeIDForComparison(localType.typeID) { return false @@ -686,12 +690,37 @@ public final class TypeMeta: Equatable, @unchecked Sendable { return false } for (remoteGeneric, localGeneric) in zip(remoteType.generics, localType.generics) - where !isCompatibleFieldType(remoteGeneric, localGeneric) { + where !isCompatibleFieldType(remoteGeneric, localGeneric, topLevel: false) { return false } return true } + private static func isCompatibleTopLevelListArrayFieldType( + _ remoteType: FieldType, + _ localType: FieldType + ) -> Bool { + if remoteType.typeID == TypeId.list.rawValue { + return listFieldType(remoteType, matchesDenseArrayTypeID: localType.typeID) + } + if localType.typeID == TypeId.list.rawValue { + return listFieldType(localType, matchesDenseArrayTypeID: remoteType.typeID) + } + return false + } + + private static func listFieldType( + _ listType: FieldType, + matchesDenseArrayTypeID arrayTypeID: UInt32 + ) -> Bool { + guard listType.typeID == TypeId.list.rawValue, + let elementType = listType.generics.first + else { + return false + } + return TypeId.listElementTypeID(elementType.typeID, matchesDenseArrayTypeID: arrayTypeID) + } + private static func normalizeCompatibleTypeIDForComparison(_ typeID: UInt32) -> UInt32 { switch typeID { case TypeId.structType.rawValue, diff --git a/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift b/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift index f7c23b73f5..7a025e6437 100644 --- a/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift +++ b/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift @@ -404,6 +404,15 @@ private func readFieldExpr( } if let codecType = field.customCodecType { let fieldCodec = field.isOptional ? "OptionalFieldCodec<\(codecType)>" : codecType + if readTypeInfoExpr.contains("remoteField.fieldType") { + return """ + try \(fieldCodec).readCompatibleField( + context, + remoteFieldType: remoteField.fieldType, + refMode: \(refModeExpr) + ) + """ + } return "try \(fieldCodec).read(context, refMode: \(refModeExpr), readTypeInfo: false)" } return "try \(field.typeText).foryRead(context, refMode: \(refModeExpr), readTypeInfo: \(readTypeInfoExpr))" diff --git a/swift/Tests/ForyTests/CompatibilityTests.swift b/swift/Tests/ForyTests/CompatibilityTests.swift index 6633e0bf7b..2afa3f255c 100644 --- a/swift/Tests/ForyTests/CompatibilityTests.swift +++ b/swift/Tests/ForyTests/CompatibilityTests.swift @@ -119,6 +119,52 @@ private struct LocalNestedVarintMapV2: Equatable { var ids: Set = [] } +@ForyStruct +private struct CompatibleListFieldV1: Equatable { + @ListField(element: .int32(encoding: .fixed)) + var values: [Int32] = [] + + var extra: Int32 = 0 +} + +@ForyStruct +private struct CompatibleVarintListFieldV1: Equatable { + @ListField(element: .int32()) + var values: [Int32] = [] + + var extra: Int32 = 0 +} + +@ForyStruct +private struct CompatibleArrayFieldV2: Equatable { + @ArrayField(element: .int32()) + var values: [Int32] = [] +} + +@ForyStruct +private struct CompatibleNullableListFieldV1: Equatable { + @ListField(element: .int32(nullable: true, encoding: .fixed)) + var values: [Int32?] = [] + + var extra: Int32 = 0 +} + +@ForyStruct +private struct CompatibleNestedListArrayFieldV1: Equatable { + @ListField(element: .list(element: .int32(encoding: .fixed))) + var values: [[Int32]] = [] + + var keep: Int32 = 0 +} + +@ForyStruct +private struct CompatibleNestedArrayListFieldV2: Equatable { + @ListField(element: .array(element: .int32())) + var values: [[Int32]] = [] + + var keep: Int32 = 0 +} + @ForyStruct private struct SchemaVersionV1: Equatable { var id: Int32 = 0 @@ -353,6 +399,81 @@ func compatibleSkipUsesRemoteMetadataForNestedMapListSetFields() throws { #expect(decoded.ids.isEmpty) } +@Test +func compatibleReadAdaptsImmediateListAndArrayFieldPair() throws { + let writer = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + writer.register(CompatibleListFieldV1.self, id: 9922) + + let reader = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + reader.register(CompatibleArrayFieldV2.self, id: 9922) + + let decoded: CompatibleArrayFieldV2 = try reader.deserialize( + try writer.serialize(CompatibleListFieldV1(values: [1, 2, 3], extra: 9)) + ) + #expect(decoded.values == [1, 2, 3]) +} + +@Test +func compatibleReadAdaptsDefaultVarintListAndArrayFieldPair() throws { + let writer = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + writer.register(CompatibleVarintListFieldV1.self, id: 9924) + + let reader = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + reader.register(CompatibleArrayFieldV2.self, id: 9924) + + let decoded: CompatibleArrayFieldV2 = try reader.deserialize( + try writer.serialize(CompatibleVarintListFieldV1(values: [-1, 2, 3], extra: 9)) + ) + #expect(decoded.values == [-1, 2, 3]) +} + +@Test +func compatibleReadAdaptsArrayFieldToDefaultVarintListField() throws { + let writer = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + writer.register(CompatibleArrayFieldV2.self, id: 9925) + + let reader = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + reader.register(CompatibleVarintListFieldV1.self, id: 9925) + + let decoded: CompatibleVarintListFieldV1 = try reader.deserialize( + try writer.serialize(CompatibleArrayFieldV2(values: [-1, 2, 3])) + ) + #expect(decoded.values == [-1, 2, 3]) +} + +@Test +func compatibleReadRejectsNullableListElementsForArrayField() throws { + let writer = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + writer.register(CompatibleNullableListFieldV1.self, id: 9923) + + let reader = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + reader.register(CompatibleArrayFieldV2.self, id: 9923) + + let bytes = try writer.serialize(CompatibleNullableListFieldV1(values: [1, 2, 3], extra: 9)) + let decoded: CompatibleArrayFieldV2 = try reader.deserialize(bytes) + #expect(decoded.values == [1, 2, 3]) + + let nullableBytes = try writer.serialize(CompatibleNullableListFieldV1(values: [1, nil, 3], extra: 9)) + #expect(throws: ForyError.invalidData("compatible list-to-array field cannot read nullable elements")) { + let _: CompatibleArrayFieldV2 = try reader.deserialize(nullableBytes) + } +} + +@Test +func compatibleReadSkipsNestedListArrayFieldPair() throws { + let writer = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + writer.register(CompatibleNestedListArrayFieldV1.self, id: 9926) + + let reader = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + reader.register(CompatibleNestedArrayListFieldV2.self, id: 9926) + + let decoded: CompatibleNestedArrayListFieldV2 = try reader.deserialize( + try writer.serialize(CompatibleNestedListArrayFieldV1(values: [[1, 2]], keep: 7)) + ) + #expect(decoded.values.isEmpty) + #expect(decoded.keep == 7) +} + @Test func compatibleNestedMapEvolves() throws { let writerV1 = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) diff --git a/swift/Tests/ForyXlangTests/main.swift b/swift/Tests/ForyXlangTests/main.swift index 0b9b5311b8..f96fd9c98a 100644 --- a/swift/Tests/ForyXlangTests/main.swift +++ b/swift/Tests/ForyXlangTests/main.swift @@ -124,6 +124,27 @@ private struct ReducedPrecisionFloatStruct { var bfloat16Array: [BFloat16] = [] } +@ForyStruct +private struct CompatibleInt32ListField { + @ForyField(id: 1) + @ListField(element: .int32(encoding: .fixed)) + var values: [Int32] = [] +} + +@ForyStruct +private struct CompatibleNullableInt32ListField { + @ForyField(id: 1) + @ListField(element: .int32(nullable: true, encoding: .fixed)) + var values: [Int32?] = [] +} + +@ForyStruct +private struct CompatibleInt32ArrayField { + @ForyField(id: 1) + @ArrayField(element: .int32()) + var values: [Int32] = [] +} + @ForyStruct private struct OneEnumFieldStruct { var f1: PeerTestEnum = .valueA @@ -1079,6 +1100,31 @@ private func handleReducedPrecisionFloatStructCompatibleSkip(_ bytes: [UInt8]) t return try roundTripSingle(bytes, fory: fory, as: EmptyStructEvolution.self) } +private func handleListArrayCompatibleListToArray(_ bytes: [UInt8]) throws -> [UInt8] { + let fory = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + fory.register(CompatibleInt32ArrayField.self, id: 901) + return try roundTripSingle(bytes, fory: fory, as: CompatibleInt32ArrayField.self) +} + +private func handleListArrayCompatibleArrayToList(_ bytes: [UInt8]) throws -> [UInt8] { + let fory = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + fory.register(CompatibleInt32ListField.self, id: 901) + return try roundTripSingle(bytes, fory: fory, as: CompatibleInt32ListField.self) +} + +private func handleListArrayCompatibleNullableListToArrayError(_ bytes: [UInt8]) throws -> [UInt8] { + let fory = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + fory.register(CompatibleInt32ArrayField.self, id: 901) + do { + let _: CompatibleInt32ArrayField = try fory.deserialize(Data(bytes)) + } catch { + return bytes + } + throw PeerError.invalidFieldValue( + "Expected nullable list payload to fail compatible array read" + ) +} + private func rewritePayload(caseName: String, bytes: [UInt8]) throws -> [UInt8] { switch caseName { case "test_buffer", "test_buffer_var": @@ -1175,6 +1221,12 @@ private func rewritePayload(caseName: String, bytes: [UInt8]) throws -> [UInt8] return try handleReducedPrecisionFloatStruct(bytes) case "test_reduced_precision_float_struct_compatible_skip": return try handleReducedPrecisionFloatStructCompatibleSkip(bytes) + case "test_list_array_compatible_list_to_array": + return try handleListArrayCompatibleListToArray(bytes) + case "test_list_array_compatible_array_to_list": + return try handleListArrayCompatibleArrayToList(bytes) + case "test_list_array_compatible_nullable_list_to_array_error": + return try handleListArrayCompatibleNullableListToArrayError(bytes) default: throw PeerError.unsupportedCase(caseName) }