From 2b2221332fca20ba6e2d214167be2c892bff43e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=85=95=E7=99=BD?= Date: Mon, 4 May 2026 22:21:58 +0800 Subject: [PATCH 01/10] parallel tests run --- AGENTS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/AGENTS.md b/AGENTS.md index bb92c13873..492560d23e 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -104,6 +104,7 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th ## Shared Validation Expectations - Run the relevant tests for every touched language or subsystem before finishing. +- When multiple independent language test suites are required, run them concurrently when the environment has enough resources instead of running them one by one; keep each language's logs and results separate, and rerun any failed suite with focused diagnostics. - Run applicable test commands in a subagent with a thinking budget one level lower than the main task budget, using medium when the current budget is unclear, unless the change is docs-only or the user explicitly asks to run them locally. - Reuse the same test subagent for repeated runs within one task and subsystem so it keeps failure context; create a fresh subagent when switching unrelated subsystems or when prior context may be stale or misleading. - Use `integration_tests/` for cross-language compatibility validation when behavior crosses runtimes. From 2188d134154d3c9e493c2220e5ec76d95dad010d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=85=95=E7=99=BD?= Date: Tue, 5 May 2026 22:16:55 +0800 Subject: [PATCH 02/10] docs: clarify task doc formatting exception --- .agents/docs-and-formatting.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.agents/docs-and-formatting.md b/.agents/docs-and-formatting.md index f4aaed7d14..0f9bfe258c 100644 --- a/.agents/docs-and-formatting.md +++ b/.agents/docs-and-formatting.md @@ -31,6 +31,8 @@ Load this file when changing documentation, public APIs, protocol specs, benchma ## Formatting Commands - Markdown: `prettier --write ` +- Do not format Markdown under `tasks/`, including task design, plan, progress, state, history, + and lessons files. These files are agent working state rather than repository documentation. - Python code, including `compiler/`, `benchmarks/`, `integration_tests/`, and `python/`: `python -m ruff format ` and `python -m ruff check --fix ` From 4e5ca490c5e31ccd3790945b416bf9ae6671ebc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=85=95=E7=99=BD?= Date: Tue, 5 May 2026 23:41:35 +0800 Subject: [PATCH 03/10] add comprehensive read checks --- BUILD | 1 + cpp/fory/serialization/context.cc | 40 +- cpp/fory/serialization/fory.h | 51 +- cpp/fory/serialization/serialization_test.cc | 200 +- cpp/fory/serialization/serializer.h | 12 +- cpp/fory/serialization/type_resolver.cc | 379 ++-- cpp/fory/serialization/type_resolver.h | 3 +- csharp/src/Fory/Fory.cs | 62 +- csharp/src/Fory/ForyFlags.cs | 6 +- csharp/src/Fory/ReadContext.cs | 2 + csharp/src/Fory/TypeInfo.cs | 9 +- csharp/src/Fory/TypeMeta.cs | 239 ++- csharp/src/Fory/TypeResolver.cs | 16 + csharp/tests/Fory.Tests/ForyRuntimeTests.cs | 65 +- .../tests/Fory.Tests/RuntimeEdgeCaseTests.cs | 42 +- .../fory/lib/src/context/read_context.dart | 33 +- dart/packages/fory/lib/src/fory.dart | 41 +- .../fory/lib/src/meta/meta_string.dart | 2 + .../packages/fory/lib/src/meta/type_meta.dart | 30 +- .../fory/lib/src/resolver/type_resolver.dart | 226 ++- .../packages/fory/lib/src/util/hash_util.dart | 9 +- .../fory/test/decimal_serializer_test.dart | 4 +- .../fory/test/time_serializer_test.dart | 2 +- .../fory/test/xlang_protocol_test.dart | 29 +- docs/specification/java_serialization_spec.md | 35 +- .../specification/xlang_serialization_spec.md | 46 +- go/fory/decimal.go | 4 +- go/fory/enum.go | 22 +- go/fory/enum_test.go | 98 + go/fory/fory.go | 180 +- go/fory/fory_typed_test.go | 38 + go/fory/map.go | 8 +- go/fory/map_primitive.go | 261 ++- go/fory/map_primitive_test.go | 61 + go/fory/primitive.go | 60 +- go/fory/reader.go | 143 +- go/fory/skip.go | 137 +- go/fory/skip_test.go | 132 ++ go/fory/stream.go | 13 +- go/fory/string.go | 8 +- go/fory/struct_test.go | 17 +- go/fory/time.go | 12 +- go/fory/type_def.go | 381 +++- go/fory/type_def_test.go | 188 +- go/fory/type_resolver.go | 7 +- go/fory/writer.go | 4 +- .../src/main/java/org/apache/fory/Fory.java | 47 +- .../fory/collection/LongLongByteMap.java | 20 + .../apache/fory/context/MetaStringReader.java | 114 +- .../apache/fory/io/ForyReadableChannel.java | 28 +- .../org/apache/fory/memory/MemoryBuffer.java | 22 + .../fory/meta/DeflaterMetaCompressor.java | 9 + .../apache/fory/meta/EncodedMetaString.java | 6 +- .../org/apache/fory/meta/MetaCompressor.java | 10 + .../fory/meta/NativeTypeDefDecoder.java | 92 +- .../fory/meta/NativeTypeDefEncoder.java | 63 +- .../java/org/apache/fory/meta/TypeDef.java | 42 +- .../org/apache/fory/meta/TypeDefDecoder.java | 130 +- .../org/apache/fory/meta/TypeDefEncoder.java | 95 +- .../fory/meta/TypeEqualMetaCompressor.java | 5 + .../apache/fory/resolver/ClassResolver.java | 52 +- .../apache/fory/resolver/SharedRegistry.java | 45 +- .../apache/fory/resolver/TypeNameBytes.java | 18 +- .../apache/fory/resolver/TypeResolver.java | 79 +- .../apache/fory/resolver/XtypeResolver.java | 17 +- .../fory/serializer/BufferSerializers.java | 7 + .../fory/serializer/ExceptionSerializers.java | 5 + .../serializer/ObjectStreamSerializer.java | 66 +- .../collection/ChildContainerSerializers.java | 4 + .../test/java/org/apache/fory/ForyTest.java | 8 + .../test/java/org/apache/fory/StreamTest.java | 61 + .../fory/meta/DeflaterMetaCompressorTest.java | 11 + .../fory/meta/NativeTypeDefEncoderTest.java | 90 +- .../apache/fory/meta/TypeDefEncoderTest.java | 175 ++ .../org/apache/fory/meta/TypeDefTest.java | 2 +- .../fory/resolver/ClassResolverTest.java | 41 +- .../fory/resolver/MetaStringIOTest.java | 89 + .../serializer/BufferSerializersTest.java | 22 + .../ObjectStreamSerializerTest.java | 55 +- .../extension/meta/TypeDefEncoderTest.java | 2 +- javascript/packages/core/lib/context.ts | 123 +- javascript/packages/core/lib/fory.ts | 86 +- javascript/packages/core/lib/meta/TypeMeta.ts | 468 +++-- javascript/packages/core/lib/reader/index.ts | 96 +- javascript/packages/core/lib/type.ts | 5 +- javascript/test/datetime.test.ts | 41 +- javascript/test/fory.test.ts | 10 +- javascript/test/typemeta.test.ts | 184 +- python/pyfory/_fory.py | 22 +- python/pyfory/context.pxi | 89 +- python/pyfory/context.py | 18 +- python/pyfory/format/infer.py | 10 +- python/pyfory/format/tests/test_infer.py | 42 + python/pyfory/meta/typedef.py | 445 ++++- python/pyfory/meta/typedef_decoder.py | 108 +- python/pyfory/meta/typedef_encoder.py | 82 +- python/pyfory/registry.py | 16 +- python/pyfory/serialization.pyx | 47 +- python/pyfory/serializer.py | 211 +- .../pyfory/tests/test_metastring_resolver.py | 113 +- python/pyfory/tests/test_policy.py | 252 ++- python/pyfory/tests/test_typedef_encoding.py | 141 +- python/pyfory/type_util.py | 10 +- rust/fory-core/src/fory.rs | 76 +- rust/fory-core/src/meta/type_meta.rs | 322 ++- rust/fory-core/src/resolver/meta_resolver.rs | 8 +- rust/fory-core/src/resolver/type_resolver.rs | 24 +- rust/fory-core/src/type_id.rs | 5 +- rust/tests/tests/test_cross_language.rs | 2 +- rust/tests/tests/test_meta.rs | 2 +- swift/Sources/Fory/Fory.swift | 1026 +++++----- swift/Sources/Fory/ForyFlags.swift | 30 +- swift/Sources/Fory/ReadContext.swift | 6 + swift/Sources/Fory/TypeMeta.swift | 1275 ++++++------ swift/Sources/Fory/TypeResolver.swift | 1225 ++++++------ swift/Tests/ForyTests/DecimalTests.swift | 150 +- swift/Tests/ForyTests/ForySwiftTests.swift | 1762 +++++++++-------- 117 files changed, 8787 insertions(+), 4635 deletions(-) create mode 100644 go/fory/enum_test.go create mode 100644 go/fory/map_primitive_test.go create mode 100644 go/fory/skip_test.go diff --git a/BUILD b/BUILD index 69350bd55b..9928ab77c4 100644 --- a/BUILD +++ b/BUILD @@ -52,6 +52,7 @@ pyx_library( "//cpp/fory/type:fory_type", "//python/pyfory/cpp:_pyfory", "//cpp/fory/thirdparty:flat_hash_map", + "//cpp/fory/thirdparty:libmmh3", ], ) diff --git a/cpp/fory/serialization/context.cc b/cpp/fory/serialization/context.cc index 7967a33352..946d37d8f6 100644 --- a/cpp/fory/serialization/context.cc +++ b/cpp/fory/serialization/context.cc @@ -500,48 +500,28 @@ Result ReadContext::read_type_meta() { // Check if we already parsed this type meta (cache lookup by header) if (has_last_meta_header_ && meta_header == last_meta_header_) { - // Fast path: same header as last parsed + // Header-cache hits intentionally skip without rehashing. Entries reach + // this cache only after a successful TypeMeta parse and 52-bit body-hash + // validation. const TypeInfo *cached = last_meta_type_info_; reading_type_infos_.push_back(cached); - if (cached && !cached->type_def.empty()) { - const size_t type_def_size = cached->type_def.size(); - if (type_def_size >= sizeof(int64_t) && - type_def_size <= std::numeric_limits::max()) { - Error skip_error; - buffer_->skip(static_cast(type_def_size - sizeof(int64_t)), - skip_error); - if (FORY_PREDICT_FALSE(!skip_error.ok())) { - return Unexpected(std::move(skip_error)); - } - return cached; - } - } - FORY_RETURN_NOT_OK(TypeMeta::skip_bytes(*buffer_, meta_header)); + FORY_RETURN_NOT_OK( + TypeMeta::skip_bytes_for_validated_header(*buffer_, meta_header)); return cached; } auto *cache_entry = parsed_type_infos_.find(meta_header); if (cache_entry != nullptr) { - // Found in cache - reuse and skip the bytes + // Header-cache hits intentionally skip without rehashing. Entries reach + // this cache only after a successful TypeMeta parse and 52-bit body-hash + // validation. const TypeInfo *cached = cache_entry->second; reading_type_infos_.push_back(cached); has_last_meta_header_ = true; last_meta_header_ = meta_header; last_meta_type_info_ = cached; - if (cached && !cached->type_def.empty()) { - const size_t type_def_size = cached->type_def.size(); - if (type_def_size >= sizeof(int64_t) && - type_def_size <= std::numeric_limits::max()) { - Error skip_error; - buffer_->skip(static_cast(type_def_size - sizeof(int64_t)), - skip_error); - if (FORY_PREDICT_FALSE(!skip_error.ok())) { - return Unexpected(std::move(skip_error)); - } - return cached; - } - } - FORY_RETURN_NOT_OK(TypeMeta::skip_bytes(*buffer_, meta_header)); + FORY_RETURN_NOT_OK( + TypeMeta::skip_bytes_for_validated_header(*buffer_, meta_header)); return cached; } diff --git a/cpp/fory/serialization/fory.h b/cpp/fory/serialization/fory.h index fee0c71aa5..8edc1c50a9 100644 --- a/cpp/fory/serialization/fory.h +++ b/cpp/fory/serialization/fory.h @@ -629,15 +629,13 @@ class Fory : public BaseFory { Buffer buffer(const_cast(data), static_cast(size), false); - FORY_TRY(header, read_header(buffer)); - if (header.is_null) { - return Unexpected(Error::invalid_data("Cannot deserialize null object")); + Error header_error; + const uint8_t header = buffer.read_uint8(header_error); + if (FORY_PREDICT_FALSE(!header_error.ok())) { + return Unexpected(std::move(header_error)); } - if (FORY_PREDICT_FALSE(header.is_xlang != config_.xlang)) { - return Unexpected(Error::invalid_data( - "Protocol mismatch: payload xlang=" + - std::string(header.is_xlang ? "true" : "false") + - ", local xlang=" + std::string(config_.xlang ? "true" : "false"))); + if (FORY_PREDICT_FALSE(header != precomputed_header_)) { + return Unexpected(invalid_root_header(header)); } read_ctx_->attach(buffer); @@ -668,19 +666,13 @@ class Fory : public BaseFory { if (FORY_PREDICT_FALSE(!finalized_)) { ensure_finalized(); } - auto header_result = read_header(buffer); - if (FORY_PREDICT_FALSE(!header_result.ok())) { - return Unexpected(std::move(header_result).error()); + Error header_error; + const uint8_t header = buffer.read_uint8(header_error); + if (FORY_PREDICT_FALSE(!header_error.ok())) { + return Unexpected(std::move(header_error)); } - auto header = std::move(header_result).value(); - if (header.is_null) { - return Unexpected(Error::invalid_data("Cannot deserialize null object")); - } - if (FORY_PREDICT_FALSE(header.is_xlang != config_.xlang)) { - return Unexpected(Error::invalid_data( - "Protocol mismatch: payload xlang=" + - std::string(header.is_xlang ? "true" : "false") + - ", local xlang=" + std::string(config_.xlang ? "true" : "false"))); + if (FORY_PREDICT_FALSE(header != precomputed_header_)) { + return Unexpected(invalid_root_header(header)); } read_ctx_->attach(buffer); @@ -775,11 +767,28 @@ class Fory : public BaseFory { static uint8_t compute_header(bool xlang) { uint8_t flags = 0; if (xlang) { - flags |= (1 << 1); // bit 1: xlang flag + flags |= (1 << 0); } return flags; } + FORY_NOINLINE Error invalid_root_header(uint8_t header) const { + constexpr uint8_t xlang_flag = 1 << 0; + constexpr uint8_t oob_flag = 1 << 1; + constexpr uint8_t known_flags = xlang_flag | oob_flag; + if ((header & ~known_flags) != 0) { + return Error::invalid_data("Unsupported root header bitmap"); + } + if ((header & oob_flag) != 0) { + return Error::invalid_data("Out-of-band mode is not supported"); + } + const bool payload_xlang = (header & xlang_flag) != 0; + return Error::invalid_data( + "Protocol mismatch: payload xlang=" + + std::string(payload_xlang ? "true" : "false") + + ", local xlang=" + std::string(config_.xlang ? "true" : "false")); + } + /// Core serialization implementation. /// TypeMeta is written inline using streaming protocol (no deferred writing). template diff --git a/cpp/fory/serialization/serialization_test.cc b/cpp/fory/serialization/serialization_test.cc index c5c51e3f77..3bfb4ef39e 100644 --- a/cpp/fory/serialization/serialization_test.cc +++ b/cpp/fory/serialization/serialization_test.cc @@ -68,8 +68,8 @@ struct NestedStruct { }; enum class Color { RED, GREEN, BLUE }; -enum class LegacyStatus : int32_t { NEG = -3, ZERO = 0, LARGE = 42 }; -FORY_ENUM(LegacyStatus, NEG, ZERO, LARGE); +enum class SignedScopedStatus : int32_t { NEG = -3, ZERO = 0, LARGE = 42 }; +FORY_ENUM(SignedScopedStatus, NEG, ZERO, LARGE); enum class SparseStatus : int32_t { UNKNOWN = 4096, OK = 8192 }; FORY_ENUM(SparseStatus, UNKNOWN, OK); @@ -95,7 +95,7 @@ inline void register_test_types(Fory &fory) { // Register all enum types used in tests fory.register_enum(type_id++); - fory.register_enum(type_id++); + fory.register_enum(type_id++); fory.register_enum(type_id++); fory.register_enum(type_id++); } @@ -224,7 +224,7 @@ TEST(SerializationTest, DateExposesDaysSinceEpochAccessorAndRoundTrips) { std::vector bytes = std::move(serialize_result).value(); Buffer expected; - expected.write_uint8(0b10); + expected.write_uint8(0b1); expected.write_int8(NOT_NULL_VALUE_FLAG); expected.write_uint8(static_cast(TypeId::DATE)); expected.write_var_int64(-1); @@ -308,7 +308,7 @@ TEST(SerializationTest, DecimalRejectsNonCanonicalBigPayloads) { auto fory = Fory::builder().xlang(true).track_ref(false).build(); Buffer zero_big_encoding; - zero_big_encoding.write_uint8(0b10); + zero_big_encoding.write_uint8(0b1); zero_big_encoding.write_int8(NOT_NULL_VALUE_FLAG); zero_big_encoding.write_uint8(static_cast(TypeId::DECIMAL)); zero_big_encoding.write_var_int32(0); @@ -322,7 +322,7 @@ TEST(SerializationTest, DecimalRejectsNonCanonicalBigPayloads) { std::string::npos); Buffer trailing_zero_payload; - trailing_zero_payload.write_uint8(0b10); + trailing_zero_payload.write_uint8(0b1); trailing_zero_payload.write_int8(NOT_NULL_VALUE_FLAG); trailing_zero_payload.write_uint8(static_cast(TypeId::DECIMAL)); trailing_zero_payload.write_var_int32(0); @@ -459,9 +459,9 @@ TEST(SerializationTest, SparseEnumRoundtrip) { TEST(SerializationTest, EnumSerializesOrdinalValue) { auto fory = Fory::builder().xlang(true).track_ref(false).build(); - fory.register_enum(1); + fory.register_enum(1); - auto bytes_result = fory.serialize(LegacyStatus::LARGE); + auto bytes_result = fory.serialize(SignedScopedStatus::LARGE); ASSERT_TRUE(bytes_result.ok()) << "Serialization failed: " << bytes_result.error().to_string(); @@ -500,9 +500,9 @@ TEST(SerializationTest, OldEnumSerializesOrdinalValue) { TEST(SerializationTest, EnumOrdinalMappingHandlesNonZeroStart) { auto fory = Fory::builder().xlang(true).track_ref(false).build(); - fory.register_enum(1); + fory.register_enum(1); - auto bytes_result = fory.serialize(LegacyStatus::NEG); + auto bytes_result = fory.serialize(SignedScopedStatus::NEG); ASSERT_TRUE(bytes_result.ok()) << "Serialization failed: " << bytes_result.error().to_string(); @@ -516,17 +516,18 @@ TEST(SerializationTest, EnumOrdinalMappingHandlesNonZeroStart) { // Ordinal 0 encoded as varuint32 is just 1 byte with value 0 EXPECT_EQ(bytes[offset + 3], 0); - auto roundtrip = fory.deserialize(bytes.data(), bytes.size()); + auto roundtrip = + fory.deserialize(bytes.data(), bytes.size()); ASSERT_TRUE(roundtrip.ok()) << "Deserialization failed: " << roundtrip.error().to_string(); - EXPECT_EQ(roundtrip.value(), LegacyStatus::NEG); + EXPECT_EQ(roundtrip.value(), SignedScopedStatus::NEG); } TEST(SerializationTest, EnumOrdinalMappingRejectsInvalidOrdinal) { auto fory = Fory::builder().xlang(true).track_ref(false).build(); - fory.register_enum(1); + fory.register_enum(1); - auto bytes_result = fory.serialize(LegacyStatus::NEG); + auto bytes_result = fory.serialize(SignedScopedStatus::NEG); ASSERT_TRUE(bytes_result.ok()) << "Serialization failed: " << bytes_result.error().to_string(); @@ -536,7 +537,8 @@ TEST(SerializationTest, EnumOrdinalMappingRejectsInvalidOrdinal) { // offset + 3 Replace the valid ordinal with an invalid one (99 as varuint32) bytes[offset + 3] = 99; - auto decode = fory.deserialize(bytes.data(), bytes.size()); + auto decode = + fory.deserialize(bytes.data(), bytes.size()); EXPECT_FALSE(decode.ok()); } @@ -689,6 +691,41 @@ TEST(SerializationTest, DeserializeRejectsXlangProtocolMismatch) { std::string::npos); } +TEST(SerializationTest, RootHeaderUsesXlangBitZero) { + auto fory = Fory::builder().xlang(true).build(); + auto bytes_result = fory.serialize(123); + ASSERT_TRUE(bytes_result.ok()) + << "Serialization failed: " << bytes_result.error().to_string(); + ASSERT_FALSE(bytes_result.value().empty()); + EXPECT_EQ(bytes_result.value()[0], 0x01); +} + +TEST(SerializationTest, DeserializeRejectsRootHeaderReservedBits) { + auto fory = Fory::builder().xlang(true).build(); + auto bytes_result = fory.serialize(123); + ASSERT_TRUE(bytes_result.ok()) + << "Serialization failed: " << bytes_result.error().to_string(); + + std::vector bytes = bytes_result.value(); + bytes[0] = 0x05; + auto result = fory.deserialize(bytes.data(), bytes.size()); + ASSERT_FALSE(result.ok()); + EXPECT_EQ(result.error().code(), ErrorCode::InvalidData); +} + +TEST(SerializationTest, DeserializeRejectsOutOfBandRootHeader) { + auto fory = Fory::builder().xlang(true).build(); + auto bytes_result = fory.serialize(123); + ASSERT_TRUE(bytes_result.ok()) + << "Serialization failed: " << bytes_result.error().to_string(); + + std::vector bytes = bytes_result.value(); + bytes[0] = 0x03; + auto result = fory.deserialize(bytes.data(), bytes.size()); + ASSERT_FALSE(result.ok()); + EXPECT_EQ(result.error().code(), ErrorCode::InvalidData); +} + TEST(SerializationTest, RegistrationByIdFailureDoesNotLeakTypeInfo) { auto fory = Fory::builder().xlang(true).track_ref(false).build(); TypeResolver &resolver = fory.type_resolver(); @@ -756,6 +793,139 @@ TEST(SerializationTest, TypeMetaRejectsOverConsumedDeclaredSize) { EXPECT_EQ(parsed.error().code(), ErrorCode::InvalidData); } +TEST(SerializationTest, TypeMetaHeaderUses52BitBodyHash) { + std::vector fields; + fields.emplace_back( + "value", FieldType(static_cast(TypeId::VARINT32), false)); + TypeMeta meta = TypeMeta::from_fields(static_cast(TypeId::STRUCT), + "", "S", false, 1, std::move(fields)); + auto bytes_result = meta.to_bytes(); + ASSERT_TRUE(bytes_result.ok()) + << "TypeMeta serialization failed: " << bytes_result.error().to_string(); + + const std::vector &bytes = bytes_result.value(); + ASSERT_GT(bytes.size(), sizeof(uint64_t)); + uint64_t header = 0; + std::memcpy(&header, bytes.data(), sizeof(header)); + + constexpr uint64_t kMetaSizeMask = 0xff; + constexpr uint64_t kCompressMetaFlag = 0x100; + constexpr uint64_t kReservedBitsMask = 0xe00; + constexpr uint32_t kHashShift = 12; + + EXPECT_EQ(header & kCompressMetaFlag, 0); + EXPECT_EQ(header & kReservedBitsMask, 0); + ASSERT_NE(header & kMetaSizeMask, kMetaSizeMask); + uint64_t meta_size = header & kMetaSizeMask; + ASSERT_EQ(bytes.size(), sizeof(uint64_t) + meta_size); + ASSERT_GT(meta_size, 0); + uint8_t body_header = bytes[sizeof(uint64_t)]; + EXPECT_EQ(body_header & 0x80, 0x80); + EXPECT_EQ(body_header & 0x40, 0); + EXPECT_EQ(body_header & 0x20, 0); + EXPECT_EQ(body_header & 0x1F, 1); + + std::vector parse_bytes = bytes; + Buffer buffer(parse_bytes); + auto parsed = TypeMeta::from_bytes(buffer, nullptr); + ASSERT_TRUE(parsed.ok()) << parsed.error().to_string(); + EXPECT_EQ(static_cast(header >> kHashShift), + parsed.value()->get_hash()); +} + +TEST(SerializationTest, TypeMetaNonStructHeaderUsesDenseKindCode) { + TypeMeta meta = + TypeMeta::from_fields(static_cast(TypeId::ENUM), "", "E", false, + 7, std::vector{}); + auto bytes_result = meta.to_bytes(); + ASSERT_TRUE(bytes_result.ok()) + << "TypeMeta serialization failed: " << bytes_result.error().to_string(); + + std::vector bytes = bytes_result.value(); + ASSERT_GT(bytes.size(), sizeof(uint64_t)); + EXPECT_EQ(bytes[sizeof(uint64_t)], 0x00); + + Buffer buffer(bytes); + auto parsed = TypeMeta::from_bytes(buffer, nullptr); + ASSERT_TRUE(parsed.ok()) << parsed.error().to_string(); + EXPECT_EQ(parsed.value()->get_type_id(), static_cast(TypeId::ENUM)); +} + +TEST(SerializationTest, TypeMetaRejectsNonStructReservedKindBits) { + TypeMeta meta = + TypeMeta::from_fields(static_cast(TypeId::ENUM), "", "E", false, + 7, std::vector{}); + auto bytes_result = meta.to_bytes(); + ASSERT_TRUE(bytes_result.ok()) + << "TypeMeta serialization failed: " << bytes_result.error().to_string(); + + std::vector bytes = bytes_result.value(); + bytes[sizeof(uint64_t)] |= 0x10; + + Buffer buffer(bytes); + auto parsed = TypeMeta::from_bytes(buffer, nullptr); + ASSERT_FALSE(parsed.ok()); + EXPECT_EQ(parsed.error().code(), ErrorCode::InvalidData); +} + +TEST(SerializationTest, TypeMetaRejectsReservedHeaderBits) { + TypeMeta meta = + TypeMeta::from_fields(static_cast(TypeId::STRUCT), "", "S", + false, 1, std::vector{}); + auto bytes_result = meta.to_bytes(); + ASSERT_TRUE(bytes_result.ok()) + << "TypeMeta serialization failed: " << bytes_result.error().to_string(); + + std::vector bytes = bytes_result.value(); + uint64_t header = 0; + std::memcpy(&header, bytes.data(), sizeof(header)); + header |= 0x200; + std::memcpy(bytes.data(), &header, sizeof(header)); + + Buffer buffer(bytes); + auto parsed = TypeMeta::from_bytes(buffer, nullptr); + ASSERT_FALSE(parsed.ok()); + EXPECT_EQ(parsed.error().code(), ErrorCode::InvalidData); +} + +TEST(SerializationTest, TypeMetaRejectsUnsupportedCompressedHeader) { + TypeMeta meta = + TypeMeta::from_fields(static_cast(TypeId::STRUCT), "", "S", + false, 1, std::vector{}); + auto bytes_result = meta.to_bytes(); + ASSERT_TRUE(bytes_result.ok()) + << "TypeMeta serialization failed: " << bytes_result.error().to_string(); + + std::vector bytes = bytes_result.value(); + uint64_t header = 0; + std::memcpy(&header, bytes.data(), sizeof(header)); + header |= 0x100; + std::memcpy(bytes.data(), &header, sizeof(header)); + + Buffer buffer(bytes); + auto parsed = TypeMeta::from_bytes(buffer, nullptr); + ASSERT_FALSE(parsed.ok()); + EXPECT_EQ(parsed.error().code(), ErrorCode::InvalidData); +} + +TEST(SerializationTest, TypeMetaRejectsBodyHashMismatchAfterParse) { + TypeMeta meta = + TypeMeta::from_fields(static_cast(TypeId::STRUCT), "", "S", + false, 1, std::vector{}); + auto bytes_result = meta.to_bytes(); + ASSERT_TRUE(bytes_result.ok()) + << "TypeMeta serialization failed: " << bytes_result.error().to_string(); + + std::vector bytes = bytes_result.value(); + ASSERT_GT(bytes.size(), sizeof(uint64_t)); + bytes.back() ^= 0x01; + + Buffer buffer(bytes); + auto parsed = TypeMeta::from_bytes(buffer, nullptr); + ASSERT_FALSE(parsed.ok()); + EXPECT_EQ(parsed.error().code(), ErrorCode::InvalidData); +} + // ============================================================================ // Configuration Tests // ============================================================================ diff --git a/cpp/fory/serialization/serializer.h b/cpp/fory/serialization/serializer.h index 36477101a1..8b6a7b15a5 100644 --- a/cpp/fory/serialization/serializer.h +++ b/cpp/fory/serialization/serializer.h @@ -72,7 +72,6 @@ inline bool is_little_endian_system() { /// Fory header information struct HeaderInfo { - bool is_null; bool is_xlang; bool is_oob; uint32_t meta_start_offset; // 0 if not present @@ -89,9 +88,14 @@ inline Result read_header(Buffer &buffer) { return Unexpected(std::move(error)); } HeaderInfo info; - info.is_null = (flags & (1 << 0)) != 0; - info.is_xlang = (flags & (1 << 1)) != 0; - info.is_oob = (flags & (1 << 2)) != 0; + constexpr uint8_t xlang_flag = 1 << 0; + constexpr uint8_t oob_flag = 1 << 1; + constexpr uint8_t known_flags = xlang_flag | oob_flag; + if (FORY_PREDICT_FALSE((flags & ~known_flags) != 0)) { + return Unexpected(Error::invalid_data("Unsupported root header bitmap")); + } + info.is_xlang = (flags & xlang_flag) != 0; + info.is_oob = (flags & oob_flag) != 0; // Note: Meta start offset would be read here if present info.meta_start_offset = 0; diff --git a/cpp/fory/serialization/type_resolver.cc b/cpp/fory/serialization/type_resolver.cc index 86074a7f77..460218399d 100644 --- a/cpp/fory/serialization/type_resolver.cc +++ b/cpp/fory/serialization/type_resolver.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -36,13 +37,18 @@ using namespace meta; // Constants from xlang spec constexpr size_t SMALL_NUM_FIELDS_THRESHOLD = 0b11111; constexpr uint8_t REGISTER_BY_NAME_FLAG = 0b100000; +constexpr uint8_t COMPATIBLE_TYPEDEF_FLAG = 0b01000000; +constexpr uint8_t STRUCT_TYPEDEF_FLAG = 0b10000000; +constexpr uint8_t NON_STRUCT_RESERVED_BITS_MASK = 0b01110000; constexpr size_t FIELD_NAME_SIZE_THRESHOLD = 0b1111; constexpr size_t BIG_NAME_THRESHOLD = 0b111111; -constexpr int64_t META_SIZE_MASK = 0xff; -// Temporary xlang behavior: keep TypeMeta uncompressed because some runtimes -// still do not support TypeMeta decompression. -constexpr int64_t HAS_FIELDS_META_FLAG = 0b1 << 8; -constexpr int8_t NUM_HASH_BITS = 50; +constexpr uint64_t META_SIZE_MASK = 0xff; +constexpr uint64_t COMPRESS_META_FLAG = 0x100; +constexpr uint64_t TYPE_META_RESERVED_BITS_MASK = 0xe00; +constexpr int8_t NUM_HASH_BITS = 52; +constexpr uint32_t TYPE_META_HASH_SHIFT = 64 - NUM_HASH_BITS; +constexpr uint64_t TYPE_META_HASH_BITS_MASK = ~uint64_t{0} + << TYPE_META_HASH_SHIFT; // ============================================================================ // FieldType Implementation @@ -329,6 +335,122 @@ write_meta_name(Buffer &buffer, const std::string &name, return Result(); } +inline bool is_compatible_struct_type_id(uint32_t type_id) { + return type_id == static_cast(TypeId::COMPATIBLE_STRUCT) || + type_id == static_cast(TypeId::NAMED_COMPATIBLE_STRUCT); +} + +inline Result type_meta_kind_code(uint32_t type_id) { + switch (static_cast(type_id)) { + case TypeId::ENUM: + return static_cast(0); + case TypeId::NAMED_ENUM: + return static_cast(1); + case TypeId::EXT: + return static_cast(2); + case TypeId::NAMED_EXT: + return static_cast(3); + case TypeId::TYPED_UNION: + return static_cast(4); + case TypeId::NAMED_UNION: + return static_cast(5); + default: + return Unexpected(Error::type_error("Unsupported TypeMeta kind")); + } +} + +inline Result type_id_from_type_meta_kind(uint8_t kind_code) { + switch (kind_code) { + case 0: + return static_cast(TypeId::ENUM); + case 1: + return static_cast(TypeId::NAMED_ENUM); + case 2: + return static_cast(TypeId::EXT); + case 3: + return static_cast(TypeId::NAMED_EXT); + case 4: + return static_cast(TypeId::TYPED_UNION); + case 5: + return static_cast(TypeId::NAMED_UNION); + default: + return Unexpected(Error::invalid_data("Unsupported TypeMeta kind code")); + } +} + +inline uint64_t compute_type_meta_hash_bits(const uint8_t *meta_bytes, + size_t meta_size) { + int64_t hash_out[2] = {0, 0}; + MurmurHash3_x64_128(meta_bytes, static_cast(meta_size), 47, hash_out); + uint64_t shifted = static_cast(hash_out[0]) << TYPE_META_HASH_SHIFT; + if (static_cast(shifted) < 0) { + shifted = ~shifted + 1; + } + return shifted & TYPE_META_HASH_BITS_MASK; +} + +inline int64_t compute_type_meta_hash(const uint8_t *meta_bytes, + size_t meta_size) { + return static_cast( + compute_type_meta_hash_bits(meta_bytes, meta_size) >> + TYPE_META_HASH_SHIFT); +} + +inline Result validate_type_meta_header(uint64_t header) { + if (FORY_PREDICT_FALSE((header & TYPE_META_RESERVED_BITS_MASK) != 0)) { + return Unexpected( + Error::invalid_data("TypeMeta reserved header bits must be zero")); + } + if (FORY_PREDICT_FALSE((header & COMPRESS_META_FLAG) != 0)) { + return Unexpected( + Error::invalid_data("Compressed TypeMeta is not supported")); + } + return Result(); +} + +inline Result +read_type_meta_size(Buffer &buffer, uint64_t header, size_t *header_size) { + Error error; + uint64_t meta_size = header & META_SIZE_MASK; + if (meta_size == META_SIZE_MASK) { + uint32_t before = buffer.reader_index(); + uint32_t extra = buffer.read_var_uint32(error); + if (FORY_PREDICT_FALSE(!error.ok())) { + return Unexpected(std::move(error)); + } + meta_size += extra; + uint32_t after = buffer.reader_index(); + if (header_size != nullptr) { + *header_size += (after - before); + } + } + if (FORY_PREDICT_FALSE( + meta_size > static_cast(std::numeric_limits::max()))) { + return Unexpected( + Error::invalid_data("TypeMeta body size exceeds supported range")); + } + return static_cast(meta_size); +} + +inline Result validate_type_meta_hash(Buffer &buffer, + uint32_t body_start, + uint32_t meta_size, + int64_t header_hash) { + uint64_t body_end = static_cast(body_start) + meta_size; + if (FORY_PREDICT_FALSE(body_end > buffer.reader_index() || + body_end > buffer.size())) { + return Unexpected( + Error::invalid_data("TypeMeta body range is not readable")); + } + uint64_t computed_hash_bits = compute_type_meta_hash_bits( + buffer.data() + body_start, static_cast(meta_size)); + if (FORY_PREDICT_FALSE((computed_hash_bits >> TYPE_META_HASH_SHIFT) != + static_cast(header_hash))) { + return Unexpected(Error::invalid_data("TypeMeta body hash mismatch")); + } + return Result(); +} + inline Result read_meta_name(Buffer &buffer, const MetaStringDecoder &decoder, const MetaEncoding *encodings, size_t enc_count) { @@ -392,21 +514,33 @@ TypeMeta TypeMeta::from_fields(uint32_t tid, const std::string &ns, Result, Error> TypeMeta::to_bytes() const { Buffer layer_buffer; - // write meta header + bool is_struct = is_struct_type(static_cast(type_id)); size_t num_fields = field_infos.size(); - uint8_t meta_header = - static_cast(std::min(num_fields, SMALL_NUM_FIELDS_THRESHOLD)); - if (register_by_name) { - meta_header |= REGISTER_BY_NAME_FLAG; + if (FORY_PREDICT_FALSE(!is_struct && num_fields != 0)) { + return Unexpected( + Error::invalid_data("Non-struct TypeMeta cannot carry field metadata")); } - layer_buffer.write_uint8(meta_header); - if (num_fields >= SMALL_NUM_FIELDS_THRESHOLD) { - layer_buffer.write_var_uint32(num_fields - SMALL_NUM_FIELDS_THRESHOLD); + if (is_struct) { + uint8_t meta_header = + STRUCT_TYPEDEF_FLAG | + static_cast(std::min(num_fields, SMALL_NUM_FIELDS_THRESHOLD)); + if (is_compatible_struct_type_id(type_id)) { + meta_header |= COMPATIBLE_TYPEDEF_FLAG; + } + if (register_by_name) { + meta_header |= REGISTER_BY_NAME_FLAG; + } + layer_buffer.write_uint8(meta_header); + + if (num_fields >= SMALL_NUM_FIELDS_THRESHOLD) { + layer_buffer.write_var_uint32(num_fields - SMALL_NUM_FIELDS_THRESHOLD); + } + } else { + FORY_TRY(kind_code, type_meta_kind_code(type_id)); + layer_buffer.write_uint8(kind_code); } - // write namespace and type name (if registered by name) using the - // same compact meta string format as Rust/Java. if (register_by_name) { FORY_RETURN_NOT_OK(write_meta_name( layer_buffer, namespace_str, k_namespace_encoder, @@ -417,7 +551,6 @@ Result, Error> TypeMeta::to_bytes() const { k_type_name_encodings, sizeof(k_type_name_encodings) / sizeof(k_type_name_encodings[0]))); } else { - layer_buffer.write_uint8(static_cast(type_id)); if (user_type_id == kInvalidUserTypeId) { return Unexpected( Error::type_error("User type id is required for this type")); @@ -434,26 +567,23 @@ Result, Error> TypeMeta::to_bytes() const { // Now write global binary header Buffer result_buffer; const uint32_t layer_size = layer_buffer.writer_index(); - int64_t meta_size = layer_size; - int64_t header = std::min(META_SIZE_MASK, meta_size); - - bool write_meta_fields_flag = !field_infos.empty(); - if (write_meta_fields_flag) { - header |= HAS_FIELDS_META_FLAG; + if (FORY_PREDICT_FALSE(layer_size > static_cast( + std::numeric_limits::max()))) { + return Unexpected( + Error::invalid_data("TypeMeta body size exceeds supported range")); } + uint64_t meta_size = layer_size; + uint64_t header = std::min(META_SIZE_MASK, meta_size); - // Compute hash - std::vector layer_data(layer_buffer.data(), - layer_buffer.data() + layer_size); - int64_t meta_hash = compute_hash(layer_data); - header |= (meta_hash << (64 - NUM_HASH_BITS)); + header |= compute_type_meta_hash_bits(layer_buffer.data(), layer_size); result_buffer.write_bytes(reinterpret_cast(&header), sizeof(header)); if (meta_size >= META_SIZE_MASK) { - result_buffer.write_var_uint32(meta_size - META_SIZE_MASK); + result_buffer.write_var_uint32( + static_cast(meta_size - META_SIZE_MASK)); } - result_buffer.write_bytes(layer_data.data(), layer_data.size()); + result_buffer.write_bytes(layer_buffer.data(), layer_size); // Use actual bytes written to construct return vector return std::vector(result_buffer.data(), result_buffer.data() + @@ -473,39 +603,52 @@ TypeMeta::from_bytes(Buffer &buffer, const TypeMeta *local_type_info) { } size_t header_size = sizeof(header); - int64_t meta_size = header & META_SIZE_MASK; - if (meta_size == META_SIZE_MASK) { - uint32_t before = buffer.reader_index(); - uint32_t extra = buffer.read_var_uint32(error); - if (FORY_PREDICT_FALSE(!error.ok())) { - return Unexpected(std::move(error)); - } - meta_size += extra; - uint32_t after = buffer.reader_index(); - header_size += (after - before); - } - int64_t meta_hash = header >> (64 - NUM_HASH_BITS); + uint64_t header_bits = static_cast(header); + FORY_RETURN_IF_ERROR(validate_type_meta_header(header_bits)); + FORY_TRY(meta_size, read_type_meta_size(buffer, header_bits, &header_size)); + int64_t meta_hash = static_cast(header_bits >> TYPE_META_HASH_SHIFT); + uint32_t body_start = static_cast(start_pos + header_size); // Read meta header uint8_t meta_header = buffer.read_uint8(error); if (FORY_PREDICT_FALSE(!error.ok())) { return Unexpected(std::move(error)); } - bool register_by_name = (meta_header & REGISTER_BY_NAME_FLAG) != 0; - size_t num_fields = meta_header & SMALL_NUM_FIELDS_THRESHOLD; - if (num_fields == SMALL_NUM_FIELDS_THRESHOLD) { - uint32_t extra = buffer.read_var_uint32(error); - if (FORY_PREDICT_FALSE(!error.ok())) { - return Unexpected(std::move(error)); - } - num_fields += extra; - } - - // Read type ID or namespace/type name uint32_t type_id = 0; uint32_t user_type_id = kInvalidUserTypeId; std::string namespace_str; std::string type_name; + bool register_by_name = false; + size_t num_fields = 0; + + if ((meta_header & STRUCT_TYPEDEF_FLAG) != 0) { + register_by_name = (meta_header & REGISTER_BY_NAME_FLAG) != 0; + bool compatible = (meta_header & COMPATIBLE_TYPEDEF_FLAG) != 0; + if (register_by_name) { + type_id = static_cast( + compatible ? TypeId::NAMED_COMPATIBLE_STRUCT : TypeId::NAMED_STRUCT); + } else { + type_id = static_cast(compatible ? TypeId::COMPATIBLE_STRUCT + : TypeId::STRUCT); + } + num_fields = meta_header & SMALL_NUM_FIELDS_THRESHOLD; + if (num_fields == SMALL_NUM_FIELDS_THRESHOLD) { + uint32_t extra = buffer.read_var_uint32(error); + if (FORY_PREDICT_FALSE(!error.ok())) { + return Unexpected(std::move(error)); + } + num_fields += extra; + } + } else { + if (FORY_PREDICT_FALSE((meta_header & NON_STRUCT_RESERVED_BITS_MASK) != + 0)) { + return Unexpected(Error::invalid_data("Invalid TypeMeta kind header")); + } + FORY_TRY(decoded_type_id, + type_id_from_type_meta_kind(meta_header & 0b1111)); + type_id = decoded_type_id; + register_by_name = is_namespaced_type(static_cast(type_id)); + } if (register_by_name) { static const MetaStringDecoder k_namespace_decoder('.', '_'); @@ -523,11 +666,6 @@ TypeMeta::from_bytes(Buffer &buffer, const TypeMeta *local_type_info) { sizeof(k_type_name_encodings[0]))); type_name = std::move(tn); } else { - uint32_t tid = buffer.read_uint8(error); - if (FORY_PREDICT_FALSE(!error.ok())) { - return Unexpected(std::move(error)); - } - type_id = tid; uint32_t uid = buffer.read_var_uint32(error); if (FORY_PREDICT_FALSE(!error.ok())) { return Unexpected(std::move(error)); @@ -559,13 +697,12 @@ TypeMeta::from_bytes(Buffer &buffer, const TypeMeta *local_type_info) { return Unexpected(Error::invalid_data( "TypeMeta parser consumed beyond declared meta size")); } - if (current_pos < expected_end_pos) { - size_t remaining = expected_end_pos - current_pos; - buffer.skip(static_cast(remaining), error); - if (FORY_PREDICT_FALSE(!error.ok())) { - return Unexpected(std::move(error)); - } + if (FORY_PREDICT_FALSE(current_pos < expected_end_pos)) { + return Unexpected(Error::invalid_data( + "TypeMeta parser did not consume declared meta size")); } + FORY_RETURN_IF_ERROR( + validate_type_meta_hash(buffer, body_start, meta_size, meta_hash)); auto meta = std::make_unique(); meta->hash = meta_hash; @@ -581,18 +718,13 @@ TypeMeta::from_bytes(Buffer &buffer, const TypeMeta *local_type_info) { Result, Error> TypeMeta::from_bytes_with_header(Buffer &buffer, int64_t header) { - Error error; - int64_t meta_size = header & META_SIZE_MASK; - if (meta_size == META_SIZE_MASK) { - uint32_t extra = buffer.read_var_uint32(error); - if (FORY_PREDICT_FALSE(!error.ok())) { - return Unexpected(std::move(error)); - } - meta_size += extra; - } - int64_t meta_hash = header >> (64 - NUM_HASH_BITS); + uint64_t header_bits = static_cast(header); + FORY_RETURN_IF_ERROR(validate_type_meta_header(header_bits)); + FORY_TRY(meta_size, read_type_meta_size(buffer, header_bits, nullptr)); + int64_t meta_hash = static_cast(header_bits >> TYPE_META_HASH_SHIFT); - size_t start_pos = buffer.reader_index(); + uint32_t start_pos = buffer.reader_index(); + Error error; // Read meta header uint8_t meta_header = buffer.read_uint8(error); @@ -600,21 +732,41 @@ TypeMeta::from_bytes_with_header(Buffer &buffer, int64_t header) { return Unexpected(std::move(error)); } - bool register_by_name = (meta_header & REGISTER_BY_NAME_FLAG) != 0; - size_t num_fields = meta_header & SMALL_NUM_FIELDS_THRESHOLD; - if (num_fields == SMALL_NUM_FIELDS_THRESHOLD) { - uint32_t extra = buffer.read_var_uint32(error); - if (FORY_PREDICT_FALSE(!error.ok())) { - return Unexpected(std::move(error)); - } - num_fields += extra; - } - - // Read type ID or namespace/type name uint32_t type_id = 0; uint32_t user_type_id = kInvalidUserTypeId; std::string namespace_str; std::string type_name; + bool register_by_name = false; + size_t num_fields = 0; + + if ((meta_header & STRUCT_TYPEDEF_FLAG) != 0) { + register_by_name = (meta_header & REGISTER_BY_NAME_FLAG) != 0; + bool compatible = (meta_header & COMPATIBLE_TYPEDEF_FLAG) != 0; + if (register_by_name) { + type_id = static_cast( + compatible ? TypeId::NAMED_COMPATIBLE_STRUCT : TypeId::NAMED_STRUCT); + } else { + type_id = static_cast(compatible ? TypeId::COMPATIBLE_STRUCT + : TypeId::STRUCT); + } + num_fields = meta_header & SMALL_NUM_FIELDS_THRESHOLD; + if (num_fields == SMALL_NUM_FIELDS_THRESHOLD) { + uint32_t extra = buffer.read_var_uint32(error); + if (FORY_PREDICT_FALSE(!error.ok())) { + return Unexpected(std::move(error)); + } + num_fields += extra; + } + } else { + if (FORY_PREDICT_FALSE((meta_header & NON_STRUCT_RESERVED_BITS_MASK) != + 0)) { + return Unexpected(Error::invalid_data("Invalid TypeMeta kind header")); + } + FORY_TRY(decoded_type_id, + type_id_from_type_meta_kind(meta_header & 0b1111)); + type_id = decoded_type_id; + register_by_name = is_namespaced_type(static_cast(type_id)); + } if (register_by_name) { static const MetaStringDecoder k_namespace_decoder('.', '_'); @@ -632,27 +784,11 @@ TypeMeta::from_bytes_with_header(Buffer &buffer, int64_t header) { sizeof(k_type_name_encodings[0]))); type_name = std::move(tn); } else { - uint32_t tid = buffer.read_uint8(error); + uint32_t uid = buffer.read_var_uint32(error); if (FORY_PREDICT_FALSE(!error.ok())) { return Unexpected(std::move(error)); } - type_id = tid; - switch (static_cast(type_id)) { - case TypeId::ENUM: - case TypeId::STRUCT: - case TypeId::COMPATIBLE_STRUCT: - case TypeId::EXT: - case TypeId::TYPED_UNION: { - uint32_t uid = buffer.read_var_uint32(error); - if (FORY_PREDICT_FALSE(!error.ok())) { - return Unexpected(std::move(error)); - } - user_type_id = uid; - break; - } - default: - break; - } + user_type_id = uid; } // Read field infos @@ -668,18 +804,17 @@ TypeMeta::from_bytes_with_header(Buffer &buffer, int64_t header) { // CRITICAL FIX: Ensure we consume exactly meta_size bytes size_t current_pos = buffer.reader_index(); - size_t expected_end_pos = start_pos + meta_size; + size_t expected_end_pos = static_cast(start_pos) + meta_size; if (FORY_PREDICT_FALSE(current_pos > expected_end_pos)) { return Unexpected(Error::invalid_data( "TypeMeta parser consumed beyond declared meta size")); } - if (current_pos < expected_end_pos) { - size_t remaining = expected_end_pos - current_pos; - buffer.skip(static_cast(remaining), error); - if (FORY_PREDICT_FALSE(!error.ok())) { - return Unexpected(std::move(error)); - } + if (FORY_PREDICT_FALSE(current_pos < expected_end_pos)) { + return Unexpected(Error::invalid_data( + "TypeMeta parser did not consume declared meta size")); } + FORY_RETURN_IF_ERROR( + validate_type_meta_hash(buffer, start_pos, meta_size, meta_hash)); auto meta = std::make_unique(); meta->hash = meta_hash; @@ -693,9 +828,10 @@ TypeMeta::from_bytes_with_header(Buffer &buffer, int64_t header) { return meta; } -Result TypeMeta::skip_bytes(Buffer &buffer, int64_t header) { +Result TypeMeta::skip_bytes_for_validated_header(Buffer &buffer, + int64_t header) { Error error; - int64_t meta_size = header & META_SIZE_MASK; + uint64_t meta_size = static_cast(header) & META_SIZE_MASK; if (meta_size == META_SIZE_MASK) { uint32_t extra = buffer.read_var_uint32(error); if (FORY_PREDICT_FALSE(!error.ok())) { @@ -703,6 +839,12 @@ Result TypeMeta::skip_bytes(Buffer &buffer, int64_t header) { } meta_size += extra; } + if (FORY_PREDICT_FALSE( + meta_size > + static_cast(std::numeric_limits::max()))) { + return Unexpected( + Error::invalid_data("TypeMeta body size exceeds supported range")); + } buffer.skip(static_cast(meta_size), error); if (FORY_PREDICT_FALSE(!error.ok())) { return Unexpected(std::move(error)); @@ -1026,18 +1168,7 @@ void TypeMeta::assign_field_ids(const TypeMeta *local_type, } int64_t TypeMeta::compute_hash(const std::vector &meta_bytes) { - // Compute hash using MurmurHash3_x64_128 to match Rust/Java - // TypeMeta implementation. We take the high 64 bits and then - // keep only the lower NUM_HASH_BITS bits. - int64_t hash_out[2] = {0, 0}; - MurmurHash3_x64_128(meta_bytes.data(), static_cast(meta_bytes.size()), - 47, hash_out); - - // hash_out[0] is the low 64 bits, hash_out[1] the high 64 bits. - uint64_t high = static_cast(hash_out[1]); - uint64_t mask = (NUM_HASH_BITS >= 64) ? ~uint64_t{0} - : ((uint64_t{1} << NUM_HASH_BITS) - 1); - return static_cast(high & mask); + return compute_type_meta_hash(meta_bytes.data(), meta_bytes.size()); } namespace { diff --git a/cpp/fory/serialization/type_resolver.h b/cpp/fory/serialization/type_resolver.h index 85d928d568..cf42126324 100644 --- a/cpp/fory/serialization/type_resolver.h +++ b/cpp/fory/serialization/type_resolver.h @@ -319,7 +319,8 @@ class TypeMeta { from_bytes_with_header(Buffer &buffer, int64_t header); /// skip type meta in buffer without parsing - static Result skip_bytes(Buffer &buffer, int64_t header); + static Result skip_bytes_for_validated_header(Buffer &buffer, + int64_t header); /// Check struct version consistency static Result check_struct_version(int32_t read_version, diff --git a/csharp/src/Fory/Fory.cs b/csharp/src/Fory/Fory.cs index 60bf46c095..032fec11e7 100644 --- a/csharp/src/Fory/Fory.cs +++ b/csharp/src/Fory/Fory.cs @@ -16,6 +16,7 @@ // under the License. using System.Buffers; +using System.Runtime.CompilerServices; namespace Apache.Fory; @@ -142,15 +143,11 @@ public byte[] Serialize(in T value) ByteWriter writer = _writeContext.Writer; writer.Reset(); Serializer serializer = _typeResolver.GetSerializer(); - bool isNone = value is null; - WriteHead(writer, isNone); - if (!isNone) - { - _writeContext.ResetFor(writer); - RefMode refMode = Config.TrackRef ? RefMode.Tracking : RefMode.NullOnly; - serializer.Write(_writeContext, value, refMode, true, false); - _writeContext.RefWriter.Reset(); - } + WriteHead(writer); + _writeContext.ResetFor(writer); + RefMode refMode = Config.TrackRef ? RefMode.Tracking : RefMode.NullOnly; + serializer.Write(_writeContext, value, refMode, true, false); + _writeContext.RefWriter.Reset(); return writer.ToArray(); } @@ -181,7 +178,7 @@ public T Deserialize(ReadOnlySpan payload) T value = DeserializeFromReader(reader); if (reader.Remaining != 0) { - throw new InvalidDataException($"unexpected trailing bytes after deserializing {typeof(T)}"); + ThrowUnexpectedTrailingBytes(); } return value; @@ -201,7 +198,7 @@ public T Deserialize(byte[] payload) T value = DeserializeFromReader(reader); if (reader.Remaining != 0) { - throw new InvalidDataException($"unexpected trailing bytes after deserializing {typeof(T)}"); + ThrowUnexpectedTrailingBytes(); } return value; @@ -228,46 +225,43 @@ public T Deserialize(ref ReadOnlySequence payload) /// Writes the frame header for a payload. /// /// Destination writer. - /// Whether the payload value is null. - internal void WriteHead(ByteWriter writer, bool isNone) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void WriteHead(ByteWriter writer) { - byte bitmap = ForyHeaderFlag.IsXlang; - - if (isNone) - { - bitmap |= ForyHeaderFlag.IsNull; - } - - writer.WriteUInt8(bitmap); + writer.WriteUInt8(ForyHeaderFlag.IsXlang); } /// /// Reads and validates the frame header. /// /// Source reader. - /// true if the payload value is null; otherwise false. /// Thrown when the peer xlang bitmap does not match this runtime mode. - internal bool ReadHead(ByteReader reader) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void ReadHead(ByteReader reader) { byte bitmap = reader.ReadUInt8(); - bool peerIsXlang = (bitmap & ForyHeaderFlag.IsXlang) != 0; - if (!peerIsXlang) + if (bitmap == ForyHeaderFlag.IsXlang) { - throw new InvalidDataException("xlang bitmap mismatch"); + return; } - - return (bitmap & ForyHeaderFlag.IsNull) != 0; + ThrowInvalidRootHeader(bitmap); } + [MethodImpl(MethodImplOptions.NoInlining)] + private static void ThrowUnexpectedTrailingBytes() => + throw new InvalidDataException($"unexpected trailing bytes after deserializing {typeof(T)}"); + + [MethodImpl(MethodImplOptions.NoInlining)] + private static void ThrowInvalidRootHeader(byte bitmap) => + throw new InvalidDataException((bitmap & ForyHeaderFlag.IsXlang) == 0 + ? "xlang bitmap mismatch" + : $"unsupported root header bitmap 0x{bitmap:X2}"); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] private T DeserializeFromReader(ByteReader reader) { - bool isNone = ReadHead(reader); + ReadHead(reader); Serializer serializer = _typeResolver.GetSerializer(); - if (isNone) - { - return serializer.DefaultValue; - } - ReadContext readContext = _readContext; readContext.ResetFor(reader); RefMode refMode = Config.TrackRef ? RefMode.Tracking : RefMode.NullOnly; diff --git a/csharp/src/Fory/ForyFlags.cs b/csharp/src/Fory/ForyFlags.cs index 1a732c1d45..6f8613870a 100644 --- a/csharp/src/Fory/ForyFlags.cs +++ b/csharp/src/Fory/ForyFlags.cs @@ -47,7 +47,7 @@ public static RefMode From(bool nullable, bool trackRef) internal static class ForyHeaderFlag { - public const byte IsNull = 0x01; - public const byte IsXlang = 0x02; - public const byte IsOutOfBand = 0x04; + public const byte IsXlang = 0x01; + public const byte IsOutOfBand = 0x02; + public const byte KnownMask = IsXlang | IsOutOfBand; } diff --git a/csharp/src/Fory/ReadContext.cs b/csharp/src/Fory/ReadContext.cs index 53e23e2ee1..beddd01f89 100644 --- a/csharp/src/Fory/ReadContext.cs +++ b/csharp/src/Fory/ReadContext.cs @@ -200,6 +200,8 @@ internal TypeMeta ReadTypeMeta() ulong header = Reader.ReadUInt64(); if (TryGetCachedReadTypeMeta(header, out TypeMeta cachedTypeMeta, out int skipBytesAfterHeader)) { + // Header-cache hits intentionally skip without rehashing. Entries reach this cache only + // after a successful TypeMeta parse and 52-bit body-hash validation. Reader.Skip(skipBytesAfterHeader); StoreReadTypeMeta(cachedTypeMeta, index); return cachedTypeMeta; diff --git a/csharp/src/Fory/TypeInfo.cs b/csharp/src/Fory/TypeInfo.cs index e6a9ecc328..95c47ec86f 100644 --- a/csharp/src/Fory/TypeInfo.cs +++ b/csharp/src/Fory/TypeInfo.cs @@ -712,7 +712,6 @@ private TypeMetaCacheEntry BuildTypeMetaCacheEntry(bool trackRef) compatible: true, Evolving); IReadOnlyList fields = TypeMetaFields(trackRef); - bool hasFieldsMeta = fields.Count > 0; TypeMeta typeMeta; if (RegisterByName) { @@ -727,8 +726,7 @@ private TypeMetaCacheEntry BuildTypeMetaCacheEntry(bool trackRef) NamespaceName.Value, TypeName.Value, true, - fields, - hasFieldsMeta); + fields); } else { @@ -743,13 +741,12 @@ private TypeMetaCacheEntry BuildTypeMetaCacheEntry(bool trackRef) MetaString.Empty('.', '_'), MetaString.Empty('$', '_'), false, - fields, - hasFieldsMeta); + fields); } byte[] encoded = typeMeta.Encode(); ulong header = BinaryPrimitives.ReadUInt64LittleEndian(encoded); - ulong headerHash = header >> (int)(64 - TypeMetaConstants.TypeMetaNumHashBits); + ulong headerHash = header >> TypeMetaConstants.TypeMetaHashShift; return new TypeMetaCacheEntry(typeMeta, encoded, headerHash); } } diff --git a/csharp/src/Fory/TypeMeta.cs b/csharp/src/Fory/TypeMeta.cs index 4d8415260a..597cffe473 100644 --- a/csharp/src/Fory/TypeMeta.cs +++ b/csharp/src/Fory/TypeMeta.cs @@ -19,14 +19,18 @@ namespace Apache.Fory; internal static class TypeMetaConstants { + // 8 size bits + 1 compression bit + 3 reserved bits. + public const int TypeMetaHashShift = 12; + public const ulong TypeMetaHashMask = ulong.MaxValue << TypeMetaHashShift; public const int SmallNumFieldsThreshold = 0b1_1111; public const byte RegisterByNameFlag = 0b10_0000; + public const byte CompatibleFlag = 0b0100_0000; + public const byte StructFlag = 0b1000_0000; public const int FieldNameSizeThreshold = 0b1111; public const int BigNameThreshold = 0b11_1111; - public const ulong TypeMetaHasFieldsMetaFlag = 1UL << 8; - public const ulong TypeMetaCompressedFlag = 1UL << 9; + public const ulong TypeMetaCompressedFlag = 1UL << 8; + public const ulong TypeMetaReservedFlags = 0b111UL << 9; public const ulong TypeMetaSizeMask = 0xFF; - public const ulong TypeMetaNumHashBits = 50; public const ulong TypeMetaHashSeed = 47; public const uint NoUserTypeId = uint.MaxValue; } @@ -395,7 +399,6 @@ public TypeMeta( MetaString typeName, bool registerByName, IReadOnlyList fields, - bool hasFieldsMeta = true, bool compressed = false, ulong headerHash = 0) { @@ -425,7 +428,6 @@ public TypeMeta( TypeName = typeName; RegisterByName = registerByName; Fields = fields; - HasFieldsMeta = hasFieldsMeta; Compressed = compressed; HeaderHash = headerHash; } @@ -442,8 +444,6 @@ public TypeMeta( public IReadOnlyList Fields { get; } - public bool HasFieldsMeta { get; } - public bool Compressed { get; } public ulong HeaderHash { get; } @@ -467,22 +467,7 @@ public byte[] Encode() } byte[] body = EncodeBody(); - (ulong bodyHash, _) = MurmurHash3.X64_128(body, TypeMetaConstants.TypeMetaHashSeed); - ulong shifted = bodyHash << (int)(64 - TypeMetaConstants.TypeMetaNumHashBits); - long signed = unchecked((long)shifted); - long absSigned = signed == long.MinValue ? signed : Math.Abs(signed); - - ulong header = unchecked((ulong)absSigned); - if (HasFieldsMeta) - { - header |= TypeMetaConstants.TypeMetaHasFieldsMetaFlag; - } - - if (Compressed) - { - header |= TypeMetaConstants.TypeMetaCompressedFlag; - } - + ulong header = ComputeHeaderHashBits(body); uint bodySize = (uint)Math.Min(body.Length, (int)TypeMetaConstants.TypeMetaSizeMask); header |= bodySize; ByteWriter writer = new(body.Length + 16); @@ -504,43 +489,50 @@ public static TypeMeta Decode(byte[] bytes) public static TypeMeta Decode(ByteReader reader) { ulong header = reader.ReadUInt64(); - bool compressed = (header & TypeMetaConstants.TypeMetaCompressedFlag) != 0; - bool hasFieldsMeta = (header & TypeMetaConstants.TypeMetaHasFieldsMetaFlag) != 0; - int metaSize = (int)(header & TypeMetaConstants.TypeMetaSizeMask); - if (metaSize == (int)TypeMetaConstants.TypeMetaSizeMask) - { - metaSize += (int)reader.ReadVarUInt32(); - } - + ValidateGlobalHeader(header); + int metaSize = ReadBodySize(reader, header); byte[] encodedBody = reader.ReadBytes(metaSize); - if (compressed) - { - throw new EncodingException("compressed TypeMeta is not supported yet"); - } - ByteReader bodyReader = new(encodedBody); byte metaHeader = bodyReader.ReadUInt8(); - int numFields = metaHeader & TypeMetaConstants.SmallNumFieldsThreshold; - if (numFields == TypeMetaConstants.SmallNumFieldsThreshold) - { - numFields += (int)bodyReader.ReadVarUInt32(); - } - - bool registerByName = (metaHeader & TypeMetaConstants.RegisterByNameFlag) != 0; + bool isStruct = (metaHeader & TypeMetaConstants.StructFlag) != 0; + int numFields = 0; + bool registerByName; uint? typeId; uint? userTypeId; MetaString namespaceName; MetaString typeName; + if (isStruct) + { + registerByName = (metaHeader & TypeMetaConstants.RegisterByNameFlag) != 0; + bool compatible = (metaHeader & TypeMetaConstants.CompatibleFlag) != 0; + typeId = (uint)(registerByName + ? compatible ? global::Apache.Fory.TypeId.NamedCompatibleStruct : global::Apache.Fory.TypeId.NamedStruct + : compatible ? global::Apache.Fory.TypeId.CompatibleStruct : global::Apache.Fory.TypeId.Struct); + numFields = metaHeader & TypeMetaConstants.SmallNumFieldsThreshold; + if (numFields == TypeMetaConstants.SmallNumFieldsThreshold) + { + numFields += (int)bodyReader.ReadVarUInt32(); + } + } + else + { + if ((metaHeader & 0b0111_0000) != 0) + { + throw new InvalidDataException("invalid TypeMeta kind header"); + } + + typeId = NonStructTypeId(metaHeader & 0b1111); + registerByName = IsNamedKind(typeId.Value); + } + if (registerByName) { namespaceName = ReadName(bodyReader, MetaStringDecoder.Namespace, TypeMetaEncodings.NamespaceMetaStringEncodings); typeName = ReadName(bodyReader, MetaStringDecoder.TypeName, TypeMetaEncodings.TypeNameMetaStringEncodings); - typeId = null; userTypeId = null; } else { - typeId = bodyReader.ReadUInt8(); userTypeId = bodyReader.ReadVarUInt32(); namespaceName = MetaString.Empty('.', '_'); typeName = MetaString.Empty('$', '_'); @@ -552,11 +544,17 @@ public static TypeMeta Decode(ByteReader reader) fields.Add(TypeMetaFieldInfo.Read(bodyReader)); } + if (!isStruct && fields.Count != 0) + { + throw new InvalidDataException("non-struct TypeMeta cannot carry field metadata"); + } + if (bodyReader.Remaining != 0) { throw new InvalidDataException("unexpected trailing bytes in TypeMeta body"); } + ValidateParsedTypeMetaHash(header, encodedBody); return new TypeMeta( typeId, userTypeId, @@ -564,9 +562,112 @@ public static TypeMeta Decode(ByteReader reader) typeName, registerByName, fields, - hasFieldsMeta, - compressed, - header >> (int)(64 - TypeMetaConstants.TypeMetaNumHashBits)); + compressed: false, + header >> TypeMetaConstants.TypeMetaHashShift); + } + + internal static void ValidateAndSkipBody(ByteReader reader, ulong header) + { + ValidateGlobalHeader(header); + int metaSize = ReadBodySize(reader, header); + ReadOnlySpan encodedBody = reader.ReadSpan(metaSize); + ValidateParsedTypeMetaHash(header, encodedBody); + } + + private static void ValidateGlobalHeader(ulong header) + { + if ((header & TypeMetaConstants.TypeMetaReservedFlags) != 0) + { + throw new InvalidDataException("invalid TypeMeta global header"); + } + + if ((header & TypeMetaConstants.TypeMetaCompressedFlag) != 0) + { + throw new EncodingException("compressed TypeMeta is not supported yet"); + } + } + + private static int ReadBodySize(ByteReader reader, ulong header) + { + int metaSize = (int)(header & TypeMetaConstants.TypeMetaSizeMask); + if (metaSize == (int)TypeMetaConstants.TypeMetaSizeMask) + { + uint moreSize = reader.ReadVarUInt32(); + if (moreSize > int.MaxValue - metaSize) + { + throw new InvalidDataException("invalid TypeMeta metadata size"); + } + + metaSize += (int)moreSize; + } + + return metaSize; + } + + private static ulong ComputeHeaderHashBits(ReadOnlySpan body) + { + (ulong bodyHash, _) = MurmurHash3.X64_128(body, TypeMetaConstants.TypeMetaHashSeed); + ulong shifted = bodyHash << TypeMetaConstants.TypeMetaHashShift; + long signed = unchecked((long)shifted); + long absSigned = signed == long.MinValue ? signed : Math.Abs(signed); + return unchecked((ulong)absSigned) & TypeMetaConstants.TypeMetaHashMask; + } + + private static void ValidateParsedTypeMetaHash(ulong header, ReadOnlySpan body) + { + ulong expectedHeaderHash = ComputeHeaderHashBits(body); + ulong actualHeaderHash = header & TypeMetaConstants.TypeMetaHashMask; + if (actualHeaderHash != expectedHeaderHash) + { + throw new InvalidDataException("TypeMeta metadata hash mismatch"); + } + } + + private static uint NonStructKindCode(uint typeId) + { + return (global::Apache.Fory.TypeId)typeId switch + { + global::Apache.Fory.TypeId.Enum => 0, + global::Apache.Fory.TypeId.NamedEnum => 1, + global::Apache.Fory.TypeId.Ext => 2, + global::Apache.Fory.TypeId.NamedExt => 3, + global::Apache.Fory.TypeId.TypedUnion => 4, + global::Apache.Fory.TypeId.NamedUnion => 5, + _ => throw new EncodingException($"unsupported TypeMeta kind {typeId}"), + }; + } + + private static uint NonStructTypeId(int kindCode) + { + return kindCode switch + { + 0 => (uint)global::Apache.Fory.TypeId.Enum, + 1 => (uint)global::Apache.Fory.TypeId.NamedEnum, + 2 => (uint)global::Apache.Fory.TypeId.Ext, + 3 => (uint)global::Apache.Fory.TypeId.NamedExt, + 4 => (uint)global::Apache.Fory.TypeId.TypedUnion, + 5 => (uint)global::Apache.Fory.TypeId.NamedUnion, + _ => throw new InvalidDataException($"unsupported TypeMeta kind code {kindCode}"), + }; + } + + private static bool IsStructKind(uint typeId) + { + return (global::Apache.Fory.TypeId)typeId is + global::Apache.Fory.TypeId.Struct or + global::Apache.Fory.TypeId.CompatibleStruct or + global::Apache.Fory.TypeId.NamedStruct or + global::Apache.Fory.TypeId.NamedCompatibleStruct; + } + + private static bool IsNamedKind(uint typeId) + { + return (global::Apache.Fory.TypeId)typeId is + global::Apache.Fory.TypeId.NamedStruct or + global::Apache.Fory.TypeId.NamedCompatibleStruct or + global::Apache.Fory.TypeId.NamedEnum or + global::Apache.Fory.TypeId.NamedExt or + global::Apache.Fory.TypeId.NamedUnion; } /// @@ -708,14 +809,40 @@ private static uint NormalizeTypeIdForMatch(uint typeId) private byte[] EncodeBody() { ByteWriter writer = new(128); - byte metaHeader = (byte)Math.Min(Fields.Count, TypeMetaConstants.SmallNumFieldsThreshold); - if (RegisterByName) + if (!TypeId.HasValue) + { + throw new EncodingException("type id is required"); + } + + bool isStruct = IsStructKind(TypeId.Value); + if (!isStruct && Fields.Count != 0) + { + throw new EncodingException("non-struct TypeMeta cannot carry field metadata"); + } + + byte metaHeader; + if (isStruct) { - metaHeader |= TypeMetaConstants.RegisterByNameFlag; + metaHeader = (byte)(TypeMetaConstants.StructFlag | + Math.Min(Fields.Count, TypeMetaConstants.SmallNumFieldsThreshold)); + if (TypeId.Value is (uint)global::Apache.Fory.TypeId.CompatibleStruct or + (uint)global::Apache.Fory.TypeId.NamedCompatibleStruct) + { + metaHeader |= TypeMetaConstants.CompatibleFlag; + } + + if (RegisterByName) + { + metaHeader |= TypeMetaConstants.RegisterByNameFlag; + } + } + else + { + metaHeader = (byte)NonStructKindCode(TypeId.Value); } writer.WriteUInt8(metaHeader); - if (Fields.Count >= TypeMetaConstants.SmallNumFieldsThreshold) + if (isStruct && Fields.Count >= TypeMetaConstants.SmallNumFieldsThreshold) { writer.WriteVarUInt32((uint)(Fields.Count - TypeMetaConstants.SmallNumFieldsThreshold)); } @@ -727,17 +854,11 @@ private byte[] EncodeBody() } else { - if (!TypeId.HasValue) - { - throw new EncodingException("type id is required in register-by-id mode"); - } - if (!UserTypeId.HasValue || UserTypeId == TypeMetaConstants.NoUserTypeId) { throw new EncodingException("user type id is required in register-by-id mode"); } - writer.WriteUInt8(unchecked((byte)TypeId.Value)); writer.WriteVarUInt32(UserTypeId.Value); } @@ -811,7 +932,6 @@ public bool Equals(TypeMeta? other) TypeName.Equals(other.TypeName) && RegisterByName == other.RegisterByName && Fields.SequenceEqual(other.Fields) && - HasFieldsMeta == other.HasFieldsMeta && Compressed == other.Compressed && HeaderHash == other.HeaderHash; } @@ -829,7 +949,6 @@ public override int GetHashCode() hc.Add(NamespaceName); hc.Add(TypeName); hc.Add(RegisterByName); - hc.Add(HasFieldsMeta); hc.Add(Compressed); hc.Add(HeaderHash); foreach (TypeMetaFieldInfo f in Fields) diff --git a/csharp/src/Fory/TypeResolver.cs b/csharp/src/Fory/TypeResolver.cs index 86b86e680f..91a3556304 100644 --- a/csharp/src/Fory/TypeResolver.cs +++ b/csharp/src/Fory/TypeResolver.cs @@ -937,6 +937,7 @@ private TypeInfo ReadNamedAnyTypeInfo(TypeId wireTypeId, ReadContext context) private TypeInfo ResolveAnyTypeInfoFromMeta(TypeId wireTypeId, TypeMeta typeMeta, bool compatible) { + ValidateTypeMetaWireType(typeMeta, wireTypeId); TypeInfo typeInfo = typeMeta.RegisterByName ? RequireRegisteredTypeInfoByName(typeMeta.NamespaceName.Value, typeMeta.TypeName.Value) : typeMeta.UserTypeId.HasValue @@ -1233,6 +1234,8 @@ private static void ValidateTypeMeta( bool compatible, TypeId actualWireTypeId) { + ValidateTypeMetaWireType(remoteTypeMeta, actualWireTypeId); + if (remoteTypeMeta.RegisterByName) { if (!localInfo.RegisterByName || !localInfo.NamespaceName.HasValue || !localInfo.TypeName.HasValue) @@ -1288,6 +1291,19 @@ private static void ValidateTypeMeta( } } + private static void ValidateTypeMetaWireType(TypeMeta remoteTypeMeta, TypeId actualWireTypeId) + { + if (!remoteTypeMeta.TypeId.HasValue || !IsKnownTypeId(remoteTypeMeta.TypeId.Value)) + { + throw new InvalidDataException("missing or unknown TypeMeta kind"); + } + + if ((TypeId)remoteTypeMeta.TypeId.Value != actualWireTypeId) + { + throw new TypeMismatchException((uint)actualWireTypeId, remoteTypeMeta.TypeId.Value); + } + } + private static void WriteMetaString( WriteContext context, MetaString value, diff --git a/csharp/tests/Fory.Tests/ForyRuntimeTests.cs b/csharp/tests/Fory.Tests/ForyRuntimeTests.cs index 610b44fe1c..b3faee5e7f 100644 --- a/csharp/tests/Fory.Tests/ForyRuntimeTests.cs +++ b/csharp/tests/Fory.Tests/ForyRuntimeTests.cs @@ -833,7 +833,7 @@ public void MacroFieldOrderFollowsForyRules() byte[] data = fory.Serialize(value); ByteReader reader = new(data); - _ = fory.ReadHead(reader); + fory.ReadHead(reader); _ = reader.ReadInt8(); _ = reader.ReadVarUInt32(); _ = reader.ReadVarUInt32(); @@ -1520,13 +1520,28 @@ public void CompatibleStructFastPathValidatesEmbeddedTypeMetaTypeId() writer.Register(200); byte[] payload = writer.Serialize(new OneStringField { F1 = "hello" }); - byte[] tamperedPayload = RewriteCompatibleTypeMetaTypeId(payload, (uint)TypeId.Map); + byte[] tamperedPayload = RewriteCompatibleTypeMetaTypeId(payload, (uint)TypeId.Struct); ForyRuntime reader = ForyRuntime.Builder().Compatible(true).Build(); reader.Register(200); Assert.Throws(() => reader.Deserialize(tamperedPayload)); } + [Fact] + public void CompatibleTypeMetaCacheMissValidatesBodyHashBeforeCaching() + { + ForyRuntime writer = ForyRuntime.Builder().Compatible(true).Build(); + writer.Register(200); + byte[] payload = writer.Serialize(new OneStringField { F1 = "hello" }); + byte[] tamperedPayload = CorruptCompatibleTypeMetaBody(payload); + + ForyRuntime reader = ForyRuntime.Builder().Compatible(true).Build(); + reader.Register(200); + InvalidDataException exception = + Assert.Throws(() => reader.Deserialize(tamperedPayload)); + Assert.Contains("TypeMeta metadata hash mismatch", exception.Message, StringComparison.Ordinal); + } + [Fact] public void TypeMetaAssignFieldIdsPrefersIdAndFallsBackToName() { @@ -1653,21 +1668,7 @@ public void TypeMetaAssignFieldIdsThrowsOnDuplicateRemoteFieldId() private static byte[] RewriteCompatibleTypeMetaTypeId(byte[] payload, uint embeddedTypeId) { - ByteReader reader = new(payload); - _ = reader.ReadUInt8(); // frame header bitmap - - sbyte refFlag = reader.ReadInt8(); - Assert.Equal((sbyte)RefFlag.NotNullValue, refFlag); - - uint wireTypeId = reader.ReadUInt8(); - Assert.Equal((uint)TypeId.CompatibleStruct, wireTypeId); - - uint typeMetaIndexMarker = reader.ReadVarUInt32(); - Assert.Equal(0u, typeMetaIndexMarker & 1u); - - int typeMetaStart = reader.Cursor; - TypeMeta originalTypeMeta = TypeMeta.Decode(reader); - int typeMetaEnd = reader.Cursor; + (int typeMetaStart, int typeMetaEnd, TypeMeta originalTypeMeta) = ReadCompatibleTypeMetaRange(payload); TypeMeta rewrittenTypeMeta = new( embeddedTypeId, @@ -1676,7 +1677,6 @@ private static byte[] RewriteCompatibleTypeMetaTypeId(byte[] payload, uint embed originalTypeMeta.TypeName, originalTypeMeta.RegisterByName, originalTypeMeta.Fields, - originalTypeMeta.HasFieldsMeta, originalTypeMeta.Compressed); byte[] rewrittenTypeMetaBytes = rewrittenTypeMeta.Encode(); @@ -1687,6 +1687,35 @@ private static byte[] RewriteCompatibleTypeMetaTypeId(byte[] payload, uint embed return rewrittenPayload; } + private static byte[] CorruptCompatibleTypeMetaBody(byte[] payload) + { + (int typeMetaStart, int typeMetaEnd, _) = ReadCompatibleTypeMetaRange(payload); + Assert.True(typeMetaEnd > typeMetaStart + sizeof(ulong)); + byte[] malformed = (byte[])payload.Clone(); + malformed[typeMetaEnd - 1] ^= 1; + return malformed; + } + + private static (int TypeMetaStart, int TypeMetaEnd, TypeMeta TypeMeta) ReadCompatibleTypeMetaRange(byte[] payload) + { + ByteReader reader = new(payload); + _ = reader.ReadUInt8(); // frame header bitmap + + sbyte refFlag = reader.ReadInt8(); + Assert.Equal((sbyte)RefFlag.NotNullValue, refFlag); + + uint wireTypeId = reader.ReadUInt8(); + Assert.Equal((uint)TypeId.CompatibleStruct, wireTypeId); + + uint typeMetaIndexMarker = reader.ReadVarUInt32(); + Assert.Equal(0u, typeMetaIndexMarker & 1u); + + int typeMetaStart = reader.Cursor; + TypeMeta originalTypeMeta = TypeMeta.Decode(reader); + int typeMetaEnd = reader.Cursor; + return (typeMetaStart, typeMetaEnd, originalTypeMeta); + } + private static (ulong Encoding, string Decoded) WriteAndReadString(string value) { ByteWriter writer = new(); diff --git a/csharp/tests/Fory.Tests/RuntimeEdgeCaseTests.cs b/csharp/tests/Fory.Tests/RuntimeEdgeCaseTests.cs index a74b21fd3a..171b4bf5d7 100644 --- a/csharp/tests/Fory.Tests/RuntimeEdgeCaseTests.cs +++ b/csharp/tests/Fory.Tests/RuntimeEdgeCaseTests.cs @@ -160,7 +160,7 @@ public void TimeSpanUsesVarIntSeconds() byte[] payload = fory.Serialize(TimeSpan.FromSeconds(1) + TimeSpan.FromTicks(3)); ByteReader reader = new(payload); - Assert.False(fory.ReadHead(reader)); + fory.ReadHead(reader); Assert.Equal((sbyte)RefFlag.NotNullValue, reader.ReadInt8()); Assert.Equal((uint)TypeId.Duration, reader.ReadUInt8()); Assert.Equal(1L, reader.ReadVarInt64()); @@ -175,7 +175,7 @@ public void DateOnlyUsesVarInt64Days() byte[] payload = fory.Serialize(new DateOnly(2021, 11, 23)); ByteReader reader = new(payload); - Assert.False(fory.ReadHead(reader)); + fory.ReadHead(reader); Assert.Equal((sbyte)RefFlag.NotNullValue, reader.ReadInt8()); Assert.Equal((uint)TypeId.Date, reader.ReadUInt8()); Assert.Equal(18_954L, reader.ReadVarInt64()); @@ -252,7 +252,7 @@ public void DecimalUsesCanonicalWireFormat() byte[] payload = fory.Serialize(new ForyDecimal(BigInteger.Zero, 2)); ByteReader reader = new(payload); - Assert.False(fory.ReadHead(reader)); + fory.ReadHead(reader); Assert.Equal((sbyte)RefFlag.NotNullValue, reader.ReadInt8()); Assert.Equal((uint)TypeId.Decimal, reader.ReadUInt8()); Assert.Equal(2, reader.ReadVarInt32()); @@ -261,7 +261,7 @@ public void DecimalUsesCanonicalWireFormat() payload = fory.Serialize(new ForyDecimal(BigInteger.Parse("9223372036854775808"), 0)); reader.Reset(payload); - Assert.False(fory.ReadHead(reader)); + fory.ReadHead(reader); Assert.Equal((sbyte)RefFlag.NotNullValue, reader.ReadInt8()); Assert.Equal((uint)TypeId.Decimal, reader.ReadUInt8()); Assert.Equal(0, reader.ReadVarInt32()); @@ -275,7 +275,7 @@ public void DecimalRejectsNonCanonicalBigPayload() { ForyRuntime fory = ForyRuntime.Builder().Build(); ByteWriter writer = new(); - fory.WriteHead(writer, isNone: false); + fory.WriteHead(writer); writer.WriteInt8((sbyte)RefFlag.NotNullValue); writer.WriteUInt8((byte)TypeId.Decimal); writer.WriteVarInt32(0); @@ -283,7 +283,7 @@ public void DecimalRejectsNonCanonicalBigPayload() _ = Assert.Throws(() => fory.Deserialize(writer.ToArray())); writer.Reset(); - fory.WriteHead(writer, isNone: false); + fory.WriteHead(writer); writer.WriteInt8((sbyte)RefFlag.NotNullValue); writer.WriteUInt8((byte)TypeId.Decimal); writer.WriteVarInt32(0); @@ -302,7 +302,7 @@ public void TimestampNormalizesNegativeFractionalSecond() byte[] payload = fory.Serialize(DateTimeOffset.FromUnixTimeMilliseconds(-1)); ByteReader reader = new(payload); - Assert.False(fory.ReadHead(reader)); + fory.ReadHead(reader); Assert.Equal((sbyte)RefFlag.NotNullValue, reader.ReadInt8()); Assert.Equal((uint)TypeId.Timestamp, reader.ReadUInt8()); Assert.Equal(-1L, reader.ReadInt64()); @@ -407,6 +407,34 @@ public void DeserializeRejectsNonXlangBitmap() Assert.Contains("xlang bitmap mismatch", exception.Message, StringComparison.Ordinal); } + [Fact] + public void SerializeNullRootUsesRefMeta() + { + ForyRuntime fory = ForyRuntime.Builder().Build(); + byte[] payload = fory.Serialize(null); + + Assert.Equal(ForyHeaderFlag.IsXlang, payload[0]); + Assert.Equal(unchecked((byte)(sbyte)RefFlag.Null), payload[1]); + Assert.Null(fory.Deserialize(payload)); + } + + [Fact] + public void DeserializeRejectsUnsupportedRootHeaderBits() + { + ForyRuntime fory = ForyRuntime.Builder().Build(); + byte[] payload = fory.Serialize(123); + + foreach (byte bitmap in new[] { (byte)0x03, (byte)0x05, (byte)0x81 }) + { + byte[] invalidPayload = [.. payload]; + invalidPayload[0] = bitmap; + + InvalidDataException exception = + Assert.Throws(() => fory.Deserialize(invalidPayload)); + Assert.Contains("unsupported root header bitmap", exception.Message, StringComparison.Ordinal); + } + } + [Fact] public void DynamicAnyRejectsUnknownUserTypeId() { diff --git a/dart/packages/fory/lib/src/context/read_context.dart b/dart/packages/fory/lib/src/context/read_context.dart index ebec1f320c..27dd038db9 100644 --- a/dart/packages/fory/lib/src/context/read_context.dart +++ b/dart/packages/fory/lib/src/context/read_context.dart @@ -233,9 +233,38 @@ final class ReadContext { /// Reads a root value using Ref semantics and expected root type [T]. Object? readRefAs() { - return _readRefWithResolved( - (resolved) => _typeResolver.resolveExpectedRootWireType(resolved), + final flag = _refReader.tryPreserveRefId(_buffer); + final preservedRefId = flag >= RefWriter.refValueFlag ? flag : null; + if (flag == RefWriter.nullFlag) { + return null; + } + if (flag == RefWriter.refFlag) { + return _refReader.getReadRef(); + } + final resolved = _typeResolver.resolveExpectedRootWireType( + _readTypeMeta(), ); + final rootPreservedRefId = preservedRefId == null && + flag == RefWriter.notNullValueFlag && + _depth == 0 && + resolved.supportsRef + ? _refReader.preserveRefId() + : null; + final value = readResolvedValue( + resolved, + null, + hasPreservedRef: preservedRefId != null || rootPreservedRefId != null, + ); + if (preservedRefId != null && + resolved.supportsRef && + _refReader.readRefAt(preservedRefId) == null) { + _refReader.setReadRef(preservedRefId, value); + } + if (rootPreservedRefId != null && + _refReader.readRefAt(rootPreservedRefId) == null) { + _refReader.setReadRef(rootPreservedRefId, value); + } + return value; } Object? _readRefWithResolved(TypeInfo Function(TypeInfo) resolveRootType) { diff --git a/dart/packages/fory/lib/src/fory.dart b/dart/packages/fory/lib/src/fory.dart index 80d8190486..96d917538a 100644 --- a/dart/packages/fory/lib/src/fory.dart +++ b/dart/packages/fory/lib/src/fory.dart @@ -39,9 +39,9 @@ import 'package:fory/src/serializer/serializer.dart'; /// /// The Dart runtime only supports xlang payloads. final class Fory { - static const int _nullHeaderFlag = 0x01; - static const int _xlangHeaderFlag = 0x02; - static const int _outOfBandHeaderFlag = 0x04; + static const int _xlangHeaderFlag = 0x01; + static const int _outOfBandHeaderFlag = 0x02; + static const int _knownHeaderFlags = _xlangHeaderFlag | _outOfBandHeaderFlag; late final Buffer _readBuffer; late final Buffer _writeBuffer; @@ -124,11 +124,11 @@ final class Fory { buffer.clear(); _writeContext.prepare(buffer, trackRef: trackRef); try { + buffer.writeUint8(_xlangHeaderFlag); if (value == null) { - buffer.writeUint8(_nullHeaderFlag); + _writeContext.writeRootValue(null, trackRef: trackRef); return; } - buffer.writeUint8(_xlangHeaderFlag); _writeContext.writeRootValue(value, trackRef: trackRef); } finally { _writeContext.reset(); @@ -183,18 +183,8 @@ final class Fory { _readContext.prepare(buffer); try { final header = buffer.readUint8(); - if ((header & _outOfBandHeaderFlag) != 0) { - throw StateError( - 'Out-of-band buffers are not supported by the Dart runtime.', - ); - } - if ((header & _nullHeaderFlag) != 0) { - return null as T; - } - if ((header & _xlangHeaderFlag) == 0) { - throw StateError( - 'Only xlang payloads are supported by the Dart runtime.', - ); + if (header != _xlangHeaderFlag) { + _throwInvalidRootHeader(header); } final value = _readContext.readRefAs(); if (value is T) { @@ -208,6 +198,23 @@ final class Fory { } } + @pragma('vm:never-inline') + Never _throwInvalidRootHeader(int header) { + if ((header & ~_knownHeaderFlags) != 0) { + throw StateError( + 'Unsupported root header bitmap 0x${header.toRadixString(16)}.', + ); + } + if ((header & _outOfBandHeaderFlag) != 0) { + throw StateError( + 'Out-of-band buffers are not supported by the Dart runtime.', + ); + } + throw StateError( + 'Only xlang payloads are supported by the Dart runtime.', + ); + } + /// Registers a generated type. /// /// Exactly one registration mode is required: diff --git a/dart/packages/fory/lib/src/meta/meta_string.dart b/dart/packages/fory/lib/src/meta/meta_string.dart index abd60d517a..7089d0071d 100644 --- a/dart/packages/fory/lib/src/meta/meta_string.dart +++ b/dart/packages/fory/lib/src/meta/meta_string.dart @@ -32,6 +32,8 @@ const int metaStringAllToLowerSpecialEncoding = 4; const int metaStringSmallThreshold = 16; const int typeDefSmallFieldCountThreshold = 0x1f; const int typeDefRegisterByNameFlag = 0x20; +const int typeDefCompatibleFlag = 0x40; +const int typeDefStructFlag = 0x80; const int typeDefBigFieldNameThreshold = 0x0f; const int typeDefBigNameThreshold = 0x3f; diff --git a/dart/packages/fory/lib/src/meta/type_meta.dart b/dart/packages/fory/lib/src/meta/type_meta.dart index dc8a2df5e1..0602f040be 100644 --- a/dart/packages/fory/lib/src/meta/type_meta.dart +++ b/dart/packages/fory/lib/src/meta/type_meta.dart @@ -18,6 +18,7 @@ */ import 'dart:collection'; +import 'dart:typed_data'; import 'package:fory/src/memory/buffer.dart'; import 'package:fory/src/config.dart'; @@ -25,6 +26,7 @@ import 'package:fory/src/meta/meta_string.dart'; import 'package:fory/src/meta/type_ids.dart'; import 'package:fory/src/resolver/type_resolver.dart'; import 'package:fory/src/types/int64.dart'; +import 'package:fory/src/util/hash_util.dart'; /// Wire-level type metadata for one value. final class WireTypeMeta { @@ -53,10 +55,24 @@ final class WireTypeMeta { } final class TypeHeader { + static const int _compressMetaFlag = 1 << 8; + static const int _reservedMetaFlags = 0x0e00; + static const int _hashLow32Mask = 0xfffff000; + final Int64 value; const TypeHeader(this.value); + @pragma('vm:prefer-inline') + void validateGlobal() { + if ((value.low32 & _reservedMetaFlags) != 0) { + throw StateError('Invalid TypeDef global header.'); + } + if ((value.low32 & _compressMetaFlag) != 0) { + throw StateError('Compressed TypeDef metadata is not supported.'); + } + } + @pragma('vm:prefer-inline') int readMetaSize(Buffer buffer) { final lowBits = value.low32 & 0xff; @@ -70,6 +86,15 @@ final class TypeHeader { void skipRemaining(Buffer buffer) { buffer.skip(readMetaSize(buffer)); } + + @pragma('vm:prefer-inline') + void validateBodyHash(Uint8List body) { + final expected = typeDefHeader(body); + if (value.high32Unsigned != expected.high32Unsigned || + (value.low32 & _hashLow32Mask) != (expected.low32 & _hashLow32Mask)) { + throw StateError('Invalid TypeDef metadata hash.'); + } + } } final class ParsedTypeMetaCache { @@ -109,7 +134,7 @@ final class WireTypeMetaEncoder { const WireTypeMetaEncoder(); WireTypeMeta typeMetaFor(Config config, TypeInfo resolvedType) { - final wireTypeId = _wireTypeIdFor(config, resolvedType); + final wireTypeId = wireTypeIdFor(config, resolvedType); final writesTypeDef = wireTypeId == TypeIds.compatibleStruct || wireTypeId == TypeIds.namedCompatibleStruct || (config.compatible && @@ -146,7 +171,8 @@ final class WireTypeMetaEncoder { } } - int _wireTypeIdFor(Config config, TypeInfo resolvedType) { + @pragma('vm:prefer-inline') + int wireTypeIdFor(Config config, TypeInfo resolvedType) { switch (resolvedType.kind) { case RegistrationKind.builtin: return resolvedType.typeId; diff --git a/dart/packages/fory/lib/src/resolver/type_resolver.dart b/dart/packages/fory/lib/src/resolver/type_resolver.dart index cc7a87c0f8..40be46571b 100644 --- a/dart/packages/fory/lib/src/resolver/type_resolver.dart +++ b/dart/packages/fory/lib/src/resolver/type_resolver.dart @@ -55,6 +55,57 @@ import 'package:fory/src/util/hash_util.dart'; enum RegistrationKind { builtin, struct, enumType, ext, union } +bool _isStructTypeDefKind(int typeId) => + typeId == TypeIds.struct || + typeId == TypeIds.compatibleStruct || + typeId == TypeIds.namedStruct || + typeId == TypeIds.namedCompatibleStruct; + +bool _isNamedTypeDefKind(int typeId) => + typeId == TypeIds.namedStruct || + typeId == TypeIds.namedCompatibleStruct || + typeId == TypeIds.namedEnum || + typeId == TypeIds.namedExt || + typeId == TypeIds.namedUnion; + +int _nonStructTypeDefKindCode(int typeId) { + switch (typeId) { + case TypeIds.enumById: + return 0; + case TypeIds.namedEnum: + return 1; + case TypeIds.ext: + return 2; + case TypeIds.namedExt: + return 3; + case TypeIds.typedUnion: + return 4; + case TypeIds.namedUnion: + return 5; + default: + throw StateError('Unsupported TypeDef kind $typeId.'); + } +} + +int _nonStructTypeDefTypeId(int kindCode) { + switch (kindCode) { + case 0: + return TypeIds.enumById; + case 1: + return TypeIds.namedEnum; + case 2: + return TypeIds.ext; + case 3: + return TypeIds.namedExt; + case 4: + return TypeIds.typedUnion; + case 5: + return TypeIds.namedUnion; + default: + throw StateError('Unsupported TypeDef kind code $kindCode.'); + } +} + final class TypeInfo { final Type type; final RegistrationKind kind; @@ -588,25 +639,47 @@ final class TypeResolver { required LinkedHashMap typeDefIds, required MetaStringWriter metaStringWriter, }) { - _wireTypeMetaEncoder.write( - buffer, - wireTypeMetaForResolved(resolved), - writeTypeDef: (wireTypeMeta) => _writeTypeDef( - buffer, - typeDef ?? wireTypeMeta.resolvedType.typeDef!, - typeDefIds: typeDefIds, - ), - writePackageMetaString: (value) => metaStringWriter.writeMetaString( - buffer, - value, - ), - writeTypeNameMetaString: (value) => metaStringWriter.writeMetaString( - buffer, - value, - ), - ); + final wireTypeId = _wireTypeMetaEncoder.wireTypeIdFor(config, resolved); + buffer.writeVarUint32Small7(wireTypeId); + if (_wireTypeWritesUserTypeId(wireTypeId)) { + buffer.writeVarUint32(resolved.userTypeId!); + return; + } + if (_wireTypeWritesTypeDef(wireTypeId)) { + _writeTypeDef(buffer, typeDef ?? resolved.typeDef!, + typeDefIds: typeDefIds); + return; + } + if (_wireTypeWritesNamedType(wireTypeId)) { + metaStringWriter.writeMetaString(buffer, resolved.encodedNamespace!); + metaStringWriter.writeMetaString(buffer, resolved.encodedTypeName!); + } } + @pragma('vm:prefer-inline') + bool _wireTypeWritesUserTypeId(int wireTypeId) => + wireTypeId == TypeIds.enumById || + wireTypeId == TypeIds.struct || + wireTypeId == TypeIds.ext || + wireTypeId == TypeIds.typedUnion; + + @pragma('vm:prefer-inline') + bool _wireTypeWritesTypeDef(int wireTypeId) => + wireTypeId == TypeIds.compatibleStruct || + wireTypeId == TypeIds.namedCompatibleStruct || + (config.compatible && + (wireTypeId == TypeIds.namedEnum || + wireTypeId == TypeIds.namedStruct || + wireTypeId == TypeIds.namedExt || + wireTypeId == TypeIds.namedUnion)); + + @pragma('vm:prefer-inline') + bool _wireTypeWritesNamedType(int wireTypeId) => + wireTypeId == TypeIds.namedEnum || + wireTypeId == TypeIds.namedStruct || + wireTypeId == TypeIds.namedExt || + wireTypeId == TypeIds.namedUnion; + @pragma('vm:prefer-inline') TypeInfo readTypeMeta( Buffer buffer, { @@ -671,6 +744,7 @@ final class TypeResolver { }) { final encoded = _encodeTypeDef( kind: kind, + evolving: evolving, userTypeId: userTypeId, encodedNamespace: encodedNamespace, encodedTypeName: encodedTypeName, @@ -687,24 +761,44 @@ final class TypeResolver { Uint8List _encodeTypeDef({ required RegistrationKind kind, + required bool evolving, required int? userTypeId, required EncodedMetaString? encodedNamespace, required EncodedMetaString? encodedTypeName, required List fields, }) { final metaBuffer = Buffer(); - var classHeader = fields.length; + final byName = userTypeId == null && + encodedNamespace != null && + encodedTypeName != null; + final typeId = _typeDefTypeId(kind, byName: byName, evolving: evolving); + if (!_isStructTypeDefKind(typeId) && fields.isNotEmpty) { + throw StateError( + 'Non-struct TypeDef $typeId cannot carry field metadata.'); + } + var classHeader = 0; metaBuffer.writeByte(0xff); - if (fields.length >= typeDefSmallFieldCountThreshold) { - classHeader = typeDefSmallFieldCountThreshold; - metaBuffer.writeVarUint32Small7( - fields.length - typeDefSmallFieldCountThreshold, - ); + if (_isStructTypeDefKind(typeId)) { + final inlineFieldCount = fields.length >= typeDefSmallFieldCountThreshold + ? typeDefSmallFieldCountThreshold + : fields.length; + classHeader = typeDefStructFlag | inlineFieldCount; + if (typeId == TypeIds.compatibleStruct || + typeId == TypeIds.namedCompatibleStruct) { + classHeader |= typeDefCompatibleFlag; + } + if (fields.length >= typeDefSmallFieldCountThreshold) { + metaBuffer.writeVarUint32Small7( + fields.length - typeDefSmallFieldCountThreshold, + ); + } + if (byName) { + classHeader |= typeDefRegisterByNameFlag; + } + } else { + classHeader = _nonStructTypeDefKindCode(typeId); } - if (userTypeId == null && - encodedNamespace != null && - encodedTypeName != null) { - classHeader |= typeDefRegisterByNameFlag; + if (byName) { _writeTypeDefName( metaBuffer, encodedNamespace.bytes, @@ -720,17 +814,18 @@ final class TypeResolver { ), ); } else { - metaBuffer.writeUint8(_typeDefTypeId(kind)); metaBuffer.writeVarUint32(userTypeId!); } metaBuffer.toBytes()[0] = classHeader; - for (final field in fields) { - _writeTypeDefField(metaBuffer, field); + if (_isStructTypeDefKind(typeId)) { + for (final field in fields) { + _writeTypeDefField(metaBuffer, field); + } } final body = metaBuffer.toBytes(); final buffer = Buffer(); buffer.writeInt64( - typeDefHeader(body, hasFieldsMeta: fields.isNotEmpty), + typeDefHeader(body), ); if (body.length >= 0xff) { buffer.writeVarUint32(body.length - 0xff); @@ -805,16 +900,23 @@ final class TypeResolver { target.writeBytes(bytes); } - int _typeDefTypeId(RegistrationKind kind) { + int _typeDefTypeId( + RegistrationKind kind, { + required bool byName, + required bool evolving, + }) { switch (kind) { case RegistrationKind.struct: - return TypeIds.struct; + if (byName) { + return evolving ? TypeIds.namedCompatibleStruct : TypeIds.namedStruct; + } + return evolving ? TypeIds.compatibleStruct : TypeIds.struct; case RegistrationKind.enumType: - return TypeIds.enumById; + return byName ? TypeIds.namedEnum : TypeIds.enumById; case RegistrationKind.ext: - return TypeIds.ext; + return byName ? TypeIds.namedExt : TypeIds.ext; case RegistrationKind.union: - return TypeIds.typedUnion; + return byName ? TypeIds.namedUnion : TypeIds.typedUnion; case RegistrationKind.builtin: throw StateError('Built-in types do not write TypeDef metadata.'); } @@ -849,12 +951,16 @@ final class TypeResolver { final header = TypeHeader(buffer.readInt64()); final expectedTypeDef = expectedType?.typeDef; if (expectedTypeDef != null && expectedTypeDef.header == header.value) { + // Header-cache hits intentionally skip without rehashing. Entries reach this cache only + // after a successful TypeDef parse and 52-bit body-hash validation. header.skipRemaining(buffer); sharedTypes.add(expectedType!); return wireTypeMetaForResolved(expectedType); } final cached = _parsedTypeMetaCache.lookup(header); if (cached != null) { + // Header-cache hits intentionally skip without rehashing. Entries reach this cache only + // after a successful TypeDef parse and 52-bit body-hash validation. header.skipRemaining(buffer); sharedTypes.add(cached); return wireTypeMetaForResolved(cached); @@ -866,33 +972,69 @@ final class TypeResolver { } TypeInfo _readTypeDefWithHeader(Buffer buffer, TypeHeader header) { + header.validateGlobal(); final metaSize = header.readMetaSize(buffer); - final metaBytes = Buffer.wrap(buffer.readBytes(metaSize)); + final metaBody = buffer.readBytes(metaSize); + final metaBytes = Buffer.wrap(metaBody); final classHeader = metaBytes.readUint8(); - var fieldCount = classHeader & typeDefSmallFieldCountThreshold; - if (fieldCount == typeDefSmallFieldCountThreshold) { - fieldCount += metaBytes.readVarUint32Small7(); + final isStruct = (classHeader & typeDefStructFlag) != 0; + var fieldCount = 0; + bool byName; + int typeId; + if (isStruct) { + byName = (classHeader & typeDefRegisterByNameFlag) != 0; + final compatible = (classHeader & typeDefCompatibleFlag) != 0; + if (byName) { + typeId = + compatible ? TypeIds.namedCompatibleStruct : TypeIds.namedStruct; + } else { + typeId = compatible ? TypeIds.compatibleStruct : TypeIds.struct; + } + fieldCount = classHeader & typeDefSmallFieldCountThreshold; + if (fieldCount == typeDefSmallFieldCountThreshold) { + fieldCount += metaBytes.readVarUint32Small7(); + } + } else { + if ((classHeader & 0x70) != 0) { + throw StateError('Invalid TypeDef kind header.'); + } + typeId = _nonStructTypeDefTypeId(classHeader & 0x0f); + byName = _isNamedTypeDefKind(typeId); } - final byName = (classHeader & typeDefRegisterByNameFlag) != 0; final encodedNamespace = byName ? _readTypeDefName(metaBytes, packageNameEncoding) : null; final encodedTypeName = byName ? _readTypeDefName(metaBytes, typeNameEncoding) : null; int? userTypeId; if (!byName) { - metaBytes.readUint8(); userTypeId = metaBytes.readVarUint32(); } final fields = []; for (var i = 0; i < fieldCount; i += 1) { fields.add(_readTypeDefField(metaBytes)); } + if (!isStruct && fields.isNotEmpty) { + throw StateError('Non-struct TypeDef cannot carry field metadata.'); + } + if (metaBytes.readableBytes != 0) { + throw StateError('Invalid TypeDef metadata size.'); + } + header.validateBodyHash(metaBody); final resolved = userTypeId != null ? resolveUserById(userTypeId) : resolveUserByEncodedName( encodedNamespace!, encodedTypeName!, ); + if (resolved.typeDef?.header != header.value && + _typeDefTypeId( + resolved.kind, + byName: byName, + evolving: resolved.typeDef?.evolving ?? false, + ) != + typeId) { + throw StateError('TypeDef kind does not match registered type metadata.'); + } if (resolved.kind != RegistrationKind.struct) { return resolved; } diff --git a/dart/packages/fory/lib/src/util/hash_util.dart b/dart/packages/fory/lib/src/util/hash_util.dart index c08372b277..834f6d620d 100644 --- a/dart/packages/fory/lib/src/util/hash_util.dart +++ b/dart/packages/fory/lib/src/util/hash_util.dart @@ -25,10 +25,9 @@ import 'package:fory/src/meta/type_ids.dart'; import 'package:fory/src/types/int64.dart'; import 'package:fory/src/types/uint64.dart'; -const int _typeDefCompressMetaFlag = 1 << 9; -const int _typeDefHasFieldsMetaFlag = 1 << 8; +const int _typeDefCompressMetaFlag = 1 << 8; const int _typeDefMetaSizeMask = 0xff; -const int _typeDefHashShift = 14; +const int _typeDefHashShift = 12; final Uint64 _metaStringHashMask = Uint64.fromWords(0xffffff00, 0xffffffff); final Uint64 _c1 = Uint64.fromWords(0x114253d5, 0x87c37b91); @@ -154,7 +153,6 @@ Int64 metaStringHash(List bytes, {int encoding = 0}) { Int64 typeDefHeader( List bytes, { - required bool hasFieldsMeta, bool compressed = false, }) { final hash = _int64FromUint64( @@ -164,9 +162,6 @@ Int64 typeDefHeader( if (compressed) { header = header | _typeDefCompressMetaFlag; } - if (hasFieldsMeta) { - header = header | _typeDefHasFieldsMetaFlag; - } header = header | (bytes.length > _typeDefMetaSizeMask ? _typeDefMetaSizeMask diff --git a/dart/packages/fory/test/decimal_serializer_test.dart b/dart/packages/fory/test/decimal_serializer_test.dart index ed009b5ad4..6c42175aa1 100644 --- a/dart/packages/fory/test/decimal_serializer_test.dart +++ b/dart/packages/fory/test/decimal_serializer_test.dart @@ -88,14 +88,14 @@ void main() { test('rejects non-canonical big decimal payloads', () { final fory = Fory(); final zeroBigEncoding = Uint8List.fromList([ - 0x02, + 0x01, 0xff, TypeIds.decimal, 0x00, 0x01, ]); final trailingZeroPayload = Uint8List.fromList([ - 0x02, + 0x01, 0xff, TypeIds.decimal, 0x00, diff --git a/dart/packages/fory/test/time_serializer_test.dart b/dart/packages/fory/test/time_serializer_test.dart index 8715f845df..0da5524e26 100644 --- a/dart/packages/fory/test/time_serializer_test.dart +++ b/dart/packages/fory/test/time_serializer_test.dart @@ -98,7 +98,7 @@ void main() { final bytes = fory.serialize(value); expect( - bytes, equals(Uint8List.fromList([0x02, 0xff, TypeIds.date, 0x01]))); + bytes, equals(Uint8List.fromList([0x01, 0xff, TypeIds.date, 0x01]))); expect(fory.deserialize(bytes), equals(value)); }); diff --git a/dart/packages/fory/test/xlang_protocol_test.dart b/dart/packages/fory/test/xlang_protocol_test.dart index be3b9bb763..714edd2675 100644 --- a/dart/packages/fory/test/xlang_protocol_test.dart +++ b/dart/packages/fory/test/xlang_protocol_test.dart @@ -20,13 +20,15 @@ import 'dart:typed_data'; import 'package:fory/fory.dart'; +import 'package:fory/src/meta/type_meta.dart'; +import 'package:fory/src/util/hash_util.dart'; import 'package:test/test.dart'; void main() { group('xlang protocol regressions', () { test('deserializes NONE wire values as null', () { final fory = Fory(); - final bytes = Uint8List.fromList([0x02, 0xff, TypeIds.none]); + final bytes = Uint8List.fromList([0x01, 0xff, TypeIds.none]); expect(fory.deserialize(bytes), isNull); expect(fory.deserialize(bytes), isNull); @@ -77,7 +79,7 @@ void main() { final fory = Fory(); final bytes = fory.serializeBuiltin(7, wireTypeId: TypeIds.varInt32); - expect(bytes[0], equals(0x02)); + expect(bytes[0], equals(0x01)); expect(bytes[1], equals(0xff)); expect(bytes[2], equals(TypeIds.varInt32)); expect(fory.deserialize(bytes), equals(7)); @@ -86,7 +88,7 @@ void main() { test('rejects out-of-band xlang payload headers', () { final fory = Fory(); final bytes = Uint8List.fromList(fory.serialize('value')); - bytes[0] |= 0x04; + bytes[0] |= 0x02; expect( () => fory.deserialize(bytes), @@ -99,5 +101,26 @@ void main() { ), ); }); + + test('validates parsed TypeDef body hash before caching', () { + final body = Uint8List.fromList([0x80]); + final header = TypeHeader(typeDefHeader(body)); + final valid = Buffer.wrap(body); + header.skipRemaining(valid); + expect(valid.readableBytes, equals(0)); + + final malformed = Uint8List.fromList(body); + malformed[0] ^= 1; + expect( + () => header.validateBodyHash(malformed), + throwsA( + isA().having( + (error) => error.toString(), + 'message', + contains('metadata hash'), + ), + ), + ); + }); }); } diff --git a/docs/specification/java_serialization_spec.md b/docs/specification/java_serialization_spec.md index 2bac7bc0da..c52695f1cc 100644 --- a/docs/specification/java_serialization_spec.md +++ b/docs/specification/java_serialization_spec.md @@ -44,14 +44,14 @@ Java native serialization writes a one byte bitmap header. The header layout mir bitmap and uses the same flag bits. ``` -| 5 bits | 1 bit | 1 bit | 1 bit | -+--------------+-------+-------+-------+ -| reserved | oob | xlang | null | +| 6 bits | 1 bit | 1 bit | ++---------------+-------+-------+ +| reserved | oob | xlang | ``` -- null flag: 1 when object is null, 0 otherwise. If object is null, other bits are not set. -- xlang flag: 1 when serialization uses xlang format, 0 when serialization uses Java native format. -- oob flag: 1 when `BufferCallback` is not null, 0 otherwise. +- xlang flag: bit 0, set when serialization uses xlang format and clear for Java native format. +- oob flag: bit 1, set when `BufferCallback` is not null. +- reserved bits: bits 2-7, must be zero. The header is always a single byte; no language ID is written. @@ -202,32 +202,37 @@ when shared meta is enabled, or referenced by index when already seen. Header layout (lower bits on the right): ``` -| 50-bit hash | 4 bits reserved | 1 bit compress | 1 bit has_fields_meta | 8-bit size | +| 52-bit hash | 3 bits reserved | 1 bit compress | 8-bit size | ``` - size: lower 8 bits. If size equals the mask (0xFF), write extra size as varuint32 and add it. -- compress: set when payload is compressed. -- has_fields_meta: set when field metadata is present. -- reserved: bits 10-13 are reserved for future use and must be zero. -- hash: 50-bit hash of the payload and flags. +- compress: bit 8, set when payload is compressed. +- reserved: bits 9-11 are reserved for future use and must be zero. +- hash: 52-bit hash of the payload. ### Class meta bytes Class meta encodes a linearized class hierarchy (from parent to leaf) and field metadata: ``` -| num_classes | class_layer_0 | class_layer_1 | ... | +| root_kind_and_num_classes | class_layer_0 | class_layer_1 | ... | class_layer: | num_fields << 1 | registered_flag | [type_id if registered] | | namespace | type_name | field_infos | ``` -- `num_classes` stores `(num_layers - 1)` in a single byte. - - If it equals `0b1111`, read an extra varuint32 small7 and add it. +- `root_kind_and_num_classes` stores the root TypeDef kind in the high four bits and + `(num_layers - 1)` in the low four bits. + - Root kind codes are `STRUCT=0`, `COMPATIBLE_STRUCT=1`, `NAMED_STRUCT=2`, + `NAMED_COMPATIBLE_STRUCT=3`, `ENUM=4`, `NAMED_ENUM=5`, `EXT=6`, `NAMED_EXT=7`, + `TYPED_UNION=8`, and `NAMED_UNION=9`. + - Kind codes `10-14` are reserved and `15` is an extended-kind escape rejected until defined. + - If the low four bits equal `0b1111`, read an extra varuint32 small7 and add it. - The actual number of layers is `num_classes + 1`. - `registered_flag` is 1 if the class is registered by numeric ID. -- If registered by ID, the class type ID follows (varuint32 small7). +- If registered by ID, the one-byte class type ID follows. For user-registered ID kinds, the + user type ID follows as varuint32. - If registered by name or unregistered, namespace and type name are written as meta strings. ### Field info diff --git a/docs/specification/xlang_serialization_spec.md b/docs/specification/xlang_serialization_spec.md index 24adbe3331..b1743918a9 100644 --- a/docs/specification/xlang_serialization_spec.md +++ b/docs/specification/xlang_serialization_spec.md @@ -323,15 +323,14 @@ Detailed byte layout: ``` Byte 0: Bitmap flags - - Bit 0: null flag (0x01) - - Bit 1: xlang flag (0x02) - - Bit 2: oob flag (0x04) - - Bits 3-7: reserved + - Bit 0: xlang flag (0x01) + - Bit 1: oob flag (0x02) + - Bits 2-7: reserved ``` -- **null flag** (bit 0): 1 when object is null, 0 otherwise. If an object is null, only this flag is set. -- **xlang flag** (bit 1): 1 when serialization uses Fory xlang format, 0 when serialization uses Fory language-native format. -- **oob flag** (bit 2): 1 when out-of-band serialization is enabled (BufferCallback is not null), 0 otherwise. +- **xlang flag** (bit 0): 1 when serialization uses Fory xlang format, 0 when serialization uses Fory language-native format. +- **oob flag** (bit 1): 1 when out-of-band serialization is enabled (BufferCallback is not null), 0 otherwise. +- **reserved bits** (bits 2-7): must be zero. All data is encoded in little-endian format. @@ -536,12 +535,11 @@ The 8-byte header is a little-endian uint64: - Low 8 bits: meta size (number of bytes in the TypeDef body). - If meta size >= 0xFF, the low 8 bits are set to 0xFF and an extra `varuint32(meta_size - 0xFF)` follows immediately after the header. -- Bit 8: `HAS_FIELDS_META` (1 = fields metadata present). -- Bit 9: `COMPRESS_META` is reserved for a future xlang metadata-compression extension. +- Bit 8: `COMPRESS_META` is reserved for a future xlang metadata-compression extension. Current xlang writers MUST leave this bit unset and current xlang readers MUST treat a set bit as unsupported. -- Bits 10-13: reserved for future extension (must be zero). -- High 50 bits: hash of the TypeDef body. +- Bits 9-11: reserved for future extension (must be zero). +- High 52 bits: hash of the TypeDef body. #### TypeDef body @@ -551,12 +549,30 @@ TypeDef body has a single layer (fields are flattened in class hierarchy order): | meta header (1 byte) | type spec | field info ... | ``` -Meta header byte: +Meta header byte for struct TypeDefs: +- Bit 7: `IS_STRUCT` (1). +- Bit 6: `COMPATIBLE`. +- Bit 5: `REGISTER_BY_NAME` (1 = namespace + type name, 0 = numeric user type ID). - Bits 0-4: `num_fields` (0-30). - If `num_fields == 31`, read an extra `varuint32` and add it. -- Bit 5: `REGISTER_BY_NAME` (1 = namespace + type name, 0 = numeric type ID). -- Bits 6-7: reserved. + +Meta header byte for non-struct TypeDefs: + +- Bit 7: `IS_STRUCT` (0). +- Bits 4-6: reserved (must be zero). +- Bits 0-3: kind code. + +Non-struct kind codes: + +- `0`: `ENUM` +- `1`: `NAMED_ENUM` +- `2`: `EXT` +- `3`: `NAMED_EXT` +- `4`: `TYPED_UNION` +- `5`: `NAMED_UNION` +- `6-14`: reserved +- `15`: extended-kind escape, rejected until defined Type spec: @@ -564,7 +580,7 @@ Type spec: - `namespace` meta string - `type_name` meta string - Otherwise: - - `type_id` as `varuint32` (small7) + - user type ID as `varuint32` Field info list: diff --git a/go/fory/decimal.go b/go/fory/decimal.go index 92d6e15bea..5bc3a0d540 100644 --- a/go/fory/decimal.go +++ b/go/fory/decimal.go @@ -77,8 +77,8 @@ func (s decimalSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool return } } - if readType { - _ = ctx.buffer.ReadUint8(err) + if readType && !ctx.readExpectedTypeID(DECIMAL) { + return } if ctx.HasError() { return diff --git a/go/fory/enum.go b/go/fory/enum.go index 1a1e63c617..b8f615b977 100644 --- a/go/fory/enum.go +++ b/go/fory/enum.go @@ -85,7 +85,27 @@ func (s *enumSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, } } if readType { - _ = ctx.buffer.ReadUint8(err) + typeID := uint32(ctx.buffer.ReadUint8(err)) + if ctx.HasError() { + return + } + internalID := TypeId(typeID) + if internalID != ENUM && internalID != NAMED_ENUM { + ctx.SetError(TypeMismatchError(internalID, ENUM)) + return + } + typeInfo := ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeID, err) + if ctx.HasError() { + return + } + if typeInfo == nil || typeInfo.Type != s.type_ { + var actualType reflect.Type + if typeInfo != nil { + actualType = typeInfo.Type + } + ctx.SetError(DeserializationErrorf("enum type mismatch: expected %v, got %v", s.type_, actualType)) + return + } } if ctx.HasError() { return diff --git a/go/fory/enum_test.go b/go/fory/enum_test.go new file mode 100644 index 0000000000..f92c4d086d --- /dev/null +++ b/go/fory/enum_test.go @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package fory + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" +) + +type auditEnum int32 +type otherAuditEnum int32 +type namedAuditEnum int32 + +func TestEnumReadConsumesRegisteredTypeInfo(t *testing.T) { + f := NewFory(WithXlang(true)) + require.NoError(t, f.RegisterEnum(auditEnum(0), 101)) + enumType := reflect.TypeOf(auditEnum(0)) + serializer, err := f.typeResolver.getSerializerByType(enumType, false) + require.NoError(t, err) + typeInfo := f.typeResolver.typesInfo[enumType] + require.NotNil(t, typeInfo) + + buf := NewByteBuffer(nil) + bufErr := &Error{} + f.typeResolver.WriteTypeInfo(buf, typeInfo, bufErr) + require.NoError(t, bufErr.CheckError()) + buf.WriteVarUint32Small7(2) + + f.readCtx.SetData(buf.Bytes()) + var result auditEnum + serializer.Read(f.readCtx, RefModeNone, true, false, reflect.ValueOf(&result).Elem()) + require.NoError(t, f.readCtx.CheckError()) + require.Equal(t, auditEnum(2), result) + require.Equal(t, buf.WriterIndex(), f.readCtx.Buffer().ReaderIndex()) +} + +func TestEnumReadRejectsMismatchedRegisteredTypeInfo(t *testing.T) { + f := NewFory(WithXlang(true)) + require.NoError(t, f.RegisterEnum(auditEnum(0), 101)) + require.NoError(t, f.RegisterEnum(otherAuditEnum(0), 102)) + enumType := reflect.TypeOf(auditEnum(0)) + otherType := reflect.TypeOf(otherAuditEnum(0)) + serializer, err := f.typeResolver.getSerializerByType(enumType, false) + require.NoError(t, err) + otherTypeInfo := f.typeResolver.typesInfo[otherType] + require.NotNil(t, otherTypeInfo) + + buf := NewByteBuffer(nil) + bufErr := &Error{} + f.typeResolver.WriteTypeInfo(buf, otherTypeInfo, bufErr) + require.NoError(t, bufErr.CheckError()) + buf.WriteVarUint32Small7(2) + + f.readCtx.SetData(buf.Bytes()) + var result auditEnum + serializer.Read(f.readCtx, RefModeNone, true, false, reflect.ValueOf(&result).Elem()) + require.Error(t, f.readCtx.CheckError()) +} + +func TestNamedEnumReadConsumesNamedTypeInfo(t *testing.T) { + f := NewFory(WithXlang(true)) + require.NoError(t, f.RegisterNamedEnum(namedAuditEnum(0), "example.NamedAuditEnum")) + enumType := reflect.TypeOf(namedAuditEnum(0)) + serializer, err := f.typeResolver.getSerializerByType(enumType, false) + require.NoError(t, err) + typeInfo := f.typeResolver.typesInfo[enumType] + require.NotNil(t, typeInfo) + + buf := NewByteBuffer(nil) + bufErr := &Error{} + f.typeResolver.WriteTypeInfo(buf, typeInfo, bufErr) + require.NoError(t, bufErr.CheckError()) + buf.WriteVarUint32Small7(3) + + f.readCtx.SetData(buf.Bytes()) + var result namedAuditEnum + serializer.Read(f.readCtx, RefModeNone, true, false, reflect.ValueOf(&result).Elem()) + require.NoError(t, f.readCtx.CheckError()) + require.Equal(t, namedAuditEnum(3), result) + require.Equal(t, buf.WriterIndex(), f.readCtx.Buffer().ReaderIndex()) +} diff --git a/go/fory/fory.go b/go/fory/fory.go index 61723d190d..f50fe1f770 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -39,9 +39,9 @@ var ErrNoSerializer = errors.New("fory: no serializer registered for type") // Bitmap flags for protocol header const ( - IsNilFlag = 1 << 0 - XLangFlag = 1 << 1 - OutOfBandFlag = 1 << 2 + XLangFlag = 1 << 0 + OutOfBandFlag = 1 << 1 + headerFlagMask = XLangFlag | OutOfBandFlag ) // ============================================================================ @@ -185,6 +185,9 @@ func New(opts ...Option) *Fory { f.readCtx.refResolver = f.refResolver f.readCtx.compatible = f.config.Compatible f.readCtx.xlang = f.config.IsXlang + if f.config.IsXlang { + f.readCtx.rootHeader = XLangFlag + } return f } @@ -487,13 +490,6 @@ func (f *Fory) Reset() { // For thread-safe usage, use threadsafe.Fory which copies the data internally. func (f *Fory) Serialize(value any) ([]byte, error) { defer f.resetWriteState() - // Check if value is nil interface OR a nil pointer/slice/map/etc. - // In Go, `*int32(nil)` wrapped in `any` is NOT equal to `nil`, but we need to serialize it as null. - if isNilValue(value) { - // Use Java-compatible null format: 3 bytes (magic + bitmap with isNilFlag) - writeNullHeader(f.writeCtx) - return f.writeCtx.buffer.GetByteSlice(0, f.writeCtx.buffer.writerIndex), nil - } // WriteData protocol header writeHeader(f.writeCtx, f.config) @@ -521,16 +517,11 @@ func (f *Fory) Deserialize(data []byte, v any) error { defer f.resetReadState() f.readCtx.SetData(data) - isNull := readHeader(f.readCtx) + readHeader(f.readCtx) if f.readCtx.HasError() { return f.readCtx.TakeError() } - // Check if the serialized object is null - if isNull { - return nil - } - // Deserialize the value - TypeMeta is read inline using streaming protocol target := reflect.ValueOf(v).Elem() f.readCtx.ReadValue(target, RefModeTracking, true) @@ -561,13 +552,6 @@ func (f *Fory) resetWriteState() { // This is useful when you need to write multiple serialized values to the same buffer. // Returns error if serialization fails. func (f *Fory) SerializeTo(buf *ByteBuffer, value any) error { - // Handle nil values - if isNilValue(value) { - // Use Java-compatible null format: 1 byte (bitmap with isNilFlag) - buf.WriteByte_(IsNilFlag) - return nil - } - defer f.resetWriteState() // Temporarily swap buffer @@ -625,18 +609,12 @@ func (f *Fory) DeserializeFrom(buf *ByteBuffer, v any) error { origBuffer := f.readCtx.buffer f.readCtx.buffer = buf - isNull := readHeader(f.readCtx) + readHeader(f.readCtx) if f.readCtx.HasError() { f.readCtx.buffer = origBuffer return f.readCtx.TakeError() } - // Check if the serialized object is null - if isNull { - f.readCtx.buffer = origBuffer - return nil - } - // Deserialize the value - TypeMeta is read inline using streaming protocol target := reflect.ValueOf(v).Elem() f.readCtx.ReadValue(target, RefModeTracking, true) @@ -731,21 +709,11 @@ func (f *Fory) DeserializeWithCallbackBuffers(buffer *ByteBuffer, v any, buffers } // ReadData and validate header - isNull := readHeader(f.readCtx) + readHeader(f.readCtx) if f.readCtx.HasError() { return f.readCtx.TakeError() } - // Check if the serialized object is null - if isNull { - // v must be a pointer so we can set it to nil - rv := reflect.ValueOf(v) - if rv.Kind() == reflect.Ptr && !rv.IsNil() { - rv.Elem().Set(reflect.Zero(rv.Elem().Type())) - } - return nil - } - // v must be a pointer so we can deserialize into it if v == nil { return fmt.Errorf("v cannot be nil") @@ -803,50 +771,34 @@ func writeHeader(ctx *WriteContext, config Config) { ctx.buffer.WriteByte_(bitmap) } -// isNilValue checks if a value is nil, including nil pointers wrapped in any -// In Go, `*int32(nil)` wrapped in `any` is NOT equal to `nil`, but we need to treat it as null. -// -//go:noinline -func isNilValue(value any) bool { - if value == nil { - return true - } - rv := reflect.ValueOf(value) - switch rv.Kind() { - case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Chan, reflect.Func, reflect.Interface: - return rv.IsNil() - } - return false -} - -// writeNullHeader writes a null object header (1 byte: bitmap with isNilFlag) -// This is compatible with Java's null serialization format -// -//go:noinline -func writeNullHeader(ctx *WriteContext) { - ctx.buffer.WriteByte_(IsNilFlag) // bitmap with only isNilFlag set -} - -// Special return value indicating null object in readHeader -// Using math.MinInt32 to avoid conflict with -1 which is used for "no meta offset" -const NullObjectMetaOffset int32 = -0x7FFFFFFF - // readHeader reads and validates the Fory protocol header -// Returns true if the serialized object is null // Sets error on ctx if header is invalid (use ctx.HasError() to check) -func readHeader(ctx *ReadContext) bool { +func readHeader(ctx *ReadContext) { err := ctx.Err() bitmap := ctx.buffer.ReadByte(err) if ctx.HasError() { - return false + return } - - // Check if this is a null object - only bitmap with isNilFlag was written - if (bitmap & IsNilFlag) != 0 { - return true // is null + if bitmap == ctx.rootHeader { + return } + readHeaderSlow(ctx, bitmap) +} - return false // not null +//go:noinline +func readHeaderSlow(ctx *ReadContext, bitmap byte) { + if bitmap&^headerFlagMask != 0 { + ctx.SetError(DeserializationErrorf("unsupported root header bitmap 0x%02x", bitmap)) + return + } + if ((bitmap & XLangFlag) != 0) != ctx.xlang { + ctx.SetError(DeserializationErrorf("header bitmap mismatch at xlang bit")) + return + } + if (bitmap&OutOfBandFlag) != 0 && ctx.outOfBandBuffers == nil { + ctx.SetError(DeserializationErrorf("out-of-band buffers are required by root header")) + return + } } // ============================================================================ @@ -1025,66 +977,84 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { f.readCtx.SetData(data) // ReadData and validate header - isNull := readHeader(f.readCtx) + readHeader(f.readCtx) if f.readCtx.HasError() { return f.readCtx.TakeError() } - // Check if the serialized object is null - if isNull { - var zero T - *target = zero - return nil - } - // Fast path: type switch for common types (Go compiler can optimize this) // For primitives, read null flag, skip type ID, then read value from buffer buf := f.readCtx.buffer err := f.readCtx.Err() switch t := any(target).(type) { case *bool: - _ = buf.ReadInt8(err) // null flag - _ = buf.ReadUint8(err) // type ID + _ = buf.ReadInt8(err) // null flag + if !f.readCtx.readExpectedTypeID(BOOL) { + return f.readCtx.CheckError() + } *t = buf.ReadBool(err) return f.readCtx.CheckError() case *int8: _ = buf.ReadInt8(err) - _ = buf.ReadUint8(err) + if !f.readCtx.readExpectedTypeID(INT8) { + return f.readCtx.CheckError() + } *t = buf.ReadInt8(err) return f.readCtx.CheckError() case *int16: _ = buf.ReadInt8(err) - _ = buf.ReadUint8(err) + if !f.readCtx.readExpectedTypeID(INT16) { + return f.readCtx.CheckError() + } *t = buf.ReadInt16(err) return f.readCtx.CheckError() case *int32: _ = buf.ReadInt8(err) - _ = buf.ReadUint8(err) + if !f.readCtx.readExpectedTypeID(VARINT32) { + return f.readCtx.CheckError() + } *t = buf.ReadVarint32(err) return f.readCtx.CheckError() case *int64: _ = buf.ReadInt8(err) - _ = buf.ReadUint8(err) + if !f.readCtx.readExpectedTypeID(VARINT64) { + return f.readCtx.CheckError() + } *t = buf.ReadVarint64(err) return f.readCtx.CheckError() case *int: _ = buf.ReadInt8(err) - _ = buf.ReadUint8(err) + if strconv.IntSize == 32 { + if !f.readCtx.readExpectedTypeID(VARINT32) { + return f.readCtx.CheckError() + } + *t = int(buf.ReadVarint32(err)) + return f.readCtx.CheckError() + } + if !f.readCtx.readExpectedTypeID(VARINT64) { + return f.readCtx.CheckError() + } *t = int(buf.ReadVarint64(err)) return f.readCtx.CheckError() case *float32: _ = buf.ReadInt8(err) - _ = buf.ReadUint8(err) + if !f.readCtx.readExpectedTypeID(FLOAT32) { + return f.readCtx.CheckError() + } *t = buf.ReadFloat32(err) return f.readCtx.CheckError() case *float64: _ = buf.ReadInt8(err) - _ = buf.ReadUint8(err) + if !f.readCtx.readExpectedTypeID(FLOAT64) { + return f.readCtx.CheckError() + } *t = buf.ReadFloat64(err) return f.readCtx.CheckError() case *string: - _ = buf.ReadInt8(err) // null flag - _ = buf.ReadUint8(err) // type ID + _ = buf.ReadInt8(err) // null flag + if !f.readCtx.readExpectedTypeID(STRING) { + return f.readCtx.CheckError() + } *t = f.readCtx.ReadString() return f.readCtx.CheckError() case *[]byte: @@ -1116,31 +1086,31 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { return f.readCtx.CheckError() case *map[string]string: *t = f.readCtx.ReadStringStringMap(RefModeNullOnly, true) - return nil + return f.readCtx.CheckError() case *map[string]int64: *t = f.readCtx.ReadStringInt64Map(RefModeNullOnly, true) - return nil + return f.readCtx.CheckError() case *map[string]int32: *t = f.readCtx.ReadStringInt32Map(RefModeNullOnly, true) - return nil + return f.readCtx.CheckError() case *map[string]int: *t = f.readCtx.ReadStringIntMap(RefModeNullOnly, true) - return nil + return f.readCtx.CheckError() case *map[string]float64: *t = f.readCtx.ReadStringFloat64Map(RefModeNullOnly, true) - return nil + return f.readCtx.CheckError() case *map[string]bool: *t = f.readCtx.ReadStringBoolMap(RefModeNullOnly, true) - return nil + return f.readCtx.CheckError() case *map[int32]int32: *t = f.readCtx.ReadInt32Int32Map(RefModeNullOnly, true) - return nil + return f.readCtx.CheckError() case *map[int64]int64: *t = f.readCtx.ReadInt64Int64Map(RefModeNullOnly, true) - return nil + return f.readCtx.CheckError() case *map[int]int: *t = f.readCtx.ReadIntIntMap(RefModeNullOnly, true) - return nil + return f.readCtx.CheckError() default: // Slow path: use serializer-based deserialization targetVal := reflect.ValueOf(target).Elem() diff --git a/go/fory/fory_typed_test.go b/go/fory/fory_typed_test.go index 8108076572..127c05c891 100644 --- a/go/fory/fory_typed_test.go +++ b/go/fory/fory_typed_test.go @@ -132,6 +132,44 @@ func TestSerializeGenericPrimitives(t *testing.T) { }) } +func TestDeserializeRejectsRootTypeMismatch(t *testing.T) { + f := NewFory() + + data := []byte{0, 0xff, byte(STRING)} + var result bool + require.Error(t, Deserialize(f, data, &result)) + + data = []byte{0, 0xff, byte(BOOL)} + var mapResult map[string]string + require.Error(t, Deserialize(f, data, &mapResult)) +} + +func TestDeserializeRejectsRootPrimitiveSliceTypeMismatch(t *testing.T) { + f := NewFory() + + data := []byte{0, 0xff, byte(BINARY)} + var int32Result []int32 + require.Error(t, Deserialize(f, data, &int32Result)) + + data = []byte{0, 0xff, byte(INT8_ARRAY)} + var byteResult []byte + require.Error(t, Deserialize(f, data, &byteResult)) +} + +func TestDeserializeByteSliceAcceptsUint8ArrayRootType(t *testing.T) { + f := NewFory() + buf := NewByteBuffer(nil) + buf.WriteByte(0) + buf.WriteInt8(NotNullValueFlag) + buf.WriteUint8(uint8(UINT8_ARRAY)) + buf.WriteLength(3) + buf.WriteBinary([]byte{1, 2, 3}) + + var result []byte + require.NoError(t, Deserialize(f, buf.Bytes(), &result)) + require.Equal(t, []byte{1, 2, 3}, result) +} + // TestSerializeGenericComplex tests Serialize[T]/DeserializeWithCallbackBuffers[T] with complex types. // These fall back to reflection-based serialization. func TestSerializeGenericComplex(t *testing.T) { diff --git a/go/fory/map.go b/go/fory/map.go index cbce365e2b..a59fd8e652 100644 --- a/go/fory/map.go +++ b/go/fory/map.go @@ -484,6 +484,10 @@ func (s mapSerializer) readChunk(ctx *ReadContext, mapVal reflect.Value, header if ctx.HasError() { return 0 } + if chunkSize == 0 || chunkSize > size { + ctx.SetError(DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, size)) + return 0 + } // Read type info if not declared var keyTypeInfo, valueTypeInfo *TypeInfo @@ -618,7 +622,9 @@ func readMapRefAndType(ctx *ReadContext, refMode RefMode, readType bool, value r } } if readType { - buf.ReadUint8(ctxErr) + if !ctx.readExpectedTypeID(MAP) { + return false + } } return false } diff --git a/go/fory/map_primitive.go b/go/fory/map_primitive.go index 16e40f5aef..287777eaea 100644 --- a/go/fory/map_primitive.go +++ b/go/fory/map_primitive.go @@ -81,50 +81,32 @@ func readMapStringString(ctx *ReadContext) map[string]string { for size > 0 { chunkHeader := buf.ReadUint8(err) - // Handle null key/value cases - keyHasNull := (chunkHeader & KEY_HAS_NULL) != 0 - valueHasNull := (chunkHeader & VALUE_HAS_NULL) != 0 - - if keyHasNull && valueHasNull { - // Both null - use empty strings for key and value - result[""] = "" - size-- - continue - } else if keyHasNull { - // Null key with non-null value - valueDeclared := (chunkHeader & VALUE_DECL_TYPE) != 0 - if !valueDeclared { - buf.ReadUint8(err) // skip value type - } - v := readString(buf, err) - result[""] = v // empty string as null key - size-- - continue - } else if valueHasNull { - // Non-null key with null value - keyDeclared := (chunkHeader & KEY_DECL_TYPE) != 0 - if !keyDeclared { - buf.ReadUint8(err) // skip key type - } - k := readString(buf, err) - result[k] = "" // empty string as null value - size-- - continue + if chunkHeader&(TRACKING_KEY_REF|KEY_HAS_NULL|TRACKING_VALUE_REF|VALUE_HAS_NULL) != 0 { + ctx.SetError(DeserializationError("typed map reader does not support ref/null chunks")) + return result } - // ReadData chunk size chunkSize := int(buf.ReadUint8(err)) + if ctx.HasError() { + return result + } + if chunkSize == 0 || chunkSize > size { + ctx.SetError(DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, size)) + return result + } - // Read type info if not DECL_TYPE if (chunkHeader & KEY_DECL_TYPE) == 0 { - buf.ReadUint8(err) // skip key type + if !ctx.readExpectedTypeID(STRING) { + return result + } } if (chunkHeader & VALUE_DECL_TYPE) == 0 { - buf.ReadUint8(err) // skip value type + if !ctx.readExpectedTypeID(STRING) { + return result + } } - // ReadData chunk entries - for i := 0; i < chunkSize && size > 0; i++ { + for i := 0; i < chunkSize; i++ { k := readString(buf, err) v := readString(buf, err) result[k] = v @@ -185,22 +167,30 @@ func readMapStringInt64(ctx *ReadContext) map[string]int64 { for size > 0 { chunkHeader := buf.ReadUint8(err) - keyHasNull := (chunkHeader & KEY_HAS_NULL) != 0 - valueHasNull := (chunkHeader & VALUE_HAS_NULL) != 0 - - if keyHasNull || valueHasNull { - size-- - continue + if chunkHeader&(TRACKING_KEY_REF|KEY_HAS_NULL|TRACKING_VALUE_REF|VALUE_HAS_NULL) != 0 { + ctx.SetError(DeserializationError("typed map reader does not support ref/null chunks")) + return result } chunkSize := int(buf.ReadUint8(err)) + if ctx.HasError() { + return result + } + if chunkSize == 0 || chunkSize > size { + ctx.SetError(DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, size)) + return result + } if (chunkHeader & KEY_DECL_TYPE) == 0 { - buf.ReadUint8(err) + if !ctx.readExpectedTypeID(STRING) { + return result + } } if (chunkHeader & VALUE_DECL_TYPE) == 0 { - buf.ReadUint8(err) + if !ctx.readExpectedTypeID(VARINT64) { + return result + } } - for i := 0; i < chunkSize && size > 0; i++ { + for i := 0; i < chunkSize; i++ { k := readString(buf, err) v := buf.ReadVarint64(err) result[k] = v @@ -261,22 +251,30 @@ func readMapStringInt32(ctx *ReadContext) map[string]int32 { for size > 0 { chunkHeader := buf.ReadUint8(err) - keyHasNull := (chunkHeader & KEY_HAS_NULL) != 0 - valueHasNull := (chunkHeader & VALUE_HAS_NULL) != 0 - - if keyHasNull || valueHasNull { - size-- - continue + if chunkHeader&(TRACKING_KEY_REF|KEY_HAS_NULL|TRACKING_VALUE_REF|VALUE_HAS_NULL) != 0 { + ctx.SetError(DeserializationError("typed map reader does not support ref/null chunks")) + return result } chunkSize := int(buf.ReadUint8(err)) + if ctx.HasError() { + return result + } + if chunkSize == 0 || chunkSize > size { + ctx.SetError(DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, size)) + return result + } if (chunkHeader & KEY_DECL_TYPE) == 0 { - buf.ReadUint8(err) + if !ctx.readExpectedTypeID(STRING) { + return result + } } if (chunkHeader & VALUE_DECL_TYPE) == 0 { - buf.ReadUint8(err) + if !ctx.readExpectedTypeID(VARINT32) { + return result + } } - for i := 0; i < chunkSize && size > 0; i++ { + for i := 0; i < chunkSize; i++ { k := readString(buf, err) v := buf.ReadVarint32(err) result[k] = v @@ -337,22 +335,30 @@ func readMapStringInt(ctx *ReadContext) map[string]int { for size > 0 { chunkHeader := buf.ReadUint8(err) - keyHasNull := (chunkHeader & KEY_HAS_NULL) != 0 - valueHasNull := (chunkHeader & VALUE_HAS_NULL) != 0 - - if keyHasNull || valueHasNull { - size-- - continue + if chunkHeader&(TRACKING_KEY_REF|KEY_HAS_NULL|TRACKING_VALUE_REF|VALUE_HAS_NULL) != 0 { + ctx.SetError(DeserializationError("typed map reader does not support ref/null chunks")) + return result } chunkSize := int(buf.ReadUint8(err)) + if ctx.HasError() { + return result + } + if chunkSize == 0 || chunkSize > size { + ctx.SetError(DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, size)) + return result + } if (chunkHeader & KEY_DECL_TYPE) == 0 { - buf.ReadUint8(err) + if !ctx.readExpectedTypeID(STRING) { + return result + } } if (chunkHeader & VALUE_DECL_TYPE) == 0 { - buf.ReadUint8(err) + if !ctx.readExpectedTypeID(VARINT64) { + return result + } } - for i := 0; i < chunkSize && size > 0; i++ { + for i := 0; i < chunkSize; i++ { k := readString(buf, err) v := buf.ReadVarint64(err) result[k] = int(v) @@ -413,22 +419,30 @@ func readMapStringFloat64(ctx *ReadContext) map[string]float64 { for size > 0 { chunkHeader := buf.ReadUint8(err) - keyHasNull := (chunkHeader & KEY_HAS_NULL) != 0 - valueHasNull := (chunkHeader & VALUE_HAS_NULL) != 0 - - if keyHasNull || valueHasNull { - size-- - continue + if chunkHeader&(TRACKING_KEY_REF|KEY_HAS_NULL|TRACKING_VALUE_REF|VALUE_HAS_NULL) != 0 { + ctx.SetError(DeserializationError("typed map reader does not support ref/null chunks")) + return result } chunkSize := int(buf.ReadUint8(err)) + if ctx.HasError() { + return result + } + if chunkSize == 0 || chunkSize > size { + ctx.SetError(DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, size)) + return result + } if (chunkHeader & KEY_DECL_TYPE) == 0 { - buf.ReadUint8(err) + if !ctx.readExpectedTypeID(STRING) { + return result + } } if (chunkHeader & VALUE_DECL_TYPE) == 0 { - buf.ReadUint8(err) + if !ctx.readExpectedTypeID(FLOAT64) { + return result + } } - for i := 0; i < chunkSize && size > 0; i++ { + for i := 0; i < chunkSize; i++ { k := readString(buf, err) v := buf.ReadFloat64(err) result[k] = v @@ -489,27 +503,34 @@ func readMapStringBool(ctx *ReadContext) map[string]bool { for size > 0 { chunkHeader := buf.ReadUint8(err) - keyHasNull := (chunkHeader & KEY_HAS_NULL) != 0 - valueHasNull := (chunkHeader & VALUE_HAS_NULL) != 0 - - if keyHasNull || valueHasNull { - size-- - continue + if chunkHeader&(TRACKING_KEY_REF|KEY_HAS_NULL|TRACKING_VALUE_REF|VALUE_HAS_NULL) != 0 { + ctx.SetError(DeserializationError("typed map reader does not support ref/null chunks")) + return result } chunkSize := int(buf.ReadUint8(err)) + if ctx.HasError() { + return result + } + if chunkSize == 0 || chunkSize > size { + ctx.SetError(DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, size)) + return result + } - // Read type info (written by writeMapStringBool) keyDeclType := (chunkHeader & KEY_DECL_TYPE) != 0 valDeclType := (chunkHeader & VALUE_DECL_TYPE) != 0 if !keyDeclType { - buf.ReadUint8(err) // skip key type info + if !ctx.readExpectedTypeID(STRING) { + return result + } } if !valDeclType { - buf.ReadUint8(err) // skip value type info + if !ctx.readExpectedTypeID(BOOL) { + return result + } } - for i := 0; i < chunkSize && size > 0; i++ { + for i := 0; i < chunkSize; i++ { k := readString(buf, err) v := buf.ReadBool(err) result[k] = v @@ -570,22 +591,30 @@ func readMapInt32Int32(ctx *ReadContext) map[int32]int32 { for size > 0 { chunkHeader := buf.ReadUint8(err) - keyHasNull := (chunkHeader & KEY_HAS_NULL) != 0 - valueHasNull := (chunkHeader & VALUE_HAS_NULL) != 0 - - if keyHasNull || valueHasNull { - size-- - continue + if chunkHeader&(TRACKING_KEY_REF|KEY_HAS_NULL|TRACKING_VALUE_REF|VALUE_HAS_NULL) != 0 { + ctx.SetError(DeserializationError("typed map reader does not support ref/null chunks")) + return result } chunkSize := int(buf.ReadUint8(err)) + if ctx.HasError() { + return result + } + if chunkSize == 0 || chunkSize > size { + ctx.SetError(DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, size)) + return result + } if (chunkHeader & KEY_DECL_TYPE) == 0 { - buf.ReadUint8(err) + if !ctx.readExpectedTypeID(VARINT32) { + return result + } } if (chunkHeader & VALUE_DECL_TYPE) == 0 { - buf.ReadUint8(err) + if !ctx.readExpectedTypeID(VARINT32) { + return result + } } - for i := 0; i < chunkSize && size > 0; i++ { + for i := 0; i < chunkSize; i++ { k := buf.ReadVarint32(err) v := buf.ReadVarint32(err) result[k] = v @@ -646,22 +675,30 @@ func readMapInt64Int64(ctx *ReadContext) map[int64]int64 { for size > 0 { chunkHeader := buf.ReadUint8(err) - keyHasNull := (chunkHeader & KEY_HAS_NULL) != 0 - valueHasNull := (chunkHeader & VALUE_HAS_NULL) != 0 - - if keyHasNull || valueHasNull { - size-- - continue + if chunkHeader&(TRACKING_KEY_REF|KEY_HAS_NULL|TRACKING_VALUE_REF|VALUE_HAS_NULL) != 0 { + ctx.SetError(DeserializationError("typed map reader does not support ref/null chunks")) + return result } chunkSize := int(buf.ReadUint8(err)) + if ctx.HasError() { + return result + } + if chunkSize == 0 || chunkSize > size { + ctx.SetError(DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, size)) + return result + } if (chunkHeader & KEY_DECL_TYPE) == 0 { - buf.ReadUint8(err) + if !ctx.readExpectedTypeID(VARINT64) { + return result + } } if (chunkHeader & VALUE_DECL_TYPE) == 0 { - buf.ReadUint8(err) + if !ctx.readExpectedTypeID(VARINT64) { + return result + } } - for i := 0; i < chunkSize && size > 0; i++ { + for i := 0; i < chunkSize; i++ { k := buf.ReadVarint64(err) v := buf.ReadVarint64(err) result[k] = v @@ -722,22 +759,30 @@ func readMapIntInt(ctx *ReadContext) map[int]int { for size > 0 { chunkHeader := buf.ReadUint8(err) - keyHasNull := (chunkHeader & KEY_HAS_NULL) != 0 - valueHasNull := (chunkHeader & VALUE_HAS_NULL) != 0 - - if keyHasNull || valueHasNull { - size-- - continue + if chunkHeader&(TRACKING_KEY_REF|KEY_HAS_NULL|TRACKING_VALUE_REF|VALUE_HAS_NULL) != 0 { + ctx.SetError(DeserializationError("typed map reader does not support ref/null chunks")) + return result } chunkSize := int(buf.ReadUint8(err)) + if ctx.HasError() { + return result + } + if chunkSize == 0 || chunkSize > size { + ctx.SetError(DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, size)) + return result + } if (chunkHeader & KEY_DECL_TYPE) == 0 { - buf.ReadUint8(err) + if !ctx.readExpectedTypeID(VARINT64) { + return result + } } if (chunkHeader & VALUE_DECL_TYPE) == 0 { - buf.ReadUint8(err) + if !ctx.readExpectedTypeID(VARINT64) { + return result + } } - for i := 0; i < chunkSize && size > 0; i++ { + for i := 0; i < chunkSize; i++ { k := buf.ReadVarint64(err) v := buf.ReadVarint64(err) result[int(k)] = int(v) diff --git a/go/fory/map_primitive_test.go b/go/fory/map_primitive_test.go new file mode 100644 index 0000000000..d36fb6c1f9 --- /dev/null +++ b/go/fory/map_primitive_test.go @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package fory + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPrimitiveMapReaderRejectsInvalidChunkSize(t *testing.T) { + f := NewFory() + buf := NewByteBuffer(nil) + buf.WriteLength(1) + buf.WriteUint8(KEY_DECL_TYPE | VALUE_DECL_TYPE) + buf.WriteUint8(2) + + f.readCtx.SetData(buf.Bytes()) + _ = f.readCtx.ReadStringStringMap(RefModeNone, false) + require.Error(t, f.readCtx.CheckError()) +} + +func TestPrimitiveMapReaderRejectsUnexpectedTypeInfo(t *testing.T) { + f := NewFory() + buf := NewByteBuffer(nil) + buf.WriteLength(1) + buf.WriteUint8(0) + buf.WriteUint8(1) + buf.WriteUint8(uint8(STRING)) + buf.WriteUint8(uint8(BOOL)) + + f.readCtx.SetData(buf.Bytes()) + _ = f.readCtx.ReadStringStringMap(RefModeNone, false) + require.Error(t, f.readCtx.CheckError()) +} + +func TestPrimitiveMapReaderRejectsNullChunks(t *testing.T) { + f := NewFory() + buf := NewByteBuffer(nil) + buf.WriteLength(1) + buf.WriteUint8(KEY_HAS_NULL) + + f.readCtx.SetData(buf.Bytes()) + _ = f.readCtx.ReadStringStringMap(RefModeNone, false) + require.Error(t, f.readCtx.CheckError()) +} diff --git a/go/fory/primitive.go b/go/fory/primitive.go index f5a7e40550..1748bc27a8 100644 --- a/go/fory/primitive.go +++ b/go/fory/primitive.go @@ -56,8 +56,8 @@ func (s boolSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, h return } } - if readType { - _ = ctx.buffer.ReadUint8(err) + if readType && !ctx.readExpectedTypeID(BOOL) { + return } if ctx.HasError() { return @@ -103,8 +103,8 @@ func (s int8Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool, h return } } - if readType { - _ = ctx.buffer.ReadUint8(err) + if readType && !ctx.readExpectedTypeID(INT8) { + return } if ctx.HasError() { return @@ -149,8 +149,8 @@ func (s byteSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, h return } } - if readType { - _ = ctx.buffer.ReadUint8(err) + if readType && !ctx.readExpectedTypeID(UINT8) { + return } if ctx.HasError() { return @@ -195,8 +195,8 @@ func (s uint16Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool, return } } - if readType { - _ = ctx.buffer.ReadUint8(err) + if readType && !ctx.readExpectedTypeID(UINT16) { + return } if ctx.HasError() { return @@ -241,8 +241,8 @@ func (s uint32Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool, return } } - if readType { - _ = ctx.buffer.ReadUint8(err) + if readType && !ctx.readExpectedTypeID(VAR_UINT32) { + return } if ctx.HasError() { return @@ -287,8 +287,8 @@ func (s uint64Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool, return } } - if readType { - _ = ctx.buffer.ReadUint8(err) + if readType && !ctx.readExpectedTypeID(VAR_UINT64) { + return } if ctx.HasError() { return @@ -331,8 +331,8 @@ func (s uintSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, h return } } - if readType { - _ = ctx.buffer.ReadUint8(err) + if readType && !ctx.readExpectedTypeID(VAR_UINT64) { + return } if ctx.HasError() { return @@ -375,8 +375,8 @@ func (s int16Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool, return } } - if readType { - _ = ctx.buffer.ReadUint8(err) + if readType && !ctx.readExpectedTypeID(INT16) { + return } if ctx.HasError() { return @@ -419,8 +419,8 @@ func (s int32Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool, return } } - if readType { - _ = ctx.buffer.ReadUint8(err) + if readType && !ctx.readExpectedTypeID(VARINT32) { + return } if ctx.HasError() { return @@ -463,8 +463,8 @@ func (s int64Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool, return } } - if readType { - _ = ctx.buffer.ReadUint8(err) + if readType && !ctx.readExpectedTypeID(VARINT64) { + return } if ctx.HasError() { return @@ -505,8 +505,8 @@ func (s intSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, ha return } } - if readType { - _ = ctx.buffer.ReadUint8(err) + if readType && !ctx.readExpectedTypeID(VARINT64) { + return } if ctx.HasError() { return @@ -549,8 +549,8 @@ func (s float32Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool return } } - if readType { - _ = ctx.buffer.ReadUint8(err) + if readType && !ctx.readExpectedTypeID(FLOAT32) { + return } if ctx.HasError() { return @@ -593,8 +593,8 @@ func (s float64Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool return } } - if readType { - _ = ctx.buffer.ReadUint8(err) + if readType && !ctx.readExpectedTypeID(FLOAT64) { + return } if ctx.HasError() { return @@ -651,8 +651,8 @@ func (s float16Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool return } } - if readType { - _ = ctx.buffer.ReadUint8(err) + if readType && !ctx.readExpectedTypeID(FLOAT16) { + return } if ctx.HasError() { return @@ -704,8 +704,8 @@ func (s bfloat16Serializer) Read(ctx *ReadContext, refMode RefMode, readType boo return } } - if readType { - _ = ctx.buffer.ReadUint8(err) + if readType && !ctx.readExpectedTypeID(BFLOAT16) { + return } if ctx.HasError() { return diff --git a/go/fory/reader.go b/go/fory/reader.go index 9c8b049ad2..d0f89fa81c 100644 --- a/go/fory/reader.go +++ b/go/fory/reader.go @@ -31,11 +31,12 @@ import ( type ReadContext struct { buffer *ByteBuffer refReader *RefReader - trackRef bool // Cached flag to avoid indirection - xlang bool // Cross-language serialization mode + trackRef bool // Cached flag to avoid indirection + xlang bool // Cross-language serialization mode + rootHeader byte compatible bool // Schema evolution compatibility mode typeResolver *TypeResolver // For complex type deserialization - refResolver *RefResolver // For reference tracking (legacy) + refResolver *RefResolver // For reference tracking in native-mode paths outOfBandBuffers []*ByteBuffer // Out-of-band buffers for deserialization outOfBandIndex int // Current index into out-of-band buffers depth int // Current nesting depth for cycle detection @@ -109,7 +110,7 @@ func (c *ReadContext) TypeResolver() *TypeResolver { return c.typeResolver } -// RefResolver returns the reference resolver (legacy) +// RefResolver returns the reference resolver. func (c *ReadContext) RefResolver() *RefResolver { return c.refResolver } @@ -151,6 +152,18 @@ func (c *ReadContext) CheckError() error { return nil } +func (c *ReadContext) readExpectedTypeID(expected TypeId) bool { + actual := TypeId(c.buffer.ReadUint8(c.Err())) + if c.HasError() { + return false + } + if actual != expected { + c.SetError(TypeMismatchError(actual, expected)) + return false + } + return true +} + // Inline primitive reads func (c *ReadContext) RawBool() bool { return c.buffer.ReadBool(c.Err()) } func (c *ReadContext) RawInt8() int8 { return int8(c.buffer.ReadByte(c.Err())) } @@ -288,7 +301,11 @@ func (c *ReadContext) ReadBoolSlice(refMode RefMode, readType bool) []bool { } } if readType { - _ = c.buffer.ReadUint8(err) + actual := TypeId(c.buffer.ReadUint8(err)) + if actual != BOOL_ARRAY { + c.SetError(TypeMismatchError(actual, BOOL_ARRAY)) + return nil + } } return ReadBoolSlice(c.buffer, err) } @@ -302,7 +319,11 @@ func (c *ReadContext) ReadInt8Slice(refMode RefMode, readType bool) []int8 { } } if readType { - _ = c.buffer.ReadUint8(err) + actual := TypeId(c.buffer.ReadUint8(err)) + if actual != INT8_ARRAY { + c.SetError(TypeMismatchError(actual, INT8_ARRAY)) + return nil + } } return ReadInt8Slice(c.buffer, err) } @@ -316,7 +337,11 @@ func (c *ReadContext) ReadInt16Slice(refMode RefMode, readType bool) []int16 { } } if readType { - _ = c.buffer.ReadUint8(err) + actual := TypeId(c.buffer.ReadUint8(err)) + if actual != INT16_ARRAY { + c.SetError(TypeMismatchError(actual, INT16_ARRAY)) + return nil + } } return ReadInt16Slice(c.buffer, err) } @@ -330,7 +355,11 @@ func (c *ReadContext) ReadInt32Slice(refMode RefMode, readType bool) []int32 { } } if readType { - _ = c.buffer.ReadUint8(err) + actual := TypeId(c.buffer.ReadUint8(err)) + if actual != INT32_ARRAY { + c.SetError(TypeMismatchError(actual, INT32_ARRAY)) + return nil + } } return ReadInt32Slice(c.buffer, err) } @@ -344,7 +373,11 @@ func (c *ReadContext) ReadInt64Slice(refMode RefMode, readType bool) []int64 { } } if readType { - _ = c.buffer.ReadUint8(err) + actual := TypeId(c.buffer.ReadUint8(err)) + if actual != INT64_ARRAY { + c.SetError(TypeMismatchError(actual, INT64_ARRAY)) + return nil + } } return ReadInt64Slice(c.buffer, err) } @@ -358,7 +391,11 @@ func (c *ReadContext) ReadUint16Slice(refMode RefMode, readType bool) []uint16 { } } if readType { - _ = c.buffer.ReadUint8(err) + actual := TypeId(c.buffer.ReadUint8(err)) + if actual != UINT16_ARRAY { + c.SetError(TypeMismatchError(actual, UINT16_ARRAY)) + return nil + } } return ReadUint16Slice(c.buffer, err) } @@ -372,7 +409,11 @@ func (c *ReadContext) ReadUint32Slice(refMode RefMode, readType bool) []uint32 { } } if readType { - _ = c.buffer.ReadUint8(err) + actual := TypeId(c.buffer.ReadUint8(err)) + if actual != UINT32_ARRAY { + c.SetError(TypeMismatchError(actual, UINT32_ARRAY)) + return nil + } } return ReadUint32Slice(c.buffer, err) } @@ -386,7 +427,11 @@ func (c *ReadContext) ReadUint64Slice(refMode RefMode, readType bool) []uint64 { } } if readType { - _ = c.buffer.ReadUint8(err) + actual := TypeId(c.buffer.ReadUint8(err)) + if actual != UINT64_ARRAY { + c.SetError(TypeMismatchError(actual, UINT64_ARRAY)) + return nil + } } return ReadUint64Slice(c.buffer, err) } @@ -400,7 +445,15 @@ func (c *ReadContext) ReadIntSlice(refMode RefMode, readType bool) []int { } } if readType { - _ = c.buffer.ReadUint8(err) + actual := TypeId(c.buffer.ReadUint8(err)) + expected := TypeId(INT64_ARRAY) + if strconv.IntSize == 32 { + expected = INT32_ARRAY + } + if actual != expected { + c.SetError(TypeMismatchError(actual, expected)) + return nil + } } return ReadIntSlice(c.buffer, err) } @@ -414,7 +467,15 @@ func (c *ReadContext) ReadUintSlice(refMode RefMode, readType bool) []uint { } } if readType { - _ = c.buffer.ReadUint8(err) + actual := TypeId(c.buffer.ReadUint8(err)) + expected := TypeId(UINT64_ARRAY) + if strconv.IntSize == 32 { + expected = UINT32_ARRAY + } + if actual != expected { + c.SetError(TypeMismatchError(actual, expected)) + return nil + } } return ReadUintSlice(c.buffer, err) } @@ -428,7 +489,11 @@ func (c *ReadContext) ReadFloat32Slice(refMode RefMode, readType bool) []float32 } } if readType { - _ = c.buffer.ReadUint8(err) + actual := TypeId(c.buffer.ReadUint8(err)) + if actual != FLOAT32_ARRAY { + c.SetError(TypeMismatchError(actual, FLOAT32_ARRAY)) + return nil + } } return ReadFloat32Slice(c.buffer, err) } @@ -442,7 +507,11 @@ func (c *ReadContext) ReadFloat64Slice(refMode RefMode, readType bool) []float64 } } if readType { - _ = c.buffer.ReadUint8(err) + actual := TypeId(c.buffer.ReadUint8(err)) + if actual != FLOAT64_ARRAY { + c.SetError(TypeMismatchError(actual, FLOAT64_ARRAY)) + return nil + } } return ReadFloat64Slice(c.buffer, err) } @@ -456,7 +525,11 @@ func (c *ReadContext) ReadByteSlice(refMode RefMode, readType bool) []byte { } } if readType { - _ = c.buffer.ReadUint8(err) + actual := TypeId(c.buffer.ReadUint8(err)) + if actual != BINARY && actual != UINT8_ARRAY { + c.SetError(DeserializationErrorf("slice type mismatch: expected BINARY (%d) or UINT8_ARRAY (%d), got %d", BINARY, UINT8_ARRAY, actual)) + return nil + } } size := c.ReadBinaryLength() return c.buffer.ReadBinary(size, err) @@ -484,8 +557,8 @@ func (c *ReadContext) ReadStringStringMap(refMode RefMode, readType bool) map[st return nil } } - if readType { - _ = c.buffer.ReadUint8(err) + if readType && !c.readExpectedTypeID(MAP) { + return nil } return readMapStringString(c) } @@ -498,8 +571,8 @@ func (c *ReadContext) ReadStringInt64Map(refMode RefMode, readType bool) map[str return nil } } - if readType { - _ = c.buffer.ReadUint8(err) + if readType && !c.readExpectedTypeID(MAP) { + return nil } return readMapStringInt64(c) } @@ -512,8 +585,8 @@ func (c *ReadContext) ReadStringInt32Map(refMode RefMode, readType bool) map[str return nil } } - if readType { - _ = c.buffer.ReadUint8(err) + if readType && !c.readExpectedTypeID(MAP) { + return nil } return readMapStringInt32(c) } @@ -526,8 +599,8 @@ func (c *ReadContext) ReadStringIntMap(refMode RefMode, readType bool) map[strin return nil } } - if readType { - _ = c.buffer.ReadUint8(err) + if readType && !c.readExpectedTypeID(MAP) { + return nil } return readMapStringInt(c) } @@ -540,8 +613,8 @@ func (c *ReadContext) ReadStringFloat64Map(refMode RefMode, readType bool) map[s return nil } } - if readType { - _ = c.buffer.ReadUint8(err) + if readType && !c.readExpectedTypeID(MAP) { + return nil } return readMapStringFloat64(c) } @@ -554,8 +627,8 @@ func (c *ReadContext) ReadStringBoolMap(refMode RefMode, readType bool) map[stri return nil } } - if readType { - _ = c.buffer.ReadUint8(err) + if readType && !c.readExpectedTypeID(MAP) { + return nil } return readMapStringBool(c) } @@ -568,8 +641,8 @@ func (c *ReadContext) ReadInt32Int32Map(refMode RefMode, readType bool) map[int3 return nil } } - if readType { - _ = c.buffer.ReadUint8(err) + if readType && !c.readExpectedTypeID(MAP) { + return nil } return readMapInt32Int32(c) } @@ -582,8 +655,8 @@ func (c *ReadContext) ReadInt64Int64Map(refMode RefMode, readType bool) map[int6 return nil } } - if readType { - _ = c.buffer.ReadUint8(err) + if readType && !c.readExpectedTypeID(MAP) { + return nil } return readMapInt64Int64(c) } @@ -596,8 +669,8 @@ func (c *ReadContext) ReadIntIntMap(refMode RefMode, readType bool) map[int]int return nil } } - if readType { - _ = c.buffer.ReadUint8(err) + if readType && !c.readExpectedTypeID(MAP) { + return nil } return readMapIntInt(c) } diff --git a/go/fory/skip.go b/go/fory/skip.go index 9d08f465c4..abc7466449 100644 --- a/go/fory/skip.go +++ b/go/fory/skip.go @@ -56,7 +56,10 @@ func SkipFieldValueWithTypeFlag(ctx *ReadContext, fieldDef FieldDef, readRefFlag // Check if it's an EXT type first - EXT types don't have meta info like structs if internalID == EXT { - typeInfo := ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, wroteTypeID, err) + typeInfo := readKnownTypeInfoForSkip(ctx, wroteTypeID) + if ctx.HasError() { + return + } if typeInfo != nil && typeInfo.Serializer != nil { // Use the serializer to read and discard the value var dummy any @@ -71,8 +74,11 @@ func SkipFieldValueWithTypeFlag(ctx *ReadContext, fieldDef FieldDef, readRefFlag // Check if it's a NAMED_EXT type - need to read type info to find serializer if internalID == NAMED_EXT { - typeInfo := ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, wroteTypeID, err) - if typeInfo.Serializer != nil { + typeInfo := readKnownTypeInfoForSkip(ctx, wroteTypeID) + if ctx.HasError() { + return + } + if typeInfo != nil && typeInfo.Serializer != nil { // Use the serializer to read and discard the value var dummy any dummyVal := reflect.ValueOf(&dummy).Elem() @@ -86,14 +92,20 @@ func SkipFieldValueWithTypeFlag(ctx *ReadContext, fieldDef FieldDef, readRefFlag // Check if it's a struct type - need to read type info and skip struct data if internalID == COMPATIBLE_STRUCT || internalID == STRUCT || internalID == NAMED_STRUCT || internalID == NAMED_COMPATIBLE_STRUCT { - typeInfo := ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, wroteTypeID, err) + typeInfo := readKnownTypeInfoForSkip(ctx, wroteTypeID) + if ctx.HasError() { + return + } // Now skip the struct data using the typeInfo from the written type skipStruct(ctx, typeInfo) return } if IsNamespacedType(internalID) { - typeInfo := ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, wroteTypeID, err) + typeInfo := readKnownTypeInfoForSkip(ctx, wroteTypeID) + if ctx.HasError() { + return + } // Now skip the struct data using the typeInfo from the written type skipStruct(ctx, typeInfo) return @@ -153,38 +165,22 @@ func SkipAnyValue(ctx *ReadContext, readRefFlag bool) { typeSpec: NewMapTypeSpec(TypeId(typeID), NewSimpleTypeSpec(UNKNOWN), NewSimpleTypeSpec(UNKNOWN)), nullable: true, } - case NAMED_UNION: - resolver := ctx.TypeResolver() - _, _ = resolver.metaStringResolver.ReadMetaStringBytes(ctx.buffer, err) - if ctx.HasError() { - return - } - _, _ = resolver.metaStringResolver.ReadMetaStringBytes(ctx.buffer, err) - if ctx.HasError() { - return - } - fieldDef = FieldDef{ - typeSpec: NewSimpleTypeSpec(TypeId(typeID)), - nullable: true, - } - case COMPATIBLE_STRUCT, NAMED_COMPATIBLE_STRUCT, STRUCT, NAMED_STRUCT, EXT, TYPED_UNION: + case ENUM, NAMED_ENUM, COMPATIBLE_STRUCT, NAMED_COMPATIBLE_STRUCT, STRUCT, NAMED_STRUCT, + EXT, NAMED_EXT, TYPED_UNION, NAMED_UNION: // Read type info using the shared meta reader when enabled. typeInfo = ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeID, err) if ctx.HasError() { return } + if typeInfo == nil { + ctx.SetError(DeserializationErrorf("cannot skip type %d: type info not found", typeID)) + return + } fieldDef = FieldDef{ typeSpec: NewSimpleTypeSpec(TypeId(typeID)), nullable: true, } default: - if internalID == ENUM || internalID == STRUCT || - internalID == EXT || internalID == TYPED_UNION { - ctx.buffer.ReadVarUint32(err) - if ctx.HasError() { - return - } - } fieldDef = FieldDef{ typeSpec: NewSimpleTypeSpec(TypeId(typeID)), nullable: true, @@ -206,7 +202,25 @@ func readTypeInfoForSkip(ctx *ReadContext, fieldTypeId TypeId) *TypeInfo { return nil } // Use readTypeInfoWithTypeID which handles both namespaced and non-namespaced types correctly - return ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeID, err) + typeInfo := ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeID, err) + if ctx.HasError() { + return nil + } + if typeInfo == nil { + ctx.SetError(DeserializationErrorf("cannot skip type %d: type info not found", typeID)) + } + return typeInfo +} + +func readKnownTypeInfoForSkip(ctx *ReadContext, typeID uint32) *TypeInfo { + typeInfo := ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeID, ctx.Err()) + if ctx.HasError() { + return nil + } + if typeInfo == nil { + ctx.SetError(DeserializationErrorf("cannot skip type %d: type info not found", typeID)) + } + return typeInfo } // skipCollection skips a collection (list/set) value @@ -236,7 +250,10 @@ func skipCollection(ctx *ReadContext, fieldDef FieldDef) { if ctx.HasError() { return } - elemTypeInfo = ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeID, err) + elemTypeInfo = readKnownTypeInfoForSkip(ctx, typeID) + if ctx.HasError() { + return + } elemDef = FieldDef{ typeSpec: NewSimpleTypeSpec(TypeId(elemTypeInfo.TypeID)), nullable: hasNull, @@ -336,7 +353,10 @@ func skipMap(ctx *ReadContext, fieldDef FieldDef) { if ctx.HasError() { return } - valueTypeInfo = ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeID, bufErr) + valueTypeInfo = readKnownTypeInfoForSkip(ctx, typeID) + if ctx.HasError() { + return + } valueDef = FieldDef{ typeSpec: NewSimpleTypeSpec(TypeId(valueTypeInfo.TypeID)), nullable: true, @@ -368,7 +388,10 @@ func skipMap(ctx *ReadContext, fieldDef FieldDef) { if ctx.HasError() { return } - keyTypeInfo = ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeID, bufErr) + keyTypeInfo = readKnownTypeInfoForSkip(ctx, typeID) + if ctx.HasError() { + return + } keyDef = FieldDef{ typeSpec: NewSimpleTypeSpec(TypeId(keyTypeInfo.TypeID)), nullable: true, @@ -395,6 +418,10 @@ func skipMap(ctx *ReadContext, fieldDef FieldDef) { if ctx.HasError() { return } + if chunkSize == 0 || uint32(chunkSize) > length-lenCounter { + ctx.SetError(DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, length-lenCounter)) + return + } keyDeclared := (header & KEY_DECL_TYPE) != 0 valueDeclared := (header & VALUE_DECL_TYPE) != 0 @@ -406,7 +433,10 @@ func skipMap(ctx *ReadContext, fieldDef FieldDef) { if ctx.HasError() { return } - keyTypeInfo = ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeID, bufErr) + keyTypeInfo = readKnownTypeInfoForSkip(ctx, typeID) + if ctx.HasError() { + return + } keyDef = FieldDef{ typeSpec: NewSimpleTypeSpec(TypeId(keyTypeInfo.TypeID)), nullable: true, @@ -420,7 +450,10 @@ func skipMap(ctx *ReadContext, fieldDef FieldDef) { if ctx.HasError() { return } - valueTypeInfo = ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeID, bufErr) + valueTypeInfo = readKnownTypeInfoForSkip(ctx, typeID) + if ctx.HasError() { + return + } valueDef = FieldDef{ typeSpec: NewSimpleTypeSpec(TypeId(valueTypeInfo.TypeID)), nullable: true, @@ -527,25 +560,32 @@ func skipValue(ctx *ReadContext, fieldDef FieldDef, readRefFlag bool, isField bo internalID := TypeId(typeIDNum) // Handle struct-like types if internalID == COMPATIBLE_STRUCT || internalID == STRUCT || - internalID == NAMED_STRUCT || internalID == NAMED_COMPATIBLE_STRUCT || - internalID == UNKNOWN { + internalID == NAMED_STRUCT || internalID == NAMED_COMPATIBLE_STRUCT { // If type_info is provided (from SkipAnyValue), use skipStruct directly if typeInfo != nil { skipStruct(ctx, typeInfo) return } // Otherwise we need to read type info - ti := ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeIDNum, err) + ti := readKnownTypeInfoForSkip(ctx, typeIDNum) + if ctx.HasError() { + return + } skipStruct(ctx, ti) return } - if internalID == ENUM { + if internalID == ENUM || internalID == NAMED_ENUM { // Enum values are encoded as ordinal only (VarUint32Small7) for xlang. - _ = ctx.buffer.ReadUint8(err) + _ = ctx.buffer.ReadVarUint32Small7(err) return } if internalID == EXT || internalID == NAMED_EXT || internalID == TYPED_UNION || internalID == NAMED_UNION { - typeInfo := ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeIDNum, err) + if typeInfo == nil { + typeInfo = readKnownTypeInfoForSkip(ctx, typeIDNum) + if ctx.HasError() { + return + } + } if typeInfo != nil && typeInfo.Serializer != nil { // Use the serializer to read and discard the value var dummy any @@ -569,11 +609,15 @@ func skipValue(ctx *ReadContext, fieldDef FieldDef, readRefFlag bool, isField bo case INT16: _ = ctx.buffer.ReadInt16(err) case INT32: - _ = ctx.buffer.ReadUint8(err) + _ = ctx.buffer.ReadInt32(err) case VARINT32: - _ = ctx.buffer.ReadUint8(err) - case INT64, VARINT64, TAGGED_INT64: + _ = ctx.buffer.ReadVarint32(err) + case INT64: + _ = ctx.buffer.ReadInt64(err) + case VARINT64: _ = ctx.buffer.ReadVarint64(err) + case TAGGED_INT64: + _ = ctx.buffer.ReadTaggedInt64(err) // Floating point types case BFLOAT16, FLOAT16: @@ -676,7 +720,7 @@ func skipValue(ctx *ReadContext, fieldDef FieldDef, readRefFlag bool, isField bo // Enum types case ENUM: - _ = ctx.buffer.ReadVarUint32(err) + _ = ctx.buffer.ReadVarUint32Small7(err) // Unsigned integer types case UINT8: @@ -692,12 +736,7 @@ func skipValue(ctx *ReadContext, fieldDef FieldDef, readRefFlag bool, isField bo case VAR_UINT64: _ = ctx.buffer.ReadVarUint64(err) case TAGGED_UINT64: - firstInt32 := ctx.buffer.ReadInt32(err) - if (firstInt32 & 1) != 0 { - // 9-byte encoding - _ = ctx.buffer.ReadUint64(err) - } - // Otherwise it's 4-byte encoding, already read + _ = ctx.buffer.ReadTaggedUint64(err) // Unknown (polymorphic) type - read type info and skip dynamically case UNKNOWN: diff --git a/go/fory/skip_test.go b/go/fory/skip_test.go new file mode 100644 index 0000000000..c60a40e355 --- /dev/null +++ b/go/fory/skip_test.go @@ -0,0 +1,132 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package fory + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSkipEnumConsumesSmall7Ordinal(t *testing.T) { + f := New(WithXlang(true)) + buf := NewByteBuffer(nil) + buf.WriteVarUint32Small7(128) + buf.WriteByte(0x7f) + + f.readCtx.SetData(buf.Bytes()) + skipValue( + f.readCtx, + FieldDef{typeSpec: NewSimpleTypeSpec(ENUM), nullable: true}, + false, + false, + nil, + ) + require.NoError(t, f.readCtx.CheckError()) + require.Equal(t, 2, f.readCtx.Buffer().ReaderIndex()) + require.Equal(t, byte(0x7f), f.readCtx.Buffer().ReadByte(f.readCtx.Err())) +} + +func TestSkipPrimitiveConsumesExactEncoding(t *testing.T) { + tests := []struct { + name string + typeID TypeId + write func(*ByteBuffer) + }{ + { + name: "int32", + typeID: INT32, + write: func(buf *ByteBuffer) { buf.WriteInt32(0x01020304) }, + }, + { + name: "varint32", + typeID: VARINT32, + write: func(buf *ByteBuffer) { buf.WriteVarint32(300) }, + }, + { + name: "int64", + typeID: INT64, + write: func(buf *ByteBuffer) { buf.WriteInt64(0x0102030405060708) }, + }, + { + name: "varint64", + typeID: VARINT64, + write: func(buf *ByteBuffer) { buf.WriteVarint64(1 << 35) }, + }, + { + name: "tagged_int64_small", + typeID: TAGGED_INT64, + write: func(buf *ByteBuffer) { buf.WriteTaggedInt64(1073741823) }, + }, + { + name: "tagged_int64_large", + typeID: TAGGED_INT64, + write: func(buf *ByteBuffer) { buf.WriteTaggedInt64(1 << 40) }, + }, + { + name: "tagged_uint64_small", + typeID: TAGGED_UINT64, + write: func(buf *ByteBuffer) { buf.WriteTaggedUint64(0x7fffffff) }, + }, + { + name: "tagged_uint64_large", + typeID: TAGGED_UINT64, + write: func(buf *ByteBuffer) { buf.WriteTaggedUint64(1 << 40) }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + f := New(WithXlang(true)) + buf := NewByteBuffer(nil) + tc.write(buf) + wantIndex := buf.WriterIndex() + buf.WriteByte(0x7f) + + f.readCtx.SetData(buf.Bytes()) + skipValue( + f.readCtx, + FieldDef{typeSpec: NewSimpleTypeSpec(tc.typeID), nullable: true}, + false, + false, + nil, + ) + require.NoError(t, f.readCtx.CheckError()) + require.Equal(t, wantIndex, f.readCtx.Buffer().ReaderIndex()) + require.Equal(t, byte(0x7f), f.readCtx.Buffer().ReadByte(f.readCtx.Err())) + }) + } +} + +func TestSkipMapRejectsInvalidChunkSize(t *testing.T) { + f := New(WithXlang(true)) + buf := NewByteBuffer(nil) + buf.WriteLength(1) + buf.WriteByte(KEY_DECL_TYPE | VALUE_DECL_TYPE) + buf.WriteByte(2) + + f.readCtx.SetData(buf.Bytes()) + skipMap( + f.readCtx, + FieldDef{ + typeSpec: NewMapTypeSpec(MAP, NewSimpleTypeSpec(INT32), NewSimpleTypeSpec(INT32)), + nullable: true, + }, + ) + require.Error(t, f.readCtx.CheckError()) +} diff --git a/go/fory/stream.go b/go/fory/stream.go index 420a6d4825..00de53abe2 100644 --- a/go/fory/stream.go +++ b/go/fory/stream.go @@ -112,17 +112,12 @@ func (f *Fory) DeserializeFromStream(is *InputStream, v any) error { origBuffer := f.readCtx.buffer f.readCtx.buffer = is.buffer - isNull := readHeader(f.readCtx) + readHeader(f.readCtx) if f.readCtx.HasError() { f.readCtx.buffer = origBuffer return f.readCtx.TakeError() } - if isNull { - f.readCtx.buffer = origBuffer - return nil - } - target := reflect.ValueOf(v).Elem() f.readCtx.ReadValue(target, RefModeTracking, true) if f.readCtx.HasError() { @@ -145,15 +140,11 @@ func (f *Fory) DeserializeFromReader(r io.Reader, v any) error { // Always reset to enforce stateless semantics. f.readCtx.buffer.ResetWithReader(r, 0) - isNull := readHeader(f.readCtx) + readHeader(f.readCtx) if f.readCtx.HasError() { return f.readCtx.TakeError() } - if isNull { - return nil - } - target := reflect.ValueOf(v).Elem() f.readCtx.ReadValue(target, RefModeTracking, true) if f.readCtx.HasError() { diff --git a/go/fory/string.go b/go/fory/string.go index 7c7a41bba3..e20ac25af5 100644 --- a/go/fory/string.go +++ b/go/fory/string.go @@ -152,8 +152,8 @@ func (s stringSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, return } } - if readType { - _ = ctx.buffer.ReadUint8(err) + if readType && !ctx.readExpectedTypeID(STRING) { + return } if ctx.HasError() { return @@ -194,8 +194,8 @@ func (s ptrToStringSerializer) Read(ctx *ReadContext, refMode RefMode, readType return } } - if readType { - _ = ctx.buffer.ReadUint8(err) + if readType && !ctx.readExpectedTypeID(STRING) { + return } if ctx.HasError() { return diff --git a/go/fory/struct_test.go b/go/fory/struct_test.go index 9dc03d1611..eb9287a0c2 100644 --- a/go/fory/struct_test.go +++ b/go/fory/struct_test.go @@ -604,14 +604,12 @@ func TestSkipAnyValueReadsSharedTypeMeta(t *testing.T) { f.resetReadState() f.readCtx.SetData(buf.Bytes()) - isNull := readHeader(f.readCtx) - require.False(t, isNull) + readHeader(f.readCtx) SkipAnyValue(f.readCtx, true) require.NoError(t, f.readCtx.CheckError()) f.resetReadState() - isNull = readHeader(f.readCtx) - require.False(t, isNull) + readHeader(f.readCtx) var out any f.readCtx.ReadValue(reflect.ValueOf(&out).Elem(), RefModeTracking, true) @@ -622,6 +620,17 @@ func TestSkipAnyValueReadsSharedTypeMeta(t *testing.T) { require.Equal(t, "ok", result.Name) } +func TestReadHeaderRejectsOutOfBandWithoutBuffers(t *testing.T) { + f := New(WithXlang(true)) + f.readCtx.SetData([]byte{XLangFlag | OutOfBandFlag}) + + readHeader(f.readCtx) + + err := f.readCtx.TakeError() + require.Error(t, err) + require.Contains(t, err.Error(), "out-of-band buffers") +} + func TestFloat16StructField(t *testing.T) { type StructWithFloat16 struct { F16 float16.Float16 diff --git a/go/fory/time.go b/go/fory/time.go index 676efe00e5..1a389ccbbf 100644 --- a/go/fory/time.go +++ b/go/fory/time.go @@ -159,8 +159,8 @@ func (s dateSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, h return } } - if readType { - _ = ctx.buffer.ReadUint8(err) + if readType && !ctx.readExpectedTypeID(DATE) { + return } if ctx.HasError() { return @@ -268,8 +268,8 @@ func (s durationSerializer) Read(ctx *ReadContext, refMode RefMode, readType boo return } } - if readType { - _ = ctx.buffer.ReadUint8(err) + if readType && !ctx.readExpectedTypeID(DURATION) { + return } if ctx.HasError() { return @@ -311,8 +311,8 @@ func (s timeSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, h return } } - if readType { - _ = ctx.buffer.ReadUint8(err) + if readType && !ctx.readExpectedTypeID(TIMESTAMP) { + return } if ctx.HasError() { return diff --git a/go/fory/type_def.go b/go/fory/type_def.go index 81bfa76e87..70ea7e1a03 100644 --- a/go/fory/type_def.go +++ b/go/fory/type_def.go @@ -18,10 +18,7 @@ package fory import ( - "bytes" - "compress/zlib" "fmt" - "io" "reflect" "strings" @@ -29,17 +26,17 @@ import ( ) const ( - META_SIZE_MASK = 0xFF - COMPRESS_META_FLAG = 0b1 << 9 - HAS_FIELDS_META_FLAG = 0b1 << 8 - NUM_HASH_BITS = 50 + META_SIZE_MASK = 0xFF + COMPRESS_META_FLAG = 0b1 << 8 + RESERVED_META_BITS = 0b111 << 9 + NUM_HASH_BITS = 52 ) /* TypeDef represents a transportable value object containing type information and field definitions. typeDef are layout as following: - - first 8 bytes: global header (50 bits hash + 1 bit compress flag + write fields meta + 8 bits meta size) - - next 1 byte: meta header (2 bits reserved + 1 bit register by name flag + 5 bits num fields) + - first 8 bytes: global header (52 bits hash + 1 bit compress flag + 8 bits meta size) + - next 1 byte: kind header - next variable bytes: type id (varint) or ns name + type name - next variable bytes: field definitions (see below) */ @@ -218,21 +215,21 @@ func (td *TypeDef) buildTypeInfoWithResolver(resolver *TypeResolver) (TypeInfo, type_ := td.type_ var serializer Serializer - // For extension types, use the registered serializer if available - if type_ != nil && resolver != nil { - if existingSerializer, ok := resolver.typeToSerializers[type_]; ok { - // Only use registered serializer for extension types (not struct types) - if _, isExt := existingSerializer.(*extensionSerializerAdapter); isExt { - serializer = existingSerializer - } else if ptrSer, isPtrSer := existingSerializer.(*ptrToValueSerializer); isPtrSer { - if _, isExtInner := ptrSer.valueSerializer.(*extensionSerializerAdapter); isExtInner { - serializer = existingSerializer - } + if !isStructTypeId(TypeId(td.typeId)) { + if type_ != nil && resolver != nil { + var err error + serializer, err = resolver.getSerializerByType(type_, false) + if err != nil { + return TypeInfo{}, err } } - } - // If no extension serializer, create struct serializer - if serializer == nil { + if serializer == nil && resolver != nil { + serializer = resolver.getSerializerByTypeID(td.typeId) + } + if serializer == nil { + return TypeInfo{}, fmt.Errorf("no serializer registered for TypeDef kind %d", td.typeId) + } + } else { if type_ == nil { // Unknown struct type - use skipStructSerializer to skip data serializer = &skipStructSerializer{ @@ -304,7 +301,14 @@ func readPkgName(buffer *ByteBuffer, namespaceDecoder *meta.Decoder, err *Error) encodingFlags := header & 0b11 // 2 bits for encoding size := header >> 2 // 6 bits for size if size == BIG_NAME_THRESHOLD { - size = int(buffer.ReadVarUint32Small7(err)) + BIG_NAME_THRESHOLD + extra := buffer.ReadVarUint32Small7(err) + if err.HasError() { + return "", err.TakeError() + } + if uint64(extra) > uint64(MaxInt-BIG_NAME_THRESHOLD) { + return "", fmt.Errorf("invalid TypeDef namespace length") + } + size = int(extra) + BIG_NAME_THRESHOLD } var encoding meta.Encoding @@ -319,6 +323,9 @@ func readPkgName(buffer *ByteBuffer, namespaceDecoder *meta.Decoder, err *Error) return "", fmt.Errorf("invalid package encoding flags: %d", encodingFlags) } + if size > buffer.remaining() { + return "", fmt.Errorf("TypeDef namespace length %d exceeds remaining metadata %d", size, buffer.remaining()) + } data := make([]byte, size) if _, err := buffer.Read(data); err != nil { return "", err @@ -335,7 +342,14 @@ func readTypeName(buffer *ByteBuffer, typeNameDecoder *meta.Decoder, err *Error) encodingFlags := header & 0b11 // 2 bits for encoding size := header >> 2 // 6 bits for size if size == BIG_NAME_THRESHOLD { - size = int(buffer.ReadVarUint32Small7(err)) + BIG_NAME_THRESHOLD + extra := buffer.ReadVarUint32Small7(err) + if err.HasError() { + return "", err.TakeError() + } + if uint64(extra) > uint64(MaxInt-BIG_NAME_THRESHOLD) { + return "", fmt.Errorf("invalid TypeDef typename length") + } + size = int(extra) + BIG_NAME_THRESHOLD } var encoding meta.Encoding @@ -352,6 +366,9 @@ func readTypeName(buffer *ByteBuffer, typeNameDecoder *meta.Decoder, err *Error) return "", fmt.Errorf("invalid typename encoding flags: %d", encodingFlags) } + if size > buffer.remaining() { + return "", fmt.Errorf("TypeDef typename length %d exceeds remaining metadata %d", size, buffer.remaining()) + } data := make([]byte, size) if _, err := buffer.Read(data); err != nil { return "", err @@ -362,16 +379,18 @@ func readTypeName(buffer *ByteBuffer, typeNameDecoder *meta.Decoder, err *Error) // buildTypeDef constructs a TypeDef from a value func buildTypeDef(fory *Fory, value reflect.Value) (*TypeDef, error) { - fieldDefs, err := buildFieldDefs(fory, value) - if err != nil { - return nil, fmt.Errorf("failed to extract field infos: %w", err) - } - infoPtr, err := fory.typeResolver.getTypeInfo(value, true) if err != nil { return nil, fmt.Errorf("failed to get type info for value %v: %w", value, err) } typeId := uint32(infoPtr.TypeID) + var fieldDefs []FieldDef + if isStructTypeId(TypeId(typeId)) { + fieldDefs, err = buildFieldDefs(fory, value) + if err != nil { + return nil, fmt.Errorf("failed to extract field infos: %w", err) + } + } registerByName := IsNamespacedType(TypeId(typeId)) typeDef := NewTypeDef(typeId, infoPtr.UserTypeID, infoPtr.PkgPathBytes, infoPtr.NameBytes, registerByName, false, fieldDefs) @@ -616,7 +635,9 @@ func newTypeSpecForTypeID(typeID TypeId) (*TypeSpec, error) { const ( SmallNumFieldsThreshold = 31 - REGISTER_BY_NAME_FLAG = 0b1 << 5 + RegisterByNameFlag = 0b0010_0000 + CompatibleTypeDefFlag = 0b0100_0000 + StructTypeDefFlag = 0b1000_0000 FieldNameSizeThreshold = 15 ) @@ -648,8 +669,8 @@ func getFieldNameEncodingIndex(encoding meta.Encoding) int { /* encodingTypeDef encodes a TypeDef into binary format according to the specification typeDef are layout as following: -- first 8 bytes: global header (50 bits hash + 1 bit compress flag + write fields meta + 8 bits meta size) -- next 1 byte: meta header (2 bits reserved + 1 bit register by name flag + 5 bits num fields) +- first 8 bytes: global header (52 bits hash + 1 bit compress flag + 8 bits meta size) +- next 1 byte: kind header - next variable bytes: type id (varint) or ns name + type name - next variable bytes: field defs (see below) */ @@ -758,20 +779,21 @@ func encodingTypeDef(typeResolver *TypeResolver, typeDef *TypeDef) ([]byte, erro return nil, fmt.Errorf("failed to write typename: %w", err) } } else { - buffer.WriteUint8(uint8(typeDef.typeId)) if typeDef.userTypeId == invalidUserTypeID { return nil, fmt.Errorf("missing user type ID for typeID %d", typeDef.typeId) } buffer.WriteVarUint32(typeDef.userTypeId) } - if err := writeFieldDefs(typeResolver, buffer, typeDef.fieldDefs); err != nil { - return nil, fmt.Errorf("failed to write fields def: %w", err) + if isStructTypeId(TypeId(typeDef.typeId)) { + if err := writeFieldDefs(typeResolver, buffer, typeDef.fieldDefs); err != nil { + return nil, fmt.Errorf("failed to write fields def: %w", err) + } + } else if len(typeDef.fieldDefs) != 0 { + return nil, fmt.Errorf("non-struct TypeDef %d cannot carry field metadata", typeDef.typeId) } - // Temporary xlang behavior: keep TypeMeta uncompressed. - // Some runtimes still do not support TypeMeta decompression. - result, err := prependGlobalHeader(buffer, false, len(typeDef.fieldDefs) > 0) + result, err := prependGlobalHeader(buffer, false) if err != nil { return nil, fmt.Errorf("failed to write global binary header: %w", err) } @@ -780,16 +802,11 @@ func encodingTypeDef(typeResolver *TypeResolver, typeDef *TypeDef) ([]byte, erro } // prependGlobalHeader writes the 8-byte global header -func prependGlobalHeader(buffer *ByteBuffer, isCompressed bool, hasFieldsMeta bool) (*ByteBuffer, error) { +func prependGlobalHeader(buffer *ByteBuffer, isCompressed bool) (*ByteBuffer, error) { var header uint64 metaSize := buffer.WriterIndex() - hashValue := Murmur3Sum64WithSeed(buffer.GetByteSlice(0, metaSize), 47) - header |= hashValue << (64 - NUM_HASH_BITS) - - if hasFieldsMeta { - header |= HAS_FIELDS_META_FLAG - } + header |= typeDefHeaderHash(buffer.GetByteSlice(0, metaSize)) if isCompressed { header |= COMPRESS_META_FLAG @@ -814,25 +831,82 @@ func prependGlobalHeader(buffer *ByteBuffer, isCompressed bool, hasFieldsMeta bo // writeMetaHeader writes the 1-byte meta header func writeMetaHeader(buffer *ByteBuffer, typeDef *TypeDef) error { - // 2 bits reserved + 1 bit register by name flag + 5 bits num fields offset := buffer.writerIndex if err := buffer.WriteByte(0xFF); err != nil { return err } fieldInfos := typeDef.fieldDefs - header := len(fieldInfos) - if header > SmallNumFieldsThreshold { - header = SmallNumFieldsThreshold - buffer.WriteVarUint32(uint32(len(fieldInfos) - SmallNumFieldsThreshold)) - } - if typeDef.registerByName { - header |= REGISTER_BY_NAME_FLAG + typeID := TypeId(typeDef.typeId) + var header int + if isStructTypeId(typeID) { + fieldCount := len(fieldInfos) + inlineFieldCount := fieldCount + if inlineFieldCount > SmallNumFieldsThreshold { + inlineFieldCount = SmallNumFieldsThreshold + } + header = StructTypeDefFlag | inlineFieldCount + if typeID == COMPATIBLE_STRUCT || typeID == NAMED_COMPATIBLE_STRUCT { + header |= CompatibleTypeDefFlag + } + if fieldCount >= SmallNumFieldsThreshold { + buffer.WriteVarUint32(uint32(fieldCount - SmallNumFieldsThreshold)) + } + if typeDef.registerByName { + header |= RegisterByNameFlag + } + } else { + if len(fieldInfos) != 0 { + return fmt.Errorf("non-struct TypeDef %d cannot carry field metadata", typeDef.typeId) + } + kindCode, err := xlangNonStructKindCode(typeID) + if err != nil { + return err + } + header = kindCode } buffer.PutUint8(offset, uint8(header)) return nil } +func xlangNonStructKindCode(typeID TypeId) (int, error) { + switch typeID { + case ENUM: + return 0, nil + case NAMED_ENUM: + return 1, nil + case EXT: + return 2, nil + case NAMED_EXT: + return 3, nil + case TYPED_UNION: + return 4, nil + case NAMED_UNION: + return 5, nil + default: + return 0, fmt.Errorf("unsupported TypeDef kind %d", typeID) + } +} + +func xlangNonStructTypeID(kindCode int) (TypeId, error) { + switch kindCode { + case 0: + return ENUM, nil + case 1: + return NAMED_ENUM, nil + case 2: + return EXT, nil + case 3: + return NAMED_EXT, nil + case 4: + return TYPED_UNION, nil + case 5: + return NAMED_UNION, nil + default: + return UNKNOWN, fmt.Errorf("unsupported TypeDef kind code %d", kindCode) + } +} + // writeFieldDefs writes field definitions according to the specification // field def layout as following: // - first 1 byte: header (2 bits field name encoding + 4 bits size + nullability flag + ref tracking flag) @@ -913,8 +987,8 @@ func writeFieldDef(typeResolver *TypeResolver, buffer *ByteBuffer, field FieldDe /* decodeTypeDef decodes a TypeDef from the buffer typeDef are layout as following: - - first 8 bytes: global header (50 bits hash + 1 bit compress flag + write fields meta + 8 bits meta size) - - next 1 byte: meta header (2 bits reserved + 1 bit register by name flag + 5 bits num fields) + - first 8 bytes: global header (52 bits hash + 1 bit compress flag + 8 bits meta size) + - next 1 byte: kind header - next variable bytes: type id (varint) or ns name + type name - next variable bytes: field definitions (see below) */ @@ -922,49 +996,84 @@ func decodeTypeDef(fory *Fory, buffer *ByteBuffer, header int64) (*TypeDef, erro // ReadData 8-byte global header var bufErr Error globalHeader := header - hasFieldsMeta := (globalHeader & HAS_FIELDS_META_FLAG) != 0 + if (globalHeader & RESERVED_META_BITS) != 0 { + return nil, fmt.Errorf("invalid TypeDef global header") + } isCompressed := (globalHeader & COMPRESS_META_FLAG) != 0 + if isCompressed { + return nil, fmt.Errorf("compressed xlang TypeDef is not supported") + } metaSizeBits := int(globalHeader & META_SIZE_MASK) metaSize := metaSizeBits extraMetaSize := 0 if metaSizeBits == META_SIZE_MASK { - extraMetaSize = int(buffer.ReadVarUint32(&bufErr)) + extra := buffer.ReadVarUint32(&bufErr) + if bufErr.HasError() { + return nil, bufErr.TakeError() + } + if uint64(extra) > uint64(MaxInt-metaSize) { + return nil, fmt.Errorf("invalid TypeDef metadata size") + } + extraMetaSize = int(extra) metaSize += extraMetaSize } + if metaSize > fory.config.MaxBinarySize { + return nil, MaxBinarySizeExceededError(metaSize, fory.config.MaxBinarySize) + } // Store the encoded bytes for the TypeDef (including meta header and metadata) encodedMeta := buffer.ReadBinary(metaSize, &bufErr) if bufErr.HasError() { return nil, bufErr.TakeError() } - decodedMeta := encodedMeta - if isCompressed { - decodedMetaBytes, err := decompressMeta(encodedMeta) - if err != nil { - return nil, err - } - decodedMeta = decodedMetaBytes - } - metaBuffer := NewByteBuffer(decodedMeta) + metaBuffer := NewByteBuffer(encodedMeta) var metaErr Error // ReadData 1-byte meta header metaHeaderByte := metaBuffer.ReadByte(&metaErr) - // Extract field count from lower 5 bits - fieldCount := int(metaHeaderByte & SmallNumFieldsThreshold) - if fieldCount == SmallNumFieldsThreshold { - fieldCount += int(metaBuffer.ReadVarUint32(&metaErr)) - } - if fieldCount > fory.config.MaxTypeFields || fieldCount > metaBuffer.remaining() { - return nil, fmt.Errorf("field count exceeds maximum allowed limit or available buffer size") - } - registeredByName := (metaHeaderByte & REGISTER_BY_NAME_FLAG) != 0 + isStruct := (metaHeaderByte & StructTypeDefFlag) != 0 + fieldCount := 0 + registeredByName := false // ReadData name or type ID according to the registerByName flag var typeId uint32 userTypeId := invalidUserTypeID var nsBytes, nameBytes *MetaStringBytes var type_ reflect.Type + if isStruct { + registeredByName = (metaHeaderByte & RegisterByNameFlag) != 0 + fieldCount = int(metaHeaderByte & SmallNumFieldsThreshold) + if fieldCount == SmallNumFieldsThreshold { + fieldCount += int(metaBuffer.ReadVarUint32(&metaErr)) + } + if metaErr.HasError() { + return nil, metaErr.TakeError() + } + if fieldCount > fory.config.MaxTypeFields || fieldCount > metaBuffer.remaining() { + return nil, fmt.Errorf("field count exceeds maximum allowed limit or available buffer size") + } + if registeredByName { + if (metaHeaderByte & CompatibleTypeDefFlag) != 0 { + typeId = uint32(NAMED_COMPATIBLE_STRUCT) + } else { + typeId = uint32(NAMED_STRUCT) + } + } else if (metaHeaderByte & CompatibleTypeDefFlag) != 0 { + typeId = uint32(COMPATIBLE_STRUCT) + } else { + typeId = uint32(STRUCT) + } + } else { + if (metaHeaderByte & 0b0111_0000) != 0 { + return nil, fmt.Errorf("invalid TypeDef kind header") + } + kindType, err := xlangNonStructTypeID(int(metaHeaderByte & 0b1111)) + if err != nil { + return nil, err + } + typeId = uint32(kindType) + registeredByName = IsNamespacedType(kindType) + } if registeredByName { // ReadData namespace and type name for namespaced types // NOTE: TypeDefs use simple name format, not meta string format with dynamic IDs @@ -974,7 +1083,14 @@ func decodeTypeDef(fory *Fory, buffer *ByteBuffer, header int64) (*TypeDef, erro nsEncodingFlags := nsHeader & 0b11 // 2 bits for encoding nsSize := nsHeader >> 2 // 6 bits for size if nsSize == BIG_NAME_THRESHOLD { - nsSize = int(metaBuffer.ReadVarUint32Small7(&metaErr)) + BIG_NAME_THRESHOLD + extra := metaBuffer.ReadVarUint32Small7(&metaErr) + if metaErr.HasError() { + return nil, metaErr.TakeError() + } + if uint64(extra) > uint64(MaxInt-BIG_NAME_THRESHOLD) { + return nil, fmt.Errorf("invalid TypeDef namespace length") + } + nsSize = int(extra) + BIG_NAME_THRESHOLD } // Java pkg encoding: 0=UTF_8, 1=ALL_TO_LOWER_SPECIAL, 2=LOWER_UPPER_DIGIT_SPECIAL @@ -989,6 +1105,9 @@ func decodeTypeDef(fory *Fory, buffer *ByteBuffer, header int64) (*TypeDef, erro default: return nil, fmt.Errorf("invalid package encoding flags: %d", nsEncodingFlags) } + if nsSize > metaBuffer.remaining() { + return nil, fmt.Errorf("TypeDef namespace length %d exceeds remaining metadata %d", nsSize, metaBuffer.remaining()) + } nsData := make([]byte, nsSize) if _, err := metaBuffer.Read(nsData); err != nil { return nil, fmt.Errorf("failed to read namespace data: %w", err) @@ -1000,7 +1119,14 @@ func decodeTypeDef(fory *Fory, buffer *ByteBuffer, header int64) (*TypeDef, erro typeEncodingFlags := typeHeader & 0b11 // 2 bits for encoding typeSize := typeHeader >> 2 // 6 bits for size if typeSize == BIG_NAME_THRESHOLD { - typeSize = int(metaBuffer.ReadVarUint32Small7(&metaErr)) + BIG_NAME_THRESHOLD + extra := metaBuffer.ReadVarUint32Small7(&metaErr) + if metaErr.HasError() { + return nil, metaErr.TakeError() + } + if uint64(extra) > uint64(MaxInt-BIG_NAME_THRESHOLD) { + return nil, fmt.Errorf("invalid TypeDef typename length") + } + typeSize = int(extra) + BIG_NAME_THRESHOLD } // Java typename encoding: 0=UTF_8, 1=ALL_TO_LOWER_SPECIAL, 2=LOWER_UPPER_DIGIT_SPECIAL, 3=FIRST_TO_LOWER_SPECIAL @@ -1017,6 +1143,9 @@ func decodeTypeDef(fory *Fory, buffer *ByteBuffer, header int64) (*TypeDef, erro default: return nil, fmt.Errorf("invalid typename encoding flags: %d", typeEncodingFlags) } + if typeSize > metaBuffer.remaining() { + return nil, fmt.Errorf("TypeDef typename length %d exceeds remaining metadata %d", typeSize, metaBuffer.remaining()) + } typeData := make([]byte, typeSize) if _, err := metaBuffer.Read(typeData); err != nil { return nil, fmt.Errorf("failed to read typename data: %w", err) @@ -1059,18 +1188,19 @@ func decodeTypeDef(fory *Fory, buffer *ByteBuffer, header int64) (*TypeDef, erro if type_.Kind() == reflect.Ptr { type_ = type_.Elem() } - typeId = uint32(info.TypeID) + if uint32(info.TypeID) != typeId { + return nil, fmt.Errorf("TypeDef kind does not match registered type metadata") + } userTypeId = info.UserTypeID } else { - // Type not registered - use NAMED_STRUCT as default typeId - // The type_ will remain nil and will be set from field definitions later - typeId = uint32(NAMED_STRUCT) type_ = nil } } else { - typeId = uint32(metaBuffer.ReadUint8(&metaErr)) userTypeId = metaBuffer.ReadVarUint32(&metaErr) if info, exists := fory.typeResolver.userTypeIdToTypeInfo[userTypeId]; exists { + if uint32(info.TypeID) != typeId { + return nil, fmt.Errorf("TypeDef kind does not match registered type metadata") + } type_ = info.Type } else if info, exists := fory.typeResolver.typeIDToTypeInfo[typeId]; exists { type_ = info.Type @@ -1083,14 +1213,24 @@ func decodeTypeDef(fory *Fory, buffer *ByteBuffer, header int64) (*TypeDef, erro // ReadData fields information fieldInfos := make([]FieldDef, fieldCount) - if hasFieldsMeta { - for i := 0; i < fieldCount; i++ { - fieldInfo, err := readFieldDef(fory.typeResolver, metaBuffer) - if err != nil { - return nil, fmt.Errorf("failed to read field def %d: %w", i, err) - } - fieldInfos[i] = fieldInfo + for i := 0; i < fieldCount; i++ { + fieldInfo, err := readFieldDef(fory.typeResolver, metaBuffer) + if err != nil { + return nil, fmt.Errorf("failed to read field def %d: %w", i, err) } + fieldInfos[i] = fieldInfo + } + if !isStruct && len(fieldInfos) != 0 { + return nil, fmt.Errorf("non-struct TypeDef cannot carry field metadata") + } + if metaErr.HasError() { + return nil, metaErr.TakeError() + } + if remaining := metaBuffer.remaining(); remaining != 0 { + return nil, fmt.Errorf("TypeDef metadata body has %d trailing bytes", remaining) + } + if err := validateParsedTypeDefHash(globalHeader, metaSizeBits, extraMetaSize, encodedMeta); err != nil { + return nil, err } encoded := buildTypeDefEncoded(globalHeader, metaSizeBits, extraMetaSize, encodedMeta) @@ -1130,17 +1270,32 @@ func buildTypeDefEncoded(header int64, metaSizeBits, extraMetaSize int, metaByte return buffer.Bytes() } -func decompressMeta(encoded []byte) ([]byte, error) { - reader, err := zlib.NewReader(bytes.NewReader(encoded)) - if err != nil { - return nil, fmt.Errorf("failed to create meta decompressor: %w", err) +func typeDefHeaderHash(data []byte) uint64 { + hash := int64(Murmur3Sum64WithSeed(data, 47) << (64 - NUM_HASH_BITS)) + if hash < 0 { + hash = -hash } - defer reader.Close() - decoded, err := io.ReadAll(reader) - if err != nil { - return nil, fmt.Errorf("failed to decompress meta: %w", err) + hashMask := ^uint64(0) + hashMask <<= uint(64 - NUM_HASH_BITS) + return uint64(hash) & hashMask +} + +func validateParsedTypeDefHash(header int64, metaSizeBits, extraMetaSize int, encoded []byte) error { + size := metaSizeBits + if size == META_SIZE_MASK { + size += extraMetaSize + } + if len(encoded) != size { + return fmt.Errorf("invalid TypeDef encoded size") } - return decoded, nil + hashMask := ^uint64(0) + hashMask <<= uint(64 - NUM_HASH_BITS) + expectedHeaderHash := typeDefHeaderHash(encoded) + actualHeaderHash := uint64(header) & hashMask + if expectedHeaderHash != actualHeaderHash { + return fmt.Errorf("invalid TypeDef metadata hash") + } + return nil } /* @@ -1174,7 +1329,14 @@ func readFieldDef(typeResolver *TypeResolver, buffer *ByteBuffer) (FieldDef, err // Read tag ID tagID := sizeBits if sizeBits == 0x0F { - tagID = FieldNameSizeThreshold + int(buffer.ReadVarUint32(&bufErr)) + extra := buffer.ReadVarUint32(&bufErr) + if bufErr.HasError() { + return FieldDef{}, bufErr.TakeError() + } + if uint64(extra) > uint64(MaxInt-FieldNameSizeThreshold) { + return FieldDef{}, fmt.Errorf("invalid TypeDef field tag ID") + } + tagID = FieldNameSizeThreshold + int(extra) } // Read field type @@ -1182,6 +1344,9 @@ func readFieldDef(typeResolver *TypeResolver, buffer *ByteBuffer) (FieldDef, err if err != nil { return FieldDef{}, err } + if bufErr.HasError() { + return FieldDef{}, bufErr.TakeError() + } return FieldDef{ name: "", // No field name when using tag ID @@ -1197,7 +1362,14 @@ func readFieldDef(typeResolver *TypeResolver, buffer *ByteBuffer) (FieldDef, err nameEncoding := fieldNameEncodings[nameEncodingFlag] nameLen := sizeBits if nameLen == 0x0F { - nameLen = FieldNameSizeThreshold + int(buffer.ReadVarUint32(&bufErr)) + extra := buffer.ReadVarUint32(&bufErr) + if bufErr.HasError() { + return FieldDef{}, bufErr.TakeError() + } + if uint64(extra) > uint64(MaxInt-FieldNameSizeThreshold) { + return FieldDef{}, fmt.Errorf("invalid TypeDef field name length") + } + nameLen = FieldNameSizeThreshold + int(extra) } else { nameLen++ // Adjust for 1-based encoding } @@ -1207,9 +1379,18 @@ func readFieldDef(typeResolver *TypeResolver, buffer *ByteBuffer) (FieldDef, err if err != nil { return FieldDef{}, err } + if bufErr.HasError() { + return FieldDef{}, bufErr.TakeError() + } // Read field name based on encoding + if nameLen > buffer.remaining() { + return FieldDef{}, fmt.Errorf("TypeDef field name length %d exceeds remaining metadata %d", nameLen, buffer.remaining()) + } nameBytes := buffer.ReadBinary(nameLen, &bufErr) + if bufErr.HasError() { + return FieldDef{}, bufErr.TakeError() + } fieldName, err := typeResolver.typeNameDecoder.Decode(nameBytes, nameEncoding) if err != nil { return FieldDef{}, fmt.Errorf("failed to decode field name: %w", err) diff --git a/go/fory/type_def_test.go b/go/fory/type_def_test.go index 5e88bd1744..55a8a4efeb 100644 --- a/go/fory/type_def_test.go +++ b/go/fory/type_def_test.go @@ -18,10 +18,13 @@ package fory import ( + "bytes" + "compress/zlib" "reflect" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // Test structs for encoding/decoding @@ -176,6 +179,35 @@ func checkTypeSpecRecursivelyOrNil(t *testing.T, original, decoded *TypeSpec, pa checkTypeSpecRecursively(t, original, decoded, path, compareRootFlags) } +func typeDefTestBodyWithoutFields() []byte { + buffer := NewByteBuffer(nil) + buffer.WriteByte(StructTypeDefFlag) + buffer.WriteVarUint32(0) + return buffer.Bytes() +} + +func typeDefTestFrame(t *testing.T, body []byte, compressed bool) (*ByteBuffer, int64) { + t.Helper() + bodyBuffer := NewByteBuffer(nil) + bodyBuffer.WriteBinary(body) + frame, err := prependGlobalHeader(bodyBuffer, compressed) + require.NoError(t, err) + readErr := &Error{} + header := frame.ReadInt64(readErr) + require.NoError(t, readErr.CheckError()) + return frame, header +} + +func deflateTypeDefTestBody(t *testing.T, body []byte) []byte { + t.Helper() + var compressed bytes.Buffer + writer := zlib.NewWriter(&compressed) + _, err := writer.Write(body) + require.NoError(t, err) + require.NoError(t, writer.Close()) + return compressed.Bytes() +} + // Item1 struct with mixed nullable (pointer) and non-nullable (primitive) fields type Item1 struct { F1 int32 @@ -301,22 +333,166 @@ func TestTypeDefNullableFields(t *testing.T) { // allocation that would OOM-crash the process. func TestTypeDefFieldCountOOMPanic(t *testing.T) { fory := NewFory() - header := int64(HAS_FIELDS_META_FLAG | 8) - // metaHeaderByte value of 31 triggers the extended VarUint32 field-count path. buffer := NewByteBuffer(make([]byte, 0, 8)) - buffer.WriteByte(31) + buffer.WriteByte(StructTypeDefFlag | SmallNumFieldsThreshold) buffer.WriteVarUint32(2000000000) - buffer.WriteUint8(0) - buffer.WriteVarUint32(0) buffer.SetReaderIndex(0) - _, err := decodeTypeDef(fory, buffer, header) + _, err := decodeTypeDef(fory, buffer, int64(buffer.WriterIndex())) if err == nil { t.Fatal("expected error for oversized fieldCount, got nil") } } +func TestTypeDefRejectsReservedGlobalHeaderBits(t *testing.T) { + fory := NewFory() + buffer := NewByteBuffer(nil) + buffer.WriteByte(StructTypeDefFlag) + buffer.WriteVarUint32(0) + buffer.SetReaderIndex(0) + + _, err := decodeTypeDef(fory, buffer, int64(RESERVED_META_BITS|uint64(buffer.WriterIndex()))) + if err == nil { + t.Fatal("expected error for reserved TypeDef global header bits") + } +} + +func TestTypeDefRejectsReservedNonStructKindBits(t *testing.T) { + fory := NewFory() + body := []byte{0b0001_0000} + frame, header := typeDefTestFrame(t, body, false) + + _, err := decodeTypeDef(fory, frame, header) + if err == nil { + t.Fatal("expected error for reserved non-struct TypeDef kind bits") + } +} + +func TestTypeDefRejectsTrailingMetadataBytes(t *testing.T) { + fory := NewFory() + meta := NewByteBuffer(nil) + meta.WriteByte(StructTypeDefFlag) + meta.WriteVarUint32(0) + meta.WriteByte(0xff) + + buffer := NewByteBuffer(nil) + _, writeErr := buffer.Write(meta.Bytes()) + if writeErr != nil { + t.Fatalf("failed to write type def metadata: %v", writeErr) + } + buffer.SetReaderIndex(0) + + _, err := decodeTypeDef(fory, buffer, int64(len(meta.Bytes()))) + if err == nil { + t.Fatal("expected error for trailing TypeDef metadata bytes") + } +} + +func TestTypeDefExtendedFieldCountHeaderDoesNotSetRegisterByName(t *testing.T) { + fields := make([]FieldDef, 32) + for i := range fields { + fields[i] = FieldDef{ + typeSpec: NewSimpleTypeSpec(INT32), + tagID: i + 1, + } + } + typeDef := NewTypeDef(uint32(STRUCT), 700, nil, nil, false, false, fields) + buffer := NewByteBuffer(nil) + + require.NoError(t, writeMetaHeader(buffer, typeDef)) + header := buffer.Bytes()[0] + require.Equal(t, byte(StructTypeDefFlag|SmallNumFieldsThreshold), header) + require.Zero(t, header&RegisterByNameFlag) +} + +func TestTypeDefRejectsMetadataHashMismatch(t *testing.T) { + fory := NewFory() + body := typeDefTestBodyWithoutFields() + buffer := NewByteBuffer(nil) + buffer.WriteBinary(body) + buffer.SetReaderIndex(0) + + _, err := decodeTypeDef(fory, buffer, int64(len(body))) + require.Error(t, err) + require.Contains(t, err.Error(), "metadata hash") +} + +func TestTypeDefRejectsEncodedMetadataAboveMaxBinarySize(t *testing.T) { + fory := NewFory(WithMaxBinarySize(1)) + body := typeDefTestBodyWithoutFields() + frame, header := typeDefTestFrame(t, body, false) + + _, err := decodeTypeDef(fory, frame, header) + require.Error(t, err) + require.Contains(t, err.Error(), "max binary size exceeded") +} + +func TestTypeDefRejectsCompressedMetadata(t *testing.T) { + decoded := typeDefTestBodyWithoutFields() + compressed := deflateTypeDefTestBody(t, decoded) + fory := NewFory(WithMaxBinarySize(4096)) + frame, header := typeDefTestFrame(t, compressed, true) + + _, err := decodeTypeDef(fory, frame, header) + require.Error(t, err) + require.Contains(t, err.Error(), "compressed xlang TypeDef") +} + +func TestReadSharedTypeMetaCapsParsedTypeDefCache(t *testing.T) { + fory := NewFory(WithCompatible(true)) + require.NoError(t, fory.RegisterNamedStruct(SimpleStruct{}, "example.SimpleStruct")) + typeDef, err := buildTypeDef(fory, reflect.ValueOf(SimpleStruct{})) + require.NoError(t, err) + require.NotEmpty(t, typeDef.encoded) + + for i := 0; i < maxCachedTypeDefs; i++ { + fory.typeResolver.defIdToTypeDef[int64(i)] = typeDef + } + headerErr := &Error{} + header := NewByteBuffer(typeDef.encoded).ReadInt64(headerErr) + require.NoError(t, headerErr.CheckError()) + require.NotContains(t, fory.typeResolver.defIdToTypeDef, header) + + buffer := NewByteBuffer(nil) + buffer.WriteVarUint32(0) + buffer.WriteBinary(typeDef.encoded) + readErr := &Error{} + typeInfo := fory.typeResolver.readSharedTypeMeta(buffer, readErr) + require.NoError(t, readErr.CheckError()) + require.NotNil(t, typeInfo) + require.Len(t, fory.typeResolver.defIdToTypeDef, maxCachedTypeDefs) + require.NotContains(t, fory.typeResolver.defIdToTypeDef, header) +} + +func TestTypeDefRejectsNamespaceLengthBeyondMetadata(t *testing.T) { + fory := NewFory() + meta := NewByteBuffer(nil) + meta.WriteByte(StructTypeDefFlag | RegisterByNameFlag) + meta.WriteByte(byte(BIG_NAME_THRESHOLD << 2)) + meta.WriteVarUint32Small7(100) + frame, header := typeDefTestFrame(t, meta.Bytes(), false) + + _, err := decodeTypeDef(fory, frame, header) + require.Error(t, err) + require.Contains(t, err.Error(), "namespace length") +} + +func TestTypeDefRejectsFieldNameLengthBeyondMetadata(t *testing.T) { + fory := NewFory() + meta := NewByteBuffer(nil) + meta.WriteByte(StructTypeDefFlag | 1) + meta.WriteVarUint32(0) + meta.WriteByte(0x0F << 2) + meta.WriteVarUint32(100) + meta.WriteUint8(uint8(INT32)) + frame, header := typeDefTestFrame(t, meta.Bytes(), false) + + _, err := decodeTypeDef(fory, frame, header) + require.Error(t, err) + require.Contains(t, err.Error(), "field name length") +} + // TestTypeDefNestedRecursionStackOverflowPanic verifies that readFieldTypeWithFlags // rejects a crafted payload with 20 million nested LIST types, returning an error // at depth 64 instead of recursing until a goroutine stack overflow crashes the process. diff --git a/go/fory/type_resolver.go b/go/fory/type_resolver.go index 3f0e7ee84c..a42b01ec95 100644 --- a/go/fory/type_resolver.go +++ b/go/fory/type_resolver.go @@ -55,6 +55,7 @@ const ( maxUserTypeID uint32 = 0xfffffffe invalidUserTypeID uint32 = 0xffffffff internalTypeIDLimit = 0xFF + maxCachedTypeDefs = 8192 ) var ( @@ -1633,6 +1634,8 @@ func (r *TypeResolver) readSharedTypeMeta(buffer *ByteBuffer, err *Error) *TypeI var td *TypeDef if existingTd, exists := r.defIdToTypeDef[id]; exists { + // Header-cache hits intentionally skip without rehashing. Entries reach this cache only + // after a successful TypeDef parse and 52-bit body-hash validation. skipTypeDef(buffer, id, err) td = existingTd } else { @@ -1640,7 +1643,9 @@ func (r *TypeResolver) readSharedTypeMeta(buffer *ByteBuffer, err *Error) *TypeI if err.HasError() { return nil } - r.defIdToTypeDef[id] = newTd + if len(r.defIdToTypeDef) < maxCachedTypeDefs { + r.defIdToTypeDef[id] = newTd + } td = newTd } diff --git a/go/fory/writer.go b/go/fory/writer.go index 4fccf37112..41ac503000 100644 --- a/go/fory/writer.go +++ b/go/fory/writer.go @@ -38,7 +38,7 @@ type WriteContext struct { depth int maxDepth int typeResolver *TypeResolver // For complex type serialization - refResolver *RefResolver // For reference tracking (legacy) + refResolver *RefResolver // For reference tracking in native-mode paths bufferCallback func(BufferObject) bool // Callback for out-of-band buffers outOfBand bool // Whether out-of-band serialization is enabled err Error // Accumulated error state for deferred checking @@ -108,7 +108,7 @@ func (c *WriteContext) TypeResolver() *TypeResolver { return c.typeResolver } -// RefResolver returns the reference resolver (legacy) +// RefResolver returns the reference resolver. func (c *WriteContext) RefResolver() *RefResolver { return c.refResolver } diff --git a/java/fory-core/src/main/java/org/apache/fory/Fory.java b/java/fory-core/src/main/java/org/apache/fory/Fory.java index adf1e8c6d9..56f6addbd0 100644 --- a/java/fory-core/src/main/java/org/apache/fory/Fory.java +++ b/java/fory-core/src/main/java/org/apache/fory/Fory.java @@ -68,7 +68,7 @@ /** * Cross-language header layout: 1-byte bitmap. * - *

Bit 0: null flag, Bit 1: xlang flag, Bit 2: out-of-band flag, Bits 3-7 reserved. + *

Bit 0: xlang flag, Bit 1: out-of-band flag, Bits 2-7 reserved. * *

serialize/deserialize are the root object APIs. Nested serialization and deserialization go * through {@link WriteContext} and {@link ReadContext}. @@ -86,9 +86,9 @@ public final class Fory implements BaseFory { // this flag indicates that the object is a referencable and first write. public static final byte REF_VALUE_FLAG = 0; public static final byte NOT_SUPPORT_XLANG = 0; - private static final byte isNilFlag = 1; - private static final byte isCrossLanguageFlag = 1 << 1; - private static final byte isOutOfBandFlag = 1 << 2; + private static final byte isCrossLanguageFlag = 1; + private static final byte isOutOfBandFlag = 1 << 1; + private static final byte reservedBitmapFlags = (byte) ~0b11; private final Config config; private final TypeResolver typeResolver; @@ -98,6 +98,7 @@ public final class Fory implements BaseFory { private final WriteContext writeContext; private final ReadContext readContext; private final CopyContext copyContext; + private final byte headerBitmap; private MemoryBuffer buffer; public Fory(ForyBuilder builder, ClassLoader classLoader) { @@ -119,6 +120,7 @@ public Fory(ForyBuilder builder, ClassLoader classLoader, SharedRegistry sharedR this.sharedRegistry = sharedRegistry; this.classLoader = classLoader; config = new Config(builder); + headerBitmap = config.isXlang() ? isCrossLanguageFlag : 0; RefWriter refWriter; RefReader refReader; if (config.trackingRef()) { @@ -293,15 +295,7 @@ public MemoryBuffer serialize(MemoryBuffer buffer, Object obj, BufferCallback ca ensureRegistrationFinished(); writeContext.prepare(buffer, callback); try { - byte bitmap = 0; - if (config.isXlang()) { - bitmap |= isCrossLanguageFlag; - } - if (obj == null) { - bitmap |= isNilFlag; - buffer.writeByte(bitmap); - return buffer; - } + byte bitmap = headerBitmap; if (callback != null) { bitmap |= isOutOfBandFlag; } @@ -379,12 +373,9 @@ public T deserialize(byte[] bytes, Class type) { public T deserialize(MemoryBuffer buffer, Class type) { ensureRegistrationFinished(); byte bitmap = buffer.readByte(); - if ((bitmap & isNilFlag) == isNilFlag) { - return null; + if (bitmap != headerBitmap) { + checkHeaderBitmapWithoutOutOfBand(bitmap); } - boolean peerOutOfBandEnabled = (bitmap & isOutOfBandFlag) == isOutOfBandFlag; - assert !peerOutOfBandEnabled : "Out of band buffers not passed in when deserializing"; - checkXlangBitmap(bitmap); readContext.prepare(buffer, null, false); try { try { @@ -449,11 +440,10 @@ public Object deserialize(MemoryBuffer buffer) { public Object deserialize(MemoryBuffer buffer, Iterable outOfBandBuffers) { ensureRegistrationFinished(); byte bitmap = buffer.readByte(); - if ((bitmap & isNilFlag) == isNilFlag) { - return null; + boolean peerOutOfBandEnabled = false; + if (bitmap != headerBitmap) { + peerOutOfBandEnabled = checkHeaderBitmap(bitmap); } - checkXlangBitmap(bitmap); - boolean peerOutOfBandEnabled = (bitmap & isOutOfBandFlag) == isOutOfBandFlag; if (peerOutOfBandEnabled) { Preconditions.checkNotNull( outOfBandBuffers, @@ -530,13 +520,24 @@ private T deserializeByType(MemoryBuffer buffer, Class type) { } } - private void checkXlangBitmap(byte bitmap) { + private void checkHeaderBitmapWithoutOutOfBand(byte bitmap) { + if (checkHeaderBitmap(bitmap)) { + throw new IllegalArgumentException("Out of band buffers not passed in when deserializing"); + } + } + + private boolean checkHeaderBitmap(byte bitmap) { + Preconditions.checkArgument( + (bitmap & reservedBitmapFlags) == 0, + "Serialized payload uses reserved header bitmap flags 0x%s", + Integer.toHexString(Byte.toUnsignedInt((byte) (bitmap & reservedBitmapFlags)))); boolean payloadCrossLanguage = (bitmap & isCrossLanguageFlag) == isCrossLanguageFlag; Preconditions.checkArgument( payloadCrossLanguage == config.isXlang(), "Serialized payload xlang flag %s does not match this Fory mode %s", payloadCrossLanguage, config.isXlang()); + return (bitmap & isOutOfBandFlag) == isOutOfBandFlag; } @Override diff --git a/java/fory-core/src/main/java/org/apache/fory/collection/LongLongByteMap.java b/java/fory-core/src/main/java/org/apache/fory/collection/LongLongByteMap.java index 3df143179b..26bc53190b 100644 --- a/java/fory-core/src/main/java/org/apache/fory/collection/LongLongByteMap.java +++ b/java/fory-core/src/main/java/org/apache/fory/collection/LongLongByteMap.java @@ -19,6 +19,7 @@ package org.apache.fory.collection; +import java.util.Arrays; import org.apache.fory.annotation.Internal; import org.apache.fory.util.Preconditions; @@ -129,6 +130,25 @@ public V get(long k1, long k2, byte k3) { } } + public void clear() { + if (size == 0) { + return; + } + size = 0; + Arrays.fill(keyTable, null); + ObjectArray.clearObjectArray(valueTable, 0, valueTable.length); + } + + public void clear(int maximumCapacity) { + int tableSize = ForyObjectMap.tableSize(maximumCapacity, loadFactor); + if (keyTable.length <= tableSize) { + clear(); + return; + } + size = 0; + resize(tableSize); + } + private void resize(int newSize) { int oldCapacity = keyTable.length; threshold = (int) (newSize * loadFactor); diff --git a/java/fory-core/src/main/java/org/apache/fory/context/MetaStringReader.java b/java/fory-core/src/main/java/org/apache/fory/context/MetaStringReader.java index 23b8d444ee..ccd85f8001 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/MetaStringReader.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/MetaStringReader.java @@ -23,11 +23,12 @@ import org.apache.fory.annotation.Internal; import org.apache.fory.collection.LongLongByteMap; import org.apache.fory.collection.LongMap; +import org.apache.fory.exception.ForyException; import org.apache.fory.memory.LittleEndian; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.meta.EncodedMetaString; +import org.apache.fory.meta.MetaString; import org.apache.fory.resolver.SharedRegistry; -import org.apache.fory.util.MurmurHash3; /** * Read-side state for meta-string references. @@ -40,6 +41,9 @@ public final class MetaStringReader { private static final int INITIAL_CAPACITY = 2; private static final float LOAD_FACTOR = 0.5f; private static final int SMALL_STRING_THRESHOLD = 16; + private static final int ENCODING_BITS = 4; + private static final int MAX_CACHED_READ_META_STRINGS = 8192; + private static final int MAX_CACHED_READ_META_STRING_LENGTH = 2048; private final LongMap hash2MetaStringMap = new LongMap<>(INITIAL_CAPACITY, LOAD_FACTOR); @@ -47,7 +51,7 @@ public final class MetaStringReader { new LongLongByteMap<>(INITIAL_CAPACITY, LOAD_FACTOR); private final SharedRegistry sharedRegistry; private EncodedMetaString[] dynamicReadStringIds = new EncodedMetaString[INITIAL_CAPACITY]; - private short dynamicReadStringId; + private int dynamicReadStringId; /** Creates an empty reader state for one deserialization stream. */ public MetaStringReader(SharedRegistry sharedRegistry) { @@ -68,7 +72,7 @@ public EncodedMetaString readMetaStringWithFlag(MemoryBuffer buffer, int header) updateDynamicString(encodedMetaString); return encodedMetaString; } - return dynamicReadStringIds[len - 1]; + return readDynamicString(len); } /** @@ -88,7 +92,7 @@ public EncodedMetaString readMetaStringWithFlag( updateDynamicString(encodedMetaString); return encodedMetaString; } - return dynamicReadStringIds[len - 1]; + return readDynamicString(len); } /** Reads a meta string from the current buffer, including any dynamic-id indirection. */ @@ -103,7 +107,7 @@ public EncodedMetaString readMetaString(MemoryBuffer buffer) { updateDynamicString(encodedMetaString); return encodedMetaString; } - return dynamicReadStringIds[len - 1]; + return readDynamicString(len); } /** @@ -121,31 +125,71 @@ public EncodedMetaString readMetaString(MemoryBuffer buffer, EncodedMetaString c updateDynamicString(encodedMetaString); return encodedMetaString; } - return dynamicReadStringIds[len - 1]; + return readDynamicString(len); } private EncodedMetaString readBigMetaString( MemoryBuffer buffer, EncodedMetaString cache, int len) { long hashCode = buffer.readInt64(); - if (cache.hash == hashCode) { - buffer.increaseReaderIndex(len); + if (cache.hash == hashCode && cache.bytes.length == len) { + buffer.checkReadableBytes(len); + buffer._increaseReaderIndexUnsafe(len); return cache; } return readBigMetaString(buffer, len, hashCode); } private EncodedMetaString readBigMetaString(MemoryBuffer buffer, int len, long hashCode) { + buffer.checkReadableBytes(len); EncodedMetaString encodedMetaString = hash2MetaStringMap.get(hashCode); - if (encodedMetaString == null) { - encodedMetaString = - sharedRegistry.getOrCreateEncodedMetaString(buffer.readBytes(len), hashCode); - hash2MetaStringMap.put(hashCode, encodedMetaString); + if (encodedMetaString != null && encodedMetaString.bytes.length == len) { + buffer._increaseReaderIndexUnsafe(len); return encodedMetaString; } - buffer.increaseReaderIndex(len); + byte[] bytes = readAndValidateBigMetaString(buffer, len, hashCode); + EncodedMetaString canonicalMetaString = + sharedRegistry.getOrCreateEncodedMetaString(bytes, hashCode); + if (encodedMetaString == null + && len <= MAX_CACHED_READ_META_STRING_LENGTH + && hash2MetaStringMap.size < MAX_CACHED_READ_META_STRINGS) { + hash2MetaStringMap.put(hashCode, canonicalMetaString); + } + return canonicalMetaString; + } + + private byte[] readAndValidateBigMetaString(MemoryBuffer buffer, int len, long hashCode) { + byte[] bytes = buffer.readBytes(len); + MetaString.Encoding encoding = MetaString.Encoding.fromInt((int) (hashCode & 0xff)); + long canonicalHash = EncodedMetaString.computeHash(bytes, encoding); + if (canonicalHash != hashCode) { + throw new ForyException("Malformed meta string hash"); + } + return bytes; + } + + private boolean shouldCacheSmallMetaString() { + return longLongMetaStringMap.size < MAX_CACHED_READ_META_STRINGS; + } + + private EncodedMetaString cacheSmallMetaString( + long v1, long v2, byte key, EncodedMetaString encodedMetaString) { + if (shouldCacheSmallMetaString()) { + longLongMetaStringMap.put(v1, v2, key, encodedMetaString); + } return encodedMetaString; } + private EncodedMetaString createSmallMetaString( + int len, MetaString.Encoding encoding, byte key, long v1, long v2) { + byte[] data = new byte[16]; + LittleEndian.putInt64(data, 0, v1); + LittleEndian.putInt64(data, 8, v2); + byte[] bytes = Arrays.copyOf(data, len); + long hashCode = EncodedMetaString.computeHash(bytes, encoding); + return cacheSmallMetaString( + v1, v2, key, sharedRegistry.getOrCreateEncodedMetaString(bytes, hashCode)); + } + private EncodedMetaString readSmallMetaString(MemoryBuffer buffer, int len) { if (len == 0) { return EncodedMetaString.EMPTY; @@ -159,9 +203,11 @@ private EncodedMetaString readSmallMetaString(MemoryBuffer buffer, int len) { v1 = buffer.readInt64(); v2 = buffer.readBytesAsInt64(len - 8); } - EncodedMetaString encodedMetaString = longLongMetaStringMap.get(v1, v2, encoding); + int encodingValue = encoding & 0xff; + byte key = smallMetaStringKey(len, encodingValue); + EncodedMetaString encodedMetaString = longLongMetaStringMap.get(v1, v2, key); if (encodedMetaString == null) { - return createSmallMetaString(len, encoding, v1, v2); + return createSmallMetaString(len, MetaString.Encoding.fromInt(encodingValue), key, v1, v2); } return encodedMetaString; } @@ -180,39 +226,45 @@ private EncodedMetaString readSmallMetaString( v1 = buffer.readInt64(); v2 = buffer.readBytesAsInt64(len - 8); } - if (cache.first8Bytes == v1 && cache.second8Bytes == v2) { + int encodingValue = encoding & 0xff; + if (cache.bytes.length == len + && cache.encodingValue == encodingValue + && cache.first8Bytes == v1 + && cache.second8Bytes == v2) { return cache; } - EncodedMetaString encodedMetaString = longLongMetaStringMap.get(v1, v2, encoding); + byte key = smallMetaStringKey(len, encodingValue); + EncodedMetaString encodedMetaString = longLongMetaStringMap.get(v1, v2, key); if (encodedMetaString == null) { - return createSmallMetaString(len, encoding, v1, v2); + return createSmallMetaString(len, MetaString.Encoding.fromInt(encodingValue), key, v1, v2); } return encodedMetaString; } - private EncodedMetaString createSmallMetaString(int len, byte encoding, long v1, long v2) { - byte[] data = new byte[16]; - LittleEndian.putInt64(data, 0, v1); - LittleEndian.putInt64(data, 8, v2); - long hashCode = MurmurHash3.murmurhash3_x64_128(data, 0, len, 47)[0]; - hashCode = Math.abs(hashCode); - hashCode = (hashCode & 0xffffffffffffff00L) | encoding; - EncodedMetaString encodedMetaString = - sharedRegistry.getOrCreateEncodedMetaString(Arrays.copyOf(data, len), hashCode); - longLongMetaStringMap.put(v1, v2, encoding, encodedMetaString); - return encodedMetaString; + private static byte smallMetaStringKey(int len, int encodingValue) { + return (byte) (((len - 1) << ENCODING_BITS) | encodingValue); + } + + private EncodedMetaString readDynamicString(int dynamicId) { + if (dynamicId <= 0 || dynamicId > dynamicReadStringId) { + throw new ForyException("Invalid meta string reference id " + dynamicId); + } + return dynamicReadStringIds[dynamicId - 1]; } private void updateDynamicString(EncodedMetaString encodedMetaString) { - short currentDynamicReadId = dynamicReadStringId++; + int currentDynamicReadId = dynamicReadStringId++; EncodedMetaString[] readStringIds = dynamicReadStringIds; if (readStringIds.length <= currentDynamicReadId) { + if (currentDynamicReadId >= MAX_CACHED_READ_META_STRINGS) { + throw new ForyException("Too many meta string references in payload"); + } readStringIds = dynamicReadStringIds = growRead(readStringIds, currentDynamicReadId); } readStringIds[currentDynamicReadId] = encodedMetaString; } - private EncodedMetaString[] growRead(EncodedMetaString[] current, int id) { + private static EncodedMetaString[] growRead(EncodedMetaString[] current, int id) { int newLength = current.length; while (newLength <= id) { newLength <<= 1; diff --git a/java/fory-core/src/main/java/org/apache/fory/io/ForyReadableChannel.java b/java/fory-core/src/main/java/org/apache/fory/io/ForyReadableChannel.java index 1afbcbf7b3..ff8d9f85ac 100644 --- a/java/fory-core/src/main/java/org/apache/fory/io/ForyReadableChannel.java +++ b/java/fory-core/src/main/java/org/apache/fory/io/ForyReadableChannel.java @@ -66,9 +66,9 @@ public int fillBuffer(int minFillSize) { memoryBuf.initDirectBuffer(ByteBufferUtil.getAddress(byteBuf), position, byteBuf); } byteBuf.limit(newLimit); - int readCount = channel.read(byteBuf); - memoryBuf.increaseSize(readCount); - return readCount; + readFully(byteBuf, minFillSize); + memoryBuf.increaseSize(minFillSize); + return minFillSize; } catch (IOException e) { throw new DeserializationException("Failed to read the provided byte channel", e); } @@ -98,7 +98,7 @@ public void readTo(byte[] dst, int dstIndex, int length) { buf.readBytes(dst, dstIndex, remaining); try { ByteBuffer buffer = ByteBuffer.wrap(dst, dstIndex + remaining, length - remaining); - channel.read(buffer); + readFully(buffer, length - remaining); } catch (IOException e) { throw new DeserializationException("Failed to read the provided byte channel", e); } @@ -130,10 +130,13 @@ public void readToByteBuffer(ByteBuffer dst, int length) { int newLimit = dst.position() + length - remaining; if (dstLimit > newLimit) { dst.limit(newLimit); - channel.read(dst); - dst.limit(dstLimit); + try { + readFully(dst, length - remaining); + } finally { + dst.limit(dstLimit); + } } else { - channel.read(dst); + readFully(dst, length - remaining); } } catch (IOException e) { throw new DeserializationException("Failed to read the provided byte channel", e); @@ -169,4 +172,15 @@ public void close() throws IOException { public MemoryBuffer getBuffer() { return memoryBuffer; } + + private void readFully(ByteBuffer dst, int length) throws IOException { + int remaining = length; + while (remaining > 0) { + int read = channel.read(dst); + if (read <= 0) { + throw new DeserializationException("Unexpected end of byte channel"); + } + remaining -= read; + } + } } diff --git a/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java b/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java index 48d37e97d8..2ae29eaab1 100644 --- a/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java +++ b/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java @@ -2716,6 +2716,28 @@ public boolean equalTo(MemoryBuffer buf2, int offset1, int offset2, int len) { return Platform.arrayEquals(heapMemory, pos1, buf2.heapMemory, pos2, len); } + /** + * Equals a memory buffer region with a byte array region. + * + * @param bytes Array to compare with + * @param bytesOffset Offset of bytes to start comparing + * @param offset Offset of this buffer to start comparing + * @param len Length of the compared memory region + * @return true if regions are equal or len zero, false otherwise + */ + public boolean equalTo(byte[] bytes, int bytesOffset, int offset, int len) { + checkArgument(bytes != null); + checkArgument(len >= 0); + checkArgument(bytesOffset >= 0 && bytesOffset <= bytes.length - len); + checkArgument(offset >= 0 && offset <= size - len); + if (len == 0) { + return true; + } + final long pos = address + offset; + return Platform.arrayEquals( + heapMemory, pos, bytes, Platform.BYTE_ARRAY_OFFSET + bytesOffset, len); + } + @Override public String toString() { return "MemoryBuffer{" diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/DeflaterMetaCompressor.java b/java/fory-core/src/main/java/org/apache/fory/meta/DeflaterMetaCompressor.java index 7611741124..cf8cee2682 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/DeflaterMetaCompressor.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/DeflaterMetaCompressor.java @@ -43,6 +43,11 @@ public byte[] compress(byte[] input, int offset, int size) { @Override public byte[] decompress(byte[] input, int offset, int size) { + return decompress(input, offset, size, Integer.MAX_VALUE); + } + + @Override + public byte[] decompress(byte[] input, int offset, int size, int maxOutputSize) { Inflater inflater = new Inflater(); inflater.setInput(input, offset, size); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); @@ -51,6 +56,10 @@ public byte[] decompress(byte[] input, int offset, int size) { while (!inflater.finished()) { int decompressedSize = inflater.inflate(buffer); if (decompressedSize > 0) { + if (outputStream.size() > maxOutputSize - decompressedSize) { + throw new InvalidDataException( + "Decompressed TypeDef metadata exceeds the maximum size."); + } outputStream.write(buffer, 0, decompressedSize); continue; } diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/EncodedMetaString.java b/java/fory-core/src/main/java/org/apache/fory/meta/EncodedMetaString.java index b13c1630b7..1692aa63d0 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/EncodedMetaString.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/EncodedMetaString.java @@ -33,6 +33,7 @@ public final class EncodedMetaString { public final byte[] bytes; public final long hash; + public final int encodingValue; public final MetaString.Encoding encoding; public final long first8Bytes; public final long second8Bytes; @@ -45,7 +46,8 @@ public EncodedMetaString(byte[] bytes, long hash) { assert hash != 0; this.bytes = bytes; this.hash = hash; - this.encoding = MetaString.Encoding.fromInt((int) (hash & HEADER_MASK)); + this.encodingValue = (int) (hash & HEADER_MASK); + this.encoding = MetaString.Encoding.fromInt(encodingValue); byte[] data = bytes; if (bytes.length < 16) { data = new byte[16]; @@ -55,7 +57,7 @@ public EncodedMetaString(byte[] bytes, long hash) { second8Bytes = LittleEndian.getInt64(data, Platform.BYTE_ARRAY_OFFSET + 8); } - private static long computeHash(byte[] bytes, MetaString.Encoding encoding) { + public static long computeHash(byte[] bytes, MetaString.Encoding encoding) { long hash = MurmurHash3.murmurhash3_x64_128(bytes, 0, bytes.length, 47)[0]; hash = Math.abs(hash); if (hash == 0) { diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/MetaCompressor.java b/java/fory-core/src/main/java/org/apache/fory/meta/MetaCompressor.java index 0ed3974b1c..cf74e20c2d 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/MetaCompressor.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/MetaCompressor.java @@ -19,6 +19,8 @@ package org.apache.fory.meta; +import org.apache.fory.exception.InvalidDataException; + /** * An interface used to compress class metadata such as field names and types. The implementation of * this interface should be thread safe. @@ -28,6 +30,14 @@ public interface MetaCompressor { byte[] decompress(byte[] data, int offset, int size); + default byte[] decompress(byte[] data, int offset, int size, int maxOutputSize) { + byte[] decompressed = decompress(data, offset, size); + if (decompressed.length > maxOutputSize) { + throw new InvalidDataException("Decompressed TypeDef metadata exceeds the maximum size."); + } + return decompressed; + } + /** * Check whether {@link MetaCompressor} implements `equals/hashCode` method. If not implemented, * return {@link TypeEqualMetaCompressor} instead which compare equality by the compressor type diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java b/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java index f3c02af179..84f3c5b5f1 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java @@ -25,12 +25,12 @@ import static org.apache.fory.meta.NativeTypeDefEncoder.BIG_NAME_THRESHOLD; import static org.apache.fory.meta.NativeTypeDefEncoder.NUM_CLASS_THRESHOLD; import static org.apache.fory.meta.TypeDef.COMPRESS_META_FLAG; -import static org.apache.fory.meta.TypeDef.HAS_FIELDS_META_FLAG; import static org.apache.fory.meta.TypeDef.META_SIZE_MASKS; import java.util.ArrayList; import java.util.List; import org.apache.fory.collection.Tuple2; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.meta.FieldTypes.FieldType; import org.apache.fory.meta.MetaString.Encoding; @@ -38,6 +38,7 @@ import org.apache.fory.resolver.TypeResolver; import org.apache.fory.serializer.UnknownClass; import org.apache.fory.type.Types; +import org.apache.fory.util.MurmurHash3; import org.apache.fory.util.Preconditions; /** @@ -46,8 +47,13 @@ * href="https://fory.apache.org/docs/specification/fory_java_serialization_spec">... */ class NativeTypeDefDecoder { + private static final int MAX_TYPE_DEF_SIZE_BYTES = 16 * 1024 * 1024; + static Tuple2 decodeTypeDefBuf( MemoryBuffer inputBuffer, TypeResolver resolver, long id) { + if ((id & TypeDef.RESERVED_META_FLAGS) != 0) { + throw new DeserializationException("Invalid TypeDef global header"); + } MemoryBuffer encoded = MemoryBuffer.newHeapBuffer(64); encoded.writeInt64(id); int size = (int) (id & META_SIZE_MASKS); @@ -56,10 +62,17 @@ static Tuple2 decodeTypeDefBuf( encoded.writeVarUInt32(moreSize); size += moreSize; } + if (size > MAX_TYPE_DEF_SIZE_BYTES) { + throw new DeserializationException("TypeDef metadata size exceeds the maximum size"); + } byte[] encodedTypeDef = inputBuffer.readBytes(size); encoded.writeBytes(encodedTypeDef); if ((id & COMPRESS_META_FLAG) != 0) { - encodedTypeDef = resolver.getConfig().getMetaCompressor().decompress(encodedTypeDef, 0, size); + encodedTypeDef = + resolver + .getConfig() + .getMetaCompressor() + .decompress(encodedTypeDef, 0, size, MAX_TYPE_DEF_SIZE_BYTES); } return Tuple2.of(encodedTypeDef, encoded.getBytes(0, encoded.writerIndex())); } @@ -67,7 +80,9 @@ static Tuple2 decodeTypeDefBuf( public static TypeDef decodeTypeDef(ClassResolver resolver, MemoryBuffer buffer, long id) { Tuple2 decoded = decodeTypeDefBuf(buffer, resolver, id); MemoryBuffer typeDefBuf = MemoryBuffer.fromByteArray(decoded.f0); - int numClasses = typeDefBuf.readByte(); + int bodyHeader = typeDefBuf.readByte() & 0xff; + int rootTypeId = nativeTypeId(bodyHeader >>> 4); + int numClasses = bodyHeader & NUM_CLASS_THRESHOLD; if (numClasses == NUM_CLASS_THRESHOLD) { numClasses += typeDefBuf.readVarUInt32Small7(); } @@ -111,12 +126,14 @@ public static TypeDef decodeTypeDef(ClassResolver resolver, MemoryBuffer buffer, resolver.loadClassForMeta( decodedSpec.entireClassName, decodedSpec.isEnum, decodedSpec.dimension); if (UnknownClass.isUnknowClass(cls)) { - int typeId; + int decodedTypeId; if (decodedSpec.isEnum) { - typeId = Types.NAMED_ENUM; + decodedTypeId = Types.NAMED_ENUM; } else { - typeId = resolver.isCompatible() ? Types.NAMED_COMPATIBLE_STRUCT : Types.NAMED_STRUCT; + decodedTypeId = + resolver.isCompatible() ? Types.NAMED_COMPATIBLE_STRUCT : Types.NAMED_STRUCT; } + int typeId = i == numClasses - 1 ? rootTypeId : decodedTypeId; classSpec = new ClassSpec( decodedSpec.entireClassName, @@ -134,12 +151,71 @@ public static TypeDef decodeTypeDef(ClassResolver resolver, MemoryBuffer buffer, } } } + if (i == numClasses - 1 && classSpec.typeId != rootTypeId) { + throw new DeserializationException("TypeDef root kind does not match root class metadata"); + } List fieldInfos = readFieldsInfo(typeDefBuf, resolver, className, numFields); classFields.addAll(fieldInfos); } Preconditions.checkNotNull(classSpec); - boolean hasFieldsMeta = (id & HAS_FIELDS_META_FLAG) != 0; - return new TypeDef(classSpec, classFields, hasFieldsMeta, id, decoded.f1); + if (!Types.isStructType(rootTypeId) && !classFields.isEmpty()) { + throw new DeserializationException("Non-struct TypeDef cannot carry field metadata"); + } + if (typeDefBuf.remaining() != 0) { + throw new DeserializationException("Invalid TypeDef metadata size"); + } + validateParsedTypeDefHash(id, decoded.f1); + return new TypeDef(classSpec, classFields, id, decoded.f1); + } + + static int nativeTypeId(int kindCode) { + switch (kindCode) { + case 0: + return Types.STRUCT; + case 1: + return Types.COMPATIBLE_STRUCT; + case 2: + return Types.NAMED_STRUCT; + case 3: + return Types.NAMED_COMPATIBLE_STRUCT; + case 4: + return Types.ENUM; + case 5: + return Types.NAMED_ENUM; + case 6: + return Types.EXT; + case 7: + return Types.NAMED_EXT; + case 8: + return Types.TYPED_UNION; + case 9: + return Types.NAMED_UNION; + default: + throw new DeserializationException("Unsupported TypeDef kind code " + kindCode); + } + } + + static void validateParsedTypeDefHash(long id, byte[] encoded) { + int size = (int) (id & META_SIZE_MASKS); + int bodyOffset = Long.BYTES; + if (size == META_SIZE_MASKS) { + MemoryBuffer encodedBuffer = MemoryBuffer.fromByteArray(encoded); + encodedBuffer.readerIndex(Long.BYTES); + int moreSize = encodedBuffer.readVarUInt32Small14(); + size += moreSize; + bodyOffset = encodedBuffer.readerIndex(); + } + if (encoded.length - bodyOffset != size) { + throw new DeserializationException("Invalid TypeDef encoded size"); + } + long hash = MurmurHash3.murmurhash3_x64_128(encoded, bodyOffset, size, 47)[0]; + hash <<= (Long.SIZE - TypeDef.NUM_HASH_BITS); + long hashMask = -1L << (Long.SIZE - TypeDef.NUM_HASH_BITS); + long expectedHeaderHash = Math.abs(hash) & hashMask; + long actualHeaderHash = id & hashMask; + if (expectedHeaderHash != actualHeaderHash) { + throw new DeserializationException("Invalid TypeDef metadata hash"); + } } private static List readFieldsInfo( diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefEncoder.java b/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefEncoder.java index 3f7998184d..bb2e104ee2 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefEncoder.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefEncoder.java @@ -23,7 +23,6 @@ import static org.apache.fory.meta.Encoders.pkgEncodingsList; import static org.apache.fory.meta.Encoders.typeNameEncodingsList; import static org.apache.fory.meta.TypeDef.COMPRESS_META_FLAG; -import static org.apache.fory.meta.TypeDef.HAS_FIELDS_META_FLAG; import static org.apache.fory.meta.TypeDef.META_SIZE_MASKS; import static org.apache.fory.meta.TypeDef.NUM_HASH_BITS; @@ -56,7 +55,6 @@ */ @Internal public class NativeTypeDefEncoder { - // a flag to mark a type is not struct. static final int NUM_CLASS_THRESHOLD = 0b1111; private static final java.util.function.Function IDENTITY_DESCRIPTOR = descriptor -> descriptor; @@ -142,43 +140,35 @@ public static List buildFieldsInfo(TypeResolver resolver, List } /** Build class definition from fields of class. */ - static TypeDef buildTypeDef( - ClassResolver classResolver, Class type, List fields, boolean hasFieldsMeta) { - return buildTypeDefWithFieldInfos( - classResolver, type, buildFieldsInfo(classResolver, fields), hasFieldsMeta); + static TypeDef buildTypeDef(ClassResolver classResolver, Class type, List fields) { + return buildTypeDefWithFieldInfos(classResolver, type, buildFieldsInfo(classResolver, fields)); } public static TypeDef buildTypeDefWithFieldInfos( - ClassResolver classResolver, - Class type, - List fieldInfos, - boolean hasFieldsMeta) { + ClassResolver classResolver, Class type, List fieldInfos) { Map> classLayers = getClassFields(type, fieldInfos); fieldInfos = new ArrayList<>(fieldInfos.size()); classLayers.values().forEach(fieldInfos::addAll); - MemoryBuffer encodeTypeDef = encodeTypeDef(classResolver, type, classLayers, hasFieldsMeta); + MemoryBuffer encodeTypeDef = encodeTypeDef(classResolver, type, classLayers); byte[] typeDefBytes = encodeTypeDef.getBytes(0, encodeTypeDef.writerIndex()); int typeId = classResolver.getTypeIdForTypeDef(type); int userTypeId = classResolver.getUserTypeIdForTypeDef(type); ClassSpec classSpec = new ClassSpec(type, typeId, userTypeId); - return new TypeDef( - classSpec, fieldInfos, hasFieldsMeta, encodeTypeDef.getInt64(0), typeDefBytes); + return new TypeDef(classSpec, fieldInfos, encodeTypeDef.getInt64(0), typeDefBytes); } // see spec documentation: docs/specification/java_serialization_spec.md // https://fory.apache.org/docs/specification/fory_java_serialization_spec public static MemoryBuffer encodeTypeDef( - ClassResolver classResolver, - Class type, - Map> classLayers, - boolean hasFieldsMeta) { + ClassResolver classResolver, Class type, Map> classLayers) { MemoryBuffer typeDefBuf = MemoryBuffer.newHeapBuffer(128); int numClasses = classLayers.size() - 1; // num class must be greater than 0 + int firstBodyByte = nativeKindCode(classResolver.getTypeIdForTypeDef(type)) << 4; if (numClasses >= NUM_CLASS_THRESHOLD) { - typeDefBuf.writeByte(NUM_CLASS_THRESHOLD); + typeDefBuf.writeByte(firstBodyByte | NUM_CLASS_THRESHOLD); typeDefBuf.writeVarUInt32Small7(numClasses - NUM_CLASS_THRESHOLD); } else { - typeDefBuf.writeByte(numClasses); + typeDefBuf.writeByte(firstBodyByte | numClasses); } for (Map.Entry> entry : classLayers.entrySet()) { String className = entry.getKey(); @@ -224,11 +214,10 @@ public static MemoryBuffer encodeTypeDef( typeDefBuf = MemoryBuffer.fromByteArray(compressed); typeDefBuf.writerIndex(compressed.length); } - return prependHeader(typeDefBuf, isCompressed, hasFieldsMeta); + return prependHeader(typeDefBuf, isCompressed); } - static MemoryBuffer prependHeader( - MemoryBuffer buffer, boolean isCompressed, boolean hasFieldsMeta) { + static MemoryBuffer prependHeader(MemoryBuffer buffer, boolean isCompressed) { int metaSize = buffer.writerIndex(); long hash = MurmurHash3.murmurhash3_x64_128(buffer.getHeapMemory(), 0, metaSize, 47)[0]; hash <<= (64 - NUM_HASH_BITS); @@ -237,9 +226,6 @@ static MemoryBuffer prependHeader( if (isCompressed) { header |= COMPRESS_META_FLAG; } - if (hasFieldsMeta) { - header |= HAS_FIELDS_META_FLAG; - } header |= Math.min(metaSize, META_SIZE_MASKS); MemoryBuffer result = MemoryUtils.buffer(metaSize + 8); result.writeInt64(header); @@ -250,6 +236,33 @@ static MemoryBuffer prependHeader( return result; } + static int nativeKindCode(int typeId) { + switch (typeId) { + case Types.STRUCT: + return 0; + case Types.COMPATIBLE_STRUCT: + return 1; + case Types.NAMED_STRUCT: + return 2; + case Types.NAMED_COMPATIBLE_STRUCT: + return 3; + case Types.ENUM: + return 4; + case Types.NAMED_ENUM: + return 5; + case Types.EXT: + return 6; + case Types.NAMED_EXT: + return 7; + case Types.TYPED_UNION: + return 8; + case Types.NAMED_UNION: + return 9; + default: + throw new IllegalArgumentException("Unsupported TypeDef kind " + typeId); + } + } + private static Class getType(Class cls, String type) { Class c = cls; while (cls != null) { diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/TypeDef.java b/java/fory-core/src/main/java/org/apache/fory/meta/TypeDef.java index 4fd91857c9..abfa826fcc 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/TypeDef.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/TypeDef.java @@ -69,11 +69,11 @@ public class TypeDef implements Serializable { private static final Logger LOG = LoggerFactory.getLogger(TypeDef.class); - static final int COMPRESS_META_FLAG = 0b1 << 9; - static final int HAS_FIELDS_META_FLAG = 0b1 << 8; + static final int COMPRESS_META_FLAG = 0b1 << 8; + static final long RESERVED_META_FLAGS = 0b111L << 9; // low 8 bits static final int META_SIZE_MASKS = 0xff; - static final int NUM_HASH_BITS = 50; + static final int NUM_HASH_BITS = 52; // TODO use field offset to sort field, which will hit l1-cache more. Since // `objectFieldOffset` is not part of jvm-specification, it may change between different jdk @@ -101,26 +101,21 @@ public class TypeDef implements Serializable { private final ClassSpec classSpec; private final List fieldsInfo; - private final boolean hasFieldsMeta; // Unique id for class def. If class def are same between processes, then the id will // be same too. private final long id; private final byte[] encoded; - TypeDef( - ClassSpec classSpec, - List fieldsInfo, - boolean hasFieldsMeta, - long id, - byte[] encoded) { + TypeDef(ClassSpec classSpec, List fieldsInfo, long id, byte[] encoded) { this.classSpec = classSpec; this.fieldsInfo = fieldsInfo; - this.hasFieldsMeta = hasFieldsMeta; this.id = id; this.encoded = encoded; } public static void skipTypeDef(MemoryBuffer buffer, long id) { + // Header-cache hits use the validated header as the cache key. The current body is skipped by + // its declared size; body hash validation belongs to the parse-before-cache-publication path. int size = (int) (id & META_SIZE_MASKS); if (size == META_SIZE_MASKS) { size += buffer.readVarUInt32Small14(); @@ -146,11 +141,6 @@ public List getFieldsInfo() { return fieldsInfo; } - /** Returns ext meta for the class. */ - public boolean hasFieldsMeta() { - return hasFieldsMeta; - } - /** * Returns an unique id for class def. If class def are same between processes, then the id will * be same too. @@ -179,6 +169,10 @@ public boolean isCompatible() { || classSpec.typeId == Types.NAMED_COMPATIBLE_STRUCT; } + public boolean isStructSchemaKind() { + return Types.isStructType(classSpec.typeId); + } + public int getUserTypeId() { Preconditions.checkArgument(!isNamed(), "Named types don't have user type id"); return classSpec.userTypeId; @@ -190,8 +184,7 @@ public boolean equals(Object o) { return false; } TypeDef typeDef = (TypeDef) o; - return hasFieldsMeta == typeDef.hasFieldsMeta - && id == typeDef.id + return id == typeDef.id && Objects.equals(classSpec, typeDef.classSpec) && Objects.equals(fieldsInfo, typeDef.fieldsInfo); } @@ -209,8 +202,6 @@ public String toString() { + '\'' + ", fieldsInfo=" + fieldsInfo - + ", hasFieldsMeta=" - + hasFieldsMeta + ", id=" + id + '}'; @@ -429,17 +420,12 @@ public static TypeDef buildTypeDef(TypeResolver resolver, Class cls, boolean return TypeDefEncoder.buildTypeDef((XtypeResolver) resolver, cls); } return NativeTypeDefEncoder.buildTypeDef( - (ClassResolver) resolver, cls, buildFields(resolver, cls, resolveParent), true); + (ClassResolver) resolver, cls, buildFields(resolver, cls, resolveParent)); } /** Build class definition from fields of class. */ static TypeDef buildTypeDef(ClassResolver classResolver, Class type, List fields) { - return buildTypeDef(classResolver, type, fields, true); - } - - public static TypeDef buildTypeDef( - ClassResolver classResolver, Class type, List fields, boolean hasFieldsMeta) { - return NativeTypeDefEncoder.buildTypeDef(classResolver, type, fields, hasFieldsMeta); + return NativeTypeDefEncoder.buildTypeDef(classResolver, type, fields); } public TypeDef replaceRootClassTo(TypeResolver resolver, Class targetCls) { @@ -460,6 +446,6 @@ public TypeDef replaceRootClassTo(TypeResolver resolver, Class targetCls) { (XtypeResolver) resolver, targetCls, fieldInfos); } return NativeTypeDefEncoder.buildTypeDefWithFieldInfos( - (ClassResolver) resolver, targetCls, fieldInfos, hasFieldsMeta); + (ClassResolver) resolver, targetCls, fieldInfos); } } diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefDecoder.java b/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefDecoder.java index 296316b8fa..61ef2fb0f5 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefDecoder.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefDecoder.java @@ -23,14 +23,18 @@ import static org.apache.fory.meta.NativeTypeDefDecoder.decodeTypeDefBuf; import static org.apache.fory.meta.NativeTypeDefDecoder.readPkgName; import static org.apache.fory.meta.NativeTypeDefDecoder.readTypeName; -import static org.apache.fory.meta.TypeDef.HAS_FIELDS_META_FLAG; +import static org.apache.fory.meta.NativeTypeDefDecoder.validateParsedTypeDefHash; +import static org.apache.fory.meta.TypeDef.COMPRESS_META_FLAG; +import static org.apache.fory.meta.TypeDefEncoder.COMPATIBLE_FLAG; import static org.apache.fory.meta.TypeDefEncoder.FIELD_NAME_SIZE_THRESHOLD; import static org.apache.fory.meta.TypeDefEncoder.REGISTER_BY_NAME_FLAG; import static org.apache.fory.meta.TypeDefEncoder.SMALL_NUM_FIELDS_THRESHOLD; +import static org.apache.fory.meta.TypeDefEncoder.STRUCT_FLAG; import java.util.ArrayList; import java.util.List; import org.apache.fory.collection.Tuple2; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.logging.Logger; import org.apache.fory.logging.LoggerFactory; import org.apache.fory.memory.MemoryBuffer; @@ -39,12 +43,12 @@ import org.apache.fory.resolver.TypeInfo; import org.apache.fory.resolver.XtypeResolver; import org.apache.fory.serializer.UnknownClass; +import org.apache.fory.type.Types; import org.apache.fory.util.StringUtils; import org.apache.fory.util.Utils; /** - * A decoder which decode binary into {@link TypeDef}. Global header layout follows the xlang spec - * with an 8-bit meta size and flags at bits 8/9. See spec documentation: + * A decoder which decode binary into {@link TypeDef}. See spec documentation: * docs/specification/fory_xlang_serialization_spec.md ... */ @@ -52,40 +56,88 @@ class TypeDefDecoder { private static final Logger LOG = LoggerFactory.getLogger(TypeDefDecoder.class); public static TypeDef decodeTypeDef(XtypeResolver resolver, MemoryBuffer inputBuffer, long id) { + if ((id & COMPRESS_META_FLAG) != 0) { + throw new DeserializationException("Compressed xlang TypeDef is not supported"); + } Tuple2 decoded = decodeTypeDefBuf(inputBuffer, resolver, id); MemoryBuffer buffer = MemoryBuffer.fromByteArray(decoded.f0); - byte header = buffer.readByte(); - int numFields = header & SMALL_NUM_FIELDS_THRESHOLD; - if (numFields == SMALL_NUM_FIELDS_THRESHOLD) { - numFields += buffer.readVarUInt32Small7(); - } + int header = buffer.readByte() & 0xff; + boolean isStruct = (header & STRUCT_FLAG) != 0; + int numFields = 0; ClassSpec classSpec; - if ((header & REGISTER_BY_NAME_FLAG) != 0) { - String namespace = readPkgName(buffer); - String typeName = readTypeName(buffer); - if (Utils.DEBUG_OUTPUT_ENABLED) { - LOG.info("Decode class {} using namespace {}", typeName, namespace); + if (isStruct) { + boolean named = (header & REGISTER_BY_NAME_FLAG) != 0; + boolean compatible = (header & COMPATIBLE_FLAG) != 0; + int typeId; + if (named) { + typeId = compatible ? Types.NAMED_COMPATIBLE_STRUCT : Types.NAMED_STRUCT; + } else { + typeId = compatible ? Types.COMPATIBLE_STRUCT : Types.STRUCT; } - TypeInfo userTypeInfo = resolver.getUserTypeInfo(namespace, typeName); - if (userTypeInfo == null) { - classSpec = new ClassSpec(UnknownClass.UnknownStruct.class); + numFields = header & SMALL_NUM_FIELDS_THRESHOLD; + if (numFields == SMALL_NUM_FIELDS_THRESHOLD) { + numFields += buffer.readVarUInt32Small7(); + } + if (named) { + String namespace = readPkgName(buffer); + String typeName = readTypeName(buffer); + if (Utils.DEBUG_OUTPUT_ENABLED) { + LOG.info("Decode class {} using namespace {}", typeName, namespace); + } + TypeInfo userTypeInfo = resolver.getUserTypeInfo(namespace, typeName); + if (userTypeInfo == null) { + classSpec = new ClassSpec(UnknownClass.UnknownStruct.class, typeId, -1); + } else { + validateRegisteredTypeDefKind(userTypeInfo, typeId); + classSpec = new ClassSpec(userTypeInfo.getType(), typeId, userTypeInfo.getUserTypeId()); + } } else { - classSpec = new ClassSpec(userTypeInfo.getType()); + int userTypeId = buffer.readVarUInt32(); + TypeInfo userTypeInfo = resolver.getUserTypeInfo(userTypeId); + if (userTypeInfo == null) { + classSpec = new ClassSpec(UnknownClass.UnknownStruct.class, typeId, userTypeId); + } else { + validateRegisteredTypeDefKind(userTypeInfo, typeId); + classSpec = new ClassSpec(userTypeInfo.getType(), typeId, userTypeId); + } } } else { - int typeId = buffer.readUInt8(); - int userTypeId = buffer.readVarUInt32(); - TypeInfo userTypeInfo = resolver.getUserTypeInfo(userTypeId); - if (userTypeInfo == null) { - classSpec = new ClassSpec(UnknownClass.UnknownStruct.class, typeId, userTypeId); + if ((header & 0b0111_0000) != 0) { + throw new DeserializationException("Invalid TypeDef kind header"); + } + int typeId = nonStructTypeId(header & 0b1111); + boolean named = Types.isNamedType(typeId); + if (named) { + String namespace = readPkgName(buffer); + String typeName = readTypeName(buffer); + TypeInfo userTypeInfo = resolver.getUserTypeInfo(namespace, typeName); + if (userTypeInfo == null) { + classSpec = new ClassSpec(UnknownClass.UnknownStruct.class, typeId, -1); + } else { + validateRegisteredTypeDefKind(userTypeInfo, typeId); + classSpec = new ClassSpec(userTypeInfo.getType(), typeId, userTypeInfo.getUserTypeId()); + } } else { - classSpec = new ClassSpec(userTypeInfo.getType(), typeId, userTypeId); + int userTypeId = buffer.readVarUInt32(); + TypeInfo userTypeInfo = resolver.getUserTypeInfo(userTypeId); + if (userTypeInfo == null) { + classSpec = new ClassSpec(UnknownClass.UnknownStruct.class, typeId, userTypeId); + } else { + validateRegisteredTypeDefKind(userTypeInfo, typeId); + classSpec = new ClassSpec(userTypeInfo.getType(), typeId, userTypeId); + } } } List classFields = readFieldsInfo(buffer, resolver, classSpec.entireClassName, numFields); - boolean hasFieldsMeta = (id & HAS_FIELDS_META_FLAG) != 0; - TypeDef typeDef = new TypeDef(classSpec, classFields, hasFieldsMeta, id, decoded.f1); + if (!isStruct && !classFields.isEmpty()) { + throw new DeserializationException("Non-struct TypeDef cannot carry field metadata"); + } + if (buffer.remaining() != 0) { + throw new DeserializationException("Invalid TypeDef metadata size"); + } + validateParsedTypeDefHash(id, decoded.f1); + TypeDef typeDef = new TypeDef(classSpec, classFields, id, decoded.f1); if (Utils.DEBUG_OUTPUT_ENABLED) { LOG.info("[Java TypeDef DECODED] " + typeDef); // Compute and print diff with local TypeDef @@ -103,6 +155,34 @@ public static TypeDef decodeTypeDef(XtypeResolver resolver, MemoryBuffer inputBu return typeDef; } + private static void validateRegisteredTypeDefKind(TypeInfo userTypeInfo, int typeId) { + if (userTypeInfo.getTypeId() != typeId) { + throw new DeserializationException( + String.format( + "TypeDef kind %s does not match registered kind %s for %s", + typeId, userTypeInfo.getTypeId(), userTypeInfo.getType())); + } + } + + static int nonStructTypeId(int kindCode) { + switch (kindCode) { + case 0: + return Types.ENUM; + case 1: + return Types.NAMED_ENUM; + case 2: + return Types.EXT; + case 3: + return Types.NAMED_EXT; + case 4: + return Types.TYPED_UNION; + case 5: + return Types.NAMED_UNION; + default: + throw new DeserializationException("Unsupported TypeDef kind code " + kindCode); + } + } + // | header + type info + field name | ... | header + type info + field name | private static List readFieldsInfo( MemoryBuffer buffer, XtypeResolver resolver, String className, int numFields) { diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefEncoder.java b/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefEncoder.java index d5eada0f0a..4bed7c13d6 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefEncoder.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefEncoder.java @@ -49,8 +49,7 @@ import org.apache.fory.util.Utils; /** - * An encoder which encode {@link TypeDef} into binary. Global header layout follows the xlang spec - * with an 8-bit meta size and flags at bits 8/9. See spec documentation: + * An encoder which encode {@link TypeDef} into binary. See spec documentation: * docs/specification/fory_xlang_serialization_spec.md ... */ @@ -115,7 +114,6 @@ static TypeDef buildTypeDefWithFieldInfos( new TypeDef( new ClassSpec(type, typeInfo.getTypeId(), typeInfo.getUserTypeId()), fieldInfos, - true, encodeTypeDef.getInt64(0), typeDefBytes); if (Utils.DEBUG_OUTPUT_ENABLED) { @@ -125,43 +123,82 @@ static TypeDef buildTypeDefWithFieldInfos( } static final int SMALL_NUM_FIELDS_THRESHOLD = 0b11111; - static final int REGISTER_BY_NAME_FLAG = 0b100000; + static final int REGISTER_BY_NAME_FLAG = 0b0010_0000; + static final int COMPATIBLE_FLAG = 0b0100_0000; + static final int STRUCT_FLAG = 0b1000_0000; static final int FIELD_NAME_SIZE_THRESHOLD = 0b1111; // see spec documentation: docs/specification/xlang_serialization_spec.md // https://fory.apache.org/docs/specification/fory_xlang_serialization_spec static MemoryBuffer encodeTypeDef(XtypeResolver resolver, Class type, List fields) { TypeInfo typeInfo = resolver.getTypeInfo(type); + int typeId = typeInfo.getTypeId(); + boolean isStruct = Types.isStructType(typeId); + Preconditions.checkArgument( + isStruct || fields.isEmpty(), "Non-struct TypeDef %s cannot carry field metadata", typeId); MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(128); buffer.writeByte(-1); // placeholder for header, update later - int currentClassHeader = fields.size(); - if (fields.size() >= SMALL_NUM_FIELDS_THRESHOLD) { - currentClassHeader = SMALL_NUM_FIELDS_THRESHOLD; - buffer.writeVarUInt32(fields.size() - SMALL_NUM_FIELDS_THRESHOLD); - } - if (resolver.isRegisteredById(type)) { - buffer.writeUInt8(typeInfo.getTypeId()); - Preconditions.checkArgument( - typeInfo.getUserTypeId() != -1, - "User type id is required for typeId %s", - typeInfo.getTypeId()); - buffer.writeVarUInt32(typeInfo.getUserTypeId()); + if (isStruct) { + int fieldCount = fields.size(); + int currentClassHeader = STRUCT_FLAG | Math.min(fieldCount, SMALL_NUM_FIELDS_THRESHOLD); + if (typeId == Types.COMPATIBLE_STRUCT || typeId == Types.NAMED_COMPATIBLE_STRUCT) { + currentClassHeader |= COMPATIBLE_FLAG; + } + if (fieldCount >= SMALL_NUM_FIELDS_THRESHOLD) { + buffer.writeVarUInt32(fieldCount - SMALL_NUM_FIELDS_THRESHOLD); + } + if (resolver.isRegisteredById(type)) { + Preconditions.checkArgument( + typeInfo.getUserTypeId() != -1, + "User type id is required for typeId %s", + typeInfo.getTypeId()); + buffer.writeVarUInt32(typeInfo.getUserTypeId()); + } else { + Preconditions.checkArgument(resolver.isRegisteredByName(type)); + currentClassHeader |= REGISTER_BY_NAME_FLAG; + String ns = typeInfo.decodeNamespace(); + String typename = typeInfo.decodeTypeName(); + writePkgName(buffer, ns); + writeTypeName(buffer, typename); + } + buffer.putByte(0, currentClassHeader); + writeFieldsInfo(resolver, buffer, fields); } else { - Preconditions.checkArgument(resolver.isRegisteredByName(type)); - currentClassHeader |= REGISTER_BY_NAME_FLAG; - String ns = typeInfo.decodeNamespace(); - String typename = typeInfo.decodeTypeName(); - writePkgName(buffer, ns); - writeTypeName(buffer, typename); + buffer.putByte(0, nonStructKindCode(typeId)); + if (resolver.isRegisteredById(type)) { + Preconditions.checkArgument( + typeInfo.getUserTypeId() != -1, + "User type id is required for typeId %s", + typeInfo.getTypeId()); + buffer.writeVarUInt32(typeInfo.getUserTypeId()); + } else { + Preconditions.checkArgument(resolver.isRegisteredByName(type)); + String ns = typeInfo.decodeNamespace(); + String typename = typeInfo.decodeTypeName(); + writePkgName(buffer, ns); + writeTypeName(buffer, typename); + } } - buffer.putByte(0, currentClassHeader); - writeFieldsInfo(resolver, buffer, fields); + return prependHeader(buffer, false); + } - // Temporary xlang behavior: always write TypeMeta uncompressed. - // Some runtimes still don't support TypeMeta decompression, so we must avoid emitting - // compressed xlang TypeMeta until all xlang implementations support decompress. - // Note: native mode is unchanged and still uses NativeTypeDefEncoder compression flow. - return prependHeader(buffer, false, !fields.isEmpty()); + static int nonStructKindCode(int typeId) { + switch (typeId) { + case Types.ENUM: + return 0; + case Types.NAMED_ENUM: + return 1; + case Types.EXT: + return 2; + case Types.NAMED_EXT: + return 3; + case Types.TYPED_UNION: + return 4; + case Types.NAMED_UNION: + return 5; + default: + throw new IllegalArgumentException("Unsupported TypeDef kind " + typeId); + } } static Map getClassFields(Class type, List fieldsInfo) { diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/TypeEqualMetaCompressor.java b/java/fory-core/src/main/java/org/apache/fory/meta/TypeEqualMetaCompressor.java index 1c21296ce6..34eb2fe2dd 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/TypeEqualMetaCompressor.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/TypeEqualMetaCompressor.java @@ -52,6 +52,11 @@ public byte[] decompress(byte[] data, int offset, int size) { return compressor.decompress(data, offset, size); } + @Override + public byte[] decompress(byte[] data, int offset, int size, int maxOutputSize) { + return compressor.decompress(data, offset, size, maxOutputSize); + } + @Override public boolean equals(Object obj) { if (obj == null || obj.getClass() != getClass()) { diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java index 6705de531b..9903000d74 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java @@ -41,6 +41,7 @@ import java.util.ArrayList; import java.util.Calendar; import java.util.Collection; +import java.util.Collections; import java.util.Comparator; import java.util.Date; import java.util.EnumSet; @@ -101,6 +102,7 @@ import org.apache.fory.meta.ClassSpec; import org.apache.fory.meta.EncodedMetaString; import org.apache.fory.meta.Encoders; +import org.apache.fory.meta.NativeTypeDefEncoder; import org.apache.fory.meta.TypeDef; import org.apache.fory.reflect.ObjectCreators; import org.apache.fory.reflect.ReflectionUtils; @@ -537,7 +539,7 @@ public void register(Class cls, String namespace, String name) { buildUnregisteredTypeId(cls, existingInfo == null ? null : existingInfo.serializer); TypeInfo typeInfo = new TypeInfo(cls, nsBytes, nameBytes, null, typeId, -1); classInfoMap.put(cls, typeInfo); - compositeNameBytes2TypeInfo.put(new TypeNameBytes(nsBytes.hash, nameBytes.hash), typeInfo); + compositeNameBytes2TypeInfo.put(new TypeNameBytes(nsBytes, nameBytes), typeInfo); extRegistry.registeredClasses.put(fullname, cls); registerGraalvmClass(cls); } @@ -583,7 +585,7 @@ public void registerUnion(Class cls, String namespace, String name, Serialize TypeInfo typeInfo = new TypeInfo(cls, nsBytes, nameBytes, serializer, typeId, -1); typeInfo.setSerializer(this, serializer); classInfoMap.put(cls, typeInfo); - compositeNameBytes2TypeInfo.put(new TypeNameBytes(nsBytes.hash, nameBytes.hash), typeInfo); + compositeNameBytes2TypeInfo.put(new TypeNameBytes(nsBytes, nameBytes), typeInfo); extRegistry.registeredClasses.put(fullname, cls); registerGraalvmClass(cls); } @@ -864,17 +866,30 @@ public int getTypeIdForTypeDef(Class cls) { } return typeInfo.typeId; } - int typeId = buildUnregisteredTypeId(cls, null); + int typeId = usesNonStructTypeDef(cls) ? Types.NAMED_EXT : buildUnregisteredTypeId(cls, null); typeInfo = new TypeInfo(this, cls, null, typeId, INVALID_USER_TYPE_ID); classInfoMap.put(cls, typeInfo); if (typeInfo.namespace != null && typeInfo.typeName != null) { - TypeNameBytes typeNameBytes = - new TypeNameBytes(typeInfo.namespace.hash, typeInfo.typeName.hash); + TypeNameBytes typeNameBytes = new TypeNameBytes(typeInfo.namespace, typeInfo.typeName); compositeNameBytes2TypeInfo.put(typeNameBytes, typeInfo); } return typeId; } + private boolean usesNonStructTypeDef(Class cls) { + return !cls.isEnum() + && (isCollection(cls) + || isMap(cls) + || Externalizable.class.isAssignableFrom(cls) + || requireJavaSerialization(cls) + || useReplaceResolveSerializer(cls) + || Functions.isLambda(cls) + || Calendar.class.isAssignableFrom(cls) + || ZoneId.class.isAssignableFrom(cls) + || TimeZone.class.isAssignableFrom(cls) + || ByteBuffer.class.isAssignableFrom(cls)); + } + /** * Compute the user type id used in TypeDef without forcing serializer creation. Returns -1 when * the class isn't registered by numeric id. @@ -1079,8 +1094,7 @@ private void registerSerializerImpl(Class type, Serializer serializer) { typeInfo = sharedTypeInfo; updateTypeInfo(type, typeInfo); if (typeInfo.namespace != null && typeInfo.typeName != null) { - TypeNameBytes typeNameBytes = - new TypeNameBytes(typeInfo.namespace.hash, typeInfo.typeName.hash); + TypeNameBytes typeNameBytes = new TypeNameBytes(typeInfo.namespace, typeInfo.typeName); compositeNameBytes2TypeInfo.put(typeNameBytes, typeInfo); } if (typeInfoCache.type == type) { @@ -1226,8 +1240,7 @@ public void addSerializer(Class type, Serializer serializer) { // readTypeInfo can find the TypeInfo by name bytes during deserialization. // This is important for dynamically created classes that can't be loaded by name. if (typeInfo.namespace != null && typeInfo.typeName != null) { - TypeNameBytes typeNameBytes = - new TypeNameBytes(typeInfo.namespace.hash, typeInfo.typeName.hash); + TypeNameBytes typeNameBytes = new TypeNameBytes(typeInfo.namespace, typeInfo.typeName); compositeNameBytes2TypeInfo.put(typeNameBytes, typeInfo); } } @@ -1708,7 +1721,6 @@ private void registerGraalvmSerializerClass(Class cls) { getGraalvmClassRegistry() .putDeserializerClass( typeDef.getId(), getMetaSharedDeserializerClassForGraalvmBuild(cls, typeDef)); - extRegistry.typeInfoByTypeDefId.remove(typeDef.getId()); } typeInfoCache = NIL_TYPE_INFO; if (RecordUtils.isRecord(cls)) { @@ -1809,15 +1821,15 @@ private TypeDef buildTypeDef(TypeInfo typeInfo, Class seri Preconditions.checkArgument( serializerClass != UnknownClassSerializers.UnknownStructSerializer.class); if (needToWriteTypeDef(serializerClass)) { - typeDef = - cacheTypeDef( - typeDefMap.computeIfAbsent(typeInfo.type, cls -> TypeDef.buildTypeDef(this, cls))); + typeDef = typeDefMap.computeIfAbsent(typeInfo.type, cls -> TypeDef.buildTypeDef(this, cls)); } else { // Some type will use other serializers such MapSerializer and so on. typeDef = - cacheTypeDef( - typeDefMap.computeIfAbsent( - typeInfo.type, cls -> TypeDef.buildTypeDef(this, cls, new ArrayList<>(), false))); + typeDefMap.computeIfAbsent( + typeInfo.type, + cls -> + NativeTypeDefEncoder.buildTypeDefWithFieldInfos( + this, cls, Collections.emptyList())); } typeInfo.typeDef = typeDef; return typeDef; @@ -1924,7 +1936,7 @@ private TypeInfo getTypeInfoByTypeIdForReadClassInternal(int typeId, int userTyp @Override protected TypeInfo loadBytesToTypeInfo( EncodedMetaString packageBytes, EncodedMetaString simpleClassNameBytes) { - TypeNameBytes typeNameBytes = new TypeNameBytes(packageBytes.hash, simpleClassNameBytes.hash); + TypeNameBytes typeNameBytes = new TypeNameBytes(packageBytes, simpleClassNameBytes); TypeInfo typeInfo = compositeNameBytes2TypeInfo.get(typeNameBytes); if (typeInfo == null) { typeInfo = populateBytesToTypeInfo(typeNameBytes, packageBytes, simpleClassNameBytes); @@ -1946,8 +1958,7 @@ protected TypeInfo ensureSerializerForTypeInfo(TypeInfo typeInfo) { TypeInfo newTypeInfo = getTypeInfo(typeInfo.type); // Update the cache with the correct TypeInfo that has a serializer if (typeInfo.typeName != null) { - TypeNameBytes typeNameBytes = - new TypeNameBytes(typeInfo.namespace.hash, typeInfo.typeName.hash); + TypeNameBytes typeNameBytes = new TypeNameBytes(typeInfo.namespace, typeInfo.typeName); compositeNameBytes2TypeInfo.put(typeNameBytes, newTypeInfo); } return newTypeInfo; @@ -1985,8 +1996,7 @@ public Class loadClassForMeta(String className, boolean isEnum, int arrayDims String typeName = ReflectionUtils.getClassNameWithoutPackage(className); EncodedMetaString pkgBytes = sharedRegistry.getPackageEncodedMetaString(pkg); EncodedMetaString typeBytes = sharedRegistry.getTypeNameEncodedMetaString(typeName); - TypeInfo cachedInfo = - compositeNameBytes2TypeInfo.get(new TypeNameBytes(pkgBytes.hash, typeBytes.hash)); + TypeInfo cachedInfo = compositeNameBytes2TypeInfo.get(new TypeNameBytes(pkgBytes, typeBytes)); if (cachedInfo != null) { return cachedInfo.type; } diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/SharedRegistry.java b/java/fory-core/src/main/java/org/apache/fory/resolver/SharedRegistry.java index 8e6b50484f..5dec2d5ff6 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/SharedRegistry.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/SharedRegistry.java @@ -27,6 +27,7 @@ import java.util.Objects; import java.util.SortedMap; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; import org.apache.fory.annotation.Internal; import org.apache.fory.codegen.CodeGenerator; import org.apache.fory.collection.BiMap; @@ -54,6 +55,7 @@ public final class SharedRegistry { private static final int MAX_CACHED_ENCODED_META_STRINGS = 32768; private static final int MAX_CACHED_ENCODED_META_STRING_LENGTH = 2048; + private static final int MAX_CACHED_TYPE_DEFS = 8192; final ConcurrentIdentityMap, TypeDef> typeDefMap = new ConcurrentIdentityMap<>(); final ConcurrentIdentityMap, TypeDef> currentLayerTypeDef = @@ -80,6 +82,7 @@ public final class SharedRegistry { final ConcurrentIdentityMap, Serializer> registeredSerializerCache = new ConcurrentIdentityMap<>(); private final Object metaStringCacheLock = new Object(); + private final AtomicInteger cachedTypeDefCount = new AtomicInteger(); volatile IdentityHashMap, Integer> registeredClassIdMap; volatile BiMap> registeredClasses; @@ -125,6 +128,43 @@ TypeInfo cacheRegisteredTypeInfo(Class type, TypeInfo typeInfo) { return existing == null ? typeInfo : existing; } + TypeDef getOrCreateTypeDef(TypeDef typeDef) { + long id = typeDef.getId(); + TypeDef existing = typeDefById.get(id); + if (existing != null) { + return existing; + } + if (!reserveTypeDefCacheSlot()) { + return typeDef; + } + existing = typeDefById.putIfAbsent(id, typeDef); + if (existing != null) { + cachedTypeDefCount.decrementAndGet(); + return existing; + } + return typeDef; + } + + private boolean reserveTypeDefCacheSlot() { + while (true) { + int count = cachedTypeDefCount.get(); + int mapSize = typeDefById.size(); + if (mapSize > count) { + if (cachedTypeDefCount.compareAndSet(count, mapSize)) { + count = mapSize; + } else { + continue; + } + } + if (count >= MAX_CACHED_TYPE_DEFS) { + return false; + } + if (cachedTypeDefCount.compareAndSet(count, count + 1)) { + return true; + } + } + } + EncodedMetaString getPackageEncodedMetaString(String string) { return getEncodedMetaString( string, @@ -210,11 +250,6 @@ private boolean shouldCacheEncodedMetaStringLength(EncodedMetaString encodedMeta return encodedMetaString.bytes.length <= MAX_CACHED_ENCODED_META_STRING_LENGTH; } - TypeDef getOrCreateTypeDef(TypeDef typeDef) { - TypeDef existingTypeDef = typeDefById.putIfAbsent(typeDef.getId(), typeDef); - return existingTypeDef == null ? typeDef : existingTypeDef; - } - List getOrCreateFieldDescriptors( Class type, boolean searchParent, java.util.function.Supplier> factory) { if (GraalvmSupport.isGraalBuildtime()) { diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeNameBytes.java b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeNameBytes.java index 20fe814825..f0ae54e53f 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeNameBytes.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeNameBytes.java @@ -19,20 +19,30 @@ package org.apache.fory.resolver; +import java.util.Arrays; +import org.apache.fory.meta.EncodedMetaString; + class TypeNameBytes { private final long packageHash; private final long classNameHash; + private final byte[] packageBytes; + private final byte[] classNameBytes; - TypeNameBytes(long packageHash, long classNameHash) { - this.packageHash = packageHash; - this.classNameHash = classNameHash; + TypeNameBytes(EncodedMetaString packageBytes, EncodedMetaString classNameBytes) { + this.packageHash = packageBytes.hash; + this.classNameHash = classNameBytes.hash; + this.packageBytes = packageBytes.bytes; + this.classNameBytes = classNameBytes.bytes; } @Override public boolean equals(Object o) { // ClassNameBytes is used internally, skip TypeNameBytes that = (TypeNameBytes) o; - return packageHash == that.packageHash && classNameHash == that.classNameHash; + return packageHash == that.packageHash + && classNameHash == that.classNameHash + && Arrays.equals(packageBytes, that.packageBytes) + && Arrays.equals(classNameBytes, that.classNameBytes); } @Override diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java index 37cbe1e9bc..cf5569e033 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java @@ -116,8 +116,21 @@ public abstract class TypeResolver { static final int INTERNAL_NATIVE_ID_LIMIT = 250; private static final GenericType OBJECT_GENERIC_TYPE = GenericType.build(Object.class); private static final float TYPE_ID_MAP_LOAD_FACTOR = 0.5f; + private static final int MAX_CACHED_TYPE_DEFS = 8192; static final long MAX_USER_TYPE_ID = 0xffff_fffEL; + private static final class TransformedTypeInfo { + final Class readClass; + final long typeDefId; + final TypeInfo typeInfo; + + TransformedTypeInfo(Class readClass, long typeDefId, TypeInfo typeInfo) { + this.readClass = readClass; + this.typeDefId = typeDefId; + this.typeInfo = typeInfo; + } + } + final Config config; final boolean metaContextShareEnabled; final SharedRegistry sharedRegistry; @@ -729,7 +742,7 @@ protected final TypeInfo readTypeInfoByCache( /** * Read class info from bytes with cache optimization. Uses the cached namespace and type name - * bytes to avoid map lookups when the class is the same as the cached one (hash comparison). + * bytes to avoid map lookups when the class is the same as the cached one. */ protected final TypeInfo readTypeInfoFromBytes( ReadContext readContext, TypeInfo typeInfoCache, int header) { @@ -746,9 +759,8 @@ protected final TypeInfo readTypeInfoFromBytes( assert packageNameBytesCache != null; simpleClassNameBytes = metaStringReader.readMetaString(buffer, typeNameBytesCache); - // Fast path: if hashes match, return cached TypeInfo (already has serializer) - if (typeNameBytesCache.hash == simpleClassNameBytes.hash - && packageNameBytesCache.hash == namespaceBytes.hash) { + // MetaStringReader returns the provided cache object only after validating the encoded body. + if (typeNameBytesCache == simpleClassNameBytes && packageNameBytesCache == namespaceBytes) { return typeInfoCache; } } else { @@ -775,9 +787,11 @@ protected final TypeInfo readSharedClassMeta(ReadContext readContext) { TypeInfo typeInfo; if (isRef) { // Reference to previously read type in this stream - typeInfo = metaReadContext.readTypeInfos.get(index); + typeInfo = getMetaReadTypeInfo(metaReadContext, index); } else { - // New type in stream - but may already be known from registry + // New type in stream, with optimized reuse by validated TypeDef header. A header-cache + // hit intentionally skips the body without rehashing: entries are published only after the + // TypeDef body has parsed successfully and matched the 52-bit body hash. long id = buffer.readInt64(); typeInfo = extRegistry.typeInfoByTypeDefId.get(id); if (typeInfo != null) { @@ -797,6 +811,17 @@ protected final TypeInfo readSharedClassMeta(ReadContext readContext) { return typeInfo; } + private static TypeInfo getMetaReadTypeInfo(MetaReadContext metaReadContext, int index) { + if (index < 0 || index >= metaReadContext.readTypeInfos.size) { + throw new ForyException("Invalid class metadata reference id " + index); + } + TypeInfo typeInfo = metaReadContext.readTypeInfos.get(index); + if (typeInfo == null) { + throw new ForyException("Invalid class metadata reference id " + index); + } + return typeInfo; + } + public final TypeInfo readSharedClassMeta(ReadContext readContext, Class targetClass) { TypeInfo typeInfo = readSharedClassMeta(readContext); Class readClass = typeInfo.getType(); @@ -808,20 +833,26 @@ public final TypeInfo readSharedClassMeta(ReadContext readContext, Class targ } private TypeInfo getTargetTypeInfo(TypeInfo typeInfo, Class targetClass) { - Tuple2, TypeInfo>[] infos = extRegistry.transformedTypeInfo.get(targetClass); + TransformedTypeInfo[] infos = extRegistry.transformedTypeInfo.get(targetClass); Class readClass = typeInfo.getType(); + long typeDefId = transformCacheTypeDefId(typeInfo); if (infos != null) { // It's ok to use loop here since most of case the array size will be 1. - for (Tuple2, TypeInfo> info : infos) { - if (info.f0 == readClass) { - return info.f1; + for (TransformedTypeInfo info : infos) { + if (info.readClass == readClass && info.typeDefId == typeDefId) { + return info.typeInfo; } } } - return transformTypeInfo(typeInfo, targetClass); + return transformTypeInfo(typeInfo, targetClass, typeDefId); + } + + private static long transformCacheTypeDefId(TypeInfo typeInfo) { + TypeDef typeDef = typeInfo.getTypeDef(); + return typeDef == null ? 0 : typeDef.getId(); } - private TypeInfo transformTypeInfo(TypeInfo typeInfo, Class targetClass) { + private TypeInfo transformTypeInfo(TypeInfo typeInfo, Class targetClass, long typeDefId) { Class readClass = typeInfo.getType(); TypeInfo newTypeInfo; if (targetClass.isAssignableFrom(readClass)) { @@ -832,14 +863,16 @@ private TypeInfo transformTypeInfo(TypeInfo typeInfo, Class targetClass) { getMetaSharedTypeInfo( typeInfo.typeDef.replaceRootClassTo(this, targetClass), targetClass); } - Tuple2, TypeInfo>[] infos = extRegistry.transformedTypeInfo.get(targetClass); + TransformedTypeInfo[] infos = extRegistry.transformedTypeInfo.get(targetClass); int size = infos == null ? 0 : infos.length; - @SuppressWarnings("unchecked") - Tuple2, TypeInfo>[] newInfos = (Tuple2, TypeInfo>[]) new Tuple2[size + 1]; + if (size >= MAX_CACHED_TYPE_DEFS) { + return newTypeInfo; + } + TransformedTypeInfo[] newInfos = new TransformedTypeInfo[size + 1]; if (size > 0) { System.arraycopy(infos, 0, newInfos, 0, size); } - newInfos[size] = Tuple2.of(readClass, newTypeInfo); + newInfos[size] = new TransformedTypeInfo(readClass, typeDefId, newTypeInfo); extRegistry.transformedTypeInfo.put(targetClass, newInfos); return newTypeInfo; } @@ -921,10 +954,7 @@ final TypeInfo buildMetaSharedTypeInfo(TypeDef typeDef) { return typeInfo; } Class cls = loadClass(typeDef.getClassSpec()); - // For nonexistent classes, always create a new TypeInfo with the correct typeDef, - // even if the typeDef has no fields meta. This ensures the UnknownStructSerializer - // has access to the typeDef for proper deserialization. - if (!typeDef.hasFieldsMeta() + if (!typeDef.isStructSchemaKind() && !UnknownClass.class.isAssignableFrom(TypeUtils.getComponentIfArray(cls))) { typeInfo = getTypeInfo(cls); } else if (ClassResolver.useReplaceResolveSerializer(cls)) { @@ -934,7 +964,9 @@ final TypeInfo buildMetaSharedTypeInfo(TypeDef typeDef) { } else { typeInfo = getMetaSharedTypeInfo(typeDef, cls); } - extRegistry.typeInfoByTypeDefId.put(typeDef.getId(), typeInfo); + if (extRegistry.typeInfoByTypeDefId.size < MAX_CACHED_TYPE_DEFS) { + extRegistry.typeInfoByTypeDefId.put(typeDef.getId(), typeInfo); + } return typeInfo; } @@ -1807,15 +1839,12 @@ class ExtRegistry { int userIdGenerator = 0; SerializerFactory serializerFactory; final LongMap typeInfoByTypeDefId = new LongMap<>(2, 0.5f); - // cache absTypeInfo, support customized serializer for abstract or interface. // IdentityHashMap is more memory efficient than fory IdentityMap, and this is not in hotpath // for query final IdentityHashMap, TypeInfo> abstractTypeInfo = new IdentityHashMap<>(); - // Tuple2: Tuple2 - final IdentityHashMap, Tuple2, TypeInfo>[]> transformedTypeInfo = + final IdentityHashMap, TransformedTypeInfo[]> transformedTypeInfo = new IdentityHashMap<>(); - // avoid potential recursive call for seq codec generation. // ex. A->field1: B, B.field1: A final Set> getClassCtx = new HashSet<>(); diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java index 3f1ee75fed..39e70a5431 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java @@ -334,8 +334,7 @@ private void register( typeInfo.setSerializer(this, serializer); extRegistry.registeredClasses.put(qualifiedName, type); if (typeInfo.typeName != null) { - TypeNameBytes typeNameBytes = - new TypeNameBytes(typeInfo.namespace.hash, typeInfo.typeName.hash); + TypeNameBytes typeNameBytes = new TypeNameBytes(typeInfo.namespace, typeInfo.typeName); compositeClassNameBytes2TypeInfo.put(typeNameBytes, typeInfo); } registerGraalvmClass(type); @@ -481,8 +480,7 @@ public void registerSerializer(Class type, Serializer serializer) { } updateTypeInfo(type, typeInfo); if (typeInfo.typeName != null) { - TypeNameBytes typeNameBytes = - new TypeNameBytes(typeInfo.namespace.hash, typeInfo.typeName.hash); + TypeNameBytes typeNameBytes = new TypeNameBytes(typeInfo.namespace, typeInfo.typeName); compositeClassNameBytes2TypeInfo.put(typeNameBytes, typeInfo); } } @@ -740,8 +738,7 @@ public TypeInfo getXtypeInfo(int typeId) { public TypeInfo getUserTypeInfo(String namespace, String typeName) { EncodedMetaString namespaceBytes = sharedRegistry.getPackageEncodedMetaString(namespace); EncodedMetaString typeNameBytes = sharedRegistry.getTypeNameEncodedMetaString(typeName); - return compositeClassNameBytes2TypeInfo.get( - new TypeNameBytes(namespaceBytes.hash, typeNameBytes.hash)); + return compositeClassNameBytes2TypeInfo.get(new TypeNameBytes(namespaceBytes, typeNameBytes)); } public TypeInfo getUserTypeInfo(int userTypeId) { @@ -1113,8 +1110,7 @@ public TypeInfo writeTypeInfo(MemoryBuffer buffer, Object obj) { @Override protected TypeDef buildTypeDef(TypeInfo typeInfo) { TypeDef typeDef = - cacheTypeDef( - typeDefMap.computeIfAbsent(typeInfo.type, cls -> TypeDef.buildTypeDef(this, cls))); + typeDefMap.computeIfAbsent(typeInfo.type, cls -> TypeDef.buildTypeDef(this, cls)); typeInfo.typeDef = typeDef; return typeDef; } @@ -1196,8 +1192,7 @@ protected TypeInfo ensureSerializerForTypeInfo(TypeInfo typeInfo) { TypeInfo newTypeInfo = getTypeInfo(typeInfo.type); // Update the cache with the correct TypeInfo that has a serializer if (typeInfo.typeName != null) { - TypeNameBytes typeNameBytes = - new TypeNameBytes(typeInfo.namespace.hash, typeInfo.typeName.hash); + TypeNameBytes typeNameBytes = new TypeNameBytes(typeInfo.namespace, typeInfo.typeName); compositeClassNameBytes2TypeInfo.put(typeNameBytes, newTypeInfo); } return newTypeInfo; @@ -1207,7 +1202,7 @@ protected TypeInfo ensureSerializerForTypeInfo(TypeInfo typeInfo) { private TypeInfo loadBytesToTypeInfoWithTypeId( int internalTypeId, EncodedMetaString packageBytes, EncodedMetaString simpleClassNameBytes) { - TypeNameBytes typeNameBytes = new TypeNameBytes(packageBytes.hash, simpleClassNameBytes.hash); + TypeNameBytes typeNameBytes = new TypeNameBytes(packageBytes, simpleClassNameBytes); TypeInfo typeInfo = compositeClassNameBytes2TypeInfo.get(typeNameBytes); if (typeInfo == null) { typeInfo = diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/BufferSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/BufferSerializers.java index d5d52a4ce4..74619ff9fc 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/BufferSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/BufferSerializers.java @@ -24,6 +24,7 @@ import org.apache.fory.context.CopyContext; import org.apache.fory.context.ReadContext; import org.apache.fory.context.WriteContext; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.memory.ByteBufferUtil; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.resolver.TypeResolver; @@ -60,8 +61,14 @@ public ByteBuffer read(ReadContext readContext) { MemoryBuffer newBuffer = readContext.readBufferObject(); int readerIndex = newBuffer.readerIndex(); int size = newBuffer.remaining(); + if (size < 1) { + throw new DeserializationException("Invalid ByteBuffer payload"); + } ByteBuffer originalBuffer = newBuffer.sliceAsByteBuffer(readerIndex, size - 1); byte isBigEndian = newBuffer.getByte(readerIndex + size - 1); + if (isBigEndian != 0 && isBigEndian != (byte) 1) { + throw new DeserializationException("Invalid ByteBuffer byte order marker"); + } originalBuffer.order( isBigEndian == (byte) 1 ? ByteOrder.BIG_ENDIAN : ByteOrder.LITTLE_ENDIAN); return originalBuffer; diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java index c714cedb65..5bbb5c822e 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java @@ -34,6 +34,7 @@ import org.apache.fory.context.MetaReadContext; import org.apache.fory.context.ReadContext; import org.apache.fory.context.WriteContext; +import org.apache.fory.exception.ForyException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.memory.Platform; import org.apache.fory.meta.TypeDef; @@ -302,7 +303,11 @@ private static void readAndSkipLayerClassMeta(ReadContext readContext) { } int indexMarker = buffer.readVarUInt32Small14(); boolean isRef = (indexMarker & 1) == 1; + int index = indexMarker >>> 1; if (isRef) { + if (index >= metaReadContext.readTypeInfos.size) { + throw new ForyException("Invalid layer metadata reference id " + index); + } return; } long typeDefId = buffer.readInt64(); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java index 3f0f60c3e7..b4c0780ef3 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java @@ -51,6 +51,7 @@ import org.apache.fory.context.MetaReadContext; import org.apache.fory.context.ReadContext; import org.apache.fory.context.WriteContext; +import org.apache.fory.exception.ForyException; import org.apache.fory.logging.Logger; import org.apache.fory.logging.LoggerFactory; import org.apache.fory.memory.MemoryBuffer; @@ -89,9 +90,10 @@ @SuppressWarnings({"unchecked", "rawtypes"}) public class ObjectStreamSerializer extends AbstractObjectSerializer { private static final Logger LOG = LoggerFactory.getLogger(ObjectStreamSerializer.class); + private static final int MAX_CACHED_TYPE_DEFS = 8192; private final SlotInfo[] slotsInfos; - // Instance-level cache: TypeDef ID -> TypeInfo (shared across all slots) + // Instance-level cache: TypeDef ID -> TypeInfo (shared across all slots). private final LongMap typeDefIdToTypeInfo = new LongMap<>(4, 0.4f); private static MetaSharedLayerSerializerBase newGeneratedSerializer( @@ -428,21 +430,11 @@ private void skipUnknownLayerData(ReadContext readContext, Class senderClass) TypeInfo typeInfo; if (isRef) { // Reference to previously read TypeInfo - typeInfo = metaReadContext.readTypeInfos.get(index); + typeInfo = getMetaReadTypeInfo(metaReadContext, index); } else { - // New TypeDef in stream - read ID first to check cache + // New TypeDef in stream, with optimized reuse by validated TypeDef header. long typeDefId = buffer.readInt64(); - typeInfo = typeDefIdToTypeInfo.get(typeDefId); - if (typeInfo != null) { - // Already cached - skip the TypeDef bytes, reuse existing TypeInfo - TypeDef.skipTypeDef(buffer, typeDefId); - } else { - // Not cached - read full TypeDef and create TypeInfo - TypeDef typeDef = - typeResolver.cacheTypeDef(TypeDef.readTypeDef(typeResolver, buffer, typeDefId)); - typeInfo = new TypeInfo(senderClass, typeDef); - typeDefIdToTypeInfo.put(typeDefId, typeInfo); - } + typeInfo = readLayerTypeInfo(typeResolver, buffer, senderClass, typeDefId); metaReadContext.readTypeInfos.add(typeInfo); } @@ -460,6 +452,33 @@ private void skipUnknownLayerData(ReadContext readContext, Class senderClass) skipSerializer.skipFields(readContext); } + private static TypeInfo getMetaReadTypeInfo(MetaReadContext metaReadContext, int index) { + if (index < 0 || index >= metaReadContext.readTypeInfos.size) { + throw new ForyException("Invalid layer metadata reference id " + index); + } + TypeInfo typeInfo = metaReadContext.readTypeInfos.get(index); + if (typeInfo == null) { + throw new ForyException("Invalid layer metadata reference id " + index); + } + return typeInfo; + } + + private TypeInfo readLayerTypeInfo( + TypeResolver typeResolver, MemoryBuffer buffer, Class cls, long typeDefId) { + TypeInfo typeInfo = typeDefIdToTypeInfo.get(typeDefId); + if (typeInfo != null) { + TypeDef.skipTypeDef(buffer, typeDefId); + return typeInfo; + } + TypeDef typeDef = + typeResolver.cacheTypeDef(TypeDef.readTypeDef(typeResolver, buffer, typeDefId)); + typeInfo = new TypeInfo(cls, typeDef); + if (typeDefIdToTypeInfo.size < MAX_CACHED_TYPE_DEFS) { + typeDefIdToTypeInfo.put(typeDefId, typeInfo); + } + return typeInfo; + } + private static void throwUnsupportedEncodingException(Class cls) throws UnsupportedEncodingException { throw new UnsupportedEncodingException( @@ -666,7 +685,7 @@ public SlotsInfo(TypeResolver typeResolver, Class type) { buildFieldInfoFromObjectStreamClass(typeResolver, objectStreamClass, type); layerTypeDef = NativeTypeDefEncoder.buildTypeDefWithFieldInfos( - (ClassResolver) typeResolver, type, fieldInfos, true); + (ClassResolver) typeResolver, type, fieldInfos); } else { // Fallback when ObjectStreamClass is not available (e.g., GraalVM native image) layerTypeDef = typeResolver.getTypeDef(type, false); @@ -842,21 +861,12 @@ private TypeInfo readLayerTypeInfo(TypeResolver typeResolver, ReadContext readCo int index = indexMarker >>> 1; if (isRef) { // Reference to previously read TypeInfo - return metaReadContext.readTypeInfos.get(index); + return getMetaReadTypeInfo(metaReadContext, index); } else { - // New TypeDef in stream - read ID first to check cache + // New TypeDef in stream, with optimized reuse by validated TypeDef header. long typeDefId = buffer.readInt64(); - TypeInfo typeInfo = typeDefIdToTypeInfo.get(typeDefId); - if (typeInfo != null) { - // Already cached - skip the TypeDef bytes, reuse existing TypeInfo - TypeDef.skipTypeDef(buffer, typeDefId); - } else { - // Not cached - read full TypeDef and create TypeInfo - TypeDef typeDef = - typeResolver.cacheTypeDef(TypeDef.readTypeDef(typeResolver, buffer, typeDefId)); - typeInfo = new TypeInfo(cls, typeDef); - typeDefIdToTypeInfo.put(typeDefId, typeInfo); - } + TypeInfo typeInfo = + ObjectStreamSerializer.this.readLayerTypeInfo(typeResolver, buffer, cls, typeDefId); metaReadContext.readTypeInfos.add(typeInfo); return typeInfo; } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java index 2250ed0e37..39ca5c3356 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java @@ -49,6 +49,7 @@ import org.apache.fory.context.MetaReadContext; import org.apache.fory.context.ReadContext; import org.apache.fory.context.WriteContext; +import org.apache.fory.exception.ForyException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.meta.TypeDef; import org.apache.fory.reflect.ReflectionUtils; @@ -681,6 +682,9 @@ private static void readAndSkipLayerClassMeta(ReadContext readContext) { boolean isRef = (indexMarker & 1) == 1; int index = indexMarker >>> 1; if (isRef) { + if (index >= metaReadContext.readTypeInfos.size) { + throw new ForyException("Invalid layer metadata reference id " + index); + } // Reference to previously read type - nothing more to read return; } diff --git a/java/fory-core/src/test/java/org/apache/fory/ForyTest.java b/java/fory-core/src/test/java/org/apache/fory/ForyTest.java index ea2b3751bf..11e8e1c41c 100644 --- a/java/fory-core/src/test/java/org/apache/fory/ForyTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/ForyTest.java @@ -89,6 +89,14 @@ public static Object[][] xlangConfig() { return new Object[][] {{false}, {true}}; } + @Test + public void typedDeserializeRejectsOutOfBandRootHeaderWithoutBuffers() { + Fory fory = Fory.builder().build(); + byte[] bytes = fory.serialize(7); + bytes[0] |= 0x02; + assertThrows(IllegalArgumentException.class, () -> fory.deserialize(bytes, Integer.class)); + } + @Test(dataProvider = "crossLanguageReferenceTrackingConfig") public void primitivesTest(boolean referenceTracking, boolean xlang) { Fory fory1 = diff --git a/java/fory-core/src/test/java/org/apache/fory/StreamTest.java b/java/fory-core/src/test/java/org/apache/fory/StreamTest.java index 2953f8f29f..a1a5aa82d0 100644 --- a/java/fory-core/src/test/java/org/apache/fory/StreamTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/StreamTest.java @@ -29,12 +29,15 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.io.ForyInputStream; import org.apache.fory.io.ForyReadableChannel; import org.apache.fory.io.ForyStreamReader; @@ -281,6 +284,25 @@ public void testReadableChannel() throws IOException { } } + @Test + public void testReadableChannelRequiresExactReads() throws IOException { + Fory fory = Fory.builder().requireClassRegistration(false).build(); + BeanA beanA = BeanA.createBeanA(2); + byte[] serialized = fory.serialize(beanA); + + try (ForyReadableChannel channel = + new ForyReadableChannel(new ChunkedReadableByteChannel(serialized, 1))) { + Assert.assertEquals(fory.deserialize(channel), beanA); + } + + byte[] truncated = new byte[serialized.length - 1]; + System.arraycopy(serialized, 0, truncated, 0, truncated.length); + try (ForyReadableChannel channel = + new ForyReadableChannel(new ChunkedReadableByteChannel(truncated, 1))) { + Assert.assertThrows(DeserializationException.class, () -> fory.deserialize(channel)); + } + } + @Test public void testScopedMetaShare() throws IOException { Fory fory = @@ -306,6 +328,45 @@ public void testScopedMetaShare() throws IOException { Assert.assertEquals(fory.deserialize(stream), list2); } + private static final class ChunkedReadableByteChannel implements ReadableByteChannel { + private final byte[] data; + private final int chunkSize; + private int index; + private boolean open = true; + + private ChunkedReadableByteChannel(byte[] data, int chunkSize) { + this.data = data; + this.chunkSize = chunkSize; + } + + @Override + public int read(ByteBuffer dst) { + if (!open) { + throw new IllegalStateException("Channel is closed"); + } + if (!dst.hasRemaining()) { + return 0; + } + if (index == data.length) { + return -1; + } + int length = Math.min(chunkSize, Math.min(dst.remaining(), data.length - index)); + dst.put(data, index, length); + index += length; + return length; + } + + @Override + public boolean isOpen() { + return open; + } + + @Override + public void close() { + open = false; + } + } + @Test public void testBigBufferStreamingMetaShared() throws IOException { Fory fory = builder().withCompatible(true).build(); diff --git a/java/fory-core/src/test/java/org/apache/fory/meta/DeflaterMetaCompressorTest.java b/java/fory-core/src/test/java/org/apache/fory/meta/DeflaterMetaCompressorTest.java index d601f08033..dd99b31cd3 100644 --- a/java/fory-core/src/test/java/org/apache/fory/meta/DeflaterMetaCompressorTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/meta/DeflaterMetaCompressorTest.java @@ -62,6 +62,17 @@ public void testDecompressCorruptedInputThrows() { assertTrue(e.getMessage().contains("Invalid compressed metadata")); } + @Test(timeOut = 5_000) + public void testDecompressRejectsOutputAboveLimit() { + byte[] input = new byte[4096]; + byte[] compressed = compressor.compress(input, 0, input.length); + InvalidDataException e = + Assert.expectThrows( + InvalidDataException.class, + () -> compressor.decompress(compressed, 0, compressed.length, 1024)); + assertTrue(e.getMessage().contains("maximum size")); + } + private static byte[] sampleInput() { return "0123456789abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz" .getBytes(StandardCharsets.UTF_8); diff --git a/java/fory-core/src/test/java/org/apache/fory/meta/NativeTypeDefEncoderTest.java b/java/fory-core/src/test/java/org/apache/fory/meta/NativeTypeDefEncoderTest.java index ce6b974bd2..a78bf53e41 100644 --- a/java/fory-core/src/test/java/org/apache/fory/meta/NativeTypeDefEncoderTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/meta/NativeTypeDefEncoderTest.java @@ -27,6 +27,7 @@ import lombok.Data; import org.apache.fory.Fory; import org.apache.fory.annotation.ForyField; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.resolver.ClassResolver; import org.apache.fory.test.bean.BeanA; @@ -44,7 +45,7 @@ public void testBasicTypeDef() { List fieldsInfo = buildFieldsInfo((ClassResolver) fory.getTypeResolver(), type); MemoryBuffer buffer = NativeTypeDefEncoder.encodeTypeDef( - (ClassResolver) fory.getTypeResolver(), type, getClassFields(type, fieldsInfo), true); + (ClassResolver) fory.getTypeResolver(), type, getClassFields(type, fieldsInfo)); TypeDef typeDef = TypeDef.readTypeDef(fory.getTypeResolver(), buffer); Assert.assertEquals(typeDef.getClassName(), type.getName()); Assert.assertEquals(typeDef.getFieldsInfo().size(), type.getDeclaredFields().length); @@ -116,12 +117,95 @@ public static class InnerClassTestLengthInnerClassTestLengthInnerClassTestLength public void testPrependHeader() { MemoryBuffer inputBuffer = MemoryBuffer.newHeapBuffer(TypeDef.META_SIZE_MASKS + 1); inputBuffer.writerIndex(TypeDef.META_SIZE_MASKS + 1); - MemoryBuffer outputBuffer = NativeTypeDefEncoder.prependHeader(inputBuffer, true, false); + MemoryBuffer outputBuffer = NativeTypeDefEncoder.prependHeader(inputBuffer, true); long header = outputBuffer.readInt64(); Assert.assertEquals(header & TypeDef.META_SIZE_MASKS, TypeDef.META_SIZE_MASKS); Assert.assertEquals(header & TypeDef.COMPRESS_META_FLAG, TypeDef.COMPRESS_META_FLAG); - Assert.assertEquals(header & TypeDef.HAS_FIELDS_META_FLAG, 0); + } + + @Test + public void testDecodeRejectsReservedGlobalBits() { + Fory fory = Fory.builder().withMetaShare(true).build(); + TypeDef typeDef = TypeDef.buildTypeDef(fory.getTypeResolver(), Foo1.class); + MemoryBuffer encoded = MemoryBuffer.fromByteArray(typeDef.getEncoded()); + long header = encoded.readInt64(); + + MemoryBuffer malformed = MemoryBuffer.newHeapBuffer(typeDef.getEncoded().length); + malformed.writeInt64(header | TypeDef.RESERVED_META_FLAGS); + malformed.writeBytes( + typeDef.getEncoded(), Long.BYTES, typeDef.getEncoded().length - Long.BYTES); + Assert.assertThrows( + DeserializationException.class, + () -> TypeDef.readTypeDef(fory.getTypeResolver(), malformed)); + } + + @Test + public void testDecodeRejectsTrailingTypeDefBodyBytes() { + Fory fory = Fory.builder().withMetaShare(true).build(); + TypeDef typeDef = TypeDef.buildTypeDef(fory.getTypeResolver(), Foo1.class); + MemoryBuffer encoded = MemoryBuffer.fromByteArray(typeDef.getEncoded()); + long header = encoded.readInt64(); + long size = header & TypeDef.META_SIZE_MASKS; + Assert.assertTrue(size < TypeDef.META_SIZE_MASKS); + + MemoryBuffer malformed = MemoryBuffer.newHeapBuffer(typeDef.getEncoded().length + 1); + malformed.writeInt64((header & ~TypeDef.META_SIZE_MASKS) | (size + 1)); + malformed.writeBytes( + typeDef.getEncoded(), Long.BYTES, typeDef.getEncoded().length - Long.BYTES); + malformed.writeByte(0); + Assert.assertThrows( + DeserializationException.class, + () -> TypeDef.readTypeDef(fory.getTypeResolver(), malformed)); + } + + @Test + public void testDecodeRejectsParsedTypeDefWithMismatchedHash() { + Fory fory = Fory.builder().withMetaShare(true).build(); + TypeDef typeDef = TypeDef.buildTypeDef(fory.getTypeResolver(), Foo1.class); + MemoryBuffer encoded = MemoryBuffer.fromByteArray(typeDef.getEncoded()); + long header = encoded.readInt64(); + Assert.assertEquals(header & TypeDef.COMPRESS_META_FLAG, 0); + + byte[] malformed = corruptEncodedBody(typeDef, "f1"); + Assert.assertThrows( + DeserializationException.class, + () -> TypeDef.readTypeDef(fory.getTypeResolver(), MemoryBuffer.fromByteArray(malformed))); + } + + @Test + public void testDecodeRejectsHashConsistentMalformedTypeDefBody() { + Fory fory = Fory.builder().withMetaShare(true).build(); + MemoryBuffer body = MemoryBuffer.newHeapBuffer(1); + body.writeByte(0); + MemoryBuffer encoded = NativeTypeDefEncoder.prependHeader(body, false); + Assert.assertThrows( + RuntimeException.class, () -> TypeDef.readTypeDef(fory.getTypeResolver(), encoded)); + } + + private static byte[] corruptEncodedBody(TypeDef typeDef, String needle) { + byte[] malformed = typeDef.getEncoded().clone(); + byte[] needleBytes = Encoders.encodeFieldName(needle).getBytes(); + int index = indexOf(malformed, needleBytes, Long.BYTES); + Assert.assertTrue(index >= Long.BYTES); + malformed[index + needleBytes.length - 1] ^= 1; + return malformed; + } + + private static int indexOf(byte[] bytes, byte[] needle, int fromIndex) { + for (int i = fromIndex; i <= bytes.length - needle.length; i++) { + boolean match = true; + for (int j = 0; j < needle.length; j++) { + if (bytes[i + j] != needle[j]) { + match = false; + break; + } + } + if (match) { + return i; + } + } + return -1; } @Test diff --git a/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefEncoderTest.java b/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefEncoderTest.java index dde086a204..1b256b3d73 100644 --- a/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefEncoderTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefEncoderTest.java @@ -26,6 +26,8 @@ import lombok.Data; import org.apache.fory.Fory; import org.apache.fory.annotation.ForyField; +import org.apache.fory.exception.DeserializationException; +import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.resolver.TypeResolver; import org.testng.Assert; import org.testng.annotations.Test; @@ -113,6 +115,43 @@ public static class ClassWithNoAnnotations { private double field3; } + public static class EmptyStruct {} + + public static class ManyFields { + int f00; + int f01; + int f02; + int f03; + int f04; + int f05; + int f06; + int f07; + int f08; + int f09; + int f10; + int f11; + int f12; + int f13; + int f14; + int f15; + int f16; + int f17; + int f18; + int f19; + int f20; + int f21; + int f22; + int f23; + int f24; + int f25; + int f26; + int f27; + int f28; + int f29; + int f30; + int f31; + } + // Test data: Class with all fields using field names (tagId = -1) @Data public static class ClassWithAllFieldNames { @@ -365,6 +404,142 @@ public void testXlangTypeDefIsNotCompressed() { Assert.assertEquals(typeDef.getId() & TypeDef.COMPRESS_META_FLAG, 0); } + @Test + public void testExtendedFieldCountHeaderDoesNotSetRegisterByName() { + Fory fory = Fory.builder().withXlang(true).withMetaShare(true).build(); + fory.register(ManyFields.class, 6002); + TypeDef typeDef = TypeDef.buildTypeDef(fory.getTypeResolver(), ManyFields.class); + byte bodyHeader = typeDef.getEncoded()[typeDefBodyOffset(typeDef.getEncoded())]; + + Assert.assertEquals(bodyHeader & TypeDefEncoder.SMALL_NUM_FIELDS_THRESHOLD, 31); + Assert.assertEquals(bodyHeader & TypeDefEncoder.REGISTER_BY_NAME_FLAG, 0); + TypeDef decoded = + TypeDef.readTypeDef(fory.getTypeResolver(), MemoryBuffer.fromByteArray(typeDef.getEncoded())); + Assert.assertEquals(decoded.getFieldsInfo().size(), 32); + } + + @Test + public void testDecodeRejectsCompressedXlangTypeDef() { + Fory fory = Fory.builder().withXlang(true).withMetaShare(true).build(); + fory.register(ClassWithNoAnnotations.class); + TypeDef typeDef = TypeDef.buildTypeDef(fory.getTypeResolver(), ClassWithNoAnnotations.class); + byte[] body = Arrays.copyOfRange(typeDef.getEncoded(), Long.BYTES, typeDef.getEncoded().length); + byte[] compressed = fory.getConfig().getMetaCompressor().compress(body, 0, body.length); + MemoryBuffer compressedBody = MemoryBuffer.fromByteArray(compressed); + MemoryBuffer encoded = NativeTypeDefEncoder.prependHeader(compressedBody, true); + + Assert.assertThrows( + DeserializationException.class, + () -> TypeDef.readTypeDef(fory.getTypeResolver(), encoded)); + } + + @Test + public void testDecodeRejectsReservedGlobalBits() { + Fory fory = Fory.builder().withXlang(true).withMetaShare(true).build(); + fory.register(ClassWithNoAnnotations.class); + TypeDef typeDef = TypeDef.buildTypeDef(fory.getTypeResolver(), ClassWithNoAnnotations.class); + MemoryBuffer encoded = MemoryBuffer.fromByteArray(typeDef.getEncoded()); + long header = encoded.readInt64(); + + MemoryBuffer malformed = MemoryBuffer.newHeapBuffer(typeDef.getEncoded().length); + malformed.writeInt64(header | TypeDef.RESERVED_META_FLAGS); + malformed.writeBytes( + typeDef.getEncoded(), Long.BYTES, typeDef.getEncoded().length - Long.BYTES); + Assert.assertThrows( + DeserializationException.class, + () -> TypeDef.readTypeDef(fory.getTypeResolver(), malformed)); + } + + @Test + public void testDecodeRejectsTrailingTypeDefBodyBytes() { + Fory fory = Fory.builder().withXlang(true).withMetaShare(true).build(); + fory.register(ClassWithNoAnnotations.class); + TypeDef typeDef = TypeDef.buildTypeDef(fory.getTypeResolver(), ClassWithNoAnnotations.class); + MemoryBuffer encoded = MemoryBuffer.fromByteArray(typeDef.getEncoded()); + long header = encoded.readInt64(); + long size = header & TypeDef.META_SIZE_MASKS; + Assert.assertTrue(size < TypeDef.META_SIZE_MASKS); + + MemoryBuffer malformed = MemoryBuffer.newHeapBuffer(typeDef.getEncoded().length + 1); + malformed.writeInt64((header & ~TypeDef.META_SIZE_MASKS) | (size + 1)); + malformed.writeBytes( + typeDef.getEncoded(), Long.BYTES, typeDef.getEncoded().length - Long.BYTES); + malformed.writeByte(0); + Assert.assertThrows( + DeserializationException.class, + () -> TypeDef.readTypeDef(fory.getTypeResolver(), malformed)); + } + + @Test + public void testDecodeRejectsParsedTypeDefWithMismatchedHash() { + Fory fory = Fory.builder().withXlang(true).withMetaShare(true).build(); + fory.register(ClassWithNoAnnotations.class); + TypeDef typeDef = TypeDef.buildTypeDef(fory.getTypeResolver(), ClassWithNoAnnotations.class); + byte[] malformed = corruptEncodedBody(typeDef, "field1"); + + Assert.assertThrows( + DeserializationException.class, + () -> TypeDef.readTypeDef(fory.getTypeResolver(), MemoryBuffer.fromByteArray(malformed))); + } + + @Test + public void testDecodeRejectsHashConsistentMalformedTypeDefBody() { + Fory fory = Fory.builder().withXlang(true).withMetaShare(true).build(); + MemoryBuffer body = MemoryBuffer.newHeapBuffer(1); + body.writeByte(0); + MemoryBuffer encoded = NativeTypeDefEncoder.prependHeader(body, false); + Assert.assertThrows( + RuntimeException.class, () -> TypeDef.readTypeDef(fory.getTypeResolver(), encoded)); + } + + @Test + public void testDecodeRejectsRegisteredTypeDefKindMismatch() { + Fory fory = Fory.builder().withXlang(true).withMetaShare(true).build(); + fory.register(EmptyStruct.class, 6001); + MemoryBuffer body = MemoryBuffer.newHeapBuffer(8); + body.writeByte(0); + body.writeVarUInt32(6001); + MemoryBuffer encoded = NativeTypeDefEncoder.prependHeader(body, false); + + Assert.assertThrows( + DeserializationException.class, + () -> TypeDef.readTypeDef(fory.getTypeResolver(), encoded)); + } + + private static byte[] corruptEncodedBody(TypeDef typeDef, String needle) { + byte[] malformed = typeDef.getEncoded().clone(); + byte[] needleBytes = Encoders.encodeFieldName(needle).getBytes(); + int index = indexOf(malformed, needleBytes, Long.BYTES); + Assert.assertTrue(index >= Long.BYTES); + malformed[index + needleBytes.length - 1] ^= 1; + return malformed; + } + + private static int indexOf(byte[] bytes, byte[] needle, int fromIndex) { + for (int i = fromIndex; i <= bytes.length - needle.length; i++) { + boolean match = true; + for (int j = 0; j < needle.length; j++) { + if (bytes[i + j] != needle[j]) { + match = false; + break; + } + } + if (match) { + return i; + } + } + return -1; + } + + private static int typeDefBodyOffset(byte[] encoded) { + MemoryBuffer buffer = MemoryBuffer.fromByteArray(encoded); + long header = buffer.readInt64(); + if ((header & TypeDef.META_SIZE_MASKS) == TypeDef.META_SIZE_MASKS) { + buffer.readVarUInt32(); + } + return buffer.readerIndex(); + } + /** Helper method to get a field from a class by name. */ private Field getField(Class clazz, String fieldName) { try { diff --git a/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefTest.java b/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefTest.java index bee305b8d3..711a6510a7 100644 --- a/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefTest.java @@ -186,7 +186,7 @@ public void testInterface() { Fory fory = Fory.builder().withMetaShare(true).build(); TypeDef typeDef = TypeDef.buildTypeDef(fory.getTypeResolver(), Map.class); assertTrue(typeDef.getFieldsInfo().isEmpty()); - assertTrue(typeDef.hasFieldsMeta()); + assertFalse(typeDef.isStructSchemaKind()); } @Test diff --git a/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java b/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java index b99079e3b1..dc59cfde03 100644 --- a/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java @@ -250,7 +250,7 @@ public void testSharedRegistrySharesTypeDefCachesAcrossForyInstances() { } @Test - public void testSharedRegistryCachesTypeDefByIdButKeepsTypeInfoLocal() { + public void testReadTypeDefPublishesValidatedTypeDefById() { ForyBuilder builder = Fory.builder().withXlang(false).requireClassRegistration(false).withMetaShare(true); finishBuilder(builder); @@ -265,23 +265,42 @@ public void testSharedRegistryCachesTypeDefByIdButKeepsTypeInfoLocal() { typeDef.writeTypeDef(buffer); buffer.readerIndex(0); - TypeDef canonicalTypeDef1 = resolver1.readTypeDef(buffer, buffer.readInt64()); + TypeDef readTypeDef1 = resolver1.readTypeDef(buffer, buffer.readInt64()); buffer.readerIndex(0); - TypeDef canonicalTypeDef2 = resolver2.readTypeDef(buffer, buffer.readInt64()); + TypeDef readTypeDef2 = resolver2.readTypeDef(buffer, buffer.readInt64()); - assertSame(sharedRegistry.typeDefById.get(typeDef.getId()), canonicalTypeDef1); - assertSame(canonicalTypeDef1, canonicalTypeDef2); - assertNull(resolver1.extRegistry.typeInfoByTypeDefId.get(typeDef.getId())); - assertNull(resolver2.extRegistry.typeInfoByTypeDefId.get(typeDef.getId())); + assertSame(readTypeDef1, readTypeDef2); - TypeInfo typeInfo1 = resolver1.buildMetaSharedTypeInfo(canonicalTypeDef1); - TypeInfo typeInfo2 = resolver2.buildMetaSharedTypeInfo(canonicalTypeDef2); + TypeInfo typeInfo1 = resolver1.buildMetaSharedTypeInfo(readTypeDef1); + TypeInfo typeInfo2 = resolver2.buildMetaSharedTypeInfo(readTypeDef2); - assertSame(resolver1.extRegistry.typeInfoByTypeDefId.get(typeDef.getId()), typeInfo1); - assertSame(resolver2.extRegistry.typeInfoByTypeDefId.get(typeDef.getId()), typeInfo2); assertNotSame(typeInfo1, typeInfo2); } + @Test + public void testTypeDefHeaderCacheStopsAtMaxEntries() { + ForyBuilder builder = + Fory.builder().withXlang(false).requireClassRegistration(false).withMetaShare(true); + finishBuilder(builder); + SharedRegistry sharedRegistry = new SharedRegistry(); + Fory fory = new Fory(builder, ClassResolverTest.class.getClassLoader(), sharedRegistry); + ClassResolver resolver = (ClassResolver) fory.getTypeResolver(); + TypeDef typeDef = TypeDef.buildTypeDef(resolver, BeanB.class); + int maxCachedTypeDefs = 8192; + for (long i = 0; i < maxCachedTypeDefs; i++) { + sharedRegistry.typeDefById.put(i, typeDef); + } + + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(256); + typeDef.writeTypeDef(buffer); + buffer.readerIndex(0); + TypeDef readTypeDef = resolver.readTypeDef(buffer, buffer.readInt64()); + + assertNotNull(readTypeDef); + assertNull(sharedRegistry.typeDefById.get(typeDef.getId())); + assertEquals(sharedRegistry.typeDefById.size(), maxCachedTypeDefs); + } + @Test public void testSharedRegistryCachesFieldDescriptorsAndDescriptorGrouper() { ForyBuilder builder = Fory.builder().withXlang(false).requireClassRegistration(false); diff --git a/java/fory-core/src/test/java/org/apache/fory/resolver/MetaStringIOTest.java b/java/fory-core/src/test/java/org/apache/fory/resolver/MetaStringIOTest.java index ebdfc8a1b1..ea1bb16faa 100644 --- a/java/fory-core/src/test/java/org/apache/fory/resolver/MetaStringIOTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/resolver/MetaStringIOTest.java @@ -20,13 +20,18 @@ package org.apache.fory.resolver; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; import static org.testng.Assert.assertNotSame; import static org.testng.Assert.assertSame; import static org.testng.Assert.assertTrue; +import static org.testng.Assert.expectThrows; import java.nio.ByteBuffer; +import org.apache.fory.TestUtils; +import org.apache.fory.collection.LongLongByteMap; import org.apache.fory.context.MetaStringReader; import org.apache.fory.context.MetaStringWriter; +import org.apache.fory.exception.ForyException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.memory.MemoryUtils; import org.apache.fory.meta.EncodedMetaString; @@ -141,6 +146,90 @@ public void testSharedRegistryCapsEncodedMetaStringCount() { assertNotSame(overflow1, overflow2); } + @Test + public void testReadBigMetaStringRejectsNonCanonicalHash() { + SharedRegistry sharedRegistry = new SharedRegistry(); + MetaStringReader reader = new MetaStringReader(sharedRegistry); + EncodedMetaString encodedMetaString = newGenericMetaString(StringUtils.random(32, 0)); + MemoryBuffer buffer = MemoryUtils.buffer(64); + + buffer.writeVarUInt32Small7(encodedMetaString.bytes.length << 1); + buffer.writeInt64(encodedMetaString.hash + 0x100); + buffer.writeBytes(encodedMetaString.bytes); + + expectThrows(ForyException.class, () -> reader.readMetaString(buffer)); + } + + @Test + public void testCachedBigMetaStringReusesHeaderCache() { + SharedRegistry sharedRegistry = new SharedRegistry(); + MetaStringReader reader = new MetaStringReader(sharedRegistry); + EncodedMetaString encodedMetaString = newGenericMetaString(StringUtils.random(32, 0)); + MemoryBuffer buffer = MemoryUtils.buffer(128); + + buffer.writeVarUInt32Small7(encodedMetaString.bytes.length << 1); + buffer.writeInt64(encodedMetaString.hash); + buffer.writeBytes(encodedMetaString.bytes); + assertSame( + reader.readMetaString(buffer), + sharedRegistry.getOrCreateEncodedMetaString( + encodedMetaString.bytes, encodedMetaString.hash)); + assertEquals(buffer.readerIndex(), buffer.writerIndex()); + } + + @Test + public void testReadSmallMetaStringKeyIncludesLengthAndEncoding() { + SharedRegistry sharedRegistry = new SharedRegistry(); + MetaStringReader reader = new MetaStringReader(sharedRegistry); + MemoryBuffer buffer = MemoryUtils.buffer(32); + + buffer.writeVarUInt32Small7(1 << 1); + buffer.writeByte(0); + buffer.writeByte('a'); + buffer.writeVarUInt32Small7(2 << 1); + buffer.writeByte(0); + buffer.writeByte('a'); + buffer.writeByte(0); + + EncodedMetaString oneByte = reader.readMetaString(buffer); + EncodedMetaString twoBytes = reader.readMetaString(buffer); + + assertEquals(oneByte.bytes.length, 1); + assertEquals(twoBytes.bytes.length, 2); + assertNotEquals(oneByte.hash, twoBytes.hash); + } + + @Test + public void testMetaStringReaderResetClearsDynamicIdsOnly() { + SharedRegistry sharedRegistry = new SharedRegistry(); + MetaStringReader reader = new MetaStringReader(sharedRegistry); + MemoryBuffer buffer = MemoryUtils.buffer(32); + + buffer.writeVarUInt32Small7(1 << 1); + buffer.writeByte(0); + buffer.writeByte('a'); + reader.readMetaString(buffer); + + LongLongByteMap readCache = TestUtils.getFieldValue(reader, "longLongMetaStringMap"); + assertEquals(readCache.size, 1); + reader.reset(); + assertEquals(readCache.size, 1); + + MemoryBuffer refBuffer = MemoryUtils.buffer(8); + refBuffer.writeVarUInt32Small7((1 << 1) | 1); + expectThrows(ForyException.class, () -> reader.readMetaString(refBuffer)); + } + + @Test + public void testTypeNameBytesUsesBytesWhenHashesMatch() { + EncodedMetaString namespace1 = new EncodedMetaString(new byte[] {'a'}, 0x100); + EncodedMetaString namespace2 = new EncodedMetaString(new byte[] {'b'}, 0x100); + EncodedMetaString typeName = new EncodedMetaString(new byte[] {'C'}, 0x200); + + assertNotEquals( + new TypeNameBytes(namespace1, typeName), new TypeNameBytes(namespace2, typeName)); + } + private static EncodedMetaString newGenericMetaString(String str) { return Encoders.GENERIC_ENCODER.encodeBinary(str, Encoders.computeGenericEncoding(str)); } diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/BufferSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/BufferSerializersTest.java index 8923f137d6..740f0bd6f1 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/BufferSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/BufferSerializersTest.java @@ -22,7 +22,9 @@ import java.nio.ByteBuffer; import org.apache.fory.Fory; import org.apache.fory.ForyTestBase; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.memory.ByteBufferUtil; +import org.apache.fory.memory.MemoryBuffer; import org.testng.annotations.Test; public class BufferSerializersTest extends ForyTestBase { @@ -51,4 +53,24 @@ public void testByteBuffer(Fory fory) { ByteBufferUtil.rewind(buffer2); copyCheck(fory, buffer2); } + + @Test + public void testByteBufferRejectsMalformedPayload() { + Fory fory = Fory.builder().build(); + Serializer serializer = + new BufferSerializers.ByteBufferSerializer(fory.getTypeResolver(), ByteBuffer.class); + + MemoryBuffer zeroSize = MemoryBuffer.newHeapBuffer(16); + zeroSize.writeBoolean(true); + zeroSize.writeVarUInt32Aligned(0); + org.testng.Assert.assertThrows( + DeserializationException.class, () -> readSerializer(fory, serializer, zeroSize)); + + MemoryBuffer invalidOrder = MemoryBuffer.newHeapBuffer(16); + invalidOrder.writeBoolean(true); + invalidOrder.writeVarUInt32Aligned(1); + invalidOrder.writeByte(2); + org.testng.Assert.assertThrows( + DeserializationException.class, () -> readSerializer(fory, serializer, invalidOrder)); + } } diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/ObjectStreamSerializerTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/ObjectStreamSerializerTest.java index 7fed1fb5ff..d623d609c7 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/ObjectStreamSerializerTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/ObjectStreamSerializerTest.java @@ -29,7 +29,6 @@ import java.io.ObjectOutputStream; import java.io.ObjectStreamField; import java.io.Serializable; -import java.lang.reflect.Field; import java.lang.reflect.Method; import java.math.BigInteger; import java.net.Inet4Address; @@ -46,13 +45,11 @@ import lombok.EqualsAndHashCode; import org.apache.fory.Fory; import org.apache.fory.ForyTestBase; -import org.apache.fory.collection.LongMap; import org.apache.fory.config.ForyBuilder; import org.apache.fory.context.MetaReadContext; import org.apache.fory.context.MetaWriteContext; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.resolver.SharedRegistry; -import org.apache.fory.resolver.TypeInfo; import org.apache.fory.serializer.collection.CollectionSerializers; import org.apache.fory.serializer.collection.MapSerializers; import org.apache.fory.util.Preconditions; @@ -1071,8 +1068,7 @@ public void testCrossForyInstanceSerialization(boolean compatible) { } @Test(dataProvider = "compatibleModeProvider") - public void testObjectStreamSharedRegistryCanonicalizesTypeDef(boolean compatible) - throws Exception { + public void testObjectStreamReadersReuseValidatedTypeDefCache(boolean compatible) { ForyBuilder builder = Fory.builder() .withXlang(false) @@ -1102,20 +1098,14 @@ public void testObjectStreamSharedRegistryCanonicalizesTypeDef(boolean compatibl byte[] bytes = writerFory.serialize(new MixedSerializationClass("shared", 7)); readerFory1.setMetaReadContext(new MetaReadContext()); - readerFory1.deserialize(bytes); + MixedSerializationClass result1 = (MixedSerializationClass) readerFory1.deserialize(bytes); readerFory2.setMetaReadContext(new MetaReadContext()); - readerFory2.deserialize(bytes); - - TypeInfo typeInfo1 = - getFirstTypeInfo( - getTypeDefIdToTypeInfo( - (ObjectStreamSerializer) readerFory1.getSerializer(MixedSerializationClass.class))); - TypeInfo typeInfo2 = - getFirstTypeInfo( - getTypeDefIdToTypeInfo( - (ObjectStreamSerializer) readerFory2.getSerializer(MixedSerializationClass.class))); - - assertSame(typeInfo1.getTypeDef(), typeInfo2.getTypeDef()); + MixedSerializationClass result2 = (MixedSerializationClass) readerFory2.deserialize(bytes); + + assertEquals(result1.name, "shared"); + assertEquals(result1.value, 7); + assertEquals(result2.name, "shared"); + assertEquals(result2.value, 7); } // ==================== Default Value Tests ==================== @@ -1158,35 +1148,6 @@ private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundEx } } - private static LongMap getTypeDefIdToTypeInfo(ObjectStreamSerializer serializer) - throws ReflectiveOperationException { - Field field = ObjectStreamSerializer.class.getDeclaredField("typeDefIdToTypeInfo"); - field.setAccessible(true); - @SuppressWarnings("unchecked") - LongMap typeDefIdToTypeInfo = (LongMap) field.get(serializer); - return typeDefIdToTypeInfo; - } - - private static TypeInfo getFirstTypeInfo(LongMap typeDefIdToTypeInfo) - throws ReflectiveOperationException { - Field zeroValueField = LongMap.class.getDeclaredField("zeroValue"); - zeroValueField.setAccessible(true); - TypeInfo zeroValue = (TypeInfo) zeroValueField.get(typeDefIdToTypeInfo); - if (zeroValue != null) { - return zeroValue; - } - Field valueTableField = LongMap.class.getDeclaredField("valueTable"); - valueTableField.setAccessible(true); - Object[] valueTable = (Object[]) valueTableField.get(typeDefIdToTypeInfo); - for (Object value : valueTable) { - if (value != null) { - return (TypeInfo) value; - } - } - Assert.fail("Expected at least one cached TypeInfo in ObjectStreamSerializer"); - return null; - } - private static void finishBuilder(ForyBuilder builder) { try { Method finish = ForyBuilder.class.getDeclaredMethod("finish"); diff --git a/java/fory-extensions/src/test/java/org/apache/fory/extension/meta/TypeDefEncoderTest.java b/java/fory-extensions/src/test/java/org/apache/fory/extension/meta/TypeDefEncoderTest.java index c75566b063..22199c4548 100644 --- a/java/fory-extensions/src/test/java/org/apache/fory/extension/meta/TypeDefEncoderTest.java +++ b/java/fory-extensions/src/test/java/org/apache/fory/extension/meta/TypeDefEncoderTest.java @@ -50,7 +50,7 @@ public void testBasicTypeDefZstdMetaCompressor() throws Exception { List fieldsInfo = buildFieldsInfo(classResolver, type); MemoryBuffer buffer = NativeTypeDefEncoder.encodeTypeDef( - classResolver, type, getClassFields(type, fieldsInfo), true); + classResolver, type, getClassFields(type, fieldsInfo)); TypeDef typeDef = TypeDef.readTypeDef(classResolver, buffer); Assert.assertEquals(typeDef.getClassName(), type.getName()); Assert.assertEquals(typeDef.getFieldsInfo().size(), type.getDeclaredFields().length); diff --git a/javascript/packages/core/lib/context.ts b/javascript/packages/core/lib/context.ts index 86284db62c..ce828d787a 100644 --- a/javascript/packages/core/lib/context.ts +++ b/javascript/packages/core/lib/context.ts @@ -19,7 +19,11 @@ import { BinaryReader } from "./reader"; import { BinaryWriter } from "./writer"; -import { MetaString, MetaStringDecoder, MetaStringEncoder } from "./meta/MetaString"; +import { + MetaString, + MetaStringDecoder, + MetaStringEncoder, +} from "./meta/MetaString"; import { InnerFieldInfo, TypeMeta } from "./meta/TypeMeta"; import { Type, TypeInfo } from "./typeInfo"; import { Config, RefFlags, Serializer, TypeId } from "./type"; @@ -38,8 +42,7 @@ type TypeResolverLike = { class MetaStringBytes { dynamicWriteStringId = -1; - constructor(public bytes: MetaString) { - } + constructor(public bytes: MetaString) {} } export class RefWriter { @@ -61,8 +64,7 @@ export class RefWriter { export class RefReader { private readObjects: any[] = []; - constructor(private reader: BinaryReader) { - } + constructor(private reader: BinaryReader) {} reset() { this.readObjects = []; @@ -196,8 +198,8 @@ export class WriteContext { checkCollectionSize(size: number) { if (size > this._maxCollectionSize) { throw new Error( - `Collection size ${size} exceeds maxCollectionSize ${this._maxCollectionSize}. ` - + "The data may be malicious, or increase maxCollectionSize if needed." + `Collection size ${size} exceeds maxCollectionSize ${this._maxCollectionSize}. ` + + "The data may be malicious, or increase maxCollectionSize if needed.", ); } } @@ -205,8 +207,8 @@ export class WriteContext { checkBinarySize(size: number) { if (size > this._maxBinarySize) { throw new Error( - `Binary size ${size} exceeds maxBinarySize ${this._maxBinarySize}. ` - + "The data may be malicious, or increase maxBinarySize if needed." + `Binary size ${size} exceeds maxBinarySize ${this._maxBinarySize}. ` + + "The data may be malicious, or increase maxBinarySize if needed.", ); } } @@ -382,13 +384,18 @@ export class WriteContext { } export class ReadContext { + private static readonly MAX_CACHED_TYPE_META = 8192; + readonly reader: BinaryReader; readonly refReader: RefReader; readonly metaStringReader: MetaStringReader; private typeMeta: TypeMeta[] = []; /** Persistent cross-message cache keyed by 8-byte type meta header. */ - private typeMetaCache: Map = new Map(); + private typeMetaCache: Map< + bigint, + { readonly typeMeta: TypeMeta; readonly skipBytesAfterHeader: number } + > = new Map(); private _depth = 0; private _maxDepth: number; private _maxBinarySize: number; @@ -422,8 +429,8 @@ export class ReadContext { this._depth++; if (this._depth > this._maxDepth) { throw new Error( - `Deserialization depth limit exceeded: ${this._depth} > ${this._maxDepth}. ` - + "The data may be malicious, or increase maxDepth if needed." + `Deserialization depth limit exceeded: ${this._depth} > ${this._maxDepth}. ` + + "The data may be malicious, or increase maxDepth if needed.", ); } } @@ -435,8 +442,8 @@ export class ReadContext { checkCollectionSize(size: number) { if (size > this._maxCollectionSize) { throw new Error( - `Collection size ${size} exceeds maxCollectionSize ${this._maxCollectionSize}. ` - + "The data may be malicious, or increase maxCollectionSize if needed." + `Collection size ${size} exceeds maxCollectionSize ${this._maxCollectionSize}. ` + + "The data may be malicious, or increase maxCollectionSize if needed.", ); } } @@ -444,8 +451,8 @@ export class ReadContext { checkBinarySize(size: number) { if (size > this._maxBinarySize) { throw new Error( - `Binary size ${size} exceeds maxBinarySize ${this._maxBinarySize}. ` - + "The data may be malicious, or increase maxBinarySize if needed." + `Binary size ${size} exceeds maxBinarySize ${this._maxBinarySize}. ` + + "The data may be malicious, or increase maxBinarySize if needed.", ); } } @@ -477,27 +484,54 @@ export class ReadContext { const cached = this.typeMetaCache.get(header); let typeMeta: TypeMeta; if (cached) { - TypeMeta.skipBody(this.reader, header); - typeMeta = cached; + // Header-cache hits intentionally skip without rehashing. Entries reach this cache only + // after a successful TypeMeta parse and 52-bit body-hash validation. + this.reader.readSkip(cached.skipBytesAfterHeader); + typeMeta = cached.typeMeta; } else { + const bodyStart = this.reader.readGetCursor(); typeMeta = TypeMeta.fromBytesAfterHeader(this.reader, header); - this.typeMetaCache.set(header, typeMeta); + if (this.typeMetaCache.size < ReadContext.MAX_CACHED_TYPE_META) { + this.typeMetaCache.set(header, { + typeMeta, + skipBytesAfterHeader: this.reader.readGetCursor() - bodyStart, + }); + } } this.typeMeta[dynamicTypeId] = typeMeta; return typeMeta; } - private fieldInfoToTypeInfo(fieldInfo: InnerFieldInfo, fallbackTypeInfo?: TypeInfo): TypeInfo { + private fieldInfoToTypeInfo( + fieldInfo: InnerFieldInfo, + fallbackTypeInfo?: TypeInfo, + ): TypeInfo { switch (fieldInfo.typeId) { case TypeId.MAP: return Type.map( - this.fieldInfoToTypeInfo(fieldInfo.options!.key!, fallbackTypeInfo?.options?.key), - this.fieldInfoToTypeInfo(fieldInfo.options!.value!, fallbackTypeInfo?.options?.value) + this.fieldInfoToTypeInfo( + fieldInfo.options!.key!, + fallbackTypeInfo?.options?.key, + ), + this.fieldInfoToTypeInfo( + fieldInfo.options!.value!, + fallbackTypeInfo?.options?.value, + ), ); case TypeId.LIST: - return Type.list(this.fieldInfoToTypeInfo(fieldInfo.options!.inner!, fallbackTypeInfo?.options?.inner)); + return Type.list( + this.fieldInfoToTypeInfo( + fieldInfo.options!.inner!, + fallbackTypeInfo?.options?.inner, + ), + ); case TypeId.SET: - return Type.set(this.fieldInfoToTypeInfo(fieldInfo.options!.key!, fallbackTypeInfo?.options?.key)); + return Type.set( + this.fieldInfoToTypeInfo( + fieldInfo.options!.key!, + fallbackTypeInfo?.options?.key, + ), + ); default: { // Remote TypeMeta only carries the nested user-defined type kind, not the // concrete named type or custom serializer identity. Reuse the local field @@ -508,13 +542,19 @@ export class ReadContext { if (fallbackTypeInfo) { return fallbackTypeInfo.clone(); } - const serializer = this.typeResolver.getSerializerById(fieldInfo.typeId, fieldInfo.userTypeId); + const serializer = this.typeResolver.getSerializerById( + fieldInfo.typeId, + fieldInfo.userTypeId, + ); if (serializer) { return serializer.getTypeInfo().clone(); } return Type.any(); } - const serializer = this.typeResolver.getSerializerById(fieldInfo.typeId, fieldInfo.userTypeId); + const serializer = this.typeResolver.getSerializerById( + fieldInfo.typeId, + fieldInfo.userTypeId, + ); if (serializer) { return serializer.getTypeInfo().clone(); } @@ -536,7 +576,10 @@ export class ReadContext { const named = `${typeMeta.getNs()}$${typeMeta.getTypeName()}`; original = this.typeResolver.getSerializerByName(named); } else { - original = this.typeResolver.getSerializerById(typeId, typeMeta.getUserTypeId()); + original = this.typeResolver.getSerializerById( + typeId, + typeMeta.getUserTypeId(), + ); } } let typeInfo: TypeInfo; @@ -545,17 +588,25 @@ export class ReadContext { } else if (!TypeId.isNamedType(typeId)) { typeInfo = Type.struct(typeMeta.getUserTypeId()); } else { - typeInfo = Type.struct({ typeName: typeMeta.getTypeName(), namespace: typeMeta.getNs() }); + typeInfo = Type.struct({ + typeName: typeMeta.getTypeName(), + namespace: typeMeta.getNs(), + }); } const localProps = original?.getTypeInfo().options?.props; - const props = Object.fromEntries(typeMeta.remapFieldNames(localProps).map((fieldInfo) => { - const localFieldTypeInfo = localProps?.[fieldInfo.getFieldName()]; - const fieldTypeInfo = this.fieldInfoToTypeInfo(fieldInfo, localFieldTypeInfo) - .setNullable(fieldInfo.nullable) - .setTrackingRef(fieldInfo.trackingRef) - .setId(fieldInfo.fieldId); - return [fieldInfo.getFieldName(), fieldTypeInfo]; - })); + const props = Object.fromEntries( + typeMeta.remapFieldNames(localProps).map((fieldInfo) => { + const localFieldTypeInfo = localProps?.[fieldInfo.getFieldName()]; + const fieldTypeInfo = this.fieldInfoToTypeInfo( + fieldInfo, + localFieldTypeInfo, + ) + .setNullable(fieldInfo.nullable) + .setTrackingRef(fieldInfo.trackingRef) + .setId(fieldInfo.fieldId); + return [fieldInfo.getFieldName(), fieldTypeInfo]; + }), + ); typeInfo.options = { ...typeInfo.options, props, diff --git a/javascript/packages/core/lib/fory.ts b/javascript/packages/core/lib/fory.ts index 8066698c7e..d96222b463 100644 --- a/javascript/packages/core/lib/fory.ts +++ b/javascript/packages/core/lib/fory.ts @@ -18,7 +18,15 @@ */ import TypeResolver from "./typeResolver"; -import { ConfigFlags, Serializer, Config, ForyTypeInfoSymbol, WithForyClsInfo, TypeId, CustomSerializer } from "./type"; +import { + ConfigFlags, + Serializer, + Config, + ForyTypeInfoSymbol, + WithForyClsInfo, + TypeId, + CustomSerializer, +} from "./type"; import { InputType, ResultType, TypeInfo } from "./typeInfo"; import { Gen } from "./gen"; import { PlatformBuffer } from "./platformBuffer"; @@ -40,15 +48,22 @@ export default class Fory { this.config = this.initConfig(config); const maxDepth = this.config.maxDepth ?? DEFAULT_DEPTH_LIMIT; if (!Number.isInteger(maxDepth) || maxDepth < MIN_DEPTH_LIMIT) { - throw new Error(`maxDepth must be an integer >= ${MIN_DEPTH_LIMIT} but got ${maxDepth}`); + throw new Error( + `maxDepth must be an integer >= ${MIN_DEPTH_LIMIT} but got ${maxDepth}`, + ); } const maxBinarySize = this.config.maxBinarySize ?? DEFAULT_MAX_BINARY_SIZE; if (!Number.isInteger(maxBinarySize) || maxBinarySize < 0) { - throw new Error(`maxBinarySize must be a non-negative integer but got ${maxBinarySize}`); + throw new Error( + `maxBinarySize must be a non-negative integer but got ${maxBinarySize}`, + ); } - const maxCollectionSize = this.config.maxCollectionSize ?? DEFAULT_MAX_COLLECTION_SIZE; + const maxCollectionSize = + this.config.maxCollectionSize ?? DEFAULT_MAX_COLLECTION_SIZE; if (!Number.isInteger(maxCollectionSize) || maxCollectionSize < 0) { - throw new Error(`maxCollectionSize must be a non-negative integer but got ${maxCollectionSize}`); + throw new Error( + `maxCollectionSize must be a non-negative integer but got ${maxCollectionSize}`, + ); } this.typeResolver = new TypeResolver(this.config); @@ -71,17 +86,24 @@ export default class Fory { }; } - register(constructor: new () => T, customSerializer: CustomSerializer): { + register( + constructor: new () => T, + customSerializer: CustomSerializer, + ): { serializer: Serializer; serialize(data: InputType | null): PlatformBuffer; deserialize(bytes: Uint8Array): ResultType; }; - register(typeInfo: T): { + register( + typeInfo: T, + ): { serializer: Serializer; serialize(data: InputType | null): PlatformBuffer; deserialize(bytes: Uint8Array): ResultType; }; - register any>(constructor: T): { + register any>( + constructor: T, + ): { serializer: Serializer; serialize(data: Partial> | null): PlatformBuffer; deserialize(bytes: Uint8Array): InstanceType | null; @@ -89,14 +111,21 @@ export default class Fory { register(constructor: any, customSerializer?: CustomSerializer) { let serializer: Serializer; if (constructor.prototype?.[ForyTypeInfoSymbol]) { - const typeInfo: TypeInfo = (constructor.prototype[ForyTypeInfoSymbol] as WithForyClsInfo).structTypeInfo; + const typeInfo: TypeInfo = ( + constructor.prototype[ForyTypeInfoSymbol] as WithForyClsInfo + ).structTypeInfo; typeInfo.freeze(); - serializer = new Gen(this.typeResolver, { creator: constructor, customSerializer }).generateSerializer(typeInfo); + serializer = new Gen(this.typeResolver, { + creator: constructor, + customSerializer, + }).generateSerializer(typeInfo); this.typeResolver.registerSerializer(typeInfo, serializer); } else { const typeInfo = constructor; typeInfo.freeze(); - serializer = new Gen(this.typeResolver, { customSerializer }).generateSerializer(typeInfo); + serializer = new Gen(this.typeResolver, { + customSerializer, + }).generateSerializer(typeInfo); this.typeResolver.registerSerializer(typeInfo, serializer); } return { @@ -111,32 +140,37 @@ export default class Fory { }; } - deserialize(bytes: Uint8Array, serializer: Serializer = this.anySerializer): T | null { + deserialize( + bytes: Uint8Array, + serializer: Serializer = this.anySerializer, + ): T | null { this.readContext.reset(bytes); const reader = this.readContext.reader; const bitmap = reader.readUint8(); - if ((bitmap & ConfigFlags.isNullFlag) === ConfigFlags.isNullFlag) { - return null; + if (bitmap !== ConfigFlags.isCrossLanguageFlag) { + this.throwInvalidRootHeader(bitmap); } - const isCrossLanguage = (bitmap & ConfigFlags.isCrossLanguageFlag) === ConfigFlags.isCrossLanguageFlag; - if (!isCrossLanguage) { - throw new Error("support crosslanguage mode only"); + return serializer.readRef(); + } + + private throwInvalidRootHeader(bitmap: number): never { + const knownFlags = + ConfigFlags.isCrossLanguageFlag | ConfigFlags.isOutOfBandFlag; + if ((bitmap & ~knownFlags) !== 0) { + throw new Error( + `unsupported root header bitmap 0x${bitmap.toString(16)}`, + ); } - const isOutOfBandEnabled = (bitmap & ConfigFlags.isOutOfBandFlag) === ConfigFlags.isOutOfBandFlag; - if (isOutOfBandEnabled) { - throw new Error("outofband mode is not supported now"); + if ((bitmap & ConfigFlags.isCrossLanguageFlag) === 0) { + throw new Error("support crosslanguage mode only"); } - return serializer.readRef(); + throw new Error("outofband mode is not supported now"); } private serializeInternal(data: T, serializer: Serializer) { this.writeContext.reset(); const writer = this.writeContext.writer; - let bitmap = 0; - if (data === null) { - bitmap |= ConfigFlags.isNullFlag; - } - bitmap |= ConfigFlags.isCrossLanguageFlag; + const bitmap = ConfigFlags.isCrossLanguageFlag; writer.writeUint8(bitmap); writer.reserve(serializer.fixedSize); serializer.writeRef(data); diff --git a/javascript/packages/core/lib/meta/TypeMeta.ts b/javascript/packages/core/lib/meta/TypeMeta.ts index 6069a5333c..231a13671e 100644 --- a/javascript/packages/core/lib/meta/TypeMeta.ts +++ b/javascript/packages/core/lib/meta/TypeMeta.ts @@ -36,27 +36,36 @@ const pkgDecoder = new MetaStringDecoder(".", "_"); const typeNameEncoder = new MetaStringEncoder("$", "."); const typeNameDecoder = new MetaStringDecoder("$", "."); -// Constants shared with python/java/rust/go 0.17+. See e.g. -// python/pyfory/meta/typedef.py, java/.../TypeDef.java, -// rust/fory-core/src/meta/type_meta.rs, go/fory/type_def.go. The -// JavaScript binding previously placed COMPRESS_META_FLAG at bit 63 -// and HAS_FIELDS_META_FLAG at bit 62, and used NUM_HASH_BITS = 41, -// producing an 8-byte TypeMeta preamble that no other xlang binding -// could decode. Aligning with the constants every other binding uses -// so NAMED_COMPATIBLE_STRUCT output is byte-compatible cross-binding. -const COMPRESS_META_FLAG = 1n << 9n; -const HAS_FIELDS_META_FLAG = 1n << 8n; -const META_SIZE_MASKS = 0xFF; // low 8 bits -const NUM_HASH_BITS = 50; +const COMPRESS_META_FLAG = 1n << 8n; +const RESERVED_META_FLAGS = 0b111n << 9n; +const META_SIZE_MASKS = 0xff; // low 8 bits +const NUM_HASH_BITS = 52; const HASH_SHIFT_BITS = 64n - BigInt(NUM_HASH_BITS); +const UINT64_MASK = 0xffffffffffffffffn; +const HEADER_HASH_MASK = UINT64_MASK ^ ((1n << HASH_SHIFT_BITS) - 1n); const BIG_NAME_THRESHOLD = 0b111111; const PRIMITIVE_TYPE_IDS = [ - TypeId.BOOL, TypeId.INT8, TypeId.INT16, TypeId.INT32, TypeId.VARINT32, - TypeId.INT64, TypeId.VARINT64, TypeId.TAGGED_INT64, TypeId.UINT8, - TypeId.UINT16, TypeId.UINT32, TypeId.VAR_UINT32, TypeId.UINT64, - TypeId.VAR_UINT64, TypeId.TAGGED_UINT64, TypeId.FLOAT8, TypeId.FLOAT16, - TypeId.BFLOAT16, TypeId.FLOAT32, TypeId.FLOAT64, + TypeId.BOOL, + TypeId.INT8, + TypeId.INT16, + TypeId.INT32, + TypeId.VARINT32, + TypeId.INT64, + TypeId.VARINT64, + TypeId.TAGGED_INT64, + TypeId.UINT8, + TypeId.UINT16, + TypeId.UINT32, + TypeId.VAR_UINT32, + TypeId.UINT64, + TypeId.VAR_UINT64, + TypeId.TAGGED_UINT64, + TypeId.FLOAT8, + TypeId.FLOAT16, + TypeId.BFLOAT16, + TypeId.FLOAT32, + TypeId.FLOAT64, ]; export const isPrimitiveTypeId = (typeId: number): boolean => { @@ -64,7 +73,12 @@ export const isPrimitiveTypeId = (typeId: number): boolean => { }; export const refTrackingUnableTypeId = (typeId: number): boolean => { - return PRIMITIVE_TYPE_IDS.includes(typeId as any) || [TypeId.DURATION, TypeId.DATE, TypeId.TIMESTAMP, TypeId.STRING].includes(typeId as any); + return ( + PRIMITIVE_TYPE_IDS.includes(typeId as any) || + [TypeId.DURATION, TypeId.DATE, TypeId.TIMESTAMP, TypeId.STRING].includes( + typeId as any, + ) + ); }; function getPrimitiveTypeSize(typeId: number) { @@ -114,7 +128,11 @@ function getPrimitiveTypeSize(typeId: number) { } } -export type InnerFieldInfoOptions = { key?: InnerFieldInfo; value?: InnerFieldInfo; inner?: InnerFieldInfo }; +export type InnerFieldInfoOptions = { + key?: InnerFieldInfo; + value?: InnerFieldInfo; + inner?: InnerFieldInfo; +}; export interface InnerFieldInfo { typeId: number; userTypeId: number; @@ -131,9 +149,8 @@ export class FieldInfo { public trackingRef = false, public nullable = false, public options: InnerFieldInfoOptions = {}, - public fieldId?: number - ) { - } + public fieldId?: number, + ) {} getFieldName() { return this.fieldName; @@ -155,7 +172,11 @@ export class FieldInfo { return this.fieldId; } - static writeTypeId(writer: BinaryWriter, typeInfo: InnerFieldInfo, writeFlags = false) { + static writeTypeId( + writer: BinaryWriter, + typeInfo: InnerFieldInfo, + writeFlags = false, + ) { let { typeId } = typeInfo; if (typeId === TypeId.NAMED_ENUM) { typeId = TypeId.ENUM; @@ -164,7 +185,7 @@ export class FieldInfo { } const { trackingRef, nullable } = typeInfo; if (writeFlags) { - typeId = (typeId << 2); + typeId = typeId << 2; if (nullable) { typeId |= 0b10; } @@ -205,24 +226,80 @@ export class FieldInfo { const SMALL_NUM_FIELDS_THRESHOLD = 0b11111; const REGISTER_BY_NAME_FLAG = 0b100000; +const COMPATIBLE_TYPEDEF_FLAG = 0b01000000; +const STRUCT_TYPEDEF_FLAG = 0b10000000; const FIELD_NAME_SIZE_THRESHOLD = 0b1111; -const pkgNameEncoding = [Encoding.UTF_8, Encoding.ALL_TO_LOWER_SPECIAL, Encoding.LOWER_UPPER_DIGIT_SPECIAL]; -const fieldNameEncoding = [Encoding.UTF_8, Encoding.ALL_TO_LOWER_SPECIAL, Encoding.LOWER_UPPER_DIGIT_SPECIAL]; -const typeNameEncoding = [Encoding.UTF_8, Encoding.ALL_TO_LOWER_SPECIAL, Encoding.LOWER_UPPER_DIGIT_SPECIAL, Encoding.FIRST_TO_LOWER_SPECIAL]; +const pkgNameEncoding = [ + Encoding.UTF_8, + Encoding.ALL_TO_LOWER_SPECIAL, + Encoding.LOWER_UPPER_DIGIT_SPECIAL, +]; +const fieldNameEncoding = [ + Encoding.UTF_8, + Encoding.ALL_TO_LOWER_SPECIAL, + Encoding.LOWER_UPPER_DIGIT_SPECIAL, +]; +const typeNameEncoding = [ + Encoding.UTF_8, + Encoding.ALL_TO_LOWER_SPECIAL, + Encoding.LOWER_UPPER_DIGIT_SPECIAL, + Encoding.FIRST_TO_LOWER_SPECIAL, +]; + +function nonStructKindCode(typeId: number): number { + switch (typeId) { + case TypeId.ENUM: + return 0; + case TypeId.NAMED_ENUM: + return 1; + case TypeId.EXT: + return 2; + case TypeId.NAMED_EXT: + return 3; + case TypeId.TYPED_UNION: + return 4; + case TypeId.NAMED_UNION: + return 5; + default: + throw new Error(`unsupported TypeMeta kind ${typeId}`); + } +} + +function nonStructTypeId(kindCode: number): number { + switch (kindCode) { + case 0: + return TypeId.ENUM; + case 1: + return TypeId.NAMED_ENUM; + case 2: + return TypeId.EXT; + case 3: + return TypeId.NAMED_EXT; + case 4: + return TypeId.TYPED_UNION; + case 5: + return TypeId.NAMED_UNION; + default: + throw new Error(`unsupported TypeMeta kind code ${kindCode}`); + } +} export class TypeMeta { private headerHash: number | null; - private readonly hasFieldsMeta: boolean; private readonly compressed: boolean; - private constructor(private fields: FieldInfo[], private type: { - typeId: number; - typeName: string; - namespace: string; - userTypeId: number; - }, headerHash?: number, hasFieldsMeta?: boolean, compressed = false) { + private constructor( + private fields: FieldInfo[], + private type: { + typeId: number; + typeName: string; + namespace: string; + userTypeId: number; + }, + headerHash?: number, + compressed = false, + ) { this.headerHash = headerHash ?? null; - this.hasFieldsMeta = hasFieldsMeta ?? fields.length > 0; this.compressed = compressed; } @@ -242,9 +319,15 @@ export class TypeMeta { } else { fieldIdentifier = TypeMeta.toSnakeCase(field.getFieldName()); } - fieldInfos.push([field, fieldIdentifier, this.computeFieldTypeFingerprint(field, true, true)]); + fieldInfos.push([ + field, + fieldIdentifier, + this.computeFieldTypeFingerprint(field, true, true), + ]); } - fieldInfos = fieldInfos.sort((a, b) => TypeMeta.compareFieldSortKey(a[0], b[0])); + fieldInfos = fieldInfos.sort((a, b) => + TypeMeta.compareFieldSortKey(a[0], b[0]), + ); let result = ""; for (const fieldInfo of fieldInfos) { result += `${fieldInfo[1]},${fieldInfo[2]};`; @@ -252,7 +335,11 @@ export class TypeMeta { return result; } - private computeFieldTypeFingerprint(field: InnerFieldInfo, includeRef: boolean, includeNullable: boolean) { + private computeFieldTypeFingerprint( + field: InnerFieldInfo, + includeRef: boolean, + includeNullable: boolean, + ) { const ref = includeRef && field.trackingRef ? "1" : "0"; const nullable = includeNullable && field.nullable ? "1" : "0"; let result = `${this.fingerprintTypeId(field.typeId)},${ref},${nullable}`; @@ -267,7 +354,12 @@ export class TypeMeta { } private fingerprintTypeId(typeId: number) { - if (TypeId.userDefinedType(typeId) || typeId === TypeId.UNION || typeId === TypeId.TYPED_UNION || typeId === TypeId.NAMED_UNION) { + if ( + TypeId.userDefinedType(typeId) || + typeId === TypeId.UNION || + typeId === TypeId.TYPED_UNION || + typeId === TypeId.NAMED_UNION + ) { return TypeId.UNKNOWN; } return typeId; @@ -286,29 +378,38 @@ export class TypeMeta { let fieldInfo: FieldInfo[] = []; if (TypeId.structType(typeInfo.typeId)) { const structTypeInfo = typeInfo; - fieldInfo = Object.entries(structTypeInfo.options!.props!).map(([fieldName, typeInfo]) => { - let fieldTypeId = typeResolver ? typeResolver.computeTypeId(typeInfo) : typeInfo.typeId; - if (fieldTypeId === TypeId.NAMED_ENUM) { - fieldTypeId = TypeId.ENUM; - } else if (fieldTypeId === TypeId.NAMED_UNION || fieldTypeId === TypeId.TYPED_UNION) { - fieldTypeId = TypeId.UNION; - } - const { trackingRef, nullable, id, userTypeId, options } = typeInfo; - return new FieldInfo( - fieldName, - fieldTypeId, - userTypeId, - trackingRef, - nullable, - options!, - id - ); - }); + fieldInfo = Object.entries(structTypeInfo.options!.props!).map( + ([fieldName, typeInfo]) => { + let fieldTypeId = typeResolver + ? typeResolver.computeTypeId(typeInfo) + : typeInfo.typeId; + if (fieldTypeId === TypeId.NAMED_ENUM) { + fieldTypeId = TypeId.ENUM; + } else if ( + fieldTypeId === TypeId.NAMED_UNION || + fieldTypeId === TypeId.TYPED_UNION + ) { + fieldTypeId = TypeId.UNION; + } + const { trackingRef, nullable, id, userTypeId, options } = typeInfo; + return new FieldInfo( + fieldName, + fieldTypeId, + userTypeId, + trackingRef, + nullable, + options!, + id, + ); + }, + ); } fieldInfo = TypeMeta.groupFieldsByType(fieldInfo); return new TypeMeta(fieldInfo, { - typeId: typeResolver ? typeResolver.computeTypeId(typeInfo) : typeInfo.typeId, + typeId: typeResolver + ? typeResolver.computeTypeId(typeInfo) + : typeInfo.typeId, namespace: typeInfo.namespace, typeName: typeInfo.typeName, userTypeId: typeInfo.userTypeId ?? -1, @@ -327,7 +428,8 @@ export class TypeMeta { * Skip the type meta body bytes after the header has already been read. */ static skipBody(reader: BinaryReader, header: bigint) { - reader.readSkip(TypeMeta.readMetaSize(reader, header)); + const metaSize = TypeMeta.readMetaSize(reader, header); + reader.readSkip(metaSize); } static fromBytes(reader: BinaryReader): TypeMeta { @@ -339,35 +441,48 @@ export class TypeMeta { * by readHeader(). Used by ReadContext to avoid re-reading the header. */ static fromBytesAfterHeader(reader: BinaryReader, header: bigint): TypeMeta { - const compressed = (header & COMPRESS_META_FLAG) !== 0n; - if (compressed) { - throw new Error("compressed TypeMeta is not supported yet"); - } - const hasFieldsMeta = (header & HAS_FIELDS_META_FLAG) !== 0n; + TypeMeta.validateGlobalHeader(header); const metaSize = TypeMeta.readMetaSize(reader, header); + const compressed = false; const headerHash = Number(header >> HASH_SHIFT_BITS); const bodyStart = reader.readGetCursor(); - // Read class header const classHeader = reader.readUint8(); - let numFields = classHeader & SMALL_NUM_FIELDS_THRESHOLD; - - if (numFields === SMALL_NUM_FIELDS_THRESHOLD) { - numFields += reader.readVarUInt32(); - } + const isStruct = (classHeader & STRUCT_TYPEDEF_FLAG) !== 0; + let numFields = 0; let typeId: number; let userTypeId = -1; let namespace = ""; let typeName = ""; + let registerByName: boolean; + + if (isStruct) { + registerByName = (classHeader & REGISTER_BY_NAME_FLAG) !== 0; + const compatible = (classHeader & COMPATIBLE_TYPEDEF_FLAG) !== 0; + if (registerByName) { + typeId = compatible + ? TypeId.NAMED_COMPATIBLE_STRUCT + : TypeId.NAMED_STRUCT; + } else { + typeId = compatible ? TypeId.COMPATIBLE_STRUCT : TypeId.STRUCT; + } + numFields = classHeader & SMALL_NUM_FIELDS_THRESHOLD; + if (numFields === SMALL_NUM_FIELDS_THRESHOLD) { + numFields += reader.readVarUInt32(); + } + } else { + if ((classHeader & 0b01110000) !== 0) { + throw new Error("invalid TypeMeta kind header"); + } + typeId = nonStructTypeId(classHeader & 0b1111); + registerByName = TypeId.isNamedType(typeId); + } - if (classHeader & REGISTER_BY_NAME_FLAG) { - // Read namespace and type name + if (registerByName) { namespace = this.readPkgName(reader); typeName = this.readTypeName(reader); - typeId = TypeId.NAMED_STRUCT; // Default for named types } else { - typeId = reader.readUint8(); userTypeId = reader.readVarUInt32(); } @@ -377,6 +492,9 @@ export class TypeMeta { const fieldInfo = this.readFieldInfo(reader); fields.push(fieldInfo); } + if (!isStruct && fields.length !== 0) { + throw new Error("non-struct TypeMeta cannot carry field metadata"); + } // Create a basic TypeInfo for the decoded type const typeInfo = { @@ -388,16 +506,25 @@ export class TypeMeta { const consumed = reader.readGetCursor() - bodyStart; if (consumed !== metaSize) { - throw new Error(`unexpected TypeMeta body size: expected ${metaSize}, consumed ${consumed}`); + throw new Error( + `unexpected TypeMeta body size: expected ${metaSize}, consumed ${consumed}`, + ); } - - return new TypeMeta( - fields, - typeInfo, - headerHash, - hasFieldsMeta, - compressed + TypeMeta.validateParsedBodyHash( + header, + reader.bufferRefAt(bodyStart, metaSize), ); + + return new TypeMeta(fields, typeInfo, headerHash, compressed); + } + + private static validateGlobalHeader(header: bigint) { + if ((header & RESERVED_META_FLAGS) !== 0n) { + throw new Error("invalid TypeMeta global header"); + } + if ((header & COMPRESS_META_FLAG) !== 0n) { + throw new Error("compressed TypeMeta is not supported yet"); + } } private static readMetaSize(reader: BinaryReader, header: bigint): number { @@ -408,6 +535,14 @@ export class TypeMeta { return metaSize; } + private static validateParsedBodyHash(header: bigint, body: Uint8Array) { + const expectedHeaderHash = TypeMeta.headerHashBits(body); + const actualHeaderHash = header & HEADER_HASH_MASK; + if (expectedHeaderHash !== actualHeaderHash) { + throw new Error("TypeMeta metadata hash mismatch"); + } + } + private static readFieldInfo(reader: BinaryReader): FieldInfo { const header = reader.readInt8(); const encodingFlags = (header >>> 6) & 0b11; @@ -431,14 +566,29 @@ export class TypeMeta { // Read field name const encoding = FieldInfo.u8ToEncoding(encodingFlags); - fieldName = fieldDecoder.decode(reader, size + 1, encoding || Encoding.UTF_8); + fieldName = fieldDecoder.decode( + reader, + size + 1, + encoding || Encoding.UTF_8, + ); fieldName = TypeMeta.lowerUnderscoreToLowerCamelCase(fieldName); } - return new FieldInfo(fieldName, typeId, userTypeId, trackingRef, nullable, options, fieldId); + return new FieldInfo( + fieldName, + typeId, + userTypeId, + trackingRef, + nullable, + options, + fieldId, + ); } - private static readTypeId(reader: BinaryReader, readFlag = false): InnerFieldInfo { + private static readTypeId( + reader: BinaryReader, + readFlag = false, + ): InnerFieldInfo { const options: InnerFieldInfoOptions = {}; let nullable = false; let trackingRef = false; @@ -449,7 +599,10 @@ export class TypeMeta { typeId = typeId >> 2; if (typeId === TypeId.NAMED_ENUM) { typeId = TypeId.ENUM; - } else if (typeId === TypeId.NAMED_UNION || typeId === TypeId.TYPED_UNION) { + } else if ( + typeId === TypeId.NAMED_UNION || + typeId === TypeId.TYPED_UNION + ) { typeId = TypeId.UNION; } this.readNestedTypeInfo(reader, typeId, options); @@ -465,7 +618,11 @@ export class TypeMeta { return { typeId, userTypeId: -1, nullable, trackingRef, options }; } - private static readNestedTypeInfo(reader: BinaryReader, typeId: number, options: InnerFieldInfoOptions) { + private static readNestedTypeInfo( + reader: BinaryReader, + typeId: number, + options: InnerFieldInfoOptions, + ) { switch (typeId) { case TypeId.LIST: options.inner = this.readTypeId(reader, true); @@ -490,7 +647,11 @@ export class TypeMeta { return this.readName(reader, typeNameEncoding, typeNameDecoder); } - private static readName(reader: BinaryReader, encodings: Encoding[], decoder: MetaStringDecoder): string { + private static readName( + reader: BinaryReader, + encodings: Encoding[], + decoder: MetaStringDecoder, + ): string { const header = reader.readUint8(); const encodingIndex = header & 0b11; let size = (header >> 2) & 0b111111; @@ -523,7 +684,9 @@ export class TypeMeta { return this.fields; } - remapFieldNames(localProps: Record | undefined): FieldInfo[] { + remapFieldNames( + localProps: Record | undefined, + ): FieldInfo[] { if (!localProps) { return this.fields; } @@ -566,7 +729,7 @@ export class TypeMeta { fieldInfo.trackingRef, fieldInfo.nullable, fieldInfo.options, - fieldInfo.fieldId + fieldInfo.fieldId, ); }); } @@ -578,21 +741,40 @@ export class TypeMeta { const writer = new BinaryWriter({}); writer.writeUint8(-1); // placeholder for header, update later - let currentClassHeader = this.fields.length; + const isStruct = TypeId.structType(this.type.typeId); + if (!isStruct && this.fields.length !== 0) { + throw new Error( + `non-struct TypeMeta ${this.type.typeId} cannot carry field metadata`, + ); + } - if (this.fields.length >= SMALL_NUM_FIELDS_THRESHOLD) { - currentClassHeader = SMALL_NUM_FIELDS_THRESHOLD; - writer.writeVarUInt32(this.fields.length - SMALL_NUM_FIELDS_THRESHOLD); + let currentClassHeader: number; + if (isStruct) { + currentClassHeader = + STRUCT_TYPEDEF_FLAG | + Math.min(this.fields.length, SMALL_NUM_FIELDS_THRESHOLD); + if ( + this.type.typeId === TypeId.COMPATIBLE_STRUCT || + this.type.typeId === TypeId.NAMED_COMPATIBLE_STRUCT + ) { + currentClassHeader |= COMPATIBLE_TYPEDEF_FLAG; + } + if (this.fields.length >= SMALL_NUM_FIELDS_THRESHOLD) { + writer.writeVarUInt32(this.fields.length - SMALL_NUM_FIELDS_THRESHOLD); + } + if (TypeId.isNamedType(this.type.typeId)) { + currentClassHeader |= REGISTER_BY_NAME_FLAG; + } + } else { + currentClassHeader = nonStructKindCode(this.type.typeId); } if (!TypeId.isNamedType(this.type.typeId)) { - writer.writeUint8(this.type.typeId); if (this.type.userTypeId === undefined || this.type.userTypeId === -1) { throw new Error(`userTypeId required for typeId ${this.type.typeId}`); } writer.writeVarUInt32(this.type.userTypeId); } else { - currentClassHeader |= REGISTER_BY_NAME_FLAG; const ns = this.type.namespace; const typename = this.type.typeName; this.writePkgName(writer, ns); @@ -603,12 +785,13 @@ export class TypeMeta { writer.setUint8Position(0, currentClassHeader); // Write fields info - this.writeFieldsInfo(writer, this.fields); + if (isStruct) { + this.writeFieldsInfo(writer, this.fields); + } const buffer = writer.dump(); - // For now, skip compression and just add header - return this.prependHeader(buffer, false, this.fields.length > 0); + return this.prependHeader(buffer, false); } writePkgName(writer: BinaryWriter, pkg: string) { @@ -619,7 +802,10 @@ export class TypeMeta { } writeTypeName(writer: BinaryWriter, typeName: string) { - const metaString = typeNameEncoder.encodeByEncodings(typeName, typeNameEncoding); + const metaString = typeNameEncoder.encodeByEncodings( + typeName, + typeNameEncoding, + ); const encoded = metaString.getBytes(); const encoding = metaString.getEncoding(); this.writeName(writer, encoded, typeNameEncoding.indexOf(encoding)); @@ -660,14 +846,19 @@ export class TypeMeta { encodingFlags = 3; // TAG_ID encoding } else { // Convert camelCase to snake_case for xlang compatibility - const fieldName = TypeMeta.lowerCamelToLowerUnderscore(fieldInfo.getFieldName()); - const metaString = fieldEncoder.encodeByEncodings(fieldName, fieldNameEncoding); + const fieldName = TypeMeta.lowerCamelToLowerUnderscore( + fieldInfo.getFieldName(), + ); + const metaString = fieldEncoder.encodeByEncodings( + fieldName, + fieldNameEncoding, + ); encodingFlags = fieldNameEncoding.indexOf(metaString.getEncoding()); encoded = metaString.getBytes(); size = encoded.length - 1; } - header |= (encodingFlags << 6); + header |= encodingFlags << 6; const bigSize = size >= FIELD_NAME_SIZE_THRESHOLD; if (bigSize) { @@ -675,7 +866,7 @@ export class TypeMeta { writer.writeInt8(header); writer.writeVarUint32Small7(size - FIELD_NAME_SIZE_THRESHOLD); } else { - header |= (size << 2); + header |= size << 2; writer.writeInt8(header); } @@ -746,7 +937,19 @@ export class TypeMeta { return result; } - private static buildHeader(buffer: Uint8Array, isCompressed: boolean, hasFieldsMeta: boolean) { + private static buildHeader(buffer: Uint8Array, isCompressed: boolean) { + let header = TypeMeta.headerHashBits(buffer); + if (isCompressed) { + header |= COMPRESS_META_FLAG; + } + header |= BigInt(Math.min(buffer.length, META_SIZE_MASKS)); + return { + header: BigInt.asUintN(64, header), + headerHash: Number(header >> HASH_SHIFT_BITS), + }; + } + + private static headerHashBits(buffer: Uint8Array) { const hash = x64hash128(buffer, 47); // Read the high 64 bits of the 128-bit MurmurHash3 as a SIGNED // int64 to match pyfory (`hash_buffer()[0]` unpacks `int64_t[0]`), @@ -757,33 +960,19 @@ export class TypeMeta { // is set -- unsigned BigInt can't go negative, so its sign-check // is always false and the abs is a no-op. Signed int64 here // matches the canonical behaviour of the other xlang bindings. - let header = hash.getBigInt64(0, false); - header = header << BigInt(64 - NUM_HASH_BITS); - // Arbitrary-precision abs + mask to 63 bits, matching pyfory's - // `abs(hash) & 0x7FFFFFFFFFFFFFFF`. The mask clears the sign bit - // so the COMPRESS_META_FLAG (bit 9) / HAS_FIELDS_META_FLAG - // (bit 8) / metaSize (low 8 bits) ORs below don't collide with - // residual hash bits. + let header = BigInt.asIntN( + 64, + hash.getBigInt64(0, false) << HASH_SHIFT_BITS, + ); if (header < 0n) { header = -header; } - header = header & 0x7FFFFFFFFFFFFFFFn; - if (isCompressed) { - header |= COMPRESS_META_FLAG; - } - if (hasFieldsMeta) { - header |= HAS_FIELDS_META_FLAG; - } - header |= BigInt(Math.min(buffer.length, META_SIZE_MASKS)); - return { - header: BigInt.asUintN(64, header), - headerHash: Number(header >> HASH_SHIFT_BITS), - }; + return BigInt.asUintN(64, header) & HEADER_HASH_MASK; } - private prependHeader(buffer: Uint8Array, isCompressed: boolean, hasFieldsMeta: boolean): Uint8Array { + private prependHeader(buffer: Uint8Array, isCompressed: boolean): Uint8Array { const metaSize = buffer.length; - const { header, headerHash } = TypeMeta.buildHeader(buffer, isCompressed, hasFieldsMeta); + const { header, headerHash } = TypeMeta.buildHeader(buffer, isCompressed); this.headerHash = headerHash; const writer = new BinaryWriter({}); @@ -806,7 +995,9 @@ export class TypeMeta { if (c >= "A" && c <= "Z") { if (i > 0) { const prevUpper = chars[i - 1] >= "A" && chars[i - 1] <= "Z"; - const nextUpperOrEnd = i + 1 >= chars.length || (chars[i + 1] >= "A" && chars[i + 1] <= "Z"); + const nextUpperOrEnd = + i + 1 >= chars.length || + (chars[i + 1] >= "A" && chars[i + 1] <= "Z"); if (!prevUpper || !nextUpperOrEnd) { result.push("_"); @@ -829,16 +1020,29 @@ export class TypeMeta { static compareFieldSortKey( a: { fieldName: string; fieldId?: number }, - b: { fieldName: string; fieldId?: number } + b: { fieldName: string; fieldId?: number }, ) { - if (a.fieldId !== undefined && a.fieldId !== null - && b.fieldId !== undefined && b.fieldId !== null) { + if ( + a.fieldId !== undefined && + a.fieldId !== null && + b.fieldId !== undefined && + b.fieldId !== null + ) { return a.fieldId - b.fieldId; } - return TypeMeta.getFieldSortKey(a).localeCompare(TypeMeta.getFieldSortKey(b)); + return TypeMeta.getFieldSortKey(a).localeCompare( + TypeMeta.getFieldSortKey(b), + ); } - static groupFieldsByType(typeInfos: Array): Array { + static groupFieldsByType< + T extends { + fieldName: string; + nullable?: boolean; + typeId: number; + fieldId?: number; + }, + >(typeInfos: Array): Array { const primitiveFields: Array = []; const nullablePrimitiveFields: Array = []; const internalTypeFields: Array = []; diff --git a/javascript/packages/core/lib/reader/index.ts b/javascript/packages/core/lib/reader/index.ts index 28be874d6a..02b7897d60 100644 --- a/javascript/packages/core/lib/reader/index.ts +++ b/javascript/packages/core/lib/reader/index.ts @@ -33,9 +33,7 @@ export class BinaryReader { /** Cached ArrayBuffer for fast-path DataView reuse. */ private cachedArrayBuffer: ArrayBuffer | null = null; - constructor(config: { - useSliceString?: boolean; - }) { + constructor(config: { useSliceString?: boolean }) { this.sliceStringEnable = isNodeEnv && config.useSliceString; } @@ -44,15 +42,25 @@ export class BinaryReader { this.byteLength = this.platformBuffer.byteLength; // Reuse DataView when the underlying ArrayBuffer, byteOffset, and byteLength are unchanged. const buf = this.platformBuffer.buffer; - if (buf !== this.cachedArrayBuffer - || !this.dataView - || this.dataView.byteOffset !== this.platformBuffer.byteOffset - || this.dataView.byteLength !== this.byteLength) { - this.dataView = new DataView(buf, this.platformBuffer.byteOffset, this.byteLength); + if ( + buf !== this.cachedArrayBuffer || + !this.dataView || + this.dataView.byteOffset !== this.platformBuffer.byteOffset || + this.dataView.byteLength !== this.byteLength + ) { + this.dataView = new DataView( + buf, + this.platformBuffer.byteOffset, + this.byteLength, + ); this.cachedArrayBuffer = buf; } if (this.sliceStringEnable) { - this.bigString = this.platformBuffer.toString("latin1", 0, this.byteLength); + this.bigString = this.platformBuffer.toString( + "latin1", + 0, + this.byteLength, + ); } this.cursor = 0; } @@ -184,13 +192,21 @@ export class BinaryReader { } stringUtf8(len: number) { - const result = this.platformBuffer.toString("utf8", this.cursor, this.cursor + len); + const result = this.platformBuffer.toString( + "utf8", + this.cursor, + this.cursor + len, + ); this.cursor += len; return result; } stringUtf16LE(len: number) { - const result = this.platformBuffer.toString("utf16le", this.cursor, this.cursor + len); + const result = this.platformBuffer.toString( + "utf16le", + this.cursor, + this.cursor + len, + ); this.cursor += len; return result; } @@ -243,6 +259,10 @@ export class BinaryReader { return result; } + bufferRefAt(start: number, len: number) { + return this.platformBuffer.subarray(start, start + len); + } + readVarUInt32() { // Reduce memory reads as much as possible. Reading a uint32 at once is far faster than reading four uint8s separately. if (this.byteLength - this.cursor >= 5) { @@ -262,7 +282,7 @@ export class BinaryReader { // 0xfe00000: 0b1111111 << 21 result |= (fourByteValue >>> 3) & 0xfe00000; if ((fourByteValue & 0x80000000) != 0) { - result |= (this.readUint8()) << 28; + result |= this.readUint8() << 28; } } } @@ -282,7 +302,7 @@ export class BinaryReader { result |= (byte & 0x7f) << 21; if ((byte & 0x80) != 0) { byte = this.readUint8(); - result |= (byte) << 28; + result |= byte << 28; } } } @@ -307,7 +327,7 @@ export class BinaryReader { if (this.byteLength - readIdx >= 5) { const fourByteValue = this.dataView.getUint32(readIdx, true); this.cursor = readIdx + 1; - let value = fourByteValue & 0x7F; + let value = fourByteValue & 0x7f; if ((fourByteValue & 0x80) !== 0) { this.cursor++; value |= (fourByteValue >>> 1) & 0x3f80; @@ -321,14 +341,18 @@ export class BinaryReader { } } - private continueReadVarUint32(readIdx: number, bulkRead: number, value: number): number { + private continueReadVarUint32( + readIdx: number, + bulkRead: number, + value: number, + ): number { readIdx++; value |= (bulkRead >>> 2) & 0x1fc000; if ((bulkRead & 0x800000) !== 0) { readIdx++; value |= (bulkRead >>> 3) & 0xfe00000; if ((bulkRead & 0x80000000) !== 0) { - value |= (this.dataView.getUint8(readIdx++) & 0x7F) << 28; + value |= (this.dataView.getUint8(readIdx++) & 0x7f) << 28; } } this.cursor = readIdx; @@ -340,7 +364,7 @@ export class BinaryReader { if (this.byteLength - readIdx >= 5) { const fourByteValue = this.dataView.getUint32(readIdx, true); this.cursor = readIdx + 1; - let result = fourByteValue & 0x7F; + let result = fourByteValue & 0x7f; if ((fourByteValue & 0x80) !== 0) { this.cursor++; result |= (fourByteValue >>> 1) & 0x3f80; @@ -354,14 +378,18 @@ export class BinaryReader { } } - private continueReadVarUint36(readIdx: number, fourByteValue: number, result: number): number { + private continueReadVarUint36( + readIdx: number, + fourByteValue: number, + result: number, + ): number { readIdx++; result |= (fourByteValue >>> 2) & 0x1fc000; if ((fourByteValue & 0x800000) !== 0) { readIdx++; result |= (fourByteValue >>> 3) & 0xfe00000; if ((fourByteValue & 0x80000000) !== 0) { - result |= (this.dataView.getUint8(readIdx++) & 0xFF) << 28; + result |= (this.dataView.getUint8(readIdx++) & 0xff) << 28; } } this.cursor = readIdx; @@ -370,19 +398,19 @@ export class BinaryReader { private readVarUint36Slow(): number { let b = this.readUint8(); - let result = b & 0x7F; + let result = b & 0x7f; if ((b & 0x80) !== 0) { b = this.readUint8(); - result |= (b & 0x7F) << 7; + result |= (b & 0x7f) << 7; if ((b & 0x80) !== 0) { b = this.readUint8(); - result |= (b & 0x7F) << 14; + result |= (b & 0x7f) << 14; if ((b & 0x80) !== 0) { b = this.readUint8(); - result |= (b & 0x7F) << 21; + result |= (b & 0x7f) << 21; if ((b & 0x80) !== 0) { b = this.readUint8(); - result |= (b & 0xFF) << 28; + result |= (b & 0xff) << 28; } } } @@ -427,7 +455,7 @@ export class BinaryReader { result |= (byte & 0x7fn) << 49n; if ((byte & 0x80n) != 0n) { byte = this.bigUInt8(); - result |= (byte) << 56n; + result |= byte << 56n; } } } @@ -457,7 +485,7 @@ export class BinaryReader { if ((byte & 0x80) != 0) { const h32 = this.dataView.getUint32(this.cursor++, true); byte = h32 & 0xff; - rh28 |= (byte & 0x7f); + rh28 |= byte & 0x7f; if ((byte & 0x80) != 0) { byte = (h32 >>> 8) & 0xff; this.cursor++; @@ -471,7 +499,11 @@ export class BinaryReader { this.cursor++; rh28 |= (byte & 0x7f) << 21; if ((byte & 0x80) != 0) { - return (BigInt(this.readUint8()) << 56n) | BigInt(rh28) << 28n | BigInt(rl28); + return ( + (BigInt(this.readUint8()) << 56n) | + (BigInt(rh28) << 28n) | + BigInt(rl28) + ); } } } @@ -481,7 +513,7 @@ export class BinaryReader { } } - return BigInt(rh28) << 28n | BigInt(rl28); + return (BigInt(rh28) << 28n) | BigInt(rl28); } readVarInt64() { @@ -492,8 +524,8 @@ export class BinaryReader { readFloat16() { const asUint16 = this.readUint16(); const sign = asUint16 >> 15; - const exponent = (asUint16 >> 10) & 0x1F; - const mantissa = asUint16 & 0x3FF; + const exponent = (asUint16 >> 10) & 0x1f; + const mantissa = asUint16 & 0x3ff; // IEEE 754-2008 if (exponent === 0) { @@ -514,7 +546,9 @@ export class BinaryReader { } } else { // Normalized number - return (sign === 0 ? 1 : -1) * (1 + mantissa * 2 ** -10) * 2 ** (exponent - 15); + return ( + (sign === 0 ? 1 : -1) * (1 + mantissa * 2 ** -10) * 2 ** (exponent - 15) + ); } } diff --git a/javascript/packages/core/lib/type.ts b/javascript/packages/core/lib/type.ts index 6c908c14ab..3f7fa5513d 100644 --- a/javascript/packages/core/lib/type.ts +++ b/javascript/packages/core/lib/type.ts @@ -208,9 +208,8 @@ export const TypeId = { } as const; export enum ConfigFlags { - isNullFlag = 1 << 0, - isCrossLanguageFlag = 1 << 1, - isOutOfBandFlag = 1 << 2, + isCrossLanguageFlag = 1 << 0, + isOutOfBandFlag = 1 << 1, } export type CustomSerializer = { diff --git a/javascript/test/datetime.test.ts b/javascript/test/datetime.test.ts index f68e1a8240..7f2ca76731 100644 --- a/javascript/test/datetime.test.ts +++ b/javascript/test/datetime.test.ts @@ -17,43 +17,38 @@ * under the License. */ -import Fory, { Type } from '../packages/core/index'; -import {describe, expect, test} from '@jest/globals'; -import { TypeId } from '../packages/core/lib/type'; +import Fory, { Type } from "../packages/core/index"; +import { describe, expect, test } from "@jest/globals"; +import { TypeId } from "../packages/core/lib/type"; -describe('datetime', () => { - test('should date work', () => { - - const fory = new Fory({ ref: true }); +describe("datetime", () => { + test("should date work", () => { + const fory = new Fory({ ref: true }); const now = new Date(); const input = fory.serialize(now); - const result: Date | null = fory.deserialize( - input - ); - expect(result?.getFullYear()).toEqual(now.getFullYear()) - expect(result?.getDate()).toEqual(now.getDate()) + const result: Date | null = fory.deserialize(input); + expect(result?.getFullYear()).toEqual(now.getFullYear()); + expect(result?.getDate()).toEqual(now.getDate()); }); - test('should datetime work', () => { + test("should datetime work", () => { const typeinfo = Type.struct("example.foo", { a: Type.timestamp(), b: Type.duration(), - }) - const fory = new Fory({ ref: true }); + }); + const fory = new Fory({ ref: true }); const serializer = fory.register(typeinfo).serializer; - const d = new Date('2021/10/20 09:13'); - const input = fory.serialize({ a: d, b: d}, serializer); - const result = fory.deserialize( - input - ); - expect(result).toEqual({ a: d, b: d.getTime() }) + const d = new Date("2021/10/20 09:13"); + const input = fory.serialize({ a: d, b: d }, serializer); + const result = fory.deserialize(input); + expect(result).toEqual({ a: d, b: d.getTime() }); }); - test('should use signed varint64 for date payloads', () => { + test("should use signed varint64 for date payloads", () => { const fory = new Fory({ ref: true }); const serializer = fory.register(Type.date()).serializer; const value = new Date(1969, 11, 31); const encoded = fory.serialize(value, serializer); - expect(Array.from(encoded)).toEqual([0x02, 0xff, TypeId.DATE, 0x01]); + expect(Array.from(encoded)).toEqual([0x01, 0xff, TypeId.DATE, 0x01]); expect(fory.deserialize(encoded, serializer)).toEqual(value); }); }); diff --git a/javascript/test/fory.test.ts b/javascript/test/fory.test.ts index e8b99d1351..19989ca8bc 100644 --- a/javascript/test/fory.test.ts +++ b/javascript/test/fory.test.ts @@ -25,13 +25,13 @@ describe('fory', () => { test('should deserialize null work', () => { const fory = new Fory(); - expect(fory.deserialize(new Uint8Array([1]))).toBe(null) + expect(fory.deserialize(new Uint8Array([1, 253]))).toBe(null) }); test('should deserialize xlang disable work', () => { const fory = new Fory(); try { - // bit 0 = null flag, bit 1 = xlang flag, bit 2 = oob flag + // bit 0 = xlang flag, bit 1 = oob flag // value 0 means xlang is disabled fory.deserialize(new Uint8Array([0])) throw new Error('unreachable code') @@ -43,9 +43,9 @@ describe('fory', () => { test('should deserialize oob mode work', () => { const fory = new Fory(); try { - // bit 0 = null flag, bit 1 = xlang flag, bit 2 = oob flag - // value 6 = xlang (2) + oob (4) = 6 - fory.deserialize(new Uint8Array([6])) + // bit 0 = xlang flag, bit 1 = oob flag + // value 3 = xlang (1) + oob (2) + fory.deserialize(new Uint8Array([3])) throw new Error('unreachable code') } catch (error) { expect(error.message).toBe('outofband mode is not supported now'); diff --git a/javascript/test/typemeta.test.ts b/javascript/test/typemeta.test.ts index 9e7e93400a..f5040f7572 100644 --- a/javascript/test/typemeta.test.ts +++ b/javascript/test/typemeta.test.ts @@ -17,50 +17,57 @@ * under the License. */ -import Fory, { Type } from '../packages/core/index'; -import { TypeMeta } from '../packages/core/lib/meta/TypeMeta'; -import { BinaryReader } from '../packages/core/lib/reader'; -import { describe, expect, test } from '@jest/globals'; - -const HAS_FIELDS_META_FLAG = 1n << 8n; -const COMPRESS_META_FLAG = 1n << 9n; -const META_SIZE_MASK = 0xFFn; -const HASH_SHIFT_BITS = 14n; - -describe('typemeta', () => { - test('writes TypeMeta header bits in the xlang layout', () => { +import Fory, { Type } from "../packages/core/index"; +import { TypeMeta } from "../packages/core/lib/meta/TypeMeta"; +import { BinaryReader } from "../packages/core/lib/reader"; +import { describe, expect, test } from "@jest/globals"; + +const COMPRESS_META_FLAG = 1n << 8n; +const RESERVED_META_FLAGS = 0b111n << 9n; +const META_SIZE_MASK = 0xffn; +const HASH_SHIFT_BITS = 12n; + +describe("typemeta", () => { + test("writes TypeMeta header bits in the xlang layout", () => { const typeInfo = Type.struct(7001, { fullName: Type.string().setId(1), age: Type.int32().setId(2), }); const bytes = TypeMeta.fromTypeInfo(typeInfo).toBytes(); - const header = new DataView(bytes.buffer, bytes.byteOffset, bytes.byteLength).getBigUint64(0, true); + const header = new DataView( + bytes.buffer, + bytes.byteOffset, + bytes.byteLength, + ).getBigUint64(0, true); expect(Number(header & META_SIZE_MASK)).toBe(bytes.length - 8); - expect((header & HAS_FIELDS_META_FLAG) !== 0n).toBe(true); expect((header & COMPRESS_META_FLAG) !== 0n).toBe(false); + expect((header & RESERVED_META_FLAGS) !== 0n).toBe(false); expect(header >> HASH_SHIFT_BITS).toBeGreaterThan(0n); + expect((bytes[8] & 0x80) !== 0).toBe(true); }); - test('keeps tagged direct payload order grouped instead of field-id ordered', () => { - const typeMeta = TypeMeta.fromTypeInfo(Type.struct(7005, { - stringValue: Type.string().setId(1), - mapValue: Type.map(Type.string(), Type.int32()).setId(2), - intValue: Type.int32().setId(10), - })); + test("keeps tagged direct payload order grouped instead of field-id ordered", () => { + const typeMeta = TypeMeta.fromTypeInfo( + Type.struct(7005, { + stringValue: Type.string().setId(1), + mapValue: Type.map(Type.string(), Type.int32()).setId(2), + intValue: Type.int32().setId(10), + }), + ); expect(typeMeta.getFieldInfo().map((field) => field.fieldName)).toEqual([ - 'intValue', - 'stringValue', - 'mapValue', + "intValue", + "stringValue", + "mapValue", ]); }); - test('writes the zero size extension when the TypeMeta body is exactly 0xFF bytes', () => { + test("writes the zero size extension when the TypeMeta body is exactly 0xFF bytes", () => { const typeMeta = TypeMeta.fromTypeInfo(Type.struct(7003, {})) as any; - const body = new Uint8Array(0xFF); - const bytes = typeMeta.prependHeader(body, false, false) as Uint8Array; + const body = new Uint8Array(0xff); + const bytes = typeMeta.prependHeader(body, false) as Uint8Array; const reader = new BinaryReader({}); expect(bytes).toHaveLength(8 + 1 + body.length); @@ -72,24 +79,63 @@ describe('typemeta', () => { expect(reader.readGetCursor()).toBe(bytes.length); }); - test('regenerates compatible named serializers when schema changes but field count stays the same', () => { + test("validates TypeMeta body hash before caching parsed metadata", () => { + const bytes = TypeMeta.fromTypeInfo( + Type.struct(7006, { + value: Type.string().setId(1), + }), + ).toBytes(); + const malformed = new Uint8Array(bytes); + malformed[malformed.length - 1] ^= 1; + + const parseReader = new BinaryReader({}); + parseReader.reset(malformed); + expect(() => TypeMeta.fromBytes(parseReader)).toThrow( + "TypeMeta metadata hash mismatch", + ); + + const skipReader = new BinaryReader({}); + skipReader.reset(bytes); + const header = TypeMeta.readHeader(skipReader); + TypeMeta.skipBody(skipReader, header); + expect(skipReader.readGetCursor()).toBe(bytes.length); + }); + + test("encodes extended id-registered struct field counts without the name bit", () => { + const fields: Record = {}; + for (let i = 0; i < 32; i++) { + fields[`field${i}`] = Type.int32().setId(i + 1); + } + const bytes = TypeMeta.fromTypeInfo(Type.struct(7201, fields)).toBytes(); + const reader = new BinaryReader({}); + const bodyOffset = typeMetaBodyOffset(bytes); + + expect(bytes[bodyOffset] & 0x1f).toBe(0x1f); + expect(bytes[bodyOffset] & 0x20).toBe(0); + + reader.reset(bytes); + const decoded = TypeMeta.fromBytes(reader); + expect(decoded.getFieldInfo()).toHaveLength(32); + }); + + test("regenerates compatible named serializers when schema changes but field count stays the same", () => { const writerFory = new Fory({ compatible: true }); const readerFory = new Fory({ compatible: true }); - const writerType = Type.struct('example.item', { + const writerType = Type.struct("example.item", { value: Type.string(), }); - const readerType = Type.struct('example.item', { + const readerType = Type.struct("example.item", { value: Type.int32(), }); - const bytes = writerFory.register(writerType).serialize({ value: 'hello' }); + const bytes = writerFory.register(writerType).serialize({ value: "hello" }); const result = readerFory.register(readerType).deserialize(bytes); - expect(result).toEqual({ value: 'hello' }); + expect(result).toEqual({ value: "hello" }); }); - test('remaps compatible tag-id fields onto local property names during regeneration', () => { + test("remaps compatible tag-id fields onto local property names during regeneration", () => { const writerFory = new Fory({ compatible: true }); const readerFory = new Fory({ compatible: true }); @@ -103,42 +149,42 @@ describe('typemeta', () => { }); const bytes = writerFory.register(writerType).serialize({ - fullName: 'Alice', - note: 'ally', + fullName: "Alice", + note: "ally", }); const result = readerFory.register(readerType).deserialize(bytes); expect(result).toEqual({ - name: 'Alice', - alias: 'ally', + name: "Alice", + alias: "ally", }); }); - test('keeps compatible named schema evolution working when field count differs', () => { + test("keeps compatible named schema evolution working when field count differs", () => { const writerFory = new Fory({ compatible: true }); const readerFory = new Fory({ compatible: true }); - const writerType = Type.struct('example.foo', { + const writerType = Type.struct("example.foo", { bar: Type.string(), bar2: Type.int32(), }); - const readerType = Type.struct('example.foo', { + const readerType = Type.struct("example.foo", { bar: Type.string(), }); const bytes = writerFory.register(writerType).serialize({ - bar: 'hello', + bar: "hello", bar2: 123, }); const result = readerFory.register(readerType).deserialize(bytes); expect(result).toEqual({ - bar: 'hello', + bar: "hello", bar2: 123, }); }); - test('remaps regenerated compatible field names onto local snake_case properties', () => { + test("remaps regenerated compatible field names onto local snake_case properties", () => { const writerFory = new Fory({ compatible: true }); const readerFory = new Fory({ compatible: true }); @@ -162,25 +208,27 @@ describe('typemeta', () => { const readerReg = readerFory.register(ReaderHolder); const value = new WriterHolder(); - value.animalMap.set('dog', 7); + value.animalMap.set("dog", 7); value.marker = 99; const result = readerReg.deserialize(writerReg.serialize(value)); expect(result).toBeInstanceOf(ReaderHolder); - expect(result.animal_map.get('dog')).toBe(7); - expect((result as ReaderHolder & { animalMap?: Map }).animalMap).toBeUndefined(); + expect(result.animal_map.get("dog")).toBe(7); + expect( + (result as ReaderHolder & { animalMap?: Map }).animalMap, + ).toBeUndefined(); expect((result as ReaderHolder & { marker?: number }).marker).toBe(99); }); - test('skips unknown named custom fields by falling back to any when no local field exists', () => { + test("skips unknown named custom fields by falling back to any when no local field exists", () => { const writerFory = new Fory({ compatible: true }); const readerFory = new Fory({ compatible: true }); class MyExt { id = 0; } - Type.ext('my_ext')(MyExt); + Type.ext("my_ext")(MyExt); const customSerializer = { write: (writeContext: any, value: MyExt) => { @@ -195,22 +243,22 @@ describe('typemeta', () => { readerFory.register(MyExt, customSerializer); class WriterWrapper { - note = ''; + note = ""; myExt = new MyExt(); } - Type.struct('example.wrapper', { + Type.struct("example.wrapper", { note: Type.string(), - myExt: Type.ext('my_ext'), + myExt: Type.ext("my_ext"), })(WriterWrapper); class EmptyWrapper {} - Type.struct('example.wrapper', {})(EmptyWrapper); + Type.struct("example.wrapper", {})(EmptyWrapper); const writerReg = writerFory.register(WriterWrapper); const readerReg = readerFory.register(EmptyWrapper); const value = new WriterWrapper(); - value.note = 'hello'; + value.note = "hello"; value.myExt.id = 42; const result = readerReg.deserialize(writerReg.serialize(value)); @@ -218,7 +266,7 @@ describe('typemeta', () => { expect(result).toBeInstanceOf(EmptyWrapper); }); - test('skips unknown compatible enum fields when regenerating an empty reader', () => { + test("skips unknown compatible enum fields when regenerating an empty reader", () => { const writerFory = new Fory({ compatible: true }); const readerFory = new Fory({ compatible: true }); @@ -251,7 +299,7 @@ describe('typemeta', () => { expect(result).toBeInstanceOf(EmptyStruct); }); - test('skips unknown enum and named custom fields together during compatible regeneration', () => { + test("skips unknown enum and named custom fields together during compatible regeneration", () => { const writerFory = new Fory({ compatible: true }); const readerFory = new Fory({ compatible: true }); @@ -261,13 +309,13 @@ describe('typemeta', () => { Blue: 2, White: 3, }; - writerFory.register(Type.enum('color', Color)); - readerFory.register(Type.enum('color', Color)); + writerFory.register(Type.enum("color", Color)); + readerFory.register(Type.enum("color", Color)); class MyExt { id = 0; } - Type.ext('my_ext')(MyExt); + Type.ext("my_ext")(MyExt); const customSerializer = { write: (writeContext: any, value: MyExt) => { @@ -284,7 +332,7 @@ describe('typemeta', () => { class MyStruct { id = 0; } - Type.struct('my_struct', { + Type.struct("my_struct", { id: Type.int32(), })(MyStruct); @@ -296,14 +344,14 @@ describe('typemeta', () => { myStruct = new MyStruct(); myExt = new MyExt(); } - Type.struct('my_wrapper', { - color: Type.enum('color', Color), - myStruct: Type.struct('my_struct'), - myExt: Type.ext('my_ext'), + Type.struct("my_wrapper", { + color: Type.enum("color", Color), + myStruct: Type.struct("my_struct"), + myExt: Type.ext("my_ext"), })(WriterWrapper); class EmptyWrapper {} - Type.struct('my_wrapper', {})(EmptyWrapper); + Type.struct("my_wrapper", {})(EmptyWrapper); const writerReg = writerFory.register(WriterWrapper); const readerReg = readerFory.register(EmptyWrapper); @@ -317,3 +365,13 @@ describe('typemeta', () => { expect(result).toBeInstanceOf(EmptyWrapper); }); }); + +function typeMetaBodyOffset(bytes: Uint8Array) { + const reader = new BinaryReader({}); + reader.reset(bytes); + const header = TypeMeta.readHeader(reader); + if ((header & META_SIZE_MASK) === META_SIZE_MASK) { + reader.readVarUInt32(); + } + return reader.readGetCursor(); +} diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py index 0c03d9b5e6..9583b729e4 100644 --- a/python/pyfory/_fory.py +++ b/python/pyfory/_fory.py @@ -41,7 +41,6 @@ NOT_NULL_INT64_FLAG = NOT_NULL_VALUE_FLAG & 0b11111111 | (INT64_TYPE_ID << 8) from pyfory.serialization import Buffer, Config, ENABLE_FORY_CYTHON_SERIALIZATION -from pyfory.utils import set_bit, get_bit, clear_bit from pyfory.context import WriteContext, ReadContext @@ -490,16 +489,10 @@ def _serialize( mask_index = buffer.get_writer_index() buffer.grow(1) buffer.set_writer_index(mask_index + 1) - buffer.put_int8(mask_index, 0) - if obj is None: - set_bit(buffer, mask_index, 0) - else: - clear_bit(buffer, mask_index, 0) - set_bit(buffer, mask_index, 1) + bitmap = 1 if self.xlang else 0 if buffer_callback is not None: - set_bit(buffer, mask_index, 2) - else: - clear_bit(buffer, mask_index, 2) + bitmap |= 2 + buffer.put_int8(mask_index, bitmap) write_context.write_ref(obj) return buffer @@ -559,9 +552,12 @@ def _deserialize( read_context = self.read_context reader_index = buffer.get_reader_index() buffer.set_reader_index(reader_index + 1) - if get_bit(buffer, reader_index, 0): - return None - peer_out_of_band_enabled = get_bit(buffer, reader_index, 2) + bitmap = buffer.get_int8(reader_index) & 0xFF + if bitmap & ~0b11: + raise ValueError(f"Unsupported root header bitmap 0x{bitmap:02x}") + if bool(bitmap & 1) != self.xlang: + raise ValueError("Header bitmap mismatch at xlang bit") + peer_out_of_band_enabled = bool(bitmap & 2) if peer_out_of_band_enabled: assert buffers is not None, "buffers shouldn't be null when the serialized stream is produced with buffer_callback not null." else: diff --git a/python/pyfory/context.pxi b/python/pyfory/context.pxi index f3c94422af..1a29fe8de3 100644 --- a/python/pyfory/context.pxi +++ b/python/pyfory/context.pxi @@ -19,11 +19,17 @@ from pyfory.context import EncodedMetaString, EMPTY_ENCODED_META_STRING from pyfory.resolver import NULL_FLAG, REF_FLAG, NOT_NULL_VALUE_FLAG, REF_VALUE_FLAG +cdef extern from "fory/thirdparty/MurmurHash3.h": + void MurmurHash3_x64_128(const void *key, int len, uint32_t seed, void *out) nogil + + INT64_TYPE_ID = TypeId.VARINT64 FLOAT64_TYPE_ID = TypeId.FLOAT64 BOOL_TYPE_ID = TypeId.BOOL STRING_TYPE_ID = TypeId.STRING SMALL_STRING_THRESHOLD = 16 +cdef int32_t MAX_CACHED_META_STRINGS = 8192 +cdef int32_t MAX_CACHED_META_STRING_LENGTH = 2048 cdef inline uint64_t _mix64(uint64_t x): @@ -309,24 +315,37 @@ cdef class MetaStringWriter: cdef class MetaStringReader: cdef object shared_registry cdef vector[PyObject *] _c_dynamic_id_to_encoded_meta_string_vec + cdef vector[PyObject *] _c_owned_dynamic_encoded_meta_string_vec + cdef vector[PyObject *] _c_cached_encoded_meta_string_vec cdef flat_hash_map[int64_t, PyObject *] _c_hash_to_encoded_meta_string cdef flat_hash_map[int64_t, PyObject *] _c_hash_to_small_encoded_meta_string def __init__(self, shared_registry): self.shared_registry = shared_registry + def __dealloc__(self): + cdef PyObject *item + self.reset() + for item in self._c_cached_encoded_meta_string_vec: + Py_XDECREF(item) + self._c_cached_encoded_meta_string_vec.clear() + cpdef inline read_encoded_meta_string(self, Buffer buffer): cdef int32_t header = buffer.read_var_uint32() cdef int32_t length = header >> 1 cdef int64_t v1 = 0 cdef int64_t v2 = 0 cdef int64_t hashcode + cdef int64_t canonical_hash + cdef int64_t[2] hash_out cdef PyObject *encoded_meta_string_ptr cdef int32_t reader_index cdef int8_t encoding = 0 cdef bytes data + cdef bytes cached_data cdef object encoded_meta_string cdef pair[int64_t, PyObject *] *entry + cdef bint cache_entry if header & 0b1: if length <= 0: raise ValueError("Invalid dynamic metastring id 0") @@ -344,39 +363,91 @@ cdef class MetaStringReader: else: v1 = buffer.read_int64() v2 = buffer.read_bytes_as_int64(length - 8) + if encoding > 4: + raise ValueError(f"Unexpected encoding flag: {encoding}") hashcode = _hash_small_metastring(v1, v2, length, encoding) entry = self._c_hash_to_small_encoded_meta_string.find(hashcode) if entry == NULL or deref(entry).second == NULL: reader_index = buffer.get_reader_index() data = buffer.get_bytes(reader_index - length, length) + cache_entry = self._c_hash_to_small_encoded_meta_string.size() < MAX_CACHED_META_STRINGS encoded_meta_string = self.shared_registry.get_or_create_encoded_meta_string( data, hashcode, ) encoded_meta_string_ptr = encoded_meta_string - self._c_hash_to_small_encoded_meta_string[hashcode] = encoded_meta_string_ptr + if cache_entry: + Py_INCREF( encoded_meta_string_ptr) + self._c_cached_encoded_meta_string_vec.push_back(encoded_meta_string_ptr) + self._c_hash_to_small_encoded_meta_string[hashcode] = encoded_meta_string_ptr + else: + Py_INCREF( encoded_meta_string_ptr) + self._c_owned_dynamic_encoded_meta_string_vec.push_back(encoded_meta_string_ptr) else: encoded_meta_string_ptr = deref(entry).second else: hashcode = buffer.read_int64() + if (hashcode & 0xFF) > 4: + raise ValueError(f"Unexpected encoding flag: {hashcode & 0xFF}") reader_index = buffer.get_reader_index() buffer.check_bound(reader_index, length) - buffer.set_reader_index(reader_index + length) entry = self._c_hash_to_encoded_meta_string.find(hashcode) + if entry != NULL and deref(entry).second != NULL: + cached_data = ( deref(entry).second).data + if ( + PyBytes_GET_SIZE(cached_data) == length + and memcmp( + (buffer.c_buffer.data() + reader_index), + PyBytes_AS_STRING(cached_data), + length, + ) == 0 + ): + buffer.set_reader_index(reader_index + length) + encoded_meta_string_ptr = deref(entry).second + self._c_dynamic_id_to_encoded_meta_string_vec.push_back(encoded_meta_string_ptr) + return encoded_meta_string_ptr + MurmurHash3_x64_128( + (buffer.c_buffer.data() + reader_index), + length, + 47, + &hash_out[0], + ) + canonical_hash = ( + (( hash_out[0]) & 0xffffffffffffff00) + | (hashcode & 0xFF) + ) + if canonical_hash != hashcode: + raise ValueError("Malformed metastring hash") + buffer.set_reader_index(reader_index + length) + data = buffer.get_bytes(reader_index, length) + encoded_meta_string = self.shared_registry.get_or_create_encoded_meta_string( + data, + hashcode, + ) + encoded_meta_string_ptr = encoded_meta_string if entry == NULL or deref(entry).second == NULL: - data = buffer.get_bytes(reader_index, length) - encoded_meta_string = self.shared_registry.get_or_create_encoded_meta_string( - data, - hashcode, + cache_entry = ( + self._c_hash_to_encoded_meta_string.size() < MAX_CACHED_META_STRINGS + and length <= MAX_CACHED_META_STRING_LENGTH ) - encoded_meta_string_ptr = encoded_meta_string - self._c_hash_to_encoded_meta_string[hashcode] = encoded_meta_string_ptr + if cache_entry: + Py_INCREF( encoded_meta_string_ptr) + self._c_cached_encoded_meta_string_vec.push_back(encoded_meta_string_ptr) + self._c_hash_to_encoded_meta_string[hashcode] = encoded_meta_string_ptr + else: + Py_INCREF( encoded_meta_string_ptr) + self._c_owned_dynamic_encoded_meta_string_vec.push_back(encoded_meta_string_ptr) else: - encoded_meta_string_ptr = deref(entry).second + Py_INCREF( encoded_meta_string_ptr) + self._c_owned_dynamic_encoded_meta_string_vec.push_back(encoded_meta_string_ptr) self._c_dynamic_id_to_encoded_meta_string_vec.push_back(encoded_meta_string_ptr) return encoded_meta_string_ptr cpdef inline reset(self): + cdef PyObject *item + for item in self._c_owned_dynamic_encoded_meta_string_vec: + Py_XDECREF(item) + self._c_owned_dynamic_encoded_meta_string_vec.clear() self._c_dynamic_id_to_encoded_meta_string_vec.clear() diff --git a/python/pyfory/context.py b/python/pyfory/context.py index 9f5becfe6f..2439b9046d 100644 --- a/python/pyfory/context.py +++ b/python/pyfory/context.py @@ -31,6 +31,8 @@ from pyfory.types import TypeId SMALL_STRING_THRESHOLD = 16 +MAX_CACHED_META_STRINGS = 8192 +MAX_CACHED_META_STRING_LENGTH = 2048 INT64_TYPE_ID = TypeId.VARINT64 FLOAT64_TYPE_ID = TypeId.FLOAT64 BOOL_TYPE_ID = TypeId.BOOL @@ -154,7 +156,9 @@ def read_encoded_meta_string(self, buffer): def _read_small_meta_string(self, buffer, length: int): if length == 0: return EMPTY_ENCODED_META_STRING - encoding = buffer.read_int8() + encoding = buffer.read_int8() & 0xFF + if encoding > 4: + raise ValueError(f"Unexpected encoding flag: {encoding}") if length <= 8: v1 = buffer.read_bytes_as_int64(length) v2 = 0 @@ -169,20 +173,28 @@ def _read_small_meta_string(self, buffer, length: int): reader_index = buffer.get_reader_index() data = buffer.get_bytes(reader_index - length, length) encoded_meta_string = self.shared_registry.get_or_create_encoded_meta_string(data, hashcode) - self._small_encoded_meta_strings[key] = encoded_meta_string + if len(self._small_encoded_meta_strings) < MAX_CACHED_META_STRINGS: + self._small_encoded_meta_strings[key] = encoded_meta_string return encoded_meta_string def _read_big_meta_string(self, buffer, length: int): hashcode = buffer.read_int64() + encoding = hashcode & 0xFF + if encoding > 4: + raise ValueError(f"Unexpected encoding flag: {encoding}") reader_index = buffer.get_reader_index() buffer.check_bound(reader_index, length) data = buffer.get_bytes(reader_index, length) buffer.set_reader_index(reader_index + length) + canonical_hash = hash_meta_string_data(data, encoding) + if canonical_hash != hashcode: + raise ValueError("Malformed metastring hash") key = (hashcode, data) encoded_meta_string = self._hash_to_encoded_meta_strings.get(key) if encoded_meta_string is None: encoded_meta_string = self.shared_registry.get_or_create_encoded_meta_string(data, hashcode) - self._hash_to_encoded_meta_strings[key] = encoded_meta_string + if length <= MAX_CACHED_META_STRING_LENGTH and len(self._hash_to_encoded_meta_strings) < MAX_CACHED_META_STRINGS: + self._hash_to_encoded_meta_strings[key] = encoded_meta_string return encoded_meta_string def reset(self): diff --git a/python/pyfory/format/infer.py b/python/pyfory/format/infer.py index 2d610a7259..3865773dc9 100644 --- a/python/pyfory/format/infer.py +++ b/python/pyfory/format/infer.py @@ -63,14 +63,20 @@ def get_cls_by_schema(schema): else: from pyfory.type_util import record_class_factory - cls_ = record_class_factory("Record" + str(id(schema)), [schema.field(i).name for i in range(schema.num_fields)]) + return record_class_factory( + "Record" + str(id(schema)), + [schema.field(i).name for i in range(schema.num_fields)], + publish=False, + ) __type_map__[id_] = cls_ __schemas__[id_] = schema return __type_map__[id_] def remove_schema(schema): - __schemas__.pop(id(schema)) + id_ = id(schema) + __schemas__.pop(id_, None) + __type_map__.pop(id_, None) def reset(): diff --git a/python/pyfory/format/tests/test_infer.py b/python/pyfory/format/tests/test_infer.py index 630828a0a7..810e436148 100644 --- a/python/pyfory/format/tests/test_infer.py +++ b/python/pyfory/format/tests/test_infer.py @@ -23,10 +23,13 @@ from pyfory.format.infer import ( ForyTypeVisitor, from_arrow_schema, + get_cls_by_schema, infer_field, infer_schema, + remove_schema, to_arrow_schema, ) +import pyfory.format.infer as infer_module from pyfory.format import ( TypeId, ) @@ -51,6 +54,22 @@ class Foo: f7: Bar +class FakeField: + def __init__(self, name): + self.name = name + + +class FakeSchema: + def __init__(self, field_name, cls=None): + self.num_fields = 1 + self._field = FakeField(field_name) + self.metadata = {} if cls is None else {b"cls": f"{cls.__module__}.{cls.__name__}".encode()} + + def field(self, index): + assert index == 0 + return self._field + + def _infer_field(field_name, type_, types_path=None): return infer_field(field_name, type_, ForyTypeVisitor(), types_path=types_path) @@ -142,5 +161,28 @@ def test_row_format_rejects_xlang_array_carrier_annotations(): _infer_field("values", pyfory.Array[pyfory.Int32]) +def test_row_schema_without_class_metadata_is_not_globally_cached(): + infer_module.reset() + classes = [get_cls_by_schema(FakeSchema(f"f{i}")) for i in range(8)] + + assert len({id(cls) for cls in classes}) == len(classes) + assert infer_module.__type_map__ == {} + assert infer_module.__schemas__ == {} + for cls in classes: + assert cls.__name__ not in pyfory.type_util.__dict__ + + +def test_remove_schema_clears_row_class_cache(): + infer_module.reset() + schema = FakeSchema("f", Foo) + + assert get_cls_by_schema(schema) is Foo + assert id(schema) in infer_module.__type_map__ + assert id(schema) in infer_module.__schemas__ + remove_schema(schema) + assert id(schema) not in infer_module.__type_map__ + assert id(schema) not in infer_module.__schemas__ + + if __name__ == "__main__": test_infer_class_schema() diff --git a/python/pyfory/meta/typedef.py b/python/pyfory/meta/typedef.py index 5529c433e0..0a444d1ce9 100644 --- a/python/pyfory/meta/typedef.py +++ b/python/pyfory/meta/typedef.py @@ -26,30 +26,111 @@ from pyfory.type_util import get_homogeneous_tuple_elem_type, infer_field from pyfory.meta.metastring import Encoding from pyfory.type_util import infer_field_types +from pyfory.lib.mmh3 import hash_buffer # Constants from the specification SMALL_NUM_FIELDS_THRESHOLD = 0b11111 -REGISTER_BY_NAME_FLAG = 0b100000 +REGISTER_BY_NAME_FLAG = 0b00100000 +COMPATIBLE_TYPEDEF_FLAG = 0b01000000 +STRUCT_TYPEDEF_FLAG = 0b10000000 FIELD_NAME_SIZE_THRESHOLD = 0b1111 # 4-bit threshold for field names BIG_NAME_THRESHOLD = 0b111111 # 6-bit threshold for namespace/typename -COMPRESS_META_FLAG = 0b1 << 9 -HAS_FIELDS_META_FLAG = 0b1 << 8 +COMPRESS_META_FLAG = 0b1 << 8 +RESERVED_META_FLAGS = 0b111 << 9 META_SIZE_MASKS = 0xFF -NUM_HASH_BITS = 50 - -NAMESPACE_ENCODINGS = [Encoding.UTF_8, Encoding.ALL_TO_LOWER_SPECIAL, Encoding.LOWER_UPPER_DIGIT_SPECIAL] -TYPE_NAME_ENCODINGS = [Encoding.UTF_8, Encoding.ALL_TO_LOWER_SPECIAL, Encoding.LOWER_UPPER_DIGIT_SPECIAL, Encoding.FIRST_TO_LOWER_SPECIAL] +NUM_HASH_BITS = 52 +TYPEDEF_HASH_SHIFT = 64 - NUM_HASH_BITS +TYPEDEF_HASH_MASK = ((1 << 64) - 1) ^ ((1 << TYPEDEF_HASH_SHIFT) - 1) +_INT64_MIN = -(1 << 63) +_UINT64_MASK = (1 << 64) - 1 + +NAMESPACE_ENCODINGS = [ + Encoding.UTF_8, + Encoding.ALL_TO_LOWER_SPECIAL, + Encoding.LOWER_UPPER_DIGIT_SPECIAL, +] +TYPE_NAME_ENCODINGS = [ + Encoding.UTF_8, + Encoding.ALL_TO_LOWER_SPECIAL, + Encoding.LOWER_UPPER_DIGIT_SPECIAL, + Encoding.FIRST_TO_LOWER_SPECIAL, +] # Field name encoding constants FIELD_NAME_ENCODING_UTF8 = 0b00 FIELD_NAME_ENCODING_ALL_TO_LOWER_SPECIAL = 0b01 FIELD_NAME_ENCODING_LOWER_UPPER_DIGIT_SPECIAL = 0b10 FIELD_NAME_ENCODING_TAG_ID = 0b11 -FIELD_NAME_ENCODINGS = [Encoding.UTF_8, Encoding.ALL_TO_LOWER_SPECIAL, Encoding.LOWER_UPPER_DIGIT_SPECIAL] +FIELD_NAME_ENCODINGS = [ + Encoding.UTF_8, + Encoding.ALL_TO_LOWER_SPECIAL, + Encoding.LOWER_UPPER_DIGIT_SPECIAL, +] # TAG_ID encoding constants -TAG_ID_SIZE_THRESHOLD = 0b1111 # 4-bit threshold for tag IDs (0-14 inline, 15 = overflow) +TAG_ID_SIZE_THRESHOLD = ( + 0b1111 # 4-bit threshold for tag IDs (0-14 inline, 15 = overflow) +) + + +def is_struct_typedef_kind(type_id: int) -> bool: + return type_id in { + TypeId.STRUCT, + TypeId.COMPATIBLE_STRUCT, + TypeId.NAMED_STRUCT, + TypeId.NAMED_COMPATIBLE_STRUCT, + } + + +def is_named_typedef_kind(type_id: int) -> bool: + return type_id in { + TypeId.NAMED_STRUCT, + TypeId.NAMED_COMPATIBLE_STRUCT, + TypeId.NAMED_ENUM, + TypeId.NAMED_EXT, + TypeId.NAMED_UNION, + } + + +def xlang_non_struct_kind_code(type_id: int) -> int: + mapping = { + TypeId.ENUM: 0, + TypeId.NAMED_ENUM: 1, + TypeId.EXT: 2, + TypeId.NAMED_EXT: 3, + TypeId.TYPED_UNION: 4, + TypeId.NAMED_UNION: 5, + } + try: + return mapping[type_id] + except KeyError as exc: + raise ValueError(f"Unsupported TypeDef kind {type_id}") from exc + + +def xlang_non_struct_type_id(kind_code: int) -> int: + mapping = { + 0: TypeId.ENUM, + 1: TypeId.NAMED_ENUM, + 2: TypeId.EXT, + 3: TypeId.NAMED_EXT, + 4: TypeId.TYPED_UNION, + 5: TypeId.NAMED_UNION, + } + try: + return mapping[kind_code] + except KeyError as exc: + raise ValueError(f"Unsupported TypeDef kind code {kind_code}") from exc + + +def _typedef_header_hash(encoded: bytes) -> int: + hash_value = hash_buffer(encoded, 47)[0] + shifted = (hash_value << TYPEDEF_HASH_SHIFT) & _UINT64_MASK + if shifted >= (1 << 63): + shifted -= 1 << 64 + if shifted != _INT64_MIN and shifted < 0: + shifted = -shifted + return (shifted & _UINT64_MASK) & TYPEDEF_HASH_MASK class TypeDef: @@ -73,7 +154,9 @@ def __init__( self.encoded = encoded self.is_compressed = is_compressed - def create_fields_serializer(self, resolver, resolved_field_names=None, local_field_types=None): + def create_fields_serializer( + self, resolver, resolved_field_names=None, local_field_types=None + ): """Create serializers for each field. Args: @@ -84,12 +167,18 @@ def create_fields_serializer(self, resolver, resolved_field_names=None, local_fi """ field_types = local_field_types if field_types is None: - field_types = infer_field_types(self.cls, field_nullable=resolver.field_nullable) + field_types = infer_field_types( + self.cls, field_nullable=resolver.field_nullable + ) serializers = [] for i, field_info in enumerate(self.fields): # Use resolved name if provided, otherwise use original name - lookup_name = resolved_field_names[i] if resolved_field_names else field_info.name - serializer = field_info.field_type.create_serializer(resolver, field_types.get(lookup_name, None)) + lookup_name = ( + resolved_field_names[i] if resolved_field_names else field_info.name + ) + serializer = field_info.field_type.create_serializer( + resolver, field_types.get(lookup_name, None) + ) serializers.append(serializer) return serializers @@ -146,18 +235,21 @@ def _resolve_field_names_from_tag_ids(self): return resolved_names def create_serializer(self, resolver): - if self.type_id == TypeId.NAMED_EXT: - return resolver.get_type_info_by_name(self.namespace, self.typename).serializer - if self.type_id == TypeId.NAMED_ENUM: - try: - return resolver.get_type_info_by_name(self.namespace, self.typename).serializer - except Exception: - from pyfory.serializer import NonExistEnumSerializer - - return NonExistEnumSerializer(resolver) - if self.type_id == TypeId.NAMED_UNION: - return resolver.get_type_info_by_name(self.namespace, self.typename).serializer - + if not is_struct_typedef_kind(self.type_id): + if is_named_typedef_kind(self.type_id): + try: + return resolver.get_type_info_by_name( + self.namespace, self.typename + ).serializer + except Exception: + if self.type_id == TypeId.NAMED_ENUM: + from pyfory.serializer import NonExistEnumSerializer + + return NonExistEnumSerializer(resolver) + raise + return resolver.get_type_info_by_id( + self.type_id, user_type_id=self.user_type_id + ).serializer from pyfory.struct import DataClassSerializer from pyfory.struct import FieldInfo as StructFieldInfo from pyfory.type_util import get_type_hints, unwrap_optional @@ -166,9 +258,17 @@ def create_serializer(self, resolver): field_names = self._resolve_field_names_from_tag_ids() local_field_infos = build_field_infos(resolver, self.cls) - local_infos_by_name = {field_info.name: field_info for field_info in local_field_infos} - local_infos_by_tag = {field_info.tag_id: field_info for field_info in local_field_infos if field_info.tag_id >= 0} - local_field_types = infer_field_types(self.cls, field_nullable=resolver.field_nullable) + local_infos_by_name = { + field_info.name: field_info for field_info in local_field_infos + } + local_infos_by_tag = { + field_info.tag_id: field_info + for field_info in local_field_infos + if field_info.tag_id >= 0 + } + local_field_types = infer_field_types( + self.cls, field_nullable=resolver.field_nullable + ) type_hints = get_type_hints(self.cls) runtime_field_infos = [] for i, field_info in enumerate(self.fields): @@ -183,7 +283,9 @@ def create_serializer(self, resolver): local_info.field_type if local_info is not None else None, ) type_hint = type_hints.get(resolved_name, typing.Any) - unwrapped_type, _ = unwrap_optional(type_hint, field_nullable=resolver.field_nullable) + unwrapped_type, _ = unwrap_optional( + type_hint, field_nullable=resolver.field_nullable + ) serializer = _create_compatible_field_serializer( resolver, resolved_name, @@ -244,7 +346,9 @@ def _snake_to_camel(s: str) -> str: class FieldInfo: - def __init__(self, name: str, field_type: "FieldType", defined_class: str, tag_id: int = -1): + def __init__( + self, name: str, field_type: "FieldType", defined_class: str, tag_id: int = -1 + ): self.name = name self.field_type = field_type self.defined_class = defined_class @@ -308,24 +412,45 @@ def read(cls, buffer: Buffer, resolver): is_tracking_ref = (xtype_id & 0b1) != 0 is_nullable = (xtype_id & 0b10) != 0 xtype_id = xtype_id >> 2 - return cls.read_with_type(buffer, resolver, xtype_id, is_nullable, is_tracking_ref) + return cls.read_with_type( + buffer, resolver, xtype_id, is_nullable, is_tracking_ref + ) @classmethod - def read_with_type(cls, buffer: Buffer, resolver, xtype_id: int, is_nullable: bool, is_tracking_ref: bool): + def read_with_type( + cls, + buffer: Buffer, + resolver, + xtype_id: int, + is_nullable: bool, + is_tracking_ref: bool, + ): user_type_id = NO_USER_TYPE_ID if xtype_id in [TypeId.LIST, TypeId.SET]: element_type = cls.read(buffer, resolver) - return CollectionFieldType(xtype_id, True, is_nullable, is_tracking_ref, element_type) + return CollectionFieldType( + xtype_id, True, is_nullable, is_tracking_ref, element_type + ) elif xtype_id == TypeId.MAP: key_type = cls.read(buffer, resolver) value_type = cls.read(buffer, resolver) - return MapFieldType(xtype_id, True, is_nullable, is_tracking_ref, key_type, value_type) + return MapFieldType( + xtype_id, True, is_nullable, is_tracking_ref, key_type, value_type + ) elif xtype_id == TypeId.UNKNOWN: - return DynamicFieldType(xtype_id, False, is_nullable, is_tracking_ref, user_type_id=user_type_id) + return DynamicFieldType( + xtype_id, False, is_nullable, is_tracking_ref, user_type_id=user_type_id + ) else: # For primitive types, determine if they are monomorphic based on the type is_monomorphic = not is_polymorphic_type(xtype_id) - return FieldType(xtype_id, is_monomorphic, is_nullable, is_tracking_ref, user_type_id=user_type_id) + return FieldType( + xtype_id, + is_monomorphic, + is_nullable, + is_tracking_ref, + user_type_id=user_type_id, + ) def create_serializer(self, resolver, type_): # Handle list wrapper @@ -477,7 +602,13 @@ def __init__( is_tracking_ref: bool, user_type_id: int = NO_USER_TYPE_ID, ): - super().__init__(type_id, is_monomorphic, is_nullable, is_tracking_ref, user_type_id=user_type_id) + super().__init__( + type_id, + is_monomorphic, + is_nullable, + is_tracking_ref, + user_type_id=user_type_id, + ) def create_serializer(self, resolver, type_): # For dynamic field types (UNKNOWN, STRUCT, etc.), default to None so @@ -486,10 +617,10 @@ def create_serializer(self, resolver, type_): # to write/read the union payload correctly. if isinstance(type_, list): type_ = type_[0] - assert not is_union_type(self.type_id), ( - "Union fields don't write field type info, \ + assert not is_union_type( + self.type_id + ), "Union fields don't write field type info, \ they are not dynamic field types" - ) if self.type_id != TypeId.UNKNOWN: return FieldType.create_serializer(self, resolver, type_) return None @@ -517,7 +648,9 @@ def __repr__(self): ) -def _payload_shape_matches(remote_field_type: FieldType, local_field_type: FieldType) -> bool: +def _payload_shape_matches( + remote_field_type: FieldType, local_field_type: FieldType +) -> bool: if local_field_type is None: return False remote_type_id = remote_field_type.type_id @@ -527,16 +660,22 @@ def _payload_shape_matches(remote_field_type: FieldType, local_field_type: Field if remote_type_id != local_type_id: return False if remote_type_id in (TypeId.LIST, TypeId.SET): - return _payload_shape_matches(remote_field_type.element_type, local_field_type.element_type) + return _payload_shape_matches( + remote_field_type.element_type, local_field_type.element_type + ) if remote_type_id == TypeId.MAP: - return _payload_shape_matches(remote_field_type.key_type, local_field_type.key_type) and _payload_shape_matches( + return _payload_shape_matches( + remote_field_type.key_type, local_field_type.key_type + ) and _payload_shape_matches( remote_field_type.value_type, local_field_type.value_type, ) return True -def _payload_shape_needs_local_carrier(remote_field_type: FieldType, local_field_type: FieldType) -> bool: +def _payload_shape_needs_local_carrier( + remote_field_type: FieldType, local_field_type: FieldType +) -> bool: remote_type_id = remote_field_type.type_id local_type_id = local_field_type.type_id if _is_bytes_uint8_array_pair(remote_type_id, local_type_id): @@ -546,12 +685,16 @@ def _payload_shape_needs_local_carrier(remote_field_type: FieldType, local_field if remote_type_id in _ARRAY_TYPE_IDS: return True if remote_type_id in (TypeId.LIST, TypeId.SET): - return _payload_shape_needs_local_carrier(remote_field_type.element_type, local_field_type.element_type) + return _payload_shape_needs_local_carrier( + remote_field_type.element_type, local_field_type.element_type + ) if remote_type_id == TypeId.MAP: return _payload_shape_needs_local_carrier( remote_field_type.key_type, local_field_type.key_type, - ) or _payload_shape_needs_local_carrier(remote_field_type.value_type, local_field_type.value_type) + ) or _payload_shape_needs_local_carrier( + remote_field_type.value_type, local_field_type.value_type + ) return False @@ -559,8 +702,12 @@ def _create_local_typehint_serializer(resolver, field_name, type_hint): from pyfory.struct import StructFieldSerializerVisitor from pyfory.type_util import infer_field, unwrap_optional - unwrapped_type, _ = unwrap_optional(type_hint, field_nullable=resolver.field_nullable) - return infer_field(field_name, unwrapped_type, StructFieldSerializerVisitor(resolver)) + unwrapped_type, _ = unwrap_optional( + type_hint, field_nullable=resolver.field_nullable + ) + return infer_field( + field_name, unwrapped_type, StructFieldSerializerVisitor(resolver) + ) def _create_compatible_field_serializer( @@ -571,7 +718,9 @@ def _create_compatible_field_serializer( local_field_type: typing.Optional[FieldType], local_declared_type, ): - if _payload_shape_matches(remote_field_type, local_field_type) and _payload_shape_needs_local_carrier(remote_field_type, local_field_type): + if _payload_shape_matches( + remote_field_type, local_field_type + ) and _payload_shape_needs_local_carrier(remote_field_type, local_field_type): serializer = _create_local_typehint_serializer(resolver, field_name, type_hint) if serializer is not None: return serializer @@ -581,7 +730,9 @@ def _create_compatible_field_serializer( _SIGNED_INT32_TYPE_IDS = frozenset((TypeId.INT32, TypeId.VARINT32)) _SIGNED_INT64_TYPE_IDS = frozenset((TypeId.INT64, TypeId.VARINT64, TypeId.TAGGED_INT64)) _UNSIGNED_INT32_TYPE_IDS = frozenset((TypeId.UINT32, TypeId.VAR_UINT32)) -_UNSIGNED_INT64_TYPE_IDS = frozenset((TypeId.UINT64, TypeId.VAR_UINT64, TypeId.TAGGED_UINT64)) +_UNSIGNED_INT64_TYPE_IDS = frozenset( + (TypeId.UINT64, TypeId.VAR_UINT64, TypeId.TAGGED_UINT64) +) _INT_TYPE_DOMAINS = {type_id: (True, 32) for type_id in _SIGNED_INT32_TYPE_IDS} _INT_TYPE_DOMAINS.update({type_id: (True, 64) for type_id in _SIGNED_INT64_TYPE_IDS}) _INT_TYPE_DOMAINS.update({type_id: (False, 32) for type_id in _UNSIGNED_INT32_TYPE_IDS}) @@ -594,20 +745,26 @@ def _create_compatible_field_serializer( } -def _requires_nullable_validation(remote_field_type: FieldType, local_field_type: FieldType) -> bool: +def _requires_nullable_validation( + remote_field_type: FieldType, local_field_type: FieldType +) -> bool: return remote_field_type.is_nullable and not local_field_type.is_nullable def _is_bytes_uint8_array_pair(remote_type_id: int, local_type_id: int) -> bool: - return (remote_type_id == TypeId.BINARY and local_type_id == TypeId.UINT8_ARRAY) or ( - remote_type_id == TypeId.UINT8_ARRAY and local_type_id == TypeId.BINARY - ) + return ( + remote_type_id == TypeId.BINARY and local_type_id == TypeId.UINT8_ARRAY + ) or (remote_type_id == TypeId.UINT8_ARRAY and local_type_id == TypeId.BINARY) -def _field_type_assignment(remote_field_type: FieldType, local_field_type: FieldType) -> typing.Tuple[bool, bool]: +def _field_type_assignment( + remote_field_type: FieldType, local_field_type: FieldType +) -> typing.Tuple[bool, bool]: if local_field_type is None: return False, False - needs_validation = _requires_nullable_validation(remote_field_type, local_field_type) + needs_validation = _requires_nullable_validation( + remote_field_type, local_field_type + ) remote_type_id = remote_field_type.type_id local_type_id = local_field_type.type_id if local_type_id == TypeId.UNKNOWN: @@ -633,7 +790,10 @@ def _field_type_assignment(remote_field_type: FieldType, local_field_type: Field remote_field_type.value_type, local_field_type.value_type, ) - return key_assignable and value_assignable, needs_validation or key_needs_validation or value_needs_validation + return ( + key_assignable and value_assignable, + needs_validation or key_needs_validation or value_needs_validation, + ) if _is_bytes_uint8_array_pair(remote_type_id, local_type_id): return True, True remote_int_domain = _INT_TYPE_DOMAINS.get(remote_type_id) @@ -654,7 +814,9 @@ def _field_type_assignment(remote_field_type: FieldType, local_field_type: Field def plan_field_assignment( remote_field_type: FieldType, local_field_type: typing.Optional[FieldType] ) -> typing.Tuple[bool, typing.Optional[FieldType]]: - assignable, needs_validation = _field_type_assignment(remote_field_type, local_field_type) + assignable, needs_validation = _field_type_assignment( + remote_field_type, local_field_type + ) if not assignable: return False, None return True, local_field_type if needs_validation else None @@ -698,7 +860,12 @@ def _is_uint8_array_like(value) -> bool: if isinstance(value, array.array): return value.typecode == "B" np, ndarray, uint8_dtype = _numpy_uint8_type() - return np is not None and isinstance(value, ndarray) and value.ndim == 1 and value.dtype == uint8_dtype + return ( + np is not None + and isinstance(value, ndarray) + and value.ndim == 1 + and value.dtype == uint8_dtype + ) def _bytes_from_uint8_value(value) -> bytes: @@ -714,7 +881,9 @@ def _bytes_from_uint8_value(value) -> bytes: return value.tobytes() if _is_uint8_array_like(value): return value.tobytes() - raise TypeError(f"Expected bytes or array compatible value, got {type(value)!r}") + raise TypeError( + f"Expected bytes or array compatible value, got {type(value)!r}" + ) def _uint8_array_from_bytes(value): @@ -733,12 +902,16 @@ def is_value_assignable(value, local_field_type: FieldType) -> bool: if type_id in (TypeId.LIST, TypeId.SET): if not isinstance(value, (list, tuple, set)): return False - return all(is_value_assignable(element, local_field_type.element_type) for element in value) + return all( + is_value_assignable(element, local_field_type.element_type) + for element in value + ) if type_id == TypeId.MAP: if not isinstance(value, dict): return False return all( - is_value_assignable(key, local_field_type.key_type) and is_value_assignable(map_value, local_field_type.value_type) + is_value_assignable(key, local_field_type.key_type) + and is_value_assignable(map_value, local_field_type.value_type) for key, map_value in value.items() ) if type_id in _INT_TYPE_DOMAINS: @@ -765,12 +938,20 @@ def coerce_assignable_value(value, local_field_type: FieldType): if type_id == TypeId.UINT8_ARRAY and _is_bytes_like(value): return _uint8_array_from_bytes(value) if type_id == TypeId.LIST: - return [coerce_assignable_value(element, local_field_type.element_type) for element in value] + return [ + coerce_assignable_value(element, local_field_type.element_type) + for element in value + ] if type_id == TypeId.SET: - return {coerce_assignable_value(element, local_field_type.element_type) for element in value} + return { + coerce_assignable_value(element, local_field_type.element_type) + for element in value + } if type_id == TypeId.MAP: return { - coerce_assignable_value(key, local_field_type.key_type): coerce_assignable_value(map_value, local_field_type.value_type) + coerce_assignable_value( + key, local_field_type.key_type + ): coerce_assignable_value(map_value, local_field_type.value_type) for key, map_value in value.items() } return value @@ -808,7 +989,9 @@ def build_field_infos(type_resolver, cls): for field_name in field_names: field_type_hint = type_hints.get(field_name, typing.Any) - unwrapped_type, is_optional = unwrap_optional(field_type_hint, field_nullable=field_nullable) + unwrapped_type, is_optional = unwrap_optional( + field_type_hint, field_nullable=field_nullable + ) # Get field metadata if available fory_meta = field_metas.get(field_name) @@ -836,12 +1019,24 @@ def build_field_infos(type_resolver, cls): tag_id = fory_meta.id if fory_meta is not None else -1 nullable_map[field_name] = is_nullable - field_type = build_field_type_with_ref(type_resolver, field_name, unwrapped_type, visitor, is_nullable, is_tracking_ref) + field_type = build_field_type_with_ref( + type_resolver, + field_name, + unwrapped_type, + visitor, + is_nullable, + is_tracking_ref, + ) field_info = FieldInfo(field_name, field_type, cls.__name__, tag_id) field_infos.append(field_info) field_types = infer_field_types(cls, field_nullable=field_nullable) - serializers = [field_info.field_type.create_serializer(type_resolver, field_types.get(field_info.name, None)) for field_info in field_infos] + serializers = [ + field_info.field_type.create_serializer( + type_resolver, field_types.get(field_info.name, None) + ) + for field_info in field_infos + ] # Get just the field names for sorting current_field_names = [fi.name for fi in field_infos] @@ -860,7 +1055,14 @@ def build_field_infos(type_resolver, cls): return new_field_infos -def build_field_type_with_ref(type_resolver, field_name: str, type_hint, visitor, is_nullable=False, is_tracking_ref=True): +def build_field_type_with_ref( + type_resolver, + field_name: str, + type_hint, + visitor, + is_nullable=False, + is_tracking_ref=True, +): """Build field type from type hint with explicit ref tracking control.""" type_ids = infer_field(field_name, type_hint, visitor) try: @@ -874,7 +1076,9 @@ def build_field_type_with_ref(type_resolver, field_name: str, type_hint, visitor type_hint=type_hint, ) except Exception as e: - raise TypeError(f"Error building field type for field: {field_name} with type hint: {type_hint} in class: {visitor.cls}") from e + raise TypeError( + f"Error building field type for field: {field_name} with type hint: {type_hint} in class: {visitor.cls}" + ) from e def build_field_type_from_type_ids_with_ref( @@ -904,9 +1108,17 @@ def build_field_type_from_type_ids_with_ref( elem_nullable = False elem_ref_override = None if type_hint is not None: - origin = typing.get_origin(type_hint) if hasattr(typing, "get_origin") else getattr(type_hint, "__origin__", None) + origin = ( + typing.get_origin(type_hint) + if hasattr(typing, "get_origin") + else getattr(type_hint, "__origin__", None) + ) if origin in (list, typing.List, set, typing.Set): - args = typing.get_args(type_hint) if hasattr(typing, "get_args") else getattr(type_hint, "__args__", ()) + args = ( + typing.get_args(type_hint) + if hasattr(typing, "get_args") + else getattr(type_hint, "__args__", ()) + ) if args: elem_hint, elem_ref_override = unwrap_ref(args[0]) elem_hint, elem_nullable = unwrap_optional(elem_hint) @@ -929,7 +1141,9 @@ def build_field_type_from_type_ids_with_ref( ) if elem_ref_override is not None: elem_type.tracking_ref_override = elem_ref_override - return CollectionFieldType(type_id, morphic, is_nullable, is_tracking_ref, elem_type) + return CollectionFieldType( + type_id, morphic, is_nullable, is_tracking_ref, elem_type + ) elif type_id == TypeId.MAP: key_hint = None value_hint = None @@ -938,9 +1152,17 @@ def build_field_type_from_type_ids_with_ref( key_ref_override = None value_ref_override = None if type_hint is not None: - origin = typing.get_origin(type_hint) if hasattr(typing, "get_origin") else getattr(type_hint, "__origin__", None) + origin = ( + typing.get_origin(type_hint) + if hasattr(typing, "get_origin") + else getattr(type_hint, "__origin__", None) + ) if origin in (dict, typing.Dict): - args = typing.get_args(type_hint) if hasattr(typing, "get_args") else getattr(type_hint, "__args__", ()) + args = ( + typing.get_args(type_hint) + if hasattr(typing, "get_args") + else getattr(type_hint, "__args__", ()) + ) if len(args) >= 2: key_hint, key_ref_override = unwrap_ref(args[0]) key_hint, key_nullable = unwrap_optional(key_hint) @@ -974,7 +1196,9 @@ def build_field_type_from_type_ids_with_ref( key_type.tracking_ref_override = key_ref_override if value_ref_override is not None: value_type.tracking_ref_override = value_ref_override - return MapFieldType(type_id, morphic, is_nullable, is_tracking_ref, key_type, value_type) + return MapFieldType( + type_id, morphic, is_nullable, is_tracking_ref, key_type, value_type + ) elif type_id in [ TypeId.UNKNOWN, TypeId.EXT, @@ -983,24 +1207,41 @@ def build_field_type_from_type_ids_with_ref( TypeId.COMPATIBLE_STRUCT, TypeId.NAMED_COMPATIBLE_STRUCT, ]: - return DynamicFieldType(type_id, False, is_nullable, is_tracking_ref, user_type_id=NO_USER_TYPE_ID) + return DynamicFieldType( + type_id, False, is_nullable, is_tracking_ref, user_type_id=NO_USER_TYPE_ID + ) else: if type_id <= 0 or type_id >= TypeId.BOUND: raise TypeError(f"Unknown type: {type_id} for field: {field_name}") # union/enum go here too - return FieldType(type_id, morphic, is_nullable, is_tracking_ref, user_type_id=NO_USER_TYPE_ID) + return FieldType( + type_id, morphic, is_nullable, is_tracking_ref, user_type_id=NO_USER_TYPE_ID + ) -def build_field_type(type_resolver, field_name: str, type_hint, visitor, is_nullable=False): +def build_field_type( + type_resolver, field_name: str, type_hint, visitor, is_nullable=False +): """Build field type from type hint.""" type_ids = infer_field(field_name, type_hint, visitor) try: - return build_field_type_from_type_ids(type_resolver, field_name, type_ids, visitor, is_nullable, type_hint=type_hint) + return build_field_type_from_type_ids( + type_resolver, + field_name, + type_ids, + visitor, + is_nullable, + type_hint=type_hint, + ) except Exception as e: - raise TypeError(f"Error building field type for field: {field_name} with type hint: {type_hint} in class: {visitor.cls}") from e + raise TypeError( + f"Error building field type for field: {field_name} with type hint: {type_hint} in class: {visitor.cls}" + ) from e -def build_field_type_from_type_ids(type_resolver, field_name: str, type_ids, visitor, is_nullable=False, type_hint=None): +def build_field_type_from_type_ids( + type_resolver, field_name: str, type_ids, visitor, is_nullable=False, type_hint=None +): from pyfory.type_util import unwrap_optional, unwrap_ref tracking_ref = type_resolver.track_ref @@ -1017,9 +1258,17 @@ def build_field_type_from_type_ids(type_resolver, field_name: str, type_ids, vis elem_hint = None elem_nullable = False if type_hint is not None: - origin = typing.get_origin(type_hint) if hasattr(typing, "get_origin") else getattr(type_hint, "__origin__", None) + origin = ( + typing.get_origin(type_hint) + if hasattr(typing, "get_origin") + else getattr(type_hint, "__origin__", None) + ) if origin in (list, typing.List, set, typing.Set): - args = typing.get_args(type_hint) if hasattr(typing, "get_args") else getattr(type_hint, "__args__", ()) + args = ( + typing.get_args(type_hint) + if hasattr(typing, "get_args") + else getattr(type_hint, "__args__", ()) + ) if args: elem_hint, _ = unwrap_ref(args[0]) elem_hint, elem_nullable = unwrap_optional(elem_hint) @@ -1036,16 +1285,26 @@ def build_field_type_from_type_ids(type_resolver, field_name: str, type_ids, vis is_nullable=elem_nullable, type_hint=elem_hint, ) - return CollectionFieldType(type_id, morphic, is_nullable, tracking_ref, elem_type) + return CollectionFieldType( + type_id, morphic, is_nullable, tracking_ref, elem_type + ) elif type_id == TypeId.MAP: key_hint = None value_hint = None key_nullable = False value_nullable = False if type_hint is not None: - origin = typing.get_origin(type_hint) if hasattr(typing, "get_origin") else getattr(type_hint, "__origin__", None) + origin = ( + typing.get_origin(type_hint) + if hasattr(typing, "get_origin") + else getattr(type_hint, "__origin__", None) + ) if origin in (dict, typing.Dict): - args = typing.get_args(type_hint) if hasattr(typing, "get_args") else getattr(type_hint, "__args__", ()) + args = ( + typing.get_args(type_hint) + if hasattr(typing, "get_args") + else getattr(type_hint, "__args__", ()) + ) if len(args) >= 2: key_hint, _ = unwrap_ref(args[0]) key_hint, key_nullable = unwrap_optional(key_hint) @@ -1067,7 +1326,9 @@ def build_field_type_from_type_ids(type_resolver, field_name: str, type_ids, vis is_nullable=value_nullable, type_hint=value_hint, ) - return MapFieldType(type_id, morphic, is_nullable, tracking_ref, key_type, value_type) + return MapFieldType( + type_id, morphic, is_nullable, tracking_ref, key_type, value_type + ) elif type_id in [ TypeId.UNKNOWN, TypeId.EXT, @@ -1076,8 +1337,12 @@ def build_field_type_from_type_ids(type_resolver, field_name: str, type_ids, vis TypeId.COMPATIBLE_STRUCT, TypeId.NAMED_COMPATIBLE_STRUCT, ]: - return DynamicFieldType(type_id, False, is_nullable, tracking_ref, user_type_id=NO_USER_TYPE_ID) + return DynamicFieldType( + type_id, False, is_nullable, tracking_ref, user_type_id=NO_USER_TYPE_ID + ) else: if type_id <= 0 or type_id >= TypeId.BOUND: raise TypeError(f"Unknown type: {type_id} for field: {field_name}") - return FieldType(type_id, morphic, is_nullable, tracking_ref, user_type_id=NO_USER_TYPE_ID) + return FieldType( + type_id, morphic, is_nullable, tracking_ref, user_type_id=NO_USER_TYPE_ID + ) diff --git a/python/pyfory/meta/typedef_decoder.py b/python/pyfory/meta/typedef_decoder.py index b763519ad6..6a5670f4f4 100644 --- a/python/pyfory/meta/typedef_decoder.py +++ b/python/pyfory/meta/typedef_decoder.py @@ -28,16 +28,23 @@ from pyfory.meta.typedef import ( SMALL_NUM_FIELDS_THRESHOLD, REGISTER_BY_NAME_FLAG, + COMPATIBLE_TYPEDEF_FLAG, + STRUCT_TYPEDEF_FLAG, FIELD_NAME_SIZE_THRESHOLD, BIG_NAME_THRESHOLD, COMPRESS_META_FLAG, - HAS_FIELDS_META_FLAG, + RESERVED_META_FLAGS, META_SIZE_MASKS, + TYPEDEF_HASH_MASK, FIELD_NAME_ENCODINGS, NAMESPACE_ENCODINGS, TYPE_NAME_ENCODINGS, FIELD_NAME_ENCODING_TAG_ID, TAG_ID_SIZE_THRESHOLD, + is_struct_typedef_kind, + is_named_typedef_kind, + xlang_non_struct_type_id, + _typedef_header_hash, ) from pyfory.types import TypeId from pyfory._fory import NO_USER_TYPE_ID @@ -86,20 +93,20 @@ def decode_typedef(buffer: Buffer, resolver, header=None) -> TypeDef: header = buffer.read_int64() # Extract components from header + if header & RESERVED_META_FLAGS: + raise ValueError("Invalid TypeDef global header") meta_size = header & META_SIZE_MASKS - has_fields_meta = (header & HAS_FIELDS_META_FLAG) != 0 is_compressed = (header & COMPRESS_META_FLAG) != 0 + if is_compressed: + raise ValueError("Compressed xlang TypeDef is not supported") # If meta size is at maximum, read additional size if meta_size == META_SIZE_MASKS: meta_size += buffer.read_var_uint32() # Read meta data - meta_data = buffer.read_bytes(meta_size) - - # Decompress if needed - if is_compressed: - meta_data = resolver.get_meta_compressor().decompress(meta_data) + encoded_meta_data = buffer.read_bytes(meta_size) + meta_data = encoded_meta_data # Create a new buffer for meta data meta_buffer = Buffer(meta_data) @@ -107,22 +114,34 @@ def decode_typedef(buffer: Buffer, resolver, header=None) -> TypeDef: # Read meta header meta_header = meta_buffer.read_uint8() - # Extract number of fields - num_fields = meta_header & 0b11111 - if num_fields == SMALL_NUM_FIELDS_THRESHOLD: - num_fields += meta_buffer.read_var_uint32() - - # Check field count limit - if num_fields > MAX_FIELDS_PER_CLASS: - raise ValueError( - f"Class has {num_fields} fields, exceeding the maximum allowed {MAX_FIELDS_PER_CLASS} fields. This may indicate malicious data." - ) - - # Check if registered by name - is_registered_by_name = (meta_header & REGISTER_BY_NAME_FLAG) != 0 + is_struct = (meta_header & STRUCT_TYPEDEF_FLAG) != 0 + num_fields = 0 + is_registered_by_name = False type_cls = None user_type_id = NO_USER_TYPE_ID + if is_struct: + is_registered_by_name = (meta_header & REGISTER_BY_NAME_FLAG) != 0 + compatible = (meta_header & COMPATIBLE_TYPEDEF_FLAG) != 0 + if is_registered_by_name: + type_id = ( + TypeId.NAMED_COMPATIBLE_STRUCT if compatible else TypeId.NAMED_STRUCT + ) + else: + type_id = TypeId.COMPATIBLE_STRUCT if compatible else TypeId.STRUCT + num_fields = meta_header & SMALL_NUM_FIELDS_THRESHOLD + if num_fields == SMALL_NUM_FIELDS_THRESHOLD: + num_fields += meta_buffer.read_var_uint32() + if num_fields > MAX_FIELDS_PER_CLASS: + raise ValueError( + f"Class has {num_fields} fields, exceeding the maximum allowed {MAX_FIELDS_PER_CLASS} fields." + ) + else: + if meta_header & 0b01110000: + raise ValueError("Invalid TypeDef kind header") + type_id = xlang_non_struct_type_id(meta_header & 0b1111) + is_registered_by_name = is_named_typedef_kind(type_id) + # Read type info if is_registered_by_name: namespace = read_namespace(meta_buffer) @@ -130,13 +149,10 @@ def decode_typedef(buffer: Buffer, resolver, header=None) -> TypeDef: # Look up the type_id from namespace and typename type_info = resolver.get_type_info_by_name(namespace, typename) if type_info: - type_id = type_info.type_id + if type_info.type_id != type_id: + raise ValueError("TypeDef kind does not match registered type metadata") type_cls = type_info.cls - else: - # Fallback to COMPATIBLE_STRUCT if not found - type_id = TypeId.COMPATIBLE_STRUCT else: - type_id = meta_buffer.read_uint8() user_type_id = meta_buffer.read_var_uint32() if resolver.is_registered_by_id(type_id=type_id, user_type_id=user_type_id): type_info = resolver.get_type_info_by_id(type_id, user_type_id=user_type_id) @@ -147,12 +163,17 @@ def decode_typedef(buffer: Buffer, resolver, header=None) -> TypeDef: namespace = "fory" typename = f"UnknownStruct{user_type_id if user_type_id != NO_USER_TYPE_ID else type_id}" name = namespace + "." + typename if namespace else typename - # Read fields info if present - field_infos = [] - if has_fields_meta: - field_infos = read_fields_info(meta_buffer, resolver, name, num_fields) - if type_cls is None: - if getattr(resolver, "strict", False) and not getattr(resolver, "_allow_unregistered_typedef", False): + + field_infos = read_fields_info(meta_buffer, resolver, name, num_fields) + if not is_struct and field_infos: + raise ValueError("Non-struct TypeDef cannot carry field metadata") + if meta_buffer.get_reader_index() != meta_buffer.size(): + raise ValueError("Invalid TypeDef metadata size") + _validate_parsed_typedef_hash(header, encoded_meta_data) + if type_cls is None and is_struct_typedef_kind(type_id): + if getattr(resolver, "strict", False) and not getattr( + resolver, "_allow_unregistered_typedef", False + ): raise ValueError(f"TypeDef {name} is not registered in strict mode") # Check generated class count limit if _generated_class_count >= MAX_GENERATED_CLASSES: @@ -168,9 +189,11 @@ def decode_typedef(buffer: Buffer, resolver, header=None) -> TypeDef: type_cls = make_dataclass(class_name, field_definitions) policy = getattr(resolver, "policy", None) if policy is not None: - result = policy.validate_class(type_cls, is_local=False) + result = policy.validate_class(type_cls, is_local=True) if result is not None: type_cls = result + elif type_cls is None: + raise ValueError(f"TypeDef {name} is not registered") # Create TypeDef object type_def = TypeDef( @@ -186,6 +209,11 @@ def decode_typedef(buffer: Buffer, resolver, header=None) -> TypeDef: return type_def +def _validate_parsed_typedef_hash(header: int, encoded_meta_data: bytes) -> None: + if _typedef_header_hash(encoded_meta_data) != (header & TYPEDEF_HASH_MASK): + raise ValueError("Invalid TypeDef metadata hash") + + def read_namespace(buffer: Buffer) -> str: """Read namespace from the buffer.""" return read_meta_string(buffer, NAMESPACE_DECODER, NAMESPACE_ENCODINGS) @@ -196,7 +224,9 @@ def read_typename(buffer: Buffer) -> str: return read_meta_string(buffer, TYPENAME_DECODER, TYPE_NAME_ENCODINGS) -def read_meta_string(buffer: Buffer, decoder: MetaStringDecoder, encodings: List[Encoding]) -> str: +def read_meta_string( + buffer: Buffer, decoder: MetaStringDecoder, encodings: List[Encoding] +) -> str: """Read a big meta string (namespace/typename) from the buffer using 6-bit size field.""" # Read encoding and length combined in first byte header = buffer.read_uint8() @@ -222,7 +252,9 @@ def read_meta_string(buffer: Buffer, decoder: MetaStringDecoder, encodings: List return "" -def read_fields_info(buffer: Buffer, resolver, defined_class: str, num_fields: int) -> List[FieldInfo]: +def read_fields_info( + buffer: Buffer, resolver, defined_class: str, num_fields: int +) -> List[FieldInfo]: """Read field information from the buffer.""" field_infos = [] for _ in range(num_fields): @@ -268,7 +300,9 @@ def read_field_info(buffer: Buffer, resolver, defined_class: str) -> FieldInfo: # Read field type info (no field name to read for TAG_ID) xtype_id = buffer.read_uint8() - field_type = FieldType.read_with_type(buffer, resolver, xtype_id, is_nullable, is_tracking_ref) + field_type = FieldType.read_with_type( + buffer, resolver, xtype_id, is_nullable, is_tracking_ref + ) # For TAG_ID encoding, use tag_id as field name placeholder field_name = f"__tag_{tag_id}__" @@ -283,7 +317,9 @@ def read_field_info(buffer: Buffer, resolver, defined_class: str) -> FieldInfo: # Read field type info BEFORE field name (matching Java TypeDefDecoder order) xtype_id = buffer.read_uint8() - field_type = FieldType.read_with_type(buffer, resolver, xtype_id, is_nullable, is_tracking_ref) + field_type = FieldType.read_with_type( + buffer, resolver, xtype_id, is_nullable, is_tracking_ref + ) # Read field name meta string # Keep the wire field name as-is; TypeDef._resolve_field_names_from_tag_ids() diff --git a/python/pyfory/meta/typedef_encoder.py b/python/pyfory/meta/typedef_encoder.py index 7cfb348896..4877d73cbf 100644 --- a/python/pyfory/meta/typedef_encoder.py +++ b/python/pyfory/meta/typedef_encoder.py @@ -22,23 +22,26 @@ build_field_infos, SMALL_NUM_FIELDS_THRESHOLD, REGISTER_BY_NAME_FLAG, + COMPATIBLE_TYPEDEF_FLAG, + STRUCT_TYPEDEF_FLAG, FIELD_NAME_SIZE_THRESHOLD, BIG_NAME_THRESHOLD, COMPRESS_META_FLAG, - HAS_FIELDS_META_FLAG, META_SIZE_MASKS, - NUM_HASH_BITS, FIELD_NAME_ENCODINGS, NAMESPACE_ENCODINGS, TYPE_NAME_ENCODINGS, FIELD_NAME_ENCODING_TAG_ID, TAG_ID_SIZE_THRESHOLD, + is_struct_typedef_kind, + xlang_non_struct_kind_code, + _typedef_header_hash, ) from pyfory.meta.metastring import MetaStringEncoder from pyfory._fory import NO_USER_TYPE_ID +from pyfory.types import TypeId from pyfory.serialization import Buffer -from pyfory.lib.mmh3 import hash_buffer # Meta string encoders @@ -58,35 +61,43 @@ def encode_typedef(type_resolver, cls, include_fields: bool = True): Returns: The encoded TypeDef. """ - if include_fields: + type_id, user_type_id = type_resolver.get_registered_type_ids(cls) + if include_fields and is_struct_typedef_kind(type_id): field_infos = build_field_infos(type_resolver, cls) else: field_infos = [] buffer = Buffer.allocate(64) - # Write meta header - header = len(field_infos) - if len(field_infos) >= SMALL_NUM_FIELDS_THRESHOLD: - header = SMALL_NUM_FIELDS_THRESHOLD + # Write kind header + if is_struct_typedef_kind(type_id): + num_fields = len(field_infos) + header = STRUCT_TYPEDEF_FLAG | min(num_fields, SMALL_NUM_FIELDS_THRESHOLD) + if type_id in {TypeId.COMPATIBLE_STRUCT, TypeId.NAMED_COMPATIBLE_STRUCT}: + header |= COMPATIBLE_TYPEDEF_FLAG if type_resolver.is_registered_by_name(cls): header |= REGISTER_BY_NAME_FLAG - buffer.write_uint8(header) - buffer.write_var_uint32(len(field_infos) - SMALL_NUM_FIELDS_THRESHOLD) + if num_fields >= SMALL_NUM_FIELDS_THRESHOLD: + buffer.write_uint8(header) + buffer.write_var_uint32(num_fields - SMALL_NUM_FIELDS_THRESHOLD) + else: + buffer.write_uint8(header) else: - if type_resolver.is_registered_by_name(cls): - header |= REGISTER_BY_NAME_FLAG - buffer.write_uint8(header) + if field_infos: + raise ValueError( + f"Non-struct TypeDef {type_id} cannot carry field metadata" + ) + buffer.write_uint8(xlang_non_struct_kind_code(type_id)) # Write type info - type_id, user_type_id = type_resolver.get_registered_type_ids(cls) if type_resolver.is_registered_by_name(cls): namespace, typename = type_resolver.get_registered_name(cls) write_namespace(buffer, namespace) write_typename(buffer, typename) else: - assert type_resolver.is_registered_by_id(cls=cls), "Class must be registered by name or id" - buffer.write_uint8(type_id) + assert type_resolver.is_registered_by_id( + cls=cls + ), "Class must be registered by name or id" if user_type_id in {None, NO_USER_TYPE_ID}: raise ValueError(f"user_type_id required for type_id {type_id}") buffer.write_var_uint32(user_type_id) @@ -97,12 +108,9 @@ def encode_typedef(type_resolver, cls, include_fields: bool = True): # Get the encoded binary (only the written portion, not the full buffer) binary = buffer.to_bytes(0, buffer.get_writer_index()) - # Temporary xlang behavior: always write TypeDef metadata uncompressed. - # Some runtimes still do not support TypeMeta decompression, so keep the - # xlang wire payload uncompressed until all xlang implementations support it. is_compressed = False # Prepend header - binary = prepend_header(binary, is_compressed, len(field_infos) > 0) + binary = prepend_header(binary, is_compressed) # Extract namespace and typename if type_resolver.is_registered_by_name(cls): namespace, typename = type_resolver.get_registered_name(cls) @@ -112,23 +120,29 @@ def encode_typedef(type_resolver, cls, include_fields: bool = True): splits.insert(0, "") namespace, typename = splits - result = TypeDef(namespace, typename, cls, type_id, field_infos, binary, is_compressed, user_type_id=user_type_id) + result = TypeDef( + namespace, + typename, + cls, + type_id, + field_infos, + binary, + is_compressed, + user_type_id=user_type_id, + ) return result -def prepend_header(buffer: bytes, is_compressed: bool, has_fields_meta: bool): +def prepend_header(buffer: bytes, is_compressed: bool): """Prepend header to the buffer.""" meta_size = len(buffer) - hash = hash_buffer(buffer, 47)[0] - hash <<= 64 - NUM_HASH_BITS - header = abs(hash) & 0x7FFFFFFFFFFFFFFF # Ensure it fits in 63 bits + header = _typedef_header_hash(buffer) if is_compressed: header |= COMPRESS_META_FLAG - if has_fields_meta: - header |= HAS_FIELDS_META_FLAG - header |= min(meta_size, META_SIZE_MASKS) + if header >= (1 << 63): + header -= 1 << 64 result = Buffer.allocate(meta_size + 8) result.write_int64(header) if meta_size >= META_SIZE_MASKS: @@ -146,7 +160,9 @@ def write_namespace(buffer: Buffer, namespace: str): # The `6 bits size: 0~63` will be used to indicate size `0~62`, # the value `63` the size need more byte to read, the encoding will encode `size - 62` as a varint next. meta_string = NAMESPACE_ENCODER.encode(namespace, NAMESPACE_ENCODINGS) - write_meta_string(buffer, meta_string, NAMESPACE_ENCODINGS.index(meta_string.encoding)) + write_meta_string( + buffer, meta_string, NAMESPACE_ENCODINGS.index(meta_string.encoding) + ) def write_typename(buffer: Buffer, typename: str): @@ -158,7 +174,9 @@ def write_typename(buffer: Buffer, typename: str): # The `6 bits size: 0~63` will be used to indicate size `1~64`, # the value `63` the size need more byte to read, the encoding will encode `size - 63` as a varint next. meta_string = TYPENAME_ENCODER.encode(typename, TYPE_NAME_ENCODINGS) - write_meta_string(buffer, meta_string, TYPE_NAME_ENCODINGS.index(meta_string.encoding)) + write_meta_string( + buffer, meta_string, TYPE_NAME_ENCODINGS.index(meta_string.encoding) + ) def write_meta_string(buffer: Buffer, meta_string, encoding_value: int): @@ -230,7 +248,9 @@ def write_field_info(buffer: Buffer, field_info: FieldInfo): field_info.field_type.write(buffer, False) else: # Field name encoding - encoding = FIELD_NAME_ENCODER.compute_encoding(field_info.name, FIELD_NAME_ENCODINGS) + encoding = FIELD_NAME_ENCODER.compute_encoding( + field_info.name, FIELD_NAME_ENCODINGS + ) meta_string = FIELD_NAME_ENCODER.encode_with_encoding(field_info.name, encoding) # Store (length - 1) in size field, matching Java TypeDefEncoder field_name_binary_size = len(meta_string.encoded_data) - 1 diff --git a/python/pyfory/registry.py b/python/pyfory/registry.py index bdb270aea6..50a869319e 100644 --- a/python/pyfory/registry.py +++ b/python/pyfory/registry.py @@ -156,6 +156,8 @@ logger = logging.getLogger(__name__) namespace_decoder = MetaStringDecoder(".", "_") typename_decoder = MetaStringDecoder("$", "_") +MAX_CACHED_TYPE_DEFS = 8192 +MAX_CACHED_ENCODED_META_STRINGS = 8192 _NO_REF_NUMERIC_TYPE_IDS = frozenset( { @@ -313,7 +315,8 @@ def get_or_create_encoded_meta_string(self, data: bytes, hashcode: int) -> Encod encoded_meta_string = self._encoded_metastrings.get(key) if encoded_meta_string is None: encoded_meta_string = EncodedMetaString(data, hashcode) - self._encoded_metastrings[key] = encoded_meta_string + if len(self._encoded_metastrings) < MAX_CACHED_ENCODED_META_STRINGS: + self._encoded_metastrings[key] = encoded_meta_string return encoded_meta_string @@ -374,7 +377,6 @@ def __init__(self, config, *, shared_registry): self._named_type_to_type_info = dict() self.namespace_encoder = MetaStringEncoder(".", "_") self.namespace_decoder = MetaStringDecoder(".", "_") - # Cache for TypeDef and TypeInfo tuples (similar to Java's classIdToDef) self._meta_shared_type_info = {} self.typename_encoder = MetaStringEncoder("$", "_") self.typename_decoder = MetaStringDecoder("$", "_") @@ -1157,12 +1159,14 @@ def _read_and_build_type_info(self, buffer): """ # Read the header (first 8 bytes) to get the type ID header = buffer.read_int64() - # Check if we already have this TypeDef cached type_info = self._meta_shared_type_info.get(header) if type_info is not None: + # Header-cache hits intentionally skip without rehashing. Entries reach this cache only + # after a successful TypeDef parse and 52-bit body-hash validation. skip_typedef(buffer, header) - else: - type_def = decode_typedef(buffer, self, header=header) - type_info = self._build_type_info_from_typedef(type_def) + return type_info + type_def = decode_typedef(buffer, self, header=header) + type_info = self._build_type_info_from_typedef(type_def) + if len(self._meta_shared_type_info) < MAX_CACHED_TYPE_DEFS: self._meta_shared_type_info[header] = type_info return type_info diff --git a/python/pyfory/serialization.pyx b/python/pyfory/serialization.pyx index 683025a955..954afb6f22 100644 --- a/python/pyfory/serialization.pyx +++ b/python/pyfory/serialization.pyx @@ -34,6 +34,8 @@ from cpython.dict cimport PyDict_Next from cpython.list cimport PyList_New, PyList_SET_ITEM from cpython.tuple cimport PyTuple_New, PyTuple_SET_ITEM from cpython.ref cimport Py_INCREF, Py_XDECREF +from cpython.bytes cimport PyBytes_GET_SIZE +from libc.string cimport memcmp from pyfory.includes.libflat_hash_map cimport flat_hash_map from pyfory.includes.libutil cimport FlatIntMap from pyfory._fory import ( @@ -79,6 +81,7 @@ ENABLE_FORY_CYTHON_SERIALIZATION = os.environ.get( cdef int32_t NOT_NULL_BOOL_FLAG = (NOT_NULL_VALUE_FLAG & 0xFF) | (TypeId.BOOL << 8) cdef int32_t NOT_NULL_STRING_FLAG = (NOT_NULL_VALUE_FLAG & 0xFF) | (TypeId.STRING << 8) cdef int32_t NOT_NULL_FLOAT64_FLAG = (NOT_NULL_VALUE_FLAG & 0xFF) | (TypeId.FLOAT64 << 8) +cdef int32_t MAX_CACHED_TYPE_DEFS = 8192 _PRIMITIVE_TYPEVAR_NAMES = frozenset( { @@ -492,6 +495,7 @@ cdef class TypeResolver: self._c_types_info[ typeinfo.cls] = typeinfo if self._c_types_info.size() * 10 >= self._c_types_info.bucket_count() * 5: self._c_types_info.rehash(self._c_types_info.size() * 2) + if typeinfo.typename_bytes is not None: self._c_meta_hash_to_type_info[ pair[int64_t, int64_t]( @@ -556,28 +560,35 @@ cdef class TypeResolver: cdef TypeInfo typeinfo = self._meta_shared_type_info.get(header) cdef object type_def if typeinfo is not None: + # Header-cache hits intentionally skip without rehashing. Entries reach this cache only + # after a successful TypeDef parse and 52-bit body-hash validation. _skip_typedef_fast(buffer, header) return typeinfo type_def = decode_typedef(buffer, self.resolver, header=header) typeinfo = self.resolver._build_type_info_from_typedef(type_def) - self._meta_shared_type_info[header] = typeinfo + if len(self._meta_shared_type_info) < MAX_CACHED_TYPE_DEFS: + self._meta_shared_type_info[header] = typeinfo return typeinfo - cdef inline TypeInfo _load_bytes_to_type_info(self, object ns_metabytes, object type_metabytes): + cdef inline TypeInfo _load_bytes_to_type_info( + self, object ns_metabytes, object type_metabytes + ): cdef pair[int64_t, int64_t] hash_key = pair[int64_t, int64_t]( ns_metabytes.hashcode, type_metabytes.hashcode, ) - cdef pair[pair[int64_t, int64_t], PyObject *] *entry = self._c_meta_hash_to_type_info.find(hash_key) + cdef pair[pair[int64_t, int64_t], PyObject *] *entry = ( + self._c_meta_hash_to_type_info.find(hash_key) + ) cdef TypeInfo typeinfo if entry != NULL and deref(entry).second != NULL: return deref(entry).second typeinfo = self.resolver._load_metabytes_to_type_info(ns_metabytes, type_metabytes) - self._c_meta_hash_to_type_info[ - hash_key - ] = typeinfo + if self._c_meta_hash_to_type_info.size() < MAX_CACHED_TYPE_DEFS: + self._c_meta_hash_to_type_info[hash_key] = typeinfo return typeinfo + cdef inline void _skip_typedef_fast(Buffer buffer, int64_t header): cdef int32_t meta_size = (header & 0xFF) cdef int32_t reader_index @@ -591,6 +602,7 @@ cdef inline void _skip_typedef_fast(Buffer buffer, int64_t header): buffer.set_reader_index(reader_index + meta_size) + namespace_decoder = MetaStringDecoder(".", "_") typename_decoder = MetaStringDecoder("$", "_") @@ -998,6 +1010,7 @@ cdef class Fory: cdef Buffer _serialize(self, obj, Buffer buffer=None, buffer_callback=None, unsupported_callback=None): cdef WriteContext write_context = self.write_context cdef int32_t mask_index + cdef uint8_t bitmap if buffer is None: self.buffer.set_writer_index(0) buffer = self.buffer @@ -1011,16 +1024,10 @@ cdef class Fory: mask_index = buffer.get_writer_index() buffer.grow(1) buffer.set_writer_index(mask_index + 1) - buffer.put_int8(mask_index, 0) - if obj is None: - set_bit(buffer, mask_index, 0) - else: - clear_bit(buffer, mask_index, 0) - set_bit(buffer, mask_index, 1) + bitmap = 1 if self.xlang else 0 if buffer_callback is not None: - set_bit(buffer, mask_index, 2) - else: - clear_bit(buffer, mask_index, 2) + bitmap |= 2 + buffer.put_int8(mask_index, bitmap) write_context.write_ref(obj) return buffer @@ -1038,15 +1045,19 @@ cdef class Fory: cdef ReadContext read_context = self.read_context cdef Buffer read_buffer cdef int32_t reader_index + cdef uint8_t bitmap cdef bint peer_out_of_band_enabled if isinstance(buffer, bytes): buffer = Buffer(buffer, max_binary_size=self.max_binary_size) read_buffer = buffer reader_index = read_buffer.get_reader_index() read_buffer.set_reader_index(reader_index + 1) - if get_bit(read_buffer, reader_index, 0): - return None - peer_out_of_band_enabled = get_bit(read_buffer, reader_index, 2) + bitmap = read_buffer.get_int8(reader_index) + if bitmap & 0xFC: + raise ValueError(f"Unsupported root header bitmap 0x{bitmap:02x}") + if ((bitmap & 1) != 0) != self.xlang: + raise ValueError("Header bitmap mismatch at xlang bit") + peer_out_of_band_enabled = (bitmap & 2) != 0 if peer_out_of_band_enabled: assert buffers is not None, ( "buffers shouldn't be null when the serialized stream is produced with buffer_callback not null." diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py index 55ccd473ae..b4d9732454 100644 --- a/python/pyfory/serializer.py +++ b/python/pyfory/serializer.py @@ -51,7 +51,9 @@ def _import_validated_module(policy, module_name): if result is not None: if isinstance(result, types.ModuleType): return result - assert isinstance(result, str), f"validate_module must return module, str, or None, got {type(result)}" + assert isinstance( + result, str + ), f"validate_module must return module, str, or None, got {type(result)}" module_name = result return importlib.import_module(module_name) @@ -72,7 +74,43 @@ def _check_collection_size(read_context, size, kind): if size < 0: raise ValueError(f"{kind} size {size} must be non-negative") if size > read_context.max_collection_size: - raise ValueError(f"{kind} size {size} exceeds the configured limit of {read_context.max_collection_size}") + raise ValueError( + f"{kind} size {size} exceeds the configured limit of {read_context.max_collection_size}" + ) + + +def _is_local_qualname(module_name, qualname): + return module_name == "__main__" or "" in qualname + + +def _is_local_class(cls): + return _is_local_qualname(cls.__module__, cls.__qualname__) + + +def _is_local_receiver(obj): + cls = obj if isinstance(obj, type) else obj.__class__ + return _is_local_class(cls) + + +def _is_local_callable(obj): + if isinstance(obj, type): + return _is_local_class(obj) + if isinstance(obj, (types.MethodType, types.BuiltinMethodType)): + receiver = getattr(obj, "__self__", None) + if receiver is not None and not inspect.ismodule(receiver): + return _is_local_receiver(receiver) + module_name = getattr(obj, "__module__", "") + qualname = getattr(obj, "__qualname__", getattr(obj, "__name__", "")) + return _is_local_qualname(module_name, qualname) + + +def _is_bound_method_value(obj): + if isinstance(obj, types.MethodType): + return True + if isinstance(obj, types.BuiltinMethodType): + receiver = getattr(obj, "__self__", None) + return receiver is not None and not inspect.ismodule(receiver) + return False def _validate_function_value(policy, func, is_local): @@ -81,7 +119,14 @@ def _validate_function_value(policy, func, is_local): if result is not None: func = result if isinstance(func, type): - raise TypeError(f"Function serializer resolved class {func.__module__}.{func.__qualname__}") + raise TypeError( + f"Function serializer resolved class {func.__module__}.{func.__qualname__}" + ) + if _is_bound_method_value(func): + result = policy.validate_method(func, is_local=is_local) + if result is not None: + func = result + return func if not callable(func): raise TypeError(f"Function serializer resolved non-callable object {func!r}") result = policy.validate_function(func, is_local=is_local) @@ -336,7 +381,9 @@ def _write_decimal_parts(write_context, scale: int, unscaled: int): magnitude = abs(unscaled) if magnitude == 0: raise ValueError("Zero must use the small decimal encoding") - payload = magnitude.to_bytes((magnitude.bit_length() + 7) // 8, "little", signed=False) + payload = magnitude.to_bytes( + (magnitude.bit_length() + 7) // 8, "little", signed=False + ) meta = (len(payload) << 1) | (1 if unscaled < 0 else 0) _write_var_uint64(write_context, (meta << 1) | 1) write_context.write_bytes(payload) @@ -507,7 +554,10 @@ def _build_pyarray_typecode_tables(): class PyArraySerializer(Serializer): typecode_dict = typecode_dict - typecodearray_type = {typecode: ftype for typecode, (_itemsize, ftype, _type_id) in typecode_dict.items()} + typecodearray_type = { + typecode: ftype + for typecode, (_itemsize, ftype, _type_id) in typecode_dict.items() + } def __init__(self, type_resolver, ftype, type_id: str): super().__init__(type_resolver, ftype) @@ -520,13 +570,17 @@ def _array_type_id(self, value): raise TypeError(f"Unsupported array.array typecode {value.typecode!r}") itemsize, _ftype, type_id = entry if value.itemsize != itemsize: - raise TypeError(f"array.array typecode {value.typecode!r} has itemsize {value.itemsize}, expected {itemsize}") + raise TypeError( + f"array.array typecode {value.typecode!r} has itemsize {value.itemsize}, expected {itemsize}" + ) return type_id def write(self, buffer, value): actual_type_id = self._array_type_id(value) if actual_type_id != self.type_id: - raise TypeError(f"array.array typecode {value.typecode!r} maps to type id {actual_type_id}, expected {self.type_id}") + raise TypeError( + f"array.array typecode {value.typecode!r} maps to type id {actual_type_id}, expected {self.type_id}" + ) view = memoryview(value) assert view.itemsize == self.itemsize assert view.c_contiguous # TODO handle contiguous @@ -592,7 +646,9 @@ def fory_array_serializer_type(type_id): class ForyArrayListAdapterSerializer(Serializer): - def __init__(self, type_resolver, wrapper_type, wrapper_serializer, field_name=None): + def __init__( + self, type_resolver, wrapper_type, wrapper_serializer, field_name=None + ): super().__init__(type_resolver, wrapper_type) self.wrapper_type = wrapper_type self.wrapper_serializer = wrapper_serializer @@ -601,13 +657,17 @@ def __init__(self, type_resolver, wrapper_type, wrapper_serializer, field_name=N def _copy_list_to_wrapper(self, value): if type(value) is not list: - raise TypeError(f"pyfory.Array list adapter for {self.field_name!r} requires list, got {type(value)!r}") + raise TypeError( + f"pyfory.Array list adapter for {self.field_name!r} requires list, got {type(value)!r}" + ) wrapper = self.wrapper_type() for index, item in enumerate(value): try: wrapper.append(item) except (TypeError, ValueError, OverflowError) as exc: - raise type(exc)(f"{self.field_name}[{index}] invalid for {self.wrapper_type.__name__}: {exc}") from exc + raise type(exc)( + f"{self.field_name}[{index}] invalid for {self.wrapper_type.__name__}: {exc}" + ) from exc return wrapper def write(self, buffer, value): @@ -645,7 +705,12 @@ def _build_pyarray_serializer(self, type_resolver, type_id): def _build_ndarray_serializer(self, type_resolver, type_id): if np is None: return None - for dtype, (_itemsize, _format, ftype, dtype_type_id) in Numpy1DArraySerializer.dtypes_dict.items(): + for dtype, ( + _itemsize, + _format, + ftype, + dtype_type_id, + ) in Numpy1DArraySerializer.dtypes_dict.items(): if dtype_type_id == type_id: return Numpy1DArraySerializer(type_resolver, ftype, dtype) return None @@ -660,7 +725,9 @@ def write(self, buffer, value): return if value_type is array.array: if self.pyarray_serializer is None: - raise TypeError(f"pyfory.Array field {self.field_name!r} does not support array.array for this element type") + raise TypeError( + f"pyfory.Array field {self.field_name!r} does not support array.array for this element type" + ) actual_type_id = self.pyarray_serializer._array_type_id(value) if actual_type_id != self.type_id: raise TypeError( @@ -670,7 +737,9 @@ def write(self, buffer, value): return if np is not None and value_type is np.ndarray: if self.ndarray_serializer is None: - raise TypeError(f"pyfory.Array field {self.field_name!r} does not support numpy ndarray for this element type") + raise TypeError( + f"pyfory.Array field {self.field_name!r} does not support numpy ndarray for this element type" + ) if value.dtype != self.ndarray_serializer.dtype or value.ndim != 1: raise TypeError( f"pyfory.Array field {self.field_name!r} requires 1D ndarray with dtype " @@ -679,7 +748,9 @@ def write(self, buffer, value): self.ndarray_serializer.write(buffer, value) return if value is None: - raise TypeError(f"pyfory.Array field {self.field_name!r} value must not be None") + raise TypeError( + f"pyfory.Array field {self.field_name!r} value must not be None" + ) raise TypeError( f"pyfory.Array field {self.field_name!r} requires {self.wrapper_type.__name__}, list, numpy.ndarray, or array.array, got {type(value)!r}" ) @@ -698,9 +769,13 @@ def write(self, buffer, value): try: itemsize, ftype, type_id = typecode_dict[value.typecode] except KeyError as exc: - raise TypeError(f"Unsupported array.array typecode {value.typecode!r}") from exc + raise TypeError( + f"Unsupported array.array typecode {value.typecode!r}" + ) from exc if value.itemsize != itemsize: - raise TypeError(f"array.array typecode {value.typecode!r} has itemsize {value.itemsize}, expected {itemsize}") + raise TypeError( + f"array.array typecode {value.typecode!r} has itemsize {value.itemsize}, expected {itemsize}" + ) view = memoryview(value) nbytes = len(value) * itemsize buffer.write_uint8(type_id) @@ -767,7 +842,9 @@ def read(self, buffer): ) else: _np_dtypes_dict = {} -_np_typeid_to_dtype = {type_id: dtype for dtype, (_, _, _, type_id) in _np_dtypes_dict.items()} +_np_typeid_to_dtype = { + type_id: dtype for dtype, (_, _, _, type_id) in _np_dtypes_dict.items() +} class Numpy1DArraySerializer(Serializer): @@ -791,7 +868,9 @@ def write(self, buffer, value): if self.dtype == np.dtype("bool") or not view.c_contiguous: if not is_little_endian and self.itemsize > 1: # Swap bytes on big-endian machines for multi-byte types - buffer.write_bytes(value.astype(value.dtype.newbyteorder("<")).tobytes()) + buffer.write_bytes( + value.astype(value.dtype.newbyteorder("<")).tobytes() + ) else: buffer.write_bytes(value.tobytes()) elif is_little_endian or self.itemsize == 1: @@ -818,7 +897,9 @@ def write(self, buffer, value): # Write concrete 1D primitive ndarray using type id + bytes payload. dtype_info = _np_dtypes_dict.get(value.dtype) if dtype_info is None or value.ndim != 1: - raise NotImplementedError(f"Unsupported ndarray: dtype={value.dtype}, ndim={value.ndim}") + raise NotImplementedError( + f"Unsupported ndarray: dtype={value.dtype}, ndim={value.ndim}" + ) itemsize, _typecode, _ftype, type_id = dtype_info view = memoryview(value) nbytes = len(value) * itemsize @@ -826,7 +907,9 @@ def write(self, buffer, value): buffer.write_var_uint32(nbytes) if value.dtype == np.dtype("bool") or not view.c_contiguous: if not is_little_endian and itemsize > 1: - buffer.write_bytes(value.astype(value.dtype.newbyteorder("<")).tobytes()) + buffer.write_bytes( + value.astype(value.dtype.newbyteorder("<")).tobytes() + ) else: buffer.write_bytes(value.tobytes()) elif is_little_endian or itemsize == 1: @@ -1082,17 +1165,11 @@ def __init__(self, type_resolver, cls): def _validate_global_object(self, policy, obj): result = None if isinstance(obj, type): - result = policy.validate_class(obj, is_local=False) - elif isinstance( - obj, - ( - types.FunctionType, - types.BuiltinFunctionType, - types.MethodType, - types.BuiltinMethodType, - ), - ): - result = policy.validate_function(obj, is_local=False) + result = policy.validate_class(obj, is_local=_is_local_class(obj)) + elif _is_bound_method_value(obj): + result = policy.validate_method(obj, is_local=_is_local_callable(obj)) + elif isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)): + result = policy.validate_function(obj, is_local=_is_local_callable(obj)) if result is not None: obj = result return obj @@ -1112,7 +1189,10 @@ def _resolve_global_name(self, read_context, global_name): def write(self, write_context, value): # Try __reduce_ex__ first (with protocol 5 for pickle5 out-of-band buffer support), then __reduce__ # Check if the object has a custom __reduce_ex__ method (not just the default from object) - if hasattr(value, "__reduce_ex__") and value.__class__.__reduce_ex__ is not object.__reduce_ex__: + if ( + hasattr(value, "__reduce_ex__") + and value.__class__.__reduce_ex__ is not object.__reduce_ex__ + ): try: reduce_result = value.__reduce_ex__(5) except TypeError: @@ -1121,7 +1201,9 @@ def write(self, write_context, value): elif hasattr(value, "__reduce__"): reduce_result = value.__reduce__() else: - raise ValueError(f"Object {value} has no __reduce__ or __reduce_ex__ method") + raise ValueError( + f"Object {value} has no __reduce__ or __reduce_ex__ method" + ) # Handle different __reduce__ return formats if isinstance(reduce_result, str): @@ -1152,7 +1234,9 @@ def write(self, write_context, value): dictitems, ) else: - raise ValueError(f"Invalid __reduce__ result length: {len(reduce_result)}") + raise ValueError( + f"Invalid __reduce__ result length: {len(reduce_result)}" + ) else: raise ValueError(f"Invalid __reduce__ result type: {type(reduce_result)}") write_context.write_var_uint32(len(reduce_data)) @@ -1241,15 +1325,19 @@ def read(self, read_context): return self._deserialize_local_class(read_context) module_name = read_context.read_string() qualname = read_context.read_string() - cls = _resolve_validated_module_qualname(read_context.policy, module_name, qualname) - result = read_context.policy.validate_class(cls, is_local=False) + cls = _resolve_validated_module_qualname( + read_context.policy, module_name, qualname + ) + result = read_context.policy.validate_class(cls, is_local=_is_local_class(cls)) if result is not None: cls = result return cls def _serialize_local_class(self, write_context, cls): """Serialize a local class by capturing its creation context.""" - assert self.type_resolver.track_ref, "Reference tracking must be enabled for local classes serialization" + assert ( + self.type_resolver.track_ref + ), "Reference tracking must be enabled for local classes serialization" module = cls.__module__ qualname = cls.__qualname__ write_context.write_string(module) @@ -1282,7 +1370,9 @@ def _serialize_local_class(self, write_context, cls): def _deserialize_local_class(self, read_context): """Deserialize a local class by recreating it with the captured context.""" - assert self.type_resolver.track_ref, "Reference tracking must be enabled for local classes deserialization" + assert ( + self.type_resolver.track_ref + ), "Reference tracking must be enabled for local classes deserialization" module = read_context.read_string() qualname = read_context.read_string() name = qualname.rsplit(".", 1)[-1] @@ -1291,7 +1381,9 @@ def _deserialize_local_class(self, read_context): num_bases = read_context.read_var_uint32() _check_collection_size(read_context, num_bases, "local class base") bases = tuple(read_context.read_ref() for _ in range(num_bases)) - read_context.policy.authorize_instantiation(type, module=module, qualname=qualname, bases=bases) + read_context.policy.authorize_instantiation( + type, module=module, qualname=qualname, bases=bases + ) cls = type(name, bases, {}) read_context.set_read_ref(ref_id, cls) result = read_context.policy.validate_class(cls, is_local=True) @@ -1452,7 +1544,9 @@ def _serialize_function(self, write_context, func): global_names.add(name) # Create and serialize a dictionary with only the necessary globals - globals_to_serialize = {name: globals_dict[name] for name in global_names if name in globals_dict} + globals_to_serialize = { + name: globals_dict[name] for name in global_names if name in globals_dict + } write_context.write_ref(globals_to_serialize) # Handle additional attributes @@ -1479,13 +1573,19 @@ def _deserialize_function(self, read_context): policy = read_context.policy if policy is DEFAULT_POLICY: return getattr(self_obj, method_name) - return _resolve_validated_bound_method(policy, self_obj, method_name, is_local=False) + return _resolve_validated_bound_method( + policy, self_obj, method_name, is_local=_is_local_receiver(self_obj) + ) if func_type_id == 1: module = read_context.read_string() qualname = read_context.read_string() - mod = _resolve_validated_module_qualname(read_context.policy, module, qualname) - return _validate_function_value(read_context.policy, mod, is_local=False) + mod = _resolve_validated_module_qualname( + read_context.policy, module, qualname + ) + return _validate_function_value( + read_context.policy, mod, is_local=_is_local_callable(mod) + ) module = read_context.read_string() qualname = read_context.read_string() @@ -1572,14 +1672,18 @@ def read(self, read_context): if read_context.read_bool(): module = read_context.read_string() func = _resolve_validated_module_attr(read_context.policy, module, name) - func = _validate_function_value(read_context.policy, func, is_local=False) + func = _validate_function_value( + read_context.policy, func, is_local=_is_local_callable(func) + ) else: obj = read_context.read_ref() policy = read_context.policy if policy is DEFAULT_POLICY: func = getattr(obj, name) else: - func = _resolve_validated_bound_method(policy, obj, name, is_local=False) + func = _resolve_validated_bound_method( + policy, obj, name, is_local=_is_local_receiver(obj) + ) return func @@ -1604,9 +1708,12 @@ def read(self, read_context): if self._use_default_policy: return getattr(instance, method_name) - cls = instance if isinstance(instance, type) else instance.__class__ - is_local = cls.__module__ == "__main__" or "" in cls.__qualname__ - return _resolve_validated_bound_method(read_context.policy, instance, method_name, is_local=is_local) + return _resolve_validated_bound_method( + read_context.policy, + instance, + method_name, + is_local=_is_local_receiver(instance), + ) class ObjectSerializer(Serializer): @@ -1645,7 +1752,9 @@ def read(self, read_context): read_context.reference(obj) num_fields = read_context.read_var_uint32() if num_fields > read_context.max_collection_size: - raise ValueError(f"object field size {num_fields} exceeds the configured limit of {read_context.max_collection_size}") + raise ValueError( + f"object field size {num_fields} exceeds the configured limit of {read_context.max_collection_size}" + ) state = {} for _ in range(num_fields): field_name = read_context.read_string() @@ -1663,7 +1772,9 @@ def read(self, read_context): read_context.reference(obj) num_fields = read_context.read_var_uint32() if num_fields > read_context.max_collection_size: - raise ValueError(f"object field size {num_fields} exceeds the configured limit of {read_context.max_collection_size}") + raise ValueError( + f"object field size {num_fields} exceeds the configured limit of {read_context.max_collection_size}" + ) for _ in range(num_fields): field_name = read_context.read_string() field_value = read_context.read_ref() diff --git a/python/pyfory/tests/test_metastring_resolver.py b/python/pyfory/tests/test_metastring_resolver.py index 4ea23921e7..256f2c371f 100644 --- a/python/pyfory/tests/test_metastring_resolver.py +++ b/python/pyfory/tests/test_metastring_resolver.py @@ -15,10 +15,17 @@ # specific language governing permissions and limitations # under the License. +import pytest + from pyfory import Buffer from pyfory.context import EncodedMetaString, MetaStringReader, MetaStringWriter from pyfory.meta.metastring import MetaStringEncoder -from pyfory.registry import SharedRegistry +from pyfory.registry import MAX_CACHED_ENCODED_META_STRINGS, SharedRegistry + +try: + from pyfory.serialization import MetaStringReader as CythonMetaStringReader +except ImportError: + CythonMetaStringReader = None def _roundtrip_meta_string(encoded_meta_string): @@ -47,3 +54,107 @@ def test_meta_string_writer_reader(): _roundtrip_meta_string(shared_registry.get_encoded_meta_string(encoder.encode("你好,世界"))) _roundtrip_meta_string(shared_registry.get_encoded_meta_string(encoder.encode("こんにちは世界"))) _roundtrip_meta_string(shared_registry.get_encoded_meta_string(encoder.encode("hello, world" * 10))) + + +def test_read_big_metastring_rejects_noncanonical_hash(): + shared_registry = SharedRegistry() + encoder = MetaStringEncoder("$", "_") + encoded_meta_string = shared_registry.get_encoded_meta_string(encoder.encode("hello, world" * 10)) + reader = MetaStringReader(shared_registry) + buffer = Buffer.allocate(128) + + buffer.write_var_uint32(encoded_meta_string.length << 1) + buffer.write_int64(encoded_meta_string.hashcode + 0x100) + buffer.write_bytes(encoded_meta_string.data) + buffer.set_reader_index(0) + + with pytest.raises(ValueError, match="Malformed metastring hash"): + reader.read_encoded_meta_string(buffer) + + +def test_cached_big_metastring_validates_bytes_before_reuse(): + shared_registry = SharedRegistry() + encoder = MetaStringEncoder("$", "_") + encoded_meta_string = shared_registry.get_encoded_meta_string(encoder.encode("hello, world" * 10)) + reader = MetaStringReader(shared_registry) + buffer = Buffer.allocate(128) + + buffer.write_var_uint32(encoded_meta_string.length << 1) + buffer.write_int64(encoded_meta_string.hashcode) + buffer.write_bytes(encoded_meta_string.data) + buffer.set_reader_index(0) + assert reader.read_encoded_meta_string(buffer) is encoded_meta_string + + forged_data = bytes([encoded_meta_string.data[0] ^ 1]) + encoded_meta_string.data[1:] + buffer.set_writer_index(0) + buffer.set_reader_index(0) + buffer.write_var_uint32(len(forged_data) << 1) + buffer.write_int64(encoded_meta_string.hashcode) + buffer.write_bytes(forged_data) + buffer.set_reader_index(0) + + with pytest.raises(ValueError, match="Malformed metastring hash"): + reader.read_encoded_meta_string(buffer) + + +@pytest.mark.skipif(CythonMetaStringReader is None, reason="Cython serialization extension is unavailable") +def test_cython_cached_big_metastring_validates_bytes_before_reuse(): + shared_registry = SharedRegistry() + encoder = MetaStringEncoder("$", "_") + encoded_meta_string = shared_registry.get_encoded_meta_string(encoder.encode("hello, world" * 10)) + reader = CythonMetaStringReader(shared_registry) + buffer = Buffer.allocate(128) + + buffer.write_var_uint32(encoded_meta_string.length << 1) + buffer.write_int64(encoded_meta_string.hashcode) + buffer.write_bytes(encoded_meta_string.data) + buffer.set_reader_index(0) + assert reader.read_encoded_meta_string(buffer) is encoded_meta_string + + forged_data = bytes([encoded_meta_string.data[0] ^ 1]) + encoded_meta_string.data[1:] + buffer.set_writer_index(0) + buffer.set_reader_index(0) + buffer.write_var_uint32(len(forged_data) << 1) + buffer.write_int64(encoded_meta_string.hashcode) + buffer.write_bytes(forged_data) + buffer.set_reader_index(0) + + with pytest.raises(ValueError, match="Malformed metastring hash"): + reader.read_encoded_meta_string(buffer) + + +def test_read_metastring_reset_clears_dynamic_ids_only(): + shared_registry = SharedRegistry() + encoded_meta_string = shared_registry.get_encoded_meta_string(MetaStringEncoder("$", "_").encode("hello")) + shared_registry._encoded_metastrings.clear() + reader = MetaStringReader(shared_registry) + buffer = Buffer.allocate(64) + + buffer.write_var_uint32(encoded_meta_string.length << 1) + buffer.write_int8(encoded_meta_string.encoding) + buffer.write_bytes(encoded_meta_string.data) + buffer.set_reader_index(0) + + assert reader.read_encoded_meta_string(buffer) == encoded_meta_string + assert reader._small_encoded_meta_strings + assert shared_registry._encoded_metastrings + reader.reset() + assert reader._small_encoded_meta_strings + + ref_buffer = Buffer.allocate(8) + ref_buffer.write_var_uint32((1 << 1) | 1) + ref_buffer.set_reader_index(0) + with pytest.raises(ValueError, match="Invalid dynamic metastring id 1"): + reader.read_encoded_meta_string(ref_buffer) + + +def test_encoded_metastring_registry_cache_is_bounded(): + shared_registry = SharedRegistry() + for i in range(MAX_CACHED_ENCODED_META_STRINGS): + shared_registry.get_or_create_encoded_meta_string(f"name-{i}".encode(), i << 8) + + encoded_meta_string = shared_registry.get_or_create_encoded_meta_string(b"overflow", 123 << 8) + + assert encoded_meta_string.data == b"overflow" + assert len(shared_registry._encoded_metastrings) == MAX_CACHED_ENCODED_META_STRINGS + assert ((123 << 8), b"overflow") not in shared_registry._encoded_metastrings diff --git a/python/pyfory/tests/test_policy.py b/python/pyfory/tests/test_policy.py index 7da8dd099f..e659ce30d3 100644 --- a/python/pyfory/tests/test_policy.py +++ b/python/pyfory/tests/test_policy.py @@ -19,13 +19,30 @@ import pytest from pyfory import Fory, DeserializationPolicy -from pyfory.serializer import FunctionSerializer, NativeFuncMethodSerializer +from pyfory.serializer import ( + FunctionSerializer, + NativeFuncMethodSerializer, + TypeSerializer, +) def policy_global_function(): return "safe" +class PolicyMethodHolder: + def run(self): + return "safe" + + +policy_method_holder = PolicyMethodHolder() +policy_global_bound_method = policy_method_holder.run + + +class PolicyGlobalClass: + pass + + class FakeReadContext: def __init__(self, policy, values): self.policy = policy @@ -40,6 +57,9 @@ def read_bool(self): def read_string(self): return next(self._values) + def read_ref(self): + return next(self._values) + class FalseyState: bool_called = False @@ -96,7 +116,10 @@ def __init__(self, blocked_names): self.blocked_names = blocked_names def intercept_reduce_call(self, callable_obj, args, **kwargs): - if hasattr(callable_obj, "__name__") and callable_obj.__name__ in self.blocked_names: + if ( + hasattr(callable_obj, "__name__") + and callable_obj.__name__ in self.blocked_names + ): raise ValueError(f"Callable {callable_obj.__name__} is blocked") return None @@ -207,7 +230,11 @@ def intercept_setstate(self, obj, state, **kwargs): class SecretReduceHolder: def __reduce__(self): - return (SecretReduceHolder, (), {"username": "admin", "password": "secret123"}) + return ( + SecretReduceHolder, + (), + {"username": "admin", "password": "secret123"}, + ) def __setstate__(self, state): self.__dict__.update(state) @@ -322,7 +349,9 @@ def validate_class(self, cls, is_local, **kwargs): def intercept_reduce_call(self, callable_obj, args, **kwargs): if hasattr(callable_obj, "__name__"): - self.hooks_called.append(("intercept_reduce_call", callable_obj.__name__)) + self.hooks_called.append( + ("intercept_reduce_call", callable_obj.__name__) + ) return None def inspect_reduced_object(self, obj, **kwargs): @@ -569,6 +598,77 @@ def validate_method(self, method, is_local, **kwargs): assert not GuardedMethod.getattribute_called +def test_type_global_path_reports_main_class_as_local(): + class CaptureClassPolicy(DeserializationPolicy): + def __init__(self): + self.is_local_values = [] + + def validate_class(self, cls, is_local, **kwargs): + self.is_local_values.append(is_local) + return None + + original_module = PolicyGlobalClass.__module__ + PolicyGlobalClass.__module__ = "__main__" + try: + policy = CaptureClassPolicy() + fory = Fory(ref=True, strict=False, policy=policy) + serializer = TypeSerializer(fory.type_resolver, type) + read_context = FakeReadContext(policy, [0, __name__, "PolicyGlobalClass"]) + + assert serializer.read(read_context) is PolicyGlobalClass + assert policy.is_local_values == [True] + finally: + PolicyGlobalClass.__module__ = original_module + + +def test_function_bound_method_reports_receiver_locality_to_policy(): + class LocalReceiver: + def run(self): + return "safe" + + class CaptureMethodPolicy(DeserializationPolicy): + def __init__(self): + self.is_local_values = [] + + def validate_method(self, method, is_local, **kwargs): + self.is_local_values.append(is_local) + raise ValueError("method blocked") + + policy = CaptureMethodPolicy() + fory = Fory(ref=True, strict=False, policy=policy) + serializer = FunctionSerializer(fory.type_resolver, type(policy_global_function)) + read_context = FakeReadContext(policy, [0, LocalReceiver(), "run"]) + + with pytest.raises(ValueError, match="method blocked"): + serializer._deserialize_function(read_context) + assert policy.is_local_values == [True] + + +def test_native_bound_method_reports_receiver_locality_to_policy(): + class LocalReceiver: + def run(self): + return "safe" + + class CaptureMethodPolicy(DeserializationPolicy): + def __init__(self): + self.is_local_values = [] + + def validate_method(self, method, is_local, **kwargs): + self.is_local_values.append(is_local) + raise ValueError("method blocked") + + policy = CaptureMethodPolicy() + fory = Fory(ref=True, strict=False, policy=policy) + serializer = NativeFuncMethodSerializer( + fory.type_resolver, type(policy_global_function) + ) + read_context = FakeReadContext(policy, ["run", False, LocalReceiver()]) + + with pytest.raises(ValueError, match="method blocked"): + serializer.read(read_context) + assert policy.is_local_values == [True] + + def test_function_serializer_rejects_class_resolution(): """Test function deserialization cannot resolve classes through the function policy.""" @@ -596,6 +696,56 @@ def validate_function(self, func, is_local, **kwargs): assert policy.validate_function_calls == 0 +def test_function_global_method_resolution_uses_validate_method(): + class MethodPolicy(DeserializationPolicy): + def __init__(self): + self.validate_method_calls = 0 + self.validate_function_calls = 0 + + def validate_method(self, method, is_local, **kwargs): + self.validate_method_calls += 1 + raise ValueError("method blocked") + + def validate_function(self, func, is_local, **kwargs): + self.validate_function_calls += 1 + return None + + policy = MethodPolicy() + fory = Fory(ref=True, strict=False, policy=policy) + serializer = FunctionSerializer(fory.type_resolver, type(policy_global_function)) + read_context = FakeReadContext(policy, [1, __name__, "policy_global_bound_method"]) + + with pytest.raises(ValueError, match="method blocked"): + serializer._deserialize_function(read_context) + assert policy.validate_method_calls == 1 + assert policy.validate_function_calls == 0 + + +def test_function_global_path_reports_main_function_as_local(): + class CaptureFunctionPolicy(DeserializationPolicy): + def __init__(self): + self.is_local_values = [] + + def validate_function(self, func, is_local, **kwargs): + self.is_local_values.append(is_local) + return None + + original_module = policy_global_function.__module__ + policy_global_function.__module__ = "__main__" + try: + policy = CaptureFunctionPolicy() + fory = Fory(ref=True, strict=False, policy=policy) + serializer = FunctionSerializer( + fory.type_resolver, type(policy_global_function) + ) + read_context = FakeReadContext(policy, [1, __name__, "policy_global_function"]) + + assert serializer._deserialize_function(read_context) is policy_global_function + assert policy.is_local_values == [True] + finally: + policy_global_function.__module__ = original_module + + def test_native_function_serializer_rejects_class_resolution(): """Test native function deserialization cannot resolve classes through the function policy.""" @@ -614,7 +764,9 @@ def validate_function(self, func, is_local, **kwargs): policy = BlockClassPolicy() fory = Fory(ref=True, strict=False, policy=policy) - serializer = NativeFuncMethodSerializer(fory.type_resolver, type(policy_global_function)) + serializer = NativeFuncMethodSerializer( + fory.type_resolver, type(policy_global_function) + ) read_context = FakeReadContext(policy, ["Popen", True, "subprocess"]) with pytest.raises(ValueError, match="class blocked"): @@ -623,6 +775,62 @@ def validate_function(self, func, is_local, **kwargs): assert policy.validate_function_calls == 0 +def test_native_function_global_method_resolution_uses_validate_method(): + class MethodPolicy(DeserializationPolicy): + def __init__(self): + self.validate_method_calls = 0 + self.validate_function_calls = 0 + + def validate_method(self, method, is_local, **kwargs): + self.validate_method_calls += 1 + raise ValueError("method blocked") + + def validate_function(self, func, is_local, **kwargs): + self.validate_function_calls += 1 + return None + + policy = MethodPolicy() + fory = Fory(ref=True, strict=False, policy=policy) + serializer = NativeFuncMethodSerializer( + fory.type_resolver, type(policy_global_function) + ) + read_context = FakeReadContext( + policy, ["policy_global_bound_method", True, __name__] + ) + + with pytest.raises(ValueError, match="method blocked"): + serializer.read(read_context) + assert policy.validate_method_calls == 1 + assert policy.validate_function_calls == 0 + + +def test_native_function_global_path_reports_main_function_as_local(): + class CaptureFunctionPolicy(DeserializationPolicy): + def __init__(self): + self.is_local_values = [] + + def validate_function(self, func, is_local, **kwargs): + self.is_local_values.append(is_local) + return None + + original_module = policy_global_function.__module__ + policy_global_function.__module__ = "__main__" + try: + policy = CaptureFunctionPolicy() + fory = Fory(ref=True, strict=False, policy=policy) + serializer = NativeFuncMethodSerializer( + fory.type_resolver, type(policy_global_function) + ) + read_context = FakeReadContext( + policy, ["policy_global_function", True, __name__] + ) + + assert serializer.read(read_context) is policy_global_function + assert policy.is_local_values == [True] + finally: + policy_global_function.__module__ = original_module + + def test_global_function_deserialization_validates_module(): """Test validate_module policy hook for global function deserialization.""" @@ -825,3 +1033,37 @@ def validate_function(self, func, is_local, **kwargs): fory.deserialize(fory.serialize(GlobalNamePayload())) assert policy.validate_module_calls == 1 assert policy.validate_function_calls == 1 + + +def test_reduce_global_method_resolution_uses_validate_method(): + """Test reduce global-name method deserialization uses validate_method.""" + + class GlobalNamePayload: + def __reduce__(self): + return f"{__name__}.policy_global_bound_method" + + class MethodPolicy(DeserializationPolicy): + def __init__(self): + self.validate_module_calls = 0 + self.validate_method_calls = 0 + self.validate_function_calls = 0 + + def validate_module(self, module_name, **kwargs): + self.validate_module_calls += 1 + return None + + def validate_method(self, method, is_local, **kwargs): + self.validate_method_calls += 1 + raise ValueError("method blocked") + + def validate_function(self, func, is_local, **kwargs): + self.validate_function_calls += 1 + return None + + policy = MethodPolicy() + fory = Fory(ref=True, strict=False, policy=policy) + with pytest.raises(ValueError, match="method blocked"): + fory.deserialize(fory.serialize(GlobalNamePayload())) + assert policy.validate_module_calls == 1 + assert policy.validate_method_calls == 1 + assert policy.validate_function_calls == 0 diff --git a/python/pyfory/tests/test_typedef_encoding.py b/python/pyfory/tests/test_typedef_encoding.py index a7de4e2eb5..0af48dd991 100644 --- a/python/pyfory/tests/test_typedef_encoding.py +++ b/python/pyfory/tests/test_typedef_encoding.py @@ -20,7 +20,7 @@ """ import array -from dataclasses import dataclass +from dataclasses import dataclass, make_dataclass from typing import List, Dict import pytest @@ -34,8 +34,14 @@ CollectionFieldType, MapFieldType, DynamicFieldType, + FIELD_NAME_ENCODINGS, + COMPRESS_META_FLAG, +) +from pyfory.meta.typedef_encoder import ( + FIELD_NAME_ENCODER, + encode_typedef, + prepend_header, ) -from pyfory.meta.typedef_encoder import encode_typedef from pyfory.meta.typedef_decoder import decode_typedef from pyfory.types import TypeId from pyfory import Fory @@ -124,7 +130,9 @@ def test_typedef_creation(): FieldInfo("age", FieldType(TypeId.INT32, True, True, False), "TestTypeDef"), ] - typedef = TypeDef("", "TestTypeDef", None, TypeId.STRUCT, fields, b"encoded_data", False) + typedef = TypeDef( + "", "TestTypeDef", None, TypeId.STRUCT, fields, b"encoded_data", False + ) assert typedef.namespace == "" assert typedef.typename == "TestTypeDef" @@ -186,31 +194,128 @@ def test_encode_decode_typedef(): for i, field in enumerate(decoded_typedef.fields): assert field.name == typedef.fields[i].name assert field.field_type.type_id == typedef.fields[i].field_type.type_id - assert field.field_type.is_nullable == typedef.fields[i].field_type.is_nullable + assert ( + field.field_type.is_nullable == typedef.fields[i].field_type.is_nullable + ) + + +def test_decode_typedef_rejects_parsed_body_with_mismatched_hash(): + fory = Fory(xlang=True) + fory.register(SimpleTypeDef, namespace="example", typename="SimpleTypeDef") + typedef = encode_typedef(fory.type_resolver, SimpleTypeDef) + malformed = _corrupt_encoded_field_name(typedef, "value") + + with pytest.raises(ValueError, match="Invalid TypeDef metadata hash"): + decode_typedef(Buffer(malformed), fory.type_resolver) + + +def test_decode_typedef_rejects_hash_consistent_malformed_body(): + fory = Fory(xlang=True) + encoded = prepend_header(b"\x00", False) + + with pytest.raises(Exception): + decode_typedef(Buffer(encoded), fory.type_resolver) + + +def test_decode_typedef_rejects_compressed_xlang_metadata(): + fory = Fory(xlang=True) + fory.register(SimpleTypeDef, namespace="example", typename="SimpleTypeDef") + typedef = encode_typedef(fory.type_resolver, SimpleTypeDef) + source = Buffer(typedef.encoded) + header = source.read_int64() + malformed = Buffer.allocate(len(typedef.encoded)) + malformed.write_int64(header | COMPRESS_META_FLAG) + malformed.write_bytes(typedef.encoded[8:]) + + with pytest.raises(ValueError, match="Compressed xlang TypeDef"): + decode_typedef(Buffer(malformed.to_bytes()), fory.type_resolver) + + +def test_id_registered_typedef_extended_field_count_header(): + many_fields_type = make_dataclass( + "ManyTypeDefFields", [(f"field_{i}", int) for i in range(32)] + ) + fory = Fory(xlang=True) + fory.register(many_fields_type, type_id=701) + typedef = encode_typedef(fory.type_resolver, many_fields_type) + body_offset = _typedef_body_offset(typedef.encoded) + + assert typedef.encoded[body_offset] & 0x1F == 0x1F + assert typedef.encoded[body_offset] & 0x20 == 0 + decoded_typedef = decode_typedef(Buffer(typedef.encoded), fory.type_resolver) + assert len(decoded_typedef.fields) == 32 + + +def test_meta_shared_typedef_cache_is_bounded(): + fory = Fory(xlang=True, compatible=True) + fory.register(SimpleTypeDef, namespace="example", typename="SimpleTypeDef") + resolver = fory.type_resolver + read_and_build = getattr(resolver, "_read_and_build_type_info", None) + if read_and_build is None: + pytest.skip("pure-Python resolver internals are not exposed by this runtime") + typedef = encode_typedef(resolver, SimpleTypeDef) + header_buffer = Buffer(typedef.encoded) + header = header_buffer.read_int64() + for i in range(8192): + resolver._meta_shared_type_info[i] = object() + + typeinfo = read_and_build(Buffer(typedef.encoded)) + + assert typeinfo.type_def.type_id == typedef.type_id + assert header not in resolver._meta_shared_type_info + assert len(resolver._meta_shared_type_info) == 8192 + + +def _corrupt_encoded_field_name(typedef, field_name): + malformed = bytearray(typedef.encoded) + needle = FIELD_NAME_ENCODER.encode(field_name, FIELD_NAME_ENCODINGS).encoded_data + index = bytes(malformed).find(needle, 8) + assert index >= 8 + malformed[index + len(needle) - 1] ^= 1 + return bytes(malformed) + + +def _typedef_body_offset(encoded): + buffer = Buffer(encoded) + header = buffer.read_int64() + if header & 0xFF == 0xFF: + buffer.read_var_uint32() + return buffer.get_reader_index() def test_nested_container_typedef_preserves_declared_encoding(): fory = Fory(xlang=True) - fory.register(NestedEncodingTypeDef, namespace="example", typename="NestedEncodingTypeDef") + fory.register( + NestedEncodingTypeDef, namespace="example", typename="NestedEncodingTypeDef" + ) typedef = encode_typedef(fory.type_resolver, NestedEncodingTypeDef) values_field = next(field for field in typedef.fields if field.name == "values") assert values_field.field_type.type_id == TypeId.MAP assert values_field.field_type.key_type.type_id == TypeId.INT32 assert values_field.field_type.value_type.type_id == TypeId.LIST - assert values_field.field_type.value_type.element_type.type_id == TypeId.TAGGED_INT64 + assert ( + values_field.field_type.value_type.element_type.type_id == TypeId.TAGGED_INT64 + ) decoded_typedef = decode_typedef(Buffer(typedef.encoded), fory.type_resolver) - decoded_values_field = next(field for field in decoded_typedef.fields if field.name == "values") + decoded_values_field = next( + field for field in decoded_typedef.fields if field.name == "values" + ) assert decoded_values_field.field_type.type_id == TypeId.MAP assert decoded_values_field.field_type.key_type.type_id == TypeId.INT32 assert decoded_values_field.field_type.value_type.type_id == TypeId.LIST - assert decoded_values_field.field_type.value_type.element_type.type_id == TypeId.TAGGED_INT64 + assert ( + decoded_values_field.field_type.value_type.element_type.type_id + == TypeId.TAGGED_INT64 + ) def test_python_array_typehint_lowering_keeps_list_schema_distinct(): fory = Fory(xlang=True) - fory.register(PythonArrayTypeHints, namespace="example", typename="PythonArrayTypeHints") + fory.register( + PythonArrayTypeHints, namespace="example", typename="PythonArrayTypeHints" + ) typedef = encode_typedef(fory.type_resolver, PythonArrayTypeHints) fields = {field.name: field.field_type for field in typedef.fields} @@ -231,8 +336,14 @@ def test_python_array_typehint_lowering_keeps_list_schema_distinct(): def test_python_array_typehint_rejects_scalar_encoding_modifier(): fory = Fory(xlang=True) - fory.register(InvalidArrayModifierTypeDef, namespace="example", typename="InvalidArrayModifierTypeDef") - with pytest.raises(TypeError, match="array does not allow scalar encoding modifier"): + fory.register( + InvalidArrayModifierTypeDef, + namespace="example", + typename="InvalidArrayModifierTypeDef", + ) + with pytest.raises( + TypeError, match="array does not allow scalar encoding modifier" + ): encode_typedef(fory.type_resolver, InvalidArrayModifierTypeDef) @@ -257,7 +368,9 @@ def test_compatible_bytes_assigns_to_uint8_array(): _register_byte_sequence(writer, BytesPayload) _register_byte_sequence(reader, UInt8ArrayPayload) - decoded = reader.deserialize(writer.serialize(BytesPayload(payload=b"\x01\x02\xff"))) + decoded = reader.deserialize( + writer.serialize(BytesPayload(payload=b"\x01\x02\xff")) + ) assert isinstance(decoded, UInt8ArrayPayload) _assert_uint8_array_value(decoded.payload, [1, 2, 255]) @@ -269,7 +382,9 @@ def test_compatible_uint8_array_assigns_to_bytes(): _register_byte_sequence(writer, UInt8ArrayPayload) _register_byte_sequence(reader, BytesPayload) - decoded = reader.deserialize(writer.serialize(UInt8ArrayPayload(payload=_uint8_array_value([1, 2, 255])))) + decoded = reader.deserialize( + writer.serialize(UInt8ArrayPayload(payload=_uint8_array_value([1, 2, 255]))) + ) assert isinstance(decoded, BytesPayload) assert decoded.payload == b"\x01\x02\xff" diff --git a/python/pyfory/type_util.py b/python/pyfory/type_util.py index 533403a6a3..d0e8b02a93 100644 --- a/python/pyfory/type_util.py +++ b/python/pyfory/type_util.py @@ -113,7 +113,7 @@ def unwrap_array(type_): # modified from `fluent python` -def record_class_factory(cls_name, field_names): +def record_class_factory(cls_name, field_names, *, publish=True): """ record_factory: create simple classes just for holding data fields @@ -204,8 +204,9 @@ def as_dict(self): ) cls_ = type(cls_name, (object,), cls_attrs) - # combined with __reduce__ to make it pickable - globals()[cls_name] = cls_ + if publish: + # combined with __reduce__ to make it pickable + globals()[cls_name] = cls_ return cls_ @@ -390,7 +391,8 @@ def load_class(classname: str, policy=None): while classes: cls = getattr(cls, classes.pop(0)) if policy is not None: - result = policy.validate_class(cls, is_local=False) + is_local = cls.__module__ == "__main__" or "" in cls.__qualname__ + result = policy.validate_class(cls, is_local=is_local) if result is not None: cls = result return cls diff --git a/rust/fory-core/src/fory.rs b/rust/fory-core/src/fory.rs index f82d2251cd..bfe61dbd9a 100644 --- a/rust/fory-core/src/fory.rs +++ b/rust/fory-core/src/fory.rs @@ -24,7 +24,7 @@ use crate::resolver::RefMode; use crate::resolver::TypeResolver; use crate::serializer::ForyDefault; use crate::serializer::{Serializer, StructSerializer}; -use crate::type_id::config_flags::{IS_CROSS_LANGUAGE_FLAG, IS_NULL_FLAG}; +use crate::type_id::config_flags::{IS_CROSS_LANGUAGE_FLAG, IS_OUT_OF_BAND_FLAG}; use crate::type_id::SIZE_OF_REF_AND_TYPE; use std::cell::UnsafeCell; use std::mem; @@ -738,20 +738,17 @@ impl Fory { record: &T, context: &mut WriteContext, ) -> Result<(), Error> { - let is_none = record.fory_is_none(); - self.write_head::(is_none, &mut context.writer); - if !is_none { - // Use RefMode based on config: - // - If track_ref is enabled, use RefMode::Tracking for the root object - // - Otherwise, use RefMode::NullOnly which writes NOT_NULL_VALUE_FLAG - let ref_mode = if self.config.track_ref { - RefMode::Tracking - } else { - RefMode::NullOnly - }; - // TypeMeta is written inline during serialization (streaming protocol) - ::fory_write(record, context, ref_mode, true, false)?; - } + self.write_head::(&mut context.writer); + // Use RefMode based on config: + // - If track_ref is enabled, use RefMode::Tracking for the root object + // - Otherwise, use RefMode::NullOnly which writes NOT_NULL_VALUE_FLAG + let ref_mode = if self.config.track_ref { + RefMode::Tracking + } else { + RefMode::NullOnly + }; + // TypeMeta is written inline during serialization (streaming protocol) + ::fory_write(record, context, ref_mode, true, false)?; Ok(()) } @@ -991,16 +988,14 @@ impl Fory { /// Writes the serialization header to the writer. #[inline(always)] - pub fn write_head(&self, is_none: bool, writer: &mut Writer) { + pub fn write_head(&self, writer: &mut Writer) { const HEAD_SIZE: usize = 10; writer.reserve(T::fory_reserved_space() + SIZE_OF_REF_AND_TYPE + HEAD_SIZE); - let mut bitmap: u8 = 0; - if self.config.xlang { - bitmap |= IS_CROSS_LANGUAGE_FLAG; - } - if is_none { - bitmap |= IS_NULL_FLAG; - } + let bitmap = if self.config.xlang { + IS_CROSS_LANGUAGE_FLAG + } else { + 0 + }; writer.write_u8(bitmap); } @@ -1153,10 +1148,7 @@ impl Fory { &self, context: &mut ReadContext, ) -> Result { - let is_none = self.read_head(&mut context.reader)?; - if is_none { - return Ok(T::fory_default()); - } + self.read_head(&mut context.reader)?; // Use RefMode based on config: // - If track_ref is enabled, use RefMode::Tracking for the root object // - Otherwise, use RefMode::NullOnly @@ -1172,17 +1164,31 @@ impl Fory { } #[inline(always)] - fn read_head(&self, reader: &mut Reader) -> Result { + fn read_head(&self, reader: &mut Reader) -> Result<(), Error> { let bitmap = reader.read_u8()?; - let peer_is_xlang = (bitmap & IS_CROSS_LANGUAGE_FLAG) != 0; + let expected = if self.config.xlang { + IS_CROSS_LANGUAGE_FLAG + } else { + 0 + }; + if bitmap != expected { + return self.read_head_slow(bitmap, expected); + } + Ok(()) + } + + #[cold] + #[inline(never)] + fn read_head_slow(&self, bitmap: u8, expected: u8) -> Result<(), Error> { + const KNOWN_FLAGS: u8 = IS_CROSS_LANGUAGE_FLAG | IS_OUT_OF_BAND_FLAG; + ensure!( + (bitmap & !KNOWN_FLAGS) == 0 && (bitmap & IS_OUT_OF_BAND_FLAG) == 0, + Error::invalid_data("unsupported root header bitmap") + ); ensure!( - self.config.xlang == peer_is_xlang, + (bitmap & IS_CROSS_LANGUAGE_FLAG) == (expected & IS_CROSS_LANGUAGE_FLAG), Error::invalid_data("header bitmap mismatch at xlang bit") ); - let is_none = (bitmap & IS_NULL_FLAG) != 0; - if is_none { - return Ok(true); - } - Ok(false) + Ok(()) } } diff --git a/rust/fory-core/src/meta/type_meta.rs b/rust/fory-core/src/meta/type_meta.rs index ca32a39efd..ff40dd4a2b 100644 --- a/rust/fory-core/src/meta/type_meta.rs +++ b/rust/fory-core/src/meta/type_meta.rs @@ -24,7 +24,7 @@ use crate::meta::{ use crate::resolver::{TypeInfo, TypeResolver}; use crate::type_id::{ TypeId, BINARY, COMPATIBLE_STRUCT, ENUM, EXT, NAMED_COMPATIBLE_STRUCT, NAMED_ENUM, NAMED_EXT, - NAMED_STRUCT, STRUCT, UINT8_ARRAY, UNKNOWN, + NAMED_STRUCT, NAMED_UNION, STRUCT, TYPED_UNION, UINT8_ARRAY, UNKNOWN, }; use crate::util::{murmurhash3_x64_128, to_snake_case}; @@ -61,7 +61,9 @@ use std::rc::Rc; const SMALL_NUM_FIELDS_THRESHOLD: usize = 0b11111; const MAX_TYPE_META_FIELDS: usize = i16::MAX as usize; -const REGISTER_BY_NAME_FLAG: u8 = 0b100000; +const REGISTER_BY_NAME_FLAG: u8 = 0b0010_0000; +const COMPATIBLE_TYPEDEF_FLAG: u8 = 0b0100_0000; +const STRUCT_TYPEDEF_FLAG: u8 = 0b1000_0000; const FIELD_NAME_SIZE_THRESHOLD: usize = 0b1111; /// Marker value in encoding bits to indicate field ID mode (instead of field name) const FIELD_ID_ENCODING_MARKER: u8 = 0b11; @@ -71,9 +73,10 @@ const SMALL_FIELD_ID_THRESHOLD: i16 = 0b1111; const BIG_NAME_THRESHOLD: usize = 0b111111; const META_SIZE_MASK: i64 = 0xff; -const COMPRESS_META_FLAG: i64 = 0b1 << 9; -const HAS_FIELDS_META_FLAG: i64 = 0b1 << 8; -const NUM_HASH_BITS: i8 = 50; +const COMPRESS_META_FLAG: i64 = 0b1 << 8; +const RESERVED_META_FLAGS: i64 = 0b111 << 9; +const NUM_HASH_BITS: i8 = 52; +const TYPE_META_HASH_SHIFT: u32 = 64 - NUM_HASH_BITS as u32; const NO_USER_TYPE_ID: u32 = u32::MAX; const MAX_HASH32: u64 = (1 << 31) - 1; @@ -96,6 +99,88 @@ static FIELD_NAME_ENCODINGS: &[Encoding] = &[ Encoding::LowerUpperDigitSpecial, ]; +#[inline(always)] +fn is_struct_type_def_kind(type_id: u32) -> bool { + type_id == STRUCT + || type_id == COMPATIBLE_STRUCT + || type_id == NAMED_STRUCT + || type_id == NAMED_COMPATIBLE_STRUCT +} + +#[inline(always)] +fn is_named_type_def_kind(type_id: u32) -> bool { + type_id == NAMED_STRUCT + || type_id == NAMED_COMPATIBLE_STRUCT + || type_id == NAMED_ENUM + || type_id == NAMED_EXT + || type_id == NAMED_UNION +} + +fn non_struct_kind_code(type_id: u32) -> Result { + match type_id { + x if x == ENUM => Ok(0), + x if x == NAMED_ENUM => Ok(1), + x if x == EXT => Ok(2), + x if x == NAMED_EXT => Ok(3), + x if x == TYPED_UNION => Ok(4), + x if x == NAMED_UNION => Ok(5), + _ => Err(Error::invalid_data(format!( + "unsupported TypeMeta kind {type_id}" + ))), + } +} + +fn non_struct_type_id(kind_code: u8) -> Result { + match kind_code { + 0 => Ok(ENUM), + 1 => Ok(NAMED_ENUM), + 2 => Ok(EXT), + 3 => Ok(NAMED_EXT), + 4 => Ok(TYPED_UNION), + 5 => Ok(NAMED_UNION), + _ => Err(Error::invalid_data(format!( + "unsupported TypeMeta kind code {kind_code}" + ))), + } +} + +#[inline(always)] +fn validate_type_meta_header(header: i64) -> Result<(), Error> { + if (header & RESERVED_META_FLAGS) != 0 { + return Err(Error::invalid_data("invalid TypeMeta global header")); + } + if (header & COMPRESS_META_FLAG) != 0 { + return Err(Error::invalid_data("compressed TypeMeta is not supported")); + } + Ok(()) +} + +#[inline(always)] +fn read_type_meta_body_size(reader: &mut Reader, header: i64) -> Result { + let mut meta_size = (header & META_SIZE_MASK) as usize; + if meta_size == META_SIZE_MASK as usize { + meta_size = meta_size + .checked_add(reader.read_var_u32()? as usize) + .ok_or_else(|| Error::invalid_data("invalid TypeMeta metadata size"))?; + } + Ok(meta_size) +} + +#[inline(always)] +fn type_meta_hash_bits(body: &[u8]) -> u64 { + let hash_value = murmurhash3_x64_128(body, 47).0 as i64; + hash_value.wrapping_shl(TYPE_META_HASH_SHIFT).wrapping_abs() as u64 +} + +#[inline(always)] +fn validate_type_meta_body_hash(header: i64, body: &[u8]) -> Result<(), Error> { + let hash_mask = u64::MAX << TYPE_META_HASH_SHIFT; + if ((header as u64) & hash_mask) != (type_meta_hash_bits(body) & hash_mask) { + return Err(Error::invalid_data("TypeMeta metadata hash mismatch")); + } + Ok(()) +} + #[derive(Eq, Clone)] pub struct FieldType { pub type_id: u32, @@ -203,7 +288,12 @@ impl FieldType { _nullable = (header & 2) != 0; } else { type_id = header; - _nullable = nullable.unwrap(); + _nullable = match nullable { + Some(value) => value, + None => { + return Err(Error::invalid_data("missing TypeMeta field nullability")); + } + }; _ref_tracking = false; } if type_id == NAMED_ENUM { @@ -746,23 +836,33 @@ impl TypeMeta { fn to_meta_bytes(&self) -> Result, Error> { let mut buffer = vec![]; - // meta_bytes:| meta_header | fields meta | let mut writer = Writer::from_buffer(&mut buffer); let num_fields = self.field_infos.len(); - // meta_header: | unuse:2 bits | is_register_by_id:1 bit | num_fields:4 bits | - let mut meta_header: u8 = min(num_fields, SMALL_NUM_FIELDS_THRESHOLD) as u8; - if self.register_by_name { - meta_header |= REGISTER_BY_NAME_FLAG; + let mut meta_header: u8; + if is_struct_type_def_kind(self.type_id) { + meta_header = STRUCT_TYPEDEF_FLAG | min(num_fields, SMALL_NUM_FIELDS_THRESHOLD) as u8; + if self.type_id == COMPATIBLE_STRUCT || self.type_id == NAMED_COMPATIBLE_STRUCT { + meta_header |= COMPATIBLE_TYPEDEF_FLAG; + } + if self.register_by_name { + meta_header |= REGISTER_BY_NAME_FLAG; + } + } else { + if num_fields != 0 { + return Err(Error::invalid_data( + "non-struct TypeMeta cannot carry field metadata", + )); + } + meta_header = non_struct_kind_code(self.type_id)?; } writer.write_u8(meta_header); - if num_fields >= SMALL_NUM_FIELDS_THRESHOLD { + if is_struct_type_def_kind(self.type_id) && num_fields >= SMALL_NUM_FIELDS_THRESHOLD { writer.write_var_u32((num_fields - SMALL_NUM_FIELDS_THRESHOLD) as u32); } if self.register_by_name { self.write_namespace(&mut writer); self.write_type_name(&mut writer); } else { - writer.write_u8(self.type_id as u8); if self.user_type_id == NO_USER_TYPE_ID { return Err(Error::type_error( "User type id is required for this type id", @@ -770,8 +870,10 @@ impl TypeMeta { } writer.write_var_u32(self.user_type_id); } - for field in self.field_infos.iter() { - writer.write_bytes(field.to_bytes()?.as_slice()); + if is_struct_type_def_kind(self.type_id) { + for field in self.field_infos.iter() { + writer.write_bytes(field.to_bytes()?.as_slice()); + } } Ok(buffer) } @@ -781,28 +883,48 @@ impl TypeMeta { type_resolver: &TypeResolver, ) -> Result { let meta_header = reader.read_u8()?; - let register_by_name = (meta_header & REGISTER_BY_NAME_FLAG) != 0; - let mut num_fields = meta_header as usize & SMALL_NUM_FIELDS_THRESHOLD; - if num_fields == SMALL_NUM_FIELDS_THRESHOLD { - num_fields += reader.read_var_u32()? as usize; - } - // limit the number of fields to prevent potential OOM when creating Vec - if num_fields > MAX_TYPE_META_FIELDS { - return Err(Error::invalid_data(format!( - "too many fields in type meta: {}, max: {}", - num_fields, MAX_TYPE_META_FIELDS - ))); - } - let mut type_id; + let is_struct = (meta_header & STRUCT_TYPEDEF_FLAG) != 0; + let register_by_name; + let mut num_fields = 0usize; + let type_id; let mut user_type_id = NO_USER_TYPE_ID; let namespace; let type_name; + if is_struct { + register_by_name = (meta_header & REGISTER_BY_NAME_FLAG) != 0; + let compatible = (meta_header & COMPATIBLE_TYPEDEF_FLAG) != 0; + type_id = if register_by_name { + if compatible { + NAMED_COMPATIBLE_STRUCT + } else { + NAMED_STRUCT + } + } else if compatible { + COMPATIBLE_STRUCT + } else { + STRUCT + }; + num_fields = meta_header as usize & SMALL_NUM_FIELDS_THRESHOLD; + if num_fields == SMALL_NUM_FIELDS_THRESHOLD { + num_fields += reader.read_var_u32()? as usize; + } + if num_fields > MAX_TYPE_META_FIELDS { + return Err(Error::invalid_data(format!( + "too many fields in type meta: {}, max: {}", + num_fields, MAX_TYPE_META_FIELDS + ))); + } + } else { + if (meta_header & 0b0111_0000) != 0 { + return Err(Error::invalid_data("invalid TypeMeta kind header")); + } + type_id = non_struct_type_id(meta_header & 0b1111)?; + register_by_name = is_named_type_def_kind(type_id); + } if register_by_name { namespace = Self::read_namespace(reader)?; type_name = Self::read_type_name(reader)?; - type_id = 0; } else { - type_id = reader.read_u8()? as u32; user_type_id = reader.read_var_u32()?; let empty_name = MetaString::default(); namespace = empty_name.clone(); @@ -813,6 +935,11 @@ impl TypeMeta { for _ in 0..num_fields { field_infos.push(FieldInfo::from_bytes(reader)?); } + if !is_struct && !field_infos.is_empty() { + return Err(Error::invalid_data( + "non-struct TypeMeta cannot carry field metadata", + )); + } // TypeMeta field order is the payload order. Preserve the peer's encoded order while only // remapping matched fields to local generated field indexes. let mut sorted_field_infos = field_infos; @@ -821,11 +948,20 @@ impl TypeMeta { if let Some(type_info_current) = type_resolver.get_type_info_by_name(&namespace.original, &type_name.original) { - type_id = type_info_current.get_type_id() as u32; + if type_info_current.get_type_id() as u32 != type_id { + return Err(Error::invalid_data( + "TypeMeta kind does not match registered type metadata", + )); + } Self::assign_field_ids(&type_info_current, &mut sorted_field_infos); } } else if user_type_id != NO_USER_TYPE_ID { if let Some(type_info_current) = type_resolver.get_user_type_info_by_id(user_type_id) { + if type_info_current.get_type_id() as u32 != type_id { + return Err(Error::invalid_data( + "TypeMeta kind does not match registered type metadata", + )); + } Self::assign_field_ids(&type_info_current, &mut sorted_field_infos); } } else if let Some(type_info_current) = type_resolver.get_type_info_by_id(type_id) { @@ -933,21 +1069,7 @@ impl TypeMeta { type_resolver: &TypeResolver, ) -> Result { let header = reader.read_i64()?; - let meta_size = header & META_SIZE_MASK; - if meta_size == META_SIZE_MASK { - // meta_size += reader.read_var_u32() as i64; - reader.read_var_u32()?; - } - - // let write_fields_meta = (header & HAS_FIELDS_META_FLAG) != 0; - // let is_compressed: bool = (header & COMPRESS_META_FLAG) != 0; - let meta_hash = header >> (64 - NUM_HASH_BITS); - - // let current_meta_size = 0; - // while current_meta_size < meta_size {} - let mut meta = Self::from_meta_bytes(reader, type_resolver)?; - meta.hash = meta_hash; - Ok(meta) + Self::from_bytes_with_header(reader, type_resolver, header) } pub(crate) fn from_bytes_with_header( @@ -955,30 +1077,27 @@ impl TypeMeta { type_resolver: &TypeResolver, header: i64, ) -> Result { - let meta_size = header & META_SIZE_MASK; - if meta_size == META_SIZE_MASK { - // meta_size += reader.read_var_u32()? as i64; - reader.read_var_u32()?; + validate_type_meta_header(header)?; + let meta_size = read_type_meta_body_size(reader, header)?; + let body = reader.read_bytes(meta_size)?; + let mut body_reader = Reader::new(body); + let mut meta = Self::from_meta_bytes(&mut body_reader, type_resolver)?; + if !body_reader.slice_after_cursor().is_empty() { + return Err(Error::invalid_data("invalid TypeMeta metadata size")); } - - // let write_fields_meta = (header & HAS_FIELDS_META_FLAG) != 0; - // let is_compressed: bool = (header & COMPRESS_META_FLAG) != 0; - let meta_hash = header >> (64 - NUM_HASH_BITS); - - // let current_meta_size = 0; - // while current_meta_size < meta_size {} - let mut meta = Self::from_meta_bytes(reader, type_resolver)?; + validate_type_meta_body_hash(header, body)?; + let meta_hash = header >> TYPE_META_HASH_SHIFT; meta.hash = meta_hash; Ok(meta) } #[inline(always)] - pub fn skip_bytes(reader: &mut Reader, header: i64) -> Result<(), Error> { - let mut meta_size = header & META_SIZE_MASK; - if meta_size == META_SIZE_MASK { - meta_size += reader.read_var_u32()? as i64; - } - reader.skip(meta_size as usize) + pub(crate) fn skip_bytes_for_validated_header( + reader: &mut Reader, + header: i64, + ) -> Result<(), Error> { + let meta_size = read_type_meta_body_size(reader, header)?; + reader.skip(meta_size) } /// Check class version consistency, similar to Java's checkClassVersion @@ -1005,22 +1124,14 @@ impl TypeMeta { let mut meta_buffer = vec![]; let mut meta_writer = Writer::from_buffer(&mut meta_buffer); meta_writer.write_bytes(self.to_meta_bytes()?.as_slice()); - // global_binary_header:| hash:50bits | reserved:4bits | is_compressed:1bit | write_fields_meta:1bit | meta_size:8bits | let meta_size = meta_writer.len() as i64; let mut header: i64 = min(META_SIZE_MASK, meta_size); - let write_meta_fields_flag = !self.get_field_infos().is_empty(); - if write_meta_fields_flag { - header |= HAS_FIELDS_META_FLAG; - } - // Temporary xlang behavior: keep TypeMeta uncompressed. - // Some runtimes still do not support TypeMeta decompression. let is_compressed = false; if is_compressed { header |= COMPRESS_META_FLAG; } - let hash_value = murmurhash3_x64_128(meta_writer.dump().as_slice(), 47).0 as i64; - let meta_hash_shifted = (hash_value << (64 - NUM_HASH_BITS)).abs(); - let meta_hash = meta_hash_shifted >> (64 - NUM_HASH_BITS); + let meta_hash_shifted = type_meta_hash_bits(meta_writer.dump().as_slice()) as i64; + let meta_hash = meta_hash_shifted >> TYPE_META_HASH_SHIFT; header |= meta_hash_shifted; result.write_i64(header); if meta_size >= META_SIZE_MASK { @@ -1030,3 +1141,66 @@ impl TypeMeta { Ok((buffer, meta_hash)) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn rejects_body_hash_mismatch_after_successful_parse() { + let meta = TypeMeta::new( + STRUCT, + 1, + MetaString::get_empty().clone(), + MetaString::get_empty().clone(), + false, + vec![], + ) + .unwrap(); + let (mut bytes, _) = meta.to_bytes().unwrap(); + let last = bytes.len() - 1; + bytes[last] ^= 1; + + let mut reader = Reader::new(&bytes); + let result = TypeMeta::from_bytes(&mut reader, &TypeResolver::default()); + let message = result + .err() + .map(|error| error.to_string()) + .unwrap_or_default(); + assert!(message.contains("hash mismatch")); + } + + #[test] + fn rejects_hash_consistent_trailing_body_bytes() { + let meta = TypeMeta::new( + STRUCT, + 1, + MetaString::get_empty().clone(), + MetaString::get_empty().clone(), + false, + vec![], + ) + .unwrap(); + let mut body = meta.to_meta_bytes().unwrap(); + body.push(0); + + let mut frame = vec![]; + let mut writer = Writer::from_buffer(&mut frame); + let body_size = body.len() as i64; + let mut header = type_meta_hash_bits(&body) as i64; + header |= min(META_SIZE_MASK, body_size); + writer.write_i64(header); + if body_size >= META_SIZE_MASK { + writer.write_var_u32((body_size - META_SIZE_MASK) as u32); + } + writer.write_bytes(&body); + + let mut reader = Reader::new(&frame); + let result = TypeMeta::from_bytes(&mut reader, &TypeResolver::default()); + let message = result + .err() + .map(|error| error.to_string()) + .unwrap_or_default(); + assert!(message.contains("metadata size")); + } +} diff --git a/rust/fory-core/src/resolver/meta_resolver.rs b/rust/fory-core/src/resolver/meta_resolver.rs index 43f9851f42..c7d649aa7a 100644 --- a/rust/fory-core/src/resolver/meta_resolver.rs +++ b/rust/fory-core/src/resolver/meta_resolver.rs @@ -148,15 +148,19 @@ impl MetaReaderResolver { .as_ref() .filter(|_| self.last_meta_header == meta_header) { + // Header-cache hits intentionally skip without rehashing. Entries reach this cache + // only after a successful TypeMeta parse and 52-bit body-hash validation. self.reading_type_infos.push(type_info.clone()); - TypeMeta::skip_bytes(reader, meta_header)?; + TypeMeta::skip_bytes_for_validated_header(reader, meta_header)?; return Ok(type_info.clone()); } if let Some(type_info) = self.parsed_type_infos.get(&meta_header) { + // Header-cache hits intentionally skip without rehashing. Entries reach this cache + // only after a successful TypeMeta parse and 52-bit body-hash validation. self.last_meta_header = meta_header; self.last_type_info = Some(type_info.clone()); self.reading_type_infos.push(type_info.clone()); - TypeMeta::skip_bytes(reader, meta_header)?; + TypeMeta::skip_bytes_for_validated_header(reader, meta_header)?; Ok(type_info.clone()) } else { let type_meta = Rc::new(TypeMeta::from_bytes_with_header( diff --git a/rust/fory-core/src/resolver/type_resolver.rs b/rust/fory-core/src/resolver/type_resolver.rs index 6045e3c825..1a0e3fbaa7 100644 --- a/rust/fory-core/src/resolver/type_resolver.rs +++ b/rust/fory-core/src/resolver/type_resolver.rs @@ -32,6 +32,23 @@ use std::vec; use std::{any::Any, collections::HashMap}; +#[inline(always)] +fn supports_type_def(type_id: u32) -> bool { + matches!( + type_id, + x if x == TypeId::ENUM as u32 + || x == TypeId::NAMED_ENUM as u32 + || x == TypeId::STRUCT as u32 + || x == TypeId::COMPATIBLE_STRUCT as u32 + || x == TypeId::NAMED_STRUCT as u32 + || x == TypeId::NAMED_COMPATIBLE_STRUCT as u32 + || x == TypeId::EXT as u32 + || x == TypeId::NAMED_EXT as u32 + || x == TypeId::TYPED_UNION as u32 + || x == TypeId::NAMED_UNION as u32 + ) +} + type WriteFn = fn( &dyn Any, &mut WriteContext, @@ -411,7 +428,7 @@ fn build_struct_type_infos( let type_name_ms = TYPE_NAME_ENCODER .encode_with_encodings(&variant_type_name, TYPE_NAME_ENCODINGS)?; TypeMeta::from_fields( - TypeId::ENUM as u32, + TypeId::NAMED_COMPATIBLE_STRUCT as u32, NO_USER_TYPE_ID, namespace_ms, type_name_ms, @@ -442,7 +459,7 @@ fn build_struct_type_infos( ))); } TypeMeta::from_fields( - TypeId::ENUM as u32, + TypeId::NAMED_COMPATIBLE_STRUCT as u32, NO_USER_TYPE_ID, namespace_ms, type_name_ms, @@ -467,6 +484,9 @@ fn build_serializer_type_infos( partial_info: &TypeInfo, rust_type_id: std::any::TypeId, ) -> Result, Error> { + if !supports_type_def(partial_info.type_id as u32) { + return Ok(vec![(rust_type_id, partial_info.clone())]); + } // For ext types, we just build the type info with empty fields let type_meta = TypeMeta::from_fields( partial_info.type_id as u32, diff --git a/rust/fory-core/src/type_id.rs b/rust/fory-core/src/type_id.rs index 9cff317346..5c7ed103e6 100644 --- a/rust/fory-core/src/type_id.rs +++ b/rust/fory-core/src/type_id.rs @@ -414,9 +414,8 @@ pub const fn needs_user_type_id(type_id: u32) -> bool { } pub mod config_flags { - pub const IS_NULL_FLAG: u8 = 1 << 0; - pub const IS_CROSS_LANGUAGE_FLAG: u8 = 1 << 1; - pub const IS_OUT_OF_BAND_FLAG: u8 = 1 << 2; + pub const IS_CROSS_LANGUAGE_FLAG: u8 = 1 << 0; + pub const IS_OUT_OF_BAND_FLAG: u8 = 1 << 1; } // every object start with i8 i16 reference flag and type flag diff --git a/rust/tests/tests/test_cross_language.rs b/rust/tests/tests/test_cross_language.rs index 3c2f93589f..f255f0b6fb 100644 --- a/rust/tests/tests/test_cross_language.rs +++ b/rust/tests/tests/test_cross_language.rs @@ -124,7 +124,7 @@ fn test_naive_date_uses_var_i64_day_count() { fory.serialize_to(&mut buf, &day).unwrap(); let mut reader = Reader::new(buf.as_slice()); - assert_eq!(reader.read_u8().unwrap(), 2); + assert_eq!(reader.read_u8().unwrap(), 1); assert_eq!(reader.read_i8().unwrap(), -1); assert_eq!(reader.read_u8().unwrap(), TypeId::DATE as u8); assert_eq!(reader.read_var_i64().unwrap(), -1); diff --git a/rust/tests/tests/test_meta.rs b/rust/tests/tests/test_meta.rs index fae325dc61..86fbd6882f 100644 --- a/rust/tests/tests/test_meta.rs +++ b/rust/tests/tests/test_meta.rs @@ -21,7 +21,7 @@ use fory_core::type_id::TypeId; #[test] fn test_meta_hash() { let meta = TypeMeta::new( - 42, + TypeId::STRUCT as u32, 1, MetaString::get_empty().clone(), MetaString::get_empty().clone(), diff --git a/swift/Sources/Fory/Fory.swift b/swift/Sources/Fory/Fory.swift index 0c5ed23366..d36152f305 100644 --- a/swift/Sources/Fory/Fory.swift +++ b/swift/Sources/Fory/Fory.swift @@ -18,32 +18,32 @@ import Foundation public struct Config { - public var xlang: Bool - public var trackRef: Bool - public var compatible: Bool - public var checkClassVersion: Bool - public var maxCollectionSize: Int - public var maxBinarySize: Int - public var maxDepth: Int - - public init( - xlang: Bool = true, - trackRef: Bool = false, - compatible: Bool = false, - checkClassVersion: Bool? = nil, - maxCollectionSize: Int = 1_000_000, - maxBinarySize: Int = 64 * 1024 * 1024, - maxDepth: Int = 5 - ) { - let effectiveCheckClassVersion = checkClassVersion ?? (xlang && !compatible) - self.xlang = xlang - self.trackRef = trackRef - self.compatible = compatible - self.checkClassVersion = effectiveCheckClassVersion - self.maxCollectionSize = maxCollectionSize - self.maxBinarySize = maxBinarySize - self.maxDepth = maxDepth - } + public var xlang: Bool + public var trackRef: Bool + public var compatible: Bool + public var checkClassVersion: Bool + public var maxCollectionSize: Int + public var maxBinarySize: Int + public var maxDepth: Int + + public init( + xlang: Bool = true, + trackRef: Bool = false, + compatible: Bool = false, + checkClassVersion: Bool? = nil, + maxCollectionSize: Int = 1_000_000, + maxBinarySize: Int = 64 * 1024 * 1024, + maxDepth: Int = 5 + ) { + let effectiveCheckClassVersion = checkClassVersion ?? (xlang && !compatible) + self.xlang = xlang + self.trackRef = trackRef + self.compatible = compatible + self.checkClassVersion = effectiveCheckClassVersion + self.maxCollectionSize = maxCollectionSize + self.maxBinarySize = maxBinarySize + self.maxDepth = maxDepth + } } /// Single-threaded Fory runtime. @@ -52,495 +52,487 @@ public struct Config { /// reusable read/write context pair and must not be used concurrently from /// multiple threads. public final class Fory { - public let config: Config - let typeResolver: TypeResolver - private let writeContext: WriteContext - private let readContext: ReadContext - - public convenience init( - xlang: Bool = true, - trackRef: Bool = false, - compatible: Bool = false, - checkClassVersion: Bool? = nil, - maxCollectionSize: Int = 1_000_000, - maxBinarySize: Int = 64 * 1024 * 1024, - maxDepth: Int = 5 - ) { - self.init(config: Config( - xlang: xlang, - trackRef: trackRef, - compatible: compatible, - checkClassVersion: checkClassVersion, - maxCollectionSize: maxCollectionSize, - maxBinarySize: maxBinarySize, - maxDepth: maxDepth - )) - } - - public init(config: Config) { - self.config = config - self.typeResolver = TypeResolver(trackRef: self.config.trackRef) - self.writeContext = WriteContext( - buffer: ByteBuffer(), - typeResolver: typeResolver, - xlang: self.config.xlang, - trackRef: self.config.trackRef, - compatible: self.config.compatible, - checkClassVersion: self.config.checkClassVersion, - maxDepth: self.config.maxDepth, - metaStringWriteState: MetaStringWriteState() - ) - self.readContext = ReadContext( - buffer: ByteBuffer(), - typeResolver: typeResolver, - xlang: self.config.xlang, - trackRef: self.config.trackRef, - compatible: self.config.compatible, - checkClassVersion: self.config.checkClassVersion, - maxCollectionSize: self.config.maxCollectionSize, - maxBinarySize: self.config.maxBinarySize, - maxDepth: self.config.maxDepth - ) - } - - public func register(_ type: T.Type, id: UInt32) { - typeResolver.register(type, id: id) - } - - public func register(_ type: T.Type, name: String) throws { - try typeResolver.register(type, name: name) - } - - public func register(_ type: T.Type, namespace: String, name: String) throws { - try typeResolver.register(type, namespace: namespace, typeName: name) - } - - public func serialize(_ value: T) throws -> Data { - try serializeRoot(isNone: value.foryIsNone) { context in - try writeRootTypedValue(value, context: context) - } - } - - public func deserialize(_ data: Data, as _: T.Type = T.self) throws -> T { - try deserializeRoot( - data: data, - nilValue: T.foryDefault() - ) { context in - try readRootTypedValue(context: context) - } - } - - public func serialize(_ value: T, to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer, isNone: value.foryIsNone) { context in - try writeRootTypedValue(value, context: context) - } - } - - public func deserialize(from buffer: ByteBuffer, as _: T.Type = T.self) throws -> T { - try deserializeRoot( - from: buffer, - nilValue: T.foryDefault() - ) { context in - try readRootTypedValue(context: context) - } - } - - @_disfavoredOverload - public func serialize(_ value: Any) throws -> Data { - try serializeRoot(isNone: false) { context in - try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(_ data: Data, as _: Any.Type = Any.self) throws -> Any { - try deserializeRoot( - data: data, - nilValue: ForyAnyNullValue() - ) { context in - try castAnyDynamicValue( - context.readAny(refMode: refMode, readTypeInfo: true), - to: Any.self - ) - } - } - - @_disfavoredOverload - public func serialize(_ value: AnyObject) throws -> Data { - try serializeRoot(isNone: false) { context in - try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(_ data: Data, as _: AnyObject.Type = AnyObject.self) throws -> AnyObject { - try deserializeRoot( - data: data, - nilValue: NSNull() - ) { context in - try castAnyDynamicValue( - context.readAny(refMode: refMode, readTypeInfo: true), - to: AnyObject.self - ) - } - } - - @_disfavoredOverload - public func serialize(_ value: any Serializer) throws -> Data { - try serializeRoot(isNone: false) { context in - try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(_ data: Data, as _: (any Serializer).Type = (any Serializer).self) throws -> any Serializer { - try deserializeRoot( - data: data, - nilValue: ForyAnyNullValue() - ) { context in - try castAnyDynamicValue( - context.readAny(refMode: refMode, readTypeInfo: true), - to: (any Serializer).self - ) - } - } - - @_disfavoredOverload - public func serialize(_ value: [Any]) throws -> Data { - try serializeRoot(isNone: false) { context in - try context.writeListOfAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(_ data: Data, as _: [Any].Type = [Any].self) throws -> [Any] { - try deserializeRoot( - data: data, - nilValue: [] - ) { context in - try context.readListOfAny(refMode: refMode, readTypeInfo: true) ?? [] - } - } - - @_disfavoredOverload - public func serialize(_ value: [String: Any]) throws -> Data { - try serializeRoot(isNone: false) { context in - try context.writeMapStringToAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(_ data: Data, as _: [String: Any].Type = [String: Any].self) throws -> [String: Any] { - try deserializeRoot( - data: data, - nilValue: [:] - ) { context in - try context.readMapStringToAny(refMode: refMode, readTypeInfo: true) ?? [:] - } - } - - @_disfavoredOverload - public func serialize(_ value: [Int32: Any]) throws -> Data { - try serializeRoot(isNone: false) { context in - try context.writeMapInt32ToAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(_ data: Data, as _: [Int32: Any].Type = [Int32: Any].self) throws -> [Int32: Any] { - try deserializeRoot( - data: data, - nilValue: [:] - ) { context in - try context.readMapInt32ToAny(refMode: refMode, readTypeInfo: true) ?? [:] - } - } - - @_disfavoredOverload - public func serialize(_ value: [AnyHashable: Any]) throws -> Data { - try serializeRoot(isNone: false) { context in - try context.writeMapAnyHashableToAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(_ data: Data, as _: [AnyHashable: Any].Type = [AnyHashable: Any].self) throws -> [AnyHashable: Any] { - try deserializeRoot( - data: data, - nilValue: [:] - ) { context in - try context.readMapAnyHashableToAny(refMode: refMode, readTypeInfo: true) ?? [:] - } - } - - @_disfavoredOverload - public func serialize(_ value: [Any], to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer, isNone: false) { context in - try context.writeListOfAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func serialize(_ value: Any, to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer, isNone: false) { context in - try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(from buffer: ByteBuffer, as _: Any.Type = Any.self) throws -> Any { - try deserializeRoot( - from: buffer, - nilValue: ForyAnyNullValue() - ) { context in - try castAnyDynamicValue( - context.readAny(refMode: refMode, readTypeInfo: true), - to: Any.self - ) - } - } - - @_disfavoredOverload - public func serialize(_ value: AnyObject, to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer, isNone: false) { context in - try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(from buffer: ByteBuffer, as _: AnyObject.Type = AnyObject.self) throws -> AnyObject { - try deserializeRoot( - from: buffer, - nilValue: NSNull() - ) { context in - try castAnyDynamicValue( - context.readAny(refMode: refMode, readTypeInfo: true), - to: AnyObject.self - ) - } - } - - @_disfavoredOverload - public func serialize(_ value: any Serializer, to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer, isNone: false) { context in - try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize( - from buffer: ByteBuffer, - as _: (any Serializer).Type = (any Serializer).self - ) throws -> any Serializer { - try deserializeRoot( - from: buffer, - nilValue: ForyAnyNullValue() - ) { context in - try castAnyDynamicValue( - context.readAny(refMode: refMode, readTypeInfo: true), - to: (any Serializer).self - ) - } - } - - @_disfavoredOverload - public func deserialize(from buffer: ByteBuffer, as _: [Any].Type = [Any].self) throws -> [Any] { - try deserializeRoot( - from: buffer, - nilValue: [] - ) { context in - try context.readListOfAny(refMode: refMode, readTypeInfo: true) ?? [] - } - } - - @_disfavoredOverload - public func serialize(_ value: [String: Any], to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer, isNone: false) { context in - try context.writeMapStringToAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(from buffer: ByteBuffer, as _: [String: Any].Type = [String: Any].self) throws -> [String: Any] { - try deserializeRoot( - from: buffer, - nilValue: [:] - ) { context in - try context.readMapStringToAny(refMode: refMode, readTypeInfo: true) ?? [:] - } - } - - @_disfavoredOverload - public func serialize(_ value: [Int32: Any], to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer, isNone: false) { context in - try context.writeMapInt32ToAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func serialize(_ value: [AnyHashable: Any], to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer, isNone: false) { context in - try context.writeMapAnyHashableToAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(from buffer: ByteBuffer, as _: [Int32: Any].Type = [Int32: Any].self) throws -> [Int32: Any] { - try deserializeRoot( - from: buffer, - nilValue: [:] - ) { context in - try context.readMapInt32ToAny(refMode: refMode, readTypeInfo: true) ?? [:] - } - } - - @_disfavoredOverload - public func deserialize(from buffer: ByteBuffer, as _: [AnyHashable: Any].Type = [AnyHashable: Any].self) throws -> [AnyHashable: Any] { - try deserializeRoot( - from: buffer, - nilValue: [:] - ) { context in - try context.readMapAnyHashableToAny(refMode: refMode, readTypeInfo: true) ?? [:] - } - } - - @inlinable - @inline(__always) - func writeHead(buffer: ByteBuffer, isNone: Bool) { - var bitmap: UInt8 = 0 - if config.xlang { - bitmap |= ForyHeaderFlag.isXlang - } - if isNone { - bitmap |= ForyHeaderFlag.isNull - } - buffer.writeUInt8(bitmap) - } - - @inlinable - @inline(__always) - func readHead(buffer: ByteBuffer) throws -> Bool { - let bitmap = try buffer.readUInt8() - let peerIsXlang = (bitmap & ForyHeaderFlag.isXlang) != 0 - if peerIsXlang != config.xlang { - throw ForyError.invalidData("xlang bitmap mismatch") - } - return (bitmap & ForyHeaderFlag.isNull) != 0 - } - - @inline(__always) - private var refMode: RefMode { - config.trackRef ? .tracking : .nullOnly - } - - private func writeRootTypedValue( - _ value: T, - context: WriteContext - ) throws { - let writeTypeInfo = config.xlang || config.compatible - try value.foryWrite( - context, - refMode: config.trackRef ? .tracking : (writeTypeInfo ? .nullOnly : .none), - writeTypeInfo: writeTypeInfo, - hasGenerics: false - ) - } - - @inline(__always) - private func readRootTypedValue( - context: ReadContext - ) throws -> T { - let readTypeInfo = config.xlang || config.compatible - return try T.foryRead( - context, - refMode: config.trackRef ? .tracking : (readTypeInfo ? .nullOnly : .none), - readTypeInfo: readTypeInfo - ) - } - - @inline(__always) - func withReusableReadContext( - data: Data, - _ body: (ReadContext) throws -> R - ) rethrows -> R { - readContext.buffer.replace(with: data) - defer { - readContext.reset() - } - return try body(readContext) - } - - @inline(__always) - private func serializeRoot( - isNone: Bool, - _ body: (WriteContext) throws -> Void - ) throws -> Data { - typeResolver.finishRegistration() - let context = writeContext - context.buffer.clear() - defer { - context.reset() - } - writeHead(buffer: context.buffer, isNone: isNone) - if !isNone { - try body(context) - } - return context.buffer.copyToData() - } - - @inline(__always) - private func appendSerializedRoot( - to output: inout Data, - isNone: Bool, - _ body: (WriteContext) throws -> Void - ) throws { - typeResolver.finishRegistration() - let context = writeContext - context.buffer.clear() - defer { - context.reset() - } - writeHead(buffer: context.buffer, isNone: isNone) - if !isNone { - try body(context) - } - output.append(contentsOf: context.buffer.storage.prefix(context.buffer.count)) - } - - @inline(__always) - private func deserializeRoot( - data: Data, - nilValue: @autoclosure () -> R, - _ body: (ReadContext) throws -> R - ) throws -> R { - typeResolver.finishRegistration() - return try withReusableReadContext(data: data) { context in - if try readHead(buffer: context.buffer) { - return nilValue() - } - let value = try body(context) - if context.buffer.remaining != 0 { - throw ForyError.invalidData("unexpected trailing bytes at root: \(context.buffer.remaining)") - } - return value - } - } - - @inline(__always) - private func deserializeRoot( - from buffer: ByteBuffer, - nilValue: @autoclosure () -> R, - _ body: (ReadContext) throws -> R - ) throws -> R { - typeResolver.finishRegistration() - readContext.buffer.swapState(with: buffer) - defer { - readContext.buffer.swapState(with: buffer) - readContext.reset() - } - if try readHead(buffer: readContext.buffer) { - return nilValue() - } - return try body(readContext) - } + public let config: Config + let typeResolver: TypeResolver + private let writeContext: WriteContext + private let readContext: ReadContext + + public convenience init( + xlang: Bool = true, + trackRef: Bool = false, + compatible: Bool = false, + checkClassVersion: Bool? = nil, + maxCollectionSize: Int = 1_000_000, + maxBinarySize: Int = 64 * 1024 * 1024, + maxDepth: Int = 5 + ) { + self.init( + config: Config( + xlang: xlang, + trackRef: trackRef, + compatible: compatible, + checkClassVersion: checkClassVersion, + maxCollectionSize: maxCollectionSize, + maxBinarySize: maxBinarySize, + maxDepth: maxDepth + )) + } + + public init(config: Config) { + self.config = config + self.typeResolver = TypeResolver(trackRef: self.config.trackRef) + self.writeContext = WriteContext( + buffer: ByteBuffer(), + typeResolver: typeResolver, + xlang: self.config.xlang, + trackRef: self.config.trackRef, + compatible: self.config.compatible, + checkClassVersion: self.config.checkClassVersion, + maxDepth: self.config.maxDepth, + metaStringWriteState: MetaStringWriteState() + ) + self.readContext = ReadContext( + buffer: ByteBuffer(), + typeResolver: typeResolver, + xlang: self.config.xlang, + trackRef: self.config.trackRef, + compatible: self.config.compatible, + checkClassVersion: self.config.checkClassVersion, + maxCollectionSize: self.config.maxCollectionSize, + maxBinarySize: self.config.maxBinarySize, + maxDepth: self.config.maxDepth + ) + } + + public func register(_ type: T.Type, id: UInt32) { + typeResolver.register(type, id: id) + } + + public func register(_ type: T.Type, name: String) throws { + try typeResolver.register(type, name: name) + } + + public func register(_ type: T.Type, namespace: String, name: String) throws { + try typeResolver.register(type, namespace: namespace, typeName: name) + } + + public func serialize(_ value: T) throws -> Data { + try serializeRoot { context in + try writeRootTypedValue(value, context: context) + } + } + + public func deserialize(_ data: Data, as _: T.Type = T.self) throws -> T { + try deserializeRoot( + data: data + ) { context in + try readRootTypedValue(context: context) + } + } + + public func serialize(_ value: T, to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try writeRootTypedValue(value, context: context) + } + } + + public func deserialize(from buffer: ByteBuffer, as _: T.Type = T.self) throws -> T { + try deserializeRoot( + from: buffer + ) { context in + try readRootTypedValue(context: context) + } + } + + @_disfavoredOverload + public func serialize(_ value: Any) throws -> Data { + try serializeRoot { context in + try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(_ data: Data, as _: Any.Type = Any.self) throws -> Any { + try deserializeRoot( + data: data + ) { context in + try castAnyDynamicValue( + context.readAny(refMode: refMode, readTypeInfo: true), + to: Any.self + ) + } + } + + @_disfavoredOverload + public func serialize(_ value: AnyObject) throws -> Data { + try serializeRoot { context in + try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(_ data: Data, as _: AnyObject.Type = AnyObject.self) throws -> AnyObject { + try deserializeRoot( + data: data + ) { context in + try castAnyDynamicValue( + context.readAny(refMode: refMode, readTypeInfo: true), + to: AnyObject.self + ) + } + } + + @_disfavoredOverload + public func serialize(_ value: any Serializer) throws -> Data { + try serializeRoot { context in + try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(_ data: Data, as _: (any Serializer).Type = (any Serializer).self) throws + -> any Serializer { + try deserializeRoot( + data: data + ) { context in + try castAnyDynamicValue( + context.readAny(refMode: refMode, readTypeInfo: true), + to: (any Serializer).self + ) + } + } + + @_disfavoredOverload + public func serialize(_ value: [Any]) throws -> Data { + try serializeRoot { context in + try context.writeListOfAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(_ data: Data, as _: [Any].Type = [Any].self) throws -> [Any] { + try deserializeRoot( + data: data + ) { context in + try context.readListOfAny(refMode: refMode, readTypeInfo: true) ?? [] + } + } + + @_disfavoredOverload + public func serialize(_ value: [String: Any]) throws -> Data { + try serializeRoot { context in + try context.writeMapStringToAny( + value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(_ data: Data, as _: [String: Any].Type = [String: Any].self) throws + -> [String: Any] { + try deserializeRoot( + data: data + ) { context in + try context.readMapStringToAny(refMode: refMode, readTypeInfo: true) ?? [:] + } + } + + @_disfavoredOverload + public func serialize(_ value: [Int32: Any]) throws -> Data { + try serializeRoot { context in + try context.writeMapInt32ToAny( + value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(_ data: Data, as _: [Int32: Any].Type = [Int32: Any].self) throws + -> [Int32: Any] { + try deserializeRoot( + data: data + ) { context in + try context.readMapInt32ToAny(refMode: refMode, readTypeInfo: true) ?? [:] + } + } + + @_disfavoredOverload + public func serialize(_ value: [AnyHashable: Any]) throws -> Data { + try serializeRoot { context in + try context.writeMapAnyHashableToAny( + value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(_ data: Data, as _: [AnyHashable: Any].Type = [AnyHashable: Any].self) + throws -> [AnyHashable: Any] { + try deserializeRoot( + data: data + ) { context in + try context.readMapAnyHashableToAny(refMode: refMode, readTypeInfo: true) ?? [:] + } + } + + @_disfavoredOverload + public func serialize(_ value: [Any], to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try context.writeListOfAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func serialize(_ value: Any, to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(from buffer: ByteBuffer, as _: Any.Type = Any.self) throws -> Any { + try deserializeRoot( + from: buffer + ) { context in + try castAnyDynamicValue( + context.readAny(refMode: refMode, readTypeInfo: true), + to: Any.self + ) + } + } + + @_disfavoredOverload + public func serialize(_ value: AnyObject, to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(from buffer: ByteBuffer, as _: AnyObject.Type = AnyObject.self) throws + -> AnyObject { + try deserializeRoot( + from: buffer + ) { context in + try castAnyDynamicValue( + context.readAny(refMode: refMode, readTypeInfo: true), + to: AnyObject.self + ) + } + } + + @_disfavoredOverload + public func serialize(_ value: any Serializer, to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize( + from buffer: ByteBuffer, + as _: (any Serializer).Type = (any Serializer).self + ) throws -> any Serializer { + try deserializeRoot( + from: buffer + ) { context in + try castAnyDynamicValue( + context.readAny(refMode: refMode, readTypeInfo: true), + to: (any Serializer).self + ) + } + } + + @_disfavoredOverload + public func deserialize(from buffer: ByteBuffer, as _: [Any].Type = [Any].self) throws -> [Any] { + try deserializeRoot( + from: buffer + ) { context in + try context.readListOfAny(refMode: refMode, readTypeInfo: true) ?? [] + } + } + + @_disfavoredOverload + public func serialize(_ value: [String: Any], to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try context.writeMapStringToAny( + value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(from buffer: ByteBuffer, as _: [String: Any].Type = [String: Any].self) + throws -> [String: Any] { + try deserializeRoot( + from: buffer + ) { context in + try context.readMapStringToAny(refMode: refMode, readTypeInfo: true) ?? [:] + } + } + + @_disfavoredOverload + public func serialize(_ value: [Int32: Any], to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try context.writeMapInt32ToAny( + value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func serialize(_ value: [AnyHashable: Any], to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try context.writeMapAnyHashableToAny( + value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(from buffer: ByteBuffer, as _: [Int32: Any].Type = [Int32: Any].self) + throws -> [Int32: Any] { + try deserializeRoot( + from: buffer + ) { context in + try context.readMapInt32ToAny(refMode: refMode, readTypeInfo: true) ?? [:] + } + } + + @_disfavoredOverload + public func deserialize( + from buffer: ByteBuffer, as _: [AnyHashable: Any].Type = [AnyHashable: Any].self + ) throws -> [AnyHashable: Any] { + try deserializeRoot( + from: buffer + ) { context in + try context.readMapAnyHashableToAny(refMode: refMode, readTypeInfo: true) ?? [:] + } + } + + @inlinable + @inline(__always) + func writeHead(buffer: ByteBuffer) { + buffer.writeUInt8(config.xlang ? ForyHeaderFlag.isXlang : 0) + } + + @inlinable + @inline(__always) + func readHead(buffer: ByteBuffer) throws { + let bitmap = try buffer.readUInt8() + let expected = config.xlang ? ForyHeaderFlag.isXlang : 0 + if bitmap != expected { + try readHeadSlow(bitmap: bitmap, expected: expected) + } + } + + @usableFromInline + @inline(never) + func readHeadSlow(bitmap: UInt8, expected: UInt8) throws { + if (bitmap & ~ForyHeaderFlag.knownMask) != 0 || (bitmap & ForyHeaderFlag.isOutOfBand) != 0 { + throw ForyError.invalidData("unsupported root header bitmap 0x\(String(bitmap, radix: 16))") + } + if (bitmap & ForyHeaderFlag.isXlang) != (expected & ForyHeaderFlag.isXlang) { + throw ForyError.invalidData("xlang bitmap mismatch") + } + } + + @inline(__always) + private var refMode: RefMode { + config.trackRef ? .tracking : .nullOnly + } + + private func writeRootTypedValue( + _ value: T, + context: WriteContext + ) throws { + let writeTypeInfo = config.xlang || config.compatible + try value.foryWrite( + context, + refMode: config.trackRef ? .tracking : (writeTypeInfo ? .nullOnly : .none), + writeTypeInfo: writeTypeInfo, + hasGenerics: false + ) + } + + @inline(__always) + private func readRootTypedValue( + context: ReadContext + ) throws -> T { + let readTypeInfo = config.xlang || config.compatible + return try T.foryRead( + context, + refMode: config.trackRef ? .tracking : (readTypeInfo ? .nullOnly : .none), + readTypeInfo: readTypeInfo + ) + } + + @inline(__always) + func withReusableReadContext( + data: Data, + _ body: (ReadContext) throws -> R + ) rethrows -> R { + readContext.buffer.replace(with: data) + defer { + readContext.reset() + } + return try body(readContext) + } + + @inline(__always) + private func serializeRoot( + _ body: (WriteContext) throws -> Void + ) throws -> Data { + typeResolver.finishRegistration() + let context = writeContext + context.buffer.clear() + defer { + context.reset() + } + writeHead(buffer: context.buffer) + try body(context) + return context.buffer.copyToData() + } + + @inline(__always) + private func appendSerializedRoot( + to output: inout Data, + _ body: (WriteContext) throws -> Void + ) throws { + typeResolver.finishRegistration() + let context = writeContext + context.buffer.clear() + defer { + context.reset() + } + writeHead(buffer: context.buffer) + try body(context) + output.append(contentsOf: context.buffer.storage.prefix(context.buffer.count)) + } + + @inline(__always) + private func deserializeRoot( + data: Data, + _ body: (ReadContext) throws -> R + ) throws -> R { + typeResolver.finishRegistration() + return try withReusableReadContext(data: data) { context in + try readHead(buffer: context.buffer) + let value = try body(context) + if context.buffer.remaining != 0 { + throw ForyError.invalidData( + "unexpected trailing bytes at root: \(context.buffer.remaining)") + } + return value + } + } + + @inline(__always) + private func deserializeRoot( + from buffer: ByteBuffer, + _ body: (ReadContext) throws -> R + ) throws -> R { + typeResolver.finishRegistration() + readContext.buffer.swapState(with: buffer) + defer { + readContext.buffer.swapState(with: buffer) + readContext.reset() + } + try readHead(buffer: readContext.buffer) + return try body(readContext) + } } diff --git a/swift/Sources/Fory/ForyFlags.swift b/swift/Sources/Fory/ForyFlags.swift index e0a4203760..848538b746 100644 --- a/swift/Sources/Fory/ForyFlags.swift +++ b/swift/Sources/Fory/ForyFlags.swift @@ -18,27 +18,27 @@ import Foundation public enum RefFlag: Int8 { - case null = -3 - case ref = -2 - case notNullValue = -1 - case refValue = 0 + case null = -3 + case ref = -2 + case notNullValue = -1 + case refValue = 0 } public enum RefMode: UInt8, Equatable { - case none = 0 - case nullOnly = 1 - case tracking = 2 + case none = 0 + case nullOnly = 1 + case tracking = 2 - public static func from(nullable: Bool, trackRef: Bool) -> Self { - if trackRef { - return .tracking - } - return nullable ? .nullOnly : .none + public static func from(nullable: Bool, trackRef: Bool) -> Self { + if trackRef { + return .tracking } + return nullable ? .nullOnly : .none + } } public enum ForyHeaderFlag { - public static let isNull: UInt8 = 0x01 - public static let isXlang: UInt8 = 0x02 - public static let isOutOfBand: UInt8 = 0x04 + public static let isXlang: UInt8 = 0x01 + public static let isOutOfBand: UInt8 = 0x02 + public static let knownMask: UInt8 = isXlang | isOutOfBand } diff --git a/swift/Sources/Fory/ReadContext.swift b/swift/Sources/Fory/ReadContext.swift index fb98f3b90d..922c170198 100644 --- a/swift/Sources/Fory/ReadContext.swift +++ b/swift/Sources/Fory/ReadContext.swift @@ -268,6 +268,8 @@ public final class ReadContext { bodySize += Int(try buffer.readVarUInt32()) } if header == localTypeDefHeader { + // Header-cache hits intentionally skip without rehashing. Entries reach this + // cache only after a successful TypeDef parse and 52-bit body-hash validation. compatibleTypeDefTypeInfos.push(localTypeInfo) try buffer.skip(bodySize) return nil @@ -301,6 +303,8 @@ public final class ReadContext { bodySize += Int(try buffer.readVarUInt32()) } if let cached = typeResolver.getTypeInfo(forHeader: header) { + // Header-cache hits intentionally skip without rehashing. Entries reach this cache only + // after a successful TypeDef parse and 52-bit body-hash validation. try buffer.skip(bodySize) compatibleTypeDefTypeInfos.push(cached) return cached @@ -336,6 +340,8 @@ public final class ReadContext { } if header == localTypeDefHeader { + // Header-cache hits intentionally skip without rehashing. Entries reach this + // cache only after a successful TypeDef parse and 52-bit body-hash validation. compatibleTypeDefTypeInfos.push(localTypeInfo) try buffer.skip(bodySize) return localTypeInfo diff --git a/swift/Sources/Fory/TypeMeta.swift b/swift/Sources/Fory/TypeMeta.swift index e9396ec7ab..16cc64ecfe 100644 --- a/swift/Sources/Fory/TypeMeta.swift +++ b/swift/Sources/Fory/TypeMeta.swift @@ -19,668 +19,757 @@ import Foundation private let smallNumFieldsThreshold = 0b1_1111 private let registerByNameFlag: UInt8 = 0b10_0000 +private let compatibleTypeMetaFlag: UInt8 = 0b0100_0000 +private let structTypeMetaFlag: UInt8 = 0b1000_0000 private let fieldNameSizeThreshold = 0b1111 private let bigNameThreshold = 0b11_1111 -private let typeMetaHasFieldsMetaFlag: UInt64 = 1 << 8 -private let typeMetaCompressedFlag: UInt64 = 1 << 9 +private let typeMetaCompressedFlag: UInt64 = 1 << 8 +private let typeMetaReservedFlags: UInt64 = 0b111 << 9 private let typeMetaSizeMask: UInt64 = 0xFF -private let typeMetaNumHashBits: UInt64 = 50 +private let typeMetaNumHashBits: UInt64 = 52 private let typeMetaHashSeed: UInt64 = 47 private let noUserTypeID: UInt32 = UInt32.max public let namespaceMetaStringEncodings: [MetaStringEncoding] = [ - .utf8, - .allToLowerSpecial, - .lowerUpperDigitSpecial + .utf8, + .allToLowerSpecial, + .lowerUpperDigitSpecial ] public let typeNameMetaStringEncodings: [MetaStringEncoding] = [ - .utf8, - .allToLowerSpecial, - .lowerUpperDigitSpecial, - .firstToLowerSpecial + .utf8, + .allToLowerSpecial, + .lowerUpperDigitSpecial, + .firstToLowerSpecial ] public let fieldNameMetaStringEncodings: [MetaStringEncoding] = [ - .utf8, - .allToLowerSpecial, - .lowerUpperDigitSpecial + .utf8, + .allToLowerSpecial, + .lowerUpperDigitSpecial ] public final class TypeMeta: Equatable, @unchecked Sendable { - public struct FieldType: Equatable, Sendable { - public var typeID: UInt32 - public var nullable: Bool - public var trackRef: Bool - public var generics: [FieldType] - - public init( - typeID: UInt32, - nullable: Bool, - trackRef: Bool = false, - generics: [FieldType] = [] - ) { - self.typeID = typeID - self.nullable = nullable - self.trackRef = trackRef - self.generics = generics - } - - fileprivate func write( - _ buffer: ByteBuffer, - writeFlags: Bool, - nullableOverride: Bool? = nil - ) { - if writeFlags { - var header = typeID << 2 - if nullableOverride ?? nullable { - header |= 0b10 - } - if trackRef { - header |= 0b1 - } - buffer.writeVarUInt32(header) - } else { - buffer.writeUInt8(UInt8(truncatingIfNeeded: typeID)) - } - - if typeID == TypeId.list.rawValue || typeID == TypeId.set.rawValue { - let element = generics.first ?? FieldType(typeID: TypeId.unknown.rawValue, nullable: true) - element.write(buffer, writeFlags: true, nullableOverride: element.nullable) - } else if typeID == TypeId.map.rawValue { - let key = generics.first ?? FieldType(typeID: TypeId.unknown.rawValue, nullable: true) - let value = generics.dropFirst().first ?? FieldType(typeID: TypeId.unknown.rawValue, nullable: true) - key.write(buffer, writeFlags: true, nullableOverride: key.nullable) - value.write(buffer, writeFlags: true, nullableOverride: value.nullable) - } - } + public struct FieldType: Equatable, Sendable { + public var typeID: UInt32 + public var nullable: Bool + public var trackRef: Bool + public var generics: [FieldType] - fileprivate static func read( - _ buffer: ByteBuffer, - readFlags: Bool, - nullable: Bool? = nil, - trackRef: Bool? = nil - ) throws -> FieldType { - let header: UInt32 - if readFlags { - header = try buffer.readVarUInt32() - } else { - header = UInt32(try buffer.readUInt8()) - } - - let typeID: UInt32 - let resolvedNullable: Bool - let resolvedTrackRef: Bool - - if readFlags { - typeID = header >> 2 - resolvedNullable = (header & 0b10) != 0 - resolvedTrackRef = (header & 0b1) != 0 - } else { - typeID = header - resolvedNullable = nullable ?? false - resolvedTrackRef = trackRef ?? false - } - - if typeID == TypeId.list.rawValue || typeID == TypeId.set.rawValue { - let element = try read(buffer, readFlags: true) - return FieldType( - typeID: typeID, - nullable: resolvedNullable, - trackRef: resolvedTrackRef, - generics: [element] - ) - } - if typeID == TypeId.map.rawValue { - let key = try read(buffer, readFlags: true) - let value = try read(buffer, readFlags: true) - return FieldType( - typeID: typeID, - nullable: resolvedNullable, - trackRef: resolvedTrackRef, - generics: [key, value] - ) - } - - return FieldType( - typeID: typeID, - nullable: resolvedNullable, - trackRef: resolvedTrackRef, - generics: [] - ) - } + public init( + typeID: UInt32, + nullable: Bool, + trackRef: Bool = false, + generics: [FieldType] = [] + ) { + self.typeID = typeID + self.nullable = nullable + self.trackRef = trackRef + self.generics = generics } - public struct FieldInfo: Equatable, Sendable { - public var fieldID: Int16? - public var fieldName: String - public var fieldType: FieldType + fileprivate func write( + _ buffer: ByteBuffer, + writeFlags: Bool, + nullableOverride: Bool? = nil + ) { + if writeFlags { + var header = typeID << 2 + if nullableOverride ?? nullable { + header |= 0b10 + } + if trackRef { + header |= 0b1 + } + buffer.writeVarUInt32(header) + } else { + buffer.writeUInt8(UInt8(truncatingIfNeeded: typeID)) + } + + if typeID == TypeId.list.rawValue || typeID == TypeId.set.rawValue { + let element = generics.first ?? FieldType(typeID: TypeId.unknown.rawValue, nullable: true) + element.write(buffer, writeFlags: true, nullableOverride: element.nullable) + } else if typeID == TypeId.map.rawValue { + let key = generics.first ?? FieldType(typeID: TypeId.unknown.rawValue, nullable: true) + let value = + generics.dropFirst().first ?? FieldType(typeID: TypeId.unknown.rawValue, nullable: true) + key.write(buffer, writeFlags: true, nullableOverride: key.nullable) + value.write(buffer, writeFlags: true, nullableOverride: value.nullable) + } + } - public init(fieldID: Int16?, fieldName: String, fieldType: FieldType) { - self.fieldID = fieldID - self.fieldName = fieldName - self.fieldType = fieldType - } + fileprivate static func read( + _ buffer: ByteBuffer, + readFlags: Bool, + nullable: Bool? = nil, + trackRef: Bool? = nil + ) throws -> FieldType { + let header: UInt32 + if readFlags { + header = try buffer.readVarUInt32() + } else { + header = UInt32(try buffer.readUInt8()) + } + + let typeID: UInt32 + let resolvedNullable: Bool + let resolvedTrackRef: Bool + + if readFlags { + typeID = header >> 2 + resolvedNullable = (header & 0b10) != 0 + resolvedTrackRef = (header & 0b1) != 0 + } else { + typeID = header + resolvedNullable = nullable ?? false + resolvedTrackRef = trackRef ?? false + } + + if typeID == TypeId.list.rawValue || typeID == TypeId.set.rawValue { + let element = try read(buffer, readFlags: true) + return FieldType( + typeID: typeID, + nullable: resolvedNullable, + trackRef: resolvedTrackRef, + generics: [element] + ) + } + if typeID == TypeId.map.rawValue { + let key = try read(buffer, readFlags: true) + let value = try read(buffer, readFlags: true) + return FieldType( + typeID: typeID, + nullable: resolvedNullable, + trackRef: resolvedTrackRef, + generics: [key, value] + ) + } + + return FieldType( + typeID: typeID, + nullable: resolvedNullable, + trackRef: resolvedTrackRef, + generics: [] + ) + } + } - fileprivate func write(_ buffer: ByteBuffer) throws { - var header: UInt8 = 0 - if fieldType.trackRef { - header |= 0b1 - } - if fieldType.nullable { - header |= 0b10 - } - - if let fieldID { - if fieldID < 0 { - throw ForyError.encodingError("negative field id is invalid") - } - let size = Int(fieldID) - header |= UInt8(0b11 << 6) - if size >= fieldNameSizeThreshold { - header |= 0b0011_1100 - buffer.writeUInt8(header) - buffer.writeVarUInt32(UInt32(size - fieldNameSizeThreshold)) - } else { - header |= UInt8(size << 2) - buffer.writeUInt8(header) - } - fieldType.write(buffer, writeFlags: false) - return - } - - let snakeName = lowerCamelToLowerUnderscore(fieldName) - let encoded = try MetaStringEncoder.fieldName.encode(snakeName, allowedEncodings: fieldNameMetaStringEncodings) - guard let encodingIndex = fieldNameMetaStringEncodings.firstIndex(of: encoded.encoding) else { - throw ForyError.encodingError("unsupported field name encoding") - } - - let size = encoded.bytes.count - 1 - header |= UInt8(encodingIndex << 6) - if size >= fieldNameSizeThreshold { - header |= 0b0011_1100 - buffer.writeUInt8(header) - buffer.writeVarUInt32(UInt32(size - fieldNameSizeThreshold)) - } else { - header |= UInt8(size << 2) - buffer.writeUInt8(header) - } - - fieldType.write(buffer, writeFlags: false) - buffer.writeBytes(encoded.bytes) - } + public struct FieldInfo: Equatable, Sendable { + public var fieldID: Int16? + public var fieldName: String + public var fieldType: FieldType - fileprivate static func read(_ buffer: ByteBuffer) throws -> FieldInfo { - let header = try buffer.readUInt8() - let encodingFlags = Int((header >> 6) & 0b11) - var size = Int((header >> 2) & 0b1111) - if size == fieldNameSizeThreshold { - size += Int(try buffer.readVarUInt32()) - } - size += 1 - - let nullable = (header & 0b10) != 0 - let trackRef = (header & 0b1) != 0 - let fieldType = try FieldType.read( - buffer, - readFlags: false, - nullable: nullable, - trackRef: trackRef - ) - - if encodingFlags == 3 { - let fieldID = Int16(size - 1) - return FieldInfo( - fieldID: fieldID, - fieldName: "$tag\(fieldID)", - fieldType: fieldType - ) - } - - guard encodingFlags < fieldNameMetaStringEncodings.count else { - throw ForyError.invalidData("invalid field name encoding id") - } - let nameBytes = try buffer.readBytes(count: size) - let name = try MetaStringDecoder.fieldName - .decode(bytes: nameBytes, encoding: fieldNameMetaStringEncodings[encodingFlags]) - .value - - return FieldInfo(fieldID: nil, fieldName: name, fieldType: fieldType) - } + public init(fieldID: Int16?, fieldName: String, fieldType: FieldType) { + self.fieldID = fieldID + self.fieldName = fieldName + self.fieldType = fieldType } - public let typeID: UInt32? - public let userTypeID: UInt32? - public let namespace: MetaString - public let typeName: MetaString - public let registerByName: Bool - public let fields: [FieldInfo] - public let hasFieldsMeta: Bool - public let compressed: Bool - public let headerHash: UInt64 - - public init( - typeID: UInt32?, - userTypeID: UInt32?, - namespace: MetaString, - typeName: MetaString, - registerByName: Bool, - fields: [FieldInfo], - hasFieldsMeta: Bool = true, - compressed: Bool = false, - headerHash: UInt64 = 0 - ) throws { - if registerByName { - if typeName.value.isEmpty { - throw ForyError.encodingError("type name is required in register-by-name mode") - } + fileprivate func write(_ buffer: ByteBuffer) throws { + var header: UInt8 = 0 + if fieldType.trackRef { + header |= 0b1 + } + if fieldType.nullable { + header |= 0b10 + } + + if let fieldID { + if fieldID < 0 { + throw ForyError.encodingError("negative field id is invalid") + } + let size = Int(fieldID) + header |= UInt8(0b11 << 6) + if size >= fieldNameSizeThreshold { + header |= 0b0011_1100 + buffer.writeUInt8(header) + buffer.writeVarUInt32(UInt32(size - fieldNameSizeThreshold)) } else { - guard typeID != nil else { - throw ForyError.encodingError("type id is required in register-by-id mode") - } - guard let userTypeID, userTypeID != noUserTypeID else { - throw ForyError.encodingError("user type id is required in register-by-id mode") - } - } - - self.typeID = typeID - self.userTypeID = userTypeID - self.namespace = namespace - self.typeName = typeName - self.registerByName = registerByName - self.fields = fields - self.hasFieldsMeta = hasFieldsMeta - self.compressed = compressed - self.headerHash = headerHash - } - - public static func == (lhs: TypeMeta, rhs: TypeMeta) -> Bool { - lhs.typeID == rhs.typeID && - lhs.userTypeID == rhs.userTypeID && - lhs.namespace == rhs.namespace && - lhs.typeName == rhs.typeName && - lhs.registerByName == rhs.registerByName && - lhs.fields == rhs.fields && - lhs.hasFieldsMeta == rhs.hasFieldsMeta && - lhs.compressed == rhs.compressed && - lhs.headerHash == rhs.headerHash - } - - public func encode() throws -> [UInt8] { - if compressed { - throw ForyError.encodingError("compressed TypeMeta is not supported yet") - } + header |= UInt8(size << 2) + buffer.writeUInt8(header) + } + fieldType.write(buffer, writeFlags: false) + return + } + + let snakeName = lowerCamelToLowerUnderscore(fieldName) + let encoded = try MetaStringEncoder.fieldName.encode( + snakeName, allowedEncodings: fieldNameMetaStringEncodings) + guard let encodingIndex = fieldNameMetaStringEncodings.firstIndex(of: encoded.encoding) else { + throw ForyError.encodingError("unsupported field name encoding") + } + + let size = encoded.bytes.count - 1 + header |= UInt8(encodingIndex << 6) + if size >= fieldNameSizeThreshold { + header |= 0b0011_1100 + buffer.writeUInt8(header) + buffer.writeVarUInt32(UInt32(size - fieldNameSizeThreshold)) + } else { + header |= UInt8(size << 2) + buffer.writeUInt8(header) + } + + fieldType.write(buffer, writeFlags: false) + buffer.writeBytes(encoded.bytes) + } - let body = try encodeBody() - let bodyHash = MurmurHash3.x64_128(body, seed: typeMetaHashSeed).0 - let shifted = bodyHash << (64 - typeMetaNumHashBits) - let signed = Int64(bitPattern: shifted) - let absSigned = signed == Int64.min ? signed : Swift.abs(signed) + fileprivate static func read(_ buffer: ByteBuffer) throws -> FieldInfo { + let header = try buffer.readUInt8() + let encodingFlags = Int((header >> 6) & 0b11) + var size = Int((header >> 2) & 0b1111) + if size == fieldNameSizeThreshold { + size += Int(try buffer.readVarUInt32()) + } + size += 1 + + let nullable = (header & 0b10) != 0 + let trackRef = (header & 0b1) != 0 + let fieldType = try FieldType.read( + buffer, + readFlags: false, + nullable: nullable, + trackRef: trackRef + ) + + if encodingFlags == 3 { + let fieldID = Int16(size - 1) + return FieldInfo( + fieldID: fieldID, + fieldName: "$tag\(fieldID)", + fieldType: fieldType + ) + } - var header = UInt64(bitPattern: absSigned) - if hasFieldsMeta { - header |= typeMetaHasFieldsMetaFlag - } - if compressed { - header |= typeMetaCompressedFlag - } - header |= UInt64(min(body.count, Int(typeMetaSizeMask))) + guard encodingFlags < fieldNameMetaStringEncodings.count else { + throw ForyError.invalidData("invalid field name encoding id") + } + let nameBytes = try buffer.readBytes(count: size) + let name = try MetaStringDecoder.fieldName + .decode(bytes: nameBytes, encoding: fieldNameMetaStringEncodings[encodingFlags]) + .value - let buffer = ByteBuffer(capacity: body.count + 16) - buffer.writeUInt64(header) - if body.count >= Int(typeMetaSizeMask) { - buffer.writeVarUInt32(UInt32(body.count - Int(typeMetaSizeMask))) - } - buffer.writeBytes(body) - return Array(buffer.storage.prefix(buffer.count)) + return FieldInfo(fieldID: nil, fieldName: name, fieldType: fieldType) } - - public static func decode(_ bytes: [UInt8]) throws -> TypeMeta { - try decode(ByteBuffer(bytes: bytes)) + } + + public let typeID: UInt32? + public let userTypeID: UInt32? + public let namespace: MetaString + public let typeName: MetaString + public let registerByName: Bool + public let fields: [FieldInfo] + public let compressed: Bool + public let headerHash: UInt64 + + public init( + typeID: UInt32?, + userTypeID: UInt32?, + namespace: MetaString, + typeName: MetaString, + registerByName: Bool, + fields: [FieldInfo], + compressed: Bool = false, + headerHash: UInt64 = 0 + ) throws { + guard typeID != nil else { + throw ForyError.encodingError("type id is required in type metadata") + } + if registerByName { + if typeName.value.isEmpty { + throw ForyError.encodingError("type name is required in register-by-name mode") + } + } else { + guard let userTypeID, userTypeID != noUserTypeID else { + throw ForyError.encodingError("user type id is required in register-by-id mode") + } } - public static func decode(_ buffer: ByteBuffer) throws -> TypeMeta { - let header = try buffer.readUInt64() - let compressed = (header & typeMetaCompressedFlag) != 0 - let hasFieldsMeta = (header & typeMetaHasFieldsMetaFlag) != 0 - - var metaSize = Int(header & typeMetaSizeMask) - if metaSize == Int(typeMetaSizeMask) { - metaSize += Int(try buffer.readVarUInt32()) - } - - let encodedBody = try buffer.readBytes(count: metaSize) - if compressed { - throw ForyError.encodingError("compressed TypeMeta is not supported yet") - } - - let bodyReader = ByteBuffer(bytes: encodedBody) - let metaHeader = try bodyReader.readUInt8() - - var numFields = Int(metaHeader & UInt8(smallNumFieldsThreshold)) - if numFields == smallNumFieldsThreshold { - numFields += Int(try bodyReader.readVarUInt32()) - } - - let registerByName = (metaHeader & registerByNameFlag) != 0 + self.typeID = typeID + self.userTypeID = userTypeID + self.namespace = namespace + self.typeName = typeName + self.registerByName = registerByName + self.fields = fields + self.compressed = compressed + self.headerHash = headerHash + } + + public static func == (lhs: TypeMeta, rhs: TypeMeta) -> Bool { + lhs.typeID == rhs.typeID && lhs.userTypeID == rhs.userTypeID && lhs.namespace == rhs.namespace + && lhs.typeName == rhs.typeName && lhs.registerByName == rhs.registerByName + && lhs.fields == rhs.fields && lhs.compressed == rhs.compressed + && lhs.headerHash == rhs.headerHash + } + + public func encode() throws -> [UInt8] { + if compressed { + throw ForyError.encodingError("compressed TypeMeta is not supported yet") + } - let typeID: UInt32? - let userTypeID: UInt32? - let namespace: MetaString - let typeName: MetaString + let body = try encodeBody() + var header = Self.typeMetaHeaderHash(body) + if compressed { + header |= typeMetaCompressedFlag + } + header |= UInt64(min(body.count, Int(typeMetaSizeMask))) - if registerByName { - namespace = try readName(bodyReader, decoder: .namespace, encodings: namespaceMetaStringEncodings) - typeName = try readName(bodyReader, decoder: .typeName, encodings: typeNameMetaStringEncodings) - typeID = nil - userTypeID = nil - } else { - let rawTypeID = try bodyReader.readUInt8() - typeID = UInt32(rawTypeID) - userTypeID = try bodyReader.readVarUInt32() - namespace = MetaString.empty(specialChar1: ".", specialChar2: "_") - typeName = MetaString.empty(specialChar1: "$", specialChar2: "_") - } + let buffer = ByteBuffer(capacity: body.count + 16) + buffer.writeUInt64(header) + if body.count >= Int(typeMetaSizeMask) { + buffer.writeVarUInt32(UInt32(body.count - Int(typeMetaSizeMask))) + } + buffer.writeBytes(body) + return Array(buffer.storage.prefix(buffer.count)) + } + + public static func decode(_ bytes: [UInt8]) throws -> TypeMeta { + try decode(ByteBuffer(bytes: bytes)) + } + + public static func decode(_ buffer: ByteBuffer) throws -> TypeMeta { + let header = try buffer.readUInt64() + if (header & typeMetaReservedFlags) != 0 { + throw ForyError.invalidData("invalid TypeMeta global header") + } + let compressed = (header & typeMetaCompressedFlag) != 0 - var fieldInfos: [FieldInfo] = [] - if numFields > bodyReader.remaining { - throw ForyError.invalidData( - "type meta field count \(numFields) exceeds remaining bytes \(bodyReader.remaining)" - ) - } - fieldInfos.reserveCapacity(numFields) - for _ in 0..> (64 - typeMetaNumHashBits) - ) + let bodyReader = ByteBuffer(bytes: encodedBody) + let metaHeader = try bodyReader.readUInt8() + + let isStruct = (metaHeader & structTypeMetaFlag) != 0 + var numFields = 0 + let registerByName: Bool + let typeID: UInt32 + let userTypeID: UInt32? + let namespace: MetaString + let typeName: MetaString + + if isStruct { + registerByName = (metaHeader & registerByNameFlag) != 0 + let compatible = (metaHeader & compatibleTypeMetaFlag) != 0 + numFields = Int(metaHeader & UInt8(smallNumFieldsThreshold)) + if numFields == smallNumFieldsThreshold { + numFields += Int(try bodyReader.readVarUInt32()) + } + if registerByName { + typeID = compatible ? TypeId.namedCompatibleStruct.rawValue : TypeId.namedStruct.rawValue + } else { + typeID = compatible ? TypeId.compatibleStruct.rawValue : TypeId.structType.rawValue + } + } else { + if (metaHeader & 0b0111_0000) != 0 { + throw ForyError.invalidData("invalid non-struct TypeMeta kind header") + } + let kind = try Self.typeID(forNonStructKindCode: metaHeader & 0b1111) + registerByName = Self.isNamedTypeMetaKind(kind) + typeID = kind.rawValue } - private func encodeBody() throws -> [UInt8] { - let buffer = ByteBuffer(capacity: 128) + if registerByName { + namespace = try readName( + bodyReader, decoder: .namespace, encodings: namespaceMetaStringEncodings) + typeName = try readName( + bodyReader, decoder: .typeName, encodings: typeNameMetaStringEncodings) + userTypeID = nil + } else { + userTypeID = try bodyReader.readVarUInt32() + namespace = MetaString.empty(specialChar1: ".", specialChar2: "_") + typeName = MetaString.empty(specialChar1: "$", specialChar2: "_") + } - var metaHeader = UInt8(min(fields.count, smallNumFieldsThreshold)) - if registerByName { - metaHeader |= registerByNameFlag - } - buffer.writeUInt8(metaHeader) + var fieldInfos: [FieldInfo] = [] + if numFields > bodyReader.remaining { + throw ForyError.invalidData( + "type meta field count \(numFields) exceeds remaining bytes \(bodyReader.remaining)" + ) + } + fieldInfos.reserveCapacity(numFields) + for _ in 0..= smallNumFieldsThreshold { - buffer.writeVarUInt32(UInt32(fields.count - smallNumFieldsThreshold)) - } + if !isStruct && !fieldInfos.isEmpty { + throw ForyError.invalidData("non-struct TypeMeta cannot carry field metadata") + } + if bodyReader.remaining != 0 { + throw ForyError.invalidData("unexpected trailing bytes in TypeMeta body") + } + if (header & Self.hashMask()) != Self.typeMetaHeaderHash(encodedBody) { + throw ForyError.invalidData("invalid TypeMeta metadata hash") + } - if registerByName { - try Self.writeName(buffer, name: namespace, encodings: namespaceMetaStringEncodings) - try Self.writeName(buffer, name: typeName, encodings: typeNameMetaStringEncodings) - } else { - guard let typeID else { - throw ForyError.encodingError("type id is required in register-by-id mode") - } - guard let userTypeID, userTypeID != noUserTypeID else { - throw ForyError.encodingError("user type id is required in register-by-id mode") - } - buffer.writeUInt8(UInt8(truncatingIfNeeded: typeID)) - buffer.writeVarUInt32(userTypeID) - } + return try TypeMeta( + typeID: typeID, + userTypeID: userTypeID, + namespace: namespace, + typeName: typeName, + registerByName: registerByName, + fields: fieldInfos, + compressed: compressed, + headerHash: header >> (64 - typeMetaNumHashBits) + ) + } + + private func encodeBody() throws -> [UInt8] { + let buffer = ByteBuffer(capacity: 128) + guard let rawTypeID = typeID, let resolvedTypeID = TypeId(rawValue: rawTypeID) else { + throw ForyError.encodingError("unsupported TypeMeta kind \(String(describing: typeID))") + } - for field in fields { - try field.write(buffer) - } + if Self.isStructTypeMetaKind(resolvedTypeID) { + var metaHeader = structTypeMetaFlag | UInt8(min(fields.count, smallNumFieldsThreshold)) + if resolvedTypeID == .compatibleStruct || resolvedTypeID == .namedCompatibleStruct { + metaHeader |= compatibleTypeMetaFlag + } + if registerByName { + metaHeader |= registerByNameFlag + } + buffer.writeUInt8(metaHeader) + + if fields.count >= smallNumFieldsThreshold { + buffer.writeVarUInt32(UInt32(fields.count - smallNumFieldsThreshold)) + } + } else { + if !fields.isEmpty { + throw ForyError.encodingError("non-struct TypeMeta cannot carry field metadata") + } + buffer.writeUInt8(try Self.nonStructKindCode(resolvedTypeID)) + } - return Array(buffer.storage.prefix(buffer.count)) + if registerByName { + try Self.writeName(buffer, name: namespace, encodings: namespaceMetaStringEncodings) + try Self.writeName(buffer, name: typeName, encodings: typeNameMetaStringEncodings) + } else { + guard let userTypeID, userTypeID != noUserTypeID else { + throw ForyError.encodingError("user type id is required in register-by-id mode") + } + buffer.writeVarUInt32(userTypeID) } - private static func writeName( - _ buffer: ByteBuffer, - name: MetaString, - encodings: [MetaStringEncoding] - ) throws { - let normalizedName: MetaString - if encodings.contains(name.encoding) { - normalizedName = name - } else { - let encoder: MetaStringEncoder - if encodings == namespaceMetaStringEncodings { - encoder = .namespace - } else if encodings == typeNameMetaStringEncodings { - encoder = .typeName - } else { - encoder = .fieldName - } - normalizedName = try encoder.encode(name.value, allowedEncodings: encodings) - } + for field in fields { + try field.write(buffer) + } - guard let encodingIndex = encodings.firstIndex(of: normalizedName.encoding) else { - throw ForyError.encodingError("failed to normalize meta string encoding") - } + return Array(buffer.storage.prefix(buffer.count)) + } + + private static func hashMask() -> UInt64 { + UInt64.max << (64 - typeMetaNumHashBits) + } + + private static func typeMetaHeaderHash(_ body: [UInt8]) -> UInt64 { + let bodyHash = MurmurHash3.x64_128(body, seed: typeMetaHashSeed).0 + let shifted = bodyHash << (64 - typeMetaNumHashBits) + let signed = Int64(bitPattern: shifted) + let absSigned = signed == Int64.min ? signed : Swift.abs(signed) + return UInt64(bitPattern: absSigned) & hashMask() + } + + private static func isStructTypeMetaKind(_ typeID: TypeId) -> Bool { + switch typeID { + case .structType, .compatibleStruct, .namedStruct, .namedCompatibleStruct: + return true + default: + return false + } + } + + private static func isNamedTypeMetaKind(_ typeID: TypeId) -> Bool { + switch typeID { + case .namedStruct, .namedCompatibleStruct, .namedEnum, .namedExt, .namedUnion: + return true + default: + return false + } + } + + private static func nonStructKindCode(_ typeID: TypeId) throws -> UInt8 { + switch typeID { + case .enumType: + return 0 + case .namedEnum: + return 1 + case .ext: + return 2 + case .namedExt: + return 3 + case .typedUnion: + return 4 + case .namedUnion: + return 5 + default: + throw ForyError.encodingError("unsupported TypeMeta kind \(typeID)") + } + } + + private static func typeID(forNonStructKindCode code: UInt8) throws -> TypeId { + switch code { + case 0: + return .enumType + case 1: + return .namedEnum + case 2: + return .ext + case 3: + return .namedExt + case 4: + return .typedUnion + case 5: + return .namedUnion + default: + throw ForyError.invalidData("unsupported TypeMeta kind code \(code)") + } + } + + private static func writeName( + _ buffer: ByteBuffer, + name: MetaString, + encodings: [MetaStringEncoding] + ) throws { + let normalizedName: MetaString + if encodings.contains(name.encoding) { + normalizedName = name + } else { + let encoder: MetaStringEncoder + if encodings == namespaceMetaStringEncodings { + encoder = .namespace + } else if encodings == typeNameMetaStringEncodings { + encoder = .typeName + } else { + encoder = .fieldName + } + normalizedName = try encoder.encode(name.value, allowedEncodings: encodings) + } - let bytes = normalizedName.bytes - if bytes.count >= bigNameThreshold { - buffer.writeUInt8(UInt8((bigNameThreshold << 2) | encodingIndex)) - buffer.writeVarUInt32(UInt32(bytes.count - bigNameThreshold)) - } else { - buffer.writeUInt8(UInt8((bytes.count << 2) | encodingIndex)) - } - buffer.writeBytes(bytes) - } - - private static func readName( - _ buffer: ByteBuffer, - decoder: MetaStringDecoder, - encodings: [MetaStringEncoding] - ) throws -> MetaString { - let header = try buffer.readUInt8() - let encodingIndex = Int(header & 0b11) - guard encodingIndex < encodings.count else { - throw ForyError.invalidData("invalid meta string encoding index") - } + guard let encodingIndex = encodings.firstIndex(of: normalizedName.encoding) else { + throw ForyError.encodingError("failed to normalize meta string encoding") + } - var length = Int(header >> 2) - if length >= bigNameThreshold { - length = bigNameThreshold + Int(try buffer.readVarUInt32()) - } - let bytes = try buffer.readBytes(count: length) - return try decoder.decode(bytes: bytes, encoding: encodings[encodingIndex]) + let bytes = normalizedName.bytes + if bytes.count >= bigNameThreshold { + buffer.writeUInt8(UInt8((bigNameThreshold << 2) | encodingIndex)) + buffer.writeVarUInt32(UInt32(bytes.count - bigNameThreshold)) + } else { + buffer.writeUInt8(UInt8((bytes.count << 2) | encodingIndex)) + } + buffer.writeBytes(bytes) + } + + private static func readName( + _ buffer: ByteBuffer, + decoder: MetaStringDecoder, + encodings: [MetaStringEncoding] + ) throws -> MetaString { + let header = try buffer.readUInt8() + let encodingIndex = Int(header & 0b11) + guard encodingIndex < encodings.count else { + throw ForyError.invalidData("invalid meta string encoding index") } - func assigningFieldIDs(from localTypeMeta: TypeMeta) throws -> TypeMeta { - guard !fields.isEmpty else { - return self - } + var length = Int(header >> 2) + if length >= bigNameThreshold { + length = bigNameThreshold + Int(try buffer.readVarUInt32()) + } + let bytes = try buffer.readBytes(count: length) + return try decoder.decode(bytes: bytes, encoding: encodings[encodingIndex]) + } - let localFields = localTypeMeta.fields - guard !localFields.isEmpty else { - return self - } + func assigningFieldIDs(from localTypeMeta: TypeMeta) throws -> TypeMeta { + guard !fields.isEmpty else { + return self + } - var fieldIndexByName: [String: (Int, FieldInfo)] = [:] - var fieldIndexByID: [Int16: (Int, FieldInfo)] = [:] - fieldIndexByName.reserveCapacity(localFields.count) - fieldIndexByID.reserveCapacity(localFields.count) + let localFields = localTypeMeta.fields + guard !localFields.isEmpty else { + return self + } - for (index, localField) in localFields.enumerated() { - fieldIndexByName[toSnakeCase(localField.fieldName)] = (index, localField) - if let fieldID = localField.fieldID, fieldID >= 0 { - fieldIndexByID[fieldID] = (index, localField) - } - } + var fieldIndexByName: [String: (Int, FieldInfo)] = [:] + var fieldIndexByID: [Int16: (Int, FieldInfo)] = [:] + fieldIndexByName.reserveCapacity(localFields.count) + fieldIndexByID.reserveCapacity(localFields.count) - var resolvedFields = fields - var changed = false - var usedLocalFields = Array(repeating: false, count: localFields.count) - - for index in resolvedFields.indices { - let field = resolvedFields[index] - - var localMatch: (Int, FieldInfo)? - if let fieldID = field.fieldID, fieldID >= 0 { - if let candidate = fieldIndexByID[fieldID], - Self.isCompatibleFieldType(field.fieldType, candidate.1.fieldType) { - localMatch = candidate - } - } - - if localMatch == nil { - if let candidate = fieldIndexByName[toSnakeCase(field.fieldName)], - Self.isCompatibleFieldType(field.fieldType, candidate.1.fieldType) { - localMatch = candidate - } - } - - if localMatch == nil { - for localIndex in localFields.indices where !usedLocalFields[localIndex] { - if Self.isCompatibleFieldType(field.fieldType, localFields[localIndex].fieldType) { - localMatch = (localIndex, localFields[localIndex]) - break - } - } - } - - guard let (sortedIndex, _) = localMatch, - sortedIndex <= Int(Int16.max) else { - if field.fieldID != -1 { - resolvedFields[index].fieldID = -1 - changed = true - } - continue - } - - let resolvedFieldID = Int16(sortedIndex) - if field.fieldID != resolvedFieldID { - resolvedFields[index].fieldID = resolvedFieldID - changed = true - } - usedLocalFields[sortedIndex] = true - } + for (index, localField) in localFields.enumerated() { + fieldIndexByName[toSnakeCase(localField.fieldName)] = (index, localField) + if let fieldID = localField.fieldID, fieldID >= 0 { + fieldIndexByID[fieldID] = (index, localField) + } + } - guard changed else { - return self - } + var resolvedFields = fields + var changed = false + var usedLocalFields = Array(repeating: false, count: localFields.count) + + for index in resolvedFields.indices { + let field = resolvedFields[index] + + var localMatch: (Int, FieldInfo)? + if let fieldID = field.fieldID, fieldID >= 0 { + if let candidate = fieldIndexByID[fieldID], + Self.isCompatibleFieldType(field.fieldType, candidate.1.fieldType) { + localMatch = candidate + } + } + + if localMatch == nil { + if let candidate = fieldIndexByName[toSnakeCase(field.fieldName)], + Self.isCompatibleFieldType(field.fieldType, candidate.1.fieldType) { + localMatch = candidate + } + } + + if localMatch == nil { + for localIndex in localFields.indices where !usedLocalFields[localIndex] { + if Self.isCompatibleFieldType(field.fieldType, localFields[localIndex].fieldType) { + localMatch = (localIndex, localFields[localIndex]) + break + } + } + } + + guard let (sortedIndex, _) = localMatch, + sortedIndex <= Int(Int16.max) + else { + if field.fieldID != -1 { + resolvedFields[index].fieldID = -1 + changed = true + } + continue + } + + let resolvedFieldID = Int16(sortedIndex) + if field.fieldID != resolvedFieldID { + resolvedFields[index].fieldID = resolvedFieldID + changed = true + } + usedLocalFields[sortedIndex] = true + } - return try TypeMeta( - typeID: typeID, - userTypeID: userTypeID, - namespace: namespace, - typeName: typeName, - registerByName: registerByName, - fields: resolvedFields, - hasFieldsMeta: hasFieldsMeta, - compressed: compressed, - headerHash: headerHash - ) + guard changed else { + return self } - private static func isCompatibleFieldType( - _ remoteType: FieldType, - _ localType: FieldType - ) -> Bool { - if normalizeCompatibleTypeIDForComparison(remoteType.typeID) != normalizeCompatibleTypeIDForComparison(localType.typeID) { - return false - } - if remoteType.generics.count != localType.generics.count { - return false - } - for (remoteGeneric, localGeneric) in zip(remoteType.generics, localType.generics) - where !isCompatibleFieldType(remoteGeneric, localGeneric) { - return false - } - return true - } - - private static func normalizeCompatibleTypeIDForComparison(_ typeID: UInt32) -> UInt32 { - switch typeID { - case TypeId.structType.rawValue, - TypeId.compatibleStruct.rawValue, - TypeId.namedStruct.rawValue, - TypeId.namedCompatibleStruct.rawValue, - TypeId.unknown.rawValue: - return TypeId.structType.rawValue - case TypeId.enumType.rawValue, - TypeId.namedEnum.rawValue: - return TypeId.enumType.rawValue - case TypeId.ext.rawValue, - TypeId.namedExt.rawValue: - return TypeId.ext.rawValue - case TypeId.binary.rawValue, - TypeId.int8Array.rawValue, - TypeId.uint8Array.rawValue: - return TypeId.binary.rawValue - case TypeId.union.rawValue, - TypeId.typedUnion.rawValue, - TypeId.namedUnion.rawValue: - return TypeId.union.rawValue - default: - return typeID - } + return try TypeMeta( + typeID: typeID, + userTypeID: userTypeID, + namespace: namespace, + typeName: typeName, + registerByName: registerByName, + fields: resolvedFields, + compressed: compressed, + headerHash: headerHash + ) + } + + private static func isCompatibleFieldType( + _ remoteType: FieldType, + _ localType: FieldType + ) -> Bool { + if normalizeCompatibleTypeIDForComparison(remoteType.typeID) + != normalizeCompatibleTypeIDForComparison(localType.typeID) { + return false + } + if remoteType.generics.count != localType.generics.count { + return false + } + for (remoteGeneric, localGeneric) in zip(remoteType.generics, localType.generics) + where !isCompatibleFieldType(remoteGeneric, localGeneric) { + return false } + return true + } + + private static func normalizeCompatibleTypeIDForComparison(_ typeID: UInt32) -> UInt32 { + switch typeID { + case TypeId.structType.rawValue, + TypeId.compatibleStruct.rawValue, + TypeId.namedStruct.rawValue, + TypeId.namedCompatibleStruct.rawValue, + TypeId.unknown.rawValue: + return TypeId.structType.rawValue + case TypeId.enumType.rawValue, + TypeId.namedEnum.rawValue: + return TypeId.enumType.rawValue + case TypeId.ext.rawValue, + TypeId.namedExt.rawValue: + return TypeId.ext.rawValue + case TypeId.binary.rawValue, + TypeId.int8Array.rawValue, + TypeId.uint8Array.rawValue: + return TypeId.binary.rawValue + case TypeId.union.rawValue, + TypeId.typedUnion.rawValue, + TypeId.namedUnion.rawValue: + return TypeId.union.rawValue + default: + return typeID + } + } } private func lowerCamelToLowerUnderscore(_ name: String) -> String { - if name.isEmpty { - return name - } - - let chars = Array(name) - var result = String() - result.reserveCapacity(name.count + 4) - - for (index, char) in chars.enumerated() { - if char.isUppercase { - if index > 0 { - let prevUpper = chars[index - 1].isUppercase - let nextUpperOrEnd = (index + 1 >= chars.count) || chars[index + 1].isUppercase - if !prevUpper || !nextUpperOrEnd { - result.append("_") - } - } - result.append(char.lowercased()) - } else { - result.append(char) - } + if name.isEmpty { + return name + } + + let chars = Array(name) + var result = String() + result.reserveCapacity(name.count + 4) + + for (index, char) in chars.enumerated() { + if char.isUppercase { + if index > 0 { + let prevUpper = chars[index - 1].isUppercase + let nextUpperOrEnd = (index + 1 >= chars.count) || chars[index + 1].isUppercase + if !prevUpper || !nextUpperOrEnd { + result.append("_") + } + } + result.append(char.lowercased()) + } else { + result.append(char) } + } - return result + return result } private func toSnakeCase(_ name: String) -> String { - if name.isEmpty { - return name - } - - let chars = Array(name) - var result = String() - result.reserveCapacity(name.count + 4) - - for (index, char) in chars.enumerated() { - if char.isUppercase { - if index > 0 { - let prevUpper = chars[index - 1].isUppercase - let nextUpperOrEnd = (index + 1 >= chars.count) || chars[index + 1].isUppercase - if !prevUpper || !nextUpperOrEnd { - result.append("_") - } - } - result.append(char.lowercased()) - } else { - result.append(char) - } + if name.isEmpty { + return name + } + + let chars = Array(name) + var result = String() + result.reserveCapacity(name.count + 4) + + for (index, char) in chars.enumerated() { + if char.isUppercase { + if index > 0 { + let prevUpper = chars[index - 1].isUppercase + let nextUpperOrEnd = (index + 1 >= chars.count) || chars[index + 1].isUppercase + if !prevUpper || !nextUpperOrEnd { + result.append("_") + } + } + result.append(char.lowercased()) + } else { + result.append(char) } + } - return result + return result } diff --git a/swift/Sources/Fory/TypeResolver.swift b/swift/Sources/Fory/TypeResolver.swift index 1483130e33..f1fd683174 100644 --- a/swift/Sources/Fory/TypeResolver.swift +++ b/swift/Sources/Fory/TypeResolver.swift @@ -19,691 +19,692 @@ import Foundation @inline(__always) func normalizeRegisteredTypeID(_ typeID: TypeId) -> TypeId { - switch typeID { - case .namedEnum: - return .enumType - case .compatibleStruct, .namedCompatibleStruct, .namedStruct: - return .structType - case .namedExt: - return .ext - case .namedUnion, .union: - return .typedUnion - default: - return typeID - } + switch typeID { + case .namedEnum: + return .enumType + case .compatibleStruct, .namedCompatibleStruct, .namedStruct: + return .structType + case .namedExt: + return .ext + case .namedUnion, .union: + return .typedUnion + default: + return typeID + } } @inline(__always) func namedRegisteredTypeID(for baseTypeID: TypeId, compatible: Bool, evolving: Bool) -> TypeId { - switch baseTypeID { - case .structType: - return compatible && evolving ? .namedCompatibleStruct : .namedStruct - case .enumType: - return .namedEnum - case .ext: - return .namedExt - case .typedUnion: - return .namedUnion - default: - return baseTypeID - } + switch baseTypeID { + case .structType: + return compatible && evolving ? .namedCompatibleStruct : .namedStruct + case .enumType: + return .namedEnum + case .ext: + return .namedExt + case .typedUnion: + return .namedUnion + default: + return baseTypeID + } } @inline(__always) func idRegisteredTypeID(for baseTypeID: TypeId, compatible: Bool, evolving: Bool) -> TypeId { - switch baseTypeID { - case .structType: - return compatible && evolving ? .compatibleStruct : .structType - default: - return baseTypeID - } + switch baseTypeID { + case .structType: + return compatible && evolving ? .compatibleStruct : .structType + default: + return baseTypeID + } } @inline(__always) func resolveRegisteredWireTypeID( - declaredTypeID: TypeId, - registerByName: Bool, - compatible: Bool, - evolving: Bool = true + declaredTypeID: TypeId, + registerByName: Bool, + compatible: Bool, + evolving: Bool = true ) -> TypeId { - let baseTypeID = normalizeRegisteredTypeID(declaredTypeID) - if registerByName { - return namedRegisteredTypeID(for: baseTypeID, compatible: compatible, evolving: evolving) - } - return idRegisteredTypeID(for: baseTypeID, compatible: compatible, evolving: evolving) + let baseTypeID = normalizeRegisteredTypeID(declaredTypeID) + if registerByName { + return namedRegisteredTypeID(for: baseTypeID, compatible: compatible, evolving: evolving) + } + return idRegisteredTypeID(for: baseTypeID, compatible: compatible, evolving: evolving) } @inline(__always) func isAllowedRegisteredWireTypeID( - _ wireTypeID: TypeId, - declaredTypeID: TypeId, - registerByName: Bool, - compatible: Bool, - evolving: Bool = true + _ wireTypeID: TypeId, + declaredTypeID: TypeId, + registerByName: Bool, + compatible: Bool, + evolving: Bool = true ) -> Bool { - let baseTypeID = normalizeRegisteredTypeID(declaredTypeID) - let expected = resolveRegisteredWireTypeID( - declaredTypeID: declaredTypeID, - registerByName: registerByName, - compatible: compatible, - evolving: evolving - ) - if wireTypeID == expected { - return true - } - if baseTypeID == .structType, compatible { - return wireTypeID == .compatibleStruct || - wireTypeID == .namedCompatibleStruct || - wireTypeID == .structType || - wireTypeID == .namedStruct - } - if baseTypeID == .typedUnion { - return wireTypeID == .union || (registerByName && wireTypeID == .namedUnion) - } - return false + let baseTypeID = normalizeRegisteredTypeID(declaredTypeID) + let expected = resolveRegisteredWireTypeID( + declaredTypeID: declaredTypeID, + registerByName: registerByName, + compatible: compatible, + evolving: evolving + ) + if wireTypeID == expected { + return true + } + if baseTypeID == .structType, compatible { + return wireTypeID == .compatibleStruct || wireTypeID == .namedCompatibleStruct + || wireTypeID == .structType || wireTypeID == .namedStruct + } + if baseTypeID == .typedUnion { + return wireTypeID == .union || (registerByName && wireTypeID == .namedUnion) + } + return false } @inline(__always) func registeredWireTypeNeedsUserTypeID(_ wireTypeID: TypeId) -> Bool { - switch wireTypeID { - case .enumType, .structType, .ext, .typedUnion, .union: - return true - default: - return false - } + switch wireTypeID { + case .enumType, .structType, .ext, .typedUnion, .union: + return true + default: + return false + } } @inline(__always) private func encodedTypeDefHeader(_ bytes: [UInt8]) throws -> UInt64 { - guard bytes.count >= 8 else { - throw ForyError.invalidData("encoded compatible type metadata must include an 8-byte header") - } - let buffer = ByteBuffer(bytes: bytes) - return try buffer.readUInt64() + guard bytes.count >= 8 else { + throw ForyError.invalidData("encoded compatible type metadata must include an 8-byte header") + } + let buffer = ByteBuffer(bytes: bytes) + return try buffer.readUInt64() } @inline(__always) private func encodedTypeDefHeaderHash(_ bytes: [UInt8]) throws -> UInt64 { - guard bytes.count >= 8 else { - throw ForyError.invalidData("encoded compatible type metadata must include an 8-byte header") - } - let buffer = ByteBuffer(bytes: bytes) - let header = try buffer.readUInt64() - return header >> 14 + guard bytes.count >= 8 else { + throw ForyError.invalidData("encoded compatible type metadata must include an 8-byte header") + } + let buffer = ByteBuffer(bytes: bytes) + let header = try buffer.readUInt64() + return header >> 12 } private func fieldNeedsTypeInfo(_ fieldType: TypeMeta.FieldType) -> Bool { - if let typeID = TypeId(rawValue: fieldType.typeID), - TypeId.needsTypeInfoForField(typeID) { - return true - } - return fieldType.generics.contains { fieldNeedsTypeInfo($0) } + if let typeID = TypeId(rawValue: fieldType.typeID), + TypeId.needsTypeInfoForField(typeID) { + return true + } + return fieldType.generics.contains { fieldNeedsTypeInfo($0) } } private func encodedTypeDefHasUserTypeFields(_ fields: [TypeMeta.FieldInfo]) -> Bool { - fields.contains { fieldNeedsTypeInfo($0.fieldType) } + fields.contains { fieldNeedsTypeInfo($0.fieldType) } } @inline(__always) private func readRegisteredValue(_ context: ReadContext, as type: T.Type) throws -> T { - try T.foryRead( - context, - refMode: T.isRefType ? .tracking : .none, - readTypeInfo: false - ) + try T.foryRead( + context, + refMode: T.isRefType ? .tracking : .none, + readTypeInfo: false + ) } @inline(__always) private func readCompatibleRegisteredValue( - _ context: ReadContext, - as type: T.Type, - remoteTypeInfo: TypeInfo + _ context: ReadContext, + as type: T.Type, + remoteTypeInfo: TypeInfo ) throws -> T { - guard T.isRefType else { - return try T.foryReadCompatibleData(context, remoteTypeInfo: remoteTypeInfo) - } - - let rawFlag = try context.buffer.readInt8() - guard let flag = RefFlag(rawValue: rawFlag) else { - throw ForyError.refError("invalid ref flag \(rawFlag)") - } - - switch flag { - case .null: - return T.foryDefault() - case .ref: - let refID = try context.buffer.readVarUInt32() - return try context.refReader.readRef(refID, as: T.self) - case .refValue: - let reservedRefID = context.trackRef ? context.refReader.reserveRefID() : nil - let value = try T.foryReadCompatibleData(context, remoteTypeInfo: remoteTypeInfo) - if let reservedRefID, let object = value as AnyObject? { - context.refReader.storeRef(object, at: reservedRefID) - } - return value - case .notNullValue: - return try T.foryReadCompatibleData(context, remoteTypeInfo: remoteTypeInfo) - } + guard T.isRefType else { + return try T.foryReadCompatibleData(context, remoteTypeInfo: remoteTypeInfo) + } + + let rawFlag = try context.buffer.readInt8() + guard let flag = RefFlag(rawValue: rawFlag) else { + throw ForyError.refError("invalid ref flag \(rawFlag)") + } + + switch flag { + case .null: + return T.foryDefault() + case .ref: + let refID = try context.buffer.readVarUInt32() + return try context.refReader.readRef(refID, as: T.self) + case .refValue: + let reservedRefID = context.trackRef ? context.refReader.reserveRefID() : nil + let value = try T.foryReadCompatibleData(context, remoteTypeInfo: remoteTypeInfo) + if let reservedRefID, let object = value as AnyObject? { + context.refReader.storeRef(object, at: reservedRefID) + } + return value + case .notNullValue: + return try T.foryReadCompatibleData(context, remoteTypeInfo: remoteTypeInfo) + } } public final class TypeInfo: @unchecked Sendable { - static let uncached = TypeInfo(typeID: .unknown) - - let swiftTypeID: ObjectIdentifier - let typeID: TypeId - let userTypeID: UInt32? - let registerByName: Bool - let evolving: Bool - let namespace: MetaString - let typeName: MetaString - let typeMeta: TypeMeta? - public let compatibleTypeMeta: TypeMeta? - let typeDefBytes: [UInt8]? - let firstTypeDefBytes: [UInt8]? - let typeDefHeader: UInt64? - public let typeDefHeaderHash: UInt64? - public let typeDefHasUserTypeFields: Bool - - private let reader: (ReadContext) throws -> Any - private let compatibleReader: (ReadContext, TypeInfo) throws -> Any - private let nativeWireTypeID: TypeId - private let compatibleWireTypeID: TypeId - - init( - swiftTypeID: ObjectIdentifier, - typeID: TypeId, - userTypeID: UInt32?, - registerByName: Bool, - evolving: Bool, - namespace: MetaString, - typeName: MetaString, - typeMeta: TypeMeta? = nil, - compatibleTypeMeta: TypeMeta? = nil, - typeDefBytes: [UInt8]? = nil, - firstTypeDefBytes: [UInt8]? = nil, - typeDefHeader: UInt64? = nil, - typeDefHeaderHash: UInt64? = nil, - typeDefHasUserTypeFields: Bool = true, - reader: @escaping (ReadContext) throws -> Any, - compatibleReader: @escaping (ReadContext, TypeInfo) throws -> Any - ) { - self.swiftTypeID = swiftTypeID - self.typeID = typeID - self.userTypeID = userTypeID - self.registerByName = registerByName - self.evolving = evolving - self.namespace = namespace - self.typeName = typeName - self.typeMeta = typeMeta - self.compatibleTypeMeta = compatibleTypeMeta ?? typeMeta - self.typeDefBytes = typeDefBytes - self.firstTypeDefBytes = firstTypeDefBytes - self.typeDefHeader = typeDefHeader - self.typeDefHeaderHash = typeDefHeaderHash - self.typeDefHasUserTypeFields = typeDefHasUserTypeFields - self.reader = reader - self.compatibleReader = compatibleReader - nativeWireTypeID = resolveRegisteredWireTypeID( - declaredTypeID: typeID, - registerByName: registerByName, - compatible: false, - evolving: evolving - ) - compatibleWireTypeID = resolveRegisteredWireTypeID( - declaredTypeID: typeID, - registerByName: registerByName, - compatible: true, - evolving: evolving - ) - } - - convenience init( - swiftTypeID: ObjectIdentifier, - typeID: TypeId, - userTypeID: UInt32?, - registerByName: Bool, - evolving: Bool, - namespace: MetaString, - typeName: MetaString, - fields: [TypeMeta.FieldInfo], - reader: @escaping (ReadContext) throws -> Any, - compatibleReader: @escaping (ReadContext, TypeInfo) throws -> Any - ) throws { - let compatibleWireTypeID = resolveRegisteredWireTypeID( - declaredTypeID: typeID, - registerByName: registerByName, - compatible: true, - evolving: evolving - ) - let typeMeta = try TypeMeta( - typeID: registerByName ? nil : compatibleWireTypeID.rawValue, - userTypeID: registerByName ? nil : userTypeID, - namespace: namespace, - typeName: typeName, - registerByName: registerByName, - fields: fields, - hasFieldsMeta: !fields.isEmpty - ) - let typeDefBytes = try typeMeta.encode() - var firstTypeDefBytes = [UInt8]() - firstTypeDefBytes.reserveCapacity(typeDefBytes.count + 1) - firstTypeDefBytes.append(0) - firstTypeDefBytes.append(contentsOf: typeDefBytes) - let typeDefHeader = try encodedTypeDefHeader(typeDefBytes) - let typeDefHeaderHash = try encodedTypeDefHeaderHash(typeDefBytes) - let canonicalTypeMeta = try TypeMeta( - typeID: registerByName ? nil : compatibleWireTypeID.rawValue, - userTypeID: registerByName ? nil : userTypeID, - namespace: namespace, - typeName: typeName, - registerByName: registerByName, - fields: fields, - hasFieldsMeta: !fields.isEmpty, - headerHash: typeDefHeaderHash - ) - self.init( - swiftTypeID: swiftTypeID, - typeID: typeID, - userTypeID: userTypeID, - registerByName: registerByName, - evolving: evolving, - namespace: namespace, - typeName: typeName, - typeMeta: canonicalTypeMeta, - compatibleTypeMeta: canonicalTypeMeta, - typeDefBytes: typeDefBytes, - firstTypeDefBytes: firstTypeDefBytes, - typeDefHeader: typeDefHeader, - typeDefHeaderHash: typeDefHeaderHash, - typeDefHasUserTypeFields: encodedTypeDefHasUserTypeFields(fields), - reader: reader, - compatibleReader: compatibleReader - ) - } - - convenience init(typeID: TypeId) { - self.init( - swiftTypeID: ObjectIdentifier(TypeInfo.self), - typeID: typeID, - userTypeID: nil, - registerByName: false, - evolving: true, - namespace: MetaString.empty(specialChar1: ".", specialChar2: "_"), - typeName: MetaString.empty(specialChar1: "$", specialChar2: "_"), - reader: { _ in - throw ForyError.invalidData("dynamic type \(typeID) uses runtime-only decode path") - }, - compatibleReader: { _, _ in - throw ForyError.invalidData("dynamic compatible type \(typeID) uses runtime-only decode path") - } - ) - } - - convenience init(dynamic typeInfo: TypeInfo, compatibleTypeMeta: TypeMeta) { - self.init( - swiftTypeID: typeInfo.swiftTypeID, - typeID: typeInfo.typeID, - userTypeID: typeInfo.userTypeID, - registerByName: typeInfo.registerByName, - evolving: typeInfo.evolving, - namespace: typeInfo.namespace, - typeName: typeInfo.typeName, - typeMeta: typeInfo.typeMeta, - compatibleTypeMeta: compatibleTypeMeta, - typeDefBytes: typeInfo.typeDefBytes, - firstTypeDefBytes: typeInfo.firstTypeDefBytes, - typeDefHeader: typeInfo.typeDefHeader, - typeDefHeaderHash: typeInfo.typeDefHeaderHash, - typeDefHasUserTypeFields: typeInfo.typeDefHasUserTypeFields, - reader: typeInfo.reader, - compatibleReader: typeInfo.compatibleReader - ) - } - - @inline(__always) - func matches( - typeID: TypeId, - userTypeID: UInt32?, - registerByName: Bool, - evolving: Bool, - typeName: (namespace: String, name: String) - ) -> Bool { - self.typeID == typeID && - self.userTypeID == userTypeID && - self.registerByName == registerByName && - self.evolving == evolving && - self.namespace.value == typeName.namespace && - self.typeName.value == typeName.name - } + static let uncached = TypeInfo(typeID: .unknown) + + let swiftTypeID: ObjectIdentifier + let typeID: TypeId + let userTypeID: UInt32? + let registerByName: Bool + let evolving: Bool + let namespace: MetaString + let typeName: MetaString + let typeMeta: TypeMeta? + public let compatibleTypeMeta: TypeMeta? + let typeDefBytes: [UInt8]? + let firstTypeDefBytes: [UInt8]? + let typeDefHeader: UInt64? + public let typeDefHeaderHash: UInt64? + public let typeDefHasUserTypeFields: Bool + + private let reader: (ReadContext) throws -> Any + private let compatibleReader: (ReadContext, TypeInfo) throws -> Any + private let nativeWireTypeID: TypeId + private let compatibleWireTypeID: TypeId + + init( + swiftTypeID: ObjectIdentifier, + typeID: TypeId, + userTypeID: UInt32?, + registerByName: Bool, + evolving: Bool, + namespace: MetaString, + typeName: MetaString, + typeMeta: TypeMeta? = nil, + compatibleTypeMeta: TypeMeta? = nil, + typeDefBytes: [UInt8]? = nil, + firstTypeDefBytes: [UInt8]? = nil, + typeDefHeader: UInt64? = nil, + typeDefHeaderHash: UInt64? = nil, + typeDefHasUserTypeFields: Bool = true, + reader: @escaping (ReadContext) throws -> Any, + compatibleReader: @escaping (ReadContext, TypeInfo) throws -> Any + ) { + self.swiftTypeID = swiftTypeID + self.typeID = typeID + self.userTypeID = userTypeID + self.registerByName = registerByName + self.evolving = evolving + self.namespace = namespace + self.typeName = typeName + self.typeMeta = typeMeta + self.compatibleTypeMeta = compatibleTypeMeta ?? typeMeta + self.typeDefBytes = typeDefBytes + self.firstTypeDefBytes = firstTypeDefBytes + self.typeDefHeader = typeDefHeader + self.typeDefHeaderHash = typeDefHeaderHash + self.typeDefHasUserTypeFields = typeDefHasUserTypeFields + self.reader = reader + self.compatibleReader = compatibleReader + nativeWireTypeID = resolveRegisteredWireTypeID( + declaredTypeID: typeID, + registerByName: registerByName, + compatible: false, + evolving: evolving + ) + compatibleWireTypeID = resolveRegisteredWireTypeID( + declaredTypeID: typeID, + registerByName: registerByName, + compatible: true, + evolving: evolving + ) + } - @inline(__always) - func wireTypeID(compatible: Bool) -> TypeId { - compatible ? compatibleWireTypeID : nativeWireTypeID - } + convenience init( + swiftTypeID: ObjectIdentifier, + typeID: TypeId, + userTypeID: UInt32?, + registerByName: Bool, + evolving: Bool, + namespace: MetaString, + typeName: MetaString, + fields: [TypeMeta.FieldInfo], + reader: @escaping (ReadContext) throws -> Any, + compatibleReader: @escaping (ReadContext, TypeInfo) throws -> Any + ) throws { + let compatibleWireTypeID = resolveRegisteredWireTypeID( + declaredTypeID: typeID, + registerByName: registerByName, + compatible: true, + evolving: evolving + ) + let typeMeta = try TypeMeta( + typeID: compatibleWireTypeID.rawValue, + userTypeID: registerByName ? nil : userTypeID, + namespace: namespace, + typeName: typeName, + registerByName: registerByName, + fields: fields + ) + let typeDefBytes = try typeMeta.encode() + var firstTypeDefBytes = [UInt8]() + firstTypeDefBytes.reserveCapacity(typeDefBytes.count + 1) + firstTypeDefBytes.append(0) + firstTypeDefBytes.append(contentsOf: typeDefBytes) + let typeDefHeader = try encodedTypeDefHeader(typeDefBytes) + let typeDefHeaderHash = try encodedTypeDefHeaderHash(typeDefBytes) + let canonicalTypeMeta = try TypeMeta( + typeID: compatibleWireTypeID.rawValue, + userTypeID: registerByName ? nil : userTypeID, + namespace: namespace, + typeName: typeName, + registerByName: registerByName, + fields: fields, + headerHash: typeDefHeaderHash + ) + self.init( + swiftTypeID: swiftTypeID, + typeID: typeID, + userTypeID: userTypeID, + registerByName: registerByName, + evolving: evolving, + namespace: namespace, + typeName: typeName, + typeMeta: canonicalTypeMeta, + compatibleTypeMeta: canonicalTypeMeta, + typeDefBytes: typeDefBytes, + firstTypeDefBytes: firstTypeDefBytes, + typeDefHeader: typeDefHeader, + typeDefHeaderHash: typeDefHeaderHash, + typeDefHasUserTypeFields: encodedTypeDefHasUserTypeFields(fields), + reader: reader, + compatibleReader: compatibleReader + ) + } + + convenience init(typeID: TypeId) { + self.init( + swiftTypeID: ObjectIdentifier(TypeInfo.self), + typeID: typeID, + userTypeID: nil, + registerByName: false, + evolving: true, + namespace: MetaString.empty(specialChar1: ".", specialChar2: "_"), + typeName: MetaString.empty(specialChar1: "$", specialChar2: "_"), + reader: { _ in + throw ForyError.invalidData("dynamic type \(typeID) uses runtime-only decode path") + }, + compatibleReader: { _, _ in + throw ForyError.invalidData( + "dynamic compatible type \(typeID) uses runtime-only decode path") + } + ) + } + + convenience init(dynamic typeInfo: TypeInfo, compatibleTypeMeta: TypeMeta) { + self.init( + swiftTypeID: typeInfo.swiftTypeID, + typeID: typeInfo.typeID, + userTypeID: typeInfo.userTypeID, + registerByName: typeInfo.registerByName, + evolving: typeInfo.evolving, + namespace: typeInfo.namespace, + typeName: typeInfo.typeName, + typeMeta: typeInfo.typeMeta, + compatibleTypeMeta: compatibleTypeMeta, + typeDefBytes: typeInfo.typeDefBytes, + firstTypeDefBytes: typeInfo.firstTypeDefBytes, + typeDefHeader: typeInfo.typeDefHeader, + typeDefHeaderHash: typeInfo.typeDefHeaderHash, + typeDefHasUserTypeFields: typeInfo.typeDefHasUserTypeFields, + reader: typeInfo.reader, + compatibleReader: typeInfo.compatibleReader + ) + } - @inline(__always) - func read(_ context: ReadContext, typeInfo: TypeInfo? = nil) throws -> Any { - if let typeInfo { - return try compatibleReader(context, typeInfo) - } - if context.compatible && - (compatibleWireTypeID == .compatibleStruct || compatibleWireTypeID == .namedCompatibleStruct) { - return try compatibleReader(context, self) - } - if compatibleTypeMeta !== typeMeta { - return try compatibleReader(context, self) - } - return try reader(context) - } + @inline(__always) + func matches( + typeID: TypeId, + userTypeID: UInt32?, + registerByName: Bool, + evolving: Bool, + typeName: (namespace: String, name: String) + ) -> Bool { + self.typeID == typeID && self.userTypeID == userTypeID && self.registerByName == registerByName + && self.evolving == evolving && self.namespace.value == typeName.namespace + && self.typeName.value == typeName.name + } + + @inline(__always) + func wireTypeID(compatible: Bool) -> TypeId { + compatible ? compatibleWireTypeID : nativeWireTypeID + } + + @inline(__always) + func read(_ context: ReadContext, typeInfo: TypeInfo? = nil) throws -> Any { + if let typeInfo { + return try compatibleReader(context, typeInfo) + } + if context.compatible + && (compatibleWireTypeID == .compatibleStruct + || compatibleWireTypeID == .namedCompatibleStruct) { + return try compatibleReader(context, self) + } + if compatibleTypeMeta !== typeMeta { + return try compatibleReader(context, self) + } + return try reader(context) + } } private struct TypeNameKey: Hashable { - let namespace: String - let typeName: String + let namespace: String + let typeName: String } final class TypeResolver { - private let trackRef: Bool - private var registrationFinished = false - - private var bySwiftType = UInt64Map(initialCapacity: 64) - private var byUserTypeID = UInt64Map(initialCapacity: 64) - private var byTypeName: [TypeNameKey: TypeInfo] = [:] - private var builtinTypeInfoByID: [TypeInfo?] = [] - private var typeInfoByHeader = UInt64Map(initialCapacity: 64) - - init(trackRef: Bool = false) { - self.trackRef = trackRef - } - - func finishRegistration() { - registrationFinished = true - } - - func register(_ type: T.Type, id: UInt32) { - do { - try registerByID(type, id: id) - } catch { - preconditionFailure("registration failed for \(type): \(error)") - } - } - - @inline(__always) - private func evolving(for type: T.Type) -> Bool { - guard let type = type as? any StructSerializer.Type else { - return true - } - return type.foryEvolving - } + private let trackRef: Bool + private var registrationFinished = false + + private var bySwiftType = UInt64Map(initialCapacity: 64) + private var byUserTypeID = UInt64Map(initialCapacity: 64) + private var byTypeName: [TypeNameKey: TypeInfo] = [:] + private var builtinTypeInfoByID: [TypeInfo?] = [] + private var typeInfoByHeader = UInt64Map(initialCapacity: 64) + + init(trackRef: Bool = false) { + self.trackRef = trackRef + } + + func finishRegistration() { + registrationFinished = true + } + + func register(_ type: T.Type, id: UInt32) { + do { + try registerByID(type, id: id) + } catch { + preconditionFailure("registration failed for \(type): \(error)") + } + } + + @inline(__always) + private func evolving(for type: T.Type) -> Bool { + guard let type = type as? any StructSerializer.Type else { + return true + } + return type.foryEvolving + } + + private func registerByID(_ type: T.Type, id: UInt32) throws { + try ensureRegistrationAllowed() + let swiftTypeID = ObjectIdentifier(type) + try validateIDRegistration(key: swiftTypeID, type: type, id: id) + let evolving = evolving(for: type) + + let typeInfo = try TypeInfo( + swiftTypeID: swiftTypeID, + typeID: T.staticTypeId, + userTypeID: id, + registerByName: false, + evolving: evolving, + namespace: MetaString.empty(specialChar1: ".", specialChar2: "_"), + typeName: MetaString.empty(specialChar1: "$", specialChar2: "_"), + fields: T.foryFieldsInfo(trackRef: trackRef), + reader: { context in + try readRegisteredValue(context, as: T.self) + }, + compatibleReader: { context, remoteTypeInfo in + try readCompatibleRegisteredValue(context, as: T.self, remoteTypeInfo: remoteTypeInfo) + } + ) - private func registerByID(_ type: T.Type, id: UInt32) throws { - try ensureRegistrationAllowed() - let swiftTypeID = ObjectIdentifier(type) - try validateIDRegistration(key: swiftTypeID, type: type, id: id) - let evolving = evolving(for: type) - - let typeInfo = try TypeInfo( - swiftTypeID: swiftTypeID, - typeID: T.staticTypeId, - userTypeID: id, - registerByName: false, - evolving: evolving, - namespace: MetaString.empty(specialChar1: ".", specialChar2: "_"), - typeName: MetaString.empty(specialChar1: "$", specialChar2: "_"), - fields: T.foryFieldsInfo(trackRef: trackRef), - reader: { context in - try readRegisteredValue(context, as: T.self) - }, - compatibleReader: { context, remoteTypeInfo in - try readCompatibleRegisteredValue(context, as: T.self, remoteTypeInfo: remoteTypeInfo) - } - ) + if let existing = bySwiftType.value(for: UInt64(UInt(bitPattern: swiftTypeID))), + existing.matches( + typeID: T.staticTypeId, + userTypeID: id, + registerByName: false, + evolving: evolving, + typeName: (namespace: "", name: "") + ) { + return + } + + try store(typeInfo, for: swiftTypeID, userTypeID: id) + } + + func register(_ type: T.Type, namespace: String, typeName: String) throws { + try ensureRegistrationAllowed() + let namespaceMeta = try MetaStringEncoder.namespace.encode( + namespace, + allowedEncodings: namespaceMetaStringEncodings + ) + let typeNameMeta = try MetaStringEncoder.typeName.encode( + typeName, + allowedEncodings: typeNameMetaStringEncodings + ) + let swiftTypeID = ObjectIdentifier(type) + try validateNameRegistration( + key: swiftTypeID, + type: type, + namespace: namespace, + typeName: typeName + ) + let evolving = evolving(for: type) + + let typeInfo = try TypeInfo( + swiftTypeID: swiftTypeID, + typeID: T.staticTypeId, + userTypeID: nil, + registerByName: true, + evolving: evolving, + namespace: namespaceMeta, + typeName: typeNameMeta, + fields: T.foryFieldsInfo(trackRef: trackRef), + reader: { context in + try readRegisteredValue(context, as: T.self) + }, + compatibleReader: { context, remoteTypeInfo in + try readCompatibleRegisteredValue(context, as: T.self, remoteTypeInfo: remoteTypeInfo) + } + ) - if let existing = bySwiftType.value(for: UInt64(UInt(bitPattern: swiftTypeID))), - existing.matches( - typeID: T.staticTypeId, - userTypeID: id, - registerByName: false, - evolving: evolving, - typeName: (namespace: "", name: "") - ) { - return - } - - try store(typeInfo, for: swiftTypeID, userTypeID: id) + if let existing = bySwiftType.value(for: UInt64(UInt(bitPattern: swiftTypeID))), + existing.matches( + typeID: T.staticTypeId, + userTypeID: nil, + registerByName: true, + evolving: evolving, + typeName: (namespace: namespace, name: typeName) + ) { + return } - func register(_ type: T.Type, namespace: String, typeName: String) throws { - try ensureRegistrationAllowed() - let namespaceMeta = try MetaStringEncoder.namespace.encode( - namespace, - allowedEncodings: namespaceMetaStringEncodings - ) - let typeNameMeta = try MetaStringEncoder.typeName.encode( - typeName, - allowedEncodings: typeNameMetaStringEncodings - ) - let swiftTypeID = ObjectIdentifier(type) - try validateNameRegistration( - key: swiftTypeID, - type: type, - namespace: namespace, - typeName: typeName + try store( + typeInfo, for: swiftTypeID, typeNameKey: TypeNameKey(namespace: namespace, typeName: typeName) + ) + } + + func register(_ type: T.Type, name: String) throws { + let parts = name.components(separatedBy: ".") + if parts.count <= 1 { + try register(type, namespace: "", typeName: name) + return + } + + let resolvedTypeName = parts[parts.count - 1] + let resolvedNamespace = parts.dropLast().joined(separator: ".") + try register(type, namespace: resolvedNamespace, typeName: resolvedTypeName) + } + + func requireTypeInfo(for type: T.Type) throws -> TypeInfo { + if let info = bySwiftType.value(for: UInt64(UInt(bitPattern: ObjectIdentifier(type)))) { + return info + } + throw ForyError.typeNotRegistered("\(type) is not registered") + } + + @inline(__always) + func getTypeInfo(forHeader header: UInt64) -> TypeInfo? { + typeInfoByHeader.value(for: header) + } + + @inline(__always) + func cacheTypeInfo(_ typeMeta: TypeMeta, forHeader header: UInt64) throws -> TypeInfo { + if let cached = typeInfoByHeader.value(for: header) { + return cached + } + let localTypeInfo = try requireTypeInfo(for: typeMeta) + if header == localTypeInfo.typeDefHeader { + typeInfoByHeader.set(localTypeInfo, for: header) + return localTypeInfo + } + let canonicalTypeMeta: TypeMeta + if let localTypeMeta = localTypeInfo.typeMeta, + let remapped = try? typeMeta.assigningFieldIDs(from: localTypeMeta) { + canonicalTypeMeta = remapped + } else { + canonicalTypeMeta = typeMeta + } + let typeInfo = TypeInfo(dynamic: localTypeInfo, compatibleTypeMeta: canonicalTypeMeta) + typeInfoByHeader.set(typeInfo, for: header) + return typeInfo + } + + private func store( + _ typeInfo: TypeInfo, + for swiftTypeID: ObjectIdentifier, + userTypeID: UInt32? = nil, + typeNameKey: TypeNameKey? = nil + ) throws { + bySwiftType.set(typeInfo, for: UInt64(UInt(bitPattern: swiftTypeID))) + if let userTypeID { + byUserTypeID.set(typeInfo, for: UInt64(userTypeID)) + } + if let typeNameKey { + byTypeName[typeNameKey] = typeInfo + } + if let typeMeta = typeInfo.typeMeta, + let typeDefHeader = typeInfo.typeDefHeader { + typeInfoByHeader.set( + TypeInfo( + dynamic: typeInfo, + compatibleTypeMeta: typeMeta + ), + for: typeDefHeader + ) + } + } + + @inline(__always) + func builtinTypeInfo(for typeID: TypeId) -> TypeInfo { + let index = Int(typeID.rawValue) + if index < builtinTypeInfoByID.count, let cached = builtinTypeInfoByID[index] { + return cached + } + let info = TypeInfo(typeID: typeID) + if index >= builtinTypeInfoByID.count { + builtinTypeInfoByID.append( + contentsOf: repeatElement(nil, count: index - builtinTypeInfoByID.count + 1)) + } + builtinTypeInfoByID[index] = info + return info + } + + @inline(__always) + func requireTypeInfo(userTypeID: UInt32) throws -> TypeInfo { + guard let typeInfo = byUserTypeID.value(for: UInt64(userTypeID)) else { + throw ForyError.typeNotRegistered("user_type_id=\(userTypeID)") + } + return typeInfo + } + + @inline(__always) + func requireTypeInfo(namespace: String, typeName: String) throws -> TypeInfo { + guard let typeInfo = byTypeName[TypeNameKey(namespace: namespace, typeName: typeName)] else { + throw ForyError.typeNotRegistered("namespace=\(namespace), type=\(typeName)") + } + return typeInfo + } + + private func validateIDRegistration( + key: ObjectIdentifier, + type: T.Type, + id: UInt32 + ) throws { + let swiftKey = UInt64(UInt(bitPattern: key)) + if let existing = bySwiftType.value(for: swiftKey) { + if existing.registerByName { + throw ForyError.invalidData( + "\(type) was already registered by name, cannot re-register by id" ) - let evolving = evolving(for: type) - - let typeInfo = try TypeInfo( - swiftTypeID: swiftTypeID, - typeID: T.staticTypeId, - userTypeID: nil, - registerByName: true, - evolving: evolving, - namespace: namespaceMeta, - typeName: typeNameMeta, - fields: T.foryFieldsInfo(trackRef: trackRef), - reader: { context in - try readRegisteredValue(context, as: T.self) - }, - compatibleReader: { context, remoteTypeInfo in - try readCompatibleRegisteredValue(context, as: T.self, remoteTypeInfo: remoteTypeInfo) - } + } + if existing.typeID != T.staticTypeId || existing.userTypeID != id { + let existingID = existing.userTypeID.map { String($0) } ?? "nil" + throw ForyError.invalidData( + "\(type) registration conflict: existing id=\(existingID), new id=\(id)" ) - - if let existing = bySwiftType.value(for: UInt64(UInt(bitPattern: swiftTypeID))), - existing.matches( - typeID: T.staticTypeId, - userTypeID: nil, - registerByName: true, - evolving: evolving, - typeName: (namespace: namespace, name: typeName) - ) { - return - } - - try store(typeInfo, for: swiftTypeID, typeNameKey: TypeNameKey(namespace: namespace, typeName: typeName)) - } - - func register(_ type: T.Type, name: String) throws { - let parts = name.components(separatedBy: ".") - if parts.count <= 1 { - try register(type, namespace: "", typeName: name) - return - } - - let resolvedTypeName = parts[parts.count - 1] - let resolvedNamespace = parts.dropLast().joined(separator: ".") - try register(type, namespace: resolvedNamespace, typeName: resolvedTypeName) - } - - func requireTypeInfo(for type: T.Type) throws -> TypeInfo { - if let info = bySwiftType.value(for: UInt64(UInt(bitPattern: ObjectIdentifier(type)))) { - return info - } - throw ForyError.typeNotRegistered("\(type) is not registered") - } - - @inline(__always) - func getTypeInfo(forHeader header: UInt64) -> TypeInfo? { - typeInfoByHeader.value(for: header) - } - - @inline(__always) - func cacheTypeInfo(_ typeMeta: TypeMeta, forHeader header: UInt64) throws -> TypeInfo { - if let cached = typeInfoByHeader.value(for: header) { - return cached - } - let localTypeInfo = try requireTypeInfo(for: typeMeta) - if header == localTypeInfo.typeDefHeader { - typeInfoByHeader.set(localTypeInfo, for: header) - return localTypeInfo - } - let canonicalTypeMeta: TypeMeta - if let localTypeMeta = localTypeInfo.typeMeta, - let remapped = try? typeMeta.assigningFieldIDs(from: localTypeMeta) { - canonicalTypeMeta = remapped - } else { - canonicalTypeMeta = typeMeta - } - let typeInfo = TypeInfo(dynamic: localTypeInfo, compatibleTypeMeta: canonicalTypeMeta) - typeInfoByHeader.set(typeInfo, for: header) - return typeInfo + } } - private func store( - _ typeInfo: TypeInfo, - for swiftTypeID: ObjectIdentifier, - userTypeID: UInt32? = nil, - typeNameKey: TypeNameKey? = nil - ) throws { - bySwiftType.set(typeInfo, for: UInt64(UInt(bitPattern: swiftTypeID))) - if let userTypeID { - byUserTypeID.set(typeInfo, for: UInt64(userTypeID)) - } - if let typeNameKey { - byTypeName[typeNameKey] = typeInfo - } - if let typeMeta = typeInfo.typeMeta, - let typeDefHeader = typeInfo.typeDefHeader { - typeInfoByHeader.set( - TypeInfo( - dynamic: typeInfo, - compatibleTypeMeta: typeMeta - ), - for: typeDefHeader - ) - } + if let existing = byUserTypeID.value(for: UInt64(id)), existing.swiftTypeID != key { + throw ForyError.invalidData("user type id \(id) is already registered by another type") } + } - @inline(__always) - func builtinTypeInfo(for typeID: TypeId) -> TypeInfo { - let index = Int(typeID.rawValue) - if index < builtinTypeInfoByID.count, let cached = builtinTypeInfoByID[index] { - return cached - } - let info = TypeInfo(typeID: typeID) - if index >= builtinTypeInfoByID.count { - builtinTypeInfoByID.append(contentsOf: repeatElement(nil, count: index - builtinTypeInfoByID.count + 1)) - } - builtinTypeInfoByID[index] = info - return info - } - - @inline(__always) - func requireTypeInfo(userTypeID: UInt32) throws -> TypeInfo { - guard let typeInfo = byUserTypeID.value(for: UInt64(userTypeID)) else { - throw ForyError.typeNotRegistered("user_type_id=\(userTypeID)") - } - return typeInfo - } - - @inline(__always) - func requireTypeInfo(namespace: String, typeName: String) throws -> TypeInfo { - guard let typeInfo = byTypeName[TypeNameKey(namespace: namespace, typeName: typeName)] else { - throw ForyError.typeNotRegistered("namespace=\(namespace), type=\(typeName)") - } - return typeInfo + private func validateNameRegistration( + key: ObjectIdentifier, + type: T.Type, + namespace: String, + typeName: String + ) throws { + if let existing = bySwiftType.value(for: UInt64(UInt(bitPattern: key))) { + if !existing.registerByName { + throw ForyError.invalidData( + "\(type) was already registered by id, cannot re-register by name" + ) + } + if existing.typeID != T.staticTypeId || existing.namespace.value != namespace + || existing.typeName.value != typeName { + throw ForyError.invalidData( + """ + \(type) registration conflict: existing name=\(existing.namespace.value)::\(existing.typeName.value), \ + new name=\(namespace)::\(typeName) + """ + ) + } } - private func validateIDRegistration( - key: ObjectIdentifier, - type: T.Type, - id: UInt32 - ) throws { - let swiftKey = UInt64(UInt(bitPattern: key)) - if let existing = bySwiftType.value(for: swiftKey) { - if existing.registerByName { - throw ForyError.invalidData( - "\(type) was already registered by name, cannot re-register by id" - ) - } - if existing.typeID != T.staticTypeId || existing.userTypeID != id { - let existingID = existing.userTypeID.map { String($0) } ?? "nil" - throw ForyError.invalidData( - "\(type) registration conflict: existing id=\(existingID), new id=\(id)" - ) - } - } - - if let existing = byUserTypeID.value(for: UInt64(id)), existing.swiftTypeID != key { - throw ForyError.invalidData("user type id \(id) is already registered by another type") - } + let nameKey = TypeNameKey(namespace: namespace, typeName: typeName) + if let existing = byTypeName[nameKey], existing.swiftTypeID != key { + throw ForyError.invalidData( + "type name \(namespace)::\(typeName) is already registered by another type") } + } - private func validateNameRegistration( - key: ObjectIdentifier, - type: T.Type, - namespace: String, - typeName: String - ) throws { - if let existing = bySwiftType.value(for: UInt64(UInt(bitPattern: key))) { - if !existing.registerByName { - throw ForyError.invalidData( - "\(type) was already registered by id, cannot re-register by name" - ) - } - if existing.typeID != T.staticTypeId || - existing.namespace.value != namespace || - existing.typeName.value != typeName { - throw ForyError.invalidData( - """ - \(type) registration conflict: existing name=\(existing.namespace.value)::\(existing.typeName.value), \ - new name=\(namespace)::\(typeName) - """ - ) - } - } - - let nameKey = TypeNameKey(namespace: namespace, typeName: typeName) - if let existing = byTypeName[nameKey], existing.swiftTypeID != key { - throw ForyError.invalidData("type name \(namespace)::\(typeName) is already registered by another type") - } + @inline(__always) + private func requireTypeInfo(for typeMeta: TypeMeta) throws -> TypeInfo { + if typeMeta.registerByName { + guard + let typeInfo = byTypeName[ + TypeNameKey(namespace: typeMeta.namespace.value, typeName: typeMeta.typeName.value)] + else { + throw ForyError.typeNotRegistered( + "namespace=\(typeMeta.namespace.value), type=\(typeMeta.typeName.value)" + ) + } + return typeInfo } - - @inline(__always) - private func requireTypeInfo(for typeMeta: TypeMeta) throws -> TypeInfo { - if typeMeta.registerByName { - guard let typeInfo = byTypeName[TypeNameKey(namespace: typeMeta.namespace.value, typeName: typeMeta.typeName.value)] else { - throw ForyError.typeNotRegistered( - "namespace=\(typeMeta.namespace.value), type=\(typeMeta.typeName.value)" - ) - } - return typeInfo - } - if let userTypeID = typeMeta.userTypeID { - guard let typeInfo = byUserTypeID.value(for: UInt64(userTypeID)) else { - throw ForyError.typeNotRegistered("user_type_id=\(userTypeID)") - } - return typeInfo - } - throw ForyError.invalidData("missing user type id in compatible dynamic type meta") + if let userTypeID = typeMeta.userTypeID { + guard let typeInfo = byUserTypeID.value(for: UInt64(userTypeID)) else { + throw ForyError.typeNotRegistered("user_type_id=\(userTypeID)") + } + return typeInfo } + throw ForyError.invalidData("missing user type id in compatible dynamic type meta") + } - private func ensureRegistrationAllowed() throws { - guard !registrationFinished else { - throw ForyError.invalidData( - "cannot register more types after top-level serialize/deserialize has frozen registration" - ) - } + private func ensureRegistrationAllowed() throws { + guard !registrationFinished else { + throw ForyError.invalidData( + "cannot register more types after top-level serialize/deserialize has frozen registration" + ) } + } } diff --git a/swift/Tests/ForyTests/DecimalTests.swift b/swift/Tests/ForyTests/DecimalTests.swift index f7f3f266b5..69b33ae739 100644 --- a/swift/Tests/ForyTests/DecimalTests.swift +++ b/swift/Tests/ForyTests/DecimalTests.swift @@ -17,107 +17,109 @@ import Foundation import Testing + @testable import Fory @ForyStruct private struct DecimalEnvelope: Equatable { - var amount: Decimal = .zero - var note: String = "" + var amount: Decimal = .zero + var note: String = "" } private func makeDecimal(unscaled: String, scale: Int32) throws -> Decimal { - var digits = unscaled - var sign = "" - if digits.first == "-" { - sign = "-" - digits.removeFirst() - } - guard !digits.isEmpty, digits.allSatisfy(\.isNumber) else { - throw ForyError.invalidData("failed to create decimal \(unscaled) scale \(scale)") - } - if scale == 0 { - guard let value = Decimal(string: sign + digits, locale: Locale(identifier: "en_US_POSIX")) else { - throw ForyError.invalidData("failed to create decimal \(unscaled) scale \(scale)") - } - return value + var digits = unscaled + var sign = "" + if digits.first == "-" { + sign = "-" + digits.removeFirst() + } + guard !digits.isEmpty, digits.allSatisfy(\.isNumber) else { + throw ForyError.invalidData("failed to create decimal \(unscaled) scale \(scale)") + } + if scale == 0 { + guard let value = Decimal(string: sign + digits, locale: Locale(identifier: "en_US_POSIX")) + else { + throw ForyError.invalidData("failed to create decimal \(unscaled) scale \(scale)") } - let valueString: String - if scale > 0 { - let scaleInt = Int(scale) - if digits.count > scaleInt { - let split = digits.index(digits.endIndex, offsetBy: -scaleInt) - valueString = sign + String(digits[.. 0 { + let scaleInt = Int(scale) + if digits.count > scaleInt { + let split = digits.index(digits.endIndex, offsetBy: -scaleInt) + valueString = sign + String(digits[.. - var addresses: [Address] - var metadata: [Int8: Int32?] + var id: Int64 + var name: String + var nickname: String? + var scores: [Int32] + var tags: Set + var addresses: [Address] + var metadata: [Int8: Int32?] } @ForyStruct struct FieldOrder: Equatable { - var textTail: String - var longValue: Int64 - var shortValue: Int16 - var intValue: Int32 + var textTail: String + var longValue: Int64 + var shortValue: Int16 + var intValue: Int32 } @ForyStruct struct TaggedFieldOrder: Equatable { - @ForyField(id: 1) - var textTail: String + @ForyField(id: 1) + var textTail: String - @ForyField(id: 10) - var intValue: Int32 + @ForyField(id: 10) + var intValue: Int32 } @ForyStruct struct EncodedNumberFields: Equatable { - @ForyField(encoding: .fixed) - var u32Fixed: UInt32 + @ForyField(encoding: .fixed) + var u32Fixed: UInt32 - @ForyField(encoding: .tagged) - var u64Tagged: UInt64 + @ForyField(encoding: .tagged) + var u64Tagged: UInt64 } @ForyStruct struct ReducedPrecisionMacroFields: Equatable { - var float16Value: Float16 - var bfloat16Value: BFloat16 - @ArrayField(element: .float16) - var float16Array: [Float16] - @ArrayField(element: .bfloat16) - var bfloat16Array: [BFloat16] + var float16Value: Float16 + var bfloat16Value: BFloat16 + @ArrayField(element: .float16) + var float16Array: [Float16] + @ArrayField(element: .bfloat16) + var bfloat16Array: [BFloat16] } @ForyStruct struct FieldIdConfigured: Equatable { - @ForyField(id: 2) - var stableID: Int32 + @ForyField(id: 2) + var stableID: Int32 - @ForyField(id: 5, encoding: .fixed) - var fixedValue: Int32 + @ForyField(id: 5, encoding: .fixed) + var fixedValue: Int32 } @ForyStruct struct FieldIdSource: Equatable { - @ForyField(id: 1) - var value: Int32 + @ForyField(id: 1) + var value: Int32 - @ForyField(id: 4) - var label: String + @ForyField(id: 4) + var label: String } @ForyStruct struct FieldIdTarget: Equatable { - @ForyField(id: 1) - var renamedValue: Int32 + @ForyField(id: 1) + var renamedValue: Int32 - @ForyField(id: 4) - var renamedLabel: String + @ForyField(id: 4) + var renamedLabel: String } @ForyEnum enum SparseStatus: Int32, CaseIterable { - case unknown = 4096 - case ok = 8192 + case unknown = 4096 + case ok = 8192 } @ForyStruct struct EvolvingOverrideValue: Equatable { - var f1: String = "" + var f1: String = "" } @ForyStruct(evolving: false) struct FixedOverrideValue: Equatable { - var f1: String = "" + var f1: String = "" } @ForyUnion enum FieldIdUnionSource: Equatable { - @ForyCase(id: 3) - case number(Int32) + @ForyCase(id: 3) + case number(Int32) - @ForyCase(id: 9) - case text(String) + @ForyCase(id: 9) + case text(String) } @ForyUnion enum FieldIdUnionTarget: Equatable { - @ForyCase(id: 3) - case renamedNumber(Int32) + @ForyCase(id: 3) + case renamedNumber(Int32) - @ForyCase(id: 9) - case renamedText(String) + @ForyCase(id: 9) + case renamedText(String) } @ForyStruct struct CompatibleNestedItem: Equatable { - var id: Int32 - var name: String + var id: Int32 + var name: String } @ForyStruct struct CompatibleNestedArrayHolder: Equatable { - var items: [CompatibleNestedItem] + var items: [CompatibleNestedItem] } @ForyStruct struct CompatibleNestedOptionalArrayHolder: Equatable { - var items: [CompatibleNestedItem?] + var items: [CompatibleNestedItem?] } @ForyStruct struct CompatibleNestedMapHolder: Equatable { - var items: [Int32: CompatibleNestedItem] + var items: [Int32: CompatibleNestedItem] } @ForyStruct final class Node { - var value: Int32 = 0 - var next: Node? + var value: Int32 = 0 + var next: Node? - required init() {} + required init() {} - init(value: Int32, next: Node? = nil) { - self.value = value - self.next = next - } + init(value: Int32, next: Node? = nil) { + self.value = value + self.next = next + } } @ForyStruct final class WeakNode { - var value: Int32 = 0 - weak var next: WeakNode? + var value: Int32 = 0 + weak var next: WeakNode? - required init() {} + required init() {} - init(value: Int32, next: WeakNode? = nil) { - self.value = value - self.next = next - } + init(value: Int32, next: WeakNode? = nil) { + self.value = value + self.next = next + } } @ForyStruct struct AnyObjectHolder { - var value: AnyObject - var optionalValue: AnyObject? - var items: [AnyObject] + var value: AnyObject + var optionalValue: AnyObject? + var items: [AnyObject] } @ForyStruct struct AnySerializerHolder { - var value: any Serializer - var items: [any Serializer] - var map: [String: any Serializer] + var value: any Serializer + var items: [any Serializer] + var map: [String: any Serializer] } @ForyStruct struct AnyFieldHolder { - var value: Any - var optionalValue: Any? - var list: [Any] - var stringMap: [String: Any] - var int32Map: [Int32: Any] + var value: Any + var optionalValue: Any? + var list: [Any] + var stringMap: [String: Any] + var int32Map: [Int32: Any] } @Test func primitiveRoundTrip() throws { - let fory = Fory() + let fory = Fory() - let boolData = try fory.serialize(true) - let boolValue: Bool = try fory.deserialize(boolData) - #expect(boolValue == true) + let boolData = try fory.serialize(true) + let boolValue: Bool = try fory.deserialize(boolData) + #expect(boolValue == true) - let int32Data = try fory.serialize(Int32(-123456)) - let int32Value: Int32 = try fory.deserialize(int32Data) - #expect(int32Value == -123456) + let int32Data = try fory.serialize(Int32(-123456)) + let int32Value: Int32 = try fory.deserialize(int32Data) + #expect(int32Value == -123456) - let int64Data = try fory.serialize(Int64(9_223_372_036_854_775_000)) - let int64Value: Int64 = try fory.deserialize(int64Data) - #expect(int64Value == 9_223_372_036_854_775_000) + let int64Data = try fory.serialize(Int64(9_223_372_036_854_775_000)) + let int64Value: Int64 = try fory.deserialize(int64Data) + #expect(int64Value == 9_223_372_036_854_775_000) - let uint32Data = try fory.serialize(UInt32(123456)) - let uint32Value: UInt32 = try fory.deserialize(uint32Data) - #expect(uint32Value == 123456) + let uint32Data = try fory.serialize(UInt32(123456)) + let uint32Value: UInt32 = try fory.deserialize(uint32Data) + #expect(uint32Value == 123456) - let uint64Data = try fory.serialize(UInt64(9_223_372_036_854_775_000)) - let uint64Value: UInt64 = try fory.deserialize(uint64Data) - #expect(uint64Value == 9_223_372_036_854_775_000) + let uint64Data = try fory.serialize(UInt64(9_223_372_036_854_775_000)) + let uint64Value: UInt64 = try fory.deserialize(uint64Data) + #expect(uint64Value == 9_223_372_036_854_775_000) - let floatData = try fory.serialize(Float(3.25)) - let floatValue: Float = try fory.deserialize(floatData) - #expect(floatValue == 3.25) + let floatData = try fory.serialize(Float(3.25)) + let floatValue: Float = try fory.deserialize(floatData) + #expect(floatValue == 3.25) - let doubleData = try fory.serialize(Double(3.1415926)) - let doubleValue: Double = try fory.deserialize(doubleData) - #expect(doubleValue == 3.1415926) + let doubleData = try fory.serialize(Double(3.1415926)) + let doubleValue: Double = try fory.deserialize(doubleData) + #expect(doubleValue == 3.1415926) - let stringData = try fory.serialize("hello_fory") - let stringValue: String = try fory.deserialize(stringData) - #expect(stringValue == "hello_fory") + let stringData = try fory.serialize("hello_fory") + let stringValue: String = try fory.deserialize(stringData) + #expect(stringValue == "hello_fory") - let binary = Data([0x01, 0x02, 0x03, 0xFF]) - let binaryData = try fory.serialize(binary) - let binaryValue: Data = try fory.deserialize(binaryData) - #expect(binaryValue == binary) + let binary = Data([0x01, 0x02, 0x03, 0xFF]) + let binaryData = try fory.serialize(binary) + let binaryValue: Data = try fory.deserialize(binaryData) + #expect(binaryValue == binary) } @Test func extendedWireTypesRoundTrip() throws { - let fory = Fory() + let fory = Fory() - let float16Value = Float16(3.5) - let float16Data = try fory.serialize(float16Value) - let float16Decoded: Float16 = try fory.deserialize(float16Data) - #expect(float16Decoded.bitPattern == float16Value.bitPattern) + let float16Value = Float16(3.5) + let float16Data = try fory.serialize(float16Value) + let float16Decoded: Float16 = try fory.deserialize(float16Data) + #expect(float16Decoded.bitPattern == float16Value.bitPattern) - let bfloatValue = BFloat16(rawValue: 0x3F80) - let bfloatData = try fory.serialize(bfloatValue) - let bfloatDecoded: BFloat16 = try fory.deserialize(bfloatData) - #expect(bfloatDecoded == bfloatValue) + let bfloatValue = BFloat16(rawValue: 0x3F80) + let bfloatData = try fory.serialize(bfloatValue) + let bfloatDecoded: BFloat16 = try fory.deserialize(bfloatData) + #expect(bfloatDecoded == bfloatValue) - let durationValue = Duration.seconds(-2) + Duration.nanoseconds(123_456_789) - let durationData = try fory.serialize(durationValue) - let durationDecoded: Duration = try fory.deserialize(durationData) - #expect(durationDecoded == durationValue) + let durationValue = Duration.seconds(-2) + Duration.nanoseconds(123_456_789) + let durationData = try fory.serialize(durationValue) + let durationDecoded: Duration = try fory.deserialize(durationData) + #expect(durationDecoded == durationValue) - let float16Array: [Float16] = [Float16(1), Float16(-2), Float16(4.5)] - let float16ArrayData = try fory.serialize(float16Array) - let float16ArrayDecoded: [Float16] = try fory.deserialize(float16ArrayData) - #expect(float16ArrayDecoded.map(\.bitPattern) == float16Array.map(\.bitPattern)) + let float16Array: [Float16] = [Float16(1), Float16(-2), Float16(4.5)] + let float16ArrayData = try fory.serialize(float16Array) + let float16ArrayDecoded: [Float16] = try fory.deserialize(float16ArrayData) + #expect(float16ArrayDecoded.map(\.bitPattern) == float16Array.map(\.bitPattern)) } @Test func floatingSpecialsRoundTrip() throws { - let fory = Fory() - - let floatValues: [Float] = [ - 0.0, - -0.0, - .infinity, - -.infinity, - .leastNonzeroMagnitude, - .greatestFiniteMagnitude, - Float(bitPattern: 0x7FC0_1234) - ] - for value in floatValues { - let decoded: Float = try fory.deserialize(try fory.serialize(value)) - #expect(decoded.bitPattern == value.bitPattern) - } - - let doubleValues: [Double] = [ - 0.0, - -0.0, - .infinity, - -.infinity, - .leastNonzeroMagnitude, - .greatestFiniteMagnitude, - Double(bitPattern: 0x7FF8_0000_0000_1234) - ] - for value in doubleValues { - let decoded: Double = try fory.deserialize(try fory.serialize(value)) - #expect(decoded.bitPattern == value.bitPattern) - } - - let float16Values: [Float16] = [ - .init(bitPattern: 0x0000), - .init(bitPattern: 0x8000), - .init(bitPattern: 0x7C00), - .init(bitPattern: 0xFC00), - .init(bitPattern: 0x0001), - .init(bitPattern: 0x7BFF), - .init(bitPattern: 0x7E11) - ] - for value in float16Values { - let decoded: Float16 = try fory.deserialize(try fory.serialize(value)) - #expect(decoded.bitPattern == value.bitPattern) - } - - let bfloat16Values: [BFloat16] = [ - .init(rawValue: 0x0000), - .init(rawValue: 0x8000), - .init(rawValue: 0x7F80), - .init(rawValue: 0xFF80), - .init(rawValue: 0x0001), - .init(rawValue: 0x7FC1) - ] - for value in bfloat16Values { - let decoded: BFloat16 = try fory.deserialize(try fory.serialize(value)) - #expect(decoded.rawValue == value.rawValue) - } + let fory = Fory() + + let floatValues: [Float] = [ + 0.0, + -0.0, + .infinity, + -.infinity, + .leastNonzeroMagnitude, + .greatestFiniteMagnitude, + Float(bitPattern: 0x7FC0_1234) + ] + for value in floatValues { + let decoded: Float = try fory.deserialize(try fory.serialize(value)) + #expect(decoded.bitPattern == value.bitPattern) + } + + let doubleValues: [Double] = [ + 0.0, + -0.0, + .infinity, + -.infinity, + .leastNonzeroMagnitude, + .greatestFiniteMagnitude, + Double(bitPattern: 0x7FF8_0000_0000_1234) + ] + for value in doubleValues { + let decoded: Double = try fory.deserialize(try fory.serialize(value)) + #expect(decoded.bitPattern == value.bitPattern) + } + + let float16Values: [Float16] = [ + .init(bitPattern: 0x0000), + .init(bitPattern: 0x8000), + .init(bitPattern: 0x7C00), + .init(bitPattern: 0xFC00), + .init(bitPattern: 0x0001), + .init(bitPattern: 0x7BFF), + .init(bitPattern: 0x7E11) + ] + for value in float16Values { + let decoded: Float16 = try fory.deserialize(try fory.serialize(value)) + #expect(decoded.bitPattern == value.bitPattern) + } + + let bfloat16Values: [BFloat16] = [ + .init(rawValue: 0x0000), + .init(rawValue: 0x8000), + .init(rawValue: 0x7F80), + .init(rawValue: 0xFF80), + .init(rawValue: 0x0001), + .init(rawValue: 0x7FC1) + ] + for value in bfloat16Values { + let decoded: BFloat16 = try fory.deserialize(try fory.serialize(value)) + #expect(decoded.rawValue == value.rawValue) + } } @Test func namedInitializerBuildsConfig() { - let defaultConfig = Fory() - #expect(defaultConfig.config.xlang == true) - #expect(defaultConfig.config.trackRef == false) - #expect(defaultConfig.config.compatible == false) - #expect(defaultConfig.config.checkClassVersion == true) - #expect(defaultConfig.config.maxDepth == 5) - - let explicitConfig = Fory(xlang: false, trackRef: true, compatible: true, maxDepth: 7) - #expect(explicitConfig.config.xlang == false) - #expect(explicitConfig.config.trackRef == true) - #expect(explicitConfig.config.compatible == true) - #expect(explicitConfig.config.checkClassVersion == false) - #expect(explicitConfig.config.maxDepth == 7) - - let configInit = Fory(config: .init(xlang: false, trackRef: false, compatible: true, maxDepth: 9)) - #expect(configInit.config.xlang == false) - #expect(configInit.config.trackRef == false) - #expect(configInit.config.compatible == true) - #expect(configInit.config.checkClassVersion == false) - #expect(configInit.config.maxDepth == 9) - - let nativeDirect = Fory(xlang: false, trackRef: true, compatible: false) - let nativeViaConfig = Fory(config: Config(xlang: false, trackRef: true, compatible: false)) - #expect(nativeDirect.config.checkClassVersion == false) - #expect(nativeViaConfig.config.checkClassVersion == false) + let defaultConfig = Fory() + #expect(defaultConfig.config.xlang == true) + #expect(defaultConfig.config.trackRef == false) + #expect(defaultConfig.config.compatible == false) + #expect(defaultConfig.config.checkClassVersion == true) + #expect(defaultConfig.config.maxDepth == 5) + + let explicitConfig = Fory(xlang: false, trackRef: true, compatible: true, maxDepth: 7) + #expect(explicitConfig.config.xlang == false) + #expect(explicitConfig.config.trackRef == true) + #expect(explicitConfig.config.compatible == true) + #expect(explicitConfig.config.checkClassVersion == false) + #expect(explicitConfig.config.maxDepth == 7) + + let configInit = Fory(config: .init(xlang: false, trackRef: false, compatible: true, maxDepth: 9)) + #expect(configInit.config.xlang == false) + #expect(configInit.config.trackRef == false) + #expect(configInit.config.compatible == true) + #expect(configInit.config.checkClassVersion == false) + #expect(configInit.config.maxDepth == 9) + + let nativeDirect = Fory(xlang: false, trackRef: true, compatible: false) + let nativeViaConfig = Fory(config: Config(xlang: false, trackRef: true, compatible: false)) + #expect(nativeDirect.config.checkClassVersion == false) + #expect(nativeViaConfig.config.checkClassVersion == false) } @Test func structEvolvingOverrideUsesSmallerCompatiblePayload() throws { - let fory = Fory(compatible: true) - fory.register(EvolvingOverrideValue.self, id: 1001) - fory.register(FixedOverrideValue.self, id: 1002) + let fory = Fory(compatible: true) + fory.register(EvolvingOverrideValue.self, id: 1001) + fory.register(FixedOverrideValue.self, id: 1002) - let evolving = EvolvingOverrideValue(f1: "payload") - let fixed = FixedOverrideValue(f1: "payload") + let evolving = EvolvingOverrideValue(f1: "payload") + let fixed = FixedOverrideValue(f1: "payload") - let evolvingData = try fory.serialize(evolving) - let fixedData = try fory.serialize(fixed) + let evolvingData = try fory.serialize(evolving) + let fixedData = try fory.serialize(fixed) - #expect(fixedData.count < evolvingData.count) - let decodedEvolving: EvolvingOverrideValue = try fory.deserialize(evolvingData) - let decodedFixed: FixedOverrideValue = try fory.deserialize(fixedData) - #expect(decodedEvolving == evolving) - #expect(decodedFixed == fixed) + #expect(fixedData.count < evolvingData.count) + let decodedEvolving: EvolvingOverrideValue = try fory.deserialize(evolvingData) + let decodedFixed: FixedOverrideValue = try fory.deserialize(fixedData) + #expect(decodedEvolving == evolving) + #expect(decodedFixed == fixed) } @Test func decodeLimitsRejectOversizedPayloads() throws { - let writer = Fory() - - let oversizedCollection = try writer.serialize(["a", "b", "c"]) - let collectionLimited = Fory(config: .init(maxCollectionSize: 2)) - do { - let _: [String] = try collectionLimited.deserialize(oversizedCollection) - #expect(Bool(false)) - } catch {} - - let oversizedMap = try writer.serialize([Int32(1): Int32(1), 2: 2, 3: 3]) - do { - let _: [Int32: Int32] = try collectionLimited.deserialize(oversizedMap) - #expect(Bool(false)) - } catch {} - - let oversizedBinary = try writer.serialize(Data([0x01, 0x02, 0x03, 0x04])) - let binaryLimited = Fory(config: .init(maxBinarySize: 3)) - do { - let _: Data = try binaryLimited.deserialize(oversizedBinary) - #expect(Bool(false)) - } catch {} - - let oversizedArrayPayload = try writer.serialize([UInt16(1), 2]) - let payloadLimited = Fory(config: .init(maxCollectionSize: 1)) - do { - let _: [UInt16] = try payloadLimited.deserialize(oversizedArrayPayload) - #expect(Bool(false)) - } catch {} + let writer = Fory() + + let oversizedCollection = try writer.serialize(["a", "b", "c"]) + let collectionLimited = Fory(config: .init(maxCollectionSize: 2)) + do { + let _: [String] = try collectionLimited.deserialize(oversizedCollection) + #expect(Bool(false)) + } catch {} + + let oversizedMap = try writer.serialize([Int32(1): Int32(1), 2: 2, 3: 3]) + do { + let _: [Int32: Int32] = try collectionLimited.deserialize(oversizedMap) + #expect(Bool(false)) + } catch {} + + let oversizedBinary = try writer.serialize(Data([0x01, 0x02, 0x03, 0x04])) + let binaryLimited = Fory(config: .init(maxBinarySize: 3)) + do { + let _: Data = try binaryLimited.deserialize(oversizedBinary) + #expect(Bool(false)) + } catch {} + + let oversizedArrayPayload = try writer.serialize([UInt16(1), 2]) + let payloadLimited = Fory(config: .init(maxCollectionSize: 1)) + do { + let _: [UInt16] = try payloadLimited.deserialize(oversizedArrayPayload) + #expect(Bool(false)) + } catch {} } @Test func deserializeRejectsTrailingBytes() throws { - let fory = Fory() - let payload = try fory.serialize(Int32(7)) - var bytes = [UInt8](payload) - bytes.append(0xFF) - let withTrailing = Data(bytes) + let fory = Fory() + let payload = try fory.serialize(Int32(7)) + var bytes = [UInt8](payload) + bytes.append(0xFF) + let withTrailing = Data(bytes) - do { - let _: Int32 = try fory.deserialize(withTrailing) - #expect(Bool(false)) - } catch {} + do { + let _: Int32 = try fory.deserialize(withTrailing) + #expect(Bool(false)) + } catch {} } @Test func optionalRoundTrip() throws { - let fory = Fory() + let fory = Fory() - let some: String? = "present" - let someData = try fory.serialize(some) - let someValue: String? = try fory.deserialize(someData) - #expect(someValue == "present") + let some: String? = "present" + let someData = try fory.serialize(some) + let someValue: String? = try fory.deserialize(someData) + #expect(someValue == "present") - let none: String? = nil - let noneData = try fory.serialize(none) - let noneValue: String? = try fory.deserialize(noneData) - #expect(noneValue == nil) + let none: String? = nil + let noneData = try fory.serialize(none) + let noneValue: String? = try fory.deserialize(noneData) + #expect(noneValue == nil) } @Test func collectionsRoundTrip() throws { - let fory = Fory() + let fory = Fory() - let list: [String?] = ["a", nil, "b"] - let listData = try fory.serialize(list) - let listValue: [String?] = try fory.deserialize(listData) - #expect(listValue == list) + let list: [String?] = ["a", nil, "b"] + let listData = try fory.serialize(list) + let listValue: [String?] = try fory.deserialize(listData) + #expect(listValue == list) - let intArray: [Int32] = [1, 2, 3, 4] - let intArrayData = try fory.serialize(intArray) - let intArrayValue: [Int32] = try fory.deserialize(intArrayData) - #expect(intArrayValue == intArray) + let intArray: [Int32] = [1, 2, 3, 4] + let intArrayData = try fory.serialize(intArray) + let intArrayValue: [Int32] = try fory.deserialize(intArrayData) + #expect(intArrayValue == intArray) - let uint8Array: [UInt8] = [1, 2, 3, 250] - let uint8ArrayData = try fory.serialize(uint8Array) - let uint8ArrayValue: [UInt8] = try fory.deserialize(uint8ArrayData) - #expect(uint8ArrayValue == uint8Array) + let uint8Array: [UInt8] = [1, 2, 3, 250] + let uint8ArrayData = try fory.serialize(uint8Array) + let uint8ArrayValue: [UInt8] = try fory.deserialize(uint8ArrayData) + #expect(uint8ArrayValue == uint8Array) - let set: Set = [1, 5, 8] - let setData = try fory.serialize(set) - let setValue: Set = try fory.deserialize(setData) - #expect(setValue == set) + let set: Set = [1, 5, 8] + let setData = try fory.serialize(set) + let setValue: Set = try fory.deserialize(setData) + #expect(setValue == set) - let map: [Int8: Int32?] = [1: 100, 2: nil, 3: -7] - let mapData = try fory.serialize(map) - let mapValue: [Int8: Int32?] = try fory.deserialize(mapData) - #expect(mapValue == map) + let map: [Int8: Int32?] = [1: 100, 2: nil, 3: -7] + let mapData = try fory.serialize(map) + let mapValue: [Int8: Int32?] = try fory.deserialize(mapData) + #expect(mapValue == map) - let nullableKeyMap: [Int8?: Int32?] = [1: 10, nil: nil] - let nullableMapData = try fory.serialize(nullableKeyMap) - let nullableMapValue: [Int8?: Int32?] = try fory.deserialize(nullableMapData) - #expect(nullableMapValue == nullableKeyMap) + let nullableKeyMap: [Int8?: Int32?] = [1: 10, nil: nil] + let nullableMapData = try fory.serialize(nullableKeyMap) + let nullableMapValue: [Int8?: Int32?] = try fory.deserialize(nullableMapData) + #expect(nullableMapValue == nullableKeyMap) } @Test func primitiveArrayTypeIDs() throws { - let fory = Fory() + let fory = Fory() - let int32Data = try fory.serialize([Int32(7), 9]) - let int32Bytes = [UInt8](int32Data) - #expect(int32Bytes[0] == ForyHeaderFlag.isXlang) - #expect(Int8(bitPattern: int32Bytes[1]) == RefFlag.notNullValue.rawValue) - #expect(UInt32(int32Bytes[2]) == TypeId.list.rawValue) + let int32Data = try fory.serialize([Int32(7), 9]) + let int32Bytes = [UInt8](int32Data) + #expect(int32Bytes[0] == ForyHeaderFlag.isXlang) + #expect(Int8(bitPattern: int32Bytes[1]) == RefFlag.notNullValue.rawValue) + #expect(UInt32(int32Bytes[2]) == TypeId.list.rawValue) - let uint8Data = try fory.serialize([UInt8(1), 2, 3]) - let uint8Bytes = [UInt8](uint8Data) - #expect(UInt32(uint8Bytes[2]) == TypeId.list.rawValue) + let uint8Data = try fory.serialize([UInt8(1), 2, 3]) + let uint8Bytes = [UInt8](uint8Data) + #expect(UInt32(uint8Bytes[2]) == TypeId.list.rawValue) } @Test func macroStructRoundTrip() throws { - let fory = Fory() - fory.register(Address.self, id: 100) - fory.register(Person.self, id: 101) - - let person = Person( - id: 42, - name: "Alice", - nickname: nil, - scores: [10, 20, 30], - tags: ["swift", "xlang"], - addresses: [Address(street: "Main", zip: 94107)], - metadata: [1: 100, 2: nil] - ) + let fory = Fory() + fory.register(Address.self, id: 100) + fory.register(Person.self, id: 101) + + let person = Person( + id: 42, + name: "Alice", + nickname: nil, + scores: [10, 20, 30], + tags: ["swift", "xlang"], + addresses: [Address(street: "Main", zip: 94107)], + metadata: [1: 100, 2: nil] + ) - let data = try fory.serialize(person) - let decoded: Person = try fory.deserialize(data) - #expect(decoded == person) + let data = try fory.serialize(person) + let decoded: Person = try fory.deserialize(data) + #expect(decoded == person) } @Test func macroClassRefTracking() throws { - let fory = Fory(config: .init(xlang: true, trackRef: true)) - fory.register(Node.self, id: 200) + let fory = Fory(config: .init(xlang: true, trackRef: true)) + fory.register(Node.self, id: 200) - let node = Node(value: 7) - node.next = node + let node = Node(value: 7) + node.next = node - let data = try fory.serialize(node) - let decoded: Node = try fory.deserialize(data) + let data = try fory.serialize(node) + let decoded: Node = try fory.deserialize(data) - #expect(decoded.value == 7) - #expect(decoded.next === decoded) + #expect(decoded.value == 7) + #expect(decoded.next === decoded) } @Test func macroClassWeakRefTracking() throws { - let fory = Fory(config: .init(xlang: true, trackRef: true)) - fory.register(WeakNode.self, id: 201) + let fory = Fory(config: .init(xlang: true, trackRef: true)) + fory.register(WeakNode.self, id: 201) - let node = WeakNode(value: 13) - node.next = node + let node = WeakNode(value: 13) + node.next = node - let data = try fory.serialize(node) - let decoded: WeakNode = try fory.deserialize(data) + let data = try fory.serialize(node) + let decoded: WeakNode = try fory.deserialize(data) - #expect(decoded.value == 13) - #expect(decoded.next === decoded) + #expect(decoded.value == 13) + #expect(decoded.next === decoded) } @Test func topLevelAnyRoundTrip() throws { - let fory = Fory() - fory.register(Address.self, id: 209) + let fory = Fory() + fory.register(Address.self, id: 209) - let value: Any = Address(street: "AnyTop", zip: 8080) - let data = try fory.serialize(value) - let decoded: Any = try fory.deserialize(data) - #expect(decoded as? Address == Address(street: "AnyTop", zip: 8080)) + let value: Any = Address(street: "AnyTop", zip: 8080) + let data = try fory.serialize(value) + let decoded: Any = try fory.deserialize(data) + #expect(decoded as? Address == Address(street: "AnyTop", zip: 8080)) - var buffer = Data() - try fory.serialize(value, to: &buffer) - let decodedFrom: Any = try fory.deserialize(from: ByteBuffer(data: buffer)) - #expect(decodedFrom as? Address == Address(street: "AnyTop", zip: 8080)) + var buffer = Data() + try fory.serialize(value, to: &buffer) + let decodedFrom: Any = try fory.deserialize(from: ByteBuffer(data: buffer)) + #expect(decodedFrom as? Address == Address(street: "AnyTop", zip: 8080)) - let nullAny: Any = Optional.none as Any - let nullData = try fory.serialize(nullAny) - let nullDecoded: Any = try fory.deserialize(nullData) - #expect(nullDecoded is ForyAnyNullValue) + let nullAny: Any = Optional.none as Any + let nullData = try fory.serialize(nullAny) + let nullDecoded: Any = try fory.deserialize(nullData) + #expect(nullDecoded is ForyAnyNullValue) } @Test func dynamicUserTypesDecodeByID() throws { - let fory = Fory() - fory.register(Address.self, id: 600) - try fory.register(Person.self, name: "demo.person") + let fory = Fory() + fory.register(Address.self, id: 600) + try fory.register(Person.self, name: "demo.person") - let value: Any = Address(street: "mixed", zip: 7788) - let data = try fory.serialize(value) - let decoded: Any = try fory.deserialize(data) - #expect(decoded as? Address == Address(street: "mixed", zip: 7788)) + let value: Any = Address(street: "mixed", zip: 7788) + let data = try fory.serialize(value) + let decoded: Any = try fory.deserialize(data) + #expect(decoded as? Address == Address(street: "mixed", zip: 7788)) } @Test func duplicateNameRegistrationIsRejected() throws { - let resolver = TypeResolver(trackRef: false) - try resolver.register(Address.self, namespace: "demo", typeName: "entity") + let resolver = TypeResolver(trackRef: false) + try resolver.register(Address.self, namespace: "demo", typeName: "entity") - do { - try resolver.register(Person.self, namespace: "demo", typeName: "entity") - #expect(Bool(false)) - } catch {} + do { + try resolver.register(Person.self, namespace: "demo", typeName: "entity") + #expect(Bool(false)) + } catch {} } @Test func registrationIsRejectedAfterFirstTopLevelUse() throws { - let fory = Fory() - _ = try fory.serialize(Int32(7)) - - do { - try fory.register(Address.self, name: "demo.address") - #expect(Bool(false)) - } catch { - #expect("\(error)".contains("cannot register more types")) - } + let fory = Fory() + _ = try fory.serialize(Int32(7)) + + do { + try fory.register(Address.self, name: "demo.address") + #expect(Bool(false)) + } catch { + #expect("\(error)".contains("cannot register more types")) + } } @Test func serializeToAppendsRoots() throws { - let fory = Fory() - let first = Int32(7) - let second = "swift-buffer" - let third: String? = nil + let fory = Fory() + let first = Int32(7) + let second = "swift-buffer" + let third: String? = nil - let firstData = try fory.serialize(first) - let secondData = try fory.serialize(second) - let thirdData = try fory.serialize(third) + let firstData = try fory.serialize(first) + let secondData = try fory.serialize(second) + let thirdData = try fory.serialize(third) - var stream = Data() - try fory.serialize(first, to: &stream) - try fory.serialize(second, to: &stream) - try fory.serialize(third, to: &stream) + var stream = Data() + try fory.serialize(first, to: &stream) + try fory.serialize(second, to: &stream) + try fory.serialize(third, to: &stream) - var expected = Data() - expected.append(firstData) - expected.append(secondData) - expected.append(thirdData) - #expect(stream == expected) + var expected = Data() + expected.append(firstData) + expected.append(secondData) + expected.append(thirdData) + #expect(stream == expected) - let buffer = ByteBuffer(data: stream) - let decodedFirst: Int32 = try fory.deserialize(from: buffer) - #expect(decodedFirst == first) - #expect(buffer.getCursor() == firstData.count) + let buffer = ByteBuffer(data: stream) + let decodedFirst: Int32 = try fory.deserialize(from: buffer) + #expect(decodedFirst == first) + #expect(buffer.getCursor() == firstData.count) - let decodedSecond: String = try fory.deserialize(from: buffer) - #expect(decodedSecond == second) - #expect(buffer.getCursor() == firstData.count + secondData.count) + let decodedSecond: String = try fory.deserialize(from: buffer) + #expect(decodedSecond == second) + #expect(buffer.getCursor() == firstData.count + secondData.count) - let decodedThird: String? = try fory.deserialize(from: buffer) - #expect(decodedThird == nil) - #expect(buffer.remaining == 0) + let decodedThird: String? = try fory.deserialize(from: buffer) + #expect(decodedThird == nil) + #expect(buffer.remaining == 0) } @Test func rootBufferHonorsCursor() throws { - let fory = Fory() - let prefix: [UInt8] = [0xAA, 0xBB, 0xCC] - let payload = try fory.serialize("offset") + let fory = Fory() + let prefix: [UInt8] = [0xAA, 0xBB, 0xCC] + let payload = try fory.serialize("offset") - let buffer = ByteBuffer() - buffer.writeBytes(prefix) - buffer.writeBytes(Array(payload)) - buffer.setCursor(prefix.count) + let buffer = ByteBuffer() + buffer.writeBytes(prefix) + buffer.writeBytes(Array(payload)) + buffer.setCursor(prefix.count) - let decoded: String = try fory.deserialize(from: buffer) - #expect(decoded == "offset") - #expect(buffer.getCursor() == buffer.count) - #expect(Array(buffer.storage.prefix(prefix.count)) == prefix) + let decoded: String = try fory.deserialize(from: buffer) + #expect(decoded == "offset") + #expect(buffer.getCursor() == buffer.count) + #expect(Array(buffer.storage.prefix(prefix.count)) == prefix) } @Test func topLevelAnyObjectRoundTrip() throws { - let fory = Fory(config: .init(xlang: true, trackRef: true)) - fory.register(Node.self, id: 210) + let fory = Fory(config: .init(xlang: true, trackRef: true)) + fory.register(Node.self, id: 210) - let value: AnyObject = Node(value: 123) - let data = try fory.serialize(value) - let decoded: AnyObject = try fory.deserialize(data) + let value: AnyObject = Node(value: 123) + let data = try fory.serialize(value) + let decoded: AnyObject = try fory.deserialize(data) - let node = decoded as? Node - #expect(node != nil) - #expect(node?.value == 123) + let node = decoded as? Node + #expect(node != nil) + #expect(node?.value == 123) - var buffer = Data() - try fory.serialize(value, to: &buffer) - let decodedFrom: AnyObject = try fory.deserialize(from: ByteBuffer(data: buffer)) - #expect((decodedFrom as? Node)?.value == 123) + var buffer = Data() + try fory.serialize(value, to: &buffer) + let decodedFrom: AnyObject = try fory.deserialize(from: ByteBuffer(data: buffer)) + #expect((decodedFrom as? Node)?.value == 123) } @Test func topLevelAnySerializerRoundTrip() throws { - let fory = Fory() - fory.register(Address.self, id: 211) + let fory = Fory() + fory.register(Address.self, id: 211) - let value: any Serializer = Address(street: "AnyStreet", zip: 9090) - let data = try fory.serialize(value) - let decoded: any Serializer = try fory.deserialize(data) + let value: any Serializer = Address(street: "AnyStreet", zip: 9090) + let data = try fory.serialize(value) + let decoded: any Serializer = try fory.deserialize(data) - let address = decoded as? Address - #expect(address == Address(street: "AnyStreet", zip: 9090)) + let address = decoded as? Address + #expect(address == Address(street: "AnyStreet", zip: 9090)) - var buffer = Data() - try fory.serialize(value, to: &buffer) - let decodedFrom: any Serializer = try fory.deserialize(from: ByteBuffer(data: buffer)) - #expect(decodedFrom as? Address == Address(street: "AnyStreet", zip: 9090)) + var buffer = Data() + try fory.serialize(value, to: &buffer) + let decodedFrom: any Serializer = try fory.deserialize(from: ByteBuffer(data: buffer)) + #expect(decodedFrom as? Address == Address(street: "AnyStreet", zip: 9090)) } @Test func macroDynamicAnyObjectAndAnySerializerFieldsRoundTrip() throws { - let fory = Fory(config: .init(xlang: true, trackRef: true)) - fory.register(Node.self, id: 220) - fory.register(Address.self, id: 221) - fory.register(AnyObjectHolder.self, id: 222) - fory.register(AnySerializerHolder.self, id: 223) - - let sharedNode = Node(value: 77) - let objectHolder = AnyObjectHolder( - value: sharedNode, - optionalValue: nil, - items: [sharedNode, NSNull()] - ) - let objectData = try fory.serialize(objectHolder) - let objectDecoded: AnyObjectHolder = try fory.deserialize(objectData) - #expect((objectDecoded.value as? Node)?.value == 77) - #expect(objectDecoded.optionalValue == nil) - #expect(objectDecoded.items.count == 2) - #expect((objectDecoded.items[0] as? Node)?.value == 77) - #expect(objectDecoded.items[1] is NSNull) - - let serializerHolder = AnySerializerHolder( - value: Address(street: "Root", zip: 10001), - items: [Int32(11), Address(street: "Nested", zip: 10002)], - map: [ - "age": Int64(19), - "address": Address(street: "Mapped", zip: 10003) - ] - ) - let serializerData = try fory.serialize(serializerHolder) - let serializerDecoded: AnySerializerHolder = try fory.deserialize(serializerData) + let fory = Fory(config: .init(xlang: true, trackRef: true)) + fory.register(Node.self, id: 220) + fory.register(Address.self, id: 221) + fory.register(AnyObjectHolder.self, id: 222) + fory.register(AnySerializerHolder.self, id: 223) + + let sharedNode = Node(value: 77) + let objectHolder = AnyObjectHolder( + value: sharedNode, + optionalValue: nil, + items: [sharedNode, NSNull()] + ) + let objectData = try fory.serialize(objectHolder) + let objectDecoded: AnyObjectHolder = try fory.deserialize(objectData) + #expect((objectDecoded.value as? Node)?.value == 77) + #expect(objectDecoded.optionalValue == nil) + #expect(objectDecoded.items.count == 2) + #expect((objectDecoded.items[0] as? Node)?.value == 77) + #expect(objectDecoded.items[1] is NSNull) + + let serializerHolder = AnySerializerHolder( + value: Address(street: "Root", zip: 10001), + items: [Int32(11), Address(street: "Nested", zip: 10002)], + map: [ + "age": Int64(19), + "address": Address(street: "Mapped", zip: 10003) + ] + ) + let serializerData = try fory.serialize(serializerHolder) + let serializerDecoded: AnySerializerHolder = try fory.deserialize(serializerData) - #expect(serializerDecoded.value as? Address == Address(street: "Root", zip: 10001)) - #expect(serializerDecoded.items.count == 2) - #expect(serializerDecoded.items[0] as? Int32 == 11) - #expect(serializerDecoded.items[1] as? Address == Address(street: "Nested", zip: 10002)) - #expect(serializerDecoded.map["age"] as? Int64 == 19) - #expect(serializerDecoded.map["address"] as? Address == Address(street: "Mapped", zip: 10003)) + #expect(serializerDecoded.value as? Address == Address(street: "Root", zip: 10001)) + #expect(serializerDecoded.items.count == 2) + #expect(serializerDecoded.items[0] as? Int32 == 11) + #expect(serializerDecoded.items[1] as? Address == Address(street: "Nested", zip: 10002)) + #expect(serializerDecoded.map["age"] as? Int64 == 19) + #expect(serializerDecoded.map["address"] as? Address == Address(street: "Mapped", zip: 10003)) } @Test func dynamicAnySerializerTracksRefs() throws { - let fory = Fory(config: .init(xlang: true, trackRef: true)) - fory.register(Node.self, id: 226) - fory.register(AnySerializerHolder.self, id: 227) - - let shared = Node(value: 88) - shared.next = shared - let value = AnySerializerHolder( - value: shared, - items: [shared], - map: ["shared": shared] - ) + let fory = Fory(config: .init(xlang: true, trackRef: true)) + fory.register(Node.self, id: 226) + fory.register(AnySerializerHolder.self, id: 227) - let decoded: AnySerializerHolder = try fory.deserialize(try fory.serialize(value)) - let root = decoded.value as? Node - let item = decoded.items.first as? Node - let mapped = decoded.map["shared"] as? Node + let shared = Node(value: 88) + shared.next = shared + let value = AnySerializerHolder( + value: shared, + items: [shared], + map: ["shared": shared] + ) - #expect(root != nil) - #expect(root === item) - #expect(item === mapped) - #expect(root?.next === root) + let decoded: AnySerializerHolder = try fory.deserialize(try fory.serialize(value)) + let root = decoded.value as? Node + let item = decoded.items.first as? Node + let mapped = decoded.map["shared"] as? Node + + #expect(root != nil) + #expect(root === item) + #expect(item === mapped) + #expect(root?.next === root) } @Test func macroAnyFieldsRoundTrip() throws { - let fory = Fory() - fory.register(Address.self, id: 224) - fory.register(AnyFieldHolder.self, id: 225) - - let value = AnyFieldHolder( - value: Address(street: "AnyRoot", zip: 11001), - optionalValue: nil, - list: [Int32(7), "hello", Address(street: "AnyList", zip: 11002), NSNull()], - stringMap: [ - "count": Int64(3), - "name": "map", - "address": Address(street: "AnyMap", zip: 11003), - "empty": NSNull() - ], - int32Map: [ - 1: Int32(-9), - 2: "v2", - 3: Address(street: "AnyIntMap", zip: 11004), - 4: NSNull() - ] - ) - let data = try fory.serialize(value) - let decoded: AnyFieldHolder = try fory.deserialize(data) - - #expect(decoded.value as? Address == Address(street: "AnyRoot", zip: 11001)) - #expect(decoded.optionalValue == nil) - #expect(decoded.list.count == 4) - #expect(decoded.list[0] as? Int32 == 7) - #expect(decoded.list[1] as? String == "hello") - #expect(decoded.list[2] as? Address == Address(street: "AnyList", zip: 11002)) - #expect(decoded.list[3] is NSNull) - #expect(decoded.stringMap["count"] as? Int64 == 3) - #expect(decoded.stringMap["name"] as? String == "map") - #expect(decoded.stringMap["address"] as? Address == Address(street: "AnyMap", zip: 11003)) - #expect(decoded.stringMap["empty"] is NSNull) - #expect(decoded.int32Map[1] as? Int32 == -9) - #expect(decoded.int32Map[2] as? String == "v2") - #expect(decoded.int32Map[3] as? Address == Address(street: "AnyIntMap", zip: 11004)) - #expect(decoded.int32Map[4] is NSNull) + let fory = Fory() + fory.register(Address.self, id: 224) + fory.register(AnyFieldHolder.self, id: 225) + + let value = AnyFieldHolder( + value: Address(street: "AnyRoot", zip: 11001), + optionalValue: nil, + list: [Int32(7), "hello", Address(street: "AnyList", zip: 11002), NSNull()], + stringMap: [ + "count": Int64(3), + "name": "map", + "address": Address(street: "AnyMap", zip: 11003), + "empty": NSNull() + ], + int32Map: [ + 1: Int32(-9), + 2: "v2", + 3: Address(street: "AnyIntMap", zip: 11004), + 4: NSNull() + ] + ) + let data = try fory.serialize(value) + let decoded: AnyFieldHolder = try fory.deserialize(data) + + #expect(decoded.value as? Address == Address(street: "AnyRoot", zip: 11001)) + #expect(decoded.optionalValue == nil) + #expect(decoded.list.count == 4) + #expect(decoded.list[0] as? Int32 == 7) + #expect(decoded.list[1] as? String == "hello") + #expect(decoded.list[2] as? Address == Address(street: "AnyList", zip: 11002)) + #expect(decoded.list[3] is NSNull) + #expect(decoded.stringMap["count"] as? Int64 == 3) + #expect(decoded.stringMap["name"] as? String == "map") + #expect(decoded.stringMap["address"] as? Address == Address(street: "AnyMap", zip: 11003)) + #expect(decoded.stringMap["empty"] is NSNull) + #expect(decoded.int32Map[1] as? Int32 == -9) + #expect(decoded.int32Map[2] as? String == "v2") + #expect(decoded.int32Map[3] as? Address == Address(street: "AnyIntMap", zip: 11004)) + #expect(decoded.int32Map[4] is NSNull) } @Test func collectionAndMapRefTracking() throws { - let fory = Fory(config: .init(xlang: true, trackRef: true)) - fory.register(Node.self, id: 200) - - let shared = Node(value: 11) - let list: [Node?] = [shared, shared, nil] - let listData = try fory.serialize(list) - let listReader = ByteBuffer(data: listData) - _ = try fory.readHead(buffer: listReader) - _ = try listReader.readInt8() - _ = try listReader.readVarUInt32() - _ = try listReader.readVarUInt32() - let listHeader = try listReader.readUInt8() - #expect((listHeader & 0b0000_0001) != 0) - - let decodedList: [Node?] = try fory.deserialize(listData) - #expect(decodedList.count == 3) - #expect(decodedList[0] === decodedList[1]) - #expect(decodedList[2] == nil) - - let sharedValue = Node(value: 21) - let map: [Int8: Node?] = [1: sharedValue, 2: sharedValue] - let mapData = try fory.serialize(map) - let mapReader = ByteBuffer(data: mapData) - _ = try fory.readHead(buffer: mapReader) - _ = try mapReader.readInt8() - _ = try mapReader.readVarUInt32() - _ = try mapReader.readVarUInt32() - let mapChunkHeader = try mapReader.readUInt8() - #expect((mapChunkHeader & 0b0000_1000) != 0) - - let decodedMap: [Int8: Node?] = try fory.deserialize(mapData) - let v1 = decodedMap[1] ?? nil - let v2 = decodedMap[2] ?? nil - #expect(v1 != nil) - #expect(v1 === v2) + let fory = Fory(config: .init(xlang: true, trackRef: true)) + fory.register(Node.self, id: 200) + + let shared = Node(value: 11) + let list: [Node?] = [shared, shared, nil] + let listData = try fory.serialize(list) + let listReader = ByteBuffer(data: listData) + _ = try fory.readHead(buffer: listReader) + _ = try listReader.readInt8() + _ = try listReader.readVarUInt32() + _ = try listReader.readVarUInt32() + let listHeader = try listReader.readUInt8() + #expect((listHeader & 0b0000_0001) != 0) + + let decodedList: [Node?] = try fory.deserialize(listData) + #expect(decodedList.count == 3) + #expect(decodedList[0] === decodedList[1]) + #expect(decodedList[2] == nil) + + let sharedValue = Node(value: 21) + let map: [Int8: Node?] = [1: sharedValue, 2: sharedValue] + let mapData = try fory.serialize(map) + let mapReader = ByteBuffer(data: mapData) + _ = try fory.readHead(buffer: mapReader) + _ = try mapReader.readInt8() + _ = try mapReader.readVarUInt32() + _ = try mapReader.readVarUInt32() + let mapChunkHeader = try mapReader.readUInt8() + #expect((mapChunkHeader & 0b0000_1000) != 0) + + let decodedMap: [Int8: Node?] = try fory.deserialize(mapData) + let v1 = decodedMap[1] ?? nil + let v2 = decodedMap[2] ?? nil + #expect(v1 != nil) + #expect(v1 === v2) } @Test func macroFieldOrderFollowsForyRules() throws { - let fory = Fory() - fory.register(FieldOrder.self, id: 300) + let fory = Fory() + fory.register(FieldOrder.self, id: 300) - let value = FieldOrder(textTail: "tail", longValue: 123456789, shortValue: 17, intValue: 99) - let data = try fory.serialize(value) + let value = FieldOrder(textTail: "tail", longValue: 123_456_789, shortValue: 17, intValue: 99) + let data = try fory.serialize(value) - let buffer = ByteBuffer(data: data) - _ = try fory.readHead(buffer: buffer) - _ = try buffer.readInt8() // root ref flag - _ = try buffer.readVarUInt32() // type id - _ = try buffer.readVarUInt32() // user type id - _ = try buffer.readInt32() // schema hash + let buffer = ByteBuffer(data: data) + _ = try fory.readHead(buffer: buffer) + _ = try buffer.readInt8() // root ref flag + _ = try buffer.readVarUInt32() // type id + _ = try buffer.readVarUInt32() // user type id + _ = try buffer.readInt32() // schema hash - let first = try buffer.readInt16() - let second = try buffer.readVarInt64() - let third = try buffer.readVarInt32() + let first = try buffer.readInt16() + let second = try buffer.readVarInt64() + let third = try buffer.readVarInt32() - let tailContext = ReadContext(buffer: buffer, typeResolver: fory.typeResolver, trackRef: false) - let fourth = try String.foryReadData(tailContext) + let tailContext = ReadContext(buffer: buffer, typeResolver: fory.typeResolver, trackRef: false) + let fourth = try String.foryReadData(tailContext) - #expect(first == value.shortValue) - #expect(second == value.longValue) - #expect(third == value.intValue) - #expect(fourth == value.textTail) + #expect(first == value.shortValue) + #expect(second == value.longValue) + #expect(third == value.intValue) + #expect(fourth == value.textTail) } @Test func macroTaggedFieldsKeepGroupedPayloadOrder() throws { - let fory = Fory() - fory.register(TaggedFieldOrder.self, id: 303) + let fory = Fory() + fory.register(TaggedFieldOrder.self, id: 303) - let fields = TaggedFieldOrder.foryFieldsInfo(trackRef: false) - #expect(fields.map(\.fieldName) == ["intValue", "textTail"]) - #expect(fields.map(\.fieldID) == [10, 1]) + let fields = TaggedFieldOrder.foryFieldsInfo(trackRef: false) + #expect(fields.map(\.fieldName) == ["intValue", "textTail"]) + #expect(fields.map(\.fieldID) == [10, 1]) - let value = TaggedFieldOrder(textTail: "tail", intValue: 99) - let data = try fory.serialize(value) - let buffer = ByteBuffer(data: data) - _ = try fory.readHead(buffer: buffer) - _ = try buffer.readInt8() - _ = try buffer.readVarUInt32() - _ = try buffer.readVarUInt32() - _ = try buffer.readInt32() + let value = TaggedFieldOrder(textTail: "tail", intValue: 99) + let data = try fory.serialize(value) + let buffer = ByteBuffer(data: data) + _ = try fory.readHead(buffer: buffer) + _ = try buffer.readInt8() + _ = try buffer.readVarUInt32() + _ = try buffer.readVarUInt32() + _ = try buffer.readInt32() - #expect(try buffer.readVarInt32() == value.intValue) - let tailContext = ReadContext(buffer: buffer, typeResolver: fory.typeResolver, trackRef: false) - #expect(try String.foryReadData(tailContext) == value.textTail) + #expect(try buffer.readVarInt32() == value.intValue) + let tailContext = ReadContext(buffer: buffer, typeResolver: fory.typeResolver, trackRef: false) + #expect(try String.foryReadData(tailContext) == value.textTail) } @Test func macroFieldEncodingOverridesForUnsignedTypes() throws { - let fory = Fory() - fory.register(EncodedNumberFields.self, id: 301) + let fory = Fory() + fory.register(EncodedNumberFields.self, id: 301) - let value = EncodedNumberFields( - u32Fixed: 0x11223344, - u64Tagged: UInt64(Int32.max) + 99 - ) - let data = try fory.serialize(value) - let decoded: EncodedNumberFields = try fory.deserialize(data) - #expect(decoded == value) + let value = EncodedNumberFields( + u32Fixed: 0x1122_3344, + u64Tagged: UInt64(Int32.max) + 99 + ) + let data = try fory.serialize(value) + let decoded: EncodedNumberFields = try fory.deserialize(data) + #expect(decoded == value) - let buffer = ByteBuffer(data: data) - _ = try fory.readHead(buffer: buffer) - _ = try buffer.readInt8() - _ = try buffer.readVarUInt32() - _ = try buffer.readVarUInt32() - _ = try buffer.readInt32() + let buffer = ByteBuffer(data: data) + _ = try fory.readHead(buffer: buffer) + _ = try buffer.readInt8() + _ = try buffer.readVarUInt32() + _ = try buffer.readVarUInt32() + _ = try buffer.readInt32() - #expect(try buffer.readUInt32() == value.u32Fixed) - #expect(try buffer.readTaggedUInt64() == value.u64Tagged) + #expect(try buffer.readUInt32() == value.u32Fixed) + #expect(try buffer.readTaggedUInt64() == value.u64Tagged) } @Test func macroEnumUsesExplicitIntegerRawValue() throws { - let fory = Fory(config: .init(xlang: true, trackRef: false)) - fory.register(SparseStatus.self, id: 302) + let fory = Fory(config: .init(xlang: true, trackRef: false)) + fory.register(SparseStatus.self, id: 302) - let data = try fory.serialize(SparseStatus.ok) - let buffer = ByteBuffer(data: data) - _ = try fory.readHead(buffer: buffer) - _ = try buffer.readInt8() - _ = try buffer.readVarUInt32() - _ = try buffer.readVarUInt32() - #expect(try buffer.readVarUInt32() == 8192) + let data = try fory.serialize(SparseStatus.ok) + let buffer = ByteBuffer(data: data) + _ = try fory.readHead(buffer: buffer) + _ = try buffer.readInt8() + _ = try buffer.readVarUInt32() + _ = try buffer.readVarUInt32() + #expect(try buffer.readVarUInt32() == 8192) - let decoded: SparseStatus = try fory.deserialize(data) - #expect(decoded == .ok) + let decoded: SparseStatus = try fory.deserialize(data) + #expect(decoded == .ok) } @Test func macroFieldEncodingOverridesCompatibleTypeMeta() throws { - let fields = EncodedNumberFields.foryFieldsInfo(trackRef: false) - #expect(fields.count == 2) - #expect(fields[0].fieldName == "u32Fixed") - #expect(fields[0].fieldType.typeID == TypeId.uint32.rawValue) - #expect(fields[1].fieldName == "u64Tagged") - #expect(fields[1].fieldType.typeID == TypeId.taggedUInt64.rawValue) + let fields = EncodedNumberFields.foryFieldsInfo(trackRef: false) + #expect(fields.count == 2) + #expect(fields[0].fieldName == "u32Fixed") + #expect(fields[0].fieldType.typeID == TypeId.uint32.rawValue) + #expect(fields[1].fieldName == "u64Tagged") + #expect(fields[1].fieldType.typeID == TypeId.taggedUInt64.rawValue) } @Test func macroReducedPrecisionFieldsUseXlangTypeIDs() { - let fields = ReducedPrecisionMacroFields.foryFieldsInfo(trackRef: false) - #expect(fields.count == 4) - #expect(fields.map(\.fieldName) == ["float16Value", "bfloat16Value", "float16Array", "bfloat16Array"]) - #expect(fields.map(\.fieldType.typeID) == [ - TypeId.float16.rawValue, - TypeId.bfloat16.rawValue, - TypeId.float16Array.rawValue, - TypeId.bfloat16Array.rawValue + let fields = ReducedPrecisionMacroFields.foryFieldsInfo(trackRef: false) + #expect(fields.count == 4) + #expect( + fields.map(\.fieldName) == ["float16Value", "bfloat16Value", "float16Array", "bfloat16Array"]) + #expect( + fields.map(\.fieldType.typeID) == [ + TypeId.float16.rawValue, + TypeId.bfloat16.rawValue, + TypeId.float16Array.rawValue, + TypeId.bfloat16Array.rawValue ]) } @Test func macroFieldIDsPopulateCompatibleTypeMeta() { - let fields = FieldIdConfigured.foryFieldsInfo(trackRef: false) - #expect(fields.count == 2) - - var byID: [Int16: TypeMeta.FieldInfo] = [:] - for field in fields { - if let id = field.fieldID { - byID[id] = field - } + let fields = FieldIdConfigured.foryFieldsInfo(trackRef: false) + #expect(fields.count == 2) + + var byID: [Int16: TypeMeta.FieldInfo] = [:] + for field in fields { + if let id = field.fieldID { + byID[id] = field } + } - #expect(byID[2]?.fieldName == "stableID") - #expect(byID[2]?.fieldType.typeID == TypeId.varint32.rawValue) - #expect(byID[5]?.fieldName == "fixedValue") - #expect(byID[5]?.fieldType.typeID == TypeId.int32.rawValue) + #expect(byID[2]?.fieldName == "stableID") + #expect(byID[2]?.fieldType.typeID == TypeId.varint32.rawValue) + #expect(byID[5]?.fieldName == "fixedValue") + #expect(byID[5]?.fieldType.typeID == TypeId.int32.rawValue) } @Test func macroFieldIDsDriveCompatibleStructDecodeAcrossRenames() throws { - let writer = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) - writer.register(FieldIdSource.self, id: 9101) + let writer = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + writer.register(FieldIdSource.self, id: 9101) - let reader = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) - reader.register(FieldIdTarget.self, id: 9101) + let reader = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + reader.register(FieldIdTarget.self, id: 9101) - let source = FieldIdSource(value: 42, label: "alpha") - let bytes = try writer.serialize(source) - let decoded: FieldIdTarget = try reader.deserialize(bytes) + let source = FieldIdSource(value: 42, label: "alpha") + let bytes = try writer.serialize(source) + let decoded: FieldIdTarget = try reader.deserialize(bytes) - #expect(decoded.renamedValue == source.value) - #expect(decoded.renamedLabel == source.label) + #expect(decoded.renamedValue == source.value) + #expect(decoded.renamedLabel == source.label) - let roundTrip = try reader.serialize(decoded) - let back: FieldIdSource = try writer.deserialize(roundTrip) - #expect(back == source) + let roundTrip = try reader.serialize(decoded) + let back: FieldIdSource = try writer.deserialize(roundTrip) + #expect(back == source) } @Test func macroFieldIDsDriveTaggedUnionDecodeAcrossRenames() throws { - let writer = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) - writer.register(FieldIdUnionSource.self, id: 9102) + let writer = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + writer.register(FieldIdUnionSource.self, id: 9102) - let reader = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) - reader.register(FieldIdUnionTarget.self, id: 9102) + let reader = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + reader.register(FieldIdUnionTarget.self, id: 9102) - let source = FieldIdUnionSource.number(123) - let bytes = try writer.serialize(source) - let decoded: FieldIdUnionTarget = try reader.deserialize(bytes) + let source = FieldIdUnionSource.number(123) + let bytes = try writer.serialize(source) + let decoded: FieldIdUnionTarget = try reader.deserialize(bytes) - switch decoded { - case .renamedNumber(let value): - #expect(value == 123) - default: - #expect(Bool(false)) - } + switch decoded { + case .renamedNumber(let value): + #expect(value == 123) + default: + #expect(Bool(false)) + } } @Test func compatibleNestedStructArrayRoundTrip() throws { - let writer = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) - writer.register(CompatibleNestedItem.self, id: 9103) - writer.register(CompatibleNestedArrayHolder.self, id: 9104) - - let reader = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) - reader.register(CompatibleNestedItem.self, id: 9103) - reader.register(CompatibleNestedArrayHolder.self, id: 9104) - - let value = CompatibleNestedArrayHolder( - items: [ - CompatibleNestedItem(id: 1, name: "alpha"), - CompatibleNestedItem(id: 2, name: "beta") - ] - ) - let bytes = try writer.serialize(value) - let decoded: CompatibleNestedArrayHolder = try reader.deserialize(bytes) - #expect(decoded == value) + let writer = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + writer.register(CompatibleNestedItem.self, id: 9103) + writer.register(CompatibleNestedArrayHolder.self, id: 9104) + + let reader = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + reader.register(CompatibleNestedItem.self, id: 9103) + reader.register(CompatibleNestedArrayHolder.self, id: 9104) + + let value = CompatibleNestedArrayHolder( + items: [ + CompatibleNestedItem(id: 1, name: "alpha"), + CompatibleNestedItem(id: 2, name: "beta") + ] + ) + let bytes = try writer.serialize(value) + let decoded: CompatibleNestedArrayHolder = try reader.deserialize(bytes) + #expect(decoded == value) } @Test func compatibleNestedStructOptionalArrayRoundTrip() throws { - let writer = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) - writer.register(CompatibleNestedItem.self, id: 9103) - writer.register(CompatibleNestedOptionalArrayHolder.self, id: 9105) - - let reader = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) - reader.register(CompatibleNestedItem.self, id: 9103) - reader.register(CompatibleNestedOptionalArrayHolder.self, id: 9105) - - let value = CompatibleNestedOptionalArrayHolder( - items: [ - CompatibleNestedItem(id: 1, name: "alpha"), - nil, - CompatibleNestedItem(id: 2, name: "beta") - ] - ) - let bytes = try writer.serialize(value) - let decoded: CompatibleNestedOptionalArrayHolder = try reader.deserialize(bytes) - #expect(decoded == value) + let writer = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + writer.register(CompatibleNestedItem.self, id: 9103) + writer.register(CompatibleNestedOptionalArrayHolder.self, id: 9105) + + let reader = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + reader.register(CompatibleNestedItem.self, id: 9103) + reader.register(CompatibleNestedOptionalArrayHolder.self, id: 9105) + + let value = CompatibleNestedOptionalArrayHolder( + items: [ + CompatibleNestedItem(id: 1, name: "alpha"), + nil, + CompatibleNestedItem(id: 2, name: "beta") + ] + ) + let bytes = try writer.serialize(value) + let decoded: CompatibleNestedOptionalArrayHolder = try reader.deserialize(bytes) + #expect(decoded == value) } @Test func compatibleNestedStructMapRoundTrip() throws { - let writer = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) - writer.register(CompatibleNestedItem.self, id: 9103) - writer.register(CompatibleNestedMapHolder.self, id: 9106) - - let reader = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) - reader.register(CompatibleNestedItem.self, id: 9103) - reader.register(CompatibleNestedMapHolder.self, id: 9106) - - let value = CompatibleNestedMapHolder( - items: [ - 1: CompatibleNestedItem(id: 10, name: "first"), - 2: CompatibleNestedItem(id: 20, name: "second") - ] - ) - let bytes = try writer.serialize(value) - let decoded: CompatibleNestedMapHolder = try reader.deserialize(bytes) - #expect(decoded == value) + let writer = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + writer.register(CompatibleNestedItem.self, id: 9103) + writer.register(CompatibleNestedMapHolder.self, id: 9106) + + let reader = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + reader.register(CompatibleNestedItem.self, id: 9103) + reader.register(CompatibleNestedMapHolder.self, id: 9106) + + let value = CompatibleNestedMapHolder( + items: [ + 1: CompatibleNestedItem(id: 10, name: "first"), + 2: CompatibleNestedItem(id: 20, name: "second") + ] + ) + let bytes = try writer.serialize(value) + let decoded: CompatibleNestedMapHolder = try reader.deserialize(bytes) + #expect(decoded == value) } @Test func pvlVarInt64AndVarUInt64Extremes() throws { - let uintValues: [UInt64] = [ - 0, - 1, - 127, - 128, - 16_383, - 16_384, - 2_097_151, - 2_097_152, - 268_435_455, - 268_435_456, - 34_359_738_367, - 34_359_738_368, - 4_398_046_511_103, - 4_398_046_511_104, - 562_949_953_421_311, - 562_949_953_421_312, - 72_057_594_037_927_935, - 72_057_594_037_927_936, - UInt64(Int64.max), - UInt64.max - ] - let intValues: [Int64] = [ - Int64.min, - Int64.min + 1, - -1_000_000_000_000, - -1_000_000, - -1_000, - -128, - -1, - 0, - 1, - 127, - 1_000, - 1_000_000, - 1_000_000_000_000, - Int64.max - 1, - Int64.max - ] - - let writeBuffer = ByteBuffer() - for value in uintValues { - writeBuffer.writeVarUInt64(value) - } - for value in intValues { - writeBuffer.writeVarInt64(value) - } - let minBuffer = ByteBuffer() - minBuffer.writeVarInt64(Int64.min) - #expect(minBuffer.count == 9) - #expect(minBuffer.storage.prefix(minBuffer.count).allSatisfy { $0 == 0xFF }) - - let encoded = Array(writeBuffer.storage.prefix(writeBuffer.count)) - - let readBuffer = ByteBuffer(bytes: encoded) - for value in uintValues { - #expect(try readBuffer.readVarUInt64() == value) - } - for value in intValues { - #expect(try readBuffer.readVarInt64() == value) - } - #expect(readBuffer.remaining == 0) + let uintValues: [UInt64] = [ + 0, + 1, + 127, + 128, + 16_383, + 16_384, + 2_097_151, + 2_097_152, + 268_435_455, + 268_435_456, + 34_359_738_367, + 34_359_738_368, + 4_398_046_511_103, + 4_398_046_511_104, + 562_949_953_421_311, + 562_949_953_421_312, + 72_057_594_037_927_935, + 72_057_594_037_927_936, + UInt64(Int64.max), + UInt64.max + ] + let intValues: [Int64] = [ + Int64.min, + Int64.min + 1, + -1_000_000_000_000, + -1_000_000, + -1_000, + -128, + -1, + 0, + 1, + 127, + 1_000, + 1_000_000, + 1_000_000_000_000, + Int64.max - 1, + Int64.max + ] + + let writeBuffer = ByteBuffer() + for value in uintValues { + writeBuffer.writeVarUInt64(value) + } + for value in intValues { + writeBuffer.writeVarInt64(value) + } + let minBuffer = ByteBuffer() + minBuffer.writeVarInt64(Int64.min) + #expect(minBuffer.count == 9) + #expect(minBuffer.storage.prefix(minBuffer.count).allSatisfy { $0 == 0xFF }) + + let encoded = Array(writeBuffer.storage.prefix(writeBuffer.count)) + + let readBuffer = ByteBuffer(bytes: encoded) + for value in uintValues { + #expect(try readBuffer.readVarUInt64() == value) + } + for value in intValues { + #expect(try readBuffer.readVarInt64() == value) + } + #expect(readBuffer.remaining == 0) } @Test func metaStringEncodingRoundTrip() throws { - let encoder = MetaStringEncoder.fieldName - let decoder = MetaStringDecoder.fieldName + let encoder = MetaStringEncoder.fieldName + let decoder = MetaStringDecoder.fieldName - let lower = try encoder.encode("alpha_beta", encoding: .lowerSpecial) - #expect(lower.encoding == .lowerSpecial) - #expect(try decoder.decode(bytes: lower.bytes, encoding: lower.encoding).value == "alpha_beta") + let lower = try encoder.encode("alpha_beta", encoding: .lowerSpecial) + #expect(lower.encoding == .lowerSpecial) + #expect(try decoder.decode(bytes: lower.bytes, encoding: lower.encoding).value == "alpha_beta") - let firstLower = try encoder.encode("User_name", encoding: .firstToLowerSpecial) - #expect(firstLower.encoding == .firstToLowerSpecial) - #expect(try decoder.decode(bytes: firstLower.bytes, encoding: firstLower.encoding).value == "User_name") + let firstLower = try encoder.encode("User_name", encoding: .firstToLowerSpecial) + #expect(firstLower.encoding == .firstToLowerSpecial) + #expect( + try decoder.decode(bytes: firstLower.bytes, encoding: firstLower.encoding).value == "User_name") - let allLower = try encoder.encode("MyHTTPType", encoding: .allToLowerSpecial) - #expect(allLower.encoding == .allToLowerSpecial) - #expect(try decoder.decode(bytes: allLower.bytes, encoding: allLower.encoding).value == "MyHTTPType") + let allLower = try encoder.encode("MyHTTPType", encoding: .allToLowerSpecial) + #expect(allLower.encoding == .allToLowerSpecial) + #expect( + try decoder.decode(bytes: allLower.bytes, encoding: allLower.encoding).value == "MyHTTPType") - let lowerUpperDigit = try encoder.encode("userId2", encoding: .lowerUpperDigitSpecial) - #expect(lowerUpperDigit.encoding == .lowerUpperDigitSpecial) - #expect(try decoder.decode(bytes: lowerUpperDigit.bytes, encoding: lowerUpperDigit.encoding).value == "userId2") + let lowerUpperDigit = try encoder.encode("userId2", encoding: .lowerUpperDigitSpecial) + #expect(lowerUpperDigit.encoding == .lowerUpperDigitSpecial) + #expect( + try decoder.decode(bytes: lowerUpperDigit.bytes, encoding: lowerUpperDigit.encoding).value + == "userId2") - let autoUtf8 = try encoder.encode("naïve_meta") - #expect(autoUtf8.encoding == .utf8) - #expect(try decoder.decode(bytes: autoUtf8.bytes, encoding: autoUtf8.encoding).value == "naïve_meta") + let autoUtf8 = try encoder.encode("naïve_meta") + #expect(autoUtf8.encoding == .utf8) + #expect( + try decoder.decode(bytes: autoUtf8.bytes, encoding: autoUtf8.encoding).value == "naïve_meta") } @Test func typeMetaRoundTripByName() throws { - let namespace = try MetaStringEncoder.namespace.encode("com.example") - let typeName = try MetaStringEncoder.typeName.encode("UserProfile") - - let fields: [TypeMeta.FieldInfo] = [ - .init( - fieldID: nil, - fieldName: "createdAt", - fieldType: .init(typeID: TypeId.varint64.rawValue, nullable: false) - ), - .init( - fieldID: nil, - fieldName: "tags", - fieldType: .init( - typeID: TypeId.list.rawValue, - nullable: false, - generics: [.init(typeID: TypeId.string.rawValue, nullable: true)] - ) - ), - .init( - fieldID: nil, - fieldName: "attributes", - fieldType: .init( - typeID: TypeId.map.rawValue, - nullable: true, - generics: [ - .init(typeID: TypeId.string.rawValue, nullable: false), - .init(typeID: TypeId.varint32.rawValue, nullable: true) - ] - ) - ), - .init( - fieldID: 7, - fieldName: "ignored_for_tag_mode", - fieldType: .init(typeID: TypeId.varint32.rawValue, nullable: false) - ) - ] - - let meta = try TypeMeta( - typeID: nil, - userTypeID: nil, - namespace: namespace, - typeName: typeName, - registerByName: true, - fields: fields + let namespace = try MetaStringEncoder.namespace.encode("com.example") + let typeName = try MetaStringEncoder.typeName.encode("UserProfile") + + let fields: [TypeMeta.FieldInfo] = [ + .init( + fieldID: nil, + fieldName: "createdAt", + fieldType: .init(typeID: TypeId.varint64.rawValue, nullable: false) + ), + .init( + fieldID: nil, + fieldName: "tags", + fieldType: .init( + typeID: TypeId.list.rawValue, + nullable: false, + generics: [.init(typeID: TypeId.string.rawValue, nullable: true)] + ) + ), + .init( + fieldID: nil, + fieldName: "attributes", + fieldType: .init( + typeID: TypeId.map.rawValue, + nullable: true, + generics: [ + .init(typeID: TypeId.string.rawValue, nullable: false), + .init(typeID: TypeId.varint32.rawValue, nullable: true) + ] + ) + ), + .init( + fieldID: 7, + fieldName: "ignored_for_tag_mode", + fieldType: .init(typeID: TypeId.varint32.rawValue, nullable: false) ) - - let encoded = try meta.encode() - let decoded = try TypeMeta.decode(encoded) - - #expect(decoded.registerByName == true) - #expect(decoded.namespace.value == "com.example") - #expect(decoded.typeName.value == "UserProfile") - #expect(decoded.typeID == nil) - #expect(decoded.userTypeID == nil) - #expect(decoded.fields.count == 4) - #expect(decoded.fields[0].fieldName == "created_at") - #expect(decoded.fields[3].fieldID == 7) + ] + + let meta = try TypeMeta( + typeID: TypeId.namedStruct.rawValue, + userTypeID: nil, + namespace: namespace, + typeName: typeName, + registerByName: true, + fields: fields + ) + + let encoded = try meta.encode() + let decoded = try TypeMeta.decode(encoded) + + #expect(decoded.registerByName == true) + #expect(decoded.namespace.value == "com.example") + #expect(decoded.typeName.value == "UserProfile") + #expect(decoded.typeID == TypeId.namedStruct.rawValue) + #expect(decoded.userTypeID == nil) + #expect(decoded.fields.count == 4) + #expect(decoded.fields[0].fieldName == "created_at") + #expect(decoded.fields[3].fieldID == 7) } @Test func typeMetaRoundTripByID() throws { - let emptyNamespace = MetaString.empty(specialChar1: ".", specialChar2: "_") - let emptyTypeName = MetaString.empty(specialChar1: "$", specialChar2: "_") - - let meta = try TypeMeta( - typeID: TypeId.structType.rawValue, - userTypeID: 101, - namespace: emptyNamespace, - typeName: emptyTypeName, - registerByName: false, - fields: [] - ) - - let encoded = try meta.encode() - let decoded = try TypeMeta.decode(encoded) - - #expect(decoded.registerByName == false) - #expect(decoded.typeID == TypeId.structType.rawValue) - #expect(decoded.userTypeID == 101) - #expect(decoded.fields.isEmpty) + let emptyNamespace = MetaString.empty(specialChar1: ".", specialChar2: "_") + let emptyTypeName = MetaString.empty(specialChar1: "$", specialChar2: "_") + + let meta = try TypeMeta( + typeID: TypeId.structType.rawValue, + userTypeID: 101, + namespace: emptyNamespace, + typeName: emptyTypeName, + registerByName: false, + fields: [] + ) + + let encoded = try meta.encode() + let decoded = try TypeMeta.decode(encoded) + + #expect(decoded.registerByName == false) + #expect(decoded.typeID == TypeId.structType.rawValue) + #expect(decoded.userTypeID == 101) + #expect(decoded.fields.isEmpty) } From ad3aa62737979ed06094aad5bfc7378161477492 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=85=95=E7=99=BD?= Date: Wed, 6 May 2026 01:28:50 +0800 Subject: [PATCH 04/10] fix: cap metadata caches and validate root headers --- README.md | 4 +- cpp/fory/serialization/BUILD | 1 + cpp/fory/serialization/serialization_test.cc | 26 ++++++++ cpp/fory/serialization/type_resolver.cc | 3 - csharp/src/Fory/ReadContext.cs | 18 ++++-- .../tests/Fory.Tests/RuntimeEdgeCaseTests.cs | 23 ++++++++ .../packages/fory/lib/src/meta/type_meta.dart | 12 +++- .../fory/test/xlang_protocol_test.dart | 49 +++++++++++++++ .../apache/fory/context/MetaStringReader.java | 4 ++ .../apache/fory/resolver/TypeResolver.java | 3 +- python/pyfory/_fory.py | 8 +-- python/pyfory/serialization.pyx | 12 ++-- python/pyfory/tests/test_root_header.py | 59 +++++++++++++++++++ rust/fory-core/src/resolver/meta_resolver.rs | 59 ++++++++++++++++++- swift/Sources/Fory/TypeResolver.swift | 10 +++- swift/Tests/ForyTests/ForySwiftTests.swift | 25 ++++++++ 16 files changed, 288 insertions(+), 28 deletions(-) create mode 100644 python/pyfory/tests/test_root_header.py diff --git a/README.md b/README.md index 2890af49b8..a75ddad9a1 100644 --- a/README.md +++ b/README.md @@ -744,7 +744,7 @@ Apache Fory™ supports class schema forward/backward compatibility across **Jav ### Binary Compatibility -**Current Status**: Binary compatibility is **not guaranteed** between Fory major releases as the protocol continues to evolve. However, compatibility **is guaranteed** between minor versions (e.g., 0.13.x). +**Current Status**: Binary compatibility is **not guaranteed** between Fory major releases as the protocol continues to evolve. Compatibility **is guaranteed** between minor versions (for example, 0.13.x). **Recommendations**: @@ -752,7 +752,7 @@ Apache Fory™ supports class schema forward/backward compatibility across **Jav - Plan migration strategies when upgrading major versions - See [upgrade guide](docs/guide/java) for details -**Future**: Binary compatibility will be guaranteed starting from Fory 1.0 release. +Major-version compatibility is the boundary for stable serialized data. ## Security diff --git a/cpp/fory/serialization/BUILD b/cpp/fory/serialization/BUILD index 87b3ede02c..5da749d9e2 100644 --- a/cpp/fory/serialization/BUILD +++ b/cpp/fory/serialization/BUILD @@ -51,6 +51,7 @@ cc_test( srcs = ["serialization_test.cc"], deps = [ ":fory_serialization", + "//cpp/fory/thirdparty:libmmh3", "@googletest//:gtest", "@googletest//:gtest_main", ], diff --git a/cpp/fory/serialization/serialization_test.cc b/cpp/fory/serialization/serialization_test.cc index 3bfb4ef39e..4fb31c5dda 100644 --- a/cpp/fory/serialization/serialization_test.cc +++ b/cpp/fory/serialization/serialization_test.cc @@ -20,6 +20,7 @@ #include "fory/serialization/fory.h" #include "fory/serialization/ref_resolver.h" #include "fory/serialization/skip.h" +#include "fory/thirdparty/MurmurHash3.h" #include "gtest/gtest.h" #include #include @@ -80,6 +81,23 @@ namespace fory { namespace serialization { namespace test { +namespace { + +uint64_t compute_type_meta_hash_bits_for_test(const uint8_t *meta_bytes, + size_t meta_size) { + constexpr uint32_t kHashShift = 12; + constexpr uint64_t kHashBitsMask = UINT64_MAX << kHashShift; + int64_t hash_out[2] = {0, 0}; + MurmurHash3_x64_128(meta_bytes, static_cast(meta_size), 47, hash_out); + uint64_t shifted = static_cast(hash_out[0]) << kHashShift; + if (static_cast(shifted) < 0) { + shifted = ~shifted + 1; + } + return shifted & kHashBitsMask; +} + +} // namespace + // ============================================================================ // Test Helpers // ============================================================================ @@ -861,11 +879,19 @@ TEST(SerializationTest, TypeMetaRejectsNonStructReservedKindBits) { std::vector bytes = bytes_result.value(); bytes[sizeof(uint64_t)] |= 0x10; + uint64_t header = 0; + std::memcpy(&header, bytes.data(), sizeof(header)); + ASSERT_NE(header & 0xff, 0xff); + header &= ~(UINT64_MAX << 12); + header |= compute_type_meta_hash_bits_for_test( + bytes.data() + sizeof(uint64_t), bytes.size() - sizeof(uint64_t)); + std::memcpy(bytes.data(), &header, sizeof(header)); Buffer buffer(bytes); auto parsed = TypeMeta::from_bytes(buffer, nullptr); ASSERT_FALSE(parsed.ok()); EXPECT_EQ(parsed.error().code(), ErrorCode::InvalidData); + EXPECT_NE(parsed.error().to_string().find("kind header"), std::string::npos); } TEST(SerializationTest, TypeMetaRejectsReservedHeaderBits) { diff --git a/cpp/fory/serialization/type_resolver.cc b/cpp/fory/serialization/type_resolver.cc index 460218399d..b5ffb5917e 100644 --- a/cpp/fory/serialization/type_resolver.cc +++ b/cpp/fory/serialization/type_resolver.cc @@ -195,7 +195,6 @@ Result, Error> FieldInfo::to_bytes() const { buffer.write_bytes(encoded_name.data(), encoded_name.size()); } - // CRITICAL FIX: Use writer_index() not size() to get actual bytes written! return std::vector(buffer.data(), buffer.data() + buffer.writer_index()); } @@ -690,7 +689,6 @@ TypeMeta::from_bytes(Buffer &buffer, const TypeMeta *local_type_info) { assign_field_ids(local_type_info, field_infos); } - // CRITICAL FIX: Ensure we consume exactly meta_size bytes size_t current_pos = buffer.reader_index(); size_t expected_end_pos = start_pos + header_size + meta_size; if (FORY_PREDICT_FALSE(current_pos > expected_end_pos)) { @@ -802,7 +800,6 @@ TypeMeta::from_bytes_with_header(Buffer &buffer, int64_t header) { // NOTE: Do NOT sort remote fields! They are already in the sender's sorted // order, which matches the data order. - // CRITICAL FIX: Ensure we consume exactly meta_size bytes size_t current_pos = buffer.reader_index(); size_t expected_end_pos = static_cast(start_pos) + meta_size; if (FORY_PREDICT_FALSE(current_pos > expected_end_pos)) { diff --git a/csharp/src/Fory/ReadContext.cs b/csharp/src/Fory/ReadContext.cs index beddd01f89..2825653c76 100644 --- a/csharp/src/Fory/ReadContext.cs +++ b/csharp/src/Fory/ReadContext.cs @@ -161,14 +161,24 @@ internal bool TryGetCachedReadTypeMeta(ulong header, out TypeMeta typeMeta, out internal void CacheReadTypeMeta(ulong header, TypeMeta typeMeta, int skipBytesAfterHeader) { + if (_cachedTypeMetasByHeader.TryGetValue(header, out CachedTypeMetaEntry existing)) + { + _lastMetaHeader = header; + _lastTypeMeta = existing; + _hasLastMetaHeader = true; + return; + } + + if (_cachedTypeMetasByHeader.Count >= MaxParsedTypeMetaEntries) + { + return; + } + CachedTypeMetaEntry cached = new(typeMeta, skipBytesAfterHeader); _lastMetaHeader = header; _lastTypeMeta = cached; _hasLastMetaHeader = true; - if (_cachedTypeMetasByHeader.Count < MaxParsedTypeMetaEntries) - { - _cachedTypeMetasByHeader.TryAdd(header, cached); - } + _cachedTypeMetasByHeader.TryAdd(header, cached); } internal MetaString? GetReadMetaString(int index) diff --git a/csharp/tests/Fory.Tests/RuntimeEdgeCaseTests.cs b/csharp/tests/Fory.Tests/RuntimeEdgeCaseTests.cs index 171b4bf5d7..c8452b63a4 100644 --- a/csharp/tests/Fory.Tests/RuntimeEdgeCaseTests.cs +++ b/csharp/tests/Fory.Tests/RuntimeEdgeCaseTests.cs @@ -435,6 +435,29 @@ public void DeserializeRejectsUnsupportedRootHeaderBits() } } + [Fact] + public void TypeMetaHeaderCacheStopsPublishingAtCapacity() + { + ReadContext context = new(new ByteReader(Array.Empty()), new TypeResolver(), trackRef: false); + TypeMeta typeMeta = new( + (uint)TypeId.Struct, + 901, + MetaString.Empty('.', '_'), + MetaString.Empty('$', '_'), + registerByName: false, + []); + + for (ulong header = 1; header <= 8192; header++) + { + context.CacheReadTypeMeta(header, typeMeta, skipBytesAfterHeader: 0); + } + + Assert.True(context.TryGetCachedReadTypeMeta(8192, out _, out _)); + context.CacheReadTypeMeta(8193, typeMeta, skipBytesAfterHeader: 0); + + Assert.False(context.TryGetCachedReadTypeMeta(8193, out _, out _)); + } + [Fact] public void DynamicAnyRejectsUnknownUserTypeId() { diff --git a/dart/packages/fory/lib/src/meta/type_meta.dart b/dart/packages/fory/lib/src/meta/type_meta.dart index 0602f040be..8b32f199fa 100644 --- a/dart/packages/fory/lib/src/meta/type_meta.dart +++ b/dart/packages/fory/lib/src/meta/type_meta.dart @@ -120,9 +120,17 @@ final class ParsedTypeMetaCache { @pragma('vm:prefer-inline') void remember(TypeHeader header, TypeInfo resolved) { - if (!_entries.containsKey(header.value) && _entries.length >= maxEntries) { - _entries.remove(_entries.keys.first); + final cached = _entries[header.value]; + if (cached != null) { + _entries[header.value] = resolved; + _lastHeader = header.value; + _lastResolved = resolved; + return; } + if (_entries.length >= maxEntries) { + return; + } + _entries[header.value] = resolved; _lastHeader = header.value; _lastResolved = resolved; diff --git a/dart/packages/fory/test/xlang_protocol_test.dart b/dart/packages/fory/test/xlang_protocol_test.dart index 714edd2675..0f182af5e2 100644 --- a/dart/packages/fory/test/xlang_protocol_test.dart +++ b/dart/packages/fory/test/xlang_protocol_test.dart @@ -20,10 +20,28 @@ import 'dart:typed_data'; import 'package:fory/fory.dart'; +import 'package:fory/src/context/read_context.dart'; +import 'package:fory/src/context/write_context.dart'; import 'package:fory/src/meta/type_meta.dart'; +import 'package:fory/src/resolver/type_resolver.dart'; +import 'package:fory/src/serializer/serializer.dart'; +import 'package:fory/src/types/int64.dart'; import 'package:fory/src/util/hash_util.dart'; import 'package:test/test.dart'; +final class _CacheTestSerializer extends Serializer { + const _CacheTestSerializer(); + + @override + bool get supportsRef => false; + + @override + Object? read(ReadContext context) => null; + + @override + void write(WriteContext context, Object? value) {} +} + void main() { group('xlang protocol regressions', () { test('deserializes NONE wire values as null', () { @@ -102,6 +120,37 @@ void main() { ); }); + test('parsed TypeDef cache stops publishing at capacity', () { + const resolved = TypeInfo( + type: Object, + kind: RegistrationKind.builtin, + typeId: TypeIds.struct, + supportsRef: false, + serializer: _CacheTestSerializer(), + structSerializer: null, + userTypeId: null, + namespace: null, + typeName: null, + encodedNamespace: null, + encodedTypeName: null, + typeDef: null, + remoteTypeDef: null, + ); + final cache = ParsedTypeMetaCache(); + for (var i = 0; i < ParsedTypeMetaCache.maxEntries; i++) { + cache.remember(TypeHeader(Int64(i)), resolved); + } + + expect( + cache.lookup(TypeHeader(Int64(ParsedTypeMetaCache.maxEntries - 1))), + same(resolved), + ); + final uncached = TypeHeader(Int64(ParsedTypeMetaCache.maxEntries)); + cache.remember(uncached, resolved); + + expect(cache.lookup(uncached), isNull); + }); + test('validates parsed TypeDef body hash before caching', () { final body = Uint8List.fromList([0x80]); final header = TypeHeader(typeDefHeader(body)); diff --git a/java/fory-core/src/main/java/org/apache/fory/context/MetaStringReader.java b/java/fory-core/src/main/java/org/apache/fory/context/MetaStringReader.java index ccd85f8001..7b492ddcbb 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/MetaStringReader.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/MetaStringReader.java @@ -132,6 +132,8 @@ private EncodedMetaString readBigMetaString( MemoryBuffer buffer, EncodedMetaString cache, int len) { long hashCode = buffer.readInt64(); if (cache.hash == hashCode && cache.bytes.length == len) { + // Big meta-string hashes are the wire identity on this cache hit. The body hash is computed + // and checked before a new entry is published; later hits intentionally skip the body. buffer.checkReadableBytes(len); buffer._increaseReaderIndexUnsafe(len); return cache; @@ -143,6 +145,8 @@ private EncodedMetaString readBigMetaString(MemoryBuffer buffer, int len, long h buffer.checkReadableBytes(len); EncodedMetaString encodedMetaString = hash2MetaStringMap.get(hashCode); if (encodedMetaString != null && encodedMetaString.bytes.length == len) { + // Preserve the header-keyed fast path: entries reach this map only after their bytes matched + // the wire hash, so repeat hits advance over the redundant body without rehashing. buffer._increaseReaderIndexUnsafe(len); return encodedMetaString; } diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java index cf5569e033..9624020c89 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java @@ -759,7 +759,8 @@ protected final TypeInfo readTypeInfoFromBytes( assert packageNameBytesCache != null; simpleClassNameBytes = metaStringReader.readMetaString(buffer, typeNameBytesCache); - // MetaStringReader returns the provided cache object only after validating the encoded body. + // MetaStringReader returns the provided cache object only when the wire identity matches. For + // big meta strings, body-hash validation happens before the entry is first cached. if (typeNameBytesCache == simpleClassNameBytes && packageNameBytesCache == namespaceBytes) { return typeInfoCache; } diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py index 9583b729e4..df2cd21aa0 100644 --- a/python/pyfory/_fory.py +++ b/python/pyfory/_fory.py @@ -558,10 +558,10 @@ def _deserialize( if bool(bitmap & 1) != self.xlang: raise ValueError("Header bitmap mismatch at xlang bit") peer_out_of_band_enabled = bool(bitmap & 2) - if peer_out_of_band_enabled: - assert buffers is not None, "buffers shouldn't be null when the serialized stream is produced with buffer_callback not null." - else: - assert buffers is None, "buffers should be null when the serialized stream is produced with buffer_callback null." + if peer_out_of_band_enabled and buffers is None: + raise ValueError("Out-of-band buffers are required by the root header") + if not peer_out_of_band_enabled and buffers is not None: + raise ValueError("Out-of-band buffers were provided for an in-band root payload") read_context.prepare( buffer, buffers=buffers, diff --git a/python/pyfory/serialization.pyx b/python/pyfory/serialization.pyx index 954afb6f22..6c685ba185 100644 --- a/python/pyfory/serialization.pyx +++ b/python/pyfory/serialization.pyx @@ -1058,14 +1058,10 @@ cdef class Fory: if ((bitmap & 1) != 0) != self.xlang: raise ValueError("Header bitmap mismatch at xlang bit") peer_out_of_band_enabled = (bitmap & 2) != 0 - if peer_out_of_band_enabled: - assert buffers is not None, ( - "buffers shouldn't be null when the serialized stream is produced with buffer_callback not null." - ) - else: - assert buffers is None, ( - "buffers should be null when the serialized stream is produced with buffer_callback null." - ) + if peer_out_of_band_enabled and buffers is None: + raise ValueError("Out-of-band buffers are required by the root header") + if not peer_out_of_band_enabled and buffers is not None: + raise ValueError("Out-of-band buffers were provided for an in-band root payload") # Keep the root context setup inline. Top-level deserialize is a hot path, # so it should not pay an extra method call just to bind the active buffer. read_context.buffer = read_buffer diff --git a/python/pyfory/tests/test_root_header.py b/python/pyfory/tests/test_root_header.py new file mode 100644 index 0000000000..b51597579b --- /dev/null +++ b/python/pyfory/tests/test_root_header.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from pyfory import Fory + + +@pytest.mark.parametrize("xlang", [False, True]) +def test_root_header_rejects_reserved_bits(xlang): + fory = Fory(xlang=xlang) + data = bytearray(fory.serialize(1)) + data[0] |= 0x04 + + with pytest.raises(ValueError, match="Unsupported root header bitmap"): + fory.deserialize(bytes(data)) + + +@pytest.mark.parametrize("xlang", [False, True]) +def test_root_header_rejects_xlang_mismatch(xlang): + fory = Fory(xlang=xlang) + data = bytearray(fory.serialize(1)) + data[0] ^= 0x01 + + with pytest.raises(ValueError, match="xlang bit"): + fory.deserialize(bytes(data)) + + +@pytest.mark.parametrize("xlang", [False, True]) +def test_root_header_oob_flag_requires_buffers(xlang): + fory = Fory(xlang=xlang) + data = bytearray(fory.serialize(1)) + data[0] |= 0x02 + + with pytest.raises(ValueError, match="Out-of-band buffers are required"): + fory.deserialize(bytes(data)) + + +@pytest.mark.parametrize("xlang", [False, True]) +def test_root_header_rejects_buffers_without_oob_flag(xlang): + fory = Fory(xlang=xlang) + data = fory.serialize(1) + + with pytest.raises(ValueError, match="Out-of-band buffers were provided"): + fory.deserialize(data, buffers=[]) diff --git a/rust/fory-core/src/resolver/meta_resolver.rs b/rust/fory-core/src/resolver/meta_resolver.rs index c7d649aa7a..260a651615 100644 --- a/rust/fory-core/src/resolver/meta_resolver.rs +++ b/rust/fory-core/src/resolver/meta_resolver.rs @@ -242,9 +242,9 @@ impl MetaReaderResolver { // avoid malicious type defs to OOM parsed_type_infos self.parsed_type_infos .insert(meta_header, type_info.clone()); + self.last_meta_header = meta_header; + self.last_type_info = Some(type_info.clone()); } - self.last_meta_header = meta_header; - self.last_type_info = Some(type_info.clone()); self.reading_type_infos.push(type_info.clone()); Ok(type_info) } @@ -256,3 +256,58 @@ impl MetaReaderResolver { self.reading_type_infos.clear(); } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::meta::MetaString; + use crate::TypeId; + + #[test] + fn parsed_type_info_cache_does_not_publish_after_limit() { + let meta = TypeMeta::new( + TypeId::STRUCT as u32, + 9001, + MetaString::get_empty().clone(), + MetaString::get_empty().clone(), + false, + vec![], + ) + .unwrap(); + let type_def = meta.get_bytes().to_vec(); + let mut header_reader = Reader::new(&type_def); + let meta_header = header_reader.read_i64().unwrap(); + + let mut resolver = MetaReaderResolver::default(); + let cached_type_info = Rc::new(TypeInfo::from_remote_meta( + Rc::new(TypeMeta::empty().unwrap()), + None, + None, + None, + )); + let mut header = 0; + while resolver.parsed_type_infos.len() < MAX_PARSED_NUM_TYPE_DEFS { + if header != meta_header { + resolver + .parsed_type_infos + .insert(header, cached_type_info.clone()); + } + header += 1; + } + + let mut bytes = vec![]; + let mut writer = Writer::from_buffer(&mut bytes); + writer.write_var_u32(0); + writer.write_bytes(&type_def); + + let mut reader = Reader::new(&bytes); + let current = resolver + .read_type_meta(&mut reader, &TypeResolver::default()) + .unwrap(); + + assert_eq!(current.get_user_type_id(), 9001); + assert_eq!(resolver.parsed_type_infos.len(), MAX_PARSED_NUM_TYPE_DEFS); + assert!(!resolver.parsed_type_infos.contains_key(&meta_header)); + assert!(resolver.last_type_info.is_none()); + } +} diff --git a/swift/Sources/Fory/TypeResolver.swift b/swift/Sources/Fory/TypeResolver.swift index f1fd683174..a1ac088171 100644 --- a/swift/Sources/Fory/TypeResolver.swift +++ b/swift/Sources/Fory/TypeResolver.swift @@ -398,6 +398,8 @@ private struct TypeNameKey: Hashable { } final class TypeResolver { + private static let maxCachedTypeDefHeaders = 8192 + private let trackRef: Bool private var registrationFinished = false @@ -551,7 +553,9 @@ final class TypeResolver { } let localTypeInfo = try requireTypeInfo(for: typeMeta) if header == localTypeInfo.typeDefHeader { - typeInfoByHeader.set(localTypeInfo, for: header) + if typeInfoByHeader.count < Self.maxCachedTypeDefHeaders { + typeInfoByHeader.set(localTypeInfo, for: header) + } return localTypeInfo } let canonicalTypeMeta: TypeMeta @@ -562,7 +566,9 @@ final class TypeResolver { canonicalTypeMeta = typeMeta } let typeInfo = TypeInfo(dynamic: localTypeInfo, compatibleTypeMeta: canonicalTypeMeta) - typeInfoByHeader.set(typeInfo, for: header) + if typeInfoByHeader.count < Self.maxCachedTypeDefHeaders { + typeInfoByHeader.set(typeInfo, for: header) + } return typeInfo } diff --git a/swift/Tests/ForyTests/ForySwiftTests.swift b/swift/Tests/ForyTests/ForySwiftTests.swift index 48ff046f12..4b883a5dad 100644 --- a/swift/Tests/ForyTests/ForySwiftTests.swift +++ b/swift/Tests/ForyTests/ForySwiftTests.swift @@ -490,6 +490,31 @@ func primitiveArrayTypeIDs() throws { #expect(UInt32(uint8Bytes[2]) == TypeId.list.rawValue) } +@Test +func typeDefHeaderCacheStopsPublishingAtCapacity() throws { + let resolver = TypeResolver() + resolver.register(Person.self, id: 901) + let typeInfo = try resolver.requireTypeInfo(for: Person.self) + let typeMeta = try #require(typeInfo.typeMeta) + let localHeader = try #require(typeInfo.typeDefHeader) + #expect(resolver.getTypeInfo(forHeader: localHeader) != nil) + + var header = UInt64(0x0100_0000_0000_0000) + var inserted = 0 + while inserted < 8191 { + if header != localHeader { + _ = try resolver.cacheTypeInfo(typeMeta, forHeader: header) + inserted += 1 + } + header += 1 + } + + let uncachedHeader = header == localHeader ? header + 1 : header + let current = try resolver.cacheTypeInfo(typeMeta, forHeader: uncachedHeader) + #expect(current.compatibleTypeMeta != nil) + #expect(resolver.getTypeInfo(forHeader: uncachedHeader) == nil) +} + @Test func macroStructRoundTrip() throws { let fory = Fory() From f3f994138a1ae1d4105893d6b54eff6d8ae00fd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=85=95=E7=99=BD?= Date: Wed, 6 May 2026 01:51:24 +0800 Subject: [PATCH 05/10] fix(rust): validate utf8 string reads --- rust/fory-core/src/buffer.rs | 19 ++++++++++++- rust/fory-core/src/config.rs | 9 +++++++ rust/fory-core/src/context.rs | 8 ++++++ rust/fory-core/src/fory.rs | 19 +++++++++++++ rust/fory-core/src/serializer/string.rs | 9 ++++++- rust/tests/tests/test_buffer.rs | 36 +++++++++++++++++++++++++ 6 files changed, 98 insertions(+), 2 deletions(-) diff --git a/rust/fory-core/src/buffer.rs b/rust/fory-core/src/buffer.rs index 2d4a6b4666..21cc4377dd 100644 --- a/rust/fory-core/src/buffer.rs +++ b/rust/fory-core/src/buffer.rs @@ -950,6 +950,24 @@ impl<'a> Reader<'a> { #[inline(always)] pub fn read_utf8_string(&mut self, len: usize) -> Result { + self.check_bound(len)?; + // don't use simd for memory copy, copy_non_overlapping is faster + unsafe { + let mut vec = Vec::with_capacity(len); + let src = self.bf.as_ptr().add(self.cursor); + let dst = vec.as_mut_ptr(); + // Use fastest possible copy - copy_nonoverlapping compiles to memcpy + std::ptr::copy_nonoverlapping(src, dst, len); + vec.set_len(len); + let string = String::from_utf8(vec) + .map_err(|_| Error::encoding_error("invalid UTF-8 string"))?; + self.move_next(len); + Ok(string) + } + } + + #[inline(always)] + pub fn read_utf8_string_unchecked(&mut self, len: usize) -> Result { self.check_bound(len)?; // don't use simd for memory copy, copy_non_overlapping is faster unsafe { @@ -960,7 +978,6 @@ impl<'a> Reader<'a> { std::ptr::copy_nonoverlapping(src, dst, len); vec.set_len(len); self.move_next(len); - // SAFETY: Assuming valid UTF-8 bytes (responsibility of serialization protocol) Ok(String::from_utf8_unchecked(vec)) } } diff --git a/rust/fory-core/src/config.rs b/rust/fory-core/src/config.rs index 13342700de..c0f4a4328d 100644 --- a/rust/fory-core/src/config.rs +++ b/rust/fory-core/src/config.rs @@ -30,6 +30,8 @@ pub struct Config { pub share_meta: bool, /// Whether meta string compression is enabled. pub compress_string: bool, + /// Whether UTF-8 string payloads are validated before constructing Rust strings. + pub check_string_read: bool, /// Maximum depth for nested dynamic object serialization. pub max_dyn_depth: u32, /// Whether class version checking is enabled. @@ -53,6 +55,7 @@ impl Default for Config { xlang: false, share_meta: false, compress_string: false, + check_string_read: true, max_dyn_depth: 5, check_struct_version: false, track_ref: false, @@ -92,6 +95,12 @@ impl Config { self.compress_string } + /// Check if UTF-8 string payload validation is enabled. + #[inline(always)] + pub fn is_check_string_read(&self) -> bool { + self.check_string_read + } + /// Get maximum dynamic depth. #[inline(always)] pub fn max_dyn_depth(&self) -> u32 { diff --git a/rust/fory-core/src/context.rs b/rust/fory-core/src/context.rs index acd567029c..889d279659 100644 --- a/rust/fory-core/src/context.rs +++ b/rust/fory-core/src/context.rs @@ -357,6 +357,7 @@ pub struct ReadContext<'a> { xlang: bool, max_dyn_depth: u32, check_struct_version: bool, + check_string_read: bool, max_binary_size: u32, max_collection_size: u32, @@ -386,6 +387,7 @@ impl<'a> ReadContext<'a> { xlang: config.xlang, max_dyn_depth: config.max_dyn_depth, check_struct_version: config.check_struct_version, + check_string_read: config.check_string_read, max_binary_size: config.max_binary_size, max_collection_size: config.max_collection_size, reader: Reader::default(), @@ -426,6 +428,12 @@ impl<'a> ReadContext<'a> { self.check_struct_version } + /// Check if UTF-8 string payload validation is enabled. + #[inline(always)] + pub fn is_check_string_read(&self) -> bool { + self.check_string_read + } + /// Get maximum dynamic depth #[inline(always)] pub fn max_dyn_depth(&self) -> u32 { diff --git a/rust/fory-core/src/fory.rs b/rust/fory-core/src/fory.rs index bfe61dbd9a..2e211f5fa5 100644 --- a/rust/fory-core/src/fory.rs +++ b/rust/fory-core/src/fory.rs @@ -170,6 +170,20 @@ impl ForyBuilder { self } + /// Enables or disables checked UTF-8 string reads. + /// + /// Checked reads validate UTF-8 payload bytes before constructing Rust `String` values. + /// Disabling this keeps the faster unchecked construction path and must only be used when + /// serialized bytes are trusted to contain valid UTF-8 strings. + /// + /// # Default + /// + /// The default value is `true`. + pub fn check_string_read(mut self, check_string_read: bool) -> Self { + self.config.check_string_read = check_string_read; + self + } + /// Enables or disables class version checking for schema consistency. /// /// # Arguments @@ -442,6 +456,11 @@ impl Fory { self.config.compress_string } + /// Returns whether UTF-8 string payload validation is enabled. + pub fn is_check_string_read(&self) -> bool { + self.config.check_string_read + } + /// Returns whether metadata sharing is enabled. /// /// # Returns diff --git a/rust/fory-core/src/serializer/string.rs b/rust/fory-core/src/serializer/string.rs index 093f873c76..abbb82d931 100644 --- a/rust/fory-core/src/serializer/string.rs +++ b/rust/fory-core/src/serializer/string.rs @@ -49,7 +49,14 @@ impl Serializer for String { let s = match encoding { 0 => context.reader.read_latin1_string(len as usize), 1 => context.reader.read_utf16_string(len as usize), - 2 => context.reader.read_utf8_string(len as usize), + 2 => { + let len = len as usize; + if context.is_check_string_read() { + context.reader.read_utf8_string(len) + } else { + context.reader.read_utf8_string_unchecked(len) + } + } _ => { return Err(Error::encoding_error(format!( "wrong encoding value: {}", diff --git a/rust/tests/tests/test_buffer.rs b/rust/tests/tests/test_buffer.rs index ed5c3fc26b..2ef944f1c5 100644 --- a/rust/tests/tests/test_buffer.rs +++ b/rust/tests/tests/test_buffer.rs @@ -16,6 +16,7 @@ // under the License. use fory_core::buffer::{Reader, Writer}; +use fory_core::Fory; #[test] fn test_var_i32() { @@ -116,3 +117,38 @@ fn test_fixed_width_read_bounds_checks() { assert!(bad_cursor.read_u16().is_err()); assert!(bad_cursor.read_var_u36_small().is_err()); } + +#[test] +fn test_utf8_string_read_rejects_invalid_payload() { + let mut reader = Reader::new(&[0xff]); + let err = reader.read_utf8_string(1).unwrap_err(); + assert!( + err.to_string().contains("invalid UTF-8 string"), + "unexpected error: {err}" + ); + assert_eq!(reader.get_cursor(), 0); +} + +#[test] +fn test_fory_rejects_invalid_utf8_string_by_default() { + let fory = Fory::builder().build(); + assert!(fory.is_check_string_read()); + let mut bytes = fory.serialize(&"a".to_string()).unwrap(); + *bytes.last_mut().unwrap() = 0xff; + + let err = fory.deserialize::(&bytes).unwrap_err(); + assert!( + err.to_string().contains("invalid UTF-8 string"), + "unexpected error: {err}" + ); +} + +#[test] +fn test_fory_can_disable_checked_string_read_for_trusted_data() { + let fory = Fory::builder().check_string_read(false).build(); + assert!(!fory.is_check_string_read()); + + let bytes = fory.serialize(&"valid".to_string()).unwrap(); + let value = fory.deserialize::(&bytes).unwrap(); + assert_eq!(value, "valid"); +} From e62a4dbc45686d105cea5b3749fa37d43742ff1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=85=95=E7=99=BD?= Date: Wed, 6 May 2026 02:56:39 +0800 Subject: [PATCH 06/10] fix(java): validate typedef root kinds --- .../idl_tests/javascript/roundtrip.ts | 6 +- .../fory/meta/NativeTypeDefDecoder.java | 33 +++++-- .../fory/meta/NativeTypeDefEncoder.java | 25 +++++- .../org/apache/fory/meta/TypeDefDecoder.java | 19 ++++- .../apache/fory/resolver/ClassResolver.java | 55 ++++++++++++ .../apache/fory/resolver/TypeResolver.java | 20 ++--- .../fory/meta/NativeTypeDefEncoderTest.java | 85 +++++++++++++++++++ .../apache/fory/meta/TypeDefEncoderTest.java | 23 +++-- .../extension/meta/TypeDefEncoderTest.java | 3 +- rust/fory-core/src/buffer.rs | 19 ++--- 10 files changed, 242 insertions(+), 46 deletions(-) diff --git a/integration_tests/idl_tests/javascript/roundtrip.ts b/integration_tests/idl_tests/javascript/roundtrip.ts index 23af999e0b..4b77e8fc41 100644 --- a/integration_tests/idl_tests/javascript/roundtrip.ts +++ b/integration_tests/idl_tests/javascript/roundtrip.ts @@ -123,8 +123,10 @@ function resolveRootSerializer(fory: Fory, bytes: Uint8Array): Serializer { fory.readContext.reset(bytes); const reader = fory.readContext.reader; const bitmap = reader.readUint8(); - if ((bitmap & ConfigFlags.isNullFlag) === ConfigFlags.isNullFlag) { - throw new Error("IDL roundtrip does not support null root payloads"); + const supportedBitmap = + ConfigFlags.isCrossLanguageFlag | ConfigFlags.isOutOfBandFlag; + if ((bitmap & ~supportedBitmap) !== 0) { + throw new Error("unsupported root header bitmap"); } if ( (bitmap & ConfigFlags.isCrossLanguageFlag) !== diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java b/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java index 84f3c5b5f1..804a0bc85c 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java @@ -90,12 +90,14 @@ public static TypeDef decodeTypeDef(ClassResolver resolver, MemoryBuffer buffer, String className; List classFields = new ArrayList<>(); ClassSpec classSpec = null; + Class rootClass = null; for (int i = 0; i < numClasses; i++) { // | num fields + register flag | header + package name | header + class name // | header + type id + field name | next field info | ... | int currentClassHeader = typeDefBuf.readVarUInt32Small7(); boolean isRegistered = (currentClassHeader & 0b1) != 0; int numFields = currentClassHeader >>> 1; + Class currentClass = null; if (isRegistered) { int typeId = typeDefBuf.readUInt8(); int userTypeId = -1; @@ -104,11 +106,16 @@ public static TypeDef decodeTypeDef(ClassResolver resolver, MemoryBuffer buffer, } Class cls = resolver.getRegisteredClassByTypeId(typeId, userTypeId); if (cls == null) { - classSpec = new ClassSpec(UnknownClass.UnknownStruct.class, typeId, userTypeId); + classSpec = + new ClassSpec( + UnknownClass.UnknownStruct.class, + i == numClasses - 1 ? rootTypeId : typeId, + userTypeId); className = classSpec.entireClassName; } else { className = cls.getName(); - classSpec = new ClassSpec(cls, typeId, userTypeId); + classSpec = new ClassSpec(cls, i == numClasses - 1 ? rootTypeId : typeId, userTypeId); + currentClass = cls; } } else { String pkg = readPkgName(typeDefBuf); @@ -118,9 +125,9 @@ public static TypeDef decodeTypeDef(ClassResolver resolver, MemoryBuffer buffer, if (resolver.isRegisteredByName(className)) { Class cls = resolver.getRegisteredClass(className); className = cls.getName(); - classSpec = - new ClassSpec( - cls, resolver.getTypeIdForTypeDef(cls), resolver.getUserTypeIdForTypeDef(cls)); + int typeId = i == numClasses - 1 ? rootTypeId : resolver.getTypeIdForTypeDef(cls); + classSpec = new ClassSpec(cls, typeId, resolver.getUserTypeIdForTypeDef(cls)); + currentClass = cls; } else { Class cls = resolver.loadClassForMeta( @@ -145,22 +152,30 @@ public static TypeDef decodeTypeDef(ClassResolver resolver, MemoryBuffer buffer, classSpec.type = cls; className = classSpec.entireClassName; } else { - int typeId = resolver.getTypeIdForTypeDef(cls); + int typeId = i == numClasses - 1 ? rootTypeId : resolver.getTypeIdForTypeDef(cls); classSpec = new ClassSpec(cls, typeId, resolver.getUserTypeIdForTypeDef(cls)); className = classSpec.entireClassName; + currentClass = cls; } } } - if (i == numClasses - 1 && classSpec.typeId != rootTypeId) { - throw new DeserializationException("TypeDef root kind does not match root class metadata"); + if (i == numClasses - 1) { + rootClass = currentClass; } List fieldInfos = readFieldsInfo(typeDefBuf, resolver, className, numFields); classFields.addAll(fieldInfos); } Preconditions.checkNotNull(classSpec); - if (!Types.isStructType(rootTypeId) && !classFields.isEmpty()) { + boolean hasFieldMetadata = !classFields.isEmpty(); + if (!Types.isStructType(rootTypeId) && hasFieldMetadata) { throw new DeserializationException("Non-struct TypeDef cannot carry field metadata"); } + if (rootClass != null) { + int expectedRootTypeId = resolver.getTypeDefRootTypeId(rootClass, hasFieldMetadata); + if (rootTypeId != expectedRootTypeId) { + throw new DeserializationException("TypeDef root kind does not match the decoded class"); + } + } if (typeDefBuf.remaining() != 0) { throw new DeserializationException("Invalid TypeDef metadata size"); } diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefEncoder.java b/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefEncoder.java index bb2e104ee2..6173a9e88d 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefEncoder.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefEncoder.java @@ -146,12 +146,13 @@ static TypeDef buildTypeDef(ClassResolver classResolver, Class type, List type, List fieldInfos) { + boolean hasFieldMetadata = !fieldInfos.isEmpty(); Map> classLayers = getClassFields(type, fieldInfos); fieldInfos = new ArrayList<>(fieldInfos.size()); classLayers.values().forEach(fieldInfos::addAll); - MemoryBuffer encodeTypeDef = encodeTypeDef(classResolver, type, classLayers); + MemoryBuffer encodeTypeDef = encodeTypeDef(classResolver, type, classLayers, hasFieldMetadata); byte[] typeDefBytes = encodeTypeDef.getBytes(0, encodeTypeDef.writerIndex()); - int typeId = classResolver.getTypeIdForTypeDef(type); + int typeId = classResolver.getTypeDefRootTypeId(type, hasFieldMetadata); int userTypeId = classResolver.getUserTypeIdForTypeDef(type); ClassSpec classSpec = new ClassSpec(type, typeId, userTypeId); return new TypeDef(classSpec, fieldInfos, encodeTypeDef.getInt64(0), typeDefBytes); @@ -161,9 +162,18 @@ public static TypeDef buildTypeDefWithFieldInfos( // https://fory.apache.org/docs/specification/fory_java_serialization_spec public static MemoryBuffer encodeTypeDef( ClassResolver classResolver, Class type, Map> classLayers) { + return encodeTypeDef(classResolver, type, classLayers, hasFieldMetadata(classLayers)); + } + + private static MemoryBuffer encodeTypeDef( + ClassResolver classResolver, + Class type, + Map> classLayers, + boolean hasFieldMetadata) { MemoryBuffer typeDefBuf = MemoryBuffer.newHeapBuffer(128); int numClasses = classLayers.size() - 1; // num class must be greater than 0 - int firstBodyByte = nativeKindCode(classResolver.getTypeIdForTypeDef(type)) << 4; + int rootTypeId = classResolver.getTypeDefRootTypeId(type, hasFieldMetadata); + int firstBodyByte = nativeKindCode(rootTypeId) << 4; if (numClasses >= NUM_CLASS_THRESHOLD) { typeDefBuf.writeByte(firstBodyByte | NUM_CLASS_THRESHOLD); typeDefBuf.writeVarUInt32Small7(numClasses - NUM_CLASS_THRESHOLD); @@ -217,6 +227,15 @@ public static MemoryBuffer encodeTypeDef( return prependHeader(typeDefBuf, isCompressed); } + private static boolean hasFieldMetadata(Map> classLayers) { + for (List fields : classLayers.values()) { + if (!fields.isEmpty()) { + return true; + } + } + return false; + } + static MemoryBuffer prependHeader(MemoryBuffer buffer, boolean isCompressed) { int metaSize = buffer.writerIndex(); long hash = MurmurHash3.murmurhash3_x64_128(buffer.getHeapMemory(), 0, metaSize, 47)[0]; diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefDecoder.java b/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefDecoder.java index 61ef2fb0f5..b184656c47 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefDecoder.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefDecoder.java @@ -156,14 +156,29 @@ public static TypeDef decodeTypeDef(XtypeResolver resolver, MemoryBuffer inputBu } private static void validateRegisteredTypeDefKind(TypeInfo userTypeInfo, int typeId) { - if (userTypeInfo.getTypeId() != typeId) { + int registeredTypeId = userTypeInfo.getTypeId(); + if (registeredTypeId != typeId && !isStructCompatibilityVariant(registeredTypeId, typeId)) { throw new DeserializationException( String.format( "TypeDef kind %s does not match registered kind %s for %s", - typeId, userTypeInfo.getTypeId(), userTypeInfo.getType())); + typeId, registeredTypeId, userTypeInfo.getType())); } } + private static boolean isStructCompatibilityVariant(int registeredTypeId, int typeId) { + boolean registeredIdStruct = + registeredTypeId == Types.STRUCT || registeredTypeId == Types.COMPATIBLE_STRUCT; + boolean typeIdStruct = typeId == Types.STRUCT || typeId == Types.COMPATIBLE_STRUCT; + if (registeredIdStruct || typeIdStruct) { + return registeredIdStruct && typeIdStruct; + } + boolean registeredNamedStruct = + registeredTypeId == Types.NAMED_STRUCT || registeredTypeId == Types.NAMED_COMPATIBLE_STRUCT; + boolean typeIdNamedStruct = + typeId == Types.NAMED_STRUCT || typeId == Types.NAMED_COMPATIBLE_STRUCT; + return registeredNamedStruct && typeIdNamedStruct; + } + static int nonStructTypeId(int kindCode) { switch (kindCode) { case 0: diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java index 9903000d74..09c4402209 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java @@ -876,6 +876,61 @@ public int getTypeIdForTypeDef(Class cls) { return typeId; } + public int getTypeDefRootTypeId(Class cls, boolean hasFieldMetadata) { + if (hasFieldMetadata) { + // Preserve the normal TypeInfo/name cache so locally generated or dynamically registered + // classes can be resolved when the TypeDef is decoded by the same resolver. + getTypeIdForTypeDef(cls); + return getFieldMetadataTypeIdForTypeDef(cls); + } + TypeInfo typeInfo = classInfoMap.get(cls); + if (typeInfo != null) { + return normalizeTypeDefRootTypeId(cls, typeInfo.typeId); + } + Integer classId = extRegistry.registeredClassIdMap.get(cls); + if (classId != null) { + typeInfo = classInfoMap.get(cls); + if (typeInfo == null) { + typeInfo = getTypeInfo(cls); + } + return normalizeTypeDefRootTypeId(cls, typeInfo.typeId); + } + return usesNonStructTypeDef(cls) ? Types.NAMED_EXT : buildUnregisteredTypeId(cls, null); + } + + private int getFieldMetadataTypeIdForTypeDef(Class cls) { + Integer classId = extRegistry.registeredClassIdMap.get(cls); + if (classId != null && !isInternalRegisteredClassId(cls, classId)) { + return buildUserTypeId(cls, null); + } + return super.buildUnregisteredTypeId(cls, null); + } + + private int normalizeTypeDefRootTypeId(Class cls, int typeId) { + if (isSupportedTypeDefTypeId(typeId)) { + return typeId; + } + return usesNonStructTypeDef(cls) ? Types.NAMED_EXT : buildUnregisteredTypeId(cls, null); + } + + private static boolean isSupportedTypeDefTypeId(int typeId) { + switch (typeId) { + case Types.ENUM: + case Types.NAMED_ENUM: + case Types.STRUCT: + case Types.COMPATIBLE_STRUCT: + case Types.NAMED_STRUCT: + case Types.NAMED_COMPATIBLE_STRUCT: + case Types.EXT: + case Types.NAMED_EXT: + case Types.TYPED_UNION: + case Types.NAMED_UNION: + return true; + default: + return false; + } + } + private boolean usesNonStructTypeDef(Class cls) { return !cls.isEnum() && (isCollection(cls) diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java index 9624020c89..7233414ee8 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java @@ -812,6 +812,16 @@ protected final TypeInfo readSharedClassMeta(ReadContext readContext) { return typeInfo; } + public final TypeInfo readSharedClassMeta(ReadContext readContext, Class targetClass) { + TypeInfo typeInfo = readSharedClassMeta(readContext); + Class readClass = typeInfo.getType(); + // replace target class if needed + if (targetClass != readClass) { + return getTargetTypeInfo(typeInfo, targetClass); + } + return typeInfo; + } + private static TypeInfo getMetaReadTypeInfo(MetaReadContext metaReadContext, int index) { if (index < 0 || index >= metaReadContext.readTypeInfos.size) { throw new ForyException("Invalid class metadata reference id " + index); @@ -823,16 +833,6 @@ private static TypeInfo getMetaReadTypeInfo(MetaReadContext metaReadContext, int return typeInfo; } - public final TypeInfo readSharedClassMeta(ReadContext readContext, Class targetClass) { - TypeInfo typeInfo = readSharedClassMeta(readContext); - Class readClass = typeInfo.getType(); - // replace target class if needed - if (targetClass != readClass) { - return getTargetTypeInfo(typeInfo, targetClass); - } - return typeInfo; - } - private TypeInfo getTargetTypeInfo(TypeInfo typeInfo, Class targetClass) { TransformedTypeInfo[] infos = extRegistry.transformedTypeInfo.get(targetClass); Class readClass = typeInfo.getType(); diff --git a/java/fory-core/src/test/java/org/apache/fory/meta/NativeTypeDefEncoderTest.java b/java/fory-core/src/test/java/org/apache/fory/meta/NativeTypeDefEncoderTest.java index a78bf53e41..22d988dfcf 100644 --- a/java/fory-core/src/test/java/org/apache/fory/meta/NativeTypeDefEncoderTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/meta/NativeTypeDefEncoderTest.java @@ -23,6 +23,8 @@ import static org.apache.fory.meta.NativeTypeDefEncoder.getClassFields; import java.io.Serializable; +import java.util.Arrays; +import java.util.Collections; import java.util.List; import lombok.Data; import org.apache.fory.Fory; @@ -30,9 +32,11 @@ import org.apache.fory.exception.DeserializationException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.resolver.ClassResolver; +import org.apache.fory.serializer.ObjectStreamSerializer; import org.apache.fory.test.bean.BeanA; import org.apache.fory.test.bean.MapFields; import org.apache.fory.test.bean.Struct; +import org.apache.fory.type.Types; import org.testng.Assert; import org.testng.annotations.Test; @@ -98,6 +102,87 @@ public void testBigClassNameObject() { Assert.assertEquals(typeDef1, typeDef); } + @Data + public static class NaturalExtTypeWithFields implements Serializable { + private static final long serialVersionUID = 1L; + private int value; + } + + @Test + public void testFieldMetadataTypeDefUsesStructKindForNaturalExtSerializer() { + Fory fory = + Fory.builder() + .withXlang(false) + .withMetaShare(true) + .withCompatible(true) + .requireClassRegistration(false) + .build(); + ClassResolver resolver = (ClassResolver) fory.getTypeResolver(); + fory.registerSerializer( + NaturalExtTypeWithFields.class, + new ObjectStreamSerializer(resolver, NaturalExtTypeWithFields.class)); + + TypeDef typeDef = TypeDef.buildTypeDef(resolver, NaturalExtTypeWithFields.class); + Assert.assertTrue(typeDef.isStructSchemaKind()); + + TypeDef decoded = + TypeDef.readTypeDef(resolver, MemoryBuffer.fromByteArray(typeDef.getEncoded())); + Assert.assertEquals(decoded, typeDef); + } + + @Test + public void testEmptyTypeDefKeepsNaturalExtKind() { + Fory fory = + Fory.builder() + .withXlang(false) + .withMetaShare(true) + .withCompatible(true) + .requireClassRegistration(false) + .build(); + ClassResolver resolver = (ClassResolver) fory.getTypeResolver(); + fory.getTypeResolver().getTypeInfo(java.util.ArrayList.class); + + TypeDef typeDef = + NativeTypeDefEncoder.buildTypeDefWithFieldInfos( + resolver, java.util.ArrayList.class, Collections.emptyList()); + Assert.assertEquals(typeDef.getClassSpec().typeId, Types.NAMED_EXT); + + TypeDef decoded = + TypeDef.readTypeDef(resolver, MemoryBuffer.fromByteArray(typeDef.getEncoded())); + Assert.assertEquals(decoded, typeDef); + } + + @Test + public void testDecodeRejectsKnownClassWithForgedRootKind() { + Fory fory = + Fory.builder() + .withXlang(false) + .withMetaShare(true) + .withCompatible(true) + .requireClassRegistration(false) + .build(); + ClassResolver resolver = (ClassResolver) fory.getTypeResolver(); + fory.getTypeResolver().getTypeInfo(java.util.ArrayList.class); + + TypeDef typeDef = + NativeTypeDefEncoder.buildTypeDefWithFieldInfos( + resolver, java.util.ArrayList.class, Collections.emptyList()); + byte[] encoded = typeDef.getEncoded(); + MemoryBuffer encodedBuffer = MemoryBuffer.fromByteArray(encoded); + long header = encodedBuffer.readInt64(); + Assert.assertEquals(header & TypeDef.COMPRESS_META_FLAG, 0L); + Assert.assertEquals((int) (header & TypeDef.META_SIZE_MASKS), encoded.length - Long.BYTES); + + byte[] body = Arrays.copyOfRange(encoded, Long.BYTES, encoded.length); + body[0] = + (byte) ((NativeTypeDefEncoder.nativeKindCode(Types.NAMED_STRUCT) << 4) | (body[0] & 0x0f)); + MemoryBuffer malformedBody = MemoryBuffer.newHeapBuffer(body.length); + malformedBody.writeBytes(body); + MemoryBuffer malformed = NativeTypeDefEncoder.prependHeader(malformedBody, false); + Assert.assertThrows( + DeserializationException.class, () -> TypeDef.readTypeDef(resolver, malformed)); + } + @Data public static class TestClassLengthTestClassLengthTestClassLengthTestClassLengthTestClassLengthTestClassLengthTestClassLength diff --git a/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefEncoderTest.java b/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefEncoderTest.java index 1b256b3d73..938e632631 100644 --- a/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefEncoderTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefEncoderTest.java @@ -29,6 +29,7 @@ import org.apache.fory.exception.DeserializationException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.resolver.TypeResolver; +import org.apache.fory.type.Types; import org.testng.Assert; import org.testng.annotations.Test; @@ -414,7 +415,8 @@ public void testExtendedFieldCountHeaderDoesNotSetRegisterByName() { Assert.assertEquals(bodyHeader & TypeDefEncoder.SMALL_NUM_FIELDS_THRESHOLD, 31); Assert.assertEquals(bodyHeader & TypeDefEncoder.REGISTER_BY_NAME_FLAG, 0); TypeDef decoded = - TypeDef.readTypeDef(fory.getTypeResolver(), MemoryBuffer.fromByteArray(typeDef.getEncoded())); + TypeDef.readTypeDef( + fory.getTypeResolver(), MemoryBuffer.fromByteArray(typeDef.getEncoded())); Assert.assertEquals(decoded.getFieldsInfo().size(), 32); } @@ -429,8 +431,7 @@ public void testDecodeRejectsCompressedXlangTypeDef() { MemoryBuffer encoded = NativeTypeDefEncoder.prependHeader(compressedBody, true); Assert.assertThrows( - DeserializationException.class, - () -> TypeDef.readTypeDef(fory.getTypeResolver(), encoded)); + DeserializationException.class, () -> TypeDef.readTypeDef(fory.getTypeResolver(), encoded)); } @Test @@ -502,8 +503,20 @@ public void testDecodeRejectsRegisteredTypeDefKindMismatch() { MemoryBuffer encoded = NativeTypeDefEncoder.prependHeader(body, false); Assert.assertThrows( - DeserializationException.class, - () -> TypeDef.readTypeDef(fory.getTypeResolver(), encoded)); + DeserializationException.class, () -> TypeDef.readTypeDef(fory.getTypeResolver(), encoded)); + } + + @Test + public void testDecodePreservesCompatibleStructKindForRegisteredStruct() { + Fory fory = Fory.builder().withXlang(true).withMetaShare(true).withCompatible(false).build(); + fory.register(EmptyStruct.class, 6001); + MemoryBuffer body = MemoryBuffer.newHeapBuffer(8); + body.writeByte(TypeDefEncoder.STRUCT_FLAG | TypeDefEncoder.COMPATIBLE_FLAG); + body.writeVarUInt32(6001); + MemoryBuffer encoded = NativeTypeDefEncoder.prependHeader(body, false); + + TypeDef typeDef = TypeDef.readTypeDef(fory.getTypeResolver(), encoded); + Assert.assertEquals(typeDef.getClassSpec().typeId, Types.COMPATIBLE_STRUCT); } private static byte[] corruptEncodedBody(TypeDef typeDef, String needle) { diff --git a/java/fory-extensions/src/test/java/org/apache/fory/extension/meta/TypeDefEncoderTest.java b/java/fory-extensions/src/test/java/org/apache/fory/extension/meta/TypeDefEncoderTest.java index 22199c4548..6fd3c0f882 100644 --- a/java/fory-extensions/src/test/java/org/apache/fory/extension/meta/TypeDefEncoderTest.java +++ b/java/fory-extensions/src/test/java/org/apache/fory/extension/meta/TypeDefEncoderTest.java @@ -49,8 +49,7 @@ public void testBasicTypeDefZstdMetaCompressor() throws Exception { ClassResolver classResolver = (ClassResolver) fory.getTypeResolver(); List fieldsInfo = buildFieldsInfo(classResolver, type); MemoryBuffer buffer = - NativeTypeDefEncoder.encodeTypeDef( - classResolver, type, getClassFields(type, fieldsInfo)); + NativeTypeDefEncoder.encodeTypeDef(classResolver, type, getClassFields(type, fieldsInfo)); TypeDef typeDef = TypeDef.readTypeDef(classResolver, buffer); Assert.assertEquals(typeDef.getClassName(), type.getName()); Assert.assertEquals(typeDef.getFieldsInfo().size(), type.getDeclaredFields().length); diff --git a/rust/fory-core/src/buffer.rs b/rust/fory-core/src/buffer.rs index 21cc4377dd..cc7b389eb1 100644 --- a/rust/fory-core/src/buffer.rs +++ b/rust/fory-core/src/buffer.rs @@ -951,19 +951,12 @@ impl<'a> Reader<'a> { #[inline(always)] pub fn read_utf8_string(&mut self, len: usize) -> Result { self.check_bound(len)?; - // don't use simd for memory copy, copy_non_overlapping is faster - unsafe { - let mut vec = Vec::with_capacity(len); - let src = self.bf.as_ptr().add(self.cursor); - let dst = vec.as_mut_ptr(); - // Use fastest possible copy - copy_nonoverlapping compiles to memcpy - std::ptr::copy_nonoverlapping(src, dst, len); - vec.set_len(len); - let string = String::from_utf8(vec) - .map_err(|_| Error::encoding_error("invalid UTF-8 string"))?; - self.move_next(len); - Ok(string) - } + let src = &self.bf[self.cursor..self.cursor + len]; + let string = + std::str::from_utf8(src).map_err(|_| Error::encoding_error("invalid UTF-8 string"))?; + let string = string.to_owned(); + self.move_next(len); + Ok(string) } #[inline(always)] From debe607c49f1c5fdcfe572b5b39117f8a1ee909a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=85=95=E7=99=BD?= Date: Wed, 6 May 2026 07:07:05 +0800 Subject: [PATCH 07/10] fix: clear metadata validation CI regressions --- go/fory/map_primitive_test.go | 2 +- .../fory/meta/NativeTypeDefDecoder.java | 39 +++++++++++++- .../apache/fory/resolver/ClassResolver.java | 10 +++- javascript/test/decimal.test.ts | 51 ++++++++++--------- python/pyfory/tests/test_ref_tracking.py | 6 +-- python/pyfory/tests/test_serializer.py | 2 +- python/pyfory/tests/test_size_guardrails.py | 2 +- 7 files changed, 78 insertions(+), 34 deletions(-) diff --git a/go/fory/map_primitive_test.go b/go/fory/map_primitive_test.go index d36fb6c1f9..7decca8b0a 100644 --- a/go/fory/map_primitive_test.go +++ b/go/fory/map_primitive_test.go @@ -11,7 +11,7 @@ // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the +// KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java b/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java index 804a0bc85c..199cce1fb0 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java @@ -91,6 +91,7 @@ public static TypeDef decodeTypeDef(ClassResolver resolver, MemoryBuffer buffer, List classFields = new ArrayList<>(); ClassSpec classSpec = null; Class rootClass = null; + boolean rootClassLayerRegistered = false; for (int i = 0; i < numClasses; i++) { // | num fields + register flag | header + package name | header + class name // | header + type id + field name | next field info | ... | @@ -161,6 +162,7 @@ public static TypeDef decodeTypeDef(ClassResolver resolver, MemoryBuffer buffer, } if (i == numClasses - 1) { rootClass = currentClass; + rootClassLayerRegistered = isRegistered; } List fieldInfos = readFieldsInfo(typeDefBuf, resolver, className, numFields); classFields.addAll(fieldInfos); @@ -172,8 +174,16 @@ public static TypeDef decodeTypeDef(ClassResolver resolver, MemoryBuffer buffer, } if (rootClass != null) { int expectedRootTypeId = resolver.getTypeDefRootTypeId(rootClass, hasFieldMetadata); - if (rootTypeId != expectedRootTypeId) { - throw new DeserializationException("TypeDef root kind does not match the decoded class"); + if (!isCompatibleRootKind(expectedRootTypeId, rootTypeId, !rootClassLayerRegistered)) { + throw new DeserializationException( + "TypeDef root kind does not match the decoded class: class=" + + rootClass.getName() + + ", expected=" + + expectedRootTypeId + + ", actual=" + + rootTypeId + + ", registeredClassLayer=" + + rootClassLayerRegistered); } } if (typeDefBuf.remaining() != 0) { @@ -183,6 +193,31 @@ public static TypeDef decodeTypeDef(ClassResolver resolver, MemoryBuffer buffer, return new TypeDef(classSpec, classFields, id, decoded.f1); } + private static boolean isCompatibleRootKind( + int expectedTypeId, int actualTypeId, boolean allowNamednessDifference) { + if (expectedTypeId == actualTypeId) { + return true; + } + if (allowNamednessDifference) { + return Types.isStructType(expectedTypeId) && Types.isStructType(actualTypeId); + } + return isStructCompatibilityVariant(expectedTypeId, actualTypeId); + } + + private static boolean isStructCompatibilityVariant(int expectedTypeId, int actualTypeId) { + boolean expectedIdStruct = + expectedTypeId == Types.STRUCT || expectedTypeId == Types.COMPATIBLE_STRUCT; + boolean actualIdStruct = actualTypeId == Types.STRUCT || actualTypeId == Types.COMPATIBLE_STRUCT; + if (expectedIdStruct || actualIdStruct) { + return expectedIdStruct && actualIdStruct; + } + boolean expectedNamedStruct = + expectedTypeId == Types.NAMED_STRUCT || expectedTypeId == Types.NAMED_COMPATIBLE_STRUCT; + boolean actualNamedStruct = + actualTypeId == Types.NAMED_STRUCT || actualTypeId == Types.NAMED_COMPATIBLE_STRUCT; + return expectedNamedStruct && actualNamedStruct; + } + static int nativeTypeId(int kindCode) { switch (kindCode) { case 0: diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java index 09c4402209..38a375c60c 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java @@ -907,10 +907,15 @@ private int getFieldMetadataTypeIdForTypeDef(Class cls) { } private int normalizeTypeDefRootTypeId(Class cls, int typeId) { + if (usesNonStructTypeDef(cls)) { + // Placeholder TypeInfo can be created before the natural serializer is installed. + // The TypeDef root kind must still select the non-struct serializer family. + return Types.isExtType(typeId) ? typeId : Types.NAMED_EXT; + } if (isSupportedTypeDefTypeId(typeId)) { return typeId; } - return usesNonStructTypeDef(cls) ? Types.NAMED_EXT : buildUnregisteredTypeId(cls, null); + return buildUnregisteredTypeId(cls, null); } private static boolean isSupportedTypeDefTypeId(int typeId) { @@ -933,7 +938,8 @@ private static boolean isSupportedTypeDefTypeId(int typeId) { private boolean usesNonStructTypeDef(Class cls) { return !cls.isEnum() - && (isCollection(cls) + && (cls.isArray() + || isCollection(cls) || isMap(cls) || Externalizable.class.isAssignableFrom(cls) || requireJavaSerialization(cls) diff --git a/javascript/test/decimal.test.ts b/javascript/test/decimal.test.ts index 2be7e9d06e..f0ad1fa1e8 100644 --- a/javascript/test/decimal.test.ts +++ b/javascript/test/decimal.test.ts @@ -20,7 +20,10 @@ import Fory, { Decimal, Type } from "../packages/core/index"; import { describe, expect, test } from "@jest/globals"; -function decimal(unscaledValue: string | bigint | number, scale: number): Decimal { +function decimal( + unscaledValue: string | bigint | number, + scale: number, +): Decimal { return new Decimal(unscaledValue, scale); } @@ -52,18 +55,26 @@ describe("decimal", () => { test("round-trips struct decimal fields", () => { const fory = new Fory(); - const serializer = fory.register(Type.struct({ - typeName: "example.DecimalEnvelope", - }, { - amount: Type.decimal(), - note: Type.string(), - })).serializer; + const serializer = fory.register( + Type.struct( + { + typeName: "example.DecimalEnvelope", + }, + { + amount: Type.decimal(), + note: Type.string(), + }, + ), + ).serializer; const value = { amount: decimal("123456789012345678901234567890123456789", 37), note: "principal", }; - const roundTrip = fory.deserialize(fory.serialize(value, serializer), serializer) as { + const roundTrip = fory.deserialize( + fory.serialize(value, serializer), + serializer, + ) as { amount: Decimal; note: string; }; @@ -75,24 +86,16 @@ describe("decimal", () => { test("rejects non-canonical big decimal payloads", () => { const fory = new Fory(); - const zeroBigEncoding = Buffer.from([ - 0x02, - 0xff, - 0x28, - 0x00, - 0x01, - ]); + const zeroBigEncoding = Buffer.from([0x01, 0xff, 0x28, 0x00, 0x01]); const trailingZeroPayload = Buffer.from([ - 0x02, - 0xff, - 0x28, - 0x00, - 0x09, - 0x01, - 0x00, + 0x01, 0xff, 0x28, 0x00, 0x09, 0x01, 0x00, ]); - expect(() => fory.deserialize(zeroBigEncoding)).toThrow(/Invalid decimal magnitude length/); - expect(() => fory.deserialize(trailingZeroPayload)).toThrow(/trailing zero byte/); + expect(() => fory.deserialize(zeroBigEncoding)).toThrow( + /Invalid decimal magnitude length/, + ); + expect(() => fory.deserialize(trailingZeroPayload)).toThrow( + /trailing zero byte/, + ); }); }); diff --git a/python/pyfory/tests/test_ref_tracking.py b/python/pyfory/tests/test_ref_tracking.py index b8b57f1fd3..05d67aeba0 100644 --- a/python/pyfory/tests/test_ref_tracking.py +++ b/python/pyfory/tests/test_ref_tracking.py @@ -269,7 +269,7 @@ def test_collection_mixed_type_primitive_ref_value_regression(): write_context.prepare(buffer) # Fory payload framing + top-level list object. - buffer.write_int8(0b10) + buffer.write_int8(0b1) buffer.write_int8(REF_VALUE_FLAG) fory.type_resolver.write_type_info(write_context, fory.type_resolver.get_type_info(list)) @@ -295,7 +295,7 @@ def test_invalid_top_level_ref_id_raises_value_error(): fory = pyfory.Fory(xlang=True, ref=True, strict=False) buffer = pyfory.Buffer.allocate(32) - buffer.write_int8(0b10) + buffer.write_int8(0b1) buffer.write_int8(REF_FLAG) buffer.write_var_uint32(12345) @@ -310,7 +310,7 @@ def test_invalid_collection_element_ref_id_raises_value_error(): write_context = fory.write_context write_context.prepare(buffer) - buffer.write_int8(0b10) + buffer.write_int8(0b1) buffer.write_int8(REF_VALUE_FLAG) fory.type_resolver.write_type_info(write_context, fory.type_resolver.get_type_info(list)) buffer.write_var_uint32(1) diff --git a/python/pyfory/tests/test_serializer.py b/python/pyfory/tests/test_serializer.py index 6cf4a54aa2..a2396d2a85 100644 --- a/python/pyfory/tests/test_serializer.py +++ b/python/pyfory/tests/test_serializer.py @@ -333,7 +333,7 @@ def test_date_serializer_uses_xlang_varint64_and_native_int32(xlang): day = datetime.date(1969, 12, 31) payload = fory.serialize(day) buffer = Buffer(payload) - assert buffer.read_uint8() == 2 + assert buffer.read_uint8() == (1 if xlang else 0) assert buffer.read_int8() == -1 assert buffer.read_uint8() == TypeId.DATE if xlang: diff --git a/python/pyfory/tests/test_size_guardrails.py b/python/pyfory/tests/test_size_guardrails.py index 01075e9b6e..d276a2133c 100644 --- a/python/pyfory/tests/test_size_guardrails.py +++ b/python/pyfory/tests/test_size_guardrails.py @@ -198,6 +198,6 @@ def test_in_band_buffer_object_respects_limit(self): Fory(ref=True, max_binary_size=100).deserialize(data, buffers=[]) def test_malformed_metastring_ref_raises_value_error(self): - payload = bytes([2, 255, TypeId.NAMED_STRUCT, 3]) + payload = bytes([1, 255, TypeId.NAMED_STRUCT, 3]) with pytest.raises(ValueError, match="Invalid dynamic metastring id"): Fory(xlang=True, strict=False).deserialize(payload) From 723ea859d6322702d51a40f99c1d50ba91becaf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=85=95=E7=99=BD?= Date: Wed, 6 May 2026 07:20:04 +0800 Subject: [PATCH 08/10] fix: clear scala and style CI failures --- .../fory/meta/NativeTypeDefDecoder.java | 3 +- .../apache/fory/resolver/ClassResolver.java | 1 + javascript/packages/core/lib/context.ts | 21 ++++----- javascript/packages/core/lib/fory.ts | 8 ++-- javascript/packages/core/lib/meta/TypeMeta.ts | 44 +++++++++---------- javascript/packages/core/lib/reader/index.ts | 14 +++--- 6 files changed, 47 insertions(+), 44 deletions(-) diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java b/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java index 199cce1fb0..b4f50c2c0d 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java @@ -207,7 +207,8 @@ private static boolean isCompatibleRootKind( private static boolean isStructCompatibilityVariant(int expectedTypeId, int actualTypeId) { boolean expectedIdStruct = expectedTypeId == Types.STRUCT || expectedTypeId == Types.COMPATIBLE_STRUCT; - boolean actualIdStruct = actualTypeId == Types.STRUCT || actualTypeId == Types.COMPATIBLE_STRUCT; + boolean actualIdStruct = + actualTypeId == Types.STRUCT || actualTypeId == Types.COMPATIBLE_STRUCT; if (expectedIdStruct || actualIdStruct) { return expectedIdStruct && actualIdStruct; } diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java index 38a375c60c..588756f51b 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java @@ -945,6 +945,7 @@ private boolean usesNonStructTypeDef(Class cls) { || requireJavaSerialization(cls) || useReplaceResolveSerializer(cls) || Functions.isLambda(cls) + || (config.isScalaOptimizationEnabled() && ReflectionUtils.isScalaSingletonObject(cls)) || Calendar.class.isAssignableFrom(cls) || ZoneId.class.isAssignableFrom(cls) || TimeZone.class.isAssignableFrom(cls) diff --git a/javascript/packages/core/lib/context.ts b/javascript/packages/core/lib/context.ts index ce828d787a..94e4d2e068 100644 --- a/javascript/packages/core/lib/context.ts +++ b/javascript/packages/core/lib/context.ts @@ -198,8 +198,8 @@ export class WriteContext { checkCollectionSize(size: number) { if (size > this._maxCollectionSize) { throw new Error( - `Collection size ${size} exceeds maxCollectionSize ${this._maxCollectionSize}. ` + - "The data may be malicious, or increase maxCollectionSize if needed.", + `Collection size ${size} exceeds maxCollectionSize ${this._maxCollectionSize}. ` + + "The data may be malicious, or increase maxCollectionSize if needed.", ); } } @@ -207,8 +207,8 @@ export class WriteContext { checkBinarySize(size: number) { if (size > this._maxBinarySize) { throw new Error( - `Binary size ${size} exceeds maxBinarySize ${this._maxBinarySize}. ` + - "The data may be malicious, or increase maxBinarySize if needed.", + `Binary size ${size} exceeds maxBinarySize ${this._maxBinarySize}. ` + + "The data may be malicious, or increase maxBinarySize if needed.", ); } } @@ -396,6 +396,7 @@ export class ReadContext { bigint, { readonly typeMeta: TypeMeta; readonly skipBytesAfterHeader: number } > = new Map(); + private _depth = 0; private _maxDepth: number; private _maxBinarySize: number; @@ -429,8 +430,8 @@ export class ReadContext { this._depth++; if (this._depth > this._maxDepth) { throw new Error( - `Deserialization depth limit exceeded: ${this._depth} > ${this._maxDepth}. ` + - "The data may be malicious, or increase maxDepth if needed.", + `Deserialization depth limit exceeded: ${this._depth} > ${this._maxDepth}. ` + + "The data may be malicious, or increase maxDepth if needed.", ); } } @@ -442,8 +443,8 @@ export class ReadContext { checkCollectionSize(size: number) { if (size > this._maxCollectionSize) { throw new Error( - `Collection size ${size} exceeds maxCollectionSize ${this._maxCollectionSize}. ` + - "The data may be malicious, or increase maxCollectionSize if needed.", + `Collection size ${size} exceeds maxCollectionSize ${this._maxCollectionSize}. ` + + "The data may be malicious, or increase maxCollectionSize if needed.", ); } } @@ -451,8 +452,8 @@ export class ReadContext { checkBinarySize(size: number) { if (size > this._maxBinarySize) { throw new Error( - `Binary size ${size} exceeds maxBinarySize ${this._maxBinarySize}. ` + - "The data may be malicious, or increase maxBinarySize if needed.", + `Binary size ${size} exceeds maxBinarySize ${this._maxBinarySize}. ` + + "The data may be malicious, or increase maxBinarySize if needed.", ); } } diff --git a/javascript/packages/core/lib/fory.ts b/javascript/packages/core/lib/fory.ts index d96222b463..ae7343c507 100644 --- a/javascript/packages/core/lib/fory.ts +++ b/javascript/packages/core/lib/fory.ts @@ -58,8 +58,8 @@ export default class Fory { `maxBinarySize must be a non-negative integer but got ${maxBinarySize}`, ); } - const maxCollectionSize = - this.config.maxCollectionSize ?? DEFAULT_MAX_COLLECTION_SIZE; + const maxCollectionSize + = this.config.maxCollectionSize ?? DEFAULT_MAX_COLLECTION_SIZE; if (!Number.isInteger(maxCollectionSize) || maxCollectionSize < 0) { throw new Error( `maxCollectionSize must be a non-negative integer but got ${maxCollectionSize}`, @@ -154,8 +154,8 @@ export default class Fory { } private throwInvalidRootHeader(bitmap: number): never { - const knownFlags = - ConfigFlags.isCrossLanguageFlag | ConfigFlags.isOutOfBandFlag; + const knownFlags + = ConfigFlags.isCrossLanguageFlag | ConfigFlags.isOutOfBandFlag; if ((bitmap & ~knownFlags) !== 0) { throw new Error( `unsupported root header bitmap 0x${bitmap.toString(16)}`, diff --git a/javascript/packages/core/lib/meta/TypeMeta.ts b/javascript/packages/core/lib/meta/TypeMeta.ts index 231a13671e..d42d5ef53c 100644 --- a/javascript/packages/core/lib/meta/TypeMeta.ts +++ b/javascript/packages/core/lib/meta/TypeMeta.ts @@ -74,8 +74,8 @@ export const isPrimitiveTypeId = (typeId: number): boolean => { export const refTrackingUnableTypeId = (typeId: number): boolean => { return ( - PRIMITIVE_TYPE_IDS.includes(typeId as any) || - [TypeId.DURATION, TypeId.DATE, TypeId.TIMESTAMP, TypeId.STRING].includes( + PRIMITIVE_TYPE_IDS.includes(typeId as any) + || [TypeId.DURATION, TypeId.DATE, TypeId.TIMESTAMP, TypeId.STRING].includes( typeId as any, ) ); @@ -355,10 +355,10 @@ export class TypeMeta { private fingerprintTypeId(typeId: number) { if ( - TypeId.userDefinedType(typeId) || - typeId === TypeId.UNION || - typeId === TypeId.TYPED_UNION || - typeId === TypeId.NAMED_UNION + TypeId.userDefinedType(typeId) + || typeId === TypeId.UNION + || typeId === TypeId.TYPED_UNION + || typeId === TypeId.NAMED_UNION ) { return TypeId.UNKNOWN; } @@ -386,8 +386,8 @@ export class TypeMeta { if (fieldTypeId === TypeId.NAMED_ENUM) { fieldTypeId = TypeId.ENUM; } else if ( - fieldTypeId === TypeId.NAMED_UNION || - fieldTypeId === TypeId.TYPED_UNION + fieldTypeId === TypeId.NAMED_UNION + || fieldTypeId === TypeId.TYPED_UNION ) { fieldTypeId = TypeId.UNION; } @@ -600,8 +600,8 @@ export class TypeMeta { if (typeId === TypeId.NAMED_ENUM) { typeId = TypeId.ENUM; } else if ( - typeId === TypeId.NAMED_UNION || - typeId === TypeId.TYPED_UNION + typeId === TypeId.NAMED_UNION + || typeId === TypeId.TYPED_UNION ) { typeId = TypeId.UNION; } @@ -750,12 +750,12 @@ export class TypeMeta { let currentClassHeader: number; if (isStruct) { - currentClassHeader = - STRUCT_TYPEDEF_FLAG | - Math.min(this.fields.length, SMALL_NUM_FIELDS_THRESHOLD); + currentClassHeader + = STRUCT_TYPEDEF_FLAG + | Math.min(this.fields.length, SMALL_NUM_FIELDS_THRESHOLD); if ( - this.type.typeId === TypeId.COMPATIBLE_STRUCT || - this.type.typeId === TypeId.NAMED_COMPATIBLE_STRUCT + this.type.typeId === TypeId.COMPATIBLE_STRUCT + || this.type.typeId === TypeId.NAMED_COMPATIBLE_STRUCT ) { currentClassHeader |= COMPATIBLE_TYPEDEF_FLAG; } @@ -995,9 +995,9 @@ export class TypeMeta { if (c >= "A" && c <= "Z") { if (i > 0) { const prevUpper = chars[i - 1] >= "A" && chars[i - 1] <= "Z"; - const nextUpperOrEnd = - i + 1 >= chars.length || - (chars[i + 1] >= "A" && chars[i + 1] <= "Z"); + const nextUpperOrEnd + = i + 1 >= chars.length + || (chars[i + 1] >= "A" && chars[i + 1] <= "Z"); if (!prevUpper || !nextUpperOrEnd) { result.push("_"); @@ -1023,10 +1023,10 @@ export class TypeMeta { b: { fieldName: string; fieldId?: number }, ) { if ( - a.fieldId !== undefined && - a.fieldId !== null && - b.fieldId !== undefined && - b.fieldId !== null + a.fieldId !== undefined + && a.fieldId !== null + && b.fieldId !== undefined + && b.fieldId !== null ) { return a.fieldId - b.fieldId; } diff --git a/javascript/packages/core/lib/reader/index.ts b/javascript/packages/core/lib/reader/index.ts index 02b7897d60..754011e9b0 100644 --- a/javascript/packages/core/lib/reader/index.ts +++ b/javascript/packages/core/lib/reader/index.ts @@ -43,10 +43,10 @@ export class BinaryReader { // Reuse DataView when the underlying ArrayBuffer, byteOffset, and byteLength are unchanged. const buf = this.platformBuffer.buffer; if ( - buf !== this.cachedArrayBuffer || - !this.dataView || - this.dataView.byteOffset !== this.platformBuffer.byteOffset || - this.dataView.byteLength !== this.byteLength + buf !== this.cachedArrayBuffer + || !this.dataView + || this.dataView.byteOffset !== this.platformBuffer.byteOffset + || this.dataView.byteLength !== this.byteLength ) { this.dataView = new DataView( buf, @@ -500,9 +500,9 @@ export class BinaryReader { rh28 |= (byte & 0x7f) << 21; if ((byte & 0x80) != 0) { return ( - (BigInt(this.readUint8()) << 56n) | - (BigInt(rh28) << 28n) | - BigInt(rl28) + (BigInt(this.readUint8()) << 56n) + | (BigInt(rh28) << 28n) + | BigInt(rl28) ); } } From 389f199cf3833e3496e560cabee5f59ac9bab537 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=85=95=E7=99=BD?= Date: Wed, 6 May 2026 07:30:59 +0800 Subject: [PATCH 09/10] fix: satisfy PR style checks --- python/pyfory/meta/typedef.py | 274 +++++-------------- python/pyfory/meta/typedef_decoder.py | 28 +- python/pyfory/meta/typedef_encoder.py | 20 +- python/pyfory/serializer.py | 138 +++------- python/pyfory/tests/test_policy.py | 37 +-- python/pyfory/tests/test_typedef_encoding.py | 45 +-- rust/fory-core/src/meta/type_meta.rs | 5 +- 7 files changed, 135 insertions(+), 412 deletions(-) diff --git a/python/pyfory/meta/typedef.py b/python/pyfory/meta/typedef.py index 0a444d1ce9..e13d11e18d 100644 --- a/python/pyfory/meta/typedef.py +++ b/python/pyfory/meta/typedef.py @@ -69,9 +69,7 @@ ] # TAG_ID encoding constants -TAG_ID_SIZE_THRESHOLD = ( - 0b1111 # 4-bit threshold for tag IDs (0-14 inline, 15 = overflow) -) +TAG_ID_SIZE_THRESHOLD = 0b1111 # 4-bit threshold for tag IDs (0-14 inline, 15 = overflow) def is_struct_typedef_kind(type_id: int) -> bool: @@ -154,9 +152,7 @@ def __init__( self.encoded = encoded self.is_compressed = is_compressed - def create_fields_serializer( - self, resolver, resolved_field_names=None, local_field_types=None - ): + def create_fields_serializer(self, resolver, resolved_field_names=None, local_field_types=None): """Create serializers for each field. Args: @@ -167,18 +163,12 @@ def create_fields_serializer( """ field_types = local_field_types if field_types is None: - field_types = infer_field_types( - self.cls, field_nullable=resolver.field_nullable - ) + field_types = infer_field_types(self.cls, field_nullable=resolver.field_nullable) serializers = [] for i, field_info in enumerate(self.fields): # Use resolved name if provided, otherwise use original name - lookup_name = ( - resolved_field_names[i] if resolved_field_names else field_info.name - ) - serializer = field_info.field_type.create_serializer( - resolver, field_types.get(lookup_name, None) - ) + lookup_name = resolved_field_names[i] if resolved_field_names else field_info.name + serializer = field_info.field_type.create_serializer(resolver, field_types.get(lookup_name, None)) serializers.append(serializer) return serializers @@ -238,18 +228,14 @@ def create_serializer(self, resolver): if not is_struct_typedef_kind(self.type_id): if is_named_typedef_kind(self.type_id): try: - return resolver.get_type_info_by_name( - self.namespace, self.typename - ).serializer + return resolver.get_type_info_by_name(self.namespace, self.typename).serializer except Exception: if self.type_id == TypeId.NAMED_ENUM: from pyfory.serializer import NonExistEnumSerializer return NonExistEnumSerializer(resolver) raise - return resolver.get_type_info_by_id( - self.type_id, user_type_id=self.user_type_id - ).serializer + return resolver.get_type_info_by_id(self.type_id, user_type_id=self.user_type_id).serializer from pyfory.struct import DataClassSerializer from pyfory.struct import FieldInfo as StructFieldInfo from pyfory.type_util import get_type_hints, unwrap_optional @@ -258,17 +244,9 @@ def create_serializer(self, resolver): field_names = self._resolve_field_names_from_tag_ids() local_field_infos = build_field_infos(resolver, self.cls) - local_infos_by_name = { - field_info.name: field_info for field_info in local_field_infos - } - local_infos_by_tag = { - field_info.tag_id: field_info - for field_info in local_field_infos - if field_info.tag_id >= 0 - } - local_field_types = infer_field_types( - self.cls, field_nullable=resolver.field_nullable - ) + local_infos_by_name = {field_info.name: field_info for field_info in local_field_infos} + local_infos_by_tag = {field_info.tag_id: field_info for field_info in local_field_infos if field_info.tag_id >= 0} + local_field_types = infer_field_types(self.cls, field_nullable=resolver.field_nullable) type_hints = get_type_hints(self.cls) runtime_field_infos = [] for i, field_info in enumerate(self.fields): @@ -283,9 +261,7 @@ def create_serializer(self, resolver): local_info.field_type if local_info is not None else None, ) type_hint = type_hints.get(resolved_name, typing.Any) - unwrapped_type, _ = unwrap_optional( - type_hint, field_nullable=resolver.field_nullable - ) + unwrapped_type, _ = unwrap_optional(type_hint, field_nullable=resolver.field_nullable) serializer = _create_compatible_field_serializer( resolver, resolved_name, @@ -346,9 +322,7 @@ def _snake_to_camel(s: str) -> str: class FieldInfo: - def __init__( - self, name: str, field_type: "FieldType", defined_class: str, tag_id: int = -1 - ): + def __init__(self, name: str, field_type: "FieldType", defined_class: str, tag_id: int = -1): self.name = name self.field_type = field_type self.defined_class = defined_class @@ -412,9 +386,7 @@ def read(cls, buffer: Buffer, resolver): is_tracking_ref = (xtype_id & 0b1) != 0 is_nullable = (xtype_id & 0b10) != 0 xtype_id = xtype_id >> 2 - return cls.read_with_type( - buffer, resolver, xtype_id, is_nullable, is_tracking_ref - ) + return cls.read_with_type(buffer, resolver, xtype_id, is_nullable, is_tracking_ref) @classmethod def read_with_type( @@ -428,19 +400,13 @@ def read_with_type( user_type_id = NO_USER_TYPE_ID if xtype_id in [TypeId.LIST, TypeId.SET]: element_type = cls.read(buffer, resolver) - return CollectionFieldType( - xtype_id, True, is_nullable, is_tracking_ref, element_type - ) + return CollectionFieldType(xtype_id, True, is_nullable, is_tracking_ref, element_type) elif xtype_id == TypeId.MAP: key_type = cls.read(buffer, resolver) value_type = cls.read(buffer, resolver) - return MapFieldType( - xtype_id, True, is_nullable, is_tracking_ref, key_type, value_type - ) + return MapFieldType(xtype_id, True, is_nullable, is_tracking_ref, key_type, value_type) elif xtype_id == TypeId.UNKNOWN: - return DynamicFieldType( - xtype_id, False, is_nullable, is_tracking_ref, user_type_id=user_type_id - ) + return DynamicFieldType(xtype_id, False, is_nullable, is_tracking_ref, user_type_id=user_type_id) else: # For primitive types, determine if they are monomorphic based on the type is_monomorphic = not is_polymorphic_type(xtype_id) @@ -617,10 +583,10 @@ def create_serializer(self, resolver, type_): # to write/read the union payload correctly. if isinstance(type_, list): type_ = type_[0] - assert not is_union_type( - self.type_id - ), "Union fields don't write field type info, \ + assert not is_union_type(self.type_id), ( + "Union fields don't write field type info, \ they are not dynamic field types" + ) if self.type_id != TypeId.UNKNOWN: return FieldType.create_serializer(self, resolver, type_) return None @@ -648,9 +614,7 @@ def __repr__(self): ) -def _payload_shape_matches( - remote_field_type: FieldType, local_field_type: FieldType -) -> bool: +def _payload_shape_matches(remote_field_type: FieldType, local_field_type: FieldType) -> bool: if local_field_type is None: return False remote_type_id = remote_field_type.type_id @@ -660,22 +624,16 @@ def _payload_shape_matches( if remote_type_id != local_type_id: return False if remote_type_id in (TypeId.LIST, TypeId.SET): - return _payload_shape_matches( - remote_field_type.element_type, local_field_type.element_type - ) + return _payload_shape_matches(remote_field_type.element_type, local_field_type.element_type) if remote_type_id == TypeId.MAP: - return _payload_shape_matches( - remote_field_type.key_type, local_field_type.key_type - ) and _payload_shape_matches( + return _payload_shape_matches(remote_field_type.key_type, local_field_type.key_type) and _payload_shape_matches( remote_field_type.value_type, local_field_type.value_type, ) return True -def _payload_shape_needs_local_carrier( - remote_field_type: FieldType, local_field_type: FieldType -) -> bool: +def _payload_shape_needs_local_carrier(remote_field_type: FieldType, local_field_type: FieldType) -> bool: remote_type_id = remote_field_type.type_id local_type_id = local_field_type.type_id if _is_bytes_uint8_array_pair(remote_type_id, local_type_id): @@ -685,16 +643,12 @@ def _payload_shape_needs_local_carrier( if remote_type_id in _ARRAY_TYPE_IDS: return True if remote_type_id in (TypeId.LIST, TypeId.SET): - return _payload_shape_needs_local_carrier( - remote_field_type.element_type, local_field_type.element_type - ) + return _payload_shape_needs_local_carrier(remote_field_type.element_type, local_field_type.element_type) if remote_type_id == TypeId.MAP: return _payload_shape_needs_local_carrier( remote_field_type.key_type, local_field_type.key_type, - ) or _payload_shape_needs_local_carrier( - remote_field_type.value_type, local_field_type.value_type - ) + ) or _payload_shape_needs_local_carrier(remote_field_type.value_type, local_field_type.value_type) return False @@ -702,12 +656,8 @@ def _create_local_typehint_serializer(resolver, field_name, type_hint): from pyfory.struct import StructFieldSerializerVisitor from pyfory.type_util import infer_field, unwrap_optional - unwrapped_type, _ = unwrap_optional( - type_hint, field_nullable=resolver.field_nullable - ) - return infer_field( - field_name, unwrapped_type, StructFieldSerializerVisitor(resolver) - ) + unwrapped_type, _ = unwrap_optional(type_hint, field_nullable=resolver.field_nullable) + return infer_field(field_name, unwrapped_type, StructFieldSerializerVisitor(resolver)) def _create_compatible_field_serializer( @@ -718,9 +668,7 @@ def _create_compatible_field_serializer( local_field_type: typing.Optional[FieldType], local_declared_type, ): - if _payload_shape_matches( - remote_field_type, local_field_type - ) and _payload_shape_needs_local_carrier(remote_field_type, local_field_type): + if _payload_shape_matches(remote_field_type, local_field_type) and _payload_shape_needs_local_carrier(remote_field_type, local_field_type): serializer = _create_local_typehint_serializer(resolver, field_name, type_hint) if serializer is not None: return serializer @@ -730,9 +678,7 @@ def _create_compatible_field_serializer( _SIGNED_INT32_TYPE_IDS = frozenset((TypeId.INT32, TypeId.VARINT32)) _SIGNED_INT64_TYPE_IDS = frozenset((TypeId.INT64, TypeId.VARINT64, TypeId.TAGGED_INT64)) _UNSIGNED_INT32_TYPE_IDS = frozenset((TypeId.UINT32, TypeId.VAR_UINT32)) -_UNSIGNED_INT64_TYPE_IDS = frozenset( - (TypeId.UINT64, TypeId.VAR_UINT64, TypeId.TAGGED_UINT64) -) +_UNSIGNED_INT64_TYPE_IDS = frozenset((TypeId.UINT64, TypeId.VAR_UINT64, TypeId.TAGGED_UINT64)) _INT_TYPE_DOMAINS = {type_id: (True, 32) for type_id in _SIGNED_INT32_TYPE_IDS} _INT_TYPE_DOMAINS.update({type_id: (True, 64) for type_id in _SIGNED_INT64_TYPE_IDS}) _INT_TYPE_DOMAINS.update({type_id: (False, 32) for type_id in _UNSIGNED_INT32_TYPE_IDS}) @@ -745,26 +691,20 @@ def _create_compatible_field_serializer( } -def _requires_nullable_validation( - remote_field_type: FieldType, local_field_type: FieldType -) -> bool: +def _requires_nullable_validation(remote_field_type: FieldType, local_field_type: FieldType) -> bool: return remote_field_type.is_nullable and not local_field_type.is_nullable def _is_bytes_uint8_array_pair(remote_type_id: int, local_type_id: int) -> bool: - return ( - remote_type_id == TypeId.BINARY and local_type_id == TypeId.UINT8_ARRAY - ) or (remote_type_id == TypeId.UINT8_ARRAY and local_type_id == TypeId.BINARY) + return (remote_type_id == TypeId.BINARY and local_type_id == TypeId.UINT8_ARRAY) or ( + remote_type_id == TypeId.UINT8_ARRAY and local_type_id == TypeId.BINARY + ) -def _field_type_assignment( - remote_field_type: FieldType, local_field_type: FieldType -) -> typing.Tuple[bool, bool]: +def _field_type_assignment(remote_field_type: FieldType, local_field_type: FieldType) -> typing.Tuple[bool, bool]: if local_field_type is None: return False, False - needs_validation = _requires_nullable_validation( - remote_field_type, local_field_type - ) + needs_validation = _requires_nullable_validation(remote_field_type, local_field_type) remote_type_id = remote_field_type.type_id local_type_id = local_field_type.type_id if local_type_id == TypeId.UNKNOWN: @@ -814,9 +754,7 @@ def _field_type_assignment( def plan_field_assignment( remote_field_type: FieldType, local_field_type: typing.Optional[FieldType] ) -> typing.Tuple[bool, typing.Optional[FieldType]]: - assignable, needs_validation = _field_type_assignment( - remote_field_type, local_field_type - ) + assignable, needs_validation = _field_type_assignment(remote_field_type, local_field_type) if not assignable: return False, None return True, local_field_type if needs_validation else None @@ -860,12 +798,7 @@ def _is_uint8_array_like(value) -> bool: if isinstance(value, array.array): return value.typecode == "B" np, ndarray, uint8_dtype = _numpy_uint8_type() - return ( - np is not None - and isinstance(value, ndarray) - and value.ndim == 1 - and value.dtype == uint8_dtype - ) + return np is not None and isinstance(value, ndarray) and value.ndim == 1 and value.dtype == uint8_dtype def _bytes_from_uint8_value(value) -> bytes: @@ -881,9 +814,7 @@ def _bytes_from_uint8_value(value) -> bytes: return value.tobytes() if _is_uint8_array_like(value): return value.tobytes() - raise TypeError( - f"Expected bytes or array compatible value, got {type(value)!r}" - ) + raise TypeError(f"Expected bytes or array compatible value, got {type(value)!r}") def _uint8_array_from_bytes(value): @@ -902,16 +833,12 @@ def is_value_assignable(value, local_field_type: FieldType) -> bool: if type_id in (TypeId.LIST, TypeId.SET): if not isinstance(value, (list, tuple, set)): return False - return all( - is_value_assignable(element, local_field_type.element_type) - for element in value - ) + return all(is_value_assignable(element, local_field_type.element_type) for element in value) if type_id == TypeId.MAP: if not isinstance(value, dict): return False return all( - is_value_assignable(key, local_field_type.key_type) - and is_value_assignable(map_value, local_field_type.value_type) + is_value_assignable(key, local_field_type.key_type) and is_value_assignable(map_value, local_field_type.value_type) for key, map_value in value.items() ) if type_id in _INT_TYPE_DOMAINS: @@ -938,20 +865,12 @@ def coerce_assignable_value(value, local_field_type: FieldType): if type_id == TypeId.UINT8_ARRAY and _is_bytes_like(value): return _uint8_array_from_bytes(value) if type_id == TypeId.LIST: - return [ - coerce_assignable_value(element, local_field_type.element_type) - for element in value - ] + return [coerce_assignable_value(element, local_field_type.element_type) for element in value] if type_id == TypeId.SET: - return { - coerce_assignable_value(element, local_field_type.element_type) - for element in value - } + return {coerce_assignable_value(element, local_field_type.element_type) for element in value} if type_id == TypeId.MAP: return { - coerce_assignable_value( - key, local_field_type.key_type - ): coerce_assignable_value(map_value, local_field_type.value_type) + coerce_assignable_value(key, local_field_type.key_type): coerce_assignable_value(map_value, local_field_type.value_type) for key, map_value in value.items() } return value @@ -989,9 +908,7 @@ def build_field_infos(type_resolver, cls): for field_name in field_names: field_type_hint = type_hints.get(field_name, typing.Any) - unwrapped_type, is_optional = unwrap_optional( - field_type_hint, field_nullable=field_nullable - ) + unwrapped_type, is_optional = unwrap_optional(field_type_hint, field_nullable=field_nullable) # Get field metadata if available fory_meta = field_metas.get(field_name) @@ -1031,12 +948,7 @@ def build_field_infos(type_resolver, cls): field_infos.append(field_info) field_types = infer_field_types(cls, field_nullable=field_nullable) - serializers = [ - field_info.field_type.create_serializer( - type_resolver, field_types.get(field_info.name, None) - ) - for field_info in field_infos - ] + serializers = [field_info.field_type.create_serializer(type_resolver, field_types.get(field_info.name, None)) for field_info in field_infos] # Get just the field names for sorting current_field_names = [fi.name for fi in field_infos] @@ -1076,9 +988,7 @@ def build_field_type_with_ref( type_hint=type_hint, ) except Exception as e: - raise TypeError( - f"Error building field type for field: {field_name} with type hint: {type_hint} in class: {visitor.cls}" - ) from e + raise TypeError(f"Error building field type for field: {field_name} with type hint: {type_hint} in class: {visitor.cls}") from e def build_field_type_from_type_ids_with_ref( @@ -1108,17 +1018,9 @@ def build_field_type_from_type_ids_with_ref( elem_nullable = False elem_ref_override = None if type_hint is not None: - origin = ( - typing.get_origin(type_hint) - if hasattr(typing, "get_origin") - else getattr(type_hint, "__origin__", None) - ) + origin = typing.get_origin(type_hint) if hasattr(typing, "get_origin") else getattr(type_hint, "__origin__", None) if origin in (list, typing.List, set, typing.Set): - args = ( - typing.get_args(type_hint) - if hasattr(typing, "get_args") - else getattr(type_hint, "__args__", ()) - ) + args = typing.get_args(type_hint) if hasattr(typing, "get_args") else getattr(type_hint, "__args__", ()) if args: elem_hint, elem_ref_override = unwrap_ref(args[0]) elem_hint, elem_nullable = unwrap_optional(elem_hint) @@ -1141,9 +1043,7 @@ def build_field_type_from_type_ids_with_ref( ) if elem_ref_override is not None: elem_type.tracking_ref_override = elem_ref_override - return CollectionFieldType( - type_id, morphic, is_nullable, is_tracking_ref, elem_type - ) + return CollectionFieldType(type_id, morphic, is_nullable, is_tracking_ref, elem_type) elif type_id == TypeId.MAP: key_hint = None value_hint = None @@ -1152,17 +1052,9 @@ def build_field_type_from_type_ids_with_ref( key_ref_override = None value_ref_override = None if type_hint is not None: - origin = ( - typing.get_origin(type_hint) - if hasattr(typing, "get_origin") - else getattr(type_hint, "__origin__", None) - ) + origin = typing.get_origin(type_hint) if hasattr(typing, "get_origin") else getattr(type_hint, "__origin__", None) if origin in (dict, typing.Dict): - args = ( - typing.get_args(type_hint) - if hasattr(typing, "get_args") - else getattr(type_hint, "__args__", ()) - ) + args = typing.get_args(type_hint) if hasattr(typing, "get_args") else getattr(type_hint, "__args__", ()) if len(args) >= 2: key_hint, key_ref_override = unwrap_ref(args[0]) key_hint, key_nullable = unwrap_optional(key_hint) @@ -1196,9 +1088,7 @@ def build_field_type_from_type_ids_with_ref( key_type.tracking_ref_override = key_ref_override if value_ref_override is not None: value_type.tracking_ref_override = value_ref_override - return MapFieldType( - type_id, morphic, is_nullable, is_tracking_ref, key_type, value_type - ) + return MapFieldType(type_id, morphic, is_nullable, is_tracking_ref, key_type, value_type) elif type_id in [ TypeId.UNKNOWN, TypeId.EXT, @@ -1207,21 +1097,15 @@ def build_field_type_from_type_ids_with_ref( TypeId.COMPATIBLE_STRUCT, TypeId.NAMED_COMPATIBLE_STRUCT, ]: - return DynamicFieldType( - type_id, False, is_nullable, is_tracking_ref, user_type_id=NO_USER_TYPE_ID - ) + return DynamicFieldType(type_id, False, is_nullable, is_tracking_ref, user_type_id=NO_USER_TYPE_ID) else: if type_id <= 0 or type_id >= TypeId.BOUND: raise TypeError(f"Unknown type: {type_id} for field: {field_name}") # union/enum go here too - return FieldType( - type_id, morphic, is_nullable, is_tracking_ref, user_type_id=NO_USER_TYPE_ID - ) + return FieldType(type_id, morphic, is_nullable, is_tracking_ref, user_type_id=NO_USER_TYPE_ID) -def build_field_type( - type_resolver, field_name: str, type_hint, visitor, is_nullable=False -): +def build_field_type(type_resolver, field_name: str, type_hint, visitor, is_nullable=False): """Build field type from type hint.""" type_ids = infer_field(field_name, type_hint, visitor) try: @@ -1234,14 +1118,10 @@ def build_field_type( type_hint=type_hint, ) except Exception as e: - raise TypeError( - f"Error building field type for field: {field_name} with type hint: {type_hint} in class: {visitor.cls}" - ) from e + raise TypeError(f"Error building field type for field: {field_name} with type hint: {type_hint} in class: {visitor.cls}") from e -def build_field_type_from_type_ids( - type_resolver, field_name: str, type_ids, visitor, is_nullable=False, type_hint=None -): +def build_field_type_from_type_ids(type_resolver, field_name: str, type_ids, visitor, is_nullable=False, type_hint=None): from pyfory.type_util import unwrap_optional, unwrap_ref tracking_ref = type_resolver.track_ref @@ -1258,17 +1138,9 @@ def build_field_type_from_type_ids( elem_hint = None elem_nullable = False if type_hint is not None: - origin = ( - typing.get_origin(type_hint) - if hasattr(typing, "get_origin") - else getattr(type_hint, "__origin__", None) - ) + origin = typing.get_origin(type_hint) if hasattr(typing, "get_origin") else getattr(type_hint, "__origin__", None) if origin in (list, typing.List, set, typing.Set): - args = ( - typing.get_args(type_hint) - if hasattr(typing, "get_args") - else getattr(type_hint, "__args__", ()) - ) + args = typing.get_args(type_hint) if hasattr(typing, "get_args") else getattr(type_hint, "__args__", ()) if args: elem_hint, _ = unwrap_ref(args[0]) elem_hint, elem_nullable = unwrap_optional(elem_hint) @@ -1285,26 +1157,16 @@ def build_field_type_from_type_ids( is_nullable=elem_nullable, type_hint=elem_hint, ) - return CollectionFieldType( - type_id, morphic, is_nullable, tracking_ref, elem_type - ) + return CollectionFieldType(type_id, morphic, is_nullable, tracking_ref, elem_type) elif type_id == TypeId.MAP: key_hint = None value_hint = None key_nullable = False value_nullable = False if type_hint is not None: - origin = ( - typing.get_origin(type_hint) - if hasattr(typing, "get_origin") - else getattr(type_hint, "__origin__", None) - ) + origin = typing.get_origin(type_hint) if hasattr(typing, "get_origin") else getattr(type_hint, "__origin__", None) if origin in (dict, typing.Dict): - args = ( - typing.get_args(type_hint) - if hasattr(typing, "get_args") - else getattr(type_hint, "__args__", ()) - ) + args = typing.get_args(type_hint) if hasattr(typing, "get_args") else getattr(type_hint, "__args__", ()) if len(args) >= 2: key_hint, _ = unwrap_ref(args[0]) key_hint, key_nullable = unwrap_optional(key_hint) @@ -1326,9 +1188,7 @@ def build_field_type_from_type_ids( is_nullable=value_nullable, type_hint=value_hint, ) - return MapFieldType( - type_id, morphic, is_nullable, tracking_ref, key_type, value_type - ) + return MapFieldType(type_id, morphic, is_nullable, tracking_ref, key_type, value_type) elif type_id in [ TypeId.UNKNOWN, TypeId.EXT, @@ -1337,12 +1197,8 @@ def build_field_type_from_type_ids( TypeId.COMPATIBLE_STRUCT, TypeId.NAMED_COMPATIBLE_STRUCT, ]: - return DynamicFieldType( - type_id, False, is_nullable, tracking_ref, user_type_id=NO_USER_TYPE_ID - ) + return DynamicFieldType(type_id, False, is_nullable, tracking_ref, user_type_id=NO_USER_TYPE_ID) else: if type_id <= 0 or type_id >= TypeId.BOUND: raise TypeError(f"Unknown type: {type_id} for field: {field_name}") - return FieldType( - type_id, morphic, is_nullable, tracking_ref, user_type_id=NO_USER_TYPE_ID - ) + return FieldType(type_id, morphic, is_nullable, tracking_ref, user_type_id=NO_USER_TYPE_ID) diff --git a/python/pyfory/meta/typedef_decoder.py b/python/pyfory/meta/typedef_decoder.py index 6a5670f4f4..f221b41469 100644 --- a/python/pyfory/meta/typedef_decoder.py +++ b/python/pyfory/meta/typedef_decoder.py @@ -124,18 +124,14 @@ def decode_typedef(buffer: Buffer, resolver, header=None) -> TypeDef: is_registered_by_name = (meta_header & REGISTER_BY_NAME_FLAG) != 0 compatible = (meta_header & COMPATIBLE_TYPEDEF_FLAG) != 0 if is_registered_by_name: - type_id = ( - TypeId.NAMED_COMPATIBLE_STRUCT if compatible else TypeId.NAMED_STRUCT - ) + type_id = TypeId.NAMED_COMPATIBLE_STRUCT if compatible else TypeId.NAMED_STRUCT else: type_id = TypeId.COMPATIBLE_STRUCT if compatible else TypeId.STRUCT num_fields = meta_header & SMALL_NUM_FIELDS_THRESHOLD if num_fields == SMALL_NUM_FIELDS_THRESHOLD: num_fields += meta_buffer.read_var_uint32() if num_fields > MAX_FIELDS_PER_CLASS: - raise ValueError( - f"Class has {num_fields} fields, exceeding the maximum allowed {MAX_FIELDS_PER_CLASS} fields." - ) + raise ValueError(f"Class has {num_fields} fields, exceeding the maximum allowed {MAX_FIELDS_PER_CLASS} fields.") else: if meta_header & 0b01110000: raise ValueError("Invalid TypeDef kind header") @@ -171,9 +167,7 @@ def decode_typedef(buffer: Buffer, resolver, header=None) -> TypeDef: raise ValueError("Invalid TypeDef metadata size") _validate_parsed_typedef_hash(header, encoded_meta_data) if type_cls is None and is_struct_typedef_kind(type_id): - if getattr(resolver, "strict", False) and not getattr( - resolver, "_allow_unregistered_typedef", False - ): + if getattr(resolver, "strict", False) and not getattr(resolver, "_allow_unregistered_typedef", False): raise ValueError(f"TypeDef {name} is not registered in strict mode") # Check generated class count limit if _generated_class_count >= MAX_GENERATED_CLASSES: @@ -224,9 +218,7 @@ def read_typename(buffer: Buffer) -> str: return read_meta_string(buffer, TYPENAME_DECODER, TYPE_NAME_ENCODINGS) -def read_meta_string( - buffer: Buffer, decoder: MetaStringDecoder, encodings: List[Encoding] -) -> str: +def read_meta_string(buffer: Buffer, decoder: MetaStringDecoder, encodings: List[Encoding]) -> str: """Read a big meta string (namespace/typename) from the buffer using 6-bit size field.""" # Read encoding and length combined in first byte header = buffer.read_uint8() @@ -252,9 +244,7 @@ def read_meta_string( return "" -def read_fields_info( - buffer: Buffer, resolver, defined_class: str, num_fields: int -) -> List[FieldInfo]: +def read_fields_info(buffer: Buffer, resolver, defined_class: str, num_fields: int) -> List[FieldInfo]: """Read field information from the buffer.""" field_infos = [] for _ in range(num_fields): @@ -300,9 +290,7 @@ def read_field_info(buffer: Buffer, resolver, defined_class: str) -> FieldInfo: # Read field type info (no field name to read for TAG_ID) xtype_id = buffer.read_uint8() - field_type = FieldType.read_with_type( - buffer, resolver, xtype_id, is_nullable, is_tracking_ref - ) + field_type = FieldType.read_with_type(buffer, resolver, xtype_id, is_nullable, is_tracking_ref) # For TAG_ID encoding, use tag_id as field name placeholder field_name = f"__tag_{tag_id}__" @@ -317,9 +305,7 @@ def read_field_info(buffer: Buffer, resolver, defined_class: str) -> FieldInfo: # Read field type info BEFORE field name (matching Java TypeDefDecoder order) xtype_id = buffer.read_uint8() - field_type = FieldType.read_with_type( - buffer, resolver, xtype_id, is_nullable, is_tracking_ref - ) + field_type = FieldType.read_with_type(buffer, resolver, xtype_id, is_nullable, is_tracking_ref) # Read field name meta string # Keep the wire field name as-is; TypeDef._resolve_field_names_from_tag_ids() diff --git a/python/pyfory/meta/typedef_encoder.py b/python/pyfory/meta/typedef_encoder.py index 4877d73cbf..69d83ba615 100644 --- a/python/pyfory/meta/typedef_encoder.py +++ b/python/pyfory/meta/typedef_encoder.py @@ -84,9 +84,7 @@ def encode_typedef(type_resolver, cls, include_fields: bool = True): buffer.write_uint8(header) else: if field_infos: - raise ValueError( - f"Non-struct TypeDef {type_id} cannot carry field metadata" - ) + raise ValueError(f"Non-struct TypeDef {type_id} cannot carry field metadata") buffer.write_uint8(xlang_non_struct_kind_code(type_id)) # Write type info @@ -95,9 +93,7 @@ def encode_typedef(type_resolver, cls, include_fields: bool = True): write_namespace(buffer, namespace) write_typename(buffer, typename) else: - assert type_resolver.is_registered_by_id( - cls=cls - ), "Class must be registered by name or id" + assert type_resolver.is_registered_by_id(cls=cls), "Class must be registered by name or id" if user_type_id in {None, NO_USER_TYPE_ID}: raise ValueError(f"user_type_id required for type_id {type_id}") buffer.write_var_uint32(user_type_id) @@ -160,9 +156,7 @@ def write_namespace(buffer: Buffer, namespace: str): # The `6 bits size: 0~63` will be used to indicate size `0~62`, # the value `63` the size need more byte to read, the encoding will encode `size - 62` as a varint next. meta_string = NAMESPACE_ENCODER.encode(namespace, NAMESPACE_ENCODINGS) - write_meta_string( - buffer, meta_string, NAMESPACE_ENCODINGS.index(meta_string.encoding) - ) + write_meta_string(buffer, meta_string, NAMESPACE_ENCODINGS.index(meta_string.encoding)) def write_typename(buffer: Buffer, typename: str): @@ -174,9 +168,7 @@ def write_typename(buffer: Buffer, typename: str): # The `6 bits size: 0~63` will be used to indicate size `1~64`, # the value `63` the size need more byte to read, the encoding will encode `size - 63` as a varint next. meta_string = TYPENAME_ENCODER.encode(typename, TYPE_NAME_ENCODINGS) - write_meta_string( - buffer, meta_string, TYPE_NAME_ENCODINGS.index(meta_string.encoding) - ) + write_meta_string(buffer, meta_string, TYPE_NAME_ENCODINGS.index(meta_string.encoding)) def write_meta_string(buffer: Buffer, meta_string, encoding_value: int): @@ -248,9 +240,7 @@ def write_field_info(buffer: Buffer, field_info: FieldInfo): field_info.field_type.write(buffer, False) else: # Field name encoding - encoding = FIELD_NAME_ENCODER.compute_encoding( - field_info.name, FIELD_NAME_ENCODINGS - ) + encoding = FIELD_NAME_ENCODER.compute_encoding(field_info.name, FIELD_NAME_ENCODINGS) meta_string = FIELD_NAME_ENCODER.encode_with_encoding(field_info.name, encoding) # Store (length - 1) in size field, matching Java TypeDefEncoder field_name_binary_size = len(meta_string.encoded_data) - 1 diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py index b4d9732454..509f7a1a4c 100644 --- a/python/pyfory/serializer.py +++ b/python/pyfory/serializer.py @@ -51,9 +51,7 @@ def _import_validated_module(policy, module_name): if result is not None: if isinstance(result, types.ModuleType): return result - assert isinstance( - result, str - ), f"validate_module must return module, str, or None, got {type(result)}" + assert isinstance(result, str), f"validate_module must return module, str, or None, got {type(result)}" module_name = result return importlib.import_module(module_name) @@ -74,9 +72,7 @@ def _check_collection_size(read_context, size, kind): if size < 0: raise ValueError(f"{kind} size {size} must be non-negative") if size > read_context.max_collection_size: - raise ValueError( - f"{kind} size {size} exceeds the configured limit of {read_context.max_collection_size}" - ) + raise ValueError(f"{kind} size {size} exceeds the configured limit of {read_context.max_collection_size}") def _is_local_qualname(module_name, qualname): @@ -119,9 +115,7 @@ def _validate_function_value(policy, func, is_local): if result is not None: func = result if isinstance(func, type): - raise TypeError( - f"Function serializer resolved class {func.__module__}.{func.__qualname__}" - ) + raise TypeError(f"Function serializer resolved class {func.__module__}.{func.__qualname__}") if _is_bound_method_value(func): result = policy.validate_method(func, is_local=is_local) if result is not None: @@ -381,9 +375,7 @@ def _write_decimal_parts(write_context, scale: int, unscaled: int): magnitude = abs(unscaled) if magnitude == 0: raise ValueError("Zero must use the small decimal encoding") - payload = magnitude.to_bytes( - (magnitude.bit_length() + 7) // 8, "little", signed=False - ) + payload = magnitude.to_bytes((magnitude.bit_length() + 7) // 8, "little", signed=False) meta = (len(payload) << 1) | (1 if unscaled < 0 else 0) _write_var_uint64(write_context, (meta << 1) | 1) write_context.write_bytes(payload) @@ -554,10 +546,7 @@ def _build_pyarray_typecode_tables(): class PyArraySerializer(Serializer): typecode_dict = typecode_dict - typecodearray_type = { - typecode: ftype - for typecode, (_itemsize, ftype, _type_id) in typecode_dict.items() - } + typecodearray_type = {typecode: ftype for typecode, (_itemsize, ftype, _type_id) in typecode_dict.items()} def __init__(self, type_resolver, ftype, type_id: str): super().__init__(type_resolver, ftype) @@ -570,17 +559,13 @@ def _array_type_id(self, value): raise TypeError(f"Unsupported array.array typecode {value.typecode!r}") itemsize, _ftype, type_id = entry if value.itemsize != itemsize: - raise TypeError( - f"array.array typecode {value.typecode!r} has itemsize {value.itemsize}, expected {itemsize}" - ) + raise TypeError(f"array.array typecode {value.typecode!r} has itemsize {value.itemsize}, expected {itemsize}") return type_id def write(self, buffer, value): actual_type_id = self._array_type_id(value) if actual_type_id != self.type_id: - raise TypeError( - f"array.array typecode {value.typecode!r} maps to type id {actual_type_id}, expected {self.type_id}" - ) + raise TypeError(f"array.array typecode {value.typecode!r} maps to type id {actual_type_id}, expected {self.type_id}") view = memoryview(value) assert view.itemsize == self.itemsize assert view.c_contiguous # TODO handle contiguous @@ -646,9 +631,7 @@ def fory_array_serializer_type(type_id): class ForyArrayListAdapterSerializer(Serializer): - def __init__( - self, type_resolver, wrapper_type, wrapper_serializer, field_name=None - ): + def __init__(self, type_resolver, wrapper_type, wrapper_serializer, field_name=None): super().__init__(type_resolver, wrapper_type) self.wrapper_type = wrapper_type self.wrapper_serializer = wrapper_serializer @@ -657,17 +640,13 @@ def __init__( def _copy_list_to_wrapper(self, value): if type(value) is not list: - raise TypeError( - f"pyfory.Array list adapter for {self.field_name!r} requires list, got {type(value)!r}" - ) + raise TypeError(f"pyfory.Array list adapter for {self.field_name!r} requires list, got {type(value)!r}") wrapper = self.wrapper_type() for index, item in enumerate(value): try: wrapper.append(item) except (TypeError, ValueError, OverflowError) as exc: - raise type(exc)( - f"{self.field_name}[{index}] invalid for {self.wrapper_type.__name__}: {exc}" - ) from exc + raise type(exc)(f"{self.field_name}[{index}] invalid for {self.wrapper_type.__name__}: {exc}") from exc return wrapper def write(self, buffer, value): @@ -725,9 +704,7 @@ def write(self, buffer, value): return if value_type is array.array: if self.pyarray_serializer is None: - raise TypeError( - f"pyfory.Array field {self.field_name!r} does not support array.array for this element type" - ) + raise TypeError(f"pyfory.Array field {self.field_name!r} does not support array.array for this element type") actual_type_id = self.pyarray_serializer._array_type_id(value) if actual_type_id != self.type_id: raise TypeError( @@ -737,9 +714,7 @@ def write(self, buffer, value): return if np is not None and value_type is np.ndarray: if self.ndarray_serializer is None: - raise TypeError( - f"pyfory.Array field {self.field_name!r} does not support numpy ndarray for this element type" - ) + raise TypeError(f"pyfory.Array field {self.field_name!r} does not support numpy ndarray for this element type") if value.dtype != self.ndarray_serializer.dtype or value.ndim != 1: raise TypeError( f"pyfory.Array field {self.field_name!r} requires 1D ndarray with dtype " @@ -748,9 +723,7 @@ def write(self, buffer, value): self.ndarray_serializer.write(buffer, value) return if value is None: - raise TypeError( - f"pyfory.Array field {self.field_name!r} value must not be None" - ) + raise TypeError(f"pyfory.Array field {self.field_name!r} value must not be None") raise TypeError( f"pyfory.Array field {self.field_name!r} requires {self.wrapper_type.__name__}, list, numpy.ndarray, or array.array, got {type(value)!r}" ) @@ -769,13 +742,9 @@ def write(self, buffer, value): try: itemsize, ftype, type_id = typecode_dict[value.typecode] except KeyError as exc: - raise TypeError( - f"Unsupported array.array typecode {value.typecode!r}" - ) from exc + raise TypeError(f"Unsupported array.array typecode {value.typecode!r}") from exc if value.itemsize != itemsize: - raise TypeError( - f"array.array typecode {value.typecode!r} has itemsize {value.itemsize}, expected {itemsize}" - ) + raise TypeError(f"array.array typecode {value.typecode!r} has itemsize {value.itemsize}, expected {itemsize}") view = memoryview(value) nbytes = len(value) * itemsize buffer.write_uint8(type_id) @@ -842,9 +811,7 @@ def read(self, buffer): ) else: _np_dtypes_dict = {} -_np_typeid_to_dtype = { - type_id: dtype for dtype, (_, _, _, type_id) in _np_dtypes_dict.items() -} +_np_typeid_to_dtype = {type_id: dtype for dtype, (_, _, _, type_id) in _np_dtypes_dict.items()} class Numpy1DArraySerializer(Serializer): @@ -868,9 +835,7 @@ def write(self, buffer, value): if self.dtype == np.dtype("bool") or not view.c_contiguous: if not is_little_endian and self.itemsize > 1: # Swap bytes on big-endian machines for multi-byte types - buffer.write_bytes( - value.astype(value.dtype.newbyteorder("<")).tobytes() - ) + buffer.write_bytes(value.astype(value.dtype.newbyteorder("<")).tobytes()) else: buffer.write_bytes(value.tobytes()) elif is_little_endian or self.itemsize == 1: @@ -897,9 +862,7 @@ def write(self, buffer, value): # Write concrete 1D primitive ndarray using type id + bytes payload. dtype_info = _np_dtypes_dict.get(value.dtype) if dtype_info is None or value.ndim != 1: - raise NotImplementedError( - f"Unsupported ndarray: dtype={value.dtype}, ndim={value.ndim}" - ) + raise NotImplementedError(f"Unsupported ndarray: dtype={value.dtype}, ndim={value.ndim}") itemsize, _typecode, _ftype, type_id = dtype_info view = memoryview(value) nbytes = len(value) * itemsize @@ -907,9 +870,7 @@ def write(self, buffer, value): buffer.write_var_uint32(nbytes) if value.dtype == np.dtype("bool") or not view.c_contiguous: if not is_little_endian and itemsize > 1: - buffer.write_bytes( - value.astype(value.dtype.newbyteorder("<")).tobytes() - ) + buffer.write_bytes(value.astype(value.dtype.newbyteorder("<")).tobytes()) else: buffer.write_bytes(value.tobytes()) elif is_little_endian or itemsize == 1: @@ -1189,10 +1150,7 @@ def _resolve_global_name(self, read_context, global_name): def write(self, write_context, value): # Try __reduce_ex__ first (with protocol 5 for pickle5 out-of-band buffer support), then __reduce__ # Check if the object has a custom __reduce_ex__ method (not just the default from object) - if ( - hasattr(value, "__reduce_ex__") - and value.__class__.__reduce_ex__ is not object.__reduce_ex__ - ): + if hasattr(value, "__reduce_ex__") and value.__class__.__reduce_ex__ is not object.__reduce_ex__: try: reduce_result = value.__reduce_ex__(5) except TypeError: @@ -1201,9 +1159,7 @@ def write(self, write_context, value): elif hasattr(value, "__reduce__"): reduce_result = value.__reduce__() else: - raise ValueError( - f"Object {value} has no __reduce__ or __reduce_ex__ method" - ) + raise ValueError(f"Object {value} has no __reduce__ or __reduce_ex__ method") # Handle different __reduce__ return formats if isinstance(reduce_result, str): @@ -1234,9 +1190,7 @@ def write(self, write_context, value): dictitems, ) else: - raise ValueError( - f"Invalid __reduce__ result length: {len(reduce_result)}" - ) + raise ValueError(f"Invalid __reduce__ result length: {len(reduce_result)}") else: raise ValueError(f"Invalid __reduce__ result type: {type(reduce_result)}") write_context.write_var_uint32(len(reduce_data)) @@ -1325,9 +1279,7 @@ def read(self, read_context): return self._deserialize_local_class(read_context) module_name = read_context.read_string() qualname = read_context.read_string() - cls = _resolve_validated_module_qualname( - read_context.policy, module_name, qualname - ) + cls = _resolve_validated_module_qualname(read_context.policy, module_name, qualname) result = read_context.policy.validate_class(cls, is_local=_is_local_class(cls)) if result is not None: cls = result @@ -1335,9 +1287,7 @@ def read(self, read_context): def _serialize_local_class(self, write_context, cls): """Serialize a local class by capturing its creation context.""" - assert ( - self.type_resolver.track_ref - ), "Reference tracking must be enabled for local classes serialization" + assert self.type_resolver.track_ref, "Reference tracking must be enabled for local classes serialization" module = cls.__module__ qualname = cls.__qualname__ write_context.write_string(module) @@ -1370,9 +1320,7 @@ def _serialize_local_class(self, write_context, cls): def _deserialize_local_class(self, read_context): """Deserialize a local class by recreating it with the captured context.""" - assert ( - self.type_resolver.track_ref - ), "Reference tracking must be enabled for local classes deserialization" + assert self.type_resolver.track_ref, "Reference tracking must be enabled for local classes deserialization" module = read_context.read_string() qualname = read_context.read_string() name = qualname.rsplit(".", 1)[-1] @@ -1381,9 +1329,7 @@ def _deserialize_local_class(self, read_context): num_bases = read_context.read_var_uint32() _check_collection_size(read_context, num_bases, "local class base") bases = tuple(read_context.read_ref() for _ in range(num_bases)) - read_context.policy.authorize_instantiation( - type, module=module, qualname=qualname, bases=bases - ) + read_context.policy.authorize_instantiation(type, module=module, qualname=qualname, bases=bases) cls = type(name, bases, {}) read_context.set_read_ref(ref_id, cls) result = read_context.policy.validate_class(cls, is_local=True) @@ -1544,9 +1490,7 @@ def _serialize_function(self, write_context, func): global_names.add(name) # Create and serialize a dictionary with only the necessary globals - globals_to_serialize = { - name: globals_dict[name] for name in global_names if name in globals_dict - } + globals_to_serialize = {name: globals_dict[name] for name in global_names if name in globals_dict} write_context.write_ref(globals_to_serialize) # Handle additional attributes @@ -1573,19 +1517,13 @@ def _deserialize_function(self, read_context): policy = read_context.policy if policy is DEFAULT_POLICY: return getattr(self_obj, method_name) - return _resolve_validated_bound_method( - policy, self_obj, method_name, is_local=_is_local_receiver(self_obj) - ) + return _resolve_validated_bound_method(policy, self_obj, method_name, is_local=_is_local_receiver(self_obj)) if func_type_id == 1: module = read_context.read_string() qualname = read_context.read_string() - mod = _resolve_validated_module_qualname( - read_context.policy, module, qualname - ) - return _validate_function_value( - read_context.policy, mod, is_local=_is_local_callable(mod) - ) + mod = _resolve_validated_module_qualname(read_context.policy, module, qualname) + return _validate_function_value(read_context.policy, mod, is_local=_is_local_callable(mod)) module = read_context.read_string() qualname = read_context.read_string() @@ -1672,18 +1610,14 @@ def read(self, read_context): if read_context.read_bool(): module = read_context.read_string() func = _resolve_validated_module_attr(read_context.policy, module, name) - func = _validate_function_value( - read_context.policy, func, is_local=_is_local_callable(func) - ) + func = _validate_function_value(read_context.policy, func, is_local=_is_local_callable(func)) else: obj = read_context.read_ref() policy = read_context.policy if policy is DEFAULT_POLICY: func = getattr(obj, name) else: - func = _resolve_validated_bound_method( - policy, obj, name, is_local=_is_local_receiver(obj) - ) + func = _resolve_validated_bound_method(policy, obj, name, is_local=_is_local_receiver(obj)) return func @@ -1752,9 +1686,7 @@ def read(self, read_context): read_context.reference(obj) num_fields = read_context.read_var_uint32() if num_fields > read_context.max_collection_size: - raise ValueError( - f"object field size {num_fields} exceeds the configured limit of {read_context.max_collection_size}" - ) + raise ValueError(f"object field size {num_fields} exceeds the configured limit of {read_context.max_collection_size}") state = {} for _ in range(num_fields): field_name = read_context.read_string() @@ -1772,9 +1704,7 @@ def read(self, read_context): read_context.reference(obj) num_fields = read_context.read_var_uint32() if num_fields > read_context.max_collection_size: - raise ValueError( - f"object field size {num_fields} exceeds the configured limit of {read_context.max_collection_size}" - ) + raise ValueError(f"object field size {num_fields} exceeds the configured limit of {read_context.max_collection_size}") for _ in range(num_fields): field_name = read_context.read_string() field_value = read_context.read_ref() diff --git a/python/pyfory/tests/test_policy.py b/python/pyfory/tests/test_policy.py index e659ce30d3..fe1be69556 100644 --- a/python/pyfory/tests/test_policy.py +++ b/python/pyfory/tests/test_policy.py @@ -116,10 +116,7 @@ def __init__(self, blocked_names): self.blocked_names = blocked_names def intercept_reduce_call(self, callable_obj, args, **kwargs): - if ( - hasattr(callable_obj, "__name__") - and callable_obj.__name__ in self.blocked_names - ): + if hasattr(callable_obj, "__name__") and callable_obj.__name__ in self.blocked_names: raise ValueError(f"Callable {callable_obj.__name__} is blocked") return None @@ -349,9 +346,7 @@ def validate_class(self, cls, is_local, **kwargs): def intercept_reduce_call(self, callable_obj, args, **kwargs): if hasattr(callable_obj, "__name__"): - self.hooks_called.append( - ("intercept_reduce_call", callable_obj.__name__) - ) + self.hooks_called.append(("intercept_reduce_call", callable_obj.__name__)) return None def inspect_reduced_object(self, obj, **kwargs): @@ -659,9 +654,7 @@ def validate_method(self, method, is_local, **kwargs): policy = CaptureMethodPolicy() fory = Fory(ref=True, strict=False, policy=policy) - serializer = NativeFuncMethodSerializer( - fory.type_resolver, type(policy_global_function) - ) + serializer = NativeFuncMethodSerializer(fory.type_resolver, type(policy_global_function)) read_context = FakeReadContext(policy, ["run", False, LocalReceiver()]) with pytest.raises(ValueError, match="method blocked"): @@ -735,9 +728,7 @@ def validate_function(self, func, is_local, **kwargs): try: policy = CaptureFunctionPolicy() fory = Fory(ref=True, strict=False, policy=policy) - serializer = FunctionSerializer( - fory.type_resolver, type(policy_global_function) - ) + serializer = FunctionSerializer(fory.type_resolver, type(policy_global_function)) read_context = FakeReadContext(policy, [1, __name__, "policy_global_function"]) assert serializer._deserialize_function(read_context) is policy_global_function @@ -764,9 +755,7 @@ def validate_function(self, func, is_local, **kwargs): policy = BlockClassPolicy() fory = Fory(ref=True, strict=False, policy=policy) - serializer = NativeFuncMethodSerializer( - fory.type_resolver, type(policy_global_function) - ) + serializer = NativeFuncMethodSerializer(fory.type_resolver, type(policy_global_function)) read_context = FakeReadContext(policy, ["Popen", True, "subprocess"]) with pytest.raises(ValueError, match="class blocked"): @@ -791,12 +780,8 @@ def validate_function(self, func, is_local, **kwargs): policy = MethodPolicy() fory = Fory(ref=True, strict=False, policy=policy) - serializer = NativeFuncMethodSerializer( - fory.type_resolver, type(policy_global_function) - ) - read_context = FakeReadContext( - policy, ["policy_global_bound_method", True, __name__] - ) + serializer = NativeFuncMethodSerializer(fory.type_resolver, type(policy_global_function)) + read_context = FakeReadContext(policy, ["policy_global_bound_method", True, __name__]) with pytest.raises(ValueError, match="method blocked"): serializer.read(read_context) @@ -818,12 +803,8 @@ def validate_function(self, func, is_local, **kwargs): try: policy = CaptureFunctionPolicy() fory = Fory(ref=True, strict=False, policy=policy) - serializer = NativeFuncMethodSerializer( - fory.type_resolver, type(policy_global_function) - ) - read_context = FakeReadContext( - policy, ["policy_global_function", True, __name__] - ) + serializer = NativeFuncMethodSerializer(fory.type_resolver, type(policy_global_function)) + read_context = FakeReadContext(policy, ["policy_global_function", True, __name__]) assert serializer.read(read_context) is policy_global_function assert policy.is_local_values == [True] diff --git a/python/pyfory/tests/test_typedef_encoding.py b/python/pyfory/tests/test_typedef_encoding.py index 0af48dd991..a66aab57e1 100644 --- a/python/pyfory/tests/test_typedef_encoding.py +++ b/python/pyfory/tests/test_typedef_encoding.py @@ -130,9 +130,7 @@ def test_typedef_creation(): FieldInfo("age", FieldType(TypeId.INT32, True, True, False), "TestTypeDef"), ] - typedef = TypeDef( - "", "TestTypeDef", None, TypeId.STRUCT, fields, b"encoded_data", False - ) + typedef = TypeDef("", "TestTypeDef", None, TypeId.STRUCT, fields, b"encoded_data", False) assert typedef.namespace == "" assert typedef.typename == "TestTypeDef" @@ -194,9 +192,7 @@ def test_encode_decode_typedef(): for i, field in enumerate(decoded_typedef.fields): assert field.name == typedef.fields[i].name assert field.field_type.type_id == typedef.fields[i].field_type.type_id - assert ( - field.field_type.is_nullable == typedef.fields[i].field_type.is_nullable - ) + assert field.field_type.is_nullable == typedef.fields[i].field_type.is_nullable def test_decode_typedef_rejects_parsed_body_with_mismatched_hash(): @@ -232,9 +228,7 @@ def test_decode_typedef_rejects_compressed_xlang_metadata(): def test_id_registered_typedef_extended_field_count_header(): - many_fields_type = make_dataclass( - "ManyTypeDefFields", [(f"field_{i}", int) for i in range(32)] - ) + many_fields_type = make_dataclass("ManyTypeDefFields", [(f"field_{i}", int) for i in range(32)]) fory = Fory(xlang=True) fory.register(many_fields_type, type_id=701) typedef = encode_typedef(fory.type_resolver, many_fields_type) @@ -285,37 +279,26 @@ def _typedef_body_offset(encoded): def test_nested_container_typedef_preserves_declared_encoding(): fory = Fory(xlang=True) - fory.register( - NestedEncodingTypeDef, namespace="example", typename="NestedEncodingTypeDef" - ) + fory.register(NestedEncodingTypeDef, namespace="example", typename="NestedEncodingTypeDef") typedef = encode_typedef(fory.type_resolver, NestedEncodingTypeDef) values_field = next(field for field in typedef.fields if field.name == "values") assert values_field.field_type.type_id == TypeId.MAP assert values_field.field_type.key_type.type_id == TypeId.INT32 assert values_field.field_type.value_type.type_id == TypeId.LIST - assert ( - values_field.field_type.value_type.element_type.type_id == TypeId.TAGGED_INT64 - ) + assert values_field.field_type.value_type.element_type.type_id == TypeId.TAGGED_INT64 decoded_typedef = decode_typedef(Buffer(typedef.encoded), fory.type_resolver) - decoded_values_field = next( - field for field in decoded_typedef.fields if field.name == "values" - ) + decoded_values_field = next(field for field in decoded_typedef.fields if field.name == "values") assert decoded_values_field.field_type.type_id == TypeId.MAP assert decoded_values_field.field_type.key_type.type_id == TypeId.INT32 assert decoded_values_field.field_type.value_type.type_id == TypeId.LIST - assert ( - decoded_values_field.field_type.value_type.element_type.type_id - == TypeId.TAGGED_INT64 - ) + assert decoded_values_field.field_type.value_type.element_type.type_id == TypeId.TAGGED_INT64 def test_python_array_typehint_lowering_keeps_list_schema_distinct(): fory = Fory(xlang=True) - fory.register( - PythonArrayTypeHints, namespace="example", typename="PythonArrayTypeHints" - ) + fory.register(PythonArrayTypeHints, namespace="example", typename="PythonArrayTypeHints") typedef = encode_typedef(fory.type_resolver, PythonArrayTypeHints) fields = {field.name: field.field_type for field in typedef.fields} @@ -341,9 +324,7 @@ def test_python_array_typehint_rejects_scalar_encoding_modifier(): namespace="example", typename="InvalidArrayModifierTypeDef", ) - with pytest.raises( - TypeError, match="array does not allow scalar encoding modifier" - ): + with pytest.raises(TypeError, match="array does not allow scalar encoding modifier"): encode_typedef(fory.type_resolver, InvalidArrayModifierTypeDef) @@ -368,9 +349,7 @@ def test_compatible_bytes_assigns_to_uint8_array(): _register_byte_sequence(writer, BytesPayload) _register_byte_sequence(reader, UInt8ArrayPayload) - decoded = reader.deserialize( - writer.serialize(BytesPayload(payload=b"\x01\x02\xff")) - ) + decoded = reader.deserialize(writer.serialize(BytesPayload(payload=b"\x01\x02\xff"))) assert isinstance(decoded, UInt8ArrayPayload) _assert_uint8_array_value(decoded.payload, [1, 2, 255]) @@ -382,9 +361,7 @@ def test_compatible_uint8_array_assigns_to_bytes(): _register_byte_sequence(writer, UInt8ArrayPayload) _register_byte_sequence(reader, BytesPayload) - decoded = reader.deserialize( - writer.serialize(UInt8ArrayPayload(payload=_uint8_array_value([1, 2, 255]))) - ) + decoded = reader.deserialize(writer.serialize(UInt8ArrayPayload(payload=_uint8_array_value([1, 2, 255])))) assert isinstance(decoded, BytesPayload) assert decoded.payload == b"\x01\x02\xff" diff --git a/rust/fory-core/src/meta/type_meta.rs b/rust/fory-core/src/meta/type_meta.rs index ff40dd4a2b..4338c758fb 100644 --- a/rust/fory-core/src/meta/type_meta.rs +++ b/rust/fory-core/src/meta/type_meta.rs @@ -1096,7 +1096,10 @@ impl TypeMeta { reader: &mut Reader, header: i64, ) -> Result<(), Error> { - let meta_size = read_type_meta_body_size(reader, header)?; + let mut meta_size = (header & META_SIZE_MASK) as usize; + if meta_size == META_SIZE_MASK as usize { + meta_size += reader.read_var_u32()? as usize; + } reader.skip(meta_size) } From 4f9335959a96ce7c89eb9050aa6f3ae63c4b5b52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=85=95=E7=99=BD?= Date: Wed, 6 May 2026 13:21:03 +0800 Subject: [PATCH 10/10] fix: harden deserialization metadata and length checks --- .agents/languages/java.md | 3 + .gitignore | 1 + README.md | 2 +- csharp/src/Fory/MetaString.cs | 2 + csharp/src/Fory/ReadContext.cs | 40 ++-- csharp/src/Fory/StringSerializer.cs | 3 +- csharp/src/Fory/TypeMeta.cs | 5 + .../tests/Fory.Tests/RuntimeEdgeCaseTests.cs | 34 +++- go/fory/buffer.go | 37 +++- go/fory/buffer_test.go | 27 +++ go/fory/meta_string_resolver.go | 122 +++++++----- go/fory/meta_string_resolver_test.go | 125 ++++++++++++ go/fory/stream.go | 26 +-- go/fory/stream_test.go | 23 +++ go/fory/string.go | 4 +- go/fory/struct.go | 4 +- go/fory/tests/metastring_resolver_test.go | 2 +- go/fory/type_def.go | 7 +- go/fory/type_def_test.go | 30 +++ go/fory/type_resolver.go | 19 +- .../java/org/apache/fory/config/Config.java | 18 ++ .../org/apache/fory/config/ForyBuilder.java | 26 +++ .../org/apache/fory/context/ReadContext.java | 10 +- .../org/apache/fory/memory/MemoryBuffer.java | 80 ++++++-- .../java/org/apache/fory/meta/TypeDef.java | 15 +- .../apache/fory/resolver/TypeResolver.java | 6 +- .../fory/serializer/ArraySerializers.java | 35 ++++ .../fory/serializer/BigIntegerSerializer.java | 17 +- .../fory/serializer/DecimalSerializer.java | 36 +++- .../serializer/PrimitiveArraySerializers.java | 183 ++++++++++++++++-- .../collection/ChildContainerSerializers.java | 6 +- .../collection/CollectionLikeSerializer.java | 29 ++- .../collection/CollectionSerializers.java | 93 +++++++-- .../GuavaCollectionSerializers.java | 14 +- .../ImmutableCollectionSerializers.java | 6 +- .../collection/MapLikeSerializer.java | 28 ++- .../serializer/collection/MapSerializers.java | 16 +- .../collection/PrimitiveListSerializers.java | 93 ++++++++- .../collection/SubListSerializers.java | 2 +- .../apache/fory/memory/MemoryBufferTest.java | 19 ++ .../apache/fory/meta/TypeDefEncoderTest.java | 9 + .../fory/serializer/ArraySerializersTest.java | 119 ++++++++++++ .../serializer/BufferSerializersTest.java | 21 ++ .../serializer/JdkProxySerializerTest.java | 26 +++ .../serializer/PrimitiveSerializersTest.java | 61 ++++++ .../fory/serializer/SerializersTest.java | 51 +++++ .../collection/CollectionSerializersTest.java | 69 +++++++ .../collection/MapSerializersTest.java | 20 ++ javascript/packages/core/lib/context.ts | 18 +- javascript/packages/core/lib/reader/index.ts | 41 +++- javascript/test/hps.test.ts | 2 +- javascript/test/io.test.ts | 23 ++- javascript/test/typemeta.test.ts | 35 ++++ rust/fory-core/src/buffer.rs | 2 + rust/fory-core/src/meta/meta_string.rs | 20 +- rust/fory-core/src/meta/type_meta.rs | 3 + rust/fory-core/src/util/string_util.rs | 20 +- 57 files changed, 1541 insertions(+), 247 deletions(-) diff --git a/.agents/languages/java.md b/.agents/languages/java.md index 1d19574423..5283d5a4cf 100644 --- a/.agents/languages/java.md +++ b/.agents/languages/java.md @@ -22,6 +22,9 @@ Load this file when changing anything under `java/` or when Java drives a cross- - Do not add normal-JVM process-global caches keyed by user classes, generated classes, serializer classes, classloaders, or class-bound method handles. Prefer per-runtime state, immutable shared metadata, or build-time-only template data. - Concrete serializers may opt into sharing only after auditing retained fields. Treat serializers retaining `TypeResolver`, `RefResolver`, mutable scratch buffers, runtime state, or classloader-sensitive state as non-shareable unless that state is externalized. - Resolver and serializer hot paths should keep the fast-path/null-slow-path shape obvious. Hoist repeated buffer or cache-state access into locals for multi-step operations and keep rebuild/restoration logic cold. +- In Java codec hot paths, avoid `Preconditions.checkArgument` for attacker-controlled primitive + validation. Use direct primitive branches and throw on the cold error path to preserve inlining and + avoid varargs/helper overhead. - Keep GraalVM feature code as a thin metadata/registration layer. Build time should publish metadata needed for runtime reconstruction, not retain concrete generated or user serializer instances in the image heap. - If changes touch GraalVM bootstrap, serializer retention, native-image metadata, or `ObjectStreamSerializer` GraalVM behavior, verify the native-image build and run the produced binary; a plain Java compile is insufficient. - Put latest-JDK or virtual-thread tests in the latest-JDK test modules with the matching compiler/profile floor, and centralize runtime-version probing in existing compatibility utilities. diff --git a/.gitignore b/.gitignore index c979f24d9c..daaab45bd3 100644 --- a/.gitignore +++ b/.gitignore @@ -133,6 +133,7 @@ test.md benchmarks/dart/profile_output **/*.fory.dart +integration_tests/idl_tests/dart/.dart_tool/ **/pubspec.lock **/tmp/* \ No newline at end of file diff --git a/README.md b/README.md index a75ddad9a1..ba7a38769c 100644 --- a/README.md +++ b/README.md @@ -118,7 +118,7 @@ For more detailed benchmarks and methodology, see [Go Benchmark](benchmarks/go).

-For more detailed benchmarks and methodology, see [Pythonk](benchmarks/python). +For more detailed benchmarks and methodology, see [Python](benchmarks/python). ### JavaScript/NodeJS Serialization Performance diff --git a/csharp/src/Fory/MetaString.cs b/csharp/src/Fory/MetaString.cs index 34d3f07b7d..df49932fd7 100644 --- a/csharp/src/Fory/MetaString.cs +++ b/csharp/src/Fory/MetaString.cs @@ -430,6 +430,8 @@ public MetaString Decode(byte[] bytes, MetaStringEncoding encoding) { string value = encoding switch { + // C# intentionally preserves platform UTF-8 replacement behavior; Rust is the runtime + // that provides checked UTF-8 string reads by default. MetaStringEncoding.Utf8 => Encoding.UTF8.GetString(bytes), MetaStringEncoding.LowerSpecial => DecodeGeneric(bytes, 5, UnmapLowerSpecial), MetaStringEncoding.LowerUpperDigitSpecial => DecodeGeneric(bytes, 6, UnmapLowerUpperDigitSpecial), diff --git a/csharp/src/Fory/ReadContext.cs b/csharp/src/Fory/ReadContext.cs index 2825653c76..6444500c40 100644 --- a/csharp/src/Fory/ReadContext.cs +++ b/csharp/src/Fory/ReadContext.cs @@ -21,14 +21,12 @@ public sealed class ReadContext { private const int MaxParsedTypeMetaEntries = 8192; - private readonly record struct CachedTypeMetaEntry(TypeMeta TypeMeta, int SkipBytesAfterHeader); - private readonly ReusableArray _readTypeMetas = new(); - private readonly Dictionary _cachedTypeMetasByHeader = []; + private readonly Dictionary _cachedTypeMetasByHeader = []; private TypeMeta? _firstReadTypeMeta; private bool _hasFirstReadTypeMeta; private ulong _lastMetaHeader; - private CachedTypeMetaEntry _lastTypeMeta; + private TypeMeta? _lastTypeMeta; private bool _hasLastMetaHeader; private readonly List _readMetaStrings = []; @@ -135,33 +133,30 @@ internal void StoreReadTypeMeta(TypeMeta typeMeta, int index) $"type meta index gap: index={index}, count={_readTypeMetas.Count + 1}"); } - internal bool TryGetCachedReadTypeMeta(ulong header, out TypeMeta typeMeta, out int skipBytesAfterHeader) + internal bool TryGetCachedReadTypeMeta(ulong header, out TypeMeta typeMeta) { - if (_hasLastMetaHeader && _lastMetaHeader == header) + if (_hasLastMetaHeader && _lastMetaHeader == header && _lastTypeMeta is not null) { - typeMeta = _lastTypeMeta.TypeMeta; - skipBytesAfterHeader = _lastTypeMeta.SkipBytesAfterHeader; + typeMeta = _lastTypeMeta; return true; } - if (_cachedTypeMetasByHeader.TryGetValue(header, out CachedTypeMetaEntry cached)) + if (_cachedTypeMetasByHeader.TryGetValue(header, out TypeMeta? cached) && cached is not null) { _lastMetaHeader = header; _lastTypeMeta = cached; _hasLastMetaHeader = true; - typeMeta = cached.TypeMeta; - skipBytesAfterHeader = cached.SkipBytesAfterHeader; + typeMeta = cached; return true; } typeMeta = null!; - skipBytesAfterHeader = 0; return false; } - internal void CacheReadTypeMeta(ulong header, TypeMeta typeMeta, int skipBytesAfterHeader) + internal void CacheReadTypeMeta(ulong header, TypeMeta typeMeta) { - if (_cachedTypeMetasByHeader.TryGetValue(header, out CachedTypeMetaEntry existing)) + if (_cachedTypeMetasByHeader.TryGetValue(header, out TypeMeta? existing) && existing is not null) { _lastMetaHeader = header; _lastTypeMeta = existing; @@ -174,11 +169,10 @@ internal void CacheReadTypeMeta(ulong header, TypeMeta typeMeta, int skipBytesAf return; } - CachedTypeMetaEntry cached = new(typeMeta, skipBytesAfterHeader); _lastMetaHeader = header; - _lastTypeMeta = cached; + _lastTypeMeta = typeMeta; _hasLastMetaHeader = true; - _cachedTypeMetasByHeader.TryAdd(header, cached); + _cachedTypeMetasByHeader.TryAdd(header, typeMeta); } internal MetaString? GetReadMetaString(int index) @@ -208,22 +202,20 @@ internal TypeMeta ReadTypeMeta() } ulong header = Reader.ReadUInt64(); - if (TryGetCachedReadTypeMeta(header, out TypeMeta cachedTypeMeta, out int skipBytesAfterHeader)) + if (TryGetCachedReadTypeMeta(header, out TypeMeta cachedTypeMeta)) { // Header-cache hits intentionally skip without rehashing. Entries reach this cache only - // after a successful TypeMeta parse and 52-bit body-hash validation. - Reader.Skip(skipBytesAfterHeader); + // after a successful TypeMeta parse and 52-bit body-hash validation. The current body + // size still comes from the current header bytes, not from the cached TypeMeta. + TypeMeta.SkipBody(Reader, header); StoreReadTypeMeta(cachedTypeMeta, index); return cachedTypeMeta; } - int headerStartCursor = Reader.Cursor - sizeof(ulong); Reader.MoveBack(sizeof(ulong)); TypeMeta typeMeta = TypeMeta.Decode(Reader); - int consumedTypeMetaBytes = Reader.Cursor - headerStartCursor; - int parsedSkipBytesAfterHeader = consumedTypeMetaBytes - sizeof(ulong); StoreReadTypeMeta(typeMeta, index); - CacheReadTypeMeta(header, typeMeta, parsedSkipBytesAfterHeader); + CacheReadTypeMeta(header, typeMeta); return typeMeta; } diff --git a/csharp/src/Fory/StringSerializer.cs b/csharp/src/Fory/StringSerializer.cs index 78a822eb1a..c176348c83 100644 --- a/csharp/src/Fory/StringSerializer.cs +++ b/csharp/src/Fory/StringSerializer.cs @@ -23,7 +23,6 @@ public sealed class StringSerializer : Serializer { private const int MaxVarUInt36SmallBytes = 6; - public override string DefaultValue => null!; public override void WriteData(WriteContext context, in string value, bool hasGenerics) @@ -65,6 +64,8 @@ public static string ReadString(ReadContext context) ReadOnlySpan bytes = context.Reader.ReadSpan(byteLength); return encoding switch { + // C# intentionally preserves platform UTF-8 replacement behavior; Rust is the runtime + // that provides checked UTF-8 string reads by default. (ulong)ForyStringEncoding.Utf8 => Encoding.UTF8.GetString(bytes), (ulong)ForyStringEncoding.Latin1 => DecodeLatin1(bytes), (ulong)ForyStringEncoding.Utf16 => DecodeUtf16(bytes), diff --git a/csharp/src/Fory/TypeMeta.cs b/csharp/src/Fory/TypeMeta.cs index 597cffe473..9073b5f1f4 100644 --- a/csharp/src/Fory/TypeMeta.cs +++ b/csharp/src/Fory/TypeMeta.cs @@ -604,6 +604,11 @@ private static int ReadBodySize(ByteReader reader, ulong header) return metaSize; } + internal static void SkipBody(ByteReader reader, ulong header) + { + reader.Skip(ReadBodySize(reader, header)); + } + private static ulong ComputeHeaderHashBits(ReadOnlySpan body) { (ulong bodyHash, _) = MurmurHash3.X64_128(body, TypeMetaConstants.TypeMetaHashSeed); diff --git a/csharp/tests/Fory.Tests/RuntimeEdgeCaseTests.cs b/csharp/tests/Fory.Tests/RuntimeEdgeCaseTests.cs index c8452b63a4..3b3011a3fa 100644 --- a/csharp/tests/Fory.Tests/RuntimeEdgeCaseTests.cs +++ b/csharp/tests/Fory.Tests/RuntimeEdgeCaseTests.cs @@ -449,13 +449,39 @@ public void TypeMetaHeaderCacheStopsPublishingAtCapacity() for (ulong header = 1; header <= 8192; header++) { - context.CacheReadTypeMeta(header, typeMeta, skipBytesAfterHeader: 0); + context.CacheReadTypeMeta(header, typeMeta); } - Assert.True(context.TryGetCachedReadTypeMeta(8192, out _, out _)); - context.CacheReadTypeMeta(8193, typeMeta, skipBytesAfterHeader: 0); + Assert.True(context.TryGetCachedReadTypeMeta(8192, out _)); + context.CacheReadTypeMeta(8193, typeMeta); - Assert.False(context.TryGetCachedReadTypeMeta(8193, out _, out _)); + Assert.False(context.TryGetCachedReadTypeMeta(8193, out _)); + } + + [Fact] + public void TypeMetaHeaderCacheHitSkipsCurrentBodySize() + { + const ulong header = 0xffUL; + TypeMeta typeMeta = new( + (uint)TypeId.Struct, + 902, + MetaString.Empty('.', '_'), + MetaString.Empty('$', '_'), + registerByName: false, + []); + + ByteWriter writer = new(); + writer.WriteVarUInt32(0); + writer.WriteUInt64(header); + writer.WriteVarUInt32(0); + writer.WriteBytes(new byte[0xff]); + writer.WriteUInt8(0x7b); + + ReadContext context = new(new ByteReader(writer.ToArray()), new TypeResolver(), trackRef: false); + context.CacheReadTypeMeta(header, typeMeta); + + Assert.Same(typeMeta, context.ReadTypeMeta()); + Assert.Equal(0x7b, context.Reader.ReadUInt8()); } [Fact] diff --git a/go/fory/buffer.go b/go/fory/buffer.go index c51071fd14..2e0f13655e 100644 --- a/go/fory/buffer.go +++ b/go/fory/buffer.go @@ -1224,8 +1224,8 @@ func (b *ByteBuffer) ReadVarint32(err *Error) int32 { // UnsafeReadVarint32 reads a varint32 without bounds checking. // Caller must ensure remaining() >= 5 before calling. -func (b *ByteBuffer) UnsafeReadVarint32() int32 { - u := b.readVarUint32Fast() +func (b *ByteBuffer) UnsafeReadVarint32(err *Error) int32 { + u := b.readVarUint32Fast(err) v := int32(u >> 1) if u&1 != 0 { v = ^v @@ -1246,8 +1246,8 @@ func (b *ByteBuffer) UnsafeReadVarint64() int64 { // UnsafeReadVarUint32 reads a VarUint32 without bounds checking. // Caller must ensure remaining() >= 5 before calling. -func (b *ByteBuffer) UnsafeReadVarUint32() uint32 { - return b.readVarUint32Fast() +func (b *ByteBuffer) UnsafeReadVarUint32(err *Error) uint32 { + return b.readVarUint32Fast(err) } // UnsafeReadVarUint64 reads a VarUint64 without bounds checking. @@ -1259,13 +1259,13 @@ func (b *ByteBuffer) UnsafeReadVarUint64() uint64 { // ReadVarUint32 reads a VarUint32 and sets error on bounds violation func (b *ByteBuffer) ReadVarUint32(err *Error) uint32 { if b.remaining() >= 8 { // Need 8 bytes for bulk uint64 read in fast path - return b.readVarUint32Fast() + return b.readVarUint32Fast(err) } return b.readVarUint32Slow(err) } // Fast path reading (when the remaining bytes are sufficient) -func (b *ByteBuffer) readVarUint32Fast() uint32 { +func (b *ByteBuffer) readVarUint32Fast(err *Error) uint32 { // Single instruction load using unsafe pointer cast (little-endian only) // On big-endian systems, use binary.LittleEndian which the compiler optimizes var bulk uint64 @@ -1288,6 +1288,13 @@ func (b *ByteBuffer) readVarUint32Fast() uint32 { result |= uint32((bulk >> 3) & 0xFE00000) readLength = 4 if (bulk & 0x80000000) != 0 { + fifth := byte(bulk >> 32) + if fifth > 0x0F { + if err != nil { + *err = DeserializationError("VarUint32 overflow") + } + return 0 + } result |= uint32((bulk >> 4) & 0xF0000000) readLength = 5 } @@ -1310,6 +1317,12 @@ func (b *ByteBuffer) readVarUint32Slow(err *Error) uint32 { } byteVal := b.data[b.readerIndex] b.readerIndex++ + if shift == 28 && byteVal > 0x0F { + if err != nil { + *err = DeserializationError("VarUint32 overflow") + } + return 0 + } result |= (uint32(byteVal) & 0x7F) << shift if byteVal < 0x80 { break @@ -1475,16 +1488,16 @@ func (b *ByteBuffer) readVarUint32Small14(err *Error) uint32 { readIdx++ value |= (four >> 1) & 0x3f80 if four&0x8000 != 0 { - return b.continueReadVarUint32(readIdx, four, value) + return b.continueReadVarUint32(readIdx, four, value, err) } } b.readerIndex = readIdx return value } - return uint32(b.readVaruint36Slow(err)) + return b.readVarUint32Slow(err) } -func (b *ByteBuffer) continueReadVarUint32(readIdx int, bulkRead, value uint32) uint32 { +func (b *ByteBuffer) continueReadVarUint32(readIdx int, bulkRead, value uint32, err *Error) uint32 { readIdx++ value |= (bulkRead >> 2) & 0x1fc000 if bulkRead&0x800000 != 0 { @@ -1492,6 +1505,12 @@ func (b *ByteBuffer) continueReadVarUint32(readIdx int, bulkRead, value uint32) value |= (bulkRead >> 3) & 0xfe00000 if bulkRead&0x80000000 != 0 { v := b.data[readIdx] + if v > 0x0F { + if err != nil { + *err = DeserializationError("VarUint32 overflow") + } + return 0 + } readIdx++ value |= uint32(v&0x7F) << 28 } diff --git a/go/fory/buffer_test.go b/go/fory/buffer_test.go index a65d49a7a9..b4f2022389 100644 --- a/go/fory/buffer_test.go +++ b/go/fory/buffer_test.go @@ -18,6 +18,7 @@ package fory import ( + "bytes" "testing" "github.com/stretchr/testify/require" @@ -111,3 +112,29 @@ func TestUnsafePutVarUint32PhysicalWriteWidth(t *testing.T) { "byte at index %d is outside the 8-byte reserved window and must not be written", i) } } + +func TestReadVarUint32RejectsOverflowFifthByte(t *testing.T) { + for _, data := range [][]byte{ + {0x80, 0x80, 0x80, 0x80, 0x10}, + {0x80, 0x80, 0x80, 0x80, 0x10, 0, 0, 0}, + } { + buf := NewByteBuffer(data) + var err Error + _ = buf.ReadVarUint32(&err) + require.True(t, err.HasError(), "expected overflow error for %v", data) + } +} + +func TestReadVarUint32Small7RejectsOverflowFifthByte(t *testing.T) { + buf := NewByteBuffer([]byte{0x80, 0x80, 0x80, 0x80, 0x10}) + var err Error + _ = buf.ReadVarUint32Small7(&err) + require.True(t, err.HasError()) +} + +func TestReadVarUint32Small7StreamRejectsOverflowFifthByte(t *testing.T) { + buf := NewByteBufferFromReader(bytes.NewReader([]byte{0x80, 0x80, 0x80, 0x80, 0x10}), 4) + var err Error + _ = buf.ReadVarUint32Small7(&err) + require.True(t, err.HasError()) +} diff --git a/go/fory/meta_string_resolver.go b/go/fory/meta_string_resolver.go index 6df61d0c71..5a9e1cb8ff 100644 --- a/go/fory/meta_string_resolver.go +++ b/go/fory/meta_string_resolver.go @@ -18,7 +18,6 @@ package fory import ( - "bytes" "encoding/binary" "fmt" "github.com/apache/fory/go/fory/meta" @@ -28,10 +27,10 @@ import ( const ( SmallStringThreshold = 16 // Maximum length for "small" strings DefaultDynamicWriteMetaStrID = -1 // Default ID for dynamic strings + maxCachedMetaStrings = 8192 + smallMetaStringEncodingBits = 4 ) -type Encoding int8 - type MetaStringBytes struct { Data []byte Length int16 @@ -60,14 +59,22 @@ func (a *MetaStringBytes) Hash() int64 { type pair [2]int64 +// Mirrors Java's small MetaString read cache key: two packed byte words plus one +// compact length/encoding byte. The packed words are zero-padded and are not +// exact byte identity by themselves. +type smallMetaStringKey struct { + v1 int64 + v2 int64 + compactKey byte +} + type MetaStringResolver struct { - dynamicWriteStringID int16 // Counter for dynamic string IDs - dynamicWrittenEnumString []*MetaStringBytes // Cache of written strings - dynamicIDToEnumString []*MetaStringBytes // Cache of read strings by ID - hashToMetaStrBytes map[int64]*MetaStringBytes // Large string lookup - smallHashToMetaStrBytes map[pair]*MetaStringBytes // Small string lookup - enumStrSet map[*MetaStringBytes]struct{} // String set for deduplication - metaStrToMetaStrBytes map[*meta.MetaString]*MetaStringBytes // Conversion cache + dynamicWriteStringID int16 // Counter for dynamic string IDs + dynamicWrittenEnumString []*MetaStringBytes // Cache of written strings + dynamicIDToEnumString []*MetaStringBytes // Cache of read strings by ID + hashToMetaStrBytes map[int64]*MetaStringBytes // Large string lookup + smallHashToMetaStrBytes map[smallMetaStringKey]*MetaStringBytes // Small string lookup + metaStrToMetaStrBytes map[*meta.MetaString]*MetaStringBytes // Conversion cache } var emptyMetaStringBytes = NewMetaStringBytes([]byte{}, 256) @@ -75,8 +82,7 @@ var emptyMetaStringBytes = NewMetaStringBytes([]byte{}, 256) func NewMetaStringResolver() *MetaStringResolver { return &MetaStringResolver{ hashToMetaStrBytes: make(map[int64]*MetaStringBytes), - smallHashToMetaStrBytes: make(map[pair]*MetaStringBytes), - enumStrSet: make(map[*MetaStringBytes]struct{}), + smallHashToMetaStrBytes: make(map[smallMetaStringKey]*MetaStringBytes), metaStrToMetaStrBytes: make(map[*meta.MetaString]*MetaStringBytes), } } @@ -121,20 +127,27 @@ func (r *MetaStringResolver) ReadMetaStringBytes(buf *ByteBuffer, ctxErr *Error) return nil, *ctxErr } - length := int16(header >> 1) + lengthValue := header >> 1 if header&1 != 0 { - index := int(length) - 1 + if lengthValue == 0 || uint64(lengthValue) > uint64(MaxInt) { + return nil, fmt.Errorf("invalid dynamic index: %d", lengthValue) + } + index := int(lengthValue) - 1 if index < 0 || index >= len(r.dynamicIDToEnumString) { return nil, fmt.Errorf("invalid dynamic index: %d", index) } return r.dynamicIDToEnumString[index], nil } + if lengthValue > uint32(MaxInt16) { + return nil, fmt.Errorf("meta string length %d exceeds maximum supported length %d", lengthValue, MaxInt16) + } + length := int(lengthValue) var ( hashcode int64 - key pair + key smallMetaStringKey data []byte - encoding Encoding + encoding meta.Encoding ) // Small string optimization @@ -143,9 +156,12 @@ func (r *MetaStringResolver) ReadMetaStringBytes(buf *ByteBuffer, ctxErr *Error) r.dynamicIDToEnumString = append(r.dynamicIDToEnumString, emptyMetaStringBytes) return emptyMetaStringBytes, nil } - // ReadData encoding and data encByte := buf.ReadByte(ctxErr) - encoding = Encoding(encByte) + var encErr error + encoding, encErr = meta.EncodingFromByte(encByte) + if encErr != nil { + return nil, encErr + } data = make([]byte, length) _, err := buf.Read(data) @@ -153,29 +169,33 @@ func (r *MetaStringResolver) ReadMetaStringBytes(buf *ByteBuffer, ctxErr *Error) return nil, err } - // Compute composite hash key - if length <= 8 { - key[0] = bytesToInt64(data) - } else { - err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &key[0]) - if err != nil { - return nil, err - } - key[1] = bytesToInt64(data[8:]) + words := smallMetaStringWords(data) + key = smallMetaStringKey{ + v1: words[0], + v2: words[1], + compactKey: byte(((length - 1) << smallMetaStringEncodingBits) | int(encoding)), } - hashcode = ((key[0]*31 + key[1]) >> 8 << 8) | int64(encoding) + hashcode = computeSmallMetaStringHash(words, length, encoding) } else { // Large string handling err := binary.Read(buf, binary.LittleEndian, &hashcode) if err != nil { return nil, err } - encoding = Encoding(hashcode & 0xFF) + var encErr error + encoding, encErr = meta.EncodingFromByte(byte(hashcode & 0xFF)) + if encErr != nil { + return nil, encErr + } data = make([]byte, length) _, err = buf.Read(data) if err != nil { return nil, err } + canonicalHashcode := ComputeMetaStringHash(data, encoding) + if canonicalHashcode != hashcode { + return nil, fmt.Errorf("meta string body hash mismatch") + } } // Check string caches for existing instance @@ -191,14 +211,18 @@ func (r *MetaStringResolver) ReadMetaStringBytes(buf *ByteBuffer, ctxErr *Error) } } - // Create and cache new string instance + // Cache only after the current body has been parsed and, for large bodies, hash-validated. + // Header-keyed hits stay on the fast path; forged headers cannot poison the shared cache. m := NewMetaStringBytes(data, hashcode) if length <= SmallStringThreshold { - r.smallHashToMetaStrBytes[key] = m + if len(r.smallHashToMetaStrBytes) < maxCachedMetaStrings { + r.smallHashToMetaStrBytes[key] = m + } } else { - r.hashToMetaStrBytes[hashcode] = m + if len(r.hashToMetaStrBytes) < maxCachedMetaStrings { + r.hashToMetaStrBytes[hashcode] = m + } } - r.enumStrSet[m] = struct{}{} r.dynamicIDToEnumString = append(r.dynamicIDToEnumString, m) return m, nil @@ -222,14 +246,8 @@ func (r *MetaStringResolver) GetMetaStrBytes(metastr *meta.MetaString) *MetaStri } if length <= SmallStringThreshold { // Small string: use direct bytes as hash components - var v1, v2 int64 - if length <= 8 { - v1 = bytesToInt64(data) - } else { - binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &v1) - v2 = bytesToInt64(data[8:]) - } - hashcode = ((v1*31 + v2) >> 8 << 8) | int64(metastr.GetEncoding()) + words := smallMetaStringWords(data) + hashcode = computeSmallMetaStringHash(words, length, metastr.GetEncoding()) } else { // Large string: use MurmurHash3 h64 := Murmur3Sum64WithSeed(data, 47) @@ -253,14 +271,8 @@ func ComputeMetaStringHash(data []byte, encoding meta.Encoding) int64 { hashcode |= int64(encoding) } else if length <= SmallStringThreshold { // Small string: use direct bytes as hash components - var v1, v2 int64 - if length <= 8 { - v1 = bytesToInt64(data) - } else { - binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &v1) - v2 = bytesToInt64(data[8:]) - } - hashcode = ((v1*31 + v2) >> 8 << 8) | int64(encoding) + words := smallMetaStringWords(data) + hashcode = computeSmallMetaStringHash(words, length, encoding) } else { // Large string: use MurmurHash3 h64 := Murmur3Sum64WithSeed(data, 47) @@ -316,3 +328,15 @@ func bytesToInt64(b []byte) int64 { } return v } + +func smallMetaStringWords(data []byte) pair { + if len(data) <= 8 { + return pair{bytesToInt64(data), 0} + } + return pair{int64(binary.LittleEndian.Uint64(data[:8])), bytesToInt64(data[8:])} +} + +func computeSmallMetaStringHash(words pair, length int, encoding meta.Encoding) int64 { + hash := uint64(words[0]*31+words[1]) ^ (uint64(length) << 56) + return int64((hash & 0xffffffffffffff00) | uint64(encoding)) +} diff --git a/go/fory/meta_string_resolver_test.go b/go/fory/meta_string_resolver_test.go index bbf9e39020..5a15be8024 100644 --- a/go/fory/meta_string_resolver_test.go +++ b/go/fory/meta_string_resolver_test.go @@ -18,8 +18,11 @@ package fory import ( + "encoding/binary" "github.com/stretchr/testify/require" "testing" + + "github.com/apache/fory/go/fory/meta" ) // TestMetaStringResolverNegativeIndexPanic reproduces the CRITICAL security bug @@ -65,3 +68,125 @@ func TestMetaStringResolverBoundaryRegression(t *testing.T) { require.NoError(t, err) require.Equal(t, m, result, "Should correctly resolve the first dynamic string (index 0)") } + +func TestMetaStringResolverRejectsLargeBodyHashMismatch(t *testing.T) { + resolver := NewMetaStringResolver() + buffer := NewByteBuffer(nil) + data := []byte("0123456789abcdefg") + + buffer.WriteVarUint32Small7(uint32(len(data)) << 1) + buffer.WriteInt64(int64(meta.UTF_8)) + buffer.Write(data) + buffer.SetReaderIndex(0) + + var ctxErr Error + _, err := resolver.ReadMetaStringBytes(buffer, &ctxErr) + require.Error(t, err) + require.Empty(t, resolver.hashToMetaStrBytes) + require.Empty(t, resolver.dynamicIDToEnumString) +} + +func TestMetaStringResolverRejectsOversizedLengthBeforeAllocation(t *testing.T) { + resolver := NewMetaStringResolver() + buffer := NewByteBuffer(nil) + buffer.WriteVarUint32Small7(uint32(MaxInt16+1) << 1) + buffer.SetReaderIndex(0) + + var ctxErr Error + _, err := resolver.ReadMetaStringBytes(buffer, &ctxErr) + require.Error(t, err) + require.Contains(t, err.Error(), "meta string length") + require.Empty(t, resolver.hashToMetaStrBytes) + require.Empty(t, resolver.smallHashToMetaStrBytes) + require.Empty(t, resolver.dynamicIDToEnumString) +} + +func TestMetaStringResolverSmallCacheKeyIncludesLengthAndEncoding(t *testing.T) { + resolver := NewMetaStringResolver() + + oneByte := NewByteBuffer(nil) + oneByte.WriteVarUint32Small7(1 << 1) + oneByte.WriteByte(byte(meta.UTF_8)) + oneByte.WriteByte(1) + oneByte.SetReaderIndex(0) + var oneErr Error + first, err := resolver.ReadMetaStringBytes(oneByte, &oneErr) + require.NoError(t, err) + require.Equal(t, []byte{1}, first.Data) + + twoBytes := NewByteBuffer(nil) + twoBytes.WriteVarUint32Small7(2 << 1) + twoBytes.WriteByte(byte(meta.UTF_8)) + twoBytes.Write([]byte{1, 0}) + twoBytes.SetReaderIndex(0) + var twoErr Error + second, err := resolver.ReadMetaStringBytes(twoBytes, &twoErr) + require.NoError(t, err) + require.Equal(t, []byte{1, 0}, second.Data) + require.NotSame(t, first, second) + + differentEncoding := NewByteBuffer(nil) + differentEncoding.WriteVarUint32Small7(1 << 1) + differentEncoding.WriteByte(byte(meta.LOWER_SPECIAL)) + differentEncoding.WriteByte(1) + differentEncoding.SetReaderIndex(0) + var encodingErr Error + third, err := resolver.ReadMetaStringBytes(differentEncoding, &encodingErr) + require.NoError(t, err) + require.Equal(t, []byte{1}, third.Data) + require.Equal(t, meta.LOWER_SPECIAL, third.Encoding) + require.NotSame(t, first, third) + require.Len(t, resolver.smallHashToMetaStrBytes, 3) +} + +func TestComputeMetaStringHashIncludesSmallLength(t *testing.T) { + require.NotEqual( + t, + ComputeMetaStringHash([]byte{1}, meta.UTF_8), + ComputeMetaStringHash([]byte{1, 0}, meta.UTF_8), + ) + require.NotEqual( + t, + ComputeMetaStringHash([]byte{1}, meta.UTF_8), + ComputeMetaStringHash([]byte{1}, meta.LOWER_SPECIAL), + ) +} + +func TestMetaStringResolverReadCachesAreCapped(t *testing.T) { + resolver := NewMetaStringResolver() + smallKey := smallMetaStringKey{v1: 1, v2: 0, compactKey: byte(meta.UTF_8)} + for i := 0; i < maxCachedMetaStrings; i++ { + resolver.smallHashToMetaStrBytes[smallMetaStringKey{ + v1: int64(i + 2), + v2: 0, + compactKey: byte(meta.UTF_8), + }] = + NewMetaStringBytes([]byte{byte(i)}, int64(i+2)<<8) + resolver.hashToMetaStrBytes[int64(i+2)<<8] = + NewMetaStringBytes([]byte("0123456789abcdefg"), int64(i+2)<<8) + } + + smallBuffer := NewByteBuffer(nil) + smallBuffer.WriteVarUint32Small7(1 << 1) + smallBuffer.WriteByte(byte(meta.UTF_8)) + smallBuffer.WriteByte(1) + smallBuffer.SetReaderIndex(0) + var smallErr Error + _, err := resolver.ReadMetaStringBytes(smallBuffer, &smallErr) + require.NoError(t, err) + require.Len(t, resolver.smallHashToMetaStrBytes, maxCachedMetaStrings) + require.NotContains(t, resolver.smallHashToMetaStrBytes, smallKey) + + largeData := []byte("0123456789abcdefg") + largeHash := ComputeMetaStringHash(largeData, meta.UTF_8) + largeBuffer := NewByteBuffer(nil) + largeBuffer.WriteVarUint32Small7(uint32(len(largeData)) << 1) + require.NoError(t, binary.Write(largeBuffer, binary.LittleEndian, largeHash)) + largeBuffer.Write(largeData) + largeBuffer.SetReaderIndex(0) + var largeErr Error + _, err = resolver.ReadMetaStringBytes(largeBuffer, &largeErr) + require.NoError(t, err) + require.Len(t, resolver.hashToMetaStrBytes, maxCachedMetaStrings) + require.NotContains(t, resolver.hashToMetaStrBytes, largeHash) +} diff --git a/go/fory/stream.go b/go/fory/stream.go index 00de53abe2..bb86689598 100644 --- a/go/fory/stream.go +++ b/go/fory/stream.go @@ -92,42 +92,26 @@ func (is *InputStream) Shrink() { } // DeserializeFromStream reads the next object from the stream into the provided value. -// It uses a shared ReadContext for the lifetime of the InputStream, clearing -// temporary state between calls but preserving the buffer and TypeResolver state. +// It preserves the stream buffer while clearing root-scoped read metadata between calls. func (f *Fory) DeserializeFromStream(is *InputStream, v any) error { - - // We only reset the temporary read state (like refTracker and outOfBand buffers), - // NOT the buffer or the type mapping, which must persist. - defer func() { - f.readCtx.refReader.Reset() - f.readCtx.outOfBandBuffers = nil - f.readCtx.outOfBandIndex = 0 - f.readCtx.err = Error{} - if f.readCtx.refResolver != nil { - f.readCtx.refResolver.resetRead() - } - }() - - // Temporarily swap buffer origBuffer := f.readCtx.buffer f.readCtx.buffer = is.buffer + defer func() { + f.readCtx.buffer = origBuffer + f.resetReadState() + }() readHeader(f.readCtx) if f.readCtx.HasError() { - f.readCtx.buffer = origBuffer return f.readCtx.TakeError() } target := reflect.ValueOf(v).Elem() f.readCtx.ReadValue(target, RefModeTracking, true) if f.readCtx.HasError() { - f.readCtx.buffer = origBuffer return f.readCtx.TakeError() } - // Restore original buffer - f.readCtx.buffer = origBuffer - return nil } diff --git a/go/fory/stream_test.go b/go/fory/stream_test.go index 098254587c..0c2503f3cb 100644 --- a/go/fory/stream_test.go +++ b/go/fory/stream_test.go @@ -135,6 +135,29 @@ func TestStreamDeserializationEOF(t *testing.T) { } } +func TestDeserializeFromStreamClearsReadMetadataOnError(t *testing.T) { + f := New(WithCompatible(true)) + f.typeResolver.metaStringResolver.dynamicIDToEnumString = + append(f.typeResolver.metaStringResolver.dynamicIDToEnumString, emptyMetaStringBytes) + f.metaContext.readTypeInfos = append(f.metaContext.readTypeInfos, &TypeInfo{}) + + stream := NewInputStream(bytes.NewReader(nil)) + var out int32 + err := f.DeserializeFromStream(stream, &out) + if err == nil { + t.Fatal("Expected error on empty stream, got nil") + } + if len(f.typeResolver.metaStringResolver.dynamicIDToEnumString) != 0 { + t.Fatalf( + "expected stream root cleanup to clear metastring refs, got %d", + len(f.typeResolver.metaStringResolver.dynamicIDToEnumString), + ) + } + if len(f.metaContext.readTypeInfos) != 0 { + t.Fatalf("expected stream root cleanup to clear type metadata, got %d", len(f.metaContext.readTypeInfos)) + } +} + func TestInputStreamSequential(t *testing.T) { f := New() // Register type in compatible mode to test Meta Sharing across sequential reads diff --git a/go/fory/string.go b/go/fory/string.go index e20ac25af5..9165aaf66c 100644 --- a/go/fory/string.go +++ b/go/fory/string.go @@ -106,7 +106,9 @@ func readUTF16LE(buf *ByteBuffer, byteCount int, err *Error) string { func readUTF8(buf *ByteBuffer, size int, err *Error) string { data := buf.ReadBinary(size, err) - return string(data) // Direct UTF-8 conversion + // Go intentionally keeps direct string conversion here. Rust is the runtime that checks UTF-8 + // string payloads by default; Go preserves its platform behavior for invalid byte sequences. + return string(data) } // ============================================================================ diff --git a/go/fory/struct.go b/go/fory/struct.go index a609113cbb..b2d718db86 100644 --- a/go/fory/struct.go +++ b/go/fory/struct.go @@ -1541,13 +1541,13 @@ func (s *structSerializer) ReadData(ctx *ReadContext, value reflect.Value) { } switch field.DispatchId { case PrimitiveVarint32DispatchId: - storeFieldValue(field.Kind, fieldPtr, optInfo, buf.UnsafeReadVarint32()) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.UnsafeReadVarint32(err)) case PrimitiveVarint64DispatchId: storeFieldValue(field.Kind, fieldPtr, optInfo, buf.UnsafeReadVarint64()) case PrimitiveIntDispatchId: storeFieldValue(field.Kind, fieldPtr, optInfo, int(buf.UnsafeReadVarint64())) case PrimitiveVarUint32DispatchId: - storeFieldValue(field.Kind, fieldPtr, optInfo, buf.UnsafeReadVarUint32()) + storeFieldValue(field.Kind, fieldPtr, optInfo, buf.UnsafeReadVarUint32(err)) case PrimitiveVarUint64DispatchId: storeFieldValue(field.Kind, fieldPtr, optInfo, buf.UnsafeReadVarUint64()) case PrimitiveUintDispatchId: diff --git a/go/fory/tests/metastring_resolver_test.go b/go/fory/tests/metastring_resolver_test.go index fa833b8242..b0b8202204 100644 --- a/go/fory/tests/metastring_resolver_test.go +++ b/go/fory/tests/metastring_resolver_test.go @@ -46,7 +46,7 @@ func TestMetaStringResolver(t *testing.T) { // Test 2: Manually constructed MetaStringBytes data2 := []byte{0xBF, 0x05, 0xA4, 0x71, 0xA9, 0x92, 0x53, 0x96, 0xA6, 0x49, 0x4F, 0x72, 0x9C, 0x68, 0x29, 0x80} - metaBytes2 := fory.NewMetaStringBytes(data2, int64(-5456063526933366015)) + metaBytes2 := fory.NewMetaStringBytes(data2, fory.ComputeMetaStringHash(data2, meta.LOWER_SPECIAL)) resolver.WriteMetaStringBytes(buffer, metaBytes2, &bufErr) if bufErr.HasError() { t.Fatalf("write failed: %v", bufErr.Error()) diff --git a/go/fory/type_def.go b/go/fory/type_def.go index 70ea7e1a03..e1ab7c3eb4 100644 --- a/go/fory/type_def.go +++ b/go/fory/type_def.go @@ -284,6 +284,9 @@ func readTypeDef(fory *Fory, buffer *ByteBuffer, header int64, err *Error) *Type } func skipTypeDef(buffer *ByteBuffer, header int64, err *Error) { + // Header-cache hits intentionally treat the current body as opaque bytes and skip by the size in + // the current header. Parsed TypeDefs are published to the cache only after successful body parse + // and 52-bit body-hash validation; cache hits must not reparse or rehash that body. sz := int(header & META_SIZE_MASK) if sz == META_SIZE_MASK { sz += int(buffer.ReadVarUint32(err)) @@ -1178,7 +1181,9 @@ func decodeTypeDef(fory *Fory, buffer *ByteBuffer, header int64) (*TypeDef, erro if fallbackInfo, fallbackExists := fory.typeResolver.namedTypeToTypeInfo[nameKey]; fallbackExists { info = fallbackInfo exists = true - fory.typeResolver.nsTypeToTypeInfo[nsTypeKey{nsBytes.Hashcode, nameBytes.Hashcode}] = info + if len(fory.typeResolver.nsTypeToTypeInfo) < maxCachedNamedTypeInfos { + fory.typeResolver.nsTypeToTypeInfo[nsTypeKey{nsBytes.Hashcode, nameBytes.Hashcode}] = info + } } } if exists { diff --git a/go/fory/type_def_test.go b/go/fory/type_def_test.go index 55a8a4efeb..37b465001a 100644 --- a/go/fory/type_def_test.go +++ b/go/fory/type_def_test.go @@ -465,6 +465,36 @@ func TestReadSharedTypeMetaCapsParsedTypeDefCache(t *testing.T) { require.NotContains(t, fory.typeResolver.defIdToTypeDef, header) } +func TestDecodeTypeDefFallbackNamedTypeCacheRespectsCap(t *testing.T) { + fory := NewFory(WithCompatible(true)) + require.NoError(t, fory.RegisterNamedStruct(SimpleStruct{}, "example.SimpleStruct")) + typeDef, err := buildTypeDef(fory, reflect.ValueOf(SimpleStruct{})) + require.NoError(t, err) + require.NotNil(t, typeDef.nsName) + require.NotNil(t, typeDef.typeName) + + nameKey := nsTypeKey{typeDef.nsName.Hashcode, typeDef.typeName.Hashcode} + delete(fory.typeResolver.nsTypeToTypeInfo, nameKey) + info := fory.typeResolver.namedTypeToTypeInfo[[2]string{"example", "SimpleStruct"}] + require.NotNil(t, info) + for i := 0; len(fory.typeResolver.nsTypeToTypeInfo) < maxCachedNamedTypeInfos; i++ { + fory.typeResolver.nsTypeToTypeInfo[nsTypeKey{int64(i + 1), int64(i + 2)}] = info + } + require.NotContains(t, fory.typeResolver.nsTypeToTypeInfo, nameKey) + + buffer := NewByteBuffer(nil) + readErr := &Error{} + typeDef.writeTypeDef(buffer, readErr) + require.NoError(t, readErr.CheckError()) + header := buffer.ReadInt64(readErr) + require.NoError(t, readErr.CheckError()) + decoded := readTypeDef(fory, buffer, header, readErr) + require.NoError(t, readErr.CheckError()) + require.NotNil(t, decoded) + require.Len(t, fory.typeResolver.nsTypeToTypeInfo, maxCachedNamedTypeInfos) + require.NotContains(t, fory.typeResolver.nsTypeToTypeInfo, nameKey) +} + func TestTypeDefRejectsNamespaceLengthBeyondMetadata(t *testing.T) { fory := NewFory() meta := NewByteBuffer(nil) diff --git a/go/fory/type_resolver.go b/go/fory/type_resolver.go index a42b01ec95..3838e67e25 100644 --- a/go/fory/type_resolver.go +++ b/go/fory/type_resolver.go @@ -52,10 +52,11 @@ const ( useStringId = 1 SMALL_STRING_THRESHOLD = 16 // 0xffffffff is reserved for "unset". - maxUserTypeID uint32 = 0xfffffffe - invalidUserTypeID uint32 = 0xffffffff - internalTypeIDLimit = 0xFF - maxCachedTypeDefs = 8192 + maxUserTypeID uint32 = 0xfffffffe + invalidUserTypeID uint32 = 0xffffffff + internalTypeIDLimit = 0xFF + maxCachedTypeDefs = 8192 + maxCachedNamedTypeInfos = 8192 ) var ( @@ -1643,9 +1644,6 @@ func (r *TypeResolver) readSharedTypeMeta(buffer *ByteBuffer, err *Error) *TypeI if err.HasError() { return nil } - if len(r.defIdToTypeDef) < maxCachedTypeDefs { - r.defIdToTypeDef[id] = newTd - } td = newTd } @@ -1654,6 +1652,9 @@ func (r *TypeResolver) readSharedTypeMeta(buffer *ByteBuffer, err *Error) *TypeI err.SetError(typeInfoErr) return nil } + if _, exists := r.defIdToTypeDef[id]; !exists && len(r.defIdToTypeDef) < maxCachedTypeDefs { + r.defIdToTypeDef[id] = td + } context.readTypeInfos = append(context.readTypeInfos, typeInfo) return typeInfo @@ -2167,7 +2168,9 @@ func (r *TypeResolver) resolveTypeInfoByMetaBytes(nsBytes, typeBytes *MetaString nameKey := [2]string{ns, typeName} if typeInfo, exists := r.namedTypeToTypeInfo[nameKey]; exists { - r.nsTypeToTypeInfo[compositeKey] = typeInfo + if len(r.nsTypeToTypeInfo) < maxCachedNamedTypeInfos { + r.nsTypeToTypeInfo[compositeKey] = typeInfo + } return typeInfo } diff --git a/java/fory-core/src/main/java/org/apache/fory/config/Config.java b/java/fory-core/src/main/java/org/apache/fory/config/Config.java index c820986e42..bd7c3c710a 100644 --- a/java/fory-core/src/main/java/org/apache/fory/config/Config.java +++ b/java/fory-core/src/main/java/org/apache/fory/config/Config.java @@ -65,6 +65,8 @@ public class Config implements Serializable { private final boolean serializeEnumByName; private final int bufferSizeLimitBytes; private final int maxDepth; + private final int maxBinarySize; + private final int maxCollectionSize; private final float mapRefLoadFactor; private final boolean forVirtualThread; @@ -107,6 +109,8 @@ public Config(ForyBuilder builder) { serializeEnumByName = builder.serializeEnumByName; bufferSizeLimitBytes = builder.bufferSizeLimitBytes; maxDepth = builder.maxDepth; + maxBinarySize = builder.maxBinarySize; + maxCollectionSize = builder.maxCollectionSize; mapRefLoadFactor = builder.mapRefLoadFactor; forVirtualThread = builder.forVirtualThread; } @@ -298,6 +302,16 @@ public int maxDepth() { return maxDepth; } + /** Returns max binary payload size for attacker-controlled binary and primitive-array lengths. */ + public int maxBinarySize() { + return maxBinarySize; + } + + /** Returns max collection allocation size for attacker-controlled collection lengths. */ + public int maxCollectionSize() { + return maxCollectionSize; + } + /** Returns loadFactor of MacRef's writtenObjects. */ public float mapRefLoadFactor() { return mapRefLoadFactor; @@ -332,6 +346,8 @@ public boolean equals(Object o) { && compressIntArray == config.compressIntArray && compressLongArray == config.compressLongArray && bufferSizeLimitBytes == config.bufferSizeLimitBytes + && maxBinarySize == config.maxBinarySize + && maxCollectionSize == config.maxCollectionSize && requireClassRegistration == config.requireClassRegistration && suppressClassRegistrationWarnings == config.suppressClassRegistrationWarnings && registerGuavaTypes == config.registerGuavaTypes @@ -371,6 +387,8 @@ public int hashCode() { compressIntArray, compressLongArray, bufferSizeLimitBytes, + maxBinarySize, + maxCollectionSize, requireClassRegistration, suppressClassRegistrationWarnings, registerGuavaTypes, diff --git a/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java b/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java index 5c3d6a45dd..52bd6b08ba 100644 --- a/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java @@ -96,6 +96,8 @@ public final class ForyBuilder { Integer bufferSizeLimitBytes = -1; MetaCompressor metaCompressor = new DeflaterMetaCompressor(); int maxDepth = 50; + int maxBinarySize = 64 * 1024 * 1024; + int maxCollectionSize = 1_000_000; float mapRefLoadFactor = 0.51f; boolean forVirtualThread = false; TypeChecker typeChecker; @@ -473,6 +475,30 @@ public ForyBuilder withMaxDepth(int maxDepth) { return this; } + /** + * Set max binary payload size for deserialization. Binary and primitive-array byte lengths above + * this limit are rejected before allocation. Default max binary size is 64 MiB. + */ + public ForyBuilder withMaxBinarySize(int maxBinarySize) { + Preconditions.checkArgument( + maxBinarySize >= 0, "maxBinarySize must >= 0 but got %s", maxBinarySize); + this.maxBinarySize = maxBinarySize; + recordAction(b -> b.withMaxBinarySize(maxBinarySize)); + return this; + } + + /** + * Set max collection size for deserialization. Collection lengths and collection capacity fields + * above this limit are rejected before allocation. Default max collection size is 1,000,000. + */ + public ForyBuilder withMaxCollectionSize(int maxCollectionSize) { + Preconditions.checkArgument( + maxCollectionSize >= 0, "maxCollectionSize must >= 0 but got %s", maxCollectionSize); + this.maxCollectionSize = maxCollectionSize; + recordAction(b -> b.withMaxCollectionSize(maxCollectionSize)); + return this; + } + /** Set loadFactor of MapRefResolver writtenObjects. Default value is 0.51 */ public ForyBuilder withMapRefLoadFactor(float loadFactor) { Preconditions.checkArgument( diff --git a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java index 5def56c660..0b03800dd6 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java @@ -467,11 +467,13 @@ public MemoryBuffer readBufferObject() { } else { size = buffer.readVarUInt32(); } - if (buffer.readerIndex() + size > buffer.size() && buffer.getStreamReader() != null) { - buffer.getStreamReader().fillBuffer(buffer.readerIndex() + size - buffer.size()); + if (size < 0) { + throw new IllegalArgumentException("Buffer object size must be non-negative: " + size); } - MemoryBuffer slice = buffer.slice(buffer.readerIndex(), size); - buffer.readerIndex(buffer.readerIndex() + size); + buffer.checkReadableBytes(size); + int readerIndex = buffer.readerIndex(); + MemoryBuffer slice = buffer.slice(readerIndex, size); + buffer.readerIndex(readerIndex + size); return slice; } Preconditions.checkArgument(outOfBandBuffers.hasNext()); diff --git a/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java b/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java index 2ae29eaab1..89a0c96f8f 100644 --- a/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java +++ b/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java @@ -827,15 +827,15 @@ public int writeVarUInt32Small7(int value) { } private int continueWriteVarUInt32Small7(int value) { - long encoded = (value & 0x7F); + int encoded = (value & 0x7F); encoded |= (((value & 0x3f80) << 1) | 0x80); int writerIdx = writerIndex; if (value >>> 14 == 0) { - _unsafePutInt32(writerIdx, (int) encoded); + _unsafePutInt32(writerIdx, encoded); writerIndex += 2; return 2; } - int diff = continuePutVarInt36(writerIdx, encoded, value); + int diff = continuePutVarUInt32(writerIdx, encoded, value); writerIndex += diff; return diff; } @@ -1811,7 +1811,7 @@ public int _readVarInt32OnLE() { int readIdx = readerIndex; int result; if (size - readIdx < 5) { - result = (int) readVarUint36Slow(); + result = readVarUInt32Slow(); } else { long address = this.address; // | 1bit + 7bits | 1bit + 7bits | 1bit + 7bits | 1bit + 7bits | @@ -1835,7 +1835,11 @@ public int _readVarInt32OnLE() { // 0xfe00000: 0b1111111 << 21 result |= (fourByteValue >>> 3) & 0xfe00000; if ((fourByteValue & 0x80000000) != 0) { - result |= (UNSAFE.getByte(heapMemory, address + readIdx++) & 0x7F) << 28; + int fifthByte = UNSAFE.getByte(heapMemory, address + readIdx++) & 0xFF; + if ((fifthByte & 0xF0) != 0) { + throwMalformedVarUInt32(fifthByte); + } + result |= fifthByte << 28; } } } @@ -1854,7 +1858,7 @@ public int _readVarInt32OnBE() { int readIdx = readerIndex; int result; if (size - readIdx < 5) { - result = (int) readVarUint36Slow(); + result = readVarUInt32Slow(); } else { long address = this.address; int fourByteValue = Integer.reverseBytes(UNSAFE.getInt(heapMemory, address + readIdx)); @@ -1877,7 +1881,11 @@ public int _readVarInt32OnBE() { // 0xfe00000: 0b1111111 << 21 result |= (fourByteValue >>> 3) & 0xfe00000; if ((fourByteValue & 0x80000000) != 0) { - result |= (UNSAFE.getByte(heapMemory, address + readIdx++) & 0x7F) << 28; + int fifthByte = UNSAFE.getByte(heapMemory, address + readIdx++) & 0xFF; + if ((fifthByte & 0xF0) != 0) { + throwMalformedVarUInt32(fifthByte); + } + result |= fifthByte << 28; } } } @@ -1956,11 +1964,45 @@ private long readVarUint36Slow() { return result; } + private int readVarUInt32Slow() { + int b = readByte() & 0xFF; + int result = b & 0x7F; + // Note: + // Loop are not used here to improve performance. + // We manually unroll the loop for better performance. + // noinspection Duplicates + if ((b & 0x80) != 0) { + b = readByte() & 0xFF; + result |= (b & 0x7F) << 7; + if ((b & 0x80) != 0) { + b = readByte() & 0xFF; + result |= (b & 0x7F) << 14; + if ((b & 0x80) != 0) { + b = readByte() & 0xFF; + result |= (b & 0x7F) << 21; + if ((b & 0x80) != 0) { + b = readByte() & 0xFF; + if ((b & 0xF0) != 0) { + throwMalformedVarUInt32(b); + } + result |= b << 28; + } + } + } + } + return result; + } + + private static void throwMalformedVarUInt32(int fifthByte) { + throw new IllegalArgumentException( + "Malformed varuint32 fifth byte " + fifthByte + " exceeds 32 bits"); + } + /** Reads the 1-5 byte int part of a non-negative varint. */ public int readVarUInt32() { int readIdx = readerIndex; if (size - readIdx < 5) { - return (int) readVarUint36Slow(); + return readVarUInt32Slow(); } // | 1bit + 7bits | 1bit + 7bits | 1bit + 7bits | 1bit + 7bits | int fourByteValue = _unsafeGetInt32(readIdx); @@ -1983,7 +2025,11 @@ public int readVarUInt32() { // 0xfe00000: 0b1111111 << 21 result |= (fourByteValue >>> 3) & 0xfe00000; if ((fourByteValue & 0x80000000) != 0) { - result |= (UNSAFE.getByte(heapMemory, address + readIdx++) & 0x7F) << 28; + int fifthByte = UNSAFE.getByte(heapMemory, address + readIdx++) & 0xFF; + if ((fifthByte & 0xF0) != 0) { + throwMalformedVarUInt32(fifthByte); + } + result |= fifthByte << 28; } } } @@ -2031,7 +2077,7 @@ public int readVarUInt32Small14() { readerIndex = readIdx; return value; } else { - return (int) readVarUint36Slow(); + return readVarUInt32Slow(); } } @@ -2044,7 +2090,11 @@ private int continueReadVarUInt32(int readIdx, int bulkRead, int value) { readIdx++; value |= (bulkRead >>> 3) & 0xfe00000; if ((bulkRead & 0x80000000) != 0) { - value |= (UNSAFE.getByte(heapMemory, address + readIdx++) & 0x7F) << 28; + int fifthByte = UNSAFE.getByte(heapMemory, address + readIdx++) & 0xFF; + if ((fifthByte & 0xF0) != 0) { + throwMalformedVarUInt32(fifthByte); + } + value |= fifthByte << 28; } } readerIndex = readIdx; @@ -2440,7 +2490,7 @@ public int readBinarySize() { } readerIndex = readIdx; } else { - binarySize = (int) readVarUint36Slow(); + binarySize = readVarUInt32Slow(); readIdx = readerIndex; } int diff = size - readIdx; @@ -2459,7 +2509,11 @@ private int continueReadBinarySize(int readIdx, int bulkRead, int binarySize) { readIdx++; binarySize |= (bulkRead >>> 3) & 0xfe00000; if ((bulkRead & 0x80000000) != 0) { - binarySize |= (UNSAFE.getByte(heapMemory, address + readIdx++) & 0x7F) << 28; + int fifthByte = UNSAFE.getByte(heapMemory, address + readIdx++) & 0xFF; + if ((fifthByte & 0xF0) != 0) { + throwMalformedVarUInt32(fifthByte); + } + binarySize |= fifthByte << 28; } } int diff = size - readIdx; diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/TypeDef.java b/java/fory-core/src/main/java/org/apache/fory/meta/TypeDef.java index abfa826fcc..6069472e91 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/TypeDef.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/TypeDef.java @@ -34,6 +34,7 @@ import java.util.stream.Collectors; import org.apache.fory.builder.MetaSharedCodecBuilder; import org.apache.fory.config.ForyBuilder; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.logging.Logger; import org.apache.fory.logging.LoggerFactory; import org.apache.fory.memory.MemoryBuffer; @@ -114,12 +115,20 @@ public class TypeDef implements Serializable { } public static void skipTypeDef(MemoryBuffer buffer, long id) { - // Header-cache hits use the validated header as the cache key. The current body is skipped by - // its declared size; body hash validation belongs to the parse-before-cache-publication path. + // Header-cache hits intentionally treat the current body as opaque bytes and skip by the size + // in + // the current header. Parsed TypeDefs are published to the cache only after successful body + // parse + // and 52-bit body-hash validation; cache hits must not reparse or rehash that body. int size = (int) (id & META_SIZE_MASKS); if (size == META_SIZE_MASKS) { - size += buffer.readVarUInt32Small14(); + int extendedSize = buffer.readVarUInt32Small14(); + if (extendedSize < 0 || extendedSize > Integer.MAX_VALUE - size) { + throw new DeserializationException("Invalid TypeDef metadata size " + extendedSize); + } + size += extendedSize; } + buffer.checkReadableBytes(size); buffer.increaseReaderIndex(size); } diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java index 7233414ee8..c5856d6d21 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java @@ -64,6 +64,7 @@ import org.apache.fory.context.ReadContext; import org.apache.fory.context.WriteContext; import org.apache.fory.exception.ForyException; +import org.apache.fory.exception.InsecureException; import org.apache.fory.exception.SerializerUnregisteredException; import org.apache.fory.logging.Logger; import org.apache.fory.logging.LoggerFactory; @@ -1113,7 +1114,10 @@ final Class loadClass(String className) { final Class loadClass( String className, boolean isEnum, int arrayDims, boolean deserializeUnknownClass) { - extRegistry.typeChecker.checkType(this, className); + if (!extRegistry.typeChecker.checkType(this, className)) { + throw new InsecureException( + String.format("Class %s is forbidden for serialization.", className)); + } Class cls = extRegistry.registeredClasses.get(className); if (cls != null) { return cls; diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java index b5bb42eb22..24a6a251a9 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java @@ -24,6 +24,7 @@ import org.apache.fory.context.CopyContext; import org.apache.fory.context.ReadContext; import org.apache.fory.context.WriteContext; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.resolver.ClassResolver; import org.apache.fory.resolver.RefMode; @@ -46,6 +47,19 @@ public final class ArraySerializers { private ArraySerializers() {} + private static void throwObjectArraySizeLimitExceeded(int size, int maxCollectionSize) { + throw new DeserializationException( + "Object array size " + size + " exceeds max collection size " + maxCollectionSize); + } + + private static void throwInvalidObjectArraySize(int size, int maxCollectionSize) { + if (size < 0) { + throw new DeserializationException("Object array size must be non-negative: " + size); + } else { + throwObjectArraySizeLimitExceeded(size, maxCollectionSize); + } + } + /** * Returns the object-array serializer for {@code cls}. * @@ -85,6 +99,7 @@ public static Serializer newObjectArraySerializer(TypeResolver typeResolver, public static final class ObjectArraySerializer extends Serializer { private final TypeResolver typeResolver; private final TypeInfoHolder elementTypeInfoHolder; + private final int maxCollectionSize; public ObjectArraySerializer(TypeResolver typeResolver, Class cls) { super(typeResolver.getConfig(), (Class) cls); @@ -94,6 +109,7 @@ public ObjectArraySerializer(TypeResolver typeResolver, Class cls) { } Preconditions.checkArgument(cls.isArray() && !cls.getComponentType().isPrimitive()); elementTypeInfoHolder = typeResolver.nilTypeInfoHolder(); + maxCollectionSize = typeResolver.getConfig().maxCollectionSize(); } @Override @@ -125,6 +141,11 @@ public Object[] copy(CopyContext copyContext, Object[] originArray) { public Object[] read(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); int numElements = buffer.readVarUInt32Small7(); + // Keep this as direct primitive branches. Object-array reads allocate immediately; using + // Preconditions.checkArgument here would add helper/varargs overhead on the valid path. + if (numElements < 0 || numElements > maxCollectionSize) { + throwInvalidObjectArraySize(numElements, maxCollectionSize); + } Object[] value = newArray(numElements); readContext.reference(value); if (numElements != 0) { @@ -158,6 +179,7 @@ private abstract static class SameTypeObjectArraySerializer extends Serializer componentType; private final Serializer elementSerializer; private final TypeInfoHolder elementTypeInfoHolder; + private final int maxCollectionSize; SameTypeObjectArraySerializer( TypeResolver typeResolver, Class arrayType, Class componentType) { @@ -169,6 +191,7 @@ private abstract static class SameTypeObjectArraySerializer extends Serializer maxCollectionSize) { + throwInvalidObjectArraySize(numElements, maxCollectionSize); + } Object[] value = newArray(numElements); readContext.reference(value); if (numElements != 0) { @@ -616,6 +644,7 @@ public static final class UnknownArraySerializer extends Serializer { private final String className; private final TypeResolver typeResolver; private final TypeInfoHolder elementTypeInfoHolder; + private final int maxCollectionSize; public UnknownArraySerializer(TypeResolver typeResolver, Class cls) { this(typeResolver, "Unknown", cls); @@ -627,6 +656,7 @@ public UnknownArraySerializer(TypeResolver typeResolver, String className, Class this.className = className; this.typeResolver = typeResolver; elementTypeInfoHolder = typeResolver.nilTypeInfoHolder(); + maxCollectionSize = typeResolver.getConfig().maxCollectionSize(); } @Override @@ -639,6 +669,11 @@ public void write(WriteContext writeContext, Object[] value) { public Object[] read(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); int numElements = buffer.readVarUInt32Small7(); + // Keep this as direct primitive branches. Object-array reads allocate immediately; using + // Preconditions.checkArgument here would add helper/varargs overhead on the valid path. + if (numElements < 0 || numElements > maxCollectionSize) { + throwInvalidObjectArraySize(numElements, maxCollectionSize); + } Object[] value = newArray(numElements); readContext.reference(value); if (numElements != 0) { diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/BigIntegerSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/BigIntegerSerializer.java index f5641d7701..6378e5e72f 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/BigIntegerSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/BigIntegerSerializer.java @@ -23,16 +23,19 @@ import org.apache.fory.config.Config; import org.apache.fory.context.ReadContext; import org.apache.fory.context.WriteContext; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.memory.MemoryBuffer; /** Serializer for {@link BigInteger} in native and xlang modes. */ public final class BigIntegerSerializer extends ImmutableSerializer implements Shareable { private final boolean xlang; + private final int maxBinarySize; public BigIntegerSerializer(Config config) { super(config, BigInteger.class); xlang = config.isXlang(); + maxBinarySize = config.maxBinarySize(); } @Override @@ -62,6 +65,8 @@ private void writeNative(WriteContext writeContext, BigInteger value) { private BigInteger readNative(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); int len = buffer.readVarUInt32Small7(); + checkBinaryPayloadLength(len, maxBinarySize); + buffer.checkReadableBytes(len); byte[] bytes = buffer.readBytes(len); return new BigInteger(bytes); } @@ -71,6 +76,16 @@ private void writeXlang(WriteContext writeContext, BigInteger value) { } private BigInteger readXlang(ReadContext readContext) { - return DecimalSerializer.readXlangBigInteger(readContext.getBuffer()); + return DecimalSerializer.readXlangBigInteger(readContext.getBuffer(), maxBinarySize); + } + + private static void checkBinaryPayloadLength(int len, int maxBinarySize) { + if (len <= 0) { + throw new DeserializationException("BigInteger payload length must be positive: " + len); + } + if (len > maxBinarySize) { + throw new DeserializationException( + "BigInteger payload length " + len + " exceeds max binary size " + maxBinarySize); + } } } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/DecimalSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/DecimalSerializer.java index 796d679010..05043113f4 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/DecimalSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/DecimalSerializer.java @@ -25,6 +25,7 @@ import org.apache.fory.config.Config; import org.apache.fory.context.ReadContext; import org.apache.fory.context.WriteContext; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.memory.MemoryBuffer; /** Serializer for {@link BigDecimal} in native and xlang modes. */ @@ -32,10 +33,12 @@ public final class DecimalSerializer extends ImmutableSerializer imp private static final BigInteger LONG_MIN = BigInteger.valueOf(Long.MIN_VALUE); private static final BigInteger LONG_MAX = BigInteger.valueOf(Long.MAX_VALUE); private final boolean xlang; + private final int maxBinarySize; public DecimalSerializer(Config config) { super(config, BigDecimal.class); xlang = config.isXlang(); + maxBinarySize = config.maxBinarySize(); } @Override @@ -69,6 +72,8 @@ private BigDecimal readNative(ReadContext readContext) { int scale = buffer.readVarUInt32Small7(); int precision = buffer.readVarUInt32Small7(); int len = buffer.readVarUInt32Small7(); + checkBinaryPayloadLength(len, maxBinarySize); + buffer.checkReadableBytes(len); byte[] bytes = buffer.readBytes(len); BigInteger bigInteger = new BigInteger(bytes); return new BigDecimal(bigInteger, scale, new MathContext(precision)); @@ -79,7 +84,7 @@ private void writeXlang(WriteContext writeContext, BigDecimal value) { } private BigDecimal readXlang(ReadContext readContext) { - return readXlangDecimal(readContext.getBuffer()); + return readXlangDecimal(readContext.getBuffer(), maxBinarySize); } static void writeXlangDecimal(MemoryBuffer buffer, int scale, BigInteger unscaled) { @@ -100,13 +105,21 @@ static void writeXlangDecimal(MemoryBuffer buffer, int scale, BigInteger unscale } static BigDecimal readXlangDecimal(MemoryBuffer buffer) { + return readXlangDecimal(buffer, Integer.MAX_VALUE); + } + + static BigDecimal readXlangDecimal(MemoryBuffer buffer, int maxBinarySize) { int scale = buffer.readVarInt32(); - return new BigDecimal(readXlangUnscaled(buffer), scale); + return new BigDecimal(readXlangUnscaled(buffer, maxBinarySize), scale); } static BigInteger readXlangBigInteger(MemoryBuffer buffer) { + return readXlangBigInteger(buffer, Integer.MAX_VALUE); + } + + static BigInteger readXlangBigInteger(MemoryBuffer buffer, int maxBinarySize) { int scale = buffer.readVarInt32(); - BigInteger unscaled = readXlangUnscaled(buffer); + BigInteger unscaled = readXlangUnscaled(buffer, maxBinarySize); if (scale != 0) { throw new IllegalArgumentException( "Cannot deserialize xlang decimal with scale " + scale + " into BigInteger"); @@ -114,7 +127,7 @@ static BigInteger readXlangBigInteger(MemoryBuffer buffer) { return unscaled; } - private static BigInteger readXlangUnscaled(MemoryBuffer buffer) { + private static BigInteger readXlangUnscaled(MemoryBuffer buffer, int maxBinarySize) { long header = buffer.readVarUInt64(); if ((header & 1L) == 0L) { return BigInteger.valueOf(decodeZigZag64(header >>> 1)); @@ -126,7 +139,12 @@ private static BigInteger readXlangUnscaled(MemoryBuffer buffer) { throw new IllegalArgumentException( "Invalid decimal magnitude length " + lenLong + " in xlang payload"); } + if (lenLong > maxBinarySize) { + throw new DeserializationException( + "Decimal magnitude length " + lenLong + " exceeds max binary size " + maxBinarySize); + } int len = (int) lenLong; + buffer.checkReadableBytes(len); byte[] payload = buffer.readBytes(len); if (payload[len - 1] == 0) { throw new IllegalArgumentException("Non-canonical decimal payload: trailing zero byte"); @@ -139,6 +157,16 @@ private static BigInteger readXlangUnscaled(MemoryBuffer buffer) { return sign == 0 ? abs : abs.negate(); } + private static void checkBinaryPayloadLength(int len, int maxBinarySize) { + if (len <= 0) { + throw new DeserializationException("Decimal payload length must be positive: " + len); + } + if (len > maxBinarySize) { + throw new DeserializationException( + "Decimal payload length " + len + " exceeds max binary size " + maxBinarySize); + } + } + private static boolean canUseSmallEncoding(BigInteger value) { if (value.compareTo(LONG_MIN) < 0 || value.compareTo(LONG_MAX) > 0) { return false; diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/PrimitiveArraySerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/PrimitiveArraySerializers.java index 8b5a32b483..54ef7809cb 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/PrimitiveArraySerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/PrimitiveArraySerializers.java @@ -25,6 +25,7 @@ import org.apache.fory.context.CopyContext; import org.apache.fory.context.ReadContext; import org.apache.fory.context.WriteContext; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.memory.Platform; import org.apache.fory.resolver.TypeResolver; @@ -74,13 +75,49 @@ public MemoryBuffer toBuffer() { public abstract static class PrimitiveArraySerializer extends Serializer implements Shareable { protected final Config config; + protected final int maxBinarySize; public PrimitiveArraySerializer(TypeResolver typeResolver, Class cls) { super(typeResolver.getConfig(), cls); this.config = typeResolver.getConfig(); + maxBinarySize = config.maxBinarySize(); } } + private static void throwBinarySizeLimitExceeded(long size, int maxBinarySize) { + throw new DeserializationException( + "Binary payload size " + size + " exceeds max binary size " + maxBinarySize); + } + + private static void throwNegativeBinarySize(int size) { + throw new DeserializationException("Binary payload size must be non-negative: " + size); + } + + private static void throwNegativeElementCount(int numElements) { + throw new DeserializationException("Element count must be non-negative: " + numElements); + } + + private static void throwInvalidBinarySize(int size, int maxBinarySize) { + if (size < 0) { + throwNegativeBinarySize(size); + } else { + throwBinarySizeLimitExceeded(size, maxBinarySize); + } + } + + private static void throwInvalidElementCount(int numElements, int maxBinarySize, int elemSize) { + if (numElements < 0) { + throwNegativeElementCount(numElements); + } else { + throwBinarySizeLimitExceeded((long) numElements * elemSize, maxBinarySize); + } + } + + private static void throwUnalignedBinarySize(int size, int elemSize) { + throw new DeserializationException( + "Binary payload size " + size + " is not aligned to element size " + elemSize); + } + public static final class BooleanArraySerializer extends PrimitiveArraySerializer { public BooleanArraySerializer(TypeResolver typeResolver) { super(typeResolver, boolean[].class); @@ -109,11 +146,20 @@ public boolean[] read(ReadContext readContext) { if (readContext.isPeerOutOfBandEnabled()) { MemoryBuffer buf = readContext.readBufferObject(); int size = buf.remaining(); + if (size > maxBinarySize) { + throwBinarySizeLimitExceeded(size, maxBinarySize); + } boolean[] values = new boolean[size]; buf.copyToUnsafe(0, values, Platform.BOOLEAN_ARRAY_OFFSET, size); return values; } int size = buffer.readVarUInt32Small7(); + if (size < 0 || size > maxBinarySize) { + throwInvalidBinarySize(size, maxBinarySize); + } + if (size > buffer.remaining()) { + buffer.checkReadableBytes(size); + } boolean[] values = new boolean[size]; buffer.readToUnsafe(values, Platform.BOOLEAN_ARRAY_OFFSET, size); return values; @@ -148,11 +194,20 @@ public byte[] read(ReadContext readContext) { if (readContext.isPeerOutOfBandEnabled()) { MemoryBuffer buf = readContext.readBufferObject(); int size = buf.remaining(); + if (size > maxBinarySize) { + throwBinarySizeLimitExceeded(size, maxBinarySize); + } byte[] values = new byte[size]; buf.copyToUnsafe(0, values, Platform.BYTE_ARRAY_OFFSET, size); return values; } int size = buffer.readVarUInt32Small7(); + if (size < 0 || size > maxBinarySize) { + throwInvalidBinarySize(size, maxBinarySize); + } + if (size > buffer.remaining()) { + buffer.checkReadableBytes(size); + } byte[] values = new byte[size]; buffer.readToUnsafe(values, Platform.BYTE_ARRAY_OFFSET, size); return values; @@ -208,7 +263,13 @@ public char[] read(ReadContext readContext) { if (readContext.isPeerOutOfBandEnabled()) { MemoryBuffer buf = readContext.readBufferObject(); int size = buf.remaining(); - int numElements = size / 2; + if ((size & 1) != 0) { + throwUnalignedBinarySize(size, 2); + } + if (size > maxBinarySize) { + throwBinarySizeLimitExceeded(size, maxBinarySize); + } + int numElements = size >>> 1; char[] values = new char[numElements]; if (Platform.IS_LITTLE_ENDIAN) { buf.copyToUnsafe(0, values, Platform.CHAR_ARRAY_OFFSET, size); @@ -218,7 +279,16 @@ public char[] read(ReadContext readContext) { return values; } int size = buffer.readVarUInt32Small7(); - int numElements = size / 2; + if ((size & 1) != 0) { + throwUnalignedBinarySize(size, 2); + } + if (size < 0 || size > maxBinarySize) { + throwInvalidBinarySize(size, maxBinarySize); + } + if (size > buffer.remaining()) { + buffer.checkReadableBytes(size); + } + int numElements = size >>> 1; char[] values = new char[numElements]; if (Platform.IS_LITTLE_ENDIAN) { buffer.readToUnsafe(values, Platform.CHAR_ARRAY_OFFSET, size); @@ -256,7 +326,7 @@ public short[] copy(CopyContext copyContext, short[] originArray) { @Override public short[] read(ReadContext readContext) { - return readShortBits(readContext); + return readShortBits(readContext, maxBinarySize); } } @@ -307,7 +377,13 @@ public int[] read(ReadContext readContext) { if (readContext.isPeerOutOfBandEnabled()) { MemoryBuffer buf = readContext.readBufferObject(); int size = buf.remaining(); - int numElements = size / 4; + if ((size & 3) != 0) { + throwUnalignedBinarySize(size, 4); + } + if (size > maxBinarySize) { + throwBinarySizeLimitExceeded(size, maxBinarySize); + } + int numElements = size >>> 2; int[] values = new int[numElements]; if (size > 0) { if (Platform.IS_LITTLE_ENDIAN) { @@ -322,7 +398,16 @@ public int[] read(ReadContext readContext) { return readInt32Compressed(buffer); } int size = buffer.readVarUInt32Small7(); - int numElements = size / 4; + if ((size & 3) != 0) { + throwUnalignedBinarySize(size, 4); + } + if (size < 0 || size > maxBinarySize) { + throwInvalidBinarySize(size, maxBinarySize); + } + if (size > buffer.remaining()) { + buffer.checkReadableBytes(size); + } + int numElements = size >>> 2; int[] values = new int[numElements]; if (size > 0) { if (Platform.IS_LITTLE_ENDIAN) { @@ -353,6 +438,9 @@ private void writeInt32Compressed(MemoryBuffer buffer, int[] value) { private int[] readInt32Compressed(MemoryBuffer buffer) { int numElements = buffer.readVarUInt32Small7(); + if (numElements < 0 || numElements > maxBinarySize / 4) { + throwInvalidElementCount(numElements, maxBinarySize, 4); + } int[] values = new int[numElements]; for (int i = 0; i < numElements; i++) { values[i] = buffer.readVarInt32(); @@ -414,7 +502,13 @@ public long[] read(ReadContext readContext) { if (readContext.isPeerOutOfBandEnabled()) { MemoryBuffer buf = readContext.readBufferObject(); int size = buf.remaining(); - int numElements = size / 8; + if ((size & 7) != 0) { + throwUnalignedBinarySize(size, 8); + } + if (size > maxBinarySize) { + throwBinarySizeLimitExceeded(size, maxBinarySize); + } + int numElements = size >>> 3; long[] values = new long[numElements]; if (size > 0) { if (Platform.IS_LITTLE_ENDIAN) { @@ -429,7 +523,16 @@ public long[] read(ReadContext readContext) { return readInt64Compressed(buffer, config.longEncoding()); } int size = buffer.readVarUInt32Small7(); - int numElements = size / 8; + if ((size & 7) != 0) { + throwUnalignedBinarySize(size, 8); + } + if (size < 0 || size > maxBinarySize) { + throwInvalidBinarySize(size, maxBinarySize); + } + if (size > buffer.remaining()) { + buffer.checkReadableBytes(size); + } + int numElements = size >>> 3; long[] values = new long[numElements]; if (size > 0) { if (Platform.IS_LITTLE_ENDIAN) { @@ -468,6 +571,9 @@ private void writeInt64Compressed( private long[] readInt64Compressed(MemoryBuffer buffer, Int64Encoding longEncoding) { int numElements = buffer.readVarUInt32Small7(); + if (numElements < 0 || numElements > maxBinarySize / 8) { + throwInvalidElementCount(numElements, maxBinarySize, 8); + } long[] values = new long[numElements]; if (longEncoding == Int64Encoding.TAGGED) { for (int i = 0; i < numElements; i++) { @@ -525,7 +631,13 @@ public float[] read(ReadContext readContext) { if (readContext.isPeerOutOfBandEnabled()) { MemoryBuffer buf = readContext.readBufferObject(); int size = buf.remaining(); - int numElements = size / 4; + if ((size & 3) != 0) { + throwUnalignedBinarySize(size, 4); + } + if (size > maxBinarySize) { + throwBinarySizeLimitExceeded(size, maxBinarySize); + } + int numElements = size >>> 2; float[] values = new float[numElements]; if (Platform.IS_LITTLE_ENDIAN) { buf.copyToUnsafe(0, values, Platform.FLOAT_ARRAY_OFFSET, size); @@ -535,7 +647,16 @@ public float[] read(ReadContext readContext) { return values; } int size = buffer.readVarUInt32Small7(); - int numElements = size / 4; + if ((size & 3) != 0) { + throwUnalignedBinarySize(size, 4); + } + if (size < 0 || size > maxBinarySize) { + throwInvalidBinarySize(size, maxBinarySize); + } + if (size > buffer.remaining()) { + buffer.checkReadableBytes(size); + } + int numElements = size >>> 2; float[] values = new float[numElements]; if (Platform.IS_LITTLE_ENDIAN) { buffer.readToUnsafe(values, Platform.FLOAT_ARRAY_OFFSET, size); @@ -599,7 +720,13 @@ public double[] read(ReadContext readContext) { if (readContext.isPeerOutOfBandEnabled()) { MemoryBuffer buf = readContext.readBufferObject(); int size = buf.remaining(); - int numElements = size / 8; + if ((size & 7) != 0) { + throwUnalignedBinarySize(size, 8); + } + if (size > maxBinarySize) { + throwBinarySizeLimitExceeded(size, maxBinarySize); + } + int numElements = size >>> 3; double[] values = new double[numElements]; if (Platform.IS_LITTLE_ENDIAN) { buf.copyToUnsafe(0, values, Platform.DOUBLE_ARRAY_OFFSET, size); @@ -609,7 +736,16 @@ public double[] read(ReadContext readContext) { return values; } int size = buffer.readVarUInt32Small7(); - int numElements = size / 8; + if ((size & 7) != 0) { + throwUnalignedBinarySize(size, 8); + } + if (size < 0 || size > maxBinarySize) { + throwInvalidBinarySize(size, maxBinarySize); + } + if (size > buffer.remaining()) { + buffer.checkReadableBytes(size); + } + int numElements = size >>> 3; double[] values = new double[numElements]; if (Platform.IS_LITTLE_ENDIAN) { buffer.readToUnsafe(values, Platform.DOUBLE_ARRAY_OFFSET, size); @@ -647,7 +783,7 @@ public Float16Array copy(CopyContext copyContext, Float16Array originArray) { @Override public Float16Array read(ReadContext readContext) { - return Float16Array.wrapBits(readShortBits(readContext)); + return Float16Array.wrapBits(readShortBits(readContext, maxBinarySize)); } } @@ -669,7 +805,7 @@ public BFloat16Array copy(CopyContext copyContext, BFloat16Array originArray) { @Override public BFloat16Array read(ReadContext readContext) { - return BFloat16Array.wrapBits(readShortBits(readContext)); + return BFloat16Array.wrapBits(readShortBits(readContext, maxBinarySize)); } } @@ -699,12 +835,18 @@ private static void writeInt16BySwapEndian(MemoryBuffer buffer, short[] value) { buffer._unsafeWriterIndex(idx + length * 2); } - private static short[] readShortBits(ReadContext readContext) { + private static short[] readShortBits(ReadContext readContext, int maxBinarySize) { MemoryBuffer buffer = readContext.getBuffer(); if (readContext.isPeerOutOfBandEnabled()) { MemoryBuffer buf = readContext.readBufferObject(); int size = buf.remaining(); - int numElements = size / 2; + if ((size & 1) != 0) { + throwUnalignedBinarySize(size, 2); + } + if (size > maxBinarySize) { + throwBinarySizeLimitExceeded(size, maxBinarySize); + } + int numElements = size >>> 1; short[] values = new short[numElements]; if (Platform.IS_LITTLE_ENDIAN) { buf.copyToUnsafe(0, values, Platform.SHORT_ARRAY_OFFSET, size); @@ -714,7 +856,16 @@ private static short[] readShortBits(ReadContext readContext) { return values; } int size = buffer.readVarUInt32Small7(); - int numElements = size / 2; + if ((size & 1) != 0) { + throwUnalignedBinarySize(size, 2); + } + if (size < 0 || size > maxBinarySize) { + throwInvalidBinarySize(size, maxBinarySize); + } + if (size > buffer.remaining()) { + buffer.checkReadableBytes(size); + } + int numElements = size >>> 1; short[] values = new short[numElements]; if (Platform.IS_LITTLE_ENDIAN) { buffer.readToUnsafe(values, Platform.SHORT_ARRAY_OFFSET, size); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java index 39ca5c3356..a61f114f54 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java @@ -254,7 +254,7 @@ public Collection onCollectionWrite(WriteContext writeContext, T value) { @Override public T newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); int refId = readContext.lastPreservedRefId(); Comparator comparator = (Comparator) readContext.readRef(); @@ -302,7 +302,7 @@ public Collection onCollectionWrite(WriteContext writeContext, T value) { @Override public T newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); int refId = readContext.lastPreservedRefId(); Comparator comparator = (Comparator) readContext.readRef(); @@ -414,7 +414,7 @@ public Map onMapWrite(WriteContext writeContext, T value) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readMapSize(buffer); setNumElements(numElements); int refId = readContext.lastPreservedRefId(); Comparator comparator = (Comparator) readContext.readRef(); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java index f07b40531a..8fe2c721a8 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java @@ -27,6 +27,7 @@ import org.apache.fory.context.CopyContext; import org.apache.fory.context.ReadContext; import org.apache.fory.context.WriteContext; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.reflect.ReflectionUtils; import org.apache.fory.resolver.ClassResolver; @@ -46,6 +47,7 @@ public abstract class CollectionLikeSerializer extends Serializer { private MethodHandle constructor; private int numElements; protected final Config config; + protected final int maxCollectionSize; protected final boolean supportCodegenHook; protected final TypeInfoHolder elementTypeInfoHolder; protected final TypeResolver typeResolver; @@ -67,6 +69,7 @@ public CollectionLikeSerializer( TypeResolver typeResolver, Class cls, boolean supportCodegenHook) { super(typeResolver.getConfig(), cls); this.config = typeResolver.getConfig(); + maxCollectionSize = config.maxCollectionSize(); this.supportCodegenHook = supportCodegenHook; elementTypeInfoHolder = typeResolver.nilTypeInfoHolder(); this.typeResolver = typeResolver; @@ -76,6 +79,7 @@ public CollectionLikeSerializer( TypeResolver typeResolver, Class cls, boolean supportCodegenHook, boolean immutable) { super(typeResolver.getConfig(), cls, immutable); this.config = typeResolver.getConfig(); + maxCollectionSize = config.maxCollectionSize(); this.supportCodegenHook = supportCodegenHook; elementTypeInfoHolder = typeResolver.nilTypeInfoHolder(); this.typeResolver = typeResolver; @@ -458,7 +462,7 @@ public T read(ReadContext readContext) { */ public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - numElements = buffer.readVarUInt32Small7(); + numElements = readCollectionSize(buffer); if (constructor == null) { constructor = ReflectionUtils.getCtrHandle(type, true); } @@ -537,6 +541,29 @@ protected void setNumElements(int numElements) { this.numElements = numElements; } + protected final int readCollectionSize(MemoryBuffer buffer) { + int numElements = buffer.readVarUInt32Small7(); + checkCollectionSize(numElements); + return numElements; + } + + protected final void checkCollectionSize(int numElements) { + // Keep this as direct primitive branches. Collection reads are hot enough that + // Preconditions.checkArgument would add helper/varargs overhead on the valid path. + if (numElements < 0 || numElements > maxCollectionSize) { + throwInvalidCollectionSize(numElements); + } + } + + private void throwInvalidCollectionSize(int numElements) { + if (numElements < 0) { + throw new DeserializationException("Collection size must be non-negative: " + numElements); + } else { + throw new DeserializationException( + "Collection size " + numElements + " exceeds max collection size " + maxCollectionSize); + } + } + public abstract T onCollectionRead(Collection collection); protected void readElements(ReadContext readContext, Collection collection, int numElements) { diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java index 80c8d163b7..f76b2286aa 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java @@ -50,9 +50,11 @@ import java.util.concurrent.CopyOnWriteArraySet; import java.util.concurrent.LinkedBlockingQueue; import org.apache.fory.collection.CollectionSnapshot; +import org.apache.fory.config.Config; import org.apache.fory.context.CopyContext; import org.apache.fory.context.ReadContext; import org.apache.fory.context.WriteContext; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.exception.ForyException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.memory.Platform; @@ -74,6 +76,39 @@ */ @SuppressWarnings({"unchecked", "rawtypes"}) public class CollectionSerializers { + private static void throwBinarySizeLimitExceeded(long size, int maxBinarySize) { + throw new DeserializationException( + "Binary payload size " + size + " exceeds max binary size " + maxBinarySize); + } + + private static void throwNegativeBinarySize(int size) { + throw new DeserializationException("Binary payload size must be non-negative: " + size); + } + + private static void throwUnalignedBinarySize(int size, int elemSize) { + throw new DeserializationException( + "Binary payload size " + size + " is not aligned to element size " + elemSize); + } + + private static void checkBoundedQueueCapacity(Config config, int numElements, int capacity) { + // Keep these as direct primitive branches. This collection read path is JIT-sensitive; using + // Preconditions.checkArgument here adds helper/varargs overhead and hurts inlining. + if (numElements < 0) { + throw new DeserializationException("Queue size must be non-negative: " + numElements); + } + if (capacity <= 0) { + throw new DeserializationException("Queue capacity must be positive: " + capacity); + } + if (capacity < numElements) { + throw new DeserializationException( + "Queue capacity " + capacity + " is smaller than serialized size " + numElements); + } + int maxCollectionSize = config.maxCollectionSize(); + if (capacity > maxCollectionSize) { + throw new DeserializationException( + "Queue capacity " + capacity + " exceeds max collection size " + maxCollectionSize); + } + } public static final class ArrayListSerializer extends CollectionSerializer { public ArrayListSerializer(TypeResolver typeResolver) { @@ -83,7 +118,7 @@ public ArrayListSerializer(TypeResolver typeResolver) { @Override public ArrayList newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); ArrayList arrayList = new ArrayList(numElements); readContext.reference(arrayList); @@ -141,7 +176,7 @@ public List read(ReadContext readContext) { @Override public ArrayList newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); ArrayList arrayList = new ArrayList(numElements); readContext.reference(arrayList); @@ -157,7 +192,7 @@ public HashSetSerializer(TypeResolver typeResolver) { @Override public HashSet newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); HashSet hashSet = new HashSet(numElements); readContext.reference(hashSet); @@ -173,7 +208,7 @@ public LinkedHashSetSerializer(TypeResolver typeResolver) { @Override public LinkedHashSet newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); LinkedHashSet hashSet = new LinkedHashSet(numElements); readContext.reference(hashSet); @@ -220,7 +255,7 @@ public Collection onCollectionWrite(WriteContext writeContext, T value) { public T newCollection(ReadContext readContext) { assert !config.isXlang(); MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); T collection; Comparator comparator = (Comparator) readContext.readRef(); @@ -284,7 +319,11 @@ public void write(WriteContext writeContext, List value) { @Override public List read(ReadContext readContext) { if (config.isXlang()) { - readContext.getBuffer().readVarUInt32Small7(); + int numElements = readCollectionSize(readContext.getBuffer()); + if (numElements != 0) { + throw new DeserializationException( + "Empty list payload must have zero elements but got " + numElements); + } } return Collections.EMPTY_LIST; } @@ -301,7 +340,7 @@ public CopyOnWriteArrayListSerializer( @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); return new CollectionContainer<>(numElements); } @@ -335,7 +374,7 @@ public CopyOnWriteArraySetSerializer( @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); return new CollectionContainer<>(numElements); } @@ -484,7 +523,7 @@ public CollectionSnapshot onCollectionWrite( @Override public ConcurrentSkipListSet newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); assert !config.isXlang(); int refId = readContext.lastPreservedRefId(); @@ -631,7 +670,7 @@ public VectorSerializer(TypeResolver typeResolver, Class cls) { @Override public Vector newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); Vector vector = new Vector<>(numElements); readContext.reference(vector); @@ -648,7 +687,7 @@ public ArrayDequeSerializer(TypeResolver typeResolver, Class cls) { @Override public ArrayDeque newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); ArrayDeque deque = new ArrayDeque(numElements); readContext.reference(deque); @@ -693,7 +732,7 @@ public EnumSet read(ReadContext readContext) { Class elemClass = typeResolver.readTypeInfo(readContext).getType(); EnumSet object = EnumSet.noneOf(elemClass); Serializer elemSerializer = typeResolver.getSerializer(elemClass); - int length = buffer.readVarUInt32Small7(); + int length = readCollectionSize(buffer); for (int i = 0; i < length; i++) { object.add(elemSerializer.read(readContext)); } @@ -707,8 +746,11 @@ public EnumSet copy(CopyContext copyContext, EnumSet originCollection) { } public static class BitSetSerializer extends Serializer { + private final int maxBinarySize; + public BitSetSerializer(TypeResolver typeResolver, Class type) { super(typeResolver.getConfig(), type); + maxBinarySize = typeResolver.getConfig().maxBinarySize(); } @Override @@ -727,7 +769,20 @@ public BitSet copy(CopyContext copyContext, BitSet originCollection) { @Override public BitSet read(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - long[] values = buffer.readLongs(buffer.readVarUInt32Small7()); + int size = buffer.readVarUInt32Small7(); + if (size < 0) { + throwNegativeBinarySize(size); + } + if ((size & 7) != 0) { + throwUnalignedBinarySize(size, 8); + } + if (size > maxBinarySize) { + throwBinarySizeLimitExceeded(size, maxBinarySize); + } + if (size > buffer.remaining()) { + buffer.checkReadableBytes(size); + } + long[] values = buffer.readLongs(size); return BitSet.valueOf(values); } } @@ -758,7 +813,7 @@ public Collection newCollection(CopyContext copyContext, Collection collection) public PriorityQueue newCollection(ReadContext readContext) { assert !config.isXlang(); MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); Comparator comparator = (Comparator) readContext.readRef(); PriorityQueue queue = new PriorityQueue(comparator); @@ -813,9 +868,10 @@ public CollectionSnapshot onCollectionWrite( @Override public ArrayBlockingQueue newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); int capacity = buffer.readVarUInt32Small7(); + checkBoundedQueueCapacity(config, numElements, capacity); ArrayBlockingQueue queue = new ArrayBlockingQueue<>(capacity); readContext.reference(queue); return queue; @@ -873,9 +929,10 @@ public CollectionSnapshot onCollectionWrite( @Override public LinkedBlockingQueue newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); int capacity = buffer.readVarUInt32Small7(); + checkBoundedQueueCapacity(config, numElements, capacity); LinkedBlockingQueue queue = new LinkedBlockingQueue<>(capacity); readContext.reference(queue); return queue; @@ -1012,7 +1069,7 @@ public XlangListDefaultSerializer(TypeResolver typeResolver, Class cls) { @Override public List newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); ArrayList list = new ArrayList(numElements); readContext.reference(list); @@ -1028,7 +1085,7 @@ public XlangSetDefaultSerializer(TypeResolver typeResolver, Class cls) { @Override public Set newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); HashSet set = new HashSet(numElements); readContext.reference(set); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java index 5e77a63c7c..4799c7d2ad 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java @@ -79,7 +79,7 @@ public ImmutableListSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); return new CollectionContainer<>(numElements); } @@ -112,7 +112,7 @@ public RegularImmutableListSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); return new CollectionContainer(numElements); } @@ -146,7 +146,7 @@ public ImmutableSetSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); return new CollectionContainer<>(numElements); } @@ -188,7 +188,7 @@ public Collection onCollectionWrite(WriteContext writeContext, T value) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); Comparator comparator = (Comparator) readContext.readRef(); return new SortedCollectionContainer(comparator, numElements); @@ -221,7 +221,7 @@ public GuavaMapSerializer(TypeResolver typeResolver, Class cls) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readMapSize(buffer); setNumElements(numElements); return new MapContainer(numElements); } @@ -249,7 +249,7 @@ public T onMapRead(Map map) { @Override public T read(ReadContext readContext) { - int size = readContext.getBuffer().readVarUInt32Small7(); + int size = readMapSize(readContext.getBuffer()); Map map = new HashMap(); readElements(readContext, size, map); return xnewInstance(map); @@ -372,7 +372,7 @@ public Map onMapWrite(WriteContext writeContext, T value) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readMapSize(buffer); setNumElements(numElements); Comparator comparator = (Comparator) readContext.readRef(); return new SortedMapContainer<>(comparator, numElements); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java index 7eb3d5408a..22419687a0 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java @@ -112,7 +112,7 @@ public ImmutableListSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); if (Platform.JAVA_VERSION > 8) { return new CollectionContainer<>(numElements); @@ -162,7 +162,7 @@ public ImmutableSetSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); if (Platform.JAVA_VERSION > 8) { return new CollectionContainer<>(numElements); @@ -212,7 +212,7 @@ public ImmutableMapSerializer(TypeResolver typeResolver, Class cls) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readMapSize(buffer); setNumElements(numElements); if (Platform.JAVA_VERSION > 8) { return new JDKImmutableMapContainer(numElements); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java index e5c7f8fd91..a837dce42c 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java @@ -42,6 +42,7 @@ import org.apache.fory.context.CopyContext; import org.apache.fory.context.ReadContext; import org.apache.fory.context.WriteContext; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.reflect.ReflectionUtils; import org.apache.fory.reflect.TypeRef; @@ -80,6 +81,7 @@ private MapTypeCache(TypeResolver typeResolver) { protected MethodHandle constructor; protected final Config config; + protected final int maxCollectionSize; protected final boolean supportCodegenHook; private final GenericType objType; // For subclass whose kv type are instantiated already, such as @@ -107,6 +109,7 @@ public MapLikeSerializer( TypeResolver typeResolver, Class cls, boolean supportCodegenHook, boolean immutable) { super(typeResolver.getConfig(), cls, immutable); this.config = typeResolver.getConfig(); + maxCollectionSize = config.maxCollectionSize(); this.typeResolver = typeResolver; trackRef = typeResolver.getConfig().trackingRef(); this.supportCodegenHook = supportCodegenHook; @@ -954,7 +957,7 @@ public void onMapWriteFinish(Map map) {} */ public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - numElements = buffer.readVarUInt32Small7(); + numElements = readMapSize(buffer); if (constructor == null) { constructor = ReflectionUtils.getCtrHandle(type, true); } @@ -1001,6 +1004,29 @@ public void setNumElements(int numElements) { this.numElements = numElements; } + protected final int readMapSize(MemoryBuffer buffer) { + int numElements = buffer.readVarUInt32Small7(); + checkMapSize(numElements); + return numElements; + } + + protected final void checkMapSize(int numElements) { + // Keep this as direct primitive branches. Map reads are hot enough that + // Preconditions.checkArgument would add helper/varargs overhead on the valid path. + if (numElements < 0 || numElements > maxCollectionSize) { + throwInvalidMapSize(numElements); + } + } + + private void throwInvalidMapSize(int numElements) { + if (numElements < 0) { + throw new DeserializationException("Map size must be non-negative: " + numElements); + } else { + throw new DeserializationException( + "Map size " + numElements + " exceeds max collection size " + maxCollectionSize); + } + } + public abstract T onMapCopy(Map map); public abstract T onMapRead(Map map); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java index 6ca463a727..2ddb849058 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java @@ -65,7 +65,7 @@ public HashMapSerializer(TypeResolver typeResolver) { @Override public HashMap newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readMapSize(buffer); setNumElements(numElements); HashMap hashMap = new HashMap(numElements); readContext.reference(hashMap); @@ -86,7 +86,7 @@ public LinkedHashMapSerializer(TypeResolver typeResolver) { @Override public LinkedHashMap newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readMapSize(buffer); setNumElements(numElements); LinkedHashMap hashMap = new LinkedHashMap(numElements); readContext.reference(hashMap); @@ -107,7 +107,7 @@ public LazyMapSerializer(TypeResolver typeResolver) { @Override public LazyMap newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readMapSize(buffer); setNumElements(numElements); LazyMap map = new LazyMap(numElements); readContext.reference(map); @@ -159,7 +159,7 @@ public Map onMapWrite(WriteContext writeContext, T value) { public Map newMap(ReadContext readContext) { assert !config.isXlang(); MemoryBuffer buffer = readContext.getBuffer(); - setNumElements(buffer.readVarUInt32Small7()); + setNumElements(readMapSize(buffer)); T map; Comparator comparator = (Comparator) readContext.readRef(); if (type == TreeMap.class) { @@ -280,7 +280,7 @@ public ConcurrentHashMapSerializer(TypeResolver typeResolver, Class keyType = typeResolver.readTypeInfo(readContext).getType(); return new EnumMap(keyType); } @@ -520,7 +520,7 @@ public Object onMapCopy(Map map) { public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readMapSize(buffer); setNumElements(numElements); HashMap map = new HashMap<>(numElements); readContext.reference(map); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/PrimitiveListSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/PrimitiveListSerializers.java index 56e23e6c87..036ee7110a 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/PrimitiveListSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/PrimitiveListSerializers.java @@ -39,6 +39,7 @@ import org.apache.fory.context.CopyContext; import org.apache.fory.context.ReadContext; import org.apache.fory.context.WriteContext; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.memory.Platform; import org.apache.fory.resolver.TypeResolver; @@ -52,10 +53,28 @@ /** Serializers for primitive list types. */ @SuppressWarnings({"rawtypes", "unchecked"}) public class PrimitiveListSerializers { + private static void throwBinarySizeLimitExceeded(long size, int maxBinarySize) { + throw new DeserializationException( + "Binary payload size " + size + " exceeds max binary size " + maxBinarySize); + } + + private static void throwNegativeBinarySize(int size) { + throw new DeserializationException("Binary payload size must be non-negative: " + size); + } + + private static void throwNegativeElementCount(int size) { + throw new DeserializationException("Primitive list size must be non-negative: " + size); + } + + private static void throwUnalignedBinarySize(int size, int elemSize) { + throw new DeserializationException( + "Binary payload size " + size + " is not aligned to element size " + elemSize); + } private abstract static class PrimitiveListSerializer extends CollectionLikeSerializer implements Shareable { private final boolean denseArrayPayload; + protected final int maxBinarySize; private PrimitiveListSerializer(TypeResolver typeResolver, Class cls) { this(typeResolver, cls, false); @@ -65,6 +84,7 @@ private PrimitiveListSerializer( TypeResolver typeResolver, Class cls, boolean denseArrayPayload) { super(typeResolver, cls, false, false); this.denseArrayPayload = denseArrayPayload; + maxBinarySize = config.maxBinarySize(); } @Override @@ -104,6 +124,9 @@ protected final void writeFixedWidthHeader(MemoryBuffer buffer, int size, int el protected final int readXlangListHeader(MemoryBuffer buffer) { int size = buffer.readVarUInt32Small7(); + if (size < 0) { + throwNegativeElementCount(size); + } if (config.isXlang() && size > 0) { int flags = buffer.readByte(); if (flags != CollectionFlags.DECL_SAME_TYPE_NOT_HAS_NULL) { @@ -114,21 +137,53 @@ protected final int readXlangListHeader(MemoryBuffer buffer) { } protected final int readOneByteHeader(MemoryBuffer buffer) { + int size; if (denseArrayPayload) { - return buffer.readVarUInt32Small7(); + size = buffer.readVarUInt32Small7(); + } else { + size = readXlangListHeader(buffer); + } + if (size < 0) { + throwNegativeBinarySize(size); + } + if (size > maxBinarySize) { + throwBinarySizeLimitExceeded(size, maxBinarySize); } - return readXlangListHeader(buffer); + if (size > buffer.remaining()) { + buffer.checkReadableBytes(size); + } + return size; } protected final int readFixedWidthHeader(MemoryBuffer buffer, int elemSize) { + int byteSize; if (denseArrayPayload) { - int byteSize = buffer.readVarUInt32Small7(); - return byteSize / elemSize; + byteSize = buffer.readVarUInt32Small7(); + } else if (config.isXlang()) { + int size = readXlangListHeader(buffer); + if (size > maxBinarySize / elemSize) { + throwBinarySizeLimitExceeded((long) size * elemSize, maxBinarySize); + } + byteSize = size * elemSize; + if (byteSize > buffer.remaining()) { + buffer.checkReadableBytes(byteSize); + } + return size; + } else { + byteSize = buffer.readVarUInt32Small7(); + } + if (byteSize < 0) { + throwNegativeBinarySize(byteSize); } - if (config.isXlang()) { - return readXlangListHeader(buffer); + if (byteSize % elemSize != 0) { + throwUnalignedBinarySize(byteSize, elemSize); + } + if (byteSize > maxBinarySize) { + throwBinarySizeLimitExceeded(byteSize, maxBinarySize); + } + if (byteSize > buffer.remaining()) { + buffer.checkReadableBytes(byteSize); } - int byteSize = buffer.readVarUInt32Small7(); return byteSize / elemSize; } } @@ -306,6 +361,12 @@ private void writeInt32Compressed(MemoryBuffer buffer, Int32List value) { private Int32List readInt32Compressed(MemoryBuffer buffer) { int size = buffer.readVarUInt32Small7(); + if (size < 0) { + throwNegativeElementCount(size); + } + if (size > maxBinarySize / 4) { + throwBinarySizeLimitExceeded((long) size * 4, maxBinarySize); + } Int32List list = new Int32List(size); for (int i = 0; i < size; i++) { list.add(buffer.readVarInt32()); @@ -395,6 +456,12 @@ private void writeInt64Compressed( private Int64List readInt64Compressed(MemoryBuffer buffer, Int64Encoding longEncoding) { int size = buffer.readVarUInt32Small7(); + if (size < 0) { + throwNegativeElementCount(size); + } + if (size > maxBinarySize / 8) { + throwBinarySizeLimitExceeded((long) size * 8, maxBinarySize); + } Int64List list = new Int64List(size); if (longEncoding == Int64Encoding.TAGGED) { for (int i = 0; i < size; i++) { @@ -551,6 +618,12 @@ private void writeUInt32Compressed(MemoryBuffer buffer, UInt32List value) { private UInt32List readUInt32Compressed(MemoryBuffer buffer) { int size = buffer.readVarUInt32Small7(); + if (size < 0) { + throwNegativeElementCount(size); + } + if (size > maxBinarySize / 4) { + throwBinarySizeLimitExceeded((long) size * 4, maxBinarySize); + } UInt32List list = new UInt32List(size); for (int i = 0; i < size; i++) { list.add(buffer.readVarInt32()); @@ -640,6 +713,12 @@ private void writeUInt64Compressed( private UInt64List readUInt64Compressed(MemoryBuffer buffer, Int64Encoding longEncoding) { int size = buffer.readVarUInt32Small7(); + if (size < 0) { + throwNegativeElementCount(size); + } + if (size > maxBinarySize / 8) { + throwBinarySizeLimitExceeded((long) size * 8, maxBinarySize); + } UInt64List list = new UInt64List(size); if (longEncoding == Int64Encoding.TAGGED) { for (int i = 0; i < size; i++) { diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/SubListSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/SubListSerializers.java index 75ec47a6ee..1af1fa26a4 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/SubListSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/SubListSerializers.java @@ -157,7 +157,7 @@ public SubListSerializer(TypeResolver typeResolver, Class type) { @Override public Collection newCollection(ReadContext readContext) { org.apache.fory.memory.MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); + int numElements = readCollectionSize(buffer); setNumElements(numElements); return new ArrayList(numElements); } diff --git a/java/fory-core/src/test/java/org/apache/fory/memory/MemoryBufferTest.java b/java/fory-core/src/test/java/org/apache/fory/memory/MemoryBufferTest.java index 19bec9b934..cd49a1945f 100644 --- a/java/fory-core/src/test/java/org/apache/fory/memory/MemoryBufferTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/memory/MemoryBufferTest.java @@ -20,6 +20,7 @@ package org.apache.fory.memory; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertThrows; import static org.testng.Assert.assertTrue; import java.nio.ByteBuffer; @@ -275,6 +276,24 @@ public void testWriteVarUInt32() { } } + @Test + public void testReadVarUInt32RejectsMalformedFifthByte() { + byte[] malformed = new byte[] {(byte) 0x80, (byte) 0x80, (byte) 0x80, (byte) 0x80, 0x10}; + assertThrows(IllegalArgumentException.class, () -> MemoryUtils.wrap(malformed).readVarUInt32()); + assertThrows( + IllegalArgumentException.class, () -> MemoryUtils.wrap(malformed).readVarUInt32Small7()); + assertThrows( + IllegalArgumentException.class, () -> MemoryUtils.wrap(malformed).readVarUInt32Small14()); + assertThrows(IllegalArgumentException.class, () -> MemoryUtils.wrap(malformed).readVarInt32()); + assertThrows( + IllegalArgumentException.class, () -> MemoryUtils.wrap(malformed).readBinarySize()); + + byte[] maxUInt32 = new byte[] {(byte) 0xff, (byte) 0xff, (byte) 0xff, (byte) 0xff, 0x0f}; + assertEquals(MemoryUtils.wrap(maxUInt32).readVarUInt32(), -1); + assertEquals(MemoryUtils.wrap(maxUInt32).readVarUInt32Small7(), -1); + assertEquals(MemoryUtils.wrap(maxUInt32).readVarUInt32Small14(), -1); + } + private void checkVarUInt32(MemoryBuffer buf, int value, int bytesWritten) { assertEquals(buf.writerIndex(), buf.readerIndex()); int actualBytesWritten = buf.writeVarUInt32(value); diff --git a/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefEncoderTest.java b/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefEncoderTest.java index 938e632631..8a1e196950 100644 --- a/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefEncoderTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefEncoderTest.java @@ -493,6 +493,15 @@ public void testDecodeRejectsHashConsistentMalformedTypeDefBody() { RuntimeException.class, () -> TypeDef.readTypeDef(fory.getTypeResolver(), encoded)); } + @Test + public void testSkipTypeDefRejectsExtendedSizeOverflow() { + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(8); + buffer.writeVarUInt32(-1); + + Assert.assertThrows( + DeserializationException.class, () -> TypeDef.skipTypeDef(buffer, TypeDef.META_SIZE_MASKS)); + } + @Test public void testDecodeRejectsRegisteredTypeDefKindMismatch() { Fory fory = Fory.builder().withXlang(true).withMetaShare(true).build(); diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java index 8d9da4e805..6b49e00555 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java @@ -20,11 +20,13 @@ package org.apache.fory.serializer; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertThrows; import static org.testng.Assert.assertTrue; import java.lang.reflect.Array; import java.lang.reflect.Field; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; @@ -37,6 +39,10 @@ import org.apache.fory.config.Int64Encoding; import org.apache.fory.context.MetaReadContext; import org.apache.fory.context.MetaWriteContext; +import org.apache.fory.context.ReadContext; +import org.apache.fory.exception.DeserializationException; +import org.apache.fory.memory.MemoryBuffer; +import org.apache.fory.memory.MemoryUtils; import org.apache.fory.reflect.ReflectionUtils; import org.apache.fory.test.bean.ArraysData; import org.apache.fory.type.BFloat16; @@ -196,6 +202,21 @@ public void testObjectArrayCopy(Fory fory) { copyCheck(fory, new Object[] {"str", 1}); } + @Test + public void testObjectArrayReadRejectsOversizedElementCount() { + Fory fory = + Fory.builder() + .withXlang(false) + .withRefTracking(true) + .requireClassRegistration(false) + .withMaxCollectionSize(1) + .build(); + assertThrows( + DeserializationException.class, () -> readObjectArrayPayload(fory, Object[].class, 2)); + assertThrows( + DeserializationException.class, () -> readObjectArrayPayload(fory, String[].class, 2)); + } + @Test(dataProvider = "crossLanguageReferenceTrackingConfig") public void testMultiArraySerialization(boolean referenceTracking, boolean xlang) { if (xlang) { @@ -279,6 +300,104 @@ public static void testPrimitiveArray(Fory fory1, Fory fory2) { new double[] {1.0, 1.0}, (double[]) serDe(fory1, fory2, new double[] {1.0, 1.0}))); } + @Test + public void testPrimitiveArrayReadRejectsOversizedBinaryPayload() { + Fory fory = + Fory.builder() + .withMaxBinarySize(4) + .withIntArrayCompressed(true) + .withLongArrayCompressed(true) + .build(); + for (Class arrayType : + new Class[] { + boolean[].class, + byte[].class, + char[].class, + short[].class, + int[].class, + long[].class, + float[].class, + double[].class + }) { + assertThrows( + DeserializationException.class, + () -> readPrimitiveArrayPayload(fory, arrayType, 8, false)); + } + assertThrows( + DeserializationException.class, + () -> readPrimitiveArrayPayload(fory, byte[].class, 5, true)); + } + + @Test + public void testPrimitiveArrayReadRejectsUnalignedBinaryPayload() { + Fory fory = Fory.builder().withMaxBinarySize(64).build(); + for (Class arrayType : + new Class[] { + char[].class, short[].class, int[].class, long[].class, float[].class, double[].class + }) { + assertThrows( + DeserializationException.class, + () -> readPrimitiveArrayPayload(fory, arrayType, 3, false)); + } + } + + @Test + public void testPrimitiveArrayReadRejectsNegativeDecodedBinaryPayload() { + Fory fixedWidthFory = Fory.builder().build(); + assertThrows( + DeserializationException.class, + () -> readPrimitiveArrayRawPayload(fixedWidthFory, char[].class)); + + Fory compressedFory = + Fory.builder().withIntArrayCompressed(true).withLongArrayCompressed(true).build(); + assertThrows( + DeserializationException.class, + () -> readPrimitiveArrayRawPayload(compressedFory, int[].class)); + assertThrows( + DeserializationException.class, + () -> readPrimitiveArrayRawPayload(compressedFory, long[].class)); + } + + private static Object readPrimitiveArrayPayload( + Fory fory, Class arrayType, int byteSize, boolean outOfBand) { + ReadContext readContext = fory.getReadContext(); + if (outOfBand) { + MemoryBuffer control = MemoryBuffer.newHeapBuffer(1); + control.writeBoolean(false); + readContext.prepare( + control, Collections.singletonList(MemoryUtils.wrap(new byte[byteSize])), true); + } else { + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); + buffer.writeVarUInt32Small7(byteSize); + readContext.prepare(buffer, null, false); + } + return fory.getSerializer(arrayType).read(readContext); + } + + private static Object readPrimitiveArrayRawPayload(Fory fory, Class arrayType) { + ReadContext readContext = fory.getReadContext(); + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); + writeNegativeDecodedVarUInt32(buffer); + readContext.prepare(buffer, null, false); + return fory.getSerializer(arrayType).read(readContext); + } + + private static Object readObjectArrayPayload(Fory fory, Class arrayType, int numElements) { + ReadContext readContext = fory.getReadContext(); + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); + buffer.writeVarUInt32Small7(numElements); + readContext.prepare(buffer, null, false); + return fory.getSerializer(arrayType).read(readContext); + } + + private static void writeNegativeDecodedVarUInt32(MemoryBuffer buffer) { + buffer.writeByte(0x80); + buffer.writeByte(0x80); + buffer.writeByte(0x80); + buffer.writeByte(0x80); + buffer.writeByte(0x08); + } + @Test(dataProvider = "referenceTrackingConfig") public void testArrayZeroCopy(boolean referenceTracking) { ForyBuilder builder = diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/BufferSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/BufferSerializersTest.java index 740f0bd6f1..e5c7e4a14b 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/BufferSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/BufferSerializersTest.java @@ -73,4 +73,25 @@ public void testByteBufferRejectsMalformedPayload() { org.testng.Assert.assertThrows( DeserializationException.class, () -> readSerializer(fory, serializer, invalidOrder)); } + + @Test + public void testBufferObjectRejectsInvalidInBandSizeWithoutBinaryCap() { + Fory fory = Fory.builder().withXlang(true).build(); + Serializer serializer = + new BufferSerializers.ByteBufferSerializer(fory.getTypeResolver(), ByteBuffer.class); + + MemoryBuffer negativeSize = MemoryBuffer.newHeapBuffer(16); + negativeSize.writeBoolean(true); + negativeSize.writeVarUInt32(-1); + org.testng.Assert.assertThrows( + IllegalArgumentException.class, () -> readSerializer(fory, serializer, negativeSize)); + + MemoryBuffer truncated = MemoryBuffer.newHeapBuffer(16); + truncated.writeBoolean(true); + truncated.writeVarUInt32(2); + truncated.writeByte(0); + MemoryBuffer truncatedPayload = truncated.slice(0, truncated.writerIndex()); + org.testng.Assert.assertThrows( + IndexOutOfBoundsException.class, () -> readSerializer(fory, serializer, truncatedPayload)); + } } diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/JdkProxySerializerTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/JdkProxySerializerTest.java index 3d008942e7..e51f84475a 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/JdkProxySerializerTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/JdkProxySerializerTest.java @@ -21,6 +21,7 @@ import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotSame; +import static org.testng.Assert.assertThrows; import static org.testng.Assert.assertTrue; import java.io.ObjectStreamException; @@ -31,6 +32,7 @@ import java.util.function.Function; import org.apache.fory.Fory; import org.apache.fory.ForyTestBase; +import org.apache.fory.exception.InsecureException; import org.apache.fory.reflect.ReflectionUtils; import org.testng.annotations.Test; @@ -61,6 +63,30 @@ public void testJdkProxy(boolean referenceTracking) { assertEquals(deserializedFunction.apply(null), 1); } + @Test + public void testJdkProxyInterfaceClassHonorsTypeCheckerFalse() { + Fory writer = + Fory.builder() + .withXlang(false) + .withRefTracking(true) + .requireClassRegistration(false) + .build(); + Function function = + (Function) + Proxy.newProxyInstance( + writer.getClassLoader(), new Class[] {Function.class}, new TestInvocationHandler()); + byte[] bytes = writer.serialize(function); + + Fory reader = + Fory.builder() + .withXlang(false) + .withRefTracking(true) + .requireClassRegistration(false) + .withTypeChecker((resolver, className) -> !className.equals(Function.class.getName())) + .build(); + assertThrows(InsecureException.class, () -> reader.deserialize(bytes)); + } + @Test(dataProvider = "foryCopyConfig") public void testJdkProxy(Fory fory) { Function function = diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/PrimitiveSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/PrimitiveSerializersTest.java index b466d38f72..2956fb583f 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/PrimitiveSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/PrimitiveSerializersTest.java @@ -42,6 +42,8 @@ import org.apache.fory.collection.UInt8List; import org.apache.fory.config.ForyBuilder; import org.apache.fory.config.Int64Encoding; +import org.apache.fory.context.ReadContext; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.memory.MemoryBuffer; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; @@ -246,6 +248,65 @@ public void testPrimitiveListAsCollectionFieldWithCodegen() { assertEquals(((Int8List) roundTrip.int8Values).copyArray(), new byte[] {1, -2, 3}); } + @Test + public void testPrimitiveListReadRejectsMalformedBinaryPayloadSize() { + Fory fory = + Fory.builder() + .withMaxBinarySize(4) + .withIntArrayCompressed(true) + .withLongArrayCompressed(true) + .build(); + assertThrows( + DeserializationException.class, () -> readPrimitiveListPayload(fory, Int8List.class, 5)); + assertThrows( + DeserializationException.class, () -> readPrimitiveListPayload(fory, Int16List.class, 3)); + assertThrows( + DeserializationException.class, () -> readPrimitiveListPayload(fory, Int32List.class, 2)); + assertThrows( + DeserializationException.class, () -> readPrimitiveListPayload(fory, Int64List.class, 1)); + } + + @Test + public void testPrimitiveListReadRejectsNegativeDecodedBinaryPayload() { + Fory fixedWidthFory = Fory.builder().build(); + assertThrows( + DeserializationException.class, + () -> readPrimitiveListRawPayload(fixedWidthFory, Int16List.class)); + + Fory compressedFory = + Fory.builder().withIntArrayCompressed(true).withLongArrayCompressed(true).build(); + assertThrows( + DeserializationException.class, + () -> readPrimitiveListRawPayload(compressedFory, Int32List.class)); + assertThrows( + DeserializationException.class, + () -> readPrimitiveListRawPayload(compressedFory, Int64List.class)); + } + + private static Object readPrimitiveListPayload(Fory fory, Class listType, int headerSize) { + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); + buffer.writeVarUInt32Small7(headerSize); + ReadContext readContext = fory.getReadContext(); + readContext.prepare(buffer, null, false); + return fory.getSerializer(listType).read(readContext); + } + + private static Object readPrimitiveListRawPayload(Fory fory, Class listType) { + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); + writeNegativeDecodedVarUInt32(buffer); + ReadContext readContext = fory.getReadContext(); + readContext.prepare(buffer, null, false); + return fory.getSerializer(listType).read(readContext); + } + + private static void writeNegativeDecodedVarUInt32(MemoryBuffer buffer) { + buffer.writeByte(0x80); + buffer.writeByte(0x80); + buffer.writeByte(0x80); + buffer.writeByte(0x80); + buffer.writeByte(0x08); + } + @Test public void testPrimitiveListCopyTracksReferences() { Fory fory = diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/SerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/SerializersTest.java index c97ef68686..4e1adfe6c2 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/SerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/SerializersTest.java @@ -42,6 +42,7 @@ import org.apache.fory.Fory; import org.apache.fory.ForyTestBase; import org.apache.fory.config.ForyBuilder; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.memory.MemoryUtils; import org.testng.Assert; @@ -78,6 +79,56 @@ public void testBigInt(boolean referenceTracking) { fory1, new BigInteger("11111111110101010000283895380202208220050200000000111111111")); } + @Test + public void testBigNumberReadsRejectOversizedBinaryPayload() { + Fory fory = Fory.builder().withMaxBinarySize(1).requireClassRegistration(false).build(); + + assertThrows( + DeserializationException.class, + () -> readSerializer(fory, fory.getSerializer(BigInteger.class), bigIntegerPayload(2))); + assertThrows( + DeserializationException.class, + () -> readSerializer(fory, fory.getSerializer(BigDecimal.class), bigDecimalPayload(2))); + + Fory xlangFory = + Fory.builder().withXlang(true).withMaxBinarySize(1).requireClassRegistration(false).build(); + assertThrows( + DeserializationException.class, + () -> + readSerializer( + xlangFory, xlangFory.getSerializer(BigInteger.class), xlangDecimalPayload(2))); + assertThrows( + DeserializationException.class, + () -> + readSerializer( + xlangFory, xlangFory.getSerializer(BigDecimal.class), xlangDecimalPayload(2))); + } + + private static MemoryBuffer bigIntegerPayload(int len) { + MemoryBuffer buffer = MemoryUtils.buffer(16); + buffer.writeVarUInt32Small7(len); + buffer.writeBytes(new byte[len]); + return buffer; + } + + private static MemoryBuffer bigDecimalPayload(int len) { + MemoryBuffer buffer = MemoryUtils.buffer(16); + buffer.writeVarUInt32Small7(0); + buffer.writeVarUInt32Small7(1); + buffer.writeVarUInt32Small7(len); + buffer.writeBytes(new byte[len]); + return buffer; + } + + private static MemoryBuffer xlangDecimalPayload(int len) { + MemoryBuffer buffer = MemoryUtils.buffer(16); + buffer.writeVarInt32(0); + long meta = (long) len << 1; + buffer.writeVarUInt64((meta << 1) | 1L); + buffer.writeBytes(new byte[len]); + return buffer; + } + @Test(dataProvider = "referenceTrackingConfig") public void testXlangDecimalRoundTrip(boolean referenceTracking) { ForyBuilder builder = diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/CollectionSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/CollectionSerializersTest.java index bda683d46e..8c6c96ed9d 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/CollectionSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/CollectionSerializersTest.java @@ -67,6 +67,7 @@ import org.apache.fory.Fory; import org.apache.fory.ForyTestBase; import org.apache.fory.context.ReadContext; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.memory.MemoryUtils; import org.apache.fory.reflect.TypeRef; @@ -975,6 +976,74 @@ public void testSerializeJavaBlockingQueue() { } } + @Test + public void testDeserializeJavaBlockingQueueRejectsMalformedCapacity() { + Fory fory = + Fory.builder() + .withXlang(false) + .withRefTracking(true) + .requireClassRegistration(false) + .withMaxCollectionSize(4) + .build(); + CollectionSerializers.ArrayBlockingQueueSerializer arraySerializer = + new CollectionSerializers.ArrayBlockingQueueSerializer( + fory.getTypeResolver(), ArrayBlockingQueue.class); + CollectionSerializers.LinkedBlockingQueueSerializer linkedSerializer = + new CollectionSerializers.LinkedBlockingQueueSerializer( + fory.getTypeResolver(), LinkedBlockingQueue.class); + + MemoryBuffer oversizedCapacity = MemoryUtils.buffer(8); + oversizedCapacity.writeVarUInt32Small7(2); + oversizedCapacity.writeVarUInt32Small7(5); + Assert.expectThrows( + DeserializationException.class, + () -> withReadContext(fory, oversizedCapacity, arraySerializer::newCollection)); + + MemoryBuffer undersizedCapacity = MemoryUtils.buffer(8); + undersizedCapacity.writeVarUInt32Small7(2); + undersizedCapacity.writeVarUInt32Small7(1); + Assert.expectThrows( + DeserializationException.class, + () -> withReadContext(fory, undersizedCapacity, linkedSerializer::newCollection)); + } + + @Test + public void testCollectionReadRejectsOversizedElementCount() { + Fory fory = + Fory.builder() + .withXlang(false) + .withRefTracking(true) + .requireClassRegistration(false) + .withMaxCollectionSize(1) + .build(); + CollectionSerializers.ArrayListSerializer serializer = + new CollectionSerializers.ArrayListSerializer(fory.getTypeResolver()); + MemoryBuffer buffer = MemoryUtils.buffer(8); + buffer.writeVarUInt32Small7(2); + Assert.expectThrows( + DeserializationException.class, + () -> withReadContext(fory, buffer, serializer::newCollection)); + } + + @Test + public void testBitSetReadRejectsNegativeDecodedBinaryPayload() { + Fory fory = Fory.builder().build(); + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); + writeNegativeDecodedVarUInt32(buffer); + ReadContext readContext = fory.getReadContext(); + readContext.prepare(buffer, null, false); + Assert.expectThrows( + DeserializationException.class, () -> fory.getSerializer(BitSet.class).read(readContext)); + } + + private static void writeNegativeDecodedVarUInt32(MemoryBuffer buffer) { + buffer.writeByte(0x80); + buffer.writeByte(0x80); + buffer.writeByte(0x80); + buffer.writeByte(0x80); + buffer.writeByte(0x08); + } + @Test(dataProvider = "foryCopyConfig") public void testSerializeJavaBlockingQueue(Fory fory) { { diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/MapSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/MapSerializersTest.java index 3e6a8c2204..501ca74df9 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/MapSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/MapSerializersTest.java @@ -57,6 +57,9 @@ import org.apache.fory.annotation.Ref; import org.apache.fory.collection.LazyMap; import org.apache.fory.collection.MapEntry; +import org.apache.fory.exception.DeserializationException; +import org.apache.fory.memory.MemoryBuffer; +import org.apache.fory.memory.MemoryUtils; import org.apache.fory.reflect.TypeRef; import org.apache.fory.serializer.Serializer; import org.apache.fory.serializer.collection.CollectionSerializersTest.TestEnum; @@ -555,6 +558,23 @@ public void testEmptyMap() { serDeCheckSerializer(getJavaFory(), Collections.emptySortedMap(), "EmptySortedMap"); } + @Test + public void testMapReadRejectsOversizedElementCount() { + Fory fory = + Fory.builder() + .withXlang(false) + .withRefTracking(true) + .requireClassRegistration(false) + .withMaxCollectionSize(1) + .build(); + MapSerializers.HashMapSerializer serializer = + new MapSerializers.HashMapSerializer(fory.getTypeResolver()); + MemoryBuffer buffer = MemoryUtils.buffer(8); + buffer.writeVarUInt32Small7(2); + Assert.expectThrows( + DeserializationException.class, () -> withReadContext(fory, buffer, serializer::newMap)); + } + @Test(dataProvider = "foryCopyConfig") public void testEmptyMap(Fory fory) { copyCheckWithoutSame(fory, Collections.EMPTY_MAP); diff --git a/javascript/packages/core/lib/context.ts b/javascript/packages/core/lib/context.ts index 94e4d2e068..53a9cc1820 100644 --- a/javascript/packages/core/lib/context.ts +++ b/javascript/packages/core/lib/context.ts @@ -392,10 +392,7 @@ export class ReadContext { private typeMeta: TypeMeta[] = []; /** Persistent cross-message cache keyed by 8-byte type meta header. */ - private typeMetaCache: Map< - bigint, - { readonly typeMeta: TypeMeta; readonly skipBytesAfterHeader: number } - > = new Map(); + private typeMetaCache: Map = new Map(); private _depth = 0; private _maxDepth: number; @@ -486,17 +483,14 @@ export class ReadContext { let typeMeta: TypeMeta; if (cached) { // Header-cache hits intentionally skip without rehashing. Entries reach this cache only - // after a successful TypeMeta parse and 52-bit body-hash validation. - this.reader.readSkip(cached.skipBytesAfterHeader); - typeMeta = cached.typeMeta; + // after a successful TypeMeta parse and 52-bit body-hash validation. The current body + // size still comes from the current header bytes, not from the cached TypeMeta. + TypeMeta.skipBody(this.reader, header); + typeMeta = cached; } else { - const bodyStart = this.reader.readGetCursor(); typeMeta = TypeMeta.fromBytesAfterHeader(this.reader, header); if (this.typeMetaCache.size < ReadContext.MAX_CACHED_TYPE_META) { - this.typeMetaCache.set(header, { - typeMeta, - skipBytesAfterHeader: this.reader.readGetCursor() - bodyStart, - }); + this.typeMetaCache.set(header, typeMeta); } } this.typeMeta[dynamicTypeId] = typeMeta; diff --git a/javascript/packages/core/lib/reader/index.ts b/javascript/packages/core/lib/reader/index.ts index 754011e9b0..daef64375d 100644 --- a/javascript/packages/core/lib/reader/index.ts +++ b/javascript/packages/core/lib/reader/index.ts @@ -86,6 +86,9 @@ export class BinaryReader { } readSkip(len: number) { + if (len < 0 || len > this.byteLength - this.cursor) { + throw new Error("Insufficient bytes to skip"); + } this.cursor += len; } @@ -188,20 +191,32 @@ export class BinaryReader { } stringUtf8At(start: number, len: number) { - return this.platformBuffer.toString("utf8", start, start + len); + if (start < 0 || len < 0 || start > this.byteLength - len) { + throw new Error("Insufficient bytes for UTF-8 string"); + } + const end = start + len; + return this.platformBuffer.toString("utf8", start, end); } stringUtf8(len: number) { - const result = this.platformBuffer.toString( - "utf8", - this.cursor, - this.cursor + len, - ); + if (len < 0 || len > this.byteLength - this.cursor) { + throw new Error("Insufficient bytes for UTF-8 string"); + } + const end = this.cursor + len; + // JavaScript intentionally preserves platform UTF-8 replacement behavior; Rust is the runtime + // that provides checked UTF-8 string reads by default. + const result = this.platformBuffer.toString("utf8", this.cursor, end); this.cursor += len; return result; } stringUtf16LE(len: number) { + if (len < 0 || len > this.byteLength - this.cursor) { + throw new Error("Insufficient bytes for UTF-16LE string"); + } + if ((len & 1) !== 0) { + throw new Error("UTF-16LE string length must be even"); + } const result = this.platformBuffer.toString( "utf16le", this.cursor, @@ -223,11 +238,14 @@ export class BinaryReader { case UTF16: return this.stringUtf16LE(len); default: - break; + throw new Error(`Unsupported string encoding: ${type}`); } } stringLatin1(len: number) { + if (len < 0 || len > this.byteLength - this.cursor) { + throw new Error("Insufficient bytes for Latin1 string"); + } if (this.sliceStringEnable) { return this.stringLatin1Fast(len); } @@ -247,6 +265,9 @@ export class BinaryReader { } buffer(len: number) { + if (len < 0 || len > this.byteLength - this.cursor) { + throw new Error("Insufficient bytes for buffer"); + } const result = alloc(len); this.platformBuffer.copy(result, 0, this.cursor, this.cursor + len); this.cursor += len; @@ -254,12 +275,18 @@ export class BinaryReader { } bufferRef(len: number) { + if (len < 0 || len > this.byteLength - this.cursor) { + throw new Error("Insufficient bytes for buffer reference"); + } const result = this.platformBuffer.subarray(this.cursor, this.cursor + len); this.cursor += len; return result; } bufferRefAt(start: number, len: number) { + if (start < 0 || len < 0 || start > this.byteLength - len) { + throw new Error("Insufficient bytes for buffer reference"); + } return this.platformBuffer.subarray(start, start + len); } diff --git a/javascript/test/hps.test.ts b/javascript/test/hps.test.ts index 87a52a8c00..189210ae1c 100644 --- a/javascript/test/hps.test.ts +++ b/javascript/test/hps.test.ts @@ -25,7 +25,7 @@ import { describe, expect, test } from '@jest/globals'; const skipableDescribe = (hps ? describe : describe.skip); skipableDescribe('hps', () => { - test.only('should isLatin1 work', () => { + test('should isLatin1 work', () => { const { serializeString } = hps!; for (let index = 0; index < 10000; index++) { const bf = Buffer.alloc(100); diff --git a/javascript/test/io.test.ts b/javascript/test/io.test.ts index 3b76ecc18d..32049120c9 100644 --- a/javascript/test/io.test.ts +++ b/javascript/test/io.test.ts @@ -19,7 +19,7 @@ import { fromUint8Array } from '../packages/core/lib/platformBuffer'; import { BinaryReader } from '../packages/core/lib/reader'; -import { Config, RefFlags } from '../packages/core/lib/type'; +import { Config, RefFlags, UTF8, UTF16 } from '../packages/core/lib/type'; import { BinaryWriter } from '../packages/core/lib/writer'; import { describe, expect, test } from '@jest/globals'; @@ -254,6 +254,19 @@ function num2Bin(num: number) { } }); + test('should reject malformed string payloads', () => { + const reader = new BinaryReader(config); + + reader.reset(new Uint8Array([(4 << 2) | UTF8, 0x61])); + expect(() => reader.stringWithHeader()).toThrow(/Insufficient bytes for UTF-8 string/); + + reader.reset(new Uint8Array([(1 << 2) | UTF16, 0])); + expect(() => reader.stringWithHeader()).toThrow(/UTF-16LE string length must be even/); + + reader.reset(new Uint8Array([3])); + expect(() => reader.stringWithHeader()).toThrow(/Unsupported string encoding/); + }); + test('should buffer work', () => { const writer = new BinaryWriter(config); writer.buffer(new Uint8Array([1, 2, 3, 4, 5])); @@ -268,6 +281,14 @@ function num2Bin(num: number) { expect(ab[4]).toBe(5); }); + test('should reject truncated buffer payloads', () => { + const reader = new BinaryReader(config); + reader.reset(new Uint8Array([1, 2])); + expect(() => reader.buffer(3)).toThrow(/Insufficient bytes for buffer/); + expect(() => reader.bufferRef(3)).toThrow(/Insufficient bytes for buffer reference/); + expect(() => reader.bufferRefAt(1, 2)).toThrow(/Insufficient bytes for buffer reference/); + }); + test('should bufferWithoutMemCheck work', () => { const writer = new BinaryWriter(config); writer.bufferWithoutMemCheck(fromUint8Array(new Uint8Array([1, 2, 3, 4, 5])), 5); diff --git a/javascript/test/typemeta.test.ts b/javascript/test/typemeta.test.ts index f5040f7572..473ec38418 100644 --- a/javascript/test/typemeta.test.ts +++ b/javascript/test/typemeta.test.ts @@ -18,8 +18,10 @@ */ import Fory, { Type } from "../packages/core/index"; +import { ReadContext } from "../packages/core/lib/context"; import { TypeMeta } from "../packages/core/lib/meta/TypeMeta"; import { BinaryReader } from "../packages/core/lib/reader"; +import { BinaryWriter } from "../packages/core/lib/writer"; import { describe, expect, test } from "@jest/globals"; const COMPRESS_META_FLAG = 1n << 8n; @@ -101,6 +103,39 @@ describe("typemeta", () => { expect(skipReader.readGetCursor()).toBe(bytes.length); }); + test("TypeMeta header cache hit skips the current body size", () => { + const header = 0xffn; + const typeMeta = TypeMeta.fromTypeInfo(Type.struct(7010, {})); + const writer = new BinaryWriter({}); + writer.writeVarUInt32(0); + writer.writeUint64(header); + writer.writeVarUInt32(0); + writer.buffer(new Uint8Array(0xff)); + writer.buffer(new Uint8Array([0x7b])); + + const config = { ref: false, useSliceString: false, hooks: {} } as any; + const context = new ReadContext( + { + config, + trackingRef: false, + computeTypeId: (typeInfo: any) => typeInfo.typeId, + getSerializerById: () => undefined, + getSerializerByName: () => undefined, + getSerializerByData: () => undefined, + isCompatible: () => false, + regenerateReadSerializer: () => { + throw new Error("unused"); + }, + } as any, + config, + ); + (context as any).typeMetaCache.set(header, typeMeta); + context.reset(writer.dump()); + + expect(context.readTypeMeta()).toBe(typeMeta); + expect(context.reader.readUint8()).toBe(0x7b); + }); + test("encodes extended id-registered struct field counts without the name bit", () => { const fields: Record = {}; for (let i = 0; i < 32; i++) { diff --git a/rust/fory-core/src/buffer.rs b/rust/fory-core/src/buffer.rs index cc7b389eb1..33a96d9344 100644 --- a/rust/fory-core/src/buffer.rs +++ b/rust/fory-core/src/buffer.rs @@ -952,6 +952,8 @@ impl<'a> Reader<'a> { pub fn read_utf8_string(&mut self, len: usize) -> Result { self.check_bound(len)?; let src = &self.bf[self.cursor..self.cursor + len]; + // Rust is the only runtime that checks UTF-8 string payloads by default; other runtimes + // preserve their platform replacement behavior for invalid byte sequences. let string = std::str::from_utf8(src).map_err(|_| Error::encoding_error("invalid UTF-8 string"))?; let string = string.to_owned(); diff --git a/rust/fory-core/src/meta/meta_string.rs b/rust/fory-core/src/meta/meta_string.rs index 962ac4b437..53a62d8abe 100644 --- a/rust/fory-core/src/meta/meta_string.rs +++ b/rust/fory-core/src/meta/meta_string.rs @@ -67,6 +67,22 @@ impl std::hash::Hash for MetaString { } } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn rejects_invalid_utf8_meta_string() { + let err = TYPE_NAME_DECODER + .decode(&[0xff], Encoding::Utf8) + .unwrap_err(); + assert!( + err.to_string().contains("invalid UTF-8 meta string"), + "unexpected error: {err}" + ); + } +} + static EMPTY: OnceLock = OnceLock::new(); impl MetaString { @@ -478,7 +494,9 @@ impl MetaStringDecoder { Encoding::AllToLowerSpecial => { self.decode_rep_all_to_lower_special(encoded_data) } - Encoding::Utf8 => Ok(String::from_utf8_lossy(encoded_data).into_owned()), + Encoding::Utf8 => std::str::from_utf8(encoded_data) + .map(str::to_owned) + .map_err(|_| Error::encoding_error("invalid UTF-8 meta string")), } } }?; diff --git a/rust/fory-core/src/meta/type_meta.rs b/rust/fory-core/src/meta/type_meta.rs index 4338c758fb..7b3b8ef378 100644 --- a/rust/fory-core/src/meta/type_meta.rs +++ b/rust/fory-core/src/meta/type_meta.rs @@ -1096,6 +1096,9 @@ impl TypeMeta { reader: &mut Reader, header: i64, ) -> Result<(), Error> { + // Header-cache hits intentionally treat the current body as opaque bytes and skip by the + // current header size. Parsed TypeMeta entries are cached only after body parse and hash + // validation; cache hits must not reparse or rehash that body. let mut meta_size = (header & META_SIZE_MASK) as usize; if meta_size == META_SIZE_MASK as usize { meta_size += reader.read_var_u32()? as usize; diff --git a/rust/fory-core/src/util/string_util.rs b/rust/fory-core/src/util/string_util.rs index 26bb768ca3..5e0d1ab35c 100644 --- a/rust/fory-core/src/util/string_util.rs +++ b/rust/fory-core/src/util/string_util.rs @@ -771,18 +771,14 @@ pub mod buffer_rw_string { #[inline] pub fn read_utf8_standard(reader: &mut Reader, len: usize) -> Result { - unsafe { - let mut vec = Vec::with_capacity(len); - let src = reader.bf.as_ptr().add(reader.cursor); - let dst = vec.as_mut_ptr(); - // Use fastest possible copy - copy_nonoverlapping compiles to memcpy - std::ptr::copy_nonoverlapping(src, dst, len); - vec.set_len(len); - reader.move_next(len); - // Use from_utf8_lossy for safety - handles invalid UTF-8 gracefully - // If you're certain the data is valid UTF-8, use from_utf8_unchecked for more performance - Ok(String::from_utf8_lossy(&vec).into_owned()) - } + let slice = reader.sub_slice(reader.get_cursor(), reader.get_cursor() + len)?; + // Rust is the only runtime that checks UTF-8 string payloads by default; borrow first so + // the check adds no temporary Vec before constructing the final String. + let value = std::str::from_utf8(slice) + .map_err(|_| Error::encoding_error("invalid UTF-8 string"))? + .to_owned(); + reader.move_next(len); + Ok(value) } #[inline]