().having(
+ (error) => error.toString(),
+ 'message',
+ contains('metadata hash'),
+ ),
+ ),
+ );
+ });
});
}
diff --git a/docs/specification/java_serialization_spec.md b/docs/specification/java_serialization_spec.md
index 2bac7bc0da..c52695f1cc 100644
--- a/docs/specification/java_serialization_spec.md
+++ b/docs/specification/java_serialization_spec.md
@@ -44,14 +44,14 @@ Java native serialization writes a one byte bitmap header. The header layout mir
bitmap and uses the same flag bits.
```
-| 5 bits | 1 bit | 1 bit | 1 bit |
-+--------------+-------+-------+-------+
-| reserved | oob | xlang | null |
+| 6 bits | 1 bit | 1 bit |
++---------------+-------+-------+
+| reserved | oob | xlang |
```
-- null flag: 1 when object is null, 0 otherwise. If object is null, other bits are not set.
-- xlang flag: 1 when serialization uses xlang format, 0 when serialization uses Java native format.
-- oob flag: 1 when `BufferCallback` is not null, 0 otherwise.
+- xlang flag: bit 0, set when serialization uses xlang format and clear for Java native format.
+- oob flag: bit 1, set when `BufferCallback` is not null.
+- reserved bits: bits 2-7, must be zero.
The header is always a single byte; no language ID is written.
@@ -202,32 +202,37 @@ when shared meta is enabled, or referenced by index when already seen.
Header layout (lower bits on the right):
```
-| 50-bit hash | 4 bits reserved | 1 bit compress | 1 bit has_fields_meta | 8-bit size |
+| 52-bit hash | 3 bits reserved | 1 bit compress | 8-bit size |
```
- size: lower 8 bits. If size equals the mask (0xFF), write extra size as varuint32 and add it.
-- compress: set when payload is compressed.
-- has_fields_meta: set when field metadata is present.
-- reserved: bits 10-13 are reserved for future use and must be zero.
-- hash: 50-bit hash of the payload and flags.
+- compress: bit 8, set when payload is compressed.
+- reserved: bits 9-11 are reserved for future use and must be zero.
+- hash: 52-bit hash of the payload.
### Class meta bytes
Class meta encodes a linearized class hierarchy (from parent to leaf) and field metadata:
```
-| num_classes | class_layer_0 | class_layer_1 | ... |
+| root_kind_and_num_classes | class_layer_0 | class_layer_1 | ... |
class_layer:
| num_fields << 1 | registered_flag | [type_id if registered] |
| namespace | type_name | field_infos |
```
-- `num_classes` stores `(num_layers - 1)` in a single byte.
- - If it equals `0b1111`, read an extra varuint32 small7 and add it.
+- `root_kind_and_num_classes` stores the root TypeDef kind in the high four bits and
+ `(num_layers - 1)` in the low four bits.
+ - Root kind codes are `STRUCT=0`, `COMPATIBLE_STRUCT=1`, `NAMED_STRUCT=2`,
+ `NAMED_COMPATIBLE_STRUCT=3`, `ENUM=4`, `NAMED_ENUM=5`, `EXT=6`, `NAMED_EXT=7`,
+ `TYPED_UNION=8`, and `NAMED_UNION=9`.
+ - Kind codes `10-14` are reserved and `15` is an extended-kind escape rejected until defined.
+ - If the low four bits equal `0b1111`, read an extra varuint32 small7 and add it.
- The actual number of layers is `num_classes + 1`.
- `registered_flag` is 1 if the class is registered by numeric ID.
-- If registered by ID, the class type ID follows (varuint32 small7).
+- If registered by ID, the one-byte class type ID follows. For user-registered ID kinds, the
+ user type ID follows as varuint32.
- If registered by name or unregistered, namespace and type name are written as meta strings.
### Field info
diff --git a/docs/specification/xlang_serialization_spec.md b/docs/specification/xlang_serialization_spec.md
index 24adbe3331..b1743918a9 100644
--- a/docs/specification/xlang_serialization_spec.md
+++ b/docs/specification/xlang_serialization_spec.md
@@ -323,15 +323,14 @@ Detailed byte layout:
```
Byte 0: Bitmap flags
- - Bit 0: null flag (0x01)
- - Bit 1: xlang flag (0x02)
- - Bit 2: oob flag (0x04)
- - Bits 3-7: reserved
+ - Bit 0: xlang flag (0x01)
+ - Bit 1: oob flag (0x02)
+ - Bits 2-7: reserved
```
-- **null flag** (bit 0): 1 when object is null, 0 otherwise. If an object is null, only this flag is set.
-- **xlang flag** (bit 1): 1 when serialization uses Fory xlang format, 0 when serialization uses Fory language-native format.
-- **oob flag** (bit 2): 1 when out-of-band serialization is enabled (BufferCallback is not null), 0 otherwise.
+- **xlang flag** (bit 0): 1 when serialization uses Fory xlang format, 0 when serialization uses Fory language-native format.
+- **oob flag** (bit 1): 1 when out-of-band serialization is enabled (BufferCallback is not null), 0 otherwise.
+- **reserved bits** (bits 2-7): must be zero.
All data is encoded in little-endian format.
@@ -536,12 +535,11 @@ The 8-byte header is a little-endian uint64:
- Low 8 bits: meta size (number of bytes in the TypeDef body).
- If meta size >= 0xFF, the low 8 bits are set to 0xFF and an extra
`varuint32(meta_size - 0xFF)` follows immediately after the header.
-- Bit 8: `HAS_FIELDS_META` (1 = fields metadata present).
-- Bit 9: `COMPRESS_META` is reserved for a future xlang metadata-compression extension.
+- Bit 8: `COMPRESS_META` is reserved for a future xlang metadata-compression extension.
Current xlang writers MUST leave this bit unset and current xlang readers MUST treat a set bit
as unsupported.
-- Bits 10-13: reserved for future extension (must be zero).
-- High 50 bits: hash of the TypeDef body.
+- Bits 9-11: reserved for future extension (must be zero).
+- High 52 bits: hash of the TypeDef body.
#### TypeDef body
@@ -551,12 +549,30 @@ TypeDef body has a single layer (fields are flattened in class hierarchy order):
| meta header (1 byte) | type spec | field info ... |
```
-Meta header byte:
+Meta header byte for struct TypeDefs:
+- Bit 7: `IS_STRUCT` (1).
+- Bit 6: `COMPATIBLE`.
+- Bit 5: `REGISTER_BY_NAME` (1 = namespace + type name, 0 = numeric user type ID).
- Bits 0-4: `num_fields` (0-30).
- If `num_fields == 31`, read an extra `varuint32` and add it.
-- Bit 5: `REGISTER_BY_NAME` (1 = namespace + type name, 0 = numeric type ID).
-- Bits 6-7: reserved.
+
+Meta header byte for non-struct TypeDefs:
+
+- Bit 7: `IS_STRUCT` (0).
+- Bits 4-6: reserved (must be zero).
+- Bits 0-3: kind code.
+
+Non-struct kind codes:
+
+- `0`: `ENUM`
+- `1`: `NAMED_ENUM`
+- `2`: `EXT`
+- `3`: `NAMED_EXT`
+- `4`: `TYPED_UNION`
+- `5`: `NAMED_UNION`
+- `6-14`: reserved
+- `15`: extended-kind escape, rejected until defined
Type spec:
@@ -564,7 +580,7 @@ Type spec:
- `namespace` meta string
- `type_name` meta string
- Otherwise:
- - `type_id` as `varuint32` (small7)
+ - user type ID as `varuint32`
Field info list:
diff --git a/go/fory/buffer.go b/go/fory/buffer.go
index c51071fd14..2e0f13655e 100644
--- a/go/fory/buffer.go
+++ b/go/fory/buffer.go
@@ -1224,8 +1224,8 @@ func (b *ByteBuffer) ReadVarint32(err *Error) int32 {
// UnsafeReadVarint32 reads a varint32 without bounds checking.
// Caller must ensure remaining() >= 5 before calling.
-func (b *ByteBuffer) UnsafeReadVarint32() int32 {
- u := b.readVarUint32Fast()
+func (b *ByteBuffer) UnsafeReadVarint32(err *Error) int32 {
+ u := b.readVarUint32Fast(err)
v := int32(u >> 1)
if u&1 != 0 {
v = ^v
@@ -1246,8 +1246,8 @@ func (b *ByteBuffer) UnsafeReadVarint64() int64 {
// UnsafeReadVarUint32 reads a VarUint32 without bounds checking.
// Caller must ensure remaining() >= 5 before calling.
-func (b *ByteBuffer) UnsafeReadVarUint32() uint32 {
- return b.readVarUint32Fast()
+func (b *ByteBuffer) UnsafeReadVarUint32(err *Error) uint32 {
+ return b.readVarUint32Fast(err)
}
// UnsafeReadVarUint64 reads a VarUint64 without bounds checking.
@@ -1259,13 +1259,13 @@ func (b *ByteBuffer) UnsafeReadVarUint64() uint64 {
// ReadVarUint32 reads a VarUint32 and sets error on bounds violation
func (b *ByteBuffer) ReadVarUint32(err *Error) uint32 {
if b.remaining() >= 8 { // Need 8 bytes for bulk uint64 read in fast path
- return b.readVarUint32Fast()
+ return b.readVarUint32Fast(err)
}
return b.readVarUint32Slow(err)
}
// Fast path reading (when the remaining bytes are sufficient)
-func (b *ByteBuffer) readVarUint32Fast() uint32 {
+func (b *ByteBuffer) readVarUint32Fast(err *Error) uint32 {
// Single instruction load using unsafe pointer cast (little-endian only)
// On big-endian systems, use binary.LittleEndian which the compiler optimizes
var bulk uint64
@@ -1288,6 +1288,13 @@ func (b *ByteBuffer) readVarUint32Fast() uint32 {
result |= uint32((bulk >> 3) & 0xFE00000)
readLength = 4
if (bulk & 0x80000000) != 0 {
+ fifth := byte(bulk >> 32)
+ if fifth > 0x0F {
+ if err != nil {
+ *err = DeserializationError("VarUint32 overflow")
+ }
+ return 0
+ }
result |= uint32((bulk >> 4) & 0xF0000000)
readLength = 5
}
@@ -1310,6 +1317,12 @@ func (b *ByteBuffer) readVarUint32Slow(err *Error) uint32 {
}
byteVal := b.data[b.readerIndex]
b.readerIndex++
+ if shift == 28 && byteVal > 0x0F {
+ if err != nil {
+ *err = DeserializationError("VarUint32 overflow")
+ }
+ return 0
+ }
result |= (uint32(byteVal) & 0x7F) << shift
if byteVal < 0x80 {
break
@@ -1475,16 +1488,16 @@ func (b *ByteBuffer) readVarUint32Small14(err *Error) uint32 {
readIdx++
value |= (four >> 1) & 0x3f80
if four&0x8000 != 0 {
- return b.continueReadVarUint32(readIdx, four, value)
+ return b.continueReadVarUint32(readIdx, four, value, err)
}
}
b.readerIndex = readIdx
return value
}
- return uint32(b.readVaruint36Slow(err))
+ return b.readVarUint32Slow(err)
}
-func (b *ByteBuffer) continueReadVarUint32(readIdx int, bulkRead, value uint32) uint32 {
+func (b *ByteBuffer) continueReadVarUint32(readIdx int, bulkRead, value uint32, err *Error) uint32 {
readIdx++
value |= (bulkRead >> 2) & 0x1fc000
if bulkRead&0x800000 != 0 {
@@ -1492,6 +1505,12 @@ func (b *ByteBuffer) continueReadVarUint32(readIdx int, bulkRead, value uint32)
value |= (bulkRead >> 3) & 0xfe00000
if bulkRead&0x80000000 != 0 {
v := b.data[readIdx]
+ if v > 0x0F {
+ if err != nil {
+ *err = DeserializationError("VarUint32 overflow")
+ }
+ return 0
+ }
readIdx++
value |= uint32(v&0x7F) << 28
}
diff --git a/go/fory/buffer_test.go b/go/fory/buffer_test.go
index a65d49a7a9..b4f2022389 100644
--- a/go/fory/buffer_test.go
+++ b/go/fory/buffer_test.go
@@ -18,6 +18,7 @@
package fory
import (
+ "bytes"
"testing"
"github.com/stretchr/testify/require"
@@ -111,3 +112,29 @@ func TestUnsafePutVarUint32PhysicalWriteWidth(t *testing.T) {
"byte at index %d is outside the 8-byte reserved window and must not be written", i)
}
}
+
+func TestReadVarUint32RejectsOverflowFifthByte(t *testing.T) {
+ for _, data := range [][]byte{
+ {0x80, 0x80, 0x80, 0x80, 0x10},
+ {0x80, 0x80, 0x80, 0x80, 0x10, 0, 0, 0},
+ } {
+ buf := NewByteBuffer(data)
+ var err Error
+ _ = buf.ReadVarUint32(&err)
+ require.True(t, err.HasError(), "expected overflow error for %v", data)
+ }
+}
+
+func TestReadVarUint32Small7RejectsOverflowFifthByte(t *testing.T) {
+ buf := NewByteBuffer([]byte{0x80, 0x80, 0x80, 0x80, 0x10})
+ var err Error
+ _ = buf.ReadVarUint32Small7(&err)
+ require.True(t, err.HasError())
+}
+
+func TestReadVarUint32Small7StreamRejectsOverflowFifthByte(t *testing.T) {
+ buf := NewByteBufferFromReader(bytes.NewReader([]byte{0x80, 0x80, 0x80, 0x80, 0x10}), 4)
+ var err Error
+ _ = buf.ReadVarUint32Small7(&err)
+ require.True(t, err.HasError())
+}
diff --git a/go/fory/decimal.go b/go/fory/decimal.go
index 92d6e15bea..5bc3a0d540 100644
--- a/go/fory/decimal.go
+++ b/go/fory/decimal.go
@@ -77,8 +77,8 @@ func (s decimalSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool
return
}
}
- if readType {
- _ = ctx.buffer.ReadUint8(err)
+ if readType && !ctx.readExpectedTypeID(DECIMAL) {
+ return
}
if ctx.HasError() {
return
diff --git a/go/fory/enum.go b/go/fory/enum.go
index 1a1e63c617..b8f615b977 100644
--- a/go/fory/enum.go
+++ b/go/fory/enum.go
@@ -85,7 +85,27 @@ func (s *enumSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool,
}
}
if readType {
- _ = ctx.buffer.ReadUint8(err)
+ typeID := uint32(ctx.buffer.ReadUint8(err))
+ if ctx.HasError() {
+ return
+ }
+ internalID := TypeId(typeID)
+ if internalID != ENUM && internalID != NAMED_ENUM {
+ ctx.SetError(TypeMismatchError(internalID, ENUM))
+ return
+ }
+ typeInfo := ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeID, err)
+ if ctx.HasError() {
+ return
+ }
+ if typeInfo == nil || typeInfo.Type != s.type_ {
+ var actualType reflect.Type
+ if typeInfo != nil {
+ actualType = typeInfo.Type
+ }
+ ctx.SetError(DeserializationErrorf("enum type mismatch: expected %v, got %v", s.type_, actualType))
+ return
+ }
}
if ctx.HasError() {
return
diff --git a/go/fory/enum_test.go b/go/fory/enum_test.go
new file mode 100644
index 0000000000..f92c4d086d
--- /dev/null
+++ b/go/fory/enum_test.go
@@ -0,0 +1,98 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package fory
+
+import (
+ "reflect"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+type auditEnum int32
+type otherAuditEnum int32
+type namedAuditEnum int32
+
+func TestEnumReadConsumesRegisteredTypeInfo(t *testing.T) {
+ f := NewFory(WithXlang(true))
+ require.NoError(t, f.RegisterEnum(auditEnum(0), 101))
+ enumType := reflect.TypeOf(auditEnum(0))
+ serializer, err := f.typeResolver.getSerializerByType(enumType, false)
+ require.NoError(t, err)
+ typeInfo := f.typeResolver.typesInfo[enumType]
+ require.NotNil(t, typeInfo)
+
+ buf := NewByteBuffer(nil)
+ bufErr := &Error{}
+ f.typeResolver.WriteTypeInfo(buf, typeInfo, bufErr)
+ require.NoError(t, bufErr.CheckError())
+ buf.WriteVarUint32Small7(2)
+
+ f.readCtx.SetData(buf.Bytes())
+ var result auditEnum
+ serializer.Read(f.readCtx, RefModeNone, true, false, reflect.ValueOf(&result).Elem())
+ require.NoError(t, f.readCtx.CheckError())
+ require.Equal(t, auditEnum(2), result)
+ require.Equal(t, buf.WriterIndex(), f.readCtx.Buffer().ReaderIndex())
+}
+
+func TestEnumReadRejectsMismatchedRegisteredTypeInfo(t *testing.T) {
+ f := NewFory(WithXlang(true))
+ require.NoError(t, f.RegisterEnum(auditEnum(0), 101))
+ require.NoError(t, f.RegisterEnum(otherAuditEnum(0), 102))
+ enumType := reflect.TypeOf(auditEnum(0))
+ otherType := reflect.TypeOf(otherAuditEnum(0))
+ serializer, err := f.typeResolver.getSerializerByType(enumType, false)
+ require.NoError(t, err)
+ otherTypeInfo := f.typeResolver.typesInfo[otherType]
+ require.NotNil(t, otherTypeInfo)
+
+ buf := NewByteBuffer(nil)
+ bufErr := &Error{}
+ f.typeResolver.WriteTypeInfo(buf, otherTypeInfo, bufErr)
+ require.NoError(t, bufErr.CheckError())
+ buf.WriteVarUint32Small7(2)
+
+ f.readCtx.SetData(buf.Bytes())
+ var result auditEnum
+ serializer.Read(f.readCtx, RefModeNone, true, false, reflect.ValueOf(&result).Elem())
+ require.Error(t, f.readCtx.CheckError())
+}
+
+func TestNamedEnumReadConsumesNamedTypeInfo(t *testing.T) {
+ f := NewFory(WithXlang(true))
+ require.NoError(t, f.RegisterNamedEnum(namedAuditEnum(0), "example.NamedAuditEnum"))
+ enumType := reflect.TypeOf(namedAuditEnum(0))
+ serializer, err := f.typeResolver.getSerializerByType(enumType, false)
+ require.NoError(t, err)
+ typeInfo := f.typeResolver.typesInfo[enumType]
+ require.NotNil(t, typeInfo)
+
+ buf := NewByteBuffer(nil)
+ bufErr := &Error{}
+ f.typeResolver.WriteTypeInfo(buf, typeInfo, bufErr)
+ require.NoError(t, bufErr.CheckError())
+ buf.WriteVarUint32Small7(3)
+
+ f.readCtx.SetData(buf.Bytes())
+ var result namedAuditEnum
+ serializer.Read(f.readCtx, RefModeNone, true, false, reflect.ValueOf(&result).Elem())
+ require.NoError(t, f.readCtx.CheckError())
+ require.Equal(t, namedAuditEnum(3), result)
+ require.Equal(t, buf.WriterIndex(), f.readCtx.Buffer().ReaderIndex())
+}
diff --git a/go/fory/fory.go b/go/fory/fory.go
index 61723d190d..f50fe1f770 100644
--- a/go/fory/fory.go
+++ b/go/fory/fory.go
@@ -39,9 +39,9 @@ var ErrNoSerializer = errors.New("fory: no serializer registered for type")
// Bitmap flags for protocol header
const (
- IsNilFlag = 1 << 0
- XLangFlag = 1 << 1
- OutOfBandFlag = 1 << 2
+ XLangFlag = 1 << 0
+ OutOfBandFlag = 1 << 1
+ headerFlagMask = XLangFlag | OutOfBandFlag
)
// ============================================================================
@@ -185,6 +185,9 @@ func New(opts ...Option) *Fory {
f.readCtx.refResolver = f.refResolver
f.readCtx.compatible = f.config.Compatible
f.readCtx.xlang = f.config.IsXlang
+ if f.config.IsXlang {
+ f.readCtx.rootHeader = XLangFlag
+ }
return f
}
@@ -487,13 +490,6 @@ func (f *Fory) Reset() {
// For thread-safe usage, use threadsafe.Fory which copies the data internally.
func (f *Fory) Serialize(value any) ([]byte, error) {
defer f.resetWriteState()
- // Check if value is nil interface OR a nil pointer/slice/map/etc.
- // In Go, `*int32(nil)` wrapped in `any` is NOT equal to `nil`, but we need to serialize it as null.
- if isNilValue(value) {
- // Use Java-compatible null format: 3 bytes (magic + bitmap with isNilFlag)
- writeNullHeader(f.writeCtx)
- return f.writeCtx.buffer.GetByteSlice(0, f.writeCtx.buffer.writerIndex), nil
- }
// WriteData protocol header
writeHeader(f.writeCtx, f.config)
@@ -521,16 +517,11 @@ func (f *Fory) Deserialize(data []byte, v any) error {
defer f.resetReadState()
f.readCtx.SetData(data)
- isNull := readHeader(f.readCtx)
+ readHeader(f.readCtx)
if f.readCtx.HasError() {
return f.readCtx.TakeError()
}
- // Check if the serialized object is null
- if isNull {
- return nil
- }
-
// Deserialize the value - TypeMeta is read inline using streaming protocol
target := reflect.ValueOf(v).Elem()
f.readCtx.ReadValue(target, RefModeTracking, true)
@@ -561,13 +552,6 @@ func (f *Fory) resetWriteState() {
// This is useful when you need to write multiple serialized values to the same buffer.
// Returns error if serialization fails.
func (f *Fory) SerializeTo(buf *ByteBuffer, value any) error {
- // Handle nil values
- if isNilValue(value) {
- // Use Java-compatible null format: 1 byte (bitmap with isNilFlag)
- buf.WriteByte_(IsNilFlag)
- return nil
- }
-
defer f.resetWriteState()
// Temporarily swap buffer
@@ -625,18 +609,12 @@ func (f *Fory) DeserializeFrom(buf *ByteBuffer, v any) error {
origBuffer := f.readCtx.buffer
f.readCtx.buffer = buf
- isNull := readHeader(f.readCtx)
+ readHeader(f.readCtx)
if f.readCtx.HasError() {
f.readCtx.buffer = origBuffer
return f.readCtx.TakeError()
}
- // Check if the serialized object is null
- if isNull {
- f.readCtx.buffer = origBuffer
- return nil
- }
-
// Deserialize the value - TypeMeta is read inline using streaming protocol
target := reflect.ValueOf(v).Elem()
f.readCtx.ReadValue(target, RefModeTracking, true)
@@ -731,21 +709,11 @@ func (f *Fory) DeserializeWithCallbackBuffers(buffer *ByteBuffer, v any, buffers
}
// ReadData and validate header
- isNull := readHeader(f.readCtx)
+ readHeader(f.readCtx)
if f.readCtx.HasError() {
return f.readCtx.TakeError()
}
- // Check if the serialized object is null
- if isNull {
- // v must be a pointer so we can set it to nil
- rv := reflect.ValueOf(v)
- if rv.Kind() == reflect.Ptr && !rv.IsNil() {
- rv.Elem().Set(reflect.Zero(rv.Elem().Type()))
- }
- return nil
- }
-
// v must be a pointer so we can deserialize into it
if v == nil {
return fmt.Errorf("v cannot be nil")
@@ -803,50 +771,34 @@ func writeHeader(ctx *WriteContext, config Config) {
ctx.buffer.WriteByte_(bitmap)
}
-// isNilValue checks if a value is nil, including nil pointers wrapped in any
-// In Go, `*int32(nil)` wrapped in `any` is NOT equal to `nil`, but we need to treat it as null.
-//
-//go:noinline
-func isNilValue(value any) bool {
- if value == nil {
- return true
- }
- rv := reflect.ValueOf(value)
- switch rv.Kind() {
- case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Chan, reflect.Func, reflect.Interface:
- return rv.IsNil()
- }
- return false
-}
-
-// writeNullHeader writes a null object header (1 byte: bitmap with isNilFlag)
-// This is compatible with Java's null serialization format
-//
-//go:noinline
-func writeNullHeader(ctx *WriteContext) {
- ctx.buffer.WriteByte_(IsNilFlag) // bitmap with only isNilFlag set
-}
-
-// Special return value indicating null object in readHeader
-// Using math.MinInt32 to avoid conflict with -1 which is used for "no meta offset"
-const NullObjectMetaOffset int32 = -0x7FFFFFFF
-
// readHeader reads and validates the Fory protocol header
-// Returns true if the serialized object is null
// Sets error on ctx if header is invalid (use ctx.HasError() to check)
-func readHeader(ctx *ReadContext) bool {
+func readHeader(ctx *ReadContext) {
err := ctx.Err()
bitmap := ctx.buffer.ReadByte(err)
if ctx.HasError() {
- return false
+ return
}
-
- // Check if this is a null object - only bitmap with isNilFlag was written
- if (bitmap & IsNilFlag) != 0 {
- return true // is null
+ if bitmap == ctx.rootHeader {
+ return
}
+ readHeaderSlow(ctx, bitmap)
+}
- return false // not null
+//go:noinline
+func readHeaderSlow(ctx *ReadContext, bitmap byte) {
+ if bitmap&^headerFlagMask != 0 {
+ ctx.SetError(DeserializationErrorf("unsupported root header bitmap 0x%02x", bitmap))
+ return
+ }
+ if ((bitmap & XLangFlag) != 0) != ctx.xlang {
+ ctx.SetError(DeserializationErrorf("header bitmap mismatch at xlang bit"))
+ return
+ }
+ if (bitmap&OutOfBandFlag) != 0 && ctx.outOfBandBuffers == nil {
+ ctx.SetError(DeserializationErrorf("out-of-band buffers are required by root header"))
+ return
+ }
}
// ============================================================================
@@ -1025,66 +977,84 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error {
f.readCtx.SetData(data)
// ReadData and validate header
- isNull := readHeader(f.readCtx)
+ readHeader(f.readCtx)
if f.readCtx.HasError() {
return f.readCtx.TakeError()
}
- // Check if the serialized object is null
- if isNull {
- var zero T
- *target = zero
- return nil
- }
-
// Fast path: type switch for common types (Go compiler can optimize this)
// For primitives, read null flag, skip type ID, then read value from buffer
buf := f.readCtx.buffer
err := f.readCtx.Err()
switch t := any(target).(type) {
case *bool:
- _ = buf.ReadInt8(err) // null flag
- _ = buf.ReadUint8(err) // type ID
+ _ = buf.ReadInt8(err) // null flag
+ if !f.readCtx.readExpectedTypeID(BOOL) {
+ return f.readCtx.CheckError()
+ }
*t = buf.ReadBool(err)
return f.readCtx.CheckError()
case *int8:
_ = buf.ReadInt8(err)
- _ = buf.ReadUint8(err)
+ if !f.readCtx.readExpectedTypeID(INT8) {
+ return f.readCtx.CheckError()
+ }
*t = buf.ReadInt8(err)
return f.readCtx.CheckError()
case *int16:
_ = buf.ReadInt8(err)
- _ = buf.ReadUint8(err)
+ if !f.readCtx.readExpectedTypeID(INT16) {
+ return f.readCtx.CheckError()
+ }
*t = buf.ReadInt16(err)
return f.readCtx.CheckError()
case *int32:
_ = buf.ReadInt8(err)
- _ = buf.ReadUint8(err)
+ if !f.readCtx.readExpectedTypeID(VARINT32) {
+ return f.readCtx.CheckError()
+ }
*t = buf.ReadVarint32(err)
return f.readCtx.CheckError()
case *int64:
_ = buf.ReadInt8(err)
- _ = buf.ReadUint8(err)
+ if !f.readCtx.readExpectedTypeID(VARINT64) {
+ return f.readCtx.CheckError()
+ }
*t = buf.ReadVarint64(err)
return f.readCtx.CheckError()
case *int:
_ = buf.ReadInt8(err)
- _ = buf.ReadUint8(err)
+ if strconv.IntSize == 32 {
+ if !f.readCtx.readExpectedTypeID(VARINT32) {
+ return f.readCtx.CheckError()
+ }
+ *t = int(buf.ReadVarint32(err))
+ return f.readCtx.CheckError()
+ }
+ if !f.readCtx.readExpectedTypeID(VARINT64) {
+ return f.readCtx.CheckError()
+ }
*t = int(buf.ReadVarint64(err))
return f.readCtx.CheckError()
case *float32:
_ = buf.ReadInt8(err)
- _ = buf.ReadUint8(err)
+ if !f.readCtx.readExpectedTypeID(FLOAT32) {
+ return f.readCtx.CheckError()
+ }
*t = buf.ReadFloat32(err)
return f.readCtx.CheckError()
case *float64:
_ = buf.ReadInt8(err)
- _ = buf.ReadUint8(err)
+ if !f.readCtx.readExpectedTypeID(FLOAT64) {
+ return f.readCtx.CheckError()
+ }
*t = buf.ReadFloat64(err)
return f.readCtx.CheckError()
case *string:
- _ = buf.ReadInt8(err) // null flag
- _ = buf.ReadUint8(err) // type ID
+ _ = buf.ReadInt8(err) // null flag
+ if !f.readCtx.readExpectedTypeID(STRING) {
+ return f.readCtx.CheckError()
+ }
*t = f.readCtx.ReadString()
return f.readCtx.CheckError()
case *[]byte:
@@ -1116,31 +1086,31 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error {
return f.readCtx.CheckError()
case *map[string]string:
*t = f.readCtx.ReadStringStringMap(RefModeNullOnly, true)
- return nil
+ return f.readCtx.CheckError()
case *map[string]int64:
*t = f.readCtx.ReadStringInt64Map(RefModeNullOnly, true)
- return nil
+ return f.readCtx.CheckError()
case *map[string]int32:
*t = f.readCtx.ReadStringInt32Map(RefModeNullOnly, true)
- return nil
+ return f.readCtx.CheckError()
case *map[string]int:
*t = f.readCtx.ReadStringIntMap(RefModeNullOnly, true)
- return nil
+ return f.readCtx.CheckError()
case *map[string]float64:
*t = f.readCtx.ReadStringFloat64Map(RefModeNullOnly, true)
- return nil
+ return f.readCtx.CheckError()
case *map[string]bool:
*t = f.readCtx.ReadStringBoolMap(RefModeNullOnly, true)
- return nil
+ return f.readCtx.CheckError()
case *map[int32]int32:
*t = f.readCtx.ReadInt32Int32Map(RefModeNullOnly, true)
- return nil
+ return f.readCtx.CheckError()
case *map[int64]int64:
*t = f.readCtx.ReadInt64Int64Map(RefModeNullOnly, true)
- return nil
+ return f.readCtx.CheckError()
case *map[int]int:
*t = f.readCtx.ReadIntIntMap(RefModeNullOnly, true)
- return nil
+ return f.readCtx.CheckError()
default:
// Slow path: use serializer-based deserialization
targetVal := reflect.ValueOf(target).Elem()
diff --git a/go/fory/fory_typed_test.go b/go/fory/fory_typed_test.go
index 8108076572..127c05c891 100644
--- a/go/fory/fory_typed_test.go
+++ b/go/fory/fory_typed_test.go
@@ -132,6 +132,44 @@ func TestSerializeGenericPrimitives(t *testing.T) {
})
}
+func TestDeserializeRejectsRootTypeMismatch(t *testing.T) {
+ f := NewFory()
+
+ data := []byte{0, 0xff, byte(STRING)}
+ var result bool
+ require.Error(t, Deserialize(f, data, &result))
+
+ data = []byte{0, 0xff, byte(BOOL)}
+ var mapResult map[string]string
+ require.Error(t, Deserialize(f, data, &mapResult))
+}
+
+func TestDeserializeRejectsRootPrimitiveSliceTypeMismatch(t *testing.T) {
+ f := NewFory()
+
+ data := []byte{0, 0xff, byte(BINARY)}
+ var int32Result []int32
+ require.Error(t, Deserialize(f, data, &int32Result))
+
+ data = []byte{0, 0xff, byte(INT8_ARRAY)}
+ var byteResult []byte
+ require.Error(t, Deserialize(f, data, &byteResult))
+}
+
+func TestDeserializeByteSliceAcceptsUint8ArrayRootType(t *testing.T) {
+ f := NewFory()
+ buf := NewByteBuffer(nil)
+ buf.WriteByte(0)
+ buf.WriteInt8(NotNullValueFlag)
+ buf.WriteUint8(uint8(UINT8_ARRAY))
+ buf.WriteLength(3)
+ buf.WriteBinary([]byte{1, 2, 3})
+
+ var result []byte
+ require.NoError(t, Deserialize(f, buf.Bytes(), &result))
+ require.Equal(t, []byte{1, 2, 3}, result)
+}
+
// TestSerializeGenericComplex tests Serialize[T]/DeserializeWithCallbackBuffers[T] with complex types.
// These fall back to reflection-based serialization.
func TestSerializeGenericComplex(t *testing.T) {
diff --git a/go/fory/map.go b/go/fory/map.go
index cbce365e2b..a59fd8e652 100644
--- a/go/fory/map.go
+++ b/go/fory/map.go
@@ -484,6 +484,10 @@ func (s mapSerializer) readChunk(ctx *ReadContext, mapVal reflect.Value, header
if ctx.HasError() {
return 0
}
+ if chunkSize == 0 || chunkSize > size {
+ ctx.SetError(DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, size))
+ return 0
+ }
// Read type info if not declared
var keyTypeInfo, valueTypeInfo *TypeInfo
@@ -618,7 +622,9 @@ func readMapRefAndType(ctx *ReadContext, refMode RefMode, readType bool, value r
}
}
if readType {
- buf.ReadUint8(ctxErr)
+ if !ctx.readExpectedTypeID(MAP) {
+ return false
+ }
}
return false
}
diff --git a/go/fory/map_primitive.go b/go/fory/map_primitive.go
index 16e40f5aef..287777eaea 100644
--- a/go/fory/map_primitive.go
+++ b/go/fory/map_primitive.go
@@ -81,50 +81,32 @@ func readMapStringString(ctx *ReadContext) map[string]string {
for size > 0 {
chunkHeader := buf.ReadUint8(err)
- // Handle null key/value cases
- keyHasNull := (chunkHeader & KEY_HAS_NULL) != 0
- valueHasNull := (chunkHeader & VALUE_HAS_NULL) != 0
-
- if keyHasNull && valueHasNull {
- // Both null - use empty strings for key and value
- result[""] = ""
- size--
- continue
- } else if keyHasNull {
- // Null key with non-null value
- valueDeclared := (chunkHeader & VALUE_DECL_TYPE) != 0
- if !valueDeclared {
- buf.ReadUint8(err) // skip value type
- }
- v := readString(buf, err)
- result[""] = v // empty string as null key
- size--
- continue
- } else if valueHasNull {
- // Non-null key with null value
- keyDeclared := (chunkHeader & KEY_DECL_TYPE) != 0
- if !keyDeclared {
- buf.ReadUint8(err) // skip key type
- }
- k := readString(buf, err)
- result[k] = "" // empty string as null value
- size--
- continue
+ if chunkHeader&(TRACKING_KEY_REF|KEY_HAS_NULL|TRACKING_VALUE_REF|VALUE_HAS_NULL) != 0 {
+ ctx.SetError(DeserializationError("typed map reader does not support ref/null chunks"))
+ return result
}
- // ReadData chunk size
chunkSize := int(buf.ReadUint8(err))
+ if ctx.HasError() {
+ return result
+ }
+ if chunkSize == 0 || chunkSize > size {
+ ctx.SetError(DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, size))
+ return result
+ }
- // Read type info if not DECL_TYPE
if (chunkHeader & KEY_DECL_TYPE) == 0 {
- buf.ReadUint8(err) // skip key type
+ if !ctx.readExpectedTypeID(STRING) {
+ return result
+ }
}
if (chunkHeader & VALUE_DECL_TYPE) == 0 {
- buf.ReadUint8(err) // skip value type
+ if !ctx.readExpectedTypeID(STRING) {
+ return result
+ }
}
- // ReadData chunk entries
- for i := 0; i < chunkSize && size > 0; i++ {
+ for i := 0; i < chunkSize; i++ {
k := readString(buf, err)
v := readString(buf, err)
result[k] = v
@@ -185,22 +167,30 @@ func readMapStringInt64(ctx *ReadContext) map[string]int64 {
for size > 0 {
chunkHeader := buf.ReadUint8(err)
- keyHasNull := (chunkHeader & KEY_HAS_NULL) != 0
- valueHasNull := (chunkHeader & VALUE_HAS_NULL) != 0
-
- if keyHasNull || valueHasNull {
- size--
- continue
+ if chunkHeader&(TRACKING_KEY_REF|KEY_HAS_NULL|TRACKING_VALUE_REF|VALUE_HAS_NULL) != 0 {
+ ctx.SetError(DeserializationError("typed map reader does not support ref/null chunks"))
+ return result
}
chunkSize := int(buf.ReadUint8(err))
+ if ctx.HasError() {
+ return result
+ }
+ if chunkSize == 0 || chunkSize > size {
+ ctx.SetError(DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, size))
+ return result
+ }
if (chunkHeader & KEY_DECL_TYPE) == 0 {
- buf.ReadUint8(err)
+ if !ctx.readExpectedTypeID(STRING) {
+ return result
+ }
}
if (chunkHeader & VALUE_DECL_TYPE) == 0 {
- buf.ReadUint8(err)
+ if !ctx.readExpectedTypeID(VARINT64) {
+ return result
+ }
}
- for i := 0; i < chunkSize && size > 0; i++ {
+ for i := 0; i < chunkSize; i++ {
k := readString(buf, err)
v := buf.ReadVarint64(err)
result[k] = v
@@ -261,22 +251,30 @@ func readMapStringInt32(ctx *ReadContext) map[string]int32 {
for size > 0 {
chunkHeader := buf.ReadUint8(err)
- keyHasNull := (chunkHeader & KEY_HAS_NULL) != 0
- valueHasNull := (chunkHeader & VALUE_HAS_NULL) != 0
-
- if keyHasNull || valueHasNull {
- size--
- continue
+ if chunkHeader&(TRACKING_KEY_REF|KEY_HAS_NULL|TRACKING_VALUE_REF|VALUE_HAS_NULL) != 0 {
+ ctx.SetError(DeserializationError("typed map reader does not support ref/null chunks"))
+ return result
}
chunkSize := int(buf.ReadUint8(err))
+ if ctx.HasError() {
+ return result
+ }
+ if chunkSize == 0 || chunkSize > size {
+ ctx.SetError(DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, size))
+ return result
+ }
if (chunkHeader & KEY_DECL_TYPE) == 0 {
- buf.ReadUint8(err)
+ if !ctx.readExpectedTypeID(STRING) {
+ return result
+ }
}
if (chunkHeader & VALUE_DECL_TYPE) == 0 {
- buf.ReadUint8(err)
+ if !ctx.readExpectedTypeID(VARINT32) {
+ return result
+ }
}
- for i := 0; i < chunkSize && size > 0; i++ {
+ for i := 0; i < chunkSize; i++ {
k := readString(buf, err)
v := buf.ReadVarint32(err)
result[k] = v
@@ -337,22 +335,30 @@ func readMapStringInt(ctx *ReadContext) map[string]int {
for size > 0 {
chunkHeader := buf.ReadUint8(err)
- keyHasNull := (chunkHeader & KEY_HAS_NULL) != 0
- valueHasNull := (chunkHeader & VALUE_HAS_NULL) != 0
-
- if keyHasNull || valueHasNull {
- size--
- continue
+ if chunkHeader&(TRACKING_KEY_REF|KEY_HAS_NULL|TRACKING_VALUE_REF|VALUE_HAS_NULL) != 0 {
+ ctx.SetError(DeserializationError("typed map reader does not support ref/null chunks"))
+ return result
}
chunkSize := int(buf.ReadUint8(err))
+ if ctx.HasError() {
+ return result
+ }
+ if chunkSize == 0 || chunkSize > size {
+ ctx.SetError(DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, size))
+ return result
+ }
if (chunkHeader & KEY_DECL_TYPE) == 0 {
- buf.ReadUint8(err)
+ if !ctx.readExpectedTypeID(STRING) {
+ return result
+ }
}
if (chunkHeader & VALUE_DECL_TYPE) == 0 {
- buf.ReadUint8(err)
+ if !ctx.readExpectedTypeID(VARINT64) {
+ return result
+ }
}
- for i := 0; i < chunkSize && size > 0; i++ {
+ for i := 0; i < chunkSize; i++ {
k := readString(buf, err)
v := buf.ReadVarint64(err)
result[k] = int(v)
@@ -413,22 +419,30 @@ func readMapStringFloat64(ctx *ReadContext) map[string]float64 {
for size > 0 {
chunkHeader := buf.ReadUint8(err)
- keyHasNull := (chunkHeader & KEY_HAS_NULL) != 0
- valueHasNull := (chunkHeader & VALUE_HAS_NULL) != 0
-
- if keyHasNull || valueHasNull {
- size--
- continue
+ if chunkHeader&(TRACKING_KEY_REF|KEY_HAS_NULL|TRACKING_VALUE_REF|VALUE_HAS_NULL) != 0 {
+ ctx.SetError(DeserializationError("typed map reader does not support ref/null chunks"))
+ return result
}
chunkSize := int(buf.ReadUint8(err))
+ if ctx.HasError() {
+ return result
+ }
+ if chunkSize == 0 || chunkSize > size {
+ ctx.SetError(DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, size))
+ return result
+ }
if (chunkHeader & KEY_DECL_TYPE) == 0 {
- buf.ReadUint8(err)
+ if !ctx.readExpectedTypeID(STRING) {
+ return result
+ }
}
if (chunkHeader & VALUE_DECL_TYPE) == 0 {
- buf.ReadUint8(err)
+ if !ctx.readExpectedTypeID(FLOAT64) {
+ return result
+ }
}
- for i := 0; i < chunkSize && size > 0; i++ {
+ for i := 0; i < chunkSize; i++ {
k := readString(buf, err)
v := buf.ReadFloat64(err)
result[k] = v
@@ -489,27 +503,34 @@ func readMapStringBool(ctx *ReadContext) map[string]bool {
for size > 0 {
chunkHeader := buf.ReadUint8(err)
- keyHasNull := (chunkHeader & KEY_HAS_NULL) != 0
- valueHasNull := (chunkHeader & VALUE_HAS_NULL) != 0
-
- if keyHasNull || valueHasNull {
- size--
- continue
+ if chunkHeader&(TRACKING_KEY_REF|KEY_HAS_NULL|TRACKING_VALUE_REF|VALUE_HAS_NULL) != 0 {
+ ctx.SetError(DeserializationError("typed map reader does not support ref/null chunks"))
+ return result
}
chunkSize := int(buf.ReadUint8(err))
+ if ctx.HasError() {
+ return result
+ }
+ if chunkSize == 0 || chunkSize > size {
+ ctx.SetError(DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, size))
+ return result
+ }
- // Read type info (written by writeMapStringBool)
keyDeclType := (chunkHeader & KEY_DECL_TYPE) != 0
valDeclType := (chunkHeader & VALUE_DECL_TYPE) != 0
if !keyDeclType {
- buf.ReadUint8(err) // skip key type info
+ if !ctx.readExpectedTypeID(STRING) {
+ return result
+ }
}
if !valDeclType {
- buf.ReadUint8(err) // skip value type info
+ if !ctx.readExpectedTypeID(BOOL) {
+ return result
+ }
}
- for i := 0; i < chunkSize && size > 0; i++ {
+ for i := 0; i < chunkSize; i++ {
k := readString(buf, err)
v := buf.ReadBool(err)
result[k] = v
@@ -570,22 +591,30 @@ func readMapInt32Int32(ctx *ReadContext) map[int32]int32 {
for size > 0 {
chunkHeader := buf.ReadUint8(err)
- keyHasNull := (chunkHeader & KEY_HAS_NULL) != 0
- valueHasNull := (chunkHeader & VALUE_HAS_NULL) != 0
-
- if keyHasNull || valueHasNull {
- size--
- continue
+ if chunkHeader&(TRACKING_KEY_REF|KEY_HAS_NULL|TRACKING_VALUE_REF|VALUE_HAS_NULL) != 0 {
+ ctx.SetError(DeserializationError("typed map reader does not support ref/null chunks"))
+ return result
}
chunkSize := int(buf.ReadUint8(err))
+ if ctx.HasError() {
+ return result
+ }
+ if chunkSize == 0 || chunkSize > size {
+ ctx.SetError(DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, size))
+ return result
+ }
if (chunkHeader & KEY_DECL_TYPE) == 0 {
- buf.ReadUint8(err)
+ if !ctx.readExpectedTypeID(VARINT32) {
+ return result
+ }
}
if (chunkHeader & VALUE_DECL_TYPE) == 0 {
- buf.ReadUint8(err)
+ if !ctx.readExpectedTypeID(VARINT32) {
+ return result
+ }
}
- for i := 0; i < chunkSize && size > 0; i++ {
+ for i := 0; i < chunkSize; i++ {
k := buf.ReadVarint32(err)
v := buf.ReadVarint32(err)
result[k] = v
@@ -646,22 +675,30 @@ func readMapInt64Int64(ctx *ReadContext) map[int64]int64 {
for size > 0 {
chunkHeader := buf.ReadUint8(err)
- keyHasNull := (chunkHeader & KEY_HAS_NULL) != 0
- valueHasNull := (chunkHeader & VALUE_HAS_NULL) != 0
-
- if keyHasNull || valueHasNull {
- size--
- continue
+ if chunkHeader&(TRACKING_KEY_REF|KEY_HAS_NULL|TRACKING_VALUE_REF|VALUE_HAS_NULL) != 0 {
+ ctx.SetError(DeserializationError("typed map reader does not support ref/null chunks"))
+ return result
}
chunkSize := int(buf.ReadUint8(err))
+ if ctx.HasError() {
+ return result
+ }
+ if chunkSize == 0 || chunkSize > size {
+ ctx.SetError(DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, size))
+ return result
+ }
if (chunkHeader & KEY_DECL_TYPE) == 0 {
- buf.ReadUint8(err)
+ if !ctx.readExpectedTypeID(VARINT64) {
+ return result
+ }
}
if (chunkHeader & VALUE_DECL_TYPE) == 0 {
- buf.ReadUint8(err)
+ if !ctx.readExpectedTypeID(VARINT64) {
+ return result
+ }
}
- for i := 0; i < chunkSize && size > 0; i++ {
+ for i := 0; i < chunkSize; i++ {
k := buf.ReadVarint64(err)
v := buf.ReadVarint64(err)
result[k] = v
@@ -722,22 +759,30 @@ func readMapIntInt(ctx *ReadContext) map[int]int {
for size > 0 {
chunkHeader := buf.ReadUint8(err)
- keyHasNull := (chunkHeader & KEY_HAS_NULL) != 0
- valueHasNull := (chunkHeader & VALUE_HAS_NULL) != 0
-
- if keyHasNull || valueHasNull {
- size--
- continue
+ if chunkHeader&(TRACKING_KEY_REF|KEY_HAS_NULL|TRACKING_VALUE_REF|VALUE_HAS_NULL) != 0 {
+ ctx.SetError(DeserializationError("typed map reader does not support ref/null chunks"))
+ return result
}
chunkSize := int(buf.ReadUint8(err))
+ if ctx.HasError() {
+ return result
+ }
+ if chunkSize == 0 || chunkSize > size {
+ ctx.SetError(DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, size))
+ return result
+ }
if (chunkHeader & KEY_DECL_TYPE) == 0 {
- buf.ReadUint8(err)
+ if !ctx.readExpectedTypeID(VARINT64) {
+ return result
+ }
}
if (chunkHeader & VALUE_DECL_TYPE) == 0 {
- buf.ReadUint8(err)
+ if !ctx.readExpectedTypeID(VARINT64) {
+ return result
+ }
}
- for i := 0; i < chunkSize && size > 0; i++ {
+ for i := 0; i < chunkSize; i++ {
k := buf.ReadVarint64(err)
v := buf.ReadVarint64(err)
result[int(k)] = int(v)
diff --git a/go/fory/map_primitive_test.go b/go/fory/map_primitive_test.go
new file mode 100644
index 0000000000..7decca8b0a
--- /dev/null
+++ b/go/fory/map_primitive_test.go
@@ -0,0 +1,61 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package fory
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestPrimitiveMapReaderRejectsInvalidChunkSize(t *testing.T) {
+ f := NewFory()
+ buf := NewByteBuffer(nil)
+ buf.WriteLength(1)
+ buf.WriteUint8(KEY_DECL_TYPE | VALUE_DECL_TYPE)
+ buf.WriteUint8(2)
+
+ f.readCtx.SetData(buf.Bytes())
+ _ = f.readCtx.ReadStringStringMap(RefModeNone, false)
+ require.Error(t, f.readCtx.CheckError())
+}
+
+func TestPrimitiveMapReaderRejectsUnexpectedTypeInfo(t *testing.T) {
+ f := NewFory()
+ buf := NewByteBuffer(nil)
+ buf.WriteLength(1)
+ buf.WriteUint8(0)
+ buf.WriteUint8(1)
+ buf.WriteUint8(uint8(STRING))
+ buf.WriteUint8(uint8(BOOL))
+
+ f.readCtx.SetData(buf.Bytes())
+ _ = f.readCtx.ReadStringStringMap(RefModeNone, false)
+ require.Error(t, f.readCtx.CheckError())
+}
+
+func TestPrimitiveMapReaderRejectsNullChunks(t *testing.T) {
+ f := NewFory()
+ buf := NewByteBuffer(nil)
+ buf.WriteLength(1)
+ buf.WriteUint8(KEY_HAS_NULL)
+
+ f.readCtx.SetData(buf.Bytes())
+ _ = f.readCtx.ReadStringStringMap(RefModeNone, false)
+ require.Error(t, f.readCtx.CheckError())
+}
diff --git a/go/fory/meta_string_resolver.go b/go/fory/meta_string_resolver.go
index 6df61d0c71..5a9e1cb8ff 100644
--- a/go/fory/meta_string_resolver.go
+++ b/go/fory/meta_string_resolver.go
@@ -18,7 +18,6 @@
package fory
import (
- "bytes"
"encoding/binary"
"fmt"
"github.com/apache/fory/go/fory/meta"
@@ -28,10 +27,10 @@ import (
const (
SmallStringThreshold = 16 // Maximum length for "small" strings
DefaultDynamicWriteMetaStrID = -1 // Default ID for dynamic strings
+ maxCachedMetaStrings = 8192
+ smallMetaStringEncodingBits = 4
)
-type Encoding int8
-
type MetaStringBytes struct {
Data []byte
Length int16
@@ -60,14 +59,22 @@ func (a *MetaStringBytes) Hash() int64 {
type pair [2]int64
+// Mirrors Java's small MetaString read cache key: two packed byte words plus one
+// compact length/encoding byte. The packed words are zero-padded and are not
+// exact byte identity by themselves.
+type smallMetaStringKey struct {
+ v1 int64
+ v2 int64
+ compactKey byte
+}
+
type MetaStringResolver struct {
- dynamicWriteStringID int16 // Counter for dynamic string IDs
- dynamicWrittenEnumString []*MetaStringBytes // Cache of written strings
- dynamicIDToEnumString []*MetaStringBytes // Cache of read strings by ID
- hashToMetaStrBytes map[int64]*MetaStringBytes // Large string lookup
- smallHashToMetaStrBytes map[pair]*MetaStringBytes // Small string lookup
- enumStrSet map[*MetaStringBytes]struct{} // String set for deduplication
- metaStrToMetaStrBytes map[*meta.MetaString]*MetaStringBytes // Conversion cache
+ dynamicWriteStringID int16 // Counter for dynamic string IDs
+ dynamicWrittenEnumString []*MetaStringBytes // Cache of written strings
+ dynamicIDToEnumString []*MetaStringBytes // Cache of read strings by ID
+ hashToMetaStrBytes map[int64]*MetaStringBytes // Large string lookup
+ smallHashToMetaStrBytes map[smallMetaStringKey]*MetaStringBytes // Small string lookup
+ metaStrToMetaStrBytes map[*meta.MetaString]*MetaStringBytes // Conversion cache
}
var emptyMetaStringBytes = NewMetaStringBytes([]byte{}, 256)
@@ -75,8 +82,7 @@ var emptyMetaStringBytes = NewMetaStringBytes([]byte{}, 256)
func NewMetaStringResolver() *MetaStringResolver {
return &MetaStringResolver{
hashToMetaStrBytes: make(map[int64]*MetaStringBytes),
- smallHashToMetaStrBytes: make(map[pair]*MetaStringBytes),
- enumStrSet: make(map[*MetaStringBytes]struct{}),
+ smallHashToMetaStrBytes: make(map[smallMetaStringKey]*MetaStringBytes),
metaStrToMetaStrBytes: make(map[*meta.MetaString]*MetaStringBytes),
}
}
@@ -121,20 +127,27 @@ func (r *MetaStringResolver) ReadMetaStringBytes(buf *ByteBuffer, ctxErr *Error)
return nil, *ctxErr
}
- length := int16(header >> 1)
+ lengthValue := header >> 1
if header&1 != 0 {
- index := int(length) - 1
+ if lengthValue == 0 || uint64(lengthValue) > uint64(MaxInt) {
+ return nil, fmt.Errorf("invalid dynamic index: %d", lengthValue)
+ }
+ index := int(lengthValue) - 1
if index < 0 || index >= len(r.dynamicIDToEnumString) {
return nil, fmt.Errorf("invalid dynamic index: %d", index)
}
return r.dynamicIDToEnumString[index], nil
}
+ if lengthValue > uint32(MaxInt16) {
+ return nil, fmt.Errorf("meta string length %d exceeds maximum supported length %d", lengthValue, MaxInt16)
+ }
+ length := int(lengthValue)
var (
hashcode int64
- key pair
+ key smallMetaStringKey
data []byte
- encoding Encoding
+ encoding meta.Encoding
)
// Small string optimization
@@ -143,9 +156,12 @@ func (r *MetaStringResolver) ReadMetaStringBytes(buf *ByteBuffer, ctxErr *Error)
r.dynamicIDToEnumString = append(r.dynamicIDToEnumString, emptyMetaStringBytes)
return emptyMetaStringBytes, nil
}
- // ReadData encoding and data
encByte := buf.ReadByte(ctxErr)
- encoding = Encoding(encByte)
+ var encErr error
+ encoding, encErr = meta.EncodingFromByte(encByte)
+ if encErr != nil {
+ return nil, encErr
+ }
data = make([]byte, length)
_, err := buf.Read(data)
@@ -153,29 +169,33 @@ func (r *MetaStringResolver) ReadMetaStringBytes(buf *ByteBuffer, ctxErr *Error)
return nil, err
}
- // Compute composite hash key
- if length <= 8 {
- key[0] = bytesToInt64(data)
- } else {
- err := binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &key[0])
- if err != nil {
- return nil, err
- }
- key[1] = bytesToInt64(data[8:])
+ words := smallMetaStringWords(data)
+ key = smallMetaStringKey{
+ v1: words[0],
+ v2: words[1],
+ compactKey: byte(((length - 1) << smallMetaStringEncodingBits) | int(encoding)),
}
- hashcode = ((key[0]*31 + key[1]) >> 8 << 8) | int64(encoding)
+ hashcode = computeSmallMetaStringHash(words, length, encoding)
} else {
// Large string handling
err := binary.Read(buf, binary.LittleEndian, &hashcode)
if err != nil {
return nil, err
}
- encoding = Encoding(hashcode & 0xFF)
+ var encErr error
+ encoding, encErr = meta.EncodingFromByte(byte(hashcode & 0xFF))
+ if encErr != nil {
+ return nil, encErr
+ }
data = make([]byte, length)
_, err = buf.Read(data)
if err != nil {
return nil, err
}
+ canonicalHashcode := ComputeMetaStringHash(data, encoding)
+ if canonicalHashcode != hashcode {
+ return nil, fmt.Errorf("meta string body hash mismatch")
+ }
}
// Check string caches for existing instance
@@ -191,14 +211,18 @@ func (r *MetaStringResolver) ReadMetaStringBytes(buf *ByteBuffer, ctxErr *Error)
}
}
- // Create and cache new string instance
+ // Cache only after the current body has been parsed and, for large bodies, hash-validated.
+ // Header-keyed hits stay on the fast path; forged headers cannot poison the shared cache.
m := NewMetaStringBytes(data, hashcode)
if length <= SmallStringThreshold {
- r.smallHashToMetaStrBytes[key] = m
+ if len(r.smallHashToMetaStrBytes) < maxCachedMetaStrings {
+ r.smallHashToMetaStrBytes[key] = m
+ }
} else {
- r.hashToMetaStrBytes[hashcode] = m
+ if len(r.hashToMetaStrBytes) < maxCachedMetaStrings {
+ r.hashToMetaStrBytes[hashcode] = m
+ }
}
- r.enumStrSet[m] = struct{}{}
r.dynamicIDToEnumString = append(r.dynamicIDToEnumString, m)
return m, nil
@@ -222,14 +246,8 @@ func (r *MetaStringResolver) GetMetaStrBytes(metastr *meta.MetaString) *MetaStri
}
if length <= SmallStringThreshold {
// Small string: use direct bytes as hash components
- var v1, v2 int64
- if length <= 8 {
- v1 = bytesToInt64(data)
- } else {
- binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &v1)
- v2 = bytesToInt64(data[8:])
- }
- hashcode = ((v1*31 + v2) >> 8 << 8) | int64(metastr.GetEncoding())
+ words := smallMetaStringWords(data)
+ hashcode = computeSmallMetaStringHash(words, length, metastr.GetEncoding())
} else {
// Large string: use MurmurHash3
h64 := Murmur3Sum64WithSeed(data, 47)
@@ -253,14 +271,8 @@ func ComputeMetaStringHash(data []byte, encoding meta.Encoding) int64 {
hashcode |= int64(encoding)
} else if length <= SmallStringThreshold {
// Small string: use direct bytes as hash components
- var v1, v2 int64
- if length <= 8 {
- v1 = bytesToInt64(data)
- } else {
- binary.Read(bytes.NewReader(data[:8]), binary.LittleEndian, &v1)
- v2 = bytesToInt64(data[8:])
- }
- hashcode = ((v1*31 + v2) >> 8 << 8) | int64(encoding)
+ words := smallMetaStringWords(data)
+ hashcode = computeSmallMetaStringHash(words, length, encoding)
} else {
// Large string: use MurmurHash3
h64 := Murmur3Sum64WithSeed(data, 47)
@@ -316,3 +328,15 @@ func bytesToInt64(b []byte) int64 {
}
return v
}
+
+func smallMetaStringWords(data []byte) pair {
+ if len(data) <= 8 {
+ return pair{bytesToInt64(data), 0}
+ }
+ return pair{int64(binary.LittleEndian.Uint64(data[:8])), bytesToInt64(data[8:])}
+}
+
+func computeSmallMetaStringHash(words pair, length int, encoding meta.Encoding) int64 {
+ hash := uint64(words[0]*31+words[1]) ^ (uint64(length) << 56)
+ return int64((hash & 0xffffffffffffff00) | uint64(encoding))
+}
diff --git a/go/fory/meta_string_resolver_test.go b/go/fory/meta_string_resolver_test.go
index bbf9e39020..5a15be8024 100644
--- a/go/fory/meta_string_resolver_test.go
+++ b/go/fory/meta_string_resolver_test.go
@@ -18,8 +18,11 @@
package fory
import (
+ "encoding/binary"
"github.com/stretchr/testify/require"
"testing"
+
+ "github.com/apache/fory/go/fory/meta"
)
// TestMetaStringResolverNegativeIndexPanic reproduces the CRITICAL security bug
@@ -65,3 +68,125 @@ func TestMetaStringResolverBoundaryRegression(t *testing.T) {
require.NoError(t, err)
require.Equal(t, m, result, "Should correctly resolve the first dynamic string (index 0)")
}
+
+func TestMetaStringResolverRejectsLargeBodyHashMismatch(t *testing.T) {
+ resolver := NewMetaStringResolver()
+ buffer := NewByteBuffer(nil)
+ data := []byte("0123456789abcdefg")
+
+ buffer.WriteVarUint32Small7(uint32(len(data)) << 1)
+ buffer.WriteInt64(int64(meta.UTF_8))
+ buffer.Write(data)
+ buffer.SetReaderIndex(0)
+
+ var ctxErr Error
+ _, err := resolver.ReadMetaStringBytes(buffer, &ctxErr)
+ require.Error(t, err)
+ require.Empty(t, resolver.hashToMetaStrBytes)
+ require.Empty(t, resolver.dynamicIDToEnumString)
+}
+
+func TestMetaStringResolverRejectsOversizedLengthBeforeAllocation(t *testing.T) {
+ resolver := NewMetaStringResolver()
+ buffer := NewByteBuffer(nil)
+ buffer.WriteVarUint32Small7(uint32(MaxInt16+1) << 1)
+ buffer.SetReaderIndex(0)
+
+ var ctxErr Error
+ _, err := resolver.ReadMetaStringBytes(buffer, &ctxErr)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "meta string length")
+ require.Empty(t, resolver.hashToMetaStrBytes)
+ require.Empty(t, resolver.smallHashToMetaStrBytes)
+ require.Empty(t, resolver.dynamicIDToEnumString)
+}
+
+func TestMetaStringResolverSmallCacheKeyIncludesLengthAndEncoding(t *testing.T) {
+ resolver := NewMetaStringResolver()
+
+ oneByte := NewByteBuffer(nil)
+ oneByte.WriteVarUint32Small7(1 << 1)
+ oneByte.WriteByte(byte(meta.UTF_8))
+ oneByte.WriteByte(1)
+ oneByte.SetReaderIndex(0)
+ var oneErr Error
+ first, err := resolver.ReadMetaStringBytes(oneByte, &oneErr)
+ require.NoError(t, err)
+ require.Equal(t, []byte{1}, first.Data)
+
+ twoBytes := NewByteBuffer(nil)
+ twoBytes.WriteVarUint32Small7(2 << 1)
+ twoBytes.WriteByte(byte(meta.UTF_8))
+ twoBytes.Write([]byte{1, 0})
+ twoBytes.SetReaderIndex(0)
+ var twoErr Error
+ second, err := resolver.ReadMetaStringBytes(twoBytes, &twoErr)
+ require.NoError(t, err)
+ require.Equal(t, []byte{1, 0}, second.Data)
+ require.NotSame(t, first, second)
+
+ differentEncoding := NewByteBuffer(nil)
+ differentEncoding.WriteVarUint32Small7(1 << 1)
+ differentEncoding.WriteByte(byte(meta.LOWER_SPECIAL))
+ differentEncoding.WriteByte(1)
+ differentEncoding.SetReaderIndex(0)
+ var encodingErr Error
+ third, err := resolver.ReadMetaStringBytes(differentEncoding, &encodingErr)
+ require.NoError(t, err)
+ require.Equal(t, []byte{1}, third.Data)
+ require.Equal(t, meta.LOWER_SPECIAL, third.Encoding)
+ require.NotSame(t, first, third)
+ require.Len(t, resolver.smallHashToMetaStrBytes, 3)
+}
+
+func TestComputeMetaStringHashIncludesSmallLength(t *testing.T) {
+ require.NotEqual(
+ t,
+ ComputeMetaStringHash([]byte{1}, meta.UTF_8),
+ ComputeMetaStringHash([]byte{1, 0}, meta.UTF_8),
+ )
+ require.NotEqual(
+ t,
+ ComputeMetaStringHash([]byte{1}, meta.UTF_8),
+ ComputeMetaStringHash([]byte{1}, meta.LOWER_SPECIAL),
+ )
+}
+
+func TestMetaStringResolverReadCachesAreCapped(t *testing.T) {
+ resolver := NewMetaStringResolver()
+ smallKey := smallMetaStringKey{v1: 1, v2: 0, compactKey: byte(meta.UTF_8)}
+ for i := 0; i < maxCachedMetaStrings; i++ {
+ resolver.smallHashToMetaStrBytes[smallMetaStringKey{
+ v1: int64(i + 2),
+ v2: 0,
+ compactKey: byte(meta.UTF_8),
+ }] =
+ NewMetaStringBytes([]byte{byte(i)}, int64(i+2)<<8)
+ resolver.hashToMetaStrBytes[int64(i+2)<<8] =
+ NewMetaStringBytes([]byte("0123456789abcdefg"), int64(i+2)<<8)
+ }
+
+ smallBuffer := NewByteBuffer(nil)
+ smallBuffer.WriteVarUint32Small7(1 << 1)
+ smallBuffer.WriteByte(byte(meta.UTF_8))
+ smallBuffer.WriteByte(1)
+ smallBuffer.SetReaderIndex(0)
+ var smallErr Error
+ _, err := resolver.ReadMetaStringBytes(smallBuffer, &smallErr)
+ require.NoError(t, err)
+ require.Len(t, resolver.smallHashToMetaStrBytes, maxCachedMetaStrings)
+ require.NotContains(t, resolver.smallHashToMetaStrBytes, smallKey)
+
+ largeData := []byte("0123456789abcdefg")
+ largeHash := ComputeMetaStringHash(largeData, meta.UTF_8)
+ largeBuffer := NewByteBuffer(nil)
+ largeBuffer.WriteVarUint32Small7(uint32(len(largeData)) << 1)
+ require.NoError(t, binary.Write(largeBuffer, binary.LittleEndian, largeHash))
+ largeBuffer.Write(largeData)
+ largeBuffer.SetReaderIndex(0)
+ var largeErr Error
+ _, err = resolver.ReadMetaStringBytes(largeBuffer, &largeErr)
+ require.NoError(t, err)
+ require.Len(t, resolver.hashToMetaStrBytes, maxCachedMetaStrings)
+ require.NotContains(t, resolver.hashToMetaStrBytes, largeHash)
+}
diff --git a/go/fory/primitive.go b/go/fory/primitive.go
index f5a7e40550..1748bc27a8 100644
--- a/go/fory/primitive.go
+++ b/go/fory/primitive.go
@@ -56,8 +56,8 @@ func (s boolSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, h
return
}
}
- if readType {
- _ = ctx.buffer.ReadUint8(err)
+ if readType && !ctx.readExpectedTypeID(BOOL) {
+ return
}
if ctx.HasError() {
return
@@ -103,8 +103,8 @@ func (s int8Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool, h
return
}
}
- if readType {
- _ = ctx.buffer.ReadUint8(err)
+ if readType && !ctx.readExpectedTypeID(INT8) {
+ return
}
if ctx.HasError() {
return
@@ -149,8 +149,8 @@ func (s byteSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, h
return
}
}
- if readType {
- _ = ctx.buffer.ReadUint8(err)
+ if readType && !ctx.readExpectedTypeID(UINT8) {
+ return
}
if ctx.HasError() {
return
@@ -195,8 +195,8 @@ func (s uint16Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool,
return
}
}
- if readType {
- _ = ctx.buffer.ReadUint8(err)
+ if readType && !ctx.readExpectedTypeID(UINT16) {
+ return
}
if ctx.HasError() {
return
@@ -241,8 +241,8 @@ func (s uint32Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool,
return
}
}
- if readType {
- _ = ctx.buffer.ReadUint8(err)
+ if readType && !ctx.readExpectedTypeID(VAR_UINT32) {
+ return
}
if ctx.HasError() {
return
@@ -287,8 +287,8 @@ func (s uint64Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool,
return
}
}
- if readType {
- _ = ctx.buffer.ReadUint8(err)
+ if readType && !ctx.readExpectedTypeID(VAR_UINT64) {
+ return
}
if ctx.HasError() {
return
@@ -331,8 +331,8 @@ func (s uintSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, h
return
}
}
- if readType {
- _ = ctx.buffer.ReadUint8(err)
+ if readType && !ctx.readExpectedTypeID(VAR_UINT64) {
+ return
}
if ctx.HasError() {
return
@@ -375,8 +375,8 @@ func (s int16Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool,
return
}
}
- if readType {
- _ = ctx.buffer.ReadUint8(err)
+ if readType && !ctx.readExpectedTypeID(INT16) {
+ return
}
if ctx.HasError() {
return
@@ -419,8 +419,8 @@ func (s int32Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool,
return
}
}
- if readType {
- _ = ctx.buffer.ReadUint8(err)
+ if readType && !ctx.readExpectedTypeID(VARINT32) {
+ return
}
if ctx.HasError() {
return
@@ -463,8 +463,8 @@ func (s int64Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool,
return
}
}
- if readType {
- _ = ctx.buffer.ReadUint8(err)
+ if readType && !ctx.readExpectedTypeID(VARINT64) {
+ return
}
if ctx.HasError() {
return
@@ -505,8 +505,8 @@ func (s intSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, ha
return
}
}
- if readType {
- _ = ctx.buffer.ReadUint8(err)
+ if readType && !ctx.readExpectedTypeID(VARINT64) {
+ return
}
if ctx.HasError() {
return
@@ -549,8 +549,8 @@ func (s float32Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool
return
}
}
- if readType {
- _ = ctx.buffer.ReadUint8(err)
+ if readType && !ctx.readExpectedTypeID(FLOAT32) {
+ return
}
if ctx.HasError() {
return
@@ -593,8 +593,8 @@ func (s float64Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool
return
}
}
- if readType {
- _ = ctx.buffer.ReadUint8(err)
+ if readType && !ctx.readExpectedTypeID(FLOAT64) {
+ return
}
if ctx.HasError() {
return
@@ -651,8 +651,8 @@ func (s float16Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool
return
}
}
- if readType {
- _ = ctx.buffer.ReadUint8(err)
+ if readType && !ctx.readExpectedTypeID(FLOAT16) {
+ return
}
if ctx.HasError() {
return
@@ -704,8 +704,8 @@ func (s bfloat16Serializer) Read(ctx *ReadContext, refMode RefMode, readType boo
return
}
}
- if readType {
- _ = ctx.buffer.ReadUint8(err)
+ if readType && !ctx.readExpectedTypeID(BFLOAT16) {
+ return
}
if ctx.HasError() {
return
diff --git a/go/fory/reader.go b/go/fory/reader.go
index 9c8b049ad2..d0f89fa81c 100644
--- a/go/fory/reader.go
+++ b/go/fory/reader.go
@@ -31,11 +31,12 @@ import (
type ReadContext struct {
buffer *ByteBuffer
refReader *RefReader
- trackRef bool // Cached flag to avoid indirection
- xlang bool // Cross-language serialization mode
+ trackRef bool // Cached flag to avoid indirection
+ xlang bool // Cross-language serialization mode
+ rootHeader byte
compatible bool // Schema evolution compatibility mode
typeResolver *TypeResolver // For complex type deserialization
- refResolver *RefResolver // For reference tracking (legacy)
+ refResolver *RefResolver // For reference tracking in native-mode paths
outOfBandBuffers []*ByteBuffer // Out-of-band buffers for deserialization
outOfBandIndex int // Current index into out-of-band buffers
depth int // Current nesting depth for cycle detection
@@ -109,7 +110,7 @@ func (c *ReadContext) TypeResolver() *TypeResolver {
return c.typeResolver
}
-// RefResolver returns the reference resolver (legacy)
+// RefResolver returns the reference resolver.
func (c *ReadContext) RefResolver() *RefResolver {
return c.refResolver
}
@@ -151,6 +152,18 @@ func (c *ReadContext) CheckError() error {
return nil
}
+func (c *ReadContext) readExpectedTypeID(expected TypeId) bool {
+ actual := TypeId(c.buffer.ReadUint8(c.Err()))
+ if c.HasError() {
+ return false
+ }
+ if actual != expected {
+ c.SetError(TypeMismatchError(actual, expected))
+ return false
+ }
+ return true
+}
+
// Inline primitive reads
func (c *ReadContext) RawBool() bool { return c.buffer.ReadBool(c.Err()) }
func (c *ReadContext) RawInt8() int8 { return int8(c.buffer.ReadByte(c.Err())) }
@@ -288,7 +301,11 @@ func (c *ReadContext) ReadBoolSlice(refMode RefMode, readType bool) []bool {
}
}
if readType {
- _ = c.buffer.ReadUint8(err)
+ actual := TypeId(c.buffer.ReadUint8(err))
+ if actual != BOOL_ARRAY {
+ c.SetError(TypeMismatchError(actual, BOOL_ARRAY))
+ return nil
+ }
}
return ReadBoolSlice(c.buffer, err)
}
@@ -302,7 +319,11 @@ func (c *ReadContext) ReadInt8Slice(refMode RefMode, readType bool) []int8 {
}
}
if readType {
- _ = c.buffer.ReadUint8(err)
+ actual := TypeId(c.buffer.ReadUint8(err))
+ if actual != INT8_ARRAY {
+ c.SetError(TypeMismatchError(actual, INT8_ARRAY))
+ return nil
+ }
}
return ReadInt8Slice(c.buffer, err)
}
@@ -316,7 +337,11 @@ func (c *ReadContext) ReadInt16Slice(refMode RefMode, readType bool) []int16 {
}
}
if readType {
- _ = c.buffer.ReadUint8(err)
+ actual := TypeId(c.buffer.ReadUint8(err))
+ if actual != INT16_ARRAY {
+ c.SetError(TypeMismatchError(actual, INT16_ARRAY))
+ return nil
+ }
}
return ReadInt16Slice(c.buffer, err)
}
@@ -330,7 +355,11 @@ func (c *ReadContext) ReadInt32Slice(refMode RefMode, readType bool) []int32 {
}
}
if readType {
- _ = c.buffer.ReadUint8(err)
+ actual := TypeId(c.buffer.ReadUint8(err))
+ if actual != INT32_ARRAY {
+ c.SetError(TypeMismatchError(actual, INT32_ARRAY))
+ return nil
+ }
}
return ReadInt32Slice(c.buffer, err)
}
@@ -344,7 +373,11 @@ func (c *ReadContext) ReadInt64Slice(refMode RefMode, readType bool) []int64 {
}
}
if readType {
- _ = c.buffer.ReadUint8(err)
+ actual := TypeId(c.buffer.ReadUint8(err))
+ if actual != INT64_ARRAY {
+ c.SetError(TypeMismatchError(actual, INT64_ARRAY))
+ return nil
+ }
}
return ReadInt64Slice(c.buffer, err)
}
@@ -358,7 +391,11 @@ func (c *ReadContext) ReadUint16Slice(refMode RefMode, readType bool) []uint16 {
}
}
if readType {
- _ = c.buffer.ReadUint8(err)
+ actual := TypeId(c.buffer.ReadUint8(err))
+ if actual != UINT16_ARRAY {
+ c.SetError(TypeMismatchError(actual, UINT16_ARRAY))
+ return nil
+ }
}
return ReadUint16Slice(c.buffer, err)
}
@@ -372,7 +409,11 @@ func (c *ReadContext) ReadUint32Slice(refMode RefMode, readType bool) []uint32 {
}
}
if readType {
- _ = c.buffer.ReadUint8(err)
+ actual := TypeId(c.buffer.ReadUint8(err))
+ if actual != UINT32_ARRAY {
+ c.SetError(TypeMismatchError(actual, UINT32_ARRAY))
+ return nil
+ }
}
return ReadUint32Slice(c.buffer, err)
}
@@ -386,7 +427,11 @@ func (c *ReadContext) ReadUint64Slice(refMode RefMode, readType bool) []uint64 {
}
}
if readType {
- _ = c.buffer.ReadUint8(err)
+ actual := TypeId(c.buffer.ReadUint8(err))
+ if actual != UINT64_ARRAY {
+ c.SetError(TypeMismatchError(actual, UINT64_ARRAY))
+ return nil
+ }
}
return ReadUint64Slice(c.buffer, err)
}
@@ -400,7 +445,15 @@ func (c *ReadContext) ReadIntSlice(refMode RefMode, readType bool) []int {
}
}
if readType {
- _ = c.buffer.ReadUint8(err)
+ actual := TypeId(c.buffer.ReadUint8(err))
+ expected := TypeId(INT64_ARRAY)
+ if strconv.IntSize == 32 {
+ expected = INT32_ARRAY
+ }
+ if actual != expected {
+ c.SetError(TypeMismatchError(actual, expected))
+ return nil
+ }
}
return ReadIntSlice(c.buffer, err)
}
@@ -414,7 +467,15 @@ func (c *ReadContext) ReadUintSlice(refMode RefMode, readType bool) []uint {
}
}
if readType {
- _ = c.buffer.ReadUint8(err)
+ actual := TypeId(c.buffer.ReadUint8(err))
+ expected := TypeId(UINT64_ARRAY)
+ if strconv.IntSize == 32 {
+ expected = UINT32_ARRAY
+ }
+ if actual != expected {
+ c.SetError(TypeMismatchError(actual, expected))
+ return nil
+ }
}
return ReadUintSlice(c.buffer, err)
}
@@ -428,7 +489,11 @@ func (c *ReadContext) ReadFloat32Slice(refMode RefMode, readType bool) []float32
}
}
if readType {
- _ = c.buffer.ReadUint8(err)
+ actual := TypeId(c.buffer.ReadUint8(err))
+ if actual != FLOAT32_ARRAY {
+ c.SetError(TypeMismatchError(actual, FLOAT32_ARRAY))
+ return nil
+ }
}
return ReadFloat32Slice(c.buffer, err)
}
@@ -442,7 +507,11 @@ func (c *ReadContext) ReadFloat64Slice(refMode RefMode, readType bool) []float64
}
}
if readType {
- _ = c.buffer.ReadUint8(err)
+ actual := TypeId(c.buffer.ReadUint8(err))
+ if actual != FLOAT64_ARRAY {
+ c.SetError(TypeMismatchError(actual, FLOAT64_ARRAY))
+ return nil
+ }
}
return ReadFloat64Slice(c.buffer, err)
}
@@ -456,7 +525,11 @@ func (c *ReadContext) ReadByteSlice(refMode RefMode, readType bool) []byte {
}
}
if readType {
- _ = c.buffer.ReadUint8(err)
+ actual := TypeId(c.buffer.ReadUint8(err))
+ if actual != BINARY && actual != UINT8_ARRAY {
+ c.SetError(DeserializationErrorf("slice type mismatch: expected BINARY (%d) or UINT8_ARRAY (%d), got %d", BINARY, UINT8_ARRAY, actual))
+ return nil
+ }
}
size := c.ReadBinaryLength()
return c.buffer.ReadBinary(size, err)
@@ -484,8 +557,8 @@ func (c *ReadContext) ReadStringStringMap(refMode RefMode, readType bool) map[st
return nil
}
}
- if readType {
- _ = c.buffer.ReadUint8(err)
+ if readType && !c.readExpectedTypeID(MAP) {
+ return nil
}
return readMapStringString(c)
}
@@ -498,8 +571,8 @@ func (c *ReadContext) ReadStringInt64Map(refMode RefMode, readType bool) map[str
return nil
}
}
- if readType {
- _ = c.buffer.ReadUint8(err)
+ if readType && !c.readExpectedTypeID(MAP) {
+ return nil
}
return readMapStringInt64(c)
}
@@ -512,8 +585,8 @@ func (c *ReadContext) ReadStringInt32Map(refMode RefMode, readType bool) map[str
return nil
}
}
- if readType {
- _ = c.buffer.ReadUint8(err)
+ if readType && !c.readExpectedTypeID(MAP) {
+ return nil
}
return readMapStringInt32(c)
}
@@ -526,8 +599,8 @@ func (c *ReadContext) ReadStringIntMap(refMode RefMode, readType bool) map[strin
return nil
}
}
- if readType {
- _ = c.buffer.ReadUint8(err)
+ if readType && !c.readExpectedTypeID(MAP) {
+ return nil
}
return readMapStringInt(c)
}
@@ -540,8 +613,8 @@ func (c *ReadContext) ReadStringFloat64Map(refMode RefMode, readType bool) map[s
return nil
}
}
- if readType {
- _ = c.buffer.ReadUint8(err)
+ if readType && !c.readExpectedTypeID(MAP) {
+ return nil
}
return readMapStringFloat64(c)
}
@@ -554,8 +627,8 @@ func (c *ReadContext) ReadStringBoolMap(refMode RefMode, readType bool) map[stri
return nil
}
}
- if readType {
- _ = c.buffer.ReadUint8(err)
+ if readType && !c.readExpectedTypeID(MAP) {
+ return nil
}
return readMapStringBool(c)
}
@@ -568,8 +641,8 @@ func (c *ReadContext) ReadInt32Int32Map(refMode RefMode, readType bool) map[int3
return nil
}
}
- if readType {
- _ = c.buffer.ReadUint8(err)
+ if readType && !c.readExpectedTypeID(MAP) {
+ return nil
}
return readMapInt32Int32(c)
}
@@ -582,8 +655,8 @@ func (c *ReadContext) ReadInt64Int64Map(refMode RefMode, readType bool) map[int6
return nil
}
}
- if readType {
- _ = c.buffer.ReadUint8(err)
+ if readType && !c.readExpectedTypeID(MAP) {
+ return nil
}
return readMapInt64Int64(c)
}
@@ -596,8 +669,8 @@ func (c *ReadContext) ReadIntIntMap(refMode RefMode, readType bool) map[int]int
return nil
}
}
- if readType {
- _ = c.buffer.ReadUint8(err)
+ if readType && !c.readExpectedTypeID(MAP) {
+ return nil
}
return readMapIntInt(c)
}
diff --git a/go/fory/skip.go b/go/fory/skip.go
index 9d08f465c4..abc7466449 100644
--- a/go/fory/skip.go
+++ b/go/fory/skip.go
@@ -56,7 +56,10 @@ func SkipFieldValueWithTypeFlag(ctx *ReadContext, fieldDef FieldDef, readRefFlag
// Check if it's an EXT type first - EXT types don't have meta info like structs
if internalID == EXT {
- typeInfo := ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, wroteTypeID, err)
+ typeInfo := readKnownTypeInfoForSkip(ctx, wroteTypeID)
+ if ctx.HasError() {
+ return
+ }
if typeInfo != nil && typeInfo.Serializer != nil {
// Use the serializer to read and discard the value
var dummy any
@@ -71,8 +74,11 @@ func SkipFieldValueWithTypeFlag(ctx *ReadContext, fieldDef FieldDef, readRefFlag
// Check if it's a NAMED_EXT type - need to read type info to find serializer
if internalID == NAMED_EXT {
- typeInfo := ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, wroteTypeID, err)
- if typeInfo.Serializer != nil {
+ typeInfo := readKnownTypeInfoForSkip(ctx, wroteTypeID)
+ if ctx.HasError() {
+ return
+ }
+ if typeInfo != nil && typeInfo.Serializer != nil {
// Use the serializer to read and discard the value
var dummy any
dummyVal := reflect.ValueOf(&dummy).Elem()
@@ -86,14 +92,20 @@ func SkipFieldValueWithTypeFlag(ctx *ReadContext, fieldDef FieldDef, readRefFlag
// Check if it's a struct type - need to read type info and skip struct data
if internalID == COMPATIBLE_STRUCT || internalID == STRUCT ||
internalID == NAMED_STRUCT || internalID == NAMED_COMPATIBLE_STRUCT {
- typeInfo := ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, wroteTypeID, err)
+ typeInfo := readKnownTypeInfoForSkip(ctx, wroteTypeID)
+ if ctx.HasError() {
+ return
+ }
// Now skip the struct data using the typeInfo from the written type
skipStruct(ctx, typeInfo)
return
}
if IsNamespacedType(internalID) {
- typeInfo := ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, wroteTypeID, err)
+ typeInfo := readKnownTypeInfoForSkip(ctx, wroteTypeID)
+ if ctx.HasError() {
+ return
+ }
// Now skip the struct data using the typeInfo from the written type
skipStruct(ctx, typeInfo)
return
@@ -153,38 +165,22 @@ func SkipAnyValue(ctx *ReadContext, readRefFlag bool) {
typeSpec: NewMapTypeSpec(TypeId(typeID), NewSimpleTypeSpec(UNKNOWN), NewSimpleTypeSpec(UNKNOWN)),
nullable: true,
}
- case NAMED_UNION:
- resolver := ctx.TypeResolver()
- _, _ = resolver.metaStringResolver.ReadMetaStringBytes(ctx.buffer, err)
- if ctx.HasError() {
- return
- }
- _, _ = resolver.metaStringResolver.ReadMetaStringBytes(ctx.buffer, err)
- if ctx.HasError() {
- return
- }
- fieldDef = FieldDef{
- typeSpec: NewSimpleTypeSpec(TypeId(typeID)),
- nullable: true,
- }
- case COMPATIBLE_STRUCT, NAMED_COMPATIBLE_STRUCT, STRUCT, NAMED_STRUCT, EXT, TYPED_UNION:
+ case ENUM, NAMED_ENUM, COMPATIBLE_STRUCT, NAMED_COMPATIBLE_STRUCT, STRUCT, NAMED_STRUCT,
+ EXT, NAMED_EXT, TYPED_UNION, NAMED_UNION:
// Read type info using the shared meta reader when enabled.
typeInfo = ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeID, err)
if ctx.HasError() {
return
}
+ if typeInfo == nil {
+ ctx.SetError(DeserializationErrorf("cannot skip type %d: type info not found", typeID))
+ return
+ }
fieldDef = FieldDef{
typeSpec: NewSimpleTypeSpec(TypeId(typeID)),
nullable: true,
}
default:
- if internalID == ENUM || internalID == STRUCT ||
- internalID == EXT || internalID == TYPED_UNION {
- ctx.buffer.ReadVarUint32(err)
- if ctx.HasError() {
- return
- }
- }
fieldDef = FieldDef{
typeSpec: NewSimpleTypeSpec(TypeId(typeID)),
nullable: true,
@@ -206,7 +202,25 @@ func readTypeInfoForSkip(ctx *ReadContext, fieldTypeId TypeId) *TypeInfo {
return nil
}
// Use readTypeInfoWithTypeID which handles both namespaced and non-namespaced types correctly
- return ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeID, err)
+ typeInfo := ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeID, err)
+ if ctx.HasError() {
+ return nil
+ }
+ if typeInfo == nil {
+ ctx.SetError(DeserializationErrorf("cannot skip type %d: type info not found", typeID))
+ }
+ return typeInfo
+}
+
+func readKnownTypeInfoForSkip(ctx *ReadContext, typeID uint32) *TypeInfo {
+ typeInfo := ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeID, ctx.Err())
+ if ctx.HasError() {
+ return nil
+ }
+ if typeInfo == nil {
+ ctx.SetError(DeserializationErrorf("cannot skip type %d: type info not found", typeID))
+ }
+ return typeInfo
}
// skipCollection skips a collection (list/set) value
@@ -236,7 +250,10 @@ func skipCollection(ctx *ReadContext, fieldDef FieldDef) {
if ctx.HasError() {
return
}
- elemTypeInfo = ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeID, err)
+ elemTypeInfo = readKnownTypeInfoForSkip(ctx, typeID)
+ if ctx.HasError() {
+ return
+ }
elemDef = FieldDef{
typeSpec: NewSimpleTypeSpec(TypeId(elemTypeInfo.TypeID)),
nullable: hasNull,
@@ -336,7 +353,10 @@ func skipMap(ctx *ReadContext, fieldDef FieldDef) {
if ctx.HasError() {
return
}
- valueTypeInfo = ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeID, bufErr)
+ valueTypeInfo = readKnownTypeInfoForSkip(ctx, typeID)
+ if ctx.HasError() {
+ return
+ }
valueDef = FieldDef{
typeSpec: NewSimpleTypeSpec(TypeId(valueTypeInfo.TypeID)),
nullable: true,
@@ -368,7 +388,10 @@ func skipMap(ctx *ReadContext, fieldDef FieldDef) {
if ctx.HasError() {
return
}
- keyTypeInfo = ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeID, bufErr)
+ keyTypeInfo = readKnownTypeInfoForSkip(ctx, typeID)
+ if ctx.HasError() {
+ return
+ }
keyDef = FieldDef{
typeSpec: NewSimpleTypeSpec(TypeId(keyTypeInfo.TypeID)),
nullable: true,
@@ -395,6 +418,10 @@ func skipMap(ctx *ReadContext, fieldDef FieldDef) {
if ctx.HasError() {
return
}
+ if chunkSize == 0 || uint32(chunkSize) > length-lenCounter {
+ ctx.SetError(DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, length-lenCounter))
+ return
+ }
keyDeclared := (header & KEY_DECL_TYPE) != 0
valueDeclared := (header & VALUE_DECL_TYPE) != 0
@@ -406,7 +433,10 @@ func skipMap(ctx *ReadContext, fieldDef FieldDef) {
if ctx.HasError() {
return
}
- keyTypeInfo = ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeID, bufErr)
+ keyTypeInfo = readKnownTypeInfoForSkip(ctx, typeID)
+ if ctx.HasError() {
+ return
+ }
keyDef = FieldDef{
typeSpec: NewSimpleTypeSpec(TypeId(keyTypeInfo.TypeID)),
nullable: true,
@@ -420,7 +450,10 @@ func skipMap(ctx *ReadContext, fieldDef FieldDef) {
if ctx.HasError() {
return
}
- valueTypeInfo = ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeID, bufErr)
+ valueTypeInfo = readKnownTypeInfoForSkip(ctx, typeID)
+ if ctx.HasError() {
+ return
+ }
valueDef = FieldDef{
typeSpec: NewSimpleTypeSpec(TypeId(valueTypeInfo.TypeID)),
nullable: true,
@@ -527,25 +560,32 @@ func skipValue(ctx *ReadContext, fieldDef FieldDef, readRefFlag bool, isField bo
internalID := TypeId(typeIDNum)
// Handle struct-like types
if internalID == COMPATIBLE_STRUCT || internalID == STRUCT ||
- internalID == NAMED_STRUCT || internalID == NAMED_COMPATIBLE_STRUCT ||
- internalID == UNKNOWN {
+ internalID == NAMED_STRUCT || internalID == NAMED_COMPATIBLE_STRUCT {
// If type_info is provided (from SkipAnyValue), use skipStruct directly
if typeInfo != nil {
skipStruct(ctx, typeInfo)
return
}
// Otherwise we need to read type info
- ti := ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeIDNum, err)
+ ti := readKnownTypeInfoForSkip(ctx, typeIDNum)
+ if ctx.HasError() {
+ return
+ }
skipStruct(ctx, ti)
return
}
- if internalID == ENUM {
+ if internalID == ENUM || internalID == NAMED_ENUM {
// Enum values are encoded as ordinal only (VarUint32Small7) for xlang.
- _ = ctx.buffer.ReadUint8(err)
+ _ = ctx.buffer.ReadVarUint32Small7(err)
return
}
if internalID == EXT || internalID == NAMED_EXT || internalID == TYPED_UNION || internalID == NAMED_UNION {
- typeInfo := ctx.TypeResolver().readTypeInfoWithTypeID(ctx.buffer, typeIDNum, err)
+ if typeInfo == nil {
+ typeInfo = readKnownTypeInfoForSkip(ctx, typeIDNum)
+ if ctx.HasError() {
+ return
+ }
+ }
if typeInfo != nil && typeInfo.Serializer != nil {
// Use the serializer to read and discard the value
var dummy any
@@ -569,11 +609,15 @@ func skipValue(ctx *ReadContext, fieldDef FieldDef, readRefFlag bool, isField bo
case INT16:
_ = ctx.buffer.ReadInt16(err)
case INT32:
- _ = ctx.buffer.ReadUint8(err)
+ _ = ctx.buffer.ReadInt32(err)
case VARINT32:
- _ = ctx.buffer.ReadUint8(err)
- case INT64, VARINT64, TAGGED_INT64:
+ _ = ctx.buffer.ReadVarint32(err)
+ case INT64:
+ _ = ctx.buffer.ReadInt64(err)
+ case VARINT64:
_ = ctx.buffer.ReadVarint64(err)
+ case TAGGED_INT64:
+ _ = ctx.buffer.ReadTaggedInt64(err)
// Floating point types
case BFLOAT16, FLOAT16:
@@ -676,7 +720,7 @@ func skipValue(ctx *ReadContext, fieldDef FieldDef, readRefFlag bool, isField bo
// Enum types
case ENUM:
- _ = ctx.buffer.ReadVarUint32(err)
+ _ = ctx.buffer.ReadVarUint32Small7(err)
// Unsigned integer types
case UINT8:
@@ -692,12 +736,7 @@ func skipValue(ctx *ReadContext, fieldDef FieldDef, readRefFlag bool, isField bo
case VAR_UINT64:
_ = ctx.buffer.ReadVarUint64(err)
case TAGGED_UINT64:
- firstInt32 := ctx.buffer.ReadInt32(err)
- if (firstInt32 & 1) != 0 {
- // 9-byte encoding
- _ = ctx.buffer.ReadUint64(err)
- }
- // Otherwise it's 4-byte encoding, already read
+ _ = ctx.buffer.ReadTaggedUint64(err)
// Unknown (polymorphic) type - read type info and skip dynamically
case UNKNOWN:
diff --git a/go/fory/skip_test.go b/go/fory/skip_test.go
new file mode 100644
index 0000000000..c60a40e355
--- /dev/null
+++ b/go/fory/skip_test.go
@@ -0,0 +1,132 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package fory
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestSkipEnumConsumesSmall7Ordinal(t *testing.T) {
+ f := New(WithXlang(true))
+ buf := NewByteBuffer(nil)
+ buf.WriteVarUint32Small7(128)
+ buf.WriteByte(0x7f)
+
+ f.readCtx.SetData(buf.Bytes())
+ skipValue(
+ f.readCtx,
+ FieldDef{typeSpec: NewSimpleTypeSpec(ENUM), nullable: true},
+ false,
+ false,
+ nil,
+ )
+ require.NoError(t, f.readCtx.CheckError())
+ require.Equal(t, 2, f.readCtx.Buffer().ReaderIndex())
+ require.Equal(t, byte(0x7f), f.readCtx.Buffer().ReadByte(f.readCtx.Err()))
+}
+
+func TestSkipPrimitiveConsumesExactEncoding(t *testing.T) {
+ tests := []struct {
+ name string
+ typeID TypeId
+ write func(*ByteBuffer)
+ }{
+ {
+ name: "int32",
+ typeID: INT32,
+ write: func(buf *ByteBuffer) { buf.WriteInt32(0x01020304) },
+ },
+ {
+ name: "varint32",
+ typeID: VARINT32,
+ write: func(buf *ByteBuffer) { buf.WriteVarint32(300) },
+ },
+ {
+ name: "int64",
+ typeID: INT64,
+ write: func(buf *ByteBuffer) { buf.WriteInt64(0x0102030405060708) },
+ },
+ {
+ name: "varint64",
+ typeID: VARINT64,
+ write: func(buf *ByteBuffer) { buf.WriteVarint64(1 << 35) },
+ },
+ {
+ name: "tagged_int64_small",
+ typeID: TAGGED_INT64,
+ write: func(buf *ByteBuffer) { buf.WriteTaggedInt64(1073741823) },
+ },
+ {
+ name: "tagged_int64_large",
+ typeID: TAGGED_INT64,
+ write: func(buf *ByteBuffer) { buf.WriteTaggedInt64(1 << 40) },
+ },
+ {
+ name: "tagged_uint64_small",
+ typeID: TAGGED_UINT64,
+ write: func(buf *ByteBuffer) { buf.WriteTaggedUint64(0x7fffffff) },
+ },
+ {
+ name: "tagged_uint64_large",
+ typeID: TAGGED_UINT64,
+ write: func(buf *ByteBuffer) { buf.WriteTaggedUint64(1 << 40) },
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ f := New(WithXlang(true))
+ buf := NewByteBuffer(nil)
+ tc.write(buf)
+ wantIndex := buf.WriterIndex()
+ buf.WriteByte(0x7f)
+
+ f.readCtx.SetData(buf.Bytes())
+ skipValue(
+ f.readCtx,
+ FieldDef{typeSpec: NewSimpleTypeSpec(tc.typeID), nullable: true},
+ false,
+ false,
+ nil,
+ )
+ require.NoError(t, f.readCtx.CheckError())
+ require.Equal(t, wantIndex, f.readCtx.Buffer().ReaderIndex())
+ require.Equal(t, byte(0x7f), f.readCtx.Buffer().ReadByte(f.readCtx.Err()))
+ })
+ }
+}
+
+func TestSkipMapRejectsInvalidChunkSize(t *testing.T) {
+ f := New(WithXlang(true))
+ buf := NewByteBuffer(nil)
+ buf.WriteLength(1)
+ buf.WriteByte(KEY_DECL_TYPE | VALUE_DECL_TYPE)
+ buf.WriteByte(2)
+
+ f.readCtx.SetData(buf.Bytes())
+ skipMap(
+ f.readCtx,
+ FieldDef{
+ typeSpec: NewMapTypeSpec(MAP, NewSimpleTypeSpec(INT32), NewSimpleTypeSpec(INT32)),
+ nullable: true,
+ },
+ )
+ require.Error(t, f.readCtx.CheckError())
+}
diff --git a/go/fory/stream.go b/go/fory/stream.go
index 420a6d4825..bb86689598 100644
--- a/go/fory/stream.go
+++ b/go/fory/stream.go
@@ -92,47 +92,26 @@ func (is *InputStream) Shrink() {
}
// DeserializeFromStream reads the next object from the stream into the provided value.
-// It uses a shared ReadContext for the lifetime of the InputStream, clearing
-// temporary state between calls but preserving the buffer and TypeResolver state.
+// It preserves the stream buffer while clearing root-scoped read metadata between calls.
func (f *Fory) DeserializeFromStream(is *InputStream, v any) error {
-
- // We only reset the temporary read state (like refTracker and outOfBand buffers),
- // NOT the buffer or the type mapping, which must persist.
- defer func() {
- f.readCtx.refReader.Reset()
- f.readCtx.outOfBandBuffers = nil
- f.readCtx.outOfBandIndex = 0
- f.readCtx.err = Error{}
- if f.readCtx.refResolver != nil {
- f.readCtx.refResolver.resetRead()
- }
- }()
-
- // Temporarily swap buffer
origBuffer := f.readCtx.buffer
f.readCtx.buffer = is.buffer
+ defer func() {
+ f.readCtx.buffer = origBuffer
+ f.resetReadState()
+ }()
- isNull := readHeader(f.readCtx)
+ readHeader(f.readCtx)
if f.readCtx.HasError() {
- f.readCtx.buffer = origBuffer
return f.readCtx.TakeError()
}
- if isNull {
- f.readCtx.buffer = origBuffer
- return nil
- }
-
target := reflect.ValueOf(v).Elem()
f.readCtx.ReadValue(target, RefModeTracking, true)
if f.readCtx.HasError() {
- f.readCtx.buffer = origBuffer
return f.readCtx.TakeError()
}
- // Restore original buffer
- f.readCtx.buffer = origBuffer
-
return nil
}
@@ -145,15 +124,11 @@ func (f *Fory) DeserializeFromReader(r io.Reader, v any) error {
// Always reset to enforce stateless semantics.
f.readCtx.buffer.ResetWithReader(r, 0)
- isNull := readHeader(f.readCtx)
+ readHeader(f.readCtx)
if f.readCtx.HasError() {
return f.readCtx.TakeError()
}
- if isNull {
- return nil
- }
-
target := reflect.ValueOf(v).Elem()
f.readCtx.ReadValue(target, RefModeTracking, true)
if f.readCtx.HasError() {
diff --git a/go/fory/stream_test.go b/go/fory/stream_test.go
index 098254587c..0c2503f3cb 100644
--- a/go/fory/stream_test.go
+++ b/go/fory/stream_test.go
@@ -135,6 +135,29 @@ func TestStreamDeserializationEOF(t *testing.T) {
}
}
+func TestDeserializeFromStreamClearsReadMetadataOnError(t *testing.T) {
+ f := New(WithCompatible(true))
+ f.typeResolver.metaStringResolver.dynamicIDToEnumString =
+ append(f.typeResolver.metaStringResolver.dynamicIDToEnumString, emptyMetaStringBytes)
+ f.metaContext.readTypeInfos = append(f.metaContext.readTypeInfos, &TypeInfo{})
+
+ stream := NewInputStream(bytes.NewReader(nil))
+ var out int32
+ err := f.DeserializeFromStream(stream, &out)
+ if err == nil {
+ t.Fatal("Expected error on empty stream, got nil")
+ }
+ if len(f.typeResolver.metaStringResolver.dynamicIDToEnumString) != 0 {
+ t.Fatalf(
+ "expected stream root cleanup to clear metastring refs, got %d",
+ len(f.typeResolver.metaStringResolver.dynamicIDToEnumString),
+ )
+ }
+ if len(f.metaContext.readTypeInfos) != 0 {
+ t.Fatalf("expected stream root cleanup to clear type metadata, got %d", len(f.metaContext.readTypeInfos))
+ }
+}
+
func TestInputStreamSequential(t *testing.T) {
f := New()
// Register type in compatible mode to test Meta Sharing across sequential reads
diff --git a/go/fory/string.go b/go/fory/string.go
index 7c7a41bba3..9165aaf66c 100644
--- a/go/fory/string.go
+++ b/go/fory/string.go
@@ -106,7 +106,9 @@ func readUTF16LE(buf *ByteBuffer, byteCount int, err *Error) string {
func readUTF8(buf *ByteBuffer, size int, err *Error) string {
data := buf.ReadBinary(size, err)
- return string(data) // Direct UTF-8 conversion
+ // Go intentionally keeps direct string conversion here. Rust is the runtime that checks UTF-8
+ // string payloads by default; Go preserves its platform behavior for invalid byte sequences.
+ return string(data)
}
// ============================================================================
@@ -152,8 +154,8 @@ func (s stringSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool,
return
}
}
- if readType {
- _ = ctx.buffer.ReadUint8(err)
+ if readType && !ctx.readExpectedTypeID(STRING) {
+ return
}
if ctx.HasError() {
return
@@ -194,8 +196,8 @@ func (s ptrToStringSerializer) Read(ctx *ReadContext, refMode RefMode, readType
return
}
}
- if readType {
- _ = ctx.buffer.ReadUint8(err)
+ if readType && !ctx.readExpectedTypeID(STRING) {
+ return
}
if ctx.HasError() {
return
diff --git a/go/fory/struct.go b/go/fory/struct.go
index a609113cbb..b2d718db86 100644
--- a/go/fory/struct.go
+++ b/go/fory/struct.go
@@ -1541,13 +1541,13 @@ func (s *structSerializer) ReadData(ctx *ReadContext, value reflect.Value) {
}
switch field.DispatchId {
case PrimitiveVarint32DispatchId:
- storeFieldValue(field.Kind, fieldPtr, optInfo, buf.UnsafeReadVarint32())
+ storeFieldValue(field.Kind, fieldPtr, optInfo, buf.UnsafeReadVarint32(err))
case PrimitiveVarint64DispatchId:
storeFieldValue(field.Kind, fieldPtr, optInfo, buf.UnsafeReadVarint64())
case PrimitiveIntDispatchId:
storeFieldValue(field.Kind, fieldPtr, optInfo, int(buf.UnsafeReadVarint64()))
case PrimitiveVarUint32DispatchId:
- storeFieldValue(field.Kind, fieldPtr, optInfo, buf.UnsafeReadVarUint32())
+ storeFieldValue(field.Kind, fieldPtr, optInfo, buf.UnsafeReadVarUint32(err))
case PrimitiveVarUint64DispatchId:
storeFieldValue(field.Kind, fieldPtr, optInfo, buf.UnsafeReadVarUint64())
case PrimitiveUintDispatchId:
diff --git a/go/fory/struct_test.go b/go/fory/struct_test.go
index 9dc03d1611..eb9287a0c2 100644
--- a/go/fory/struct_test.go
+++ b/go/fory/struct_test.go
@@ -604,14 +604,12 @@ func TestSkipAnyValueReadsSharedTypeMeta(t *testing.T) {
f.resetReadState()
f.readCtx.SetData(buf.Bytes())
- isNull := readHeader(f.readCtx)
- require.False(t, isNull)
+ readHeader(f.readCtx)
SkipAnyValue(f.readCtx, true)
require.NoError(t, f.readCtx.CheckError())
f.resetReadState()
- isNull = readHeader(f.readCtx)
- require.False(t, isNull)
+ readHeader(f.readCtx)
var out any
f.readCtx.ReadValue(reflect.ValueOf(&out).Elem(), RefModeTracking, true)
@@ -622,6 +620,17 @@ func TestSkipAnyValueReadsSharedTypeMeta(t *testing.T) {
require.Equal(t, "ok", result.Name)
}
+func TestReadHeaderRejectsOutOfBandWithoutBuffers(t *testing.T) {
+ f := New(WithXlang(true))
+ f.readCtx.SetData([]byte{XLangFlag | OutOfBandFlag})
+
+ readHeader(f.readCtx)
+
+ err := f.readCtx.TakeError()
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "out-of-band buffers")
+}
+
func TestFloat16StructField(t *testing.T) {
type StructWithFloat16 struct {
F16 float16.Float16
diff --git a/go/fory/tests/metastring_resolver_test.go b/go/fory/tests/metastring_resolver_test.go
index fa833b8242..b0b8202204 100644
--- a/go/fory/tests/metastring_resolver_test.go
+++ b/go/fory/tests/metastring_resolver_test.go
@@ -46,7 +46,7 @@ func TestMetaStringResolver(t *testing.T) {
// Test 2: Manually constructed MetaStringBytes
data2 := []byte{0xBF, 0x05, 0xA4, 0x71, 0xA9, 0x92, 0x53, 0x96, 0xA6, 0x49, 0x4F, 0x72, 0x9C, 0x68, 0x29, 0x80}
- metaBytes2 := fory.NewMetaStringBytes(data2, int64(-5456063526933366015))
+ metaBytes2 := fory.NewMetaStringBytes(data2, fory.ComputeMetaStringHash(data2, meta.LOWER_SPECIAL))
resolver.WriteMetaStringBytes(buffer, metaBytes2, &bufErr)
if bufErr.HasError() {
t.Fatalf("write failed: %v", bufErr.Error())
diff --git a/go/fory/time.go b/go/fory/time.go
index 676efe00e5..1a389ccbbf 100644
--- a/go/fory/time.go
+++ b/go/fory/time.go
@@ -159,8 +159,8 @@ func (s dateSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, h
return
}
}
- if readType {
- _ = ctx.buffer.ReadUint8(err)
+ if readType && !ctx.readExpectedTypeID(DATE) {
+ return
}
if ctx.HasError() {
return
@@ -268,8 +268,8 @@ func (s durationSerializer) Read(ctx *ReadContext, refMode RefMode, readType boo
return
}
}
- if readType {
- _ = ctx.buffer.ReadUint8(err)
+ if readType && !ctx.readExpectedTypeID(DURATION) {
+ return
}
if ctx.HasError() {
return
@@ -311,8 +311,8 @@ func (s timeSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, h
return
}
}
- if readType {
- _ = ctx.buffer.ReadUint8(err)
+ if readType && !ctx.readExpectedTypeID(TIMESTAMP) {
+ return
}
if ctx.HasError() {
return
diff --git a/go/fory/type_def.go b/go/fory/type_def.go
index 81bfa76e87..e1ab7c3eb4 100644
--- a/go/fory/type_def.go
+++ b/go/fory/type_def.go
@@ -18,10 +18,7 @@
package fory
import (
- "bytes"
- "compress/zlib"
"fmt"
- "io"
"reflect"
"strings"
@@ -29,17 +26,17 @@ import (
)
const (
- META_SIZE_MASK = 0xFF
- COMPRESS_META_FLAG = 0b1 << 9
- HAS_FIELDS_META_FLAG = 0b1 << 8
- NUM_HASH_BITS = 50
+ META_SIZE_MASK = 0xFF
+ COMPRESS_META_FLAG = 0b1 << 8
+ RESERVED_META_BITS = 0b111 << 9
+ NUM_HASH_BITS = 52
)
/*
TypeDef represents a transportable value object containing type information and field definitions.
typeDef are layout as following:
- - first 8 bytes: global header (50 bits hash + 1 bit compress flag + write fields meta + 8 bits meta size)
- - next 1 byte: meta header (2 bits reserved + 1 bit register by name flag + 5 bits num fields)
+ - first 8 bytes: global header (52 bits hash + 1 bit compress flag + 8 bits meta size)
+ - next 1 byte: kind header
- next variable bytes: type id (varint) or ns name + type name
- next variable bytes: field definitions (see below)
*/
@@ -218,21 +215,21 @@ func (td *TypeDef) buildTypeInfoWithResolver(resolver *TypeResolver) (TypeInfo,
type_ := td.type_
var serializer Serializer
- // For extension types, use the registered serializer if available
- if type_ != nil && resolver != nil {
- if existingSerializer, ok := resolver.typeToSerializers[type_]; ok {
- // Only use registered serializer for extension types (not struct types)
- if _, isExt := existingSerializer.(*extensionSerializerAdapter); isExt {
- serializer = existingSerializer
- } else if ptrSer, isPtrSer := existingSerializer.(*ptrToValueSerializer); isPtrSer {
- if _, isExtInner := ptrSer.valueSerializer.(*extensionSerializerAdapter); isExtInner {
- serializer = existingSerializer
- }
+ if !isStructTypeId(TypeId(td.typeId)) {
+ if type_ != nil && resolver != nil {
+ var err error
+ serializer, err = resolver.getSerializerByType(type_, false)
+ if err != nil {
+ return TypeInfo{}, err
}
}
- }
- // If no extension serializer, create struct serializer
- if serializer == nil {
+ if serializer == nil && resolver != nil {
+ serializer = resolver.getSerializerByTypeID(td.typeId)
+ }
+ if serializer == nil {
+ return TypeInfo{}, fmt.Errorf("no serializer registered for TypeDef kind %d", td.typeId)
+ }
+ } else {
if type_ == nil {
// Unknown struct type - use skipStructSerializer to skip data
serializer = &skipStructSerializer{
@@ -287,6 +284,9 @@ func readTypeDef(fory *Fory, buffer *ByteBuffer, header int64, err *Error) *Type
}
func skipTypeDef(buffer *ByteBuffer, header int64, err *Error) {
+ // Header-cache hits intentionally treat the current body as opaque bytes and skip by the size in
+ // the current header. Parsed TypeDefs are published to the cache only after successful body parse
+ // and 52-bit body-hash validation; cache hits must not reparse or rehash that body.
sz := int(header & META_SIZE_MASK)
if sz == META_SIZE_MASK {
sz += int(buffer.ReadVarUint32(err))
@@ -304,7 +304,14 @@ func readPkgName(buffer *ByteBuffer, namespaceDecoder *meta.Decoder, err *Error)
encodingFlags := header & 0b11 // 2 bits for encoding
size := header >> 2 // 6 bits for size
if size == BIG_NAME_THRESHOLD {
- size = int(buffer.ReadVarUint32Small7(err)) + BIG_NAME_THRESHOLD
+ extra := buffer.ReadVarUint32Small7(err)
+ if err.HasError() {
+ return "", err.TakeError()
+ }
+ if uint64(extra) > uint64(MaxInt-BIG_NAME_THRESHOLD) {
+ return "", fmt.Errorf("invalid TypeDef namespace length")
+ }
+ size = int(extra) + BIG_NAME_THRESHOLD
}
var encoding meta.Encoding
@@ -319,6 +326,9 @@ func readPkgName(buffer *ByteBuffer, namespaceDecoder *meta.Decoder, err *Error)
return "", fmt.Errorf("invalid package encoding flags: %d", encodingFlags)
}
+ if size > buffer.remaining() {
+ return "", fmt.Errorf("TypeDef namespace length %d exceeds remaining metadata %d", size, buffer.remaining())
+ }
data := make([]byte, size)
if _, err := buffer.Read(data); err != nil {
return "", err
@@ -335,7 +345,14 @@ func readTypeName(buffer *ByteBuffer, typeNameDecoder *meta.Decoder, err *Error)
encodingFlags := header & 0b11 // 2 bits for encoding
size := header >> 2 // 6 bits for size
if size == BIG_NAME_THRESHOLD {
- size = int(buffer.ReadVarUint32Small7(err)) + BIG_NAME_THRESHOLD
+ extra := buffer.ReadVarUint32Small7(err)
+ if err.HasError() {
+ return "", err.TakeError()
+ }
+ if uint64(extra) > uint64(MaxInt-BIG_NAME_THRESHOLD) {
+ return "", fmt.Errorf("invalid TypeDef typename length")
+ }
+ size = int(extra) + BIG_NAME_THRESHOLD
}
var encoding meta.Encoding
@@ -352,6 +369,9 @@ func readTypeName(buffer *ByteBuffer, typeNameDecoder *meta.Decoder, err *Error)
return "", fmt.Errorf("invalid typename encoding flags: %d", encodingFlags)
}
+ if size > buffer.remaining() {
+ return "", fmt.Errorf("TypeDef typename length %d exceeds remaining metadata %d", size, buffer.remaining())
+ }
data := make([]byte, size)
if _, err := buffer.Read(data); err != nil {
return "", err
@@ -362,16 +382,18 @@ func readTypeName(buffer *ByteBuffer, typeNameDecoder *meta.Decoder, err *Error)
// buildTypeDef constructs a TypeDef from a value
func buildTypeDef(fory *Fory, value reflect.Value) (*TypeDef, error) {
- fieldDefs, err := buildFieldDefs(fory, value)
- if err != nil {
- return nil, fmt.Errorf("failed to extract field infos: %w", err)
- }
-
infoPtr, err := fory.typeResolver.getTypeInfo(value, true)
if err != nil {
return nil, fmt.Errorf("failed to get type info for value %v: %w", value, err)
}
typeId := uint32(infoPtr.TypeID)
+ var fieldDefs []FieldDef
+ if isStructTypeId(TypeId(typeId)) {
+ fieldDefs, err = buildFieldDefs(fory, value)
+ if err != nil {
+ return nil, fmt.Errorf("failed to extract field infos: %w", err)
+ }
+ }
registerByName := IsNamespacedType(TypeId(typeId))
typeDef := NewTypeDef(typeId, infoPtr.UserTypeID, infoPtr.PkgPathBytes, infoPtr.NameBytes, registerByName, false, fieldDefs)
@@ -616,7 +638,9 @@ func newTypeSpecForTypeID(typeID TypeId) (*TypeSpec, error) {
const (
SmallNumFieldsThreshold = 31
- REGISTER_BY_NAME_FLAG = 0b1 << 5
+ RegisterByNameFlag = 0b0010_0000
+ CompatibleTypeDefFlag = 0b0100_0000
+ StructTypeDefFlag = 0b1000_0000
FieldNameSizeThreshold = 15
)
@@ -648,8 +672,8 @@ func getFieldNameEncodingIndex(encoding meta.Encoding) int {
/*
encodingTypeDef encodes a TypeDef into binary format according to the specification
typeDef are layout as following:
-- first 8 bytes: global header (50 bits hash + 1 bit compress flag + write fields meta + 8 bits meta size)
-- next 1 byte: meta header (2 bits reserved + 1 bit register by name flag + 5 bits num fields)
+- first 8 bytes: global header (52 bits hash + 1 bit compress flag + 8 bits meta size)
+- next 1 byte: kind header
- next variable bytes: type id (varint) or ns name + type name
- next variable bytes: field defs (see below)
*/
@@ -758,20 +782,21 @@ func encodingTypeDef(typeResolver *TypeResolver, typeDef *TypeDef) ([]byte, erro
return nil, fmt.Errorf("failed to write typename: %w", err)
}
} else {
- buffer.WriteUint8(uint8(typeDef.typeId))
if typeDef.userTypeId == invalidUserTypeID {
return nil, fmt.Errorf("missing user type ID for typeID %d", typeDef.typeId)
}
buffer.WriteVarUint32(typeDef.userTypeId)
}
- if err := writeFieldDefs(typeResolver, buffer, typeDef.fieldDefs); err != nil {
- return nil, fmt.Errorf("failed to write fields def: %w", err)
+ if isStructTypeId(TypeId(typeDef.typeId)) {
+ if err := writeFieldDefs(typeResolver, buffer, typeDef.fieldDefs); err != nil {
+ return nil, fmt.Errorf("failed to write fields def: %w", err)
+ }
+ } else if len(typeDef.fieldDefs) != 0 {
+ return nil, fmt.Errorf("non-struct TypeDef %d cannot carry field metadata", typeDef.typeId)
}
- // Temporary xlang behavior: keep TypeMeta uncompressed.
- // Some runtimes still do not support TypeMeta decompression.
- result, err := prependGlobalHeader(buffer, false, len(typeDef.fieldDefs) > 0)
+ result, err := prependGlobalHeader(buffer, false)
if err != nil {
return nil, fmt.Errorf("failed to write global binary header: %w", err)
}
@@ -780,16 +805,11 @@ func encodingTypeDef(typeResolver *TypeResolver, typeDef *TypeDef) ([]byte, erro
}
// prependGlobalHeader writes the 8-byte global header
-func prependGlobalHeader(buffer *ByteBuffer, isCompressed bool, hasFieldsMeta bool) (*ByteBuffer, error) {
+func prependGlobalHeader(buffer *ByteBuffer, isCompressed bool) (*ByteBuffer, error) {
var header uint64
metaSize := buffer.WriterIndex()
- hashValue := Murmur3Sum64WithSeed(buffer.GetByteSlice(0, metaSize), 47)
- header |= hashValue << (64 - NUM_HASH_BITS)
-
- if hasFieldsMeta {
- header |= HAS_FIELDS_META_FLAG
- }
+ header |= typeDefHeaderHash(buffer.GetByteSlice(0, metaSize))
if isCompressed {
header |= COMPRESS_META_FLAG
@@ -814,25 +834,82 @@ func prependGlobalHeader(buffer *ByteBuffer, isCompressed bool, hasFieldsMeta bo
// writeMetaHeader writes the 1-byte meta header
func writeMetaHeader(buffer *ByteBuffer, typeDef *TypeDef) error {
- // 2 bits reserved + 1 bit register by name flag + 5 bits num fields
offset := buffer.writerIndex
if err := buffer.WriteByte(0xFF); err != nil {
return err
}
fieldInfos := typeDef.fieldDefs
- header := len(fieldInfos)
- if header > SmallNumFieldsThreshold {
- header = SmallNumFieldsThreshold
- buffer.WriteVarUint32(uint32(len(fieldInfos) - SmallNumFieldsThreshold))
- }
- if typeDef.registerByName {
- header |= REGISTER_BY_NAME_FLAG
+ typeID := TypeId(typeDef.typeId)
+ var header int
+ if isStructTypeId(typeID) {
+ fieldCount := len(fieldInfos)
+ inlineFieldCount := fieldCount
+ if inlineFieldCount > SmallNumFieldsThreshold {
+ inlineFieldCount = SmallNumFieldsThreshold
+ }
+ header = StructTypeDefFlag | inlineFieldCount
+ if typeID == COMPATIBLE_STRUCT || typeID == NAMED_COMPATIBLE_STRUCT {
+ header |= CompatibleTypeDefFlag
+ }
+ if fieldCount >= SmallNumFieldsThreshold {
+ buffer.WriteVarUint32(uint32(fieldCount - SmallNumFieldsThreshold))
+ }
+ if typeDef.registerByName {
+ header |= RegisterByNameFlag
+ }
+ } else {
+ if len(fieldInfos) != 0 {
+ return fmt.Errorf("non-struct TypeDef %d cannot carry field metadata", typeDef.typeId)
+ }
+ kindCode, err := xlangNonStructKindCode(typeID)
+ if err != nil {
+ return err
+ }
+ header = kindCode
}
buffer.PutUint8(offset, uint8(header))
return nil
}
+func xlangNonStructKindCode(typeID TypeId) (int, error) {
+ switch typeID {
+ case ENUM:
+ return 0, nil
+ case NAMED_ENUM:
+ return 1, nil
+ case EXT:
+ return 2, nil
+ case NAMED_EXT:
+ return 3, nil
+ case TYPED_UNION:
+ return 4, nil
+ case NAMED_UNION:
+ return 5, nil
+ default:
+ return 0, fmt.Errorf("unsupported TypeDef kind %d", typeID)
+ }
+}
+
+func xlangNonStructTypeID(kindCode int) (TypeId, error) {
+ switch kindCode {
+ case 0:
+ return ENUM, nil
+ case 1:
+ return NAMED_ENUM, nil
+ case 2:
+ return EXT, nil
+ case 3:
+ return NAMED_EXT, nil
+ case 4:
+ return TYPED_UNION, nil
+ case 5:
+ return NAMED_UNION, nil
+ default:
+ return UNKNOWN, fmt.Errorf("unsupported TypeDef kind code %d", kindCode)
+ }
+}
+
// writeFieldDefs writes field definitions according to the specification
// field def layout as following:
// - first 1 byte: header (2 bits field name encoding + 4 bits size + nullability flag + ref tracking flag)
@@ -913,8 +990,8 @@ func writeFieldDef(typeResolver *TypeResolver, buffer *ByteBuffer, field FieldDe
/*
decodeTypeDef decodes a TypeDef from the buffer
typeDef are layout as following:
- - first 8 bytes: global header (50 bits hash + 1 bit compress flag + write fields meta + 8 bits meta size)
- - next 1 byte: meta header (2 bits reserved + 1 bit register by name flag + 5 bits num fields)
+ - first 8 bytes: global header (52 bits hash + 1 bit compress flag + 8 bits meta size)
+ - next 1 byte: kind header
- next variable bytes: type id (varint) or ns name + type name
- next variable bytes: field definitions (see below)
*/
@@ -922,49 +999,84 @@ func decodeTypeDef(fory *Fory, buffer *ByteBuffer, header int64) (*TypeDef, erro
// ReadData 8-byte global header
var bufErr Error
globalHeader := header
- hasFieldsMeta := (globalHeader & HAS_FIELDS_META_FLAG) != 0
+ if (globalHeader & RESERVED_META_BITS) != 0 {
+ return nil, fmt.Errorf("invalid TypeDef global header")
+ }
isCompressed := (globalHeader & COMPRESS_META_FLAG) != 0
+ if isCompressed {
+ return nil, fmt.Errorf("compressed xlang TypeDef is not supported")
+ }
metaSizeBits := int(globalHeader & META_SIZE_MASK)
metaSize := metaSizeBits
extraMetaSize := 0
if metaSizeBits == META_SIZE_MASK {
- extraMetaSize = int(buffer.ReadVarUint32(&bufErr))
+ extra := buffer.ReadVarUint32(&bufErr)
+ if bufErr.HasError() {
+ return nil, bufErr.TakeError()
+ }
+ if uint64(extra) > uint64(MaxInt-metaSize) {
+ return nil, fmt.Errorf("invalid TypeDef metadata size")
+ }
+ extraMetaSize = int(extra)
metaSize += extraMetaSize
}
+ if metaSize > fory.config.MaxBinarySize {
+ return nil, MaxBinarySizeExceededError(metaSize, fory.config.MaxBinarySize)
+ }
// Store the encoded bytes for the TypeDef (including meta header and metadata)
encodedMeta := buffer.ReadBinary(metaSize, &bufErr)
if bufErr.HasError() {
return nil, bufErr.TakeError()
}
- decodedMeta := encodedMeta
- if isCompressed {
- decodedMetaBytes, err := decompressMeta(encodedMeta)
- if err != nil {
- return nil, err
- }
- decodedMeta = decodedMetaBytes
- }
- metaBuffer := NewByteBuffer(decodedMeta)
+ metaBuffer := NewByteBuffer(encodedMeta)
var metaErr Error
// ReadData 1-byte meta header
metaHeaderByte := metaBuffer.ReadByte(&metaErr)
- // Extract field count from lower 5 bits
- fieldCount := int(metaHeaderByte & SmallNumFieldsThreshold)
- if fieldCount == SmallNumFieldsThreshold {
- fieldCount += int(metaBuffer.ReadVarUint32(&metaErr))
- }
- if fieldCount > fory.config.MaxTypeFields || fieldCount > metaBuffer.remaining() {
- return nil, fmt.Errorf("field count exceeds maximum allowed limit or available buffer size")
- }
- registeredByName := (metaHeaderByte & REGISTER_BY_NAME_FLAG) != 0
+ isStruct := (metaHeaderByte & StructTypeDefFlag) != 0
+ fieldCount := 0
+ registeredByName := false
// ReadData name or type ID according to the registerByName flag
var typeId uint32
userTypeId := invalidUserTypeID
var nsBytes, nameBytes *MetaStringBytes
var type_ reflect.Type
+ if isStruct {
+ registeredByName = (metaHeaderByte & RegisterByNameFlag) != 0
+ fieldCount = int(metaHeaderByte & SmallNumFieldsThreshold)
+ if fieldCount == SmallNumFieldsThreshold {
+ fieldCount += int(metaBuffer.ReadVarUint32(&metaErr))
+ }
+ if metaErr.HasError() {
+ return nil, metaErr.TakeError()
+ }
+ if fieldCount > fory.config.MaxTypeFields || fieldCount > metaBuffer.remaining() {
+ return nil, fmt.Errorf("field count exceeds maximum allowed limit or available buffer size")
+ }
+ if registeredByName {
+ if (metaHeaderByte & CompatibleTypeDefFlag) != 0 {
+ typeId = uint32(NAMED_COMPATIBLE_STRUCT)
+ } else {
+ typeId = uint32(NAMED_STRUCT)
+ }
+ } else if (metaHeaderByte & CompatibleTypeDefFlag) != 0 {
+ typeId = uint32(COMPATIBLE_STRUCT)
+ } else {
+ typeId = uint32(STRUCT)
+ }
+ } else {
+ if (metaHeaderByte & 0b0111_0000) != 0 {
+ return nil, fmt.Errorf("invalid TypeDef kind header")
+ }
+ kindType, err := xlangNonStructTypeID(int(metaHeaderByte & 0b1111))
+ if err != nil {
+ return nil, err
+ }
+ typeId = uint32(kindType)
+ registeredByName = IsNamespacedType(kindType)
+ }
if registeredByName {
// ReadData namespace and type name for namespaced types
// NOTE: TypeDefs use simple name format, not meta string format with dynamic IDs
@@ -974,7 +1086,14 @@ func decodeTypeDef(fory *Fory, buffer *ByteBuffer, header int64) (*TypeDef, erro
nsEncodingFlags := nsHeader & 0b11 // 2 bits for encoding
nsSize := nsHeader >> 2 // 6 bits for size
if nsSize == BIG_NAME_THRESHOLD {
- nsSize = int(metaBuffer.ReadVarUint32Small7(&metaErr)) + BIG_NAME_THRESHOLD
+ extra := metaBuffer.ReadVarUint32Small7(&metaErr)
+ if metaErr.HasError() {
+ return nil, metaErr.TakeError()
+ }
+ if uint64(extra) > uint64(MaxInt-BIG_NAME_THRESHOLD) {
+ return nil, fmt.Errorf("invalid TypeDef namespace length")
+ }
+ nsSize = int(extra) + BIG_NAME_THRESHOLD
}
// Java pkg encoding: 0=UTF_8, 1=ALL_TO_LOWER_SPECIAL, 2=LOWER_UPPER_DIGIT_SPECIAL
@@ -989,6 +1108,9 @@ func decodeTypeDef(fory *Fory, buffer *ByteBuffer, header int64) (*TypeDef, erro
default:
return nil, fmt.Errorf("invalid package encoding flags: %d", nsEncodingFlags)
}
+ if nsSize > metaBuffer.remaining() {
+ return nil, fmt.Errorf("TypeDef namespace length %d exceeds remaining metadata %d", nsSize, metaBuffer.remaining())
+ }
nsData := make([]byte, nsSize)
if _, err := metaBuffer.Read(nsData); err != nil {
return nil, fmt.Errorf("failed to read namespace data: %w", err)
@@ -1000,7 +1122,14 @@ func decodeTypeDef(fory *Fory, buffer *ByteBuffer, header int64) (*TypeDef, erro
typeEncodingFlags := typeHeader & 0b11 // 2 bits for encoding
typeSize := typeHeader >> 2 // 6 bits for size
if typeSize == BIG_NAME_THRESHOLD {
- typeSize = int(metaBuffer.ReadVarUint32Small7(&metaErr)) + BIG_NAME_THRESHOLD
+ extra := metaBuffer.ReadVarUint32Small7(&metaErr)
+ if metaErr.HasError() {
+ return nil, metaErr.TakeError()
+ }
+ if uint64(extra) > uint64(MaxInt-BIG_NAME_THRESHOLD) {
+ return nil, fmt.Errorf("invalid TypeDef typename length")
+ }
+ typeSize = int(extra) + BIG_NAME_THRESHOLD
}
// Java typename encoding: 0=UTF_8, 1=ALL_TO_LOWER_SPECIAL, 2=LOWER_UPPER_DIGIT_SPECIAL, 3=FIRST_TO_LOWER_SPECIAL
@@ -1017,6 +1146,9 @@ func decodeTypeDef(fory *Fory, buffer *ByteBuffer, header int64) (*TypeDef, erro
default:
return nil, fmt.Errorf("invalid typename encoding flags: %d", typeEncodingFlags)
}
+ if typeSize > metaBuffer.remaining() {
+ return nil, fmt.Errorf("TypeDef typename length %d exceeds remaining metadata %d", typeSize, metaBuffer.remaining())
+ }
typeData := make([]byte, typeSize)
if _, err := metaBuffer.Read(typeData); err != nil {
return nil, fmt.Errorf("failed to read typename data: %w", err)
@@ -1049,7 +1181,9 @@ func decodeTypeDef(fory *Fory, buffer *ByteBuffer, header int64) (*TypeDef, erro
if fallbackInfo, fallbackExists := fory.typeResolver.namedTypeToTypeInfo[nameKey]; fallbackExists {
info = fallbackInfo
exists = true
- fory.typeResolver.nsTypeToTypeInfo[nsTypeKey{nsBytes.Hashcode, nameBytes.Hashcode}] = info
+ if len(fory.typeResolver.nsTypeToTypeInfo) < maxCachedNamedTypeInfos {
+ fory.typeResolver.nsTypeToTypeInfo[nsTypeKey{nsBytes.Hashcode, nameBytes.Hashcode}] = info
+ }
}
}
if exists {
@@ -1059,18 +1193,19 @@ func decodeTypeDef(fory *Fory, buffer *ByteBuffer, header int64) (*TypeDef, erro
if type_.Kind() == reflect.Ptr {
type_ = type_.Elem()
}
- typeId = uint32(info.TypeID)
+ if uint32(info.TypeID) != typeId {
+ return nil, fmt.Errorf("TypeDef kind does not match registered type metadata")
+ }
userTypeId = info.UserTypeID
} else {
- // Type not registered - use NAMED_STRUCT as default typeId
- // The type_ will remain nil and will be set from field definitions later
- typeId = uint32(NAMED_STRUCT)
type_ = nil
}
} else {
- typeId = uint32(metaBuffer.ReadUint8(&metaErr))
userTypeId = metaBuffer.ReadVarUint32(&metaErr)
if info, exists := fory.typeResolver.userTypeIdToTypeInfo[userTypeId]; exists {
+ if uint32(info.TypeID) != typeId {
+ return nil, fmt.Errorf("TypeDef kind does not match registered type metadata")
+ }
type_ = info.Type
} else if info, exists := fory.typeResolver.typeIDToTypeInfo[typeId]; exists {
type_ = info.Type
@@ -1083,14 +1218,24 @@ func decodeTypeDef(fory *Fory, buffer *ByteBuffer, header int64) (*TypeDef, erro
// ReadData fields information
fieldInfos := make([]FieldDef, fieldCount)
- if hasFieldsMeta {
- for i := 0; i < fieldCount; i++ {
- fieldInfo, err := readFieldDef(fory.typeResolver, metaBuffer)
- if err != nil {
- return nil, fmt.Errorf("failed to read field def %d: %w", i, err)
- }
- fieldInfos[i] = fieldInfo
+ for i := 0; i < fieldCount; i++ {
+ fieldInfo, err := readFieldDef(fory.typeResolver, metaBuffer)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read field def %d: %w", i, err)
}
+ fieldInfos[i] = fieldInfo
+ }
+ if !isStruct && len(fieldInfos) != 0 {
+ return nil, fmt.Errorf("non-struct TypeDef cannot carry field metadata")
+ }
+ if metaErr.HasError() {
+ return nil, metaErr.TakeError()
+ }
+ if remaining := metaBuffer.remaining(); remaining != 0 {
+ return nil, fmt.Errorf("TypeDef metadata body has %d trailing bytes", remaining)
+ }
+ if err := validateParsedTypeDefHash(globalHeader, metaSizeBits, extraMetaSize, encodedMeta); err != nil {
+ return nil, err
}
encoded := buildTypeDefEncoded(globalHeader, metaSizeBits, extraMetaSize, encodedMeta)
@@ -1130,17 +1275,32 @@ func buildTypeDefEncoded(header int64, metaSizeBits, extraMetaSize int, metaByte
return buffer.Bytes()
}
-func decompressMeta(encoded []byte) ([]byte, error) {
- reader, err := zlib.NewReader(bytes.NewReader(encoded))
- if err != nil {
- return nil, fmt.Errorf("failed to create meta decompressor: %w", err)
+func typeDefHeaderHash(data []byte) uint64 {
+ hash := int64(Murmur3Sum64WithSeed(data, 47) << (64 - NUM_HASH_BITS))
+ if hash < 0 {
+ hash = -hash
}
- defer reader.Close()
- decoded, err := io.ReadAll(reader)
- if err != nil {
- return nil, fmt.Errorf("failed to decompress meta: %w", err)
+ hashMask := ^uint64(0)
+ hashMask <<= uint(64 - NUM_HASH_BITS)
+ return uint64(hash) & hashMask
+}
+
+func validateParsedTypeDefHash(header int64, metaSizeBits, extraMetaSize int, encoded []byte) error {
+ size := metaSizeBits
+ if size == META_SIZE_MASK {
+ size += extraMetaSize
}
- return decoded, nil
+ if len(encoded) != size {
+ return fmt.Errorf("invalid TypeDef encoded size")
+ }
+ hashMask := ^uint64(0)
+ hashMask <<= uint(64 - NUM_HASH_BITS)
+ expectedHeaderHash := typeDefHeaderHash(encoded)
+ actualHeaderHash := uint64(header) & hashMask
+ if expectedHeaderHash != actualHeaderHash {
+ return fmt.Errorf("invalid TypeDef metadata hash")
+ }
+ return nil
}
/*
@@ -1174,7 +1334,14 @@ func readFieldDef(typeResolver *TypeResolver, buffer *ByteBuffer) (FieldDef, err
// Read tag ID
tagID := sizeBits
if sizeBits == 0x0F {
- tagID = FieldNameSizeThreshold + int(buffer.ReadVarUint32(&bufErr))
+ extra := buffer.ReadVarUint32(&bufErr)
+ if bufErr.HasError() {
+ return FieldDef{}, bufErr.TakeError()
+ }
+ if uint64(extra) > uint64(MaxInt-FieldNameSizeThreshold) {
+ return FieldDef{}, fmt.Errorf("invalid TypeDef field tag ID")
+ }
+ tagID = FieldNameSizeThreshold + int(extra)
}
// Read field type
@@ -1182,6 +1349,9 @@ func readFieldDef(typeResolver *TypeResolver, buffer *ByteBuffer) (FieldDef, err
if err != nil {
return FieldDef{}, err
}
+ if bufErr.HasError() {
+ return FieldDef{}, bufErr.TakeError()
+ }
return FieldDef{
name: "", // No field name when using tag ID
@@ -1197,7 +1367,14 @@ func readFieldDef(typeResolver *TypeResolver, buffer *ByteBuffer) (FieldDef, err
nameEncoding := fieldNameEncodings[nameEncodingFlag]
nameLen := sizeBits
if nameLen == 0x0F {
- nameLen = FieldNameSizeThreshold + int(buffer.ReadVarUint32(&bufErr))
+ extra := buffer.ReadVarUint32(&bufErr)
+ if bufErr.HasError() {
+ return FieldDef{}, bufErr.TakeError()
+ }
+ if uint64(extra) > uint64(MaxInt-FieldNameSizeThreshold) {
+ return FieldDef{}, fmt.Errorf("invalid TypeDef field name length")
+ }
+ nameLen = FieldNameSizeThreshold + int(extra)
} else {
nameLen++ // Adjust for 1-based encoding
}
@@ -1207,9 +1384,18 @@ func readFieldDef(typeResolver *TypeResolver, buffer *ByteBuffer) (FieldDef, err
if err != nil {
return FieldDef{}, err
}
+ if bufErr.HasError() {
+ return FieldDef{}, bufErr.TakeError()
+ }
// Read field name based on encoding
+ if nameLen > buffer.remaining() {
+ return FieldDef{}, fmt.Errorf("TypeDef field name length %d exceeds remaining metadata %d", nameLen, buffer.remaining())
+ }
nameBytes := buffer.ReadBinary(nameLen, &bufErr)
+ if bufErr.HasError() {
+ return FieldDef{}, bufErr.TakeError()
+ }
fieldName, err := typeResolver.typeNameDecoder.Decode(nameBytes, nameEncoding)
if err != nil {
return FieldDef{}, fmt.Errorf("failed to decode field name: %w", err)
diff --git a/go/fory/type_def_test.go b/go/fory/type_def_test.go
index 5e88bd1744..37b465001a 100644
--- a/go/fory/type_def_test.go
+++ b/go/fory/type_def_test.go
@@ -18,10 +18,13 @@
package fory
import (
+ "bytes"
+ "compress/zlib"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
// Test structs for encoding/decoding
@@ -176,6 +179,35 @@ func checkTypeSpecRecursivelyOrNil(t *testing.T, original, decoded *TypeSpec, pa
checkTypeSpecRecursively(t, original, decoded, path, compareRootFlags)
}
+func typeDefTestBodyWithoutFields() []byte {
+ buffer := NewByteBuffer(nil)
+ buffer.WriteByte(StructTypeDefFlag)
+ buffer.WriteVarUint32(0)
+ return buffer.Bytes()
+}
+
+func typeDefTestFrame(t *testing.T, body []byte, compressed bool) (*ByteBuffer, int64) {
+ t.Helper()
+ bodyBuffer := NewByteBuffer(nil)
+ bodyBuffer.WriteBinary(body)
+ frame, err := prependGlobalHeader(bodyBuffer, compressed)
+ require.NoError(t, err)
+ readErr := &Error{}
+ header := frame.ReadInt64(readErr)
+ require.NoError(t, readErr.CheckError())
+ return frame, header
+}
+
+func deflateTypeDefTestBody(t *testing.T, body []byte) []byte {
+ t.Helper()
+ var compressed bytes.Buffer
+ writer := zlib.NewWriter(&compressed)
+ _, err := writer.Write(body)
+ require.NoError(t, err)
+ require.NoError(t, writer.Close())
+ return compressed.Bytes()
+}
+
// Item1 struct with mixed nullable (pointer) and non-nullable (primitive) fields
type Item1 struct {
F1 int32
@@ -301,22 +333,196 @@ func TestTypeDefNullableFields(t *testing.T) {
// allocation that would OOM-crash the process.
func TestTypeDefFieldCountOOMPanic(t *testing.T) {
fory := NewFory()
- header := int64(HAS_FIELDS_META_FLAG | 8)
- // metaHeaderByte value of 31 triggers the extended VarUint32 field-count path.
buffer := NewByteBuffer(make([]byte, 0, 8))
- buffer.WriteByte(31)
+ buffer.WriteByte(StructTypeDefFlag | SmallNumFieldsThreshold)
buffer.WriteVarUint32(2000000000)
- buffer.WriteUint8(0)
- buffer.WriteVarUint32(0)
buffer.SetReaderIndex(0)
- _, err := decodeTypeDef(fory, buffer, header)
+ _, err := decodeTypeDef(fory, buffer, int64(buffer.WriterIndex()))
if err == nil {
t.Fatal("expected error for oversized fieldCount, got nil")
}
}
+func TestTypeDefRejectsReservedGlobalHeaderBits(t *testing.T) {
+ fory := NewFory()
+ buffer := NewByteBuffer(nil)
+ buffer.WriteByte(StructTypeDefFlag)
+ buffer.WriteVarUint32(0)
+ buffer.SetReaderIndex(0)
+
+ _, err := decodeTypeDef(fory, buffer, int64(RESERVED_META_BITS|uint64(buffer.WriterIndex())))
+ if err == nil {
+ t.Fatal("expected error for reserved TypeDef global header bits")
+ }
+}
+
+func TestTypeDefRejectsReservedNonStructKindBits(t *testing.T) {
+ fory := NewFory()
+ body := []byte{0b0001_0000}
+ frame, header := typeDefTestFrame(t, body, false)
+
+ _, err := decodeTypeDef(fory, frame, header)
+ if err == nil {
+ t.Fatal("expected error for reserved non-struct TypeDef kind bits")
+ }
+}
+
+func TestTypeDefRejectsTrailingMetadataBytes(t *testing.T) {
+ fory := NewFory()
+ meta := NewByteBuffer(nil)
+ meta.WriteByte(StructTypeDefFlag)
+ meta.WriteVarUint32(0)
+ meta.WriteByte(0xff)
+
+ buffer := NewByteBuffer(nil)
+ _, writeErr := buffer.Write(meta.Bytes())
+ if writeErr != nil {
+ t.Fatalf("failed to write type def metadata: %v", writeErr)
+ }
+ buffer.SetReaderIndex(0)
+
+ _, err := decodeTypeDef(fory, buffer, int64(len(meta.Bytes())))
+ if err == nil {
+ t.Fatal("expected error for trailing TypeDef metadata bytes")
+ }
+}
+
+func TestTypeDefExtendedFieldCountHeaderDoesNotSetRegisterByName(t *testing.T) {
+ fields := make([]FieldDef, 32)
+ for i := range fields {
+ fields[i] = FieldDef{
+ typeSpec: NewSimpleTypeSpec(INT32),
+ tagID: i + 1,
+ }
+ }
+ typeDef := NewTypeDef(uint32(STRUCT), 700, nil, nil, false, false, fields)
+ buffer := NewByteBuffer(nil)
+
+ require.NoError(t, writeMetaHeader(buffer, typeDef))
+ header := buffer.Bytes()[0]
+ require.Equal(t, byte(StructTypeDefFlag|SmallNumFieldsThreshold), header)
+ require.Zero(t, header&RegisterByNameFlag)
+}
+
+func TestTypeDefRejectsMetadataHashMismatch(t *testing.T) {
+ fory := NewFory()
+ body := typeDefTestBodyWithoutFields()
+ buffer := NewByteBuffer(nil)
+ buffer.WriteBinary(body)
+ buffer.SetReaderIndex(0)
+
+ _, err := decodeTypeDef(fory, buffer, int64(len(body)))
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "metadata hash")
+}
+
+func TestTypeDefRejectsEncodedMetadataAboveMaxBinarySize(t *testing.T) {
+ fory := NewFory(WithMaxBinarySize(1))
+ body := typeDefTestBodyWithoutFields()
+ frame, header := typeDefTestFrame(t, body, false)
+
+ _, err := decodeTypeDef(fory, frame, header)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "max binary size exceeded")
+}
+
+func TestTypeDefRejectsCompressedMetadata(t *testing.T) {
+ decoded := typeDefTestBodyWithoutFields()
+ compressed := deflateTypeDefTestBody(t, decoded)
+ fory := NewFory(WithMaxBinarySize(4096))
+ frame, header := typeDefTestFrame(t, compressed, true)
+
+ _, err := decodeTypeDef(fory, frame, header)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "compressed xlang TypeDef")
+}
+
+func TestReadSharedTypeMetaCapsParsedTypeDefCache(t *testing.T) {
+ fory := NewFory(WithCompatible(true))
+ require.NoError(t, fory.RegisterNamedStruct(SimpleStruct{}, "example.SimpleStruct"))
+ typeDef, err := buildTypeDef(fory, reflect.ValueOf(SimpleStruct{}))
+ require.NoError(t, err)
+ require.NotEmpty(t, typeDef.encoded)
+
+ for i := 0; i < maxCachedTypeDefs; i++ {
+ fory.typeResolver.defIdToTypeDef[int64(i)] = typeDef
+ }
+ headerErr := &Error{}
+ header := NewByteBuffer(typeDef.encoded).ReadInt64(headerErr)
+ require.NoError(t, headerErr.CheckError())
+ require.NotContains(t, fory.typeResolver.defIdToTypeDef, header)
+
+ buffer := NewByteBuffer(nil)
+ buffer.WriteVarUint32(0)
+ buffer.WriteBinary(typeDef.encoded)
+ readErr := &Error{}
+ typeInfo := fory.typeResolver.readSharedTypeMeta(buffer, readErr)
+ require.NoError(t, readErr.CheckError())
+ require.NotNil(t, typeInfo)
+ require.Len(t, fory.typeResolver.defIdToTypeDef, maxCachedTypeDefs)
+ require.NotContains(t, fory.typeResolver.defIdToTypeDef, header)
+}
+
+func TestDecodeTypeDefFallbackNamedTypeCacheRespectsCap(t *testing.T) {
+ fory := NewFory(WithCompatible(true))
+ require.NoError(t, fory.RegisterNamedStruct(SimpleStruct{}, "example.SimpleStruct"))
+ typeDef, err := buildTypeDef(fory, reflect.ValueOf(SimpleStruct{}))
+ require.NoError(t, err)
+ require.NotNil(t, typeDef.nsName)
+ require.NotNil(t, typeDef.typeName)
+
+ nameKey := nsTypeKey{typeDef.nsName.Hashcode, typeDef.typeName.Hashcode}
+ delete(fory.typeResolver.nsTypeToTypeInfo, nameKey)
+ info := fory.typeResolver.namedTypeToTypeInfo[[2]string{"example", "SimpleStruct"}]
+ require.NotNil(t, info)
+ for i := 0; len(fory.typeResolver.nsTypeToTypeInfo) < maxCachedNamedTypeInfos; i++ {
+ fory.typeResolver.nsTypeToTypeInfo[nsTypeKey{int64(i + 1), int64(i + 2)}] = info
+ }
+ require.NotContains(t, fory.typeResolver.nsTypeToTypeInfo, nameKey)
+
+ buffer := NewByteBuffer(nil)
+ readErr := &Error{}
+ typeDef.writeTypeDef(buffer, readErr)
+ require.NoError(t, readErr.CheckError())
+ header := buffer.ReadInt64(readErr)
+ require.NoError(t, readErr.CheckError())
+ decoded := readTypeDef(fory, buffer, header, readErr)
+ require.NoError(t, readErr.CheckError())
+ require.NotNil(t, decoded)
+ require.Len(t, fory.typeResolver.nsTypeToTypeInfo, maxCachedNamedTypeInfos)
+ require.NotContains(t, fory.typeResolver.nsTypeToTypeInfo, nameKey)
+}
+
+func TestTypeDefRejectsNamespaceLengthBeyondMetadata(t *testing.T) {
+ fory := NewFory()
+ meta := NewByteBuffer(nil)
+ meta.WriteByte(StructTypeDefFlag | RegisterByNameFlag)
+ meta.WriteByte(byte(BIG_NAME_THRESHOLD << 2))
+ meta.WriteVarUint32Small7(100)
+ frame, header := typeDefTestFrame(t, meta.Bytes(), false)
+
+ _, err := decodeTypeDef(fory, frame, header)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "namespace length")
+}
+
+func TestTypeDefRejectsFieldNameLengthBeyondMetadata(t *testing.T) {
+ fory := NewFory()
+ meta := NewByteBuffer(nil)
+ meta.WriteByte(StructTypeDefFlag | 1)
+ meta.WriteVarUint32(0)
+ meta.WriteByte(0x0F << 2)
+ meta.WriteVarUint32(100)
+ meta.WriteUint8(uint8(INT32))
+ frame, header := typeDefTestFrame(t, meta.Bytes(), false)
+
+ _, err := decodeTypeDef(fory, frame, header)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "field name length")
+}
+
// TestTypeDefNestedRecursionStackOverflowPanic verifies that readFieldTypeWithFlags
// rejects a crafted payload with 20 million nested LIST types, returning an error
// at depth 64 instead of recursing until a goroutine stack overflow crashes the process.
diff --git a/go/fory/type_resolver.go b/go/fory/type_resolver.go
index 3f0e7ee84c..3838e67e25 100644
--- a/go/fory/type_resolver.go
+++ b/go/fory/type_resolver.go
@@ -52,9 +52,11 @@ const (
useStringId = 1
SMALL_STRING_THRESHOLD = 16
// 0xffffffff is reserved for "unset".
- maxUserTypeID uint32 = 0xfffffffe
- invalidUserTypeID uint32 = 0xffffffff
- internalTypeIDLimit = 0xFF
+ maxUserTypeID uint32 = 0xfffffffe
+ invalidUserTypeID uint32 = 0xffffffff
+ internalTypeIDLimit = 0xFF
+ maxCachedTypeDefs = 8192
+ maxCachedNamedTypeInfos = 8192
)
var (
@@ -1633,6 +1635,8 @@ func (r *TypeResolver) readSharedTypeMeta(buffer *ByteBuffer, err *Error) *TypeI
var td *TypeDef
if existingTd, exists := r.defIdToTypeDef[id]; exists {
+ // Header-cache hits intentionally skip without rehashing. Entries reach this cache only
+ // after a successful TypeDef parse and 52-bit body-hash validation.
skipTypeDef(buffer, id, err)
td = existingTd
} else {
@@ -1640,7 +1644,6 @@ func (r *TypeResolver) readSharedTypeMeta(buffer *ByteBuffer, err *Error) *TypeI
if err.HasError() {
return nil
}
- r.defIdToTypeDef[id] = newTd
td = newTd
}
@@ -1649,6 +1652,9 @@ func (r *TypeResolver) readSharedTypeMeta(buffer *ByteBuffer, err *Error) *TypeI
err.SetError(typeInfoErr)
return nil
}
+ if _, exists := r.defIdToTypeDef[id]; !exists && len(r.defIdToTypeDef) < maxCachedTypeDefs {
+ r.defIdToTypeDef[id] = td
+ }
context.readTypeInfos = append(context.readTypeInfos, typeInfo)
return typeInfo
@@ -2162,7 +2168,9 @@ func (r *TypeResolver) resolveTypeInfoByMetaBytes(nsBytes, typeBytes *MetaString
nameKey := [2]string{ns, typeName}
if typeInfo, exists := r.namedTypeToTypeInfo[nameKey]; exists {
- r.nsTypeToTypeInfo[compositeKey] = typeInfo
+ if len(r.nsTypeToTypeInfo) < maxCachedNamedTypeInfos {
+ r.nsTypeToTypeInfo[compositeKey] = typeInfo
+ }
return typeInfo
}
diff --git a/go/fory/writer.go b/go/fory/writer.go
index 4fccf37112..41ac503000 100644
--- a/go/fory/writer.go
+++ b/go/fory/writer.go
@@ -38,7 +38,7 @@ type WriteContext struct {
depth int
maxDepth int
typeResolver *TypeResolver // For complex type serialization
- refResolver *RefResolver // For reference tracking (legacy)
+ refResolver *RefResolver // For reference tracking in native-mode paths
bufferCallback func(BufferObject) bool // Callback for out-of-band buffers
outOfBand bool // Whether out-of-band serialization is enabled
err Error // Accumulated error state for deferred checking
@@ -108,7 +108,7 @@ func (c *WriteContext) TypeResolver() *TypeResolver {
return c.typeResolver
}
-// RefResolver returns the reference resolver (legacy)
+// RefResolver returns the reference resolver.
func (c *WriteContext) RefResolver() *RefResolver {
return c.refResolver
}
diff --git a/integration_tests/idl_tests/javascript/roundtrip.ts b/integration_tests/idl_tests/javascript/roundtrip.ts
index 23af999e0b..4b77e8fc41 100644
--- a/integration_tests/idl_tests/javascript/roundtrip.ts
+++ b/integration_tests/idl_tests/javascript/roundtrip.ts
@@ -123,8 +123,10 @@ function resolveRootSerializer(fory: Fory, bytes: Uint8Array): Serializer {
fory.readContext.reset(bytes);
const reader = fory.readContext.reader;
const bitmap = reader.readUint8();
- if ((bitmap & ConfigFlags.isNullFlag) === ConfigFlags.isNullFlag) {
- throw new Error("IDL roundtrip does not support null root payloads");
+ const supportedBitmap =
+ ConfigFlags.isCrossLanguageFlag | ConfigFlags.isOutOfBandFlag;
+ if ((bitmap & ~supportedBitmap) !== 0) {
+ throw new Error("unsupported root header bitmap");
}
if (
(bitmap & ConfigFlags.isCrossLanguageFlag) !==
diff --git a/java/fory-core/src/main/java/org/apache/fory/Fory.java b/java/fory-core/src/main/java/org/apache/fory/Fory.java
index adf1e8c6d9..56f6addbd0 100644
--- a/java/fory-core/src/main/java/org/apache/fory/Fory.java
+++ b/java/fory-core/src/main/java/org/apache/fory/Fory.java
@@ -68,7 +68,7 @@
/**
* Cross-language header layout: 1-byte bitmap.
*
- * Bit 0: null flag, Bit 1: xlang flag, Bit 2: out-of-band flag, Bits 3-7 reserved.
+ *
Bit 0: xlang flag, Bit 1: out-of-band flag, Bits 2-7 reserved.
*
*
serialize/deserialize are the root object APIs. Nested serialization and deserialization go
* through {@link WriteContext} and {@link ReadContext}.
@@ -86,9 +86,9 @@ public final class Fory implements BaseFory {
// this flag indicates that the object is a referencable and first write.
public static final byte REF_VALUE_FLAG = 0;
public static final byte NOT_SUPPORT_XLANG = 0;
- private static final byte isNilFlag = 1;
- private static final byte isCrossLanguageFlag = 1 << 1;
- private static final byte isOutOfBandFlag = 1 << 2;
+ private static final byte isCrossLanguageFlag = 1;
+ private static final byte isOutOfBandFlag = 1 << 1;
+ private static final byte reservedBitmapFlags = (byte) ~0b11;
private final Config config;
private final TypeResolver typeResolver;
@@ -98,6 +98,7 @@ public final class Fory implements BaseFory {
private final WriteContext writeContext;
private final ReadContext readContext;
private final CopyContext copyContext;
+ private final byte headerBitmap;
private MemoryBuffer buffer;
public Fory(ForyBuilder builder, ClassLoader classLoader) {
@@ -119,6 +120,7 @@ public Fory(ForyBuilder builder, ClassLoader classLoader, SharedRegistry sharedR
this.sharedRegistry = sharedRegistry;
this.classLoader = classLoader;
config = new Config(builder);
+ headerBitmap = config.isXlang() ? isCrossLanguageFlag : 0;
RefWriter refWriter;
RefReader refReader;
if (config.trackingRef()) {
@@ -293,15 +295,7 @@ public MemoryBuffer serialize(MemoryBuffer buffer, Object obj, BufferCallback ca
ensureRegistrationFinished();
writeContext.prepare(buffer, callback);
try {
- byte bitmap = 0;
- if (config.isXlang()) {
- bitmap |= isCrossLanguageFlag;
- }
- if (obj == null) {
- bitmap |= isNilFlag;
- buffer.writeByte(bitmap);
- return buffer;
- }
+ byte bitmap = headerBitmap;
if (callback != null) {
bitmap |= isOutOfBandFlag;
}
@@ -379,12 +373,9 @@ public T deserialize(byte[] bytes, Class type) {
public T deserialize(MemoryBuffer buffer, Class type) {
ensureRegistrationFinished();
byte bitmap = buffer.readByte();
- if ((bitmap & isNilFlag) == isNilFlag) {
- return null;
+ if (bitmap != headerBitmap) {
+ checkHeaderBitmapWithoutOutOfBand(bitmap);
}
- boolean peerOutOfBandEnabled = (bitmap & isOutOfBandFlag) == isOutOfBandFlag;
- assert !peerOutOfBandEnabled : "Out of band buffers not passed in when deserializing";
- checkXlangBitmap(bitmap);
readContext.prepare(buffer, null, false);
try {
try {
@@ -449,11 +440,10 @@ public Object deserialize(MemoryBuffer buffer) {
public Object deserialize(MemoryBuffer buffer, Iterable outOfBandBuffers) {
ensureRegistrationFinished();
byte bitmap = buffer.readByte();
- if ((bitmap & isNilFlag) == isNilFlag) {
- return null;
+ boolean peerOutOfBandEnabled = false;
+ if (bitmap != headerBitmap) {
+ peerOutOfBandEnabled = checkHeaderBitmap(bitmap);
}
- checkXlangBitmap(bitmap);
- boolean peerOutOfBandEnabled = (bitmap & isOutOfBandFlag) == isOutOfBandFlag;
if (peerOutOfBandEnabled) {
Preconditions.checkNotNull(
outOfBandBuffers,
@@ -530,13 +520,24 @@ private T deserializeByType(MemoryBuffer buffer, Class type) {
}
}
- private void checkXlangBitmap(byte bitmap) {
+ private void checkHeaderBitmapWithoutOutOfBand(byte bitmap) {
+ if (checkHeaderBitmap(bitmap)) {
+ throw new IllegalArgumentException("Out of band buffers not passed in when deserializing");
+ }
+ }
+
+ private boolean checkHeaderBitmap(byte bitmap) {
+ Preconditions.checkArgument(
+ (bitmap & reservedBitmapFlags) == 0,
+ "Serialized payload uses reserved header bitmap flags 0x%s",
+ Integer.toHexString(Byte.toUnsignedInt((byte) (bitmap & reservedBitmapFlags))));
boolean payloadCrossLanguage = (bitmap & isCrossLanguageFlag) == isCrossLanguageFlag;
Preconditions.checkArgument(
payloadCrossLanguage == config.isXlang(),
"Serialized payload xlang flag %s does not match this Fory mode %s",
payloadCrossLanguage,
config.isXlang());
+ return (bitmap & isOutOfBandFlag) == isOutOfBandFlag;
}
@Override
diff --git a/java/fory-core/src/main/java/org/apache/fory/collection/LongLongByteMap.java b/java/fory-core/src/main/java/org/apache/fory/collection/LongLongByteMap.java
index 3df143179b..26bc53190b 100644
--- a/java/fory-core/src/main/java/org/apache/fory/collection/LongLongByteMap.java
+++ b/java/fory-core/src/main/java/org/apache/fory/collection/LongLongByteMap.java
@@ -19,6 +19,7 @@
package org.apache.fory.collection;
+import java.util.Arrays;
import org.apache.fory.annotation.Internal;
import org.apache.fory.util.Preconditions;
@@ -129,6 +130,25 @@ public V get(long k1, long k2, byte k3) {
}
}
+ public void clear() {
+ if (size == 0) {
+ return;
+ }
+ size = 0;
+ Arrays.fill(keyTable, null);
+ ObjectArray.clearObjectArray(valueTable, 0, valueTable.length);
+ }
+
+ public void clear(int maximumCapacity) {
+ int tableSize = ForyObjectMap.tableSize(maximumCapacity, loadFactor);
+ if (keyTable.length <= tableSize) {
+ clear();
+ return;
+ }
+ size = 0;
+ resize(tableSize);
+ }
+
private void resize(int newSize) {
int oldCapacity = keyTable.length;
threshold = (int) (newSize * loadFactor);
diff --git a/java/fory-core/src/main/java/org/apache/fory/config/Config.java b/java/fory-core/src/main/java/org/apache/fory/config/Config.java
index c820986e42..bd7c3c710a 100644
--- a/java/fory-core/src/main/java/org/apache/fory/config/Config.java
+++ b/java/fory-core/src/main/java/org/apache/fory/config/Config.java
@@ -65,6 +65,8 @@ public class Config implements Serializable {
private final boolean serializeEnumByName;
private final int bufferSizeLimitBytes;
private final int maxDepth;
+ private final int maxBinarySize;
+ private final int maxCollectionSize;
private final float mapRefLoadFactor;
private final boolean forVirtualThread;
@@ -107,6 +109,8 @@ public Config(ForyBuilder builder) {
serializeEnumByName = builder.serializeEnumByName;
bufferSizeLimitBytes = builder.bufferSizeLimitBytes;
maxDepth = builder.maxDepth;
+ maxBinarySize = builder.maxBinarySize;
+ maxCollectionSize = builder.maxCollectionSize;
mapRefLoadFactor = builder.mapRefLoadFactor;
forVirtualThread = builder.forVirtualThread;
}
@@ -298,6 +302,16 @@ public int maxDepth() {
return maxDepth;
}
+ /** Returns max binary payload size for attacker-controlled binary and primitive-array lengths. */
+ public int maxBinarySize() {
+ return maxBinarySize;
+ }
+
+ /** Returns max collection allocation size for attacker-controlled collection lengths. */
+ public int maxCollectionSize() {
+ return maxCollectionSize;
+ }
+
/** Returns loadFactor of MacRef's writtenObjects. */
public float mapRefLoadFactor() {
return mapRefLoadFactor;
@@ -332,6 +346,8 @@ public boolean equals(Object o) {
&& compressIntArray == config.compressIntArray
&& compressLongArray == config.compressLongArray
&& bufferSizeLimitBytes == config.bufferSizeLimitBytes
+ && maxBinarySize == config.maxBinarySize
+ && maxCollectionSize == config.maxCollectionSize
&& requireClassRegistration == config.requireClassRegistration
&& suppressClassRegistrationWarnings == config.suppressClassRegistrationWarnings
&& registerGuavaTypes == config.registerGuavaTypes
@@ -371,6 +387,8 @@ public int hashCode() {
compressIntArray,
compressLongArray,
bufferSizeLimitBytes,
+ maxBinarySize,
+ maxCollectionSize,
requireClassRegistration,
suppressClassRegistrationWarnings,
registerGuavaTypes,
diff --git a/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java b/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java
index 5c3d6a45dd..52bd6b08ba 100644
--- a/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java
+++ b/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java
@@ -96,6 +96,8 @@ public final class ForyBuilder {
Integer bufferSizeLimitBytes = -1;
MetaCompressor metaCompressor = new DeflaterMetaCompressor();
int maxDepth = 50;
+ int maxBinarySize = 64 * 1024 * 1024;
+ int maxCollectionSize = 1_000_000;
float mapRefLoadFactor = 0.51f;
boolean forVirtualThread = false;
TypeChecker typeChecker;
@@ -473,6 +475,30 @@ public ForyBuilder withMaxDepth(int maxDepth) {
return this;
}
+ /**
+ * Set max binary payload size for deserialization. Binary and primitive-array byte lengths above
+ * this limit are rejected before allocation. Default max binary size is 64 MiB.
+ */
+ public ForyBuilder withMaxBinarySize(int maxBinarySize) {
+ Preconditions.checkArgument(
+ maxBinarySize >= 0, "maxBinarySize must >= 0 but got %s", maxBinarySize);
+ this.maxBinarySize = maxBinarySize;
+ recordAction(b -> b.withMaxBinarySize(maxBinarySize));
+ return this;
+ }
+
+ /**
+ * Set max collection size for deserialization. Collection lengths and collection capacity fields
+ * above this limit are rejected before allocation. Default max collection size is 1,000,000.
+ */
+ public ForyBuilder withMaxCollectionSize(int maxCollectionSize) {
+ Preconditions.checkArgument(
+ maxCollectionSize >= 0, "maxCollectionSize must >= 0 but got %s", maxCollectionSize);
+ this.maxCollectionSize = maxCollectionSize;
+ recordAction(b -> b.withMaxCollectionSize(maxCollectionSize));
+ return this;
+ }
+
/** Set loadFactor of MapRefResolver writtenObjects. Default value is 0.51 */
public ForyBuilder withMapRefLoadFactor(float loadFactor) {
Preconditions.checkArgument(
diff --git a/java/fory-core/src/main/java/org/apache/fory/context/MetaStringReader.java b/java/fory-core/src/main/java/org/apache/fory/context/MetaStringReader.java
index 23b8d444ee..7b492ddcbb 100644
--- a/java/fory-core/src/main/java/org/apache/fory/context/MetaStringReader.java
+++ b/java/fory-core/src/main/java/org/apache/fory/context/MetaStringReader.java
@@ -23,11 +23,12 @@
import org.apache.fory.annotation.Internal;
import org.apache.fory.collection.LongLongByteMap;
import org.apache.fory.collection.LongMap;
+import org.apache.fory.exception.ForyException;
import org.apache.fory.memory.LittleEndian;
import org.apache.fory.memory.MemoryBuffer;
import org.apache.fory.meta.EncodedMetaString;
+import org.apache.fory.meta.MetaString;
import org.apache.fory.resolver.SharedRegistry;
-import org.apache.fory.util.MurmurHash3;
/**
* Read-side state for meta-string references.
@@ -40,6 +41,9 @@ public final class MetaStringReader {
private static final int INITIAL_CAPACITY = 2;
private static final float LOAD_FACTOR = 0.5f;
private static final int SMALL_STRING_THRESHOLD = 16;
+ private static final int ENCODING_BITS = 4;
+ private static final int MAX_CACHED_READ_META_STRINGS = 8192;
+ private static final int MAX_CACHED_READ_META_STRING_LENGTH = 2048;
private final LongMap hash2MetaStringMap =
new LongMap<>(INITIAL_CAPACITY, LOAD_FACTOR);
@@ -47,7 +51,7 @@ public final class MetaStringReader {
new LongLongByteMap<>(INITIAL_CAPACITY, LOAD_FACTOR);
private final SharedRegistry sharedRegistry;
private EncodedMetaString[] dynamicReadStringIds = new EncodedMetaString[INITIAL_CAPACITY];
- private short dynamicReadStringId;
+ private int dynamicReadStringId;
/** Creates an empty reader state for one deserialization stream. */
public MetaStringReader(SharedRegistry sharedRegistry) {
@@ -68,7 +72,7 @@ public EncodedMetaString readMetaStringWithFlag(MemoryBuffer buffer, int header)
updateDynamicString(encodedMetaString);
return encodedMetaString;
}
- return dynamicReadStringIds[len - 1];
+ return readDynamicString(len);
}
/**
@@ -88,7 +92,7 @@ public EncodedMetaString readMetaStringWithFlag(
updateDynamicString(encodedMetaString);
return encodedMetaString;
}
- return dynamicReadStringIds[len - 1];
+ return readDynamicString(len);
}
/** Reads a meta string from the current buffer, including any dynamic-id indirection. */
@@ -103,7 +107,7 @@ public EncodedMetaString readMetaString(MemoryBuffer buffer) {
updateDynamicString(encodedMetaString);
return encodedMetaString;
}
- return dynamicReadStringIds[len - 1];
+ return readDynamicString(len);
}
/**
@@ -121,31 +125,75 @@ public EncodedMetaString readMetaString(MemoryBuffer buffer, EncodedMetaString c
updateDynamicString(encodedMetaString);
return encodedMetaString;
}
- return dynamicReadStringIds[len - 1];
+ return readDynamicString(len);
}
private EncodedMetaString readBigMetaString(
MemoryBuffer buffer, EncodedMetaString cache, int len) {
long hashCode = buffer.readInt64();
- if (cache.hash == hashCode) {
- buffer.increaseReaderIndex(len);
+ if (cache.hash == hashCode && cache.bytes.length == len) {
+ // Big meta-string hashes are the wire identity on this cache hit. The body hash is computed
+ // and checked before a new entry is published; later hits intentionally skip the body.
+ buffer.checkReadableBytes(len);
+ buffer._increaseReaderIndexUnsafe(len);
return cache;
}
return readBigMetaString(buffer, len, hashCode);
}
private EncodedMetaString readBigMetaString(MemoryBuffer buffer, int len, long hashCode) {
+ buffer.checkReadableBytes(len);
EncodedMetaString encodedMetaString = hash2MetaStringMap.get(hashCode);
- if (encodedMetaString == null) {
- encodedMetaString =
- sharedRegistry.getOrCreateEncodedMetaString(buffer.readBytes(len), hashCode);
- hash2MetaStringMap.put(hashCode, encodedMetaString);
+ if (encodedMetaString != null && encodedMetaString.bytes.length == len) {
+ // Preserve the header-keyed fast path: entries reach this map only after their bytes matched
+ // the wire hash, so repeat hits advance over the redundant body without rehashing.
+ buffer._increaseReaderIndexUnsafe(len);
return encodedMetaString;
}
- buffer.increaseReaderIndex(len);
+ byte[] bytes = readAndValidateBigMetaString(buffer, len, hashCode);
+ EncodedMetaString canonicalMetaString =
+ sharedRegistry.getOrCreateEncodedMetaString(bytes, hashCode);
+ if (encodedMetaString == null
+ && len <= MAX_CACHED_READ_META_STRING_LENGTH
+ && hash2MetaStringMap.size < MAX_CACHED_READ_META_STRINGS) {
+ hash2MetaStringMap.put(hashCode, canonicalMetaString);
+ }
+ return canonicalMetaString;
+ }
+
+ private byte[] readAndValidateBigMetaString(MemoryBuffer buffer, int len, long hashCode) {
+ byte[] bytes = buffer.readBytes(len);
+ MetaString.Encoding encoding = MetaString.Encoding.fromInt((int) (hashCode & 0xff));
+ long canonicalHash = EncodedMetaString.computeHash(bytes, encoding);
+ if (canonicalHash != hashCode) {
+ throw new ForyException("Malformed meta string hash");
+ }
+ return bytes;
+ }
+
+ private boolean shouldCacheSmallMetaString() {
+ return longLongMetaStringMap.size < MAX_CACHED_READ_META_STRINGS;
+ }
+
+ private EncodedMetaString cacheSmallMetaString(
+ long v1, long v2, byte key, EncodedMetaString encodedMetaString) {
+ if (shouldCacheSmallMetaString()) {
+ longLongMetaStringMap.put(v1, v2, key, encodedMetaString);
+ }
return encodedMetaString;
}
+ private EncodedMetaString createSmallMetaString(
+ int len, MetaString.Encoding encoding, byte key, long v1, long v2) {
+ byte[] data = new byte[16];
+ LittleEndian.putInt64(data, 0, v1);
+ LittleEndian.putInt64(data, 8, v2);
+ byte[] bytes = Arrays.copyOf(data, len);
+ long hashCode = EncodedMetaString.computeHash(bytes, encoding);
+ return cacheSmallMetaString(
+ v1, v2, key, sharedRegistry.getOrCreateEncodedMetaString(bytes, hashCode));
+ }
+
private EncodedMetaString readSmallMetaString(MemoryBuffer buffer, int len) {
if (len == 0) {
return EncodedMetaString.EMPTY;
@@ -159,9 +207,11 @@ private EncodedMetaString readSmallMetaString(MemoryBuffer buffer, int len) {
v1 = buffer.readInt64();
v2 = buffer.readBytesAsInt64(len - 8);
}
- EncodedMetaString encodedMetaString = longLongMetaStringMap.get(v1, v2, encoding);
+ int encodingValue = encoding & 0xff;
+ byte key = smallMetaStringKey(len, encodingValue);
+ EncodedMetaString encodedMetaString = longLongMetaStringMap.get(v1, v2, key);
if (encodedMetaString == null) {
- return createSmallMetaString(len, encoding, v1, v2);
+ return createSmallMetaString(len, MetaString.Encoding.fromInt(encodingValue), key, v1, v2);
}
return encodedMetaString;
}
@@ -180,39 +230,45 @@ private EncodedMetaString readSmallMetaString(
v1 = buffer.readInt64();
v2 = buffer.readBytesAsInt64(len - 8);
}
- if (cache.first8Bytes == v1 && cache.second8Bytes == v2) {
+ int encodingValue = encoding & 0xff;
+ if (cache.bytes.length == len
+ && cache.encodingValue == encodingValue
+ && cache.first8Bytes == v1
+ && cache.second8Bytes == v2) {
return cache;
}
- EncodedMetaString encodedMetaString = longLongMetaStringMap.get(v1, v2, encoding);
+ byte key = smallMetaStringKey(len, encodingValue);
+ EncodedMetaString encodedMetaString = longLongMetaStringMap.get(v1, v2, key);
if (encodedMetaString == null) {
- return createSmallMetaString(len, encoding, v1, v2);
+ return createSmallMetaString(len, MetaString.Encoding.fromInt(encodingValue), key, v1, v2);
}
return encodedMetaString;
}
- private EncodedMetaString createSmallMetaString(int len, byte encoding, long v1, long v2) {
- byte[] data = new byte[16];
- LittleEndian.putInt64(data, 0, v1);
- LittleEndian.putInt64(data, 8, v2);
- long hashCode = MurmurHash3.murmurhash3_x64_128(data, 0, len, 47)[0];
- hashCode = Math.abs(hashCode);
- hashCode = (hashCode & 0xffffffffffffff00L) | encoding;
- EncodedMetaString encodedMetaString =
- sharedRegistry.getOrCreateEncodedMetaString(Arrays.copyOf(data, len), hashCode);
- longLongMetaStringMap.put(v1, v2, encoding, encodedMetaString);
- return encodedMetaString;
+ private static byte smallMetaStringKey(int len, int encodingValue) {
+ return (byte) (((len - 1) << ENCODING_BITS) | encodingValue);
+ }
+
+ private EncodedMetaString readDynamicString(int dynamicId) {
+ if (dynamicId <= 0 || dynamicId > dynamicReadStringId) {
+ throw new ForyException("Invalid meta string reference id " + dynamicId);
+ }
+ return dynamicReadStringIds[dynamicId - 1];
}
private void updateDynamicString(EncodedMetaString encodedMetaString) {
- short currentDynamicReadId = dynamicReadStringId++;
+ int currentDynamicReadId = dynamicReadStringId++;
EncodedMetaString[] readStringIds = dynamicReadStringIds;
if (readStringIds.length <= currentDynamicReadId) {
+ if (currentDynamicReadId >= MAX_CACHED_READ_META_STRINGS) {
+ throw new ForyException("Too many meta string references in payload");
+ }
readStringIds = dynamicReadStringIds = growRead(readStringIds, currentDynamicReadId);
}
readStringIds[currentDynamicReadId] = encodedMetaString;
}
- private EncodedMetaString[] growRead(EncodedMetaString[] current, int id) {
+ private static EncodedMetaString[] growRead(EncodedMetaString[] current, int id) {
int newLength = current.length;
while (newLength <= id) {
newLength <<= 1;
diff --git a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java
index 5def56c660..0b03800dd6 100644
--- a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java
+++ b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java
@@ -467,11 +467,13 @@ public MemoryBuffer readBufferObject() {
} else {
size = buffer.readVarUInt32();
}
- if (buffer.readerIndex() + size > buffer.size() && buffer.getStreamReader() != null) {
- buffer.getStreamReader().fillBuffer(buffer.readerIndex() + size - buffer.size());
+ if (size < 0) {
+ throw new IllegalArgumentException("Buffer object size must be non-negative: " + size);
}
- MemoryBuffer slice = buffer.slice(buffer.readerIndex(), size);
- buffer.readerIndex(buffer.readerIndex() + size);
+ buffer.checkReadableBytes(size);
+ int readerIndex = buffer.readerIndex();
+ MemoryBuffer slice = buffer.slice(readerIndex, size);
+ buffer.readerIndex(readerIndex + size);
return slice;
}
Preconditions.checkArgument(outOfBandBuffers.hasNext());
diff --git a/java/fory-core/src/main/java/org/apache/fory/io/ForyReadableChannel.java b/java/fory-core/src/main/java/org/apache/fory/io/ForyReadableChannel.java
index 1afbcbf7b3..ff8d9f85ac 100644
--- a/java/fory-core/src/main/java/org/apache/fory/io/ForyReadableChannel.java
+++ b/java/fory-core/src/main/java/org/apache/fory/io/ForyReadableChannel.java
@@ -66,9 +66,9 @@ public int fillBuffer(int minFillSize) {
memoryBuf.initDirectBuffer(ByteBufferUtil.getAddress(byteBuf), position, byteBuf);
}
byteBuf.limit(newLimit);
- int readCount = channel.read(byteBuf);
- memoryBuf.increaseSize(readCount);
- return readCount;
+ readFully(byteBuf, minFillSize);
+ memoryBuf.increaseSize(minFillSize);
+ return minFillSize;
} catch (IOException e) {
throw new DeserializationException("Failed to read the provided byte channel", e);
}
@@ -98,7 +98,7 @@ public void readTo(byte[] dst, int dstIndex, int length) {
buf.readBytes(dst, dstIndex, remaining);
try {
ByteBuffer buffer = ByteBuffer.wrap(dst, dstIndex + remaining, length - remaining);
- channel.read(buffer);
+ readFully(buffer, length - remaining);
} catch (IOException e) {
throw new DeserializationException("Failed to read the provided byte channel", e);
}
@@ -130,10 +130,13 @@ public void readToByteBuffer(ByteBuffer dst, int length) {
int newLimit = dst.position() + length - remaining;
if (dstLimit > newLimit) {
dst.limit(newLimit);
- channel.read(dst);
- dst.limit(dstLimit);
+ try {
+ readFully(dst, length - remaining);
+ } finally {
+ dst.limit(dstLimit);
+ }
} else {
- channel.read(dst);
+ readFully(dst, length - remaining);
}
} catch (IOException e) {
throw new DeserializationException("Failed to read the provided byte channel", e);
@@ -169,4 +172,15 @@ public void close() throws IOException {
public MemoryBuffer getBuffer() {
return memoryBuffer;
}
+
+ private void readFully(ByteBuffer dst, int length) throws IOException {
+ int remaining = length;
+ while (remaining > 0) {
+ int read = channel.read(dst);
+ if (read <= 0) {
+ throw new DeserializationException("Unexpected end of byte channel");
+ }
+ remaining -= read;
+ }
+ }
}
diff --git a/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java b/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java
index 48d37e97d8..89a0c96f8f 100644
--- a/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java
+++ b/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java
@@ -827,15 +827,15 @@ public int writeVarUInt32Small7(int value) {
}
private int continueWriteVarUInt32Small7(int value) {
- long encoded = (value & 0x7F);
+ int encoded = (value & 0x7F);
encoded |= (((value & 0x3f80) << 1) | 0x80);
int writerIdx = writerIndex;
if (value >>> 14 == 0) {
- _unsafePutInt32(writerIdx, (int) encoded);
+ _unsafePutInt32(writerIdx, encoded);
writerIndex += 2;
return 2;
}
- int diff = continuePutVarInt36(writerIdx, encoded, value);
+ int diff = continuePutVarUInt32(writerIdx, encoded, value);
writerIndex += diff;
return diff;
}
@@ -1811,7 +1811,7 @@ public int _readVarInt32OnLE() {
int readIdx = readerIndex;
int result;
if (size - readIdx < 5) {
- result = (int) readVarUint36Slow();
+ result = readVarUInt32Slow();
} else {
long address = this.address;
// | 1bit + 7bits | 1bit + 7bits | 1bit + 7bits | 1bit + 7bits |
@@ -1835,7 +1835,11 @@ public int _readVarInt32OnLE() {
// 0xfe00000: 0b1111111 << 21
result |= (fourByteValue >>> 3) & 0xfe00000;
if ((fourByteValue & 0x80000000) != 0) {
- result |= (UNSAFE.getByte(heapMemory, address + readIdx++) & 0x7F) << 28;
+ int fifthByte = UNSAFE.getByte(heapMemory, address + readIdx++) & 0xFF;
+ if ((fifthByte & 0xF0) != 0) {
+ throwMalformedVarUInt32(fifthByte);
+ }
+ result |= fifthByte << 28;
}
}
}
@@ -1854,7 +1858,7 @@ public int _readVarInt32OnBE() {
int readIdx = readerIndex;
int result;
if (size - readIdx < 5) {
- result = (int) readVarUint36Slow();
+ result = readVarUInt32Slow();
} else {
long address = this.address;
int fourByteValue = Integer.reverseBytes(UNSAFE.getInt(heapMemory, address + readIdx));
@@ -1877,7 +1881,11 @@ public int _readVarInt32OnBE() {
// 0xfe00000: 0b1111111 << 21
result |= (fourByteValue >>> 3) & 0xfe00000;
if ((fourByteValue & 0x80000000) != 0) {
- result |= (UNSAFE.getByte(heapMemory, address + readIdx++) & 0x7F) << 28;
+ int fifthByte = UNSAFE.getByte(heapMemory, address + readIdx++) & 0xFF;
+ if ((fifthByte & 0xF0) != 0) {
+ throwMalformedVarUInt32(fifthByte);
+ }
+ result |= fifthByte << 28;
}
}
}
@@ -1956,11 +1964,45 @@ private long readVarUint36Slow() {
return result;
}
+ private int readVarUInt32Slow() {
+ int b = readByte() & 0xFF;
+ int result = b & 0x7F;
+ // Note:
+ // Loop are not used here to improve performance.
+ // We manually unroll the loop for better performance.
+ // noinspection Duplicates
+ if ((b & 0x80) != 0) {
+ b = readByte() & 0xFF;
+ result |= (b & 0x7F) << 7;
+ if ((b & 0x80) != 0) {
+ b = readByte() & 0xFF;
+ result |= (b & 0x7F) << 14;
+ if ((b & 0x80) != 0) {
+ b = readByte() & 0xFF;
+ result |= (b & 0x7F) << 21;
+ if ((b & 0x80) != 0) {
+ b = readByte() & 0xFF;
+ if ((b & 0xF0) != 0) {
+ throwMalformedVarUInt32(b);
+ }
+ result |= b << 28;
+ }
+ }
+ }
+ }
+ return result;
+ }
+
+ private static void throwMalformedVarUInt32(int fifthByte) {
+ throw new IllegalArgumentException(
+ "Malformed varuint32 fifth byte " + fifthByte + " exceeds 32 bits");
+ }
+
/** Reads the 1-5 byte int part of a non-negative varint. */
public int readVarUInt32() {
int readIdx = readerIndex;
if (size - readIdx < 5) {
- return (int) readVarUint36Slow();
+ return readVarUInt32Slow();
}
// | 1bit + 7bits | 1bit + 7bits | 1bit + 7bits | 1bit + 7bits |
int fourByteValue = _unsafeGetInt32(readIdx);
@@ -1983,7 +2025,11 @@ public int readVarUInt32() {
// 0xfe00000: 0b1111111 << 21
result |= (fourByteValue >>> 3) & 0xfe00000;
if ((fourByteValue & 0x80000000) != 0) {
- result |= (UNSAFE.getByte(heapMemory, address + readIdx++) & 0x7F) << 28;
+ int fifthByte = UNSAFE.getByte(heapMemory, address + readIdx++) & 0xFF;
+ if ((fifthByte & 0xF0) != 0) {
+ throwMalformedVarUInt32(fifthByte);
+ }
+ result |= fifthByte << 28;
}
}
}
@@ -2031,7 +2077,7 @@ public int readVarUInt32Small14() {
readerIndex = readIdx;
return value;
} else {
- return (int) readVarUint36Slow();
+ return readVarUInt32Slow();
}
}
@@ -2044,7 +2090,11 @@ private int continueReadVarUInt32(int readIdx, int bulkRead, int value) {
readIdx++;
value |= (bulkRead >>> 3) & 0xfe00000;
if ((bulkRead & 0x80000000) != 0) {
- value |= (UNSAFE.getByte(heapMemory, address + readIdx++) & 0x7F) << 28;
+ int fifthByte = UNSAFE.getByte(heapMemory, address + readIdx++) & 0xFF;
+ if ((fifthByte & 0xF0) != 0) {
+ throwMalformedVarUInt32(fifthByte);
+ }
+ value |= fifthByte << 28;
}
}
readerIndex = readIdx;
@@ -2440,7 +2490,7 @@ public int readBinarySize() {
}
readerIndex = readIdx;
} else {
- binarySize = (int) readVarUint36Slow();
+ binarySize = readVarUInt32Slow();
readIdx = readerIndex;
}
int diff = size - readIdx;
@@ -2459,7 +2509,11 @@ private int continueReadBinarySize(int readIdx, int bulkRead, int binarySize) {
readIdx++;
binarySize |= (bulkRead >>> 3) & 0xfe00000;
if ((bulkRead & 0x80000000) != 0) {
- binarySize |= (UNSAFE.getByte(heapMemory, address + readIdx++) & 0x7F) << 28;
+ int fifthByte = UNSAFE.getByte(heapMemory, address + readIdx++) & 0xFF;
+ if ((fifthByte & 0xF0) != 0) {
+ throwMalformedVarUInt32(fifthByte);
+ }
+ binarySize |= fifthByte << 28;
}
}
int diff = size - readIdx;
@@ -2716,6 +2770,28 @@ public boolean equalTo(MemoryBuffer buf2, int offset1, int offset2, int len) {
return Platform.arrayEquals(heapMemory, pos1, buf2.heapMemory, pos2, len);
}
+ /**
+ * Equals a memory buffer region with a byte array region.
+ *
+ * @param bytes Array to compare with
+ * @param bytesOffset Offset of bytes to start comparing
+ * @param offset Offset of this buffer to start comparing
+ * @param len Length of the compared memory region
+ * @return true if regions are equal or len zero, false otherwise
+ */
+ public boolean equalTo(byte[] bytes, int bytesOffset, int offset, int len) {
+ checkArgument(bytes != null);
+ checkArgument(len >= 0);
+ checkArgument(bytesOffset >= 0 && bytesOffset <= bytes.length - len);
+ checkArgument(offset >= 0 && offset <= size - len);
+ if (len == 0) {
+ return true;
+ }
+ final long pos = address + offset;
+ return Platform.arrayEquals(
+ heapMemory, pos, bytes, Platform.BYTE_ARRAY_OFFSET + bytesOffset, len);
+ }
+
@Override
public String toString() {
return "MemoryBuffer{"
diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/DeflaterMetaCompressor.java b/java/fory-core/src/main/java/org/apache/fory/meta/DeflaterMetaCompressor.java
index 7611741124..cf8cee2682 100644
--- a/java/fory-core/src/main/java/org/apache/fory/meta/DeflaterMetaCompressor.java
+++ b/java/fory-core/src/main/java/org/apache/fory/meta/DeflaterMetaCompressor.java
@@ -43,6 +43,11 @@ public byte[] compress(byte[] input, int offset, int size) {
@Override
public byte[] decompress(byte[] input, int offset, int size) {
+ return decompress(input, offset, size, Integer.MAX_VALUE);
+ }
+
+ @Override
+ public byte[] decompress(byte[] input, int offset, int size, int maxOutputSize) {
Inflater inflater = new Inflater();
inflater.setInput(input, offset, size);
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
@@ -51,6 +56,10 @@ public byte[] decompress(byte[] input, int offset, int size) {
while (!inflater.finished()) {
int decompressedSize = inflater.inflate(buffer);
if (decompressedSize > 0) {
+ if (outputStream.size() > maxOutputSize - decompressedSize) {
+ throw new InvalidDataException(
+ "Decompressed TypeDef metadata exceeds the maximum size.");
+ }
outputStream.write(buffer, 0, decompressedSize);
continue;
}
diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/EncodedMetaString.java b/java/fory-core/src/main/java/org/apache/fory/meta/EncodedMetaString.java
index b13c1630b7..1692aa63d0 100644
--- a/java/fory-core/src/main/java/org/apache/fory/meta/EncodedMetaString.java
+++ b/java/fory-core/src/main/java/org/apache/fory/meta/EncodedMetaString.java
@@ -33,6 +33,7 @@ public final class EncodedMetaString {
public final byte[] bytes;
public final long hash;
+ public final int encodingValue;
public final MetaString.Encoding encoding;
public final long first8Bytes;
public final long second8Bytes;
@@ -45,7 +46,8 @@ public EncodedMetaString(byte[] bytes, long hash) {
assert hash != 0;
this.bytes = bytes;
this.hash = hash;
- this.encoding = MetaString.Encoding.fromInt((int) (hash & HEADER_MASK));
+ this.encodingValue = (int) (hash & HEADER_MASK);
+ this.encoding = MetaString.Encoding.fromInt(encodingValue);
byte[] data = bytes;
if (bytes.length < 16) {
data = new byte[16];
@@ -55,7 +57,7 @@ public EncodedMetaString(byte[] bytes, long hash) {
second8Bytes = LittleEndian.getInt64(data, Platform.BYTE_ARRAY_OFFSET + 8);
}
- private static long computeHash(byte[] bytes, MetaString.Encoding encoding) {
+ public static long computeHash(byte[] bytes, MetaString.Encoding encoding) {
long hash = MurmurHash3.murmurhash3_x64_128(bytes, 0, bytes.length, 47)[0];
hash = Math.abs(hash);
if (hash == 0) {
diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/MetaCompressor.java b/java/fory-core/src/main/java/org/apache/fory/meta/MetaCompressor.java
index 0ed3974b1c..cf74e20c2d 100644
--- a/java/fory-core/src/main/java/org/apache/fory/meta/MetaCompressor.java
+++ b/java/fory-core/src/main/java/org/apache/fory/meta/MetaCompressor.java
@@ -19,6 +19,8 @@
package org.apache.fory.meta;
+import org.apache.fory.exception.InvalidDataException;
+
/**
* An interface used to compress class metadata such as field names and types. The implementation of
* this interface should be thread safe.
@@ -28,6 +30,14 @@ public interface MetaCompressor {
byte[] decompress(byte[] data, int offset, int size);
+ default byte[] decompress(byte[] data, int offset, int size, int maxOutputSize) {
+ byte[] decompressed = decompress(data, offset, size);
+ if (decompressed.length > maxOutputSize) {
+ throw new InvalidDataException("Decompressed TypeDef metadata exceeds the maximum size.");
+ }
+ return decompressed;
+ }
+
/**
* Check whether {@link MetaCompressor} implements `equals/hashCode` method. If not implemented,
* return {@link TypeEqualMetaCompressor} instead which compare equality by the compressor type
diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java b/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java
index f3c02af179..b4f50c2c0d 100644
--- a/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java
+++ b/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java
@@ -25,12 +25,12 @@
import static org.apache.fory.meta.NativeTypeDefEncoder.BIG_NAME_THRESHOLD;
import static org.apache.fory.meta.NativeTypeDefEncoder.NUM_CLASS_THRESHOLD;
import static org.apache.fory.meta.TypeDef.COMPRESS_META_FLAG;
-import static org.apache.fory.meta.TypeDef.HAS_FIELDS_META_FLAG;
import static org.apache.fory.meta.TypeDef.META_SIZE_MASKS;
import java.util.ArrayList;
import java.util.List;
import org.apache.fory.collection.Tuple2;
+import org.apache.fory.exception.DeserializationException;
import org.apache.fory.memory.MemoryBuffer;
import org.apache.fory.meta.FieldTypes.FieldType;
import org.apache.fory.meta.MetaString.Encoding;
@@ -38,6 +38,7 @@
import org.apache.fory.resolver.TypeResolver;
import org.apache.fory.serializer.UnknownClass;
import org.apache.fory.type.Types;
+import org.apache.fory.util.MurmurHash3;
import org.apache.fory.util.Preconditions;
/**
@@ -46,8 +47,13 @@
* href="https://fory.apache.org/docs/specification/fory_java_serialization_spec">...
*/
class NativeTypeDefDecoder {
+ private static final int MAX_TYPE_DEF_SIZE_BYTES = 16 * 1024 * 1024;
+
static Tuple2 decodeTypeDefBuf(
MemoryBuffer inputBuffer, TypeResolver resolver, long id) {
+ if ((id & TypeDef.RESERVED_META_FLAGS) != 0) {
+ throw new DeserializationException("Invalid TypeDef global header");
+ }
MemoryBuffer encoded = MemoryBuffer.newHeapBuffer(64);
encoded.writeInt64(id);
int size = (int) (id & META_SIZE_MASKS);
@@ -56,10 +62,17 @@ static Tuple2 decodeTypeDefBuf(
encoded.writeVarUInt32(moreSize);
size += moreSize;
}
+ if (size > MAX_TYPE_DEF_SIZE_BYTES) {
+ throw new DeserializationException("TypeDef metadata size exceeds the maximum size");
+ }
byte[] encodedTypeDef = inputBuffer.readBytes(size);
encoded.writeBytes(encodedTypeDef);
if ((id & COMPRESS_META_FLAG) != 0) {
- encodedTypeDef = resolver.getConfig().getMetaCompressor().decompress(encodedTypeDef, 0, size);
+ encodedTypeDef =
+ resolver
+ .getConfig()
+ .getMetaCompressor()
+ .decompress(encodedTypeDef, 0, size, MAX_TYPE_DEF_SIZE_BYTES);
}
return Tuple2.of(encodedTypeDef, encoded.getBytes(0, encoded.writerIndex()));
}
@@ -67,7 +80,9 @@ static Tuple2 decodeTypeDefBuf(
public static TypeDef decodeTypeDef(ClassResolver resolver, MemoryBuffer buffer, long id) {
Tuple2 decoded = decodeTypeDefBuf(buffer, resolver, id);
MemoryBuffer typeDefBuf = MemoryBuffer.fromByteArray(decoded.f0);
- int numClasses = typeDefBuf.readByte();
+ int bodyHeader = typeDefBuf.readByte() & 0xff;
+ int rootTypeId = nativeTypeId(bodyHeader >>> 4);
+ int numClasses = bodyHeader & NUM_CLASS_THRESHOLD;
if (numClasses == NUM_CLASS_THRESHOLD) {
numClasses += typeDefBuf.readVarUInt32Small7();
}
@@ -75,12 +90,15 @@ public static TypeDef decodeTypeDef(ClassResolver resolver, MemoryBuffer buffer,
String className;
List classFields = new ArrayList<>();
ClassSpec classSpec = null;
+ Class> rootClass = null;
+ boolean rootClassLayerRegistered = false;
for (int i = 0; i < numClasses; i++) {
// | num fields + register flag | header + package name | header + class name
// | header + type id + field name | next field info | ... |
int currentClassHeader = typeDefBuf.readVarUInt32Small7();
boolean isRegistered = (currentClassHeader & 0b1) != 0;
int numFields = currentClassHeader >>> 1;
+ Class> currentClass = null;
if (isRegistered) {
int typeId = typeDefBuf.readUInt8();
int userTypeId = -1;
@@ -89,11 +107,16 @@ public static TypeDef decodeTypeDef(ClassResolver resolver, MemoryBuffer buffer,
}
Class> cls = resolver.getRegisteredClassByTypeId(typeId, userTypeId);
if (cls == null) {
- classSpec = new ClassSpec(UnknownClass.UnknownStruct.class, typeId, userTypeId);
+ classSpec =
+ new ClassSpec(
+ UnknownClass.UnknownStruct.class,
+ i == numClasses - 1 ? rootTypeId : typeId,
+ userTypeId);
className = classSpec.entireClassName;
} else {
className = cls.getName();
- classSpec = new ClassSpec(cls, typeId, userTypeId);
+ classSpec = new ClassSpec(cls, i == numClasses - 1 ? rootTypeId : typeId, userTypeId);
+ currentClass = cls;
}
} else {
String pkg = readPkgName(typeDefBuf);
@@ -103,20 +126,22 @@ public static TypeDef decodeTypeDef(ClassResolver resolver, MemoryBuffer buffer,
if (resolver.isRegisteredByName(className)) {
Class> cls = resolver.getRegisteredClass(className);
className = cls.getName();
- classSpec =
- new ClassSpec(
- cls, resolver.getTypeIdForTypeDef(cls), resolver.getUserTypeIdForTypeDef(cls));
+ int typeId = i == numClasses - 1 ? rootTypeId : resolver.getTypeIdForTypeDef(cls);
+ classSpec = new ClassSpec(cls, typeId, resolver.getUserTypeIdForTypeDef(cls));
+ currentClass = cls;
} else {
Class> cls =
resolver.loadClassForMeta(
decodedSpec.entireClassName, decodedSpec.isEnum, decodedSpec.dimension);
if (UnknownClass.isUnknowClass(cls)) {
- int typeId;
+ int decodedTypeId;
if (decodedSpec.isEnum) {
- typeId = Types.NAMED_ENUM;
+ decodedTypeId = Types.NAMED_ENUM;
} else {
- typeId = resolver.isCompatible() ? Types.NAMED_COMPATIBLE_STRUCT : Types.NAMED_STRUCT;
+ decodedTypeId =
+ resolver.isCompatible() ? Types.NAMED_COMPATIBLE_STRUCT : Types.NAMED_STRUCT;
}
+ int typeId = i == numClasses - 1 ? rootTypeId : decodedTypeId;
classSpec =
new ClassSpec(
decodedSpec.entireClassName,
@@ -128,18 +153,120 @@ public static TypeDef decodeTypeDef(ClassResolver resolver, MemoryBuffer buffer,
classSpec.type = cls;
className = classSpec.entireClassName;
} else {
- int typeId = resolver.getTypeIdForTypeDef(cls);
+ int typeId = i == numClasses - 1 ? rootTypeId : resolver.getTypeIdForTypeDef(cls);
classSpec = new ClassSpec(cls, typeId, resolver.getUserTypeIdForTypeDef(cls));
className = classSpec.entireClassName;
+ currentClass = cls;
}
}
}
+ if (i == numClasses - 1) {
+ rootClass = currentClass;
+ rootClassLayerRegistered = isRegistered;
+ }
List fieldInfos = readFieldsInfo(typeDefBuf, resolver, className, numFields);
classFields.addAll(fieldInfos);
}
Preconditions.checkNotNull(classSpec);
- boolean hasFieldsMeta = (id & HAS_FIELDS_META_FLAG) != 0;
- return new TypeDef(classSpec, classFields, hasFieldsMeta, id, decoded.f1);
+ boolean hasFieldMetadata = !classFields.isEmpty();
+ if (!Types.isStructType(rootTypeId) && hasFieldMetadata) {
+ throw new DeserializationException("Non-struct TypeDef cannot carry field metadata");
+ }
+ if (rootClass != null) {
+ int expectedRootTypeId = resolver.getTypeDefRootTypeId(rootClass, hasFieldMetadata);
+ if (!isCompatibleRootKind(expectedRootTypeId, rootTypeId, !rootClassLayerRegistered)) {
+ throw new DeserializationException(
+ "TypeDef root kind does not match the decoded class: class="
+ + rootClass.getName()
+ + ", expected="
+ + expectedRootTypeId
+ + ", actual="
+ + rootTypeId
+ + ", registeredClassLayer="
+ + rootClassLayerRegistered);
+ }
+ }
+ if (typeDefBuf.remaining() != 0) {
+ throw new DeserializationException("Invalid TypeDef metadata size");
+ }
+ validateParsedTypeDefHash(id, decoded.f1);
+ return new TypeDef(classSpec, classFields, id, decoded.f1);
+ }
+
+ private static boolean isCompatibleRootKind(
+ int expectedTypeId, int actualTypeId, boolean allowNamednessDifference) {
+ if (expectedTypeId == actualTypeId) {
+ return true;
+ }
+ if (allowNamednessDifference) {
+ return Types.isStructType(expectedTypeId) && Types.isStructType(actualTypeId);
+ }
+ return isStructCompatibilityVariant(expectedTypeId, actualTypeId);
+ }
+
+ private static boolean isStructCompatibilityVariant(int expectedTypeId, int actualTypeId) {
+ boolean expectedIdStruct =
+ expectedTypeId == Types.STRUCT || expectedTypeId == Types.COMPATIBLE_STRUCT;
+ boolean actualIdStruct =
+ actualTypeId == Types.STRUCT || actualTypeId == Types.COMPATIBLE_STRUCT;
+ if (expectedIdStruct || actualIdStruct) {
+ return expectedIdStruct && actualIdStruct;
+ }
+ boolean expectedNamedStruct =
+ expectedTypeId == Types.NAMED_STRUCT || expectedTypeId == Types.NAMED_COMPATIBLE_STRUCT;
+ boolean actualNamedStruct =
+ actualTypeId == Types.NAMED_STRUCT || actualTypeId == Types.NAMED_COMPATIBLE_STRUCT;
+ return expectedNamedStruct && actualNamedStruct;
+ }
+
+ static int nativeTypeId(int kindCode) {
+ switch (kindCode) {
+ case 0:
+ return Types.STRUCT;
+ case 1:
+ return Types.COMPATIBLE_STRUCT;
+ case 2:
+ return Types.NAMED_STRUCT;
+ case 3:
+ return Types.NAMED_COMPATIBLE_STRUCT;
+ case 4:
+ return Types.ENUM;
+ case 5:
+ return Types.NAMED_ENUM;
+ case 6:
+ return Types.EXT;
+ case 7:
+ return Types.NAMED_EXT;
+ case 8:
+ return Types.TYPED_UNION;
+ case 9:
+ return Types.NAMED_UNION;
+ default:
+ throw new DeserializationException("Unsupported TypeDef kind code " + kindCode);
+ }
+ }
+
+ static void validateParsedTypeDefHash(long id, byte[] encoded) {
+ int size = (int) (id & META_SIZE_MASKS);
+ int bodyOffset = Long.BYTES;
+ if (size == META_SIZE_MASKS) {
+ MemoryBuffer encodedBuffer = MemoryBuffer.fromByteArray(encoded);
+ encodedBuffer.readerIndex(Long.BYTES);
+ int moreSize = encodedBuffer.readVarUInt32Small14();
+ size += moreSize;
+ bodyOffset = encodedBuffer.readerIndex();
+ }
+ if (encoded.length - bodyOffset != size) {
+ throw new DeserializationException("Invalid TypeDef encoded size");
+ }
+ long hash = MurmurHash3.murmurhash3_x64_128(encoded, bodyOffset, size, 47)[0];
+ hash <<= (Long.SIZE - TypeDef.NUM_HASH_BITS);
+ long hashMask = -1L << (Long.SIZE - TypeDef.NUM_HASH_BITS);
+ long expectedHeaderHash = Math.abs(hash) & hashMask;
+ long actualHeaderHash = id & hashMask;
+ if (expectedHeaderHash != actualHeaderHash) {
+ throw new DeserializationException("Invalid TypeDef metadata hash");
+ }
}
private static List readFieldsInfo(
diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefEncoder.java b/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefEncoder.java
index 3f7998184d..6173a9e88d 100644
--- a/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefEncoder.java
+++ b/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefEncoder.java
@@ -23,7 +23,6 @@
import static org.apache.fory.meta.Encoders.pkgEncodingsList;
import static org.apache.fory.meta.Encoders.typeNameEncodingsList;
import static org.apache.fory.meta.TypeDef.COMPRESS_META_FLAG;
-import static org.apache.fory.meta.TypeDef.HAS_FIELDS_META_FLAG;
import static org.apache.fory.meta.TypeDef.META_SIZE_MASKS;
import static org.apache.fory.meta.TypeDef.NUM_HASH_BITS;
@@ -56,7 +55,6 @@
*/
@Internal
public class NativeTypeDefEncoder {
- // a flag to mark a type is not struct.
static final int NUM_CLASS_THRESHOLD = 0b1111;
private static final java.util.function.Function IDENTITY_DESCRIPTOR =
descriptor -> descriptor;
@@ -142,43 +140,45 @@ public static List buildFieldsInfo(TypeResolver resolver, List
}
/** Build class definition from fields of class. */
- static TypeDef buildTypeDef(
- ClassResolver classResolver, Class> type, List fields, boolean hasFieldsMeta) {
- return buildTypeDefWithFieldInfos(
- classResolver, type, buildFieldsInfo(classResolver, fields), hasFieldsMeta);
+ static TypeDef buildTypeDef(ClassResolver classResolver, Class> type, List fields) {
+ return buildTypeDefWithFieldInfos(classResolver, type, buildFieldsInfo(classResolver, fields));
}
public static TypeDef buildTypeDefWithFieldInfos(
- ClassResolver classResolver,
- Class> type,
- List fieldInfos,
- boolean hasFieldsMeta) {
+ ClassResolver classResolver, Class> type, List fieldInfos) {
+ boolean hasFieldMetadata = !fieldInfos.isEmpty();
Map> classLayers = getClassFields(type, fieldInfos);
fieldInfos = new ArrayList<>(fieldInfos.size());
classLayers.values().forEach(fieldInfos::addAll);
- MemoryBuffer encodeTypeDef = encodeTypeDef(classResolver, type, classLayers, hasFieldsMeta);
+ MemoryBuffer encodeTypeDef = encodeTypeDef(classResolver, type, classLayers, hasFieldMetadata);
byte[] typeDefBytes = encodeTypeDef.getBytes(0, encodeTypeDef.writerIndex());
- int typeId = classResolver.getTypeIdForTypeDef(type);
+ int typeId = classResolver.getTypeDefRootTypeId(type, hasFieldMetadata);
int userTypeId = classResolver.getUserTypeIdForTypeDef(type);
ClassSpec classSpec = new ClassSpec(type, typeId, userTypeId);
- return new TypeDef(
- classSpec, fieldInfos, hasFieldsMeta, encodeTypeDef.getInt64(0), typeDefBytes);
+ return new TypeDef(classSpec, fieldInfos, encodeTypeDef.getInt64(0), typeDefBytes);
}
// see spec documentation: docs/specification/java_serialization_spec.md
// https://fory.apache.org/docs/specification/fory_java_serialization_spec
public static MemoryBuffer encodeTypeDef(
+ ClassResolver classResolver, Class> type, Map> classLayers) {
+ return encodeTypeDef(classResolver, type, classLayers, hasFieldMetadata(classLayers));
+ }
+
+ private static MemoryBuffer encodeTypeDef(
ClassResolver classResolver,
Class> type,
Map> classLayers,
- boolean hasFieldsMeta) {
+ boolean hasFieldMetadata) {
MemoryBuffer typeDefBuf = MemoryBuffer.newHeapBuffer(128);
int numClasses = classLayers.size() - 1; // num class must be greater than 0
+ int rootTypeId = classResolver.getTypeDefRootTypeId(type, hasFieldMetadata);
+ int firstBodyByte = nativeKindCode(rootTypeId) << 4;
if (numClasses >= NUM_CLASS_THRESHOLD) {
- typeDefBuf.writeByte(NUM_CLASS_THRESHOLD);
+ typeDefBuf.writeByte(firstBodyByte | NUM_CLASS_THRESHOLD);
typeDefBuf.writeVarUInt32Small7(numClasses - NUM_CLASS_THRESHOLD);
} else {
- typeDefBuf.writeByte(numClasses);
+ typeDefBuf.writeByte(firstBodyByte | numClasses);
}
for (Map.Entry> entry : classLayers.entrySet()) {
String className = entry.getKey();
@@ -224,11 +224,19 @@ public static MemoryBuffer encodeTypeDef(
typeDefBuf = MemoryBuffer.fromByteArray(compressed);
typeDefBuf.writerIndex(compressed.length);
}
- return prependHeader(typeDefBuf, isCompressed, hasFieldsMeta);
+ return prependHeader(typeDefBuf, isCompressed);
+ }
+
+ private static boolean hasFieldMetadata(Map> classLayers) {
+ for (List fields : classLayers.values()) {
+ if (!fields.isEmpty()) {
+ return true;
+ }
+ }
+ return false;
}
- static MemoryBuffer prependHeader(
- MemoryBuffer buffer, boolean isCompressed, boolean hasFieldsMeta) {
+ static MemoryBuffer prependHeader(MemoryBuffer buffer, boolean isCompressed) {
int metaSize = buffer.writerIndex();
long hash = MurmurHash3.murmurhash3_x64_128(buffer.getHeapMemory(), 0, metaSize, 47)[0];
hash <<= (64 - NUM_HASH_BITS);
@@ -237,9 +245,6 @@ static MemoryBuffer prependHeader(
if (isCompressed) {
header |= COMPRESS_META_FLAG;
}
- if (hasFieldsMeta) {
- header |= HAS_FIELDS_META_FLAG;
- }
header |= Math.min(metaSize, META_SIZE_MASKS);
MemoryBuffer result = MemoryUtils.buffer(metaSize + 8);
result.writeInt64(header);
@@ -250,6 +255,33 @@ static MemoryBuffer prependHeader(
return result;
}
+ static int nativeKindCode(int typeId) {
+ switch (typeId) {
+ case Types.STRUCT:
+ return 0;
+ case Types.COMPATIBLE_STRUCT:
+ return 1;
+ case Types.NAMED_STRUCT:
+ return 2;
+ case Types.NAMED_COMPATIBLE_STRUCT:
+ return 3;
+ case Types.ENUM:
+ return 4;
+ case Types.NAMED_ENUM:
+ return 5;
+ case Types.EXT:
+ return 6;
+ case Types.NAMED_EXT:
+ return 7;
+ case Types.TYPED_UNION:
+ return 8;
+ case Types.NAMED_UNION:
+ return 9;
+ default:
+ throw new IllegalArgumentException("Unsupported TypeDef kind " + typeId);
+ }
+ }
+
private static Class> getType(Class> cls, String type) {
Class> c = cls;
while (cls != null) {
diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/TypeDef.java b/java/fory-core/src/main/java/org/apache/fory/meta/TypeDef.java
index 4fd91857c9..6069472e91 100644
--- a/java/fory-core/src/main/java/org/apache/fory/meta/TypeDef.java
+++ b/java/fory-core/src/main/java/org/apache/fory/meta/TypeDef.java
@@ -34,6 +34,7 @@
import java.util.stream.Collectors;
import org.apache.fory.builder.MetaSharedCodecBuilder;
import org.apache.fory.config.ForyBuilder;
+import org.apache.fory.exception.DeserializationException;
import org.apache.fory.logging.Logger;
import org.apache.fory.logging.LoggerFactory;
import org.apache.fory.memory.MemoryBuffer;
@@ -69,11 +70,11 @@
public class TypeDef implements Serializable {
private static final Logger LOG = LoggerFactory.getLogger(TypeDef.class);
- static final int COMPRESS_META_FLAG = 0b1 << 9;
- static final int HAS_FIELDS_META_FLAG = 0b1 << 8;
+ static final int COMPRESS_META_FLAG = 0b1 << 8;
+ static final long RESERVED_META_FLAGS = 0b111L << 9;
// low 8 bits
static final int META_SIZE_MASKS = 0xff;
- static final int NUM_HASH_BITS = 50;
+ static final int NUM_HASH_BITS = 52;
// TODO use field offset to sort field, which will hit l1-cache more. Since
// `objectFieldOffset` is not part of jvm-specification, it may change between different jdk
@@ -101,30 +102,33 @@ public class TypeDef implements Serializable {
private final ClassSpec classSpec;
private final List fieldsInfo;
- private final boolean hasFieldsMeta;
// Unique id for class def. If class def are same between processes, then the id will
// be same too.
private final long id;
private final byte[] encoded;
- TypeDef(
- ClassSpec classSpec,
- List fieldsInfo,
- boolean hasFieldsMeta,
- long id,
- byte[] encoded) {
+ TypeDef(ClassSpec classSpec, List fieldsInfo, long id, byte[] encoded) {
this.classSpec = classSpec;
this.fieldsInfo = fieldsInfo;
- this.hasFieldsMeta = hasFieldsMeta;
this.id = id;
this.encoded = encoded;
}
public static void skipTypeDef(MemoryBuffer buffer, long id) {
+ // Header-cache hits intentionally treat the current body as opaque bytes and skip by the size
+ // in
+ // the current header. Parsed TypeDefs are published to the cache only after successful body
+ // parse
+ // and 52-bit body-hash validation; cache hits must not reparse or rehash that body.
int size = (int) (id & META_SIZE_MASKS);
if (size == META_SIZE_MASKS) {
- size += buffer.readVarUInt32Small14();
+ int extendedSize = buffer.readVarUInt32Small14();
+ if (extendedSize < 0 || extendedSize > Integer.MAX_VALUE - size) {
+ throw new DeserializationException("Invalid TypeDef metadata size " + extendedSize);
+ }
+ size += extendedSize;
}
+ buffer.checkReadableBytes(size);
buffer.increaseReaderIndex(size);
}
@@ -146,11 +150,6 @@ public List getFieldsInfo() {
return fieldsInfo;
}
- /** Returns ext meta for the class. */
- public boolean hasFieldsMeta() {
- return hasFieldsMeta;
- }
-
/**
* Returns an unique id for class def. If class def are same between processes, then the id will
* be same too.
@@ -179,6 +178,10 @@ public boolean isCompatible() {
|| classSpec.typeId == Types.NAMED_COMPATIBLE_STRUCT;
}
+ public boolean isStructSchemaKind() {
+ return Types.isStructType(classSpec.typeId);
+ }
+
public int getUserTypeId() {
Preconditions.checkArgument(!isNamed(), "Named types don't have user type id");
return classSpec.userTypeId;
@@ -190,8 +193,7 @@ public boolean equals(Object o) {
return false;
}
TypeDef typeDef = (TypeDef) o;
- return hasFieldsMeta == typeDef.hasFieldsMeta
- && id == typeDef.id
+ return id == typeDef.id
&& Objects.equals(classSpec, typeDef.classSpec)
&& Objects.equals(fieldsInfo, typeDef.fieldsInfo);
}
@@ -209,8 +211,6 @@ public String toString() {
+ '\''
+ ", fieldsInfo="
+ fieldsInfo
- + ", hasFieldsMeta="
- + hasFieldsMeta
+ ", id="
+ id
+ '}';
@@ -429,17 +429,12 @@ public static TypeDef buildTypeDef(TypeResolver resolver, Class> cls, boolean
return TypeDefEncoder.buildTypeDef((XtypeResolver) resolver, cls);
}
return NativeTypeDefEncoder.buildTypeDef(
- (ClassResolver) resolver, cls, buildFields(resolver, cls, resolveParent), true);
+ (ClassResolver) resolver, cls, buildFields(resolver, cls, resolveParent));
}
/** Build class definition from fields of class. */
static TypeDef buildTypeDef(ClassResolver classResolver, Class> type, List fields) {
- return buildTypeDef(classResolver, type, fields, true);
- }
-
- public static TypeDef buildTypeDef(
- ClassResolver classResolver, Class> type, List fields, boolean hasFieldsMeta) {
- return NativeTypeDefEncoder.buildTypeDef(classResolver, type, fields, hasFieldsMeta);
+ return NativeTypeDefEncoder.buildTypeDef(classResolver, type, fields);
}
public TypeDef replaceRootClassTo(TypeResolver resolver, Class> targetCls) {
@@ -460,6 +455,6 @@ public TypeDef replaceRootClassTo(TypeResolver resolver, Class> targetCls) {
(XtypeResolver) resolver, targetCls, fieldInfos);
}
return NativeTypeDefEncoder.buildTypeDefWithFieldInfos(
- (ClassResolver) resolver, targetCls, fieldInfos, hasFieldsMeta);
+ (ClassResolver) resolver, targetCls, fieldInfos);
}
}
diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefDecoder.java b/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefDecoder.java
index 296316b8fa..b184656c47 100644
--- a/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefDecoder.java
+++ b/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefDecoder.java
@@ -23,14 +23,18 @@
import static org.apache.fory.meta.NativeTypeDefDecoder.decodeTypeDefBuf;
import static org.apache.fory.meta.NativeTypeDefDecoder.readPkgName;
import static org.apache.fory.meta.NativeTypeDefDecoder.readTypeName;
-import static org.apache.fory.meta.TypeDef.HAS_FIELDS_META_FLAG;
+import static org.apache.fory.meta.NativeTypeDefDecoder.validateParsedTypeDefHash;
+import static org.apache.fory.meta.TypeDef.COMPRESS_META_FLAG;
+import static org.apache.fory.meta.TypeDefEncoder.COMPATIBLE_FLAG;
import static org.apache.fory.meta.TypeDefEncoder.FIELD_NAME_SIZE_THRESHOLD;
import static org.apache.fory.meta.TypeDefEncoder.REGISTER_BY_NAME_FLAG;
import static org.apache.fory.meta.TypeDefEncoder.SMALL_NUM_FIELDS_THRESHOLD;
+import static org.apache.fory.meta.TypeDefEncoder.STRUCT_FLAG;
import java.util.ArrayList;
import java.util.List;
import org.apache.fory.collection.Tuple2;
+import org.apache.fory.exception.DeserializationException;
import org.apache.fory.logging.Logger;
import org.apache.fory.logging.LoggerFactory;
import org.apache.fory.memory.MemoryBuffer;
@@ -39,12 +43,12 @@
import org.apache.fory.resolver.TypeInfo;
import org.apache.fory.resolver.XtypeResolver;
import org.apache.fory.serializer.UnknownClass;
+import org.apache.fory.type.Types;
import org.apache.fory.util.StringUtils;
import org.apache.fory.util.Utils;
/**
- * A decoder which decode binary into {@link TypeDef}. Global header layout follows the xlang spec
- * with an 8-bit meta size and flags at bits 8/9. See spec documentation:
+ * A decoder which decode binary into {@link TypeDef}. See spec documentation:
* docs/specification/fory_xlang_serialization_spec.md ...
*/
@@ -52,40 +56,88 @@ class TypeDefDecoder {
private static final Logger LOG = LoggerFactory.getLogger(TypeDefDecoder.class);
public static TypeDef decodeTypeDef(XtypeResolver resolver, MemoryBuffer inputBuffer, long id) {
+ if ((id & COMPRESS_META_FLAG) != 0) {
+ throw new DeserializationException("Compressed xlang TypeDef is not supported");
+ }
Tuple2 decoded = decodeTypeDefBuf(inputBuffer, resolver, id);
MemoryBuffer buffer = MemoryBuffer.fromByteArray(decoded.f0);
- byte header = buffer.readByte();
- int numFields = header & SMALL_NUM_FIELDS_THRESHOLD;
- if (numFields == SMALL_NUM_FIELDS_THRESHOLD) {
- numFields += buffer.readVarUInt32Small7();
- }
+ int header = buffer.readByte() & 0xff;
+ boolean isStruct = (header & STRUCT_FLAG) != 0;
+ int numFields = 0;
ClassSpec classSpec;
- if ((header & REGISTER_BY_NAME_FLAG) != 0) {
- String namespace = readPkgName(buffer);
- String typeName = readTypeName(buffer);
- if (Utils.DEBUG_OUTPUT_ENABLED) {
- LOG.info("Decode class {} using namespace {}", typeName, namespace);
+ if (isStruct) {
+ boolean named = (header & REGISTER_BY_NAME_FLAG) != 0;
+ boolean compatible = (header & COMPATIBLE_FLAG) != 0;
+ int typeId;
+ if (named) {
+ typeId = compatible ? Types.NAMED_COMPATIBLE_STRUCT : Types.NAMED_STRUCT;
+ } else {
+ typeId = compatible ? Types.COMPATIBLE_STRUCT : Types.STRUCT;
+ }
+ numFields = header & SMALL_NUM_FIELDS_THRESHOLD;
+ if (numFields == SMALL_NUM_FIELDS_THRESHOLD) {
+ numFields += buffer.readVarUInt32Small7();
}
- TypeInfo userTypeInfo = resolver.getUserTypeInfo(namespace, typeName);
- if (userTypeInfo == null) {
- classSpec = new ClassSpec(UnknownClass.UnknownStruct.class);
+ if (named) {
+ String namespace = readPkgName(buffer);
+ String typeName = readTypeName(buffer);
+ if (Utils.DEBUG_OUTPUT_ENABLED) {
+ LOG.info("Decode class {} using namespace {}", typeName, namespace);
+ }
+ TypeInfo userTypeInfo = resolver.getUserTypeInfo(namespace, typeName);
+ if (userTypeInfo == null) {
+ classSpec = new ClassSpec(UnknownClass.UnknownStruct.class, typeId, -1);
+ } else {
+ validateRegisteredTypeDefKind(userTypeInfo, typeId);
+ classSpec = new ClassSpec(userTypeInfo.getType(), typeId, userTypeInfo.getUserTypeId());
+ }
} else {
- classSpec = new ClassSpec(userTypeInfo.getType());
+ int userTypeId = buffer.readVarUInt32();
+ TypeInfo userTypeInfo = resolver.getUserTypeInfo(userTypeId);
+ if (userTypeInfo == null) {
+ classSpec = new ClassSpec(UnknownClass.UnknownStruct.class, typeId, userTypeId);
+ } else {
+ validateRegisteredTypeDefKind(userTypeInfo, typeId);
+ classSpec = new ClassSpec(userTypeInfo.getType(), typeId, userTypeId);
+ }
}
} else {
- int typeId = buffer.readUInt8();
- int userTypeId = buffer.readVarUInt32();
- TypeInfo userTypeInfo = resolver.getUserTypeInfo(userTypeId);
- if (userTypeInfo == null) {
- classSpec = new ClassSpec(UnknownClass.UnknownStruct.class, typeId, userTypeId);
+ if ((header & 0b0111_0000) != 0) {
+ throw new DeserializationException("Invalid TypeDef kind header");
+ }
+ int typeId = nonStructTypeId(header & 0b1111);
+ boolean named = Types.isNamedType(typeId);
+ if (named) {
+ String namespace = readPkgName(buffer);
+ String typeName = readTypeName(buffer);
+ TypeInfo userTypeInfo = resolver.getUserTypeInfo(namespace, typeName);
+ if (userTypeInfo == null) {
+ classSpec = new ClassSpec(UnknownClass.UnknownStruct.class, typeId, -1);
+ } else {
+ validateRegisteredTypeDefKind(userTypeInfo, typeId);
+ classSpec = new ClassSpec(userTypeInfo.getType(), typeId, userTypeInfo.getUserTypeId());
+ }
} else {
- classSpec = new ClassSpec(userTypeInfo.getType(), typeId, userTypeId);
+ int userTypeId = buffer.readVarUInt32();
+ TypeInfo userTypeInfo = resolver.getUserTypeInfo(userTypeId);
+ if (userTypeInfo == null) {
+ classSpec = new ClassSpec(UnknownClass.UnknownStruct.class, typeId, userTypeId);
+ } else {
+ validateRegisteredTypeDefKind(userTypeInfo, typeId);
+ classSpec = new ClassSpec(userTypeInfo.getType(), typeId, userTypeId);
+ }
}
}
List classFields =
readFieldsInfo(buffer, resolver, classSpec.entireClassName, numFields);
- boolean hasFieldsMeta = (id & HAS_FIELDS_META_FLAG) != 0;
- TypeDef typeDef = new TypeDef(classSpec, classFields, hasFieldsMeta, id, decoded.f1);
+ if (!isStruct && !classFields.isEmpty()) {
+ throw new DeserializationException("Non-struct TypeDef cannot carry field metadata");
+ }
+ if (buffer.remaining() != 0) {
+ throw new DeserializationException("Invalid TypeDef metadata size");
+ }
+ validateParsedTypeDefHash(id, decoded.f1);
+ TypeDef typeDef = new TypeDef(classSpec, classFields, id, decoded.f1);
if (Utils.DEBUG_OUTPUT_ENABLED) {
LOG.info("[Java TypeDef DECODED] " + typeDef);
// Compute and print diff with local TypeDef
@@ -103,6 +155,49 @@ public static TypeDef decodeTypeDef(XtypeResolver resolver, MemoryBuffer inputBu
return typeDef;
}
+ private static void validateRegisteredTypeDefKind(TypeInfo userTypeInfo, int typeId) {
+ int registeredTypeId = userTypeInfo.getTypeId();
+ if (registeredTypeId != typeId && !isStructCompatibilityVariant(registeredTypeId, typeId)) {
+ throw new DeserializationException(
+ String.format(
+ "TypeDef kind %s does not match registered kind %s for %s",
+ typeId, registeredTypeId, userTypeInfo.getType()));
+ }
+ }
+
+ private static boolean isStructCompatibilityVariant(int registeredTypeId, int typeId) {
+ boolean registeredIdStruct =
+ registeredTypeId == Types.STRUCT || registeredTypeId == Types.COMPATIBLE_STRUCT;
+ boolean typeIdStruct = typeId == Types.STRUCT || typeId == Types.COMPATIBLE_STRUCT;
+ if (registeredIdStruct || typeIdStruct) {
+ return registeredIdStruct && typeIdStruct;
+ }
+ boolean registeredNamedStruct =
+ registeredTypeId == Types.NAMED_STRUCT || registeredTypeId == Types.NAMED_COMPATIBLE_STRUCT;
+ boolean typeIdNamedStruct =
+ typeId == Types.NAMED_STRUCT || typeId == Types.NAMED_COMPATIBLE_STRUCT;
+ return registeredNamedStruct && typeIdNamedStruct;
+ }
+
+ static int nonStructTypeId(int kindCode) {
+ switch (kindCode) {
+ case 0:
+ return Types.ENUM;
+ case 1:
+ return Types.NAMED_ENUM;
+ case 2:
+ return Types.EXT;
+ case 3:
+ return Types.NAMED_EXT;
+ case 4:
+ return Types.TYPED_UNION;
+ case 5:
+ return Types.NAMED_UNION;
+ default:
+ throw new DeserializationException("Unsupported TypeDef kind code " + kindCode);
+ }
+ }
+
// | header + type info + field name | ... | header + type info + field name |
private static List readFieldsInfo(
MemoryBuffer buffer, XtypeResolver resolver, String className, int numFields) {
diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefEncoder.java b/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefEncoder.java
index d5eada0f0a..4bed7c13d6 100644
--- a/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefEncoder.java
+++ b/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefEncoder.java
@@ -49,8 +49,7 @@
import org.apache.fory.util.Utils;
/**
- * An encoder which encode {@link TypeDef} into binary. Global header layout follows the xlang spec
- * with an 8-bit meta size and flags at bits 8/9. See spec documentation:
+ * An encoder which encode {@link TypeDef} into binary. See spec documentation:
* docs/specification/fory_xlang_serialization_spec.md ...
*/
@@ -115,7 +114,6 @@ static TypeDef buildTypeDefWithFieldInfos(
new TypeDef(
new ClassSpec(type, typeInfo.getTypeId(), typeInfo.getUserTypeId()),
fieldInfos,
- true,
encodeTypeDef.getInt64(0),
typeDefBytes);
if (Utils.DEBUG_OUTPUT_ENABLED) {
@@ -125,43 +123,82 @@ static TypeDef buildTypeDefWithFieldInfos(
}
static final int SMALL_NUM_FIELDS_THRESHOLD = 0b11111;
- static final int REGISTER_BY_NAME_FLAG = 0b100000;
+ static final int REGISTER_BY_NAME_FLAG = 0b0010_0000;
+ static final int COMPATIBLE_FLAG = 0b0100_0000;
+ static final int STRUCT_FLAG = 0b1000_0000;
static final int FIELD_NAME_SIZE_THRESHOLD = 0b1111;
// see spec documentation: docs/specification/xlang_serialization_spec.md
// https://fory.apache.org/docs/specification/fory_xlang_serialization_spec
static MemoryBuffer encodeTypeDef(XtypeResolver resolver, Class> type, List fields) {
TypeInfo typeInfo = resolver.getTypeInfo(type);
+ int typeId = typeInfo.getTypeId();
+ boolean isStruct = Types.isStructType(typeId);
+ Preconditions.checkArgument(
+ isStruct || fields.isEmpty(), "Non-struct TypeDef %s cannot carry field metadata", typeId);
MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(128);
buffer.writeByte(-1); // placeholder for header, update later
- int currentClassHeader = fields.size();
- if (fields.size() >= SMALL_NUM_FIELDS_THRESHOLD) {
- currentClassHeader = SMALL_NUM_FIELDS_THRESHOLD;
- buffer.writeVarUInt32(fields.size() - SMALL_NUM_FIELDS_THRESHOLD);
- }
- if (resolver.isRegisteredById(type)) {
- buffer.writeUInt8(typeInfo.getTypeId());
- Preconditions.checkArgument(
- typeInfo.getUserTypeId() != -1,
- "User type id is required for typeId %s",
- typeInfo.getTypeId());
- buffer.writeVarUInt32(typeInfo.getUserTypeId());
+ if (isStruct) {
+ int fieldCount = fields.size();
+ int currentClassHeader = STRUCT_FLAG | Math.min(fieldCount, SMALL_NUM_FIELDS_THRESHOLD);
+ if (typeId == Types.COMPATIBLE_STRUCT || typeId == Types.NAMED_COMPATIBLE_STRUCT) {
+ currentClassHeader |= COMPATIBLE_FLAG;
+ }
+ if (fieldCount >= SMALL_NUM_FIELDS_THRESHOLD) {
+ buffer.writeVarUInt32(fieldCount - SMALL_NUM_FIELDS_THRESHOLD);
+ }
+ if (resolver.isRegisteredById(type)) {
+ Preconditions.checkArgument(
+ typeInfo.getUserTypeId() != -1,
+ "User type id is required for typeId %s",
+ typeInfo.getTypeId());
+ buffer.writeVarUInt32(typeInfo.getUserTypeId());
+ } else {
+ Preconditions.checkArgument(resolver.isRegisteredByName(type));
+ currentClassHeader |= REGISTER_BY_NAME_FLAG;
+ String ns = typeInfo.decodeNamespace();
+ String typename = typeInfo.decodeTypeName();
+ writePkgName(buffer, ns);
+ writeTypeName(buffer, typename);
+ }
+ buffer.putByte(0, currentClassHeader);
+ writeFieldsInfo(resolver, buffer, fields);
} else {
- Preconditions.checkArgument(resolver.isRegisteredByName(type));
- currentClassHeader |= REGISTER_BY_NAME_FLAG;
- String ns = typeInfo.decodeNamespace();
- String typename = typeInfo.decodeTypeName();
- writePkgName(buffer, ns);
- writeTypeName(buffer, typename);
+ buffer.putByte(0, nonStructKindCode(typeId));
+ if (resolver.isRegisteredById(type)) {
+ Preconditions.checkArgument(
+ typeInfo.getUserTypeId() != -1,
+ "User type id is required for typeId %s",
+ typeInfo.getTypeId());
+ buffer.writeVarUInt32(typeInfo.getUserTypeId());
+ } else {
+ Preconditions.checkArgument(resolver.isRegisteredByName(type));
+ String ns = typeInfo.decodeNamespace();
+ String typename = typeInfo.decodeTypeName();
+ writePkgName(buffer, ns);
+ writeTypeName(buffer, typename);
+ }
}
- buffer.putByte(0, currentClassHeader);
- writeFieldsInfo(resolver, buffer, fields);
+ return prependHeader(buffer, false);
+ }
- // Temporary xlang behavior: always write TypeMeta uncompressed.
- // Some runtimes still don't support TypeMeta decompression, so we must avoid emitting
- // compressed xlang TypeMeta until all xlang implementations support decompress.
- // Note: native mode is unchanged and still uses NativeTypeDefEncoder compression flow.
- return prependHeader(buffer, false, !fields.isEmpty());
+ static int nonStructKindCode(int typeId) {
+ switch (typeId) {
+ case Types.ENUM:
+ return 0;
+ case Types.NAMED_ENUM:
+ return 1;
+ case Types.EXT:
+ return 2;
+ case Types.NAMED_EXT:
+ return 3;
+ case Types.TYPED_UNION:
+ return 4;
+ case Types.NAMED_UNION:
+ return 5;
+ default:
+ throw new IllegalArgumentException("Unsupported TypeDef kind " + typeId);
+ }
}
static Map getClassFields(Class> type, List fieldsInfo) {
diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/TypeEqualMetaCompressor.java b/java/fory-core/src/main/java/org/apache/fory/meta/TypeEqualMetaCompressor.java
index 1c21296ce6..34eb2fe2dd 100644
--- a/java/fory-core/src/main/java/org/apache/fory/meta/TypeEqualMetaCompressor.java
+++ b/java/fory-core/src/main/java/org/apache/fory/meta/TypeEqualMetaCompressor.java
@@ -52,6 +52,11 @@ public byte[] decompress(byte[] data, int offset, int size) {
return compressor.decompress(data, offset, size);
}
+ @Override
+ public byte[] decompress(byte[] data, int offset, int size, int maxOutputSize) {
+ return compressor.decompress(data, offset, size, maxOutputSize);
+ }
+
@Override
public boolean equals(Object obj) {
if (obj == null || obj.getClass() != getClass()) {
diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java
index 6705de531b..588756f51b 100644
--- a/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java
+++ b/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java
@@ -41,6 +41,7 @@
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Collection;
+import java.util.Collections;
import java.util.Comparator;
import java.util.Date;
import java.util.EnumSet;
@@ -101,6 +102,7 @@
import org.apache.fory.meta.ClassSpec;
import org.apache.fory.meta.EncodedMetaString;
import org.apache.fory.meta.Encoders;
+import org.apache.fory.meta.NativeTypeDefEncoder;
import org.apache.fory.meta.TypeDef;
import org.apache.fory.reflect.ObjectCreators;
import org.apache.fory.reflect.ReflectionUtils;
@@ -537,7 +539,7 @@ public void register(Class> cls, String namespace, String name) {
buildUnregisteredTypeId(cls, existingInfo == null ? null : existingInfo.serializer);
TypeInfo typeInfo = new TypeInfo(cls, nsBytes, nameBytes, null, typeId, -1);
classInfoMap.put(cls, typeInfo);
- compositeNameBytes2TypeInfo.put(new TypeNameBytes(nsBytes.hash, nameBytes.hash), typeInfo);
+ compositeNameBytes2TypeInfo.put(new TypeNameBytes(nsBytes, nameBytes), typeInfo);
extRegistry.registeredClasses.put(fullname, cls);
registerGraalvmClass(cls);
}
@@ -583,7 +585,7 @@ public void registerUnion(Class> cls, String namespace, String name, Serialize
TypeInfo typeInfo = new TypeInfo(cls, nsBytes, nameBytes, serializer, typeId, -1);
typeInfo.setSerializer(this, serializer);
classInfoMap.put(cls, typeInfo);
- compositeNameBytes2TypeInfo.put(new TypeNameBytes(nsBytes.hash, nameBytes.hash), typeInfo);
+ compositeNameBytes2TypeInfo.put(new TypeNameBytes(nsBytes, nameBytes), typeInfo);
extRegistry.registeredClasses.put(fullname, cls);
registerGraalvmClass(cls);
}
@@ -864,17 +866,92 @@ public int getTypeIdForTypeDef(Class> cls) {
}
return typeInfo.typeId;
}
- int typeId = buildUnregisteredTypeId(cls, null);
+ int typeId = usesNonStructTypeDef(cls) ? Types.NAMED_EXT : buildUnregisteredTypeId(cls, null);
typeInfo = new TypeInfo(this, cls, null, typeId, INVALID_USER_TYPE_ID);
classInfoMap.put(cls, typeInfo);
if (typeInfo.namespace != null && typeInfo.typeName != null) {
- TypeNameBytes typeNameBytes =
- new TypeNameBytes(typeInfo.namespace.hash, typeInfo.typeName.hash);
+ TypeNameBytes typeNameBytes = new TypeNameBytes(typeInfo.namespace, typeInfo.typeName);
compositeNameBytes2TypeInfo.put(typeNameBytes, typeInfo);
}
return typeId;
}
+ public int getTypeDefRootTypeId(Class> cls, boolean hasFieldMetadata) {
+ if (hasFieldMetadata) {
+ // Preserve the normal TypeInfo/name cache so locally generated or dynamically registered
+ // classes can be resolved when the TypeDef is decoded by the same resolver.
+ getTypeIdForTypeDef(cls);
+ return getFieldMetadataTypeIdForTypeDef(cls);
+ }
+ TypeInfo typeInfo = classInfoMap.get(cls);
+ if (typeInfo != null) {
+ return normalizeTypeDefRootTypeId(cls, typeInfo.typeId);
+ }
+ Integer classId = extRegistry.registeredClassIdMap.get(cls);
+ if (classId != null) {
+ typeInfo = classInfoMap.get(cls);
+ if (typeInfo == null) {
+ typeInfo = getTypeInfo(cls);
+ }
+ return normalizeTypeDefRootTypeId(cls, typeInfo.typeId);
+ }
+ return usesNonStructTypeDef(cls) ? Types.NAMED_EXT : buildUnregisteredTypeId(cls, null);
+ }
+
+ private int getFieldMetadataTypeIdForTypeDef(Class> cls) {
+ Integer classId = extRegistry.registeredClassIdMap.get(cls);
+ if (classId != null && !isInternalRegisteredClassId(cls, classId)) {
+ return buildUserTypeId(cls, null);
+ }
+ return super.buildUnregisteredTypeId(cls, null);
+ }
+
+ private int normalizeTypeDefRootTypeId(Class> cls, int typeId) {
+ if (usesNonStructTypeDef(cls)) {
+ // Placeholder TypeInfo can be created before the natural serializer is installed.
+ // The TypeDef root kind must still select the non-struct serializer family.
+ return Types.isExtType(typeId) ? typeId : Types.NAMED_EXT;
+ }
+ if (isSupportedTypeDefTypeId(typeId)) {
+ return typeId;
+ }
+ return buildUnregisteredTypeId(cls, null);
+ }
+
+ private static boolean isSupportedTypeDefTypeId(int typeId) {
+ switch (typeId) {
+ case Types.ENUM:
+ case Types.NAMED_ENUM:
+ case Types.STRUCT:
+ case Types.COMPATIBLE_STRUCT:
+ case Types.NAMED_STRUCT:
+ case Types.NAMED_COMPATIBLE_STRUCT:
+ case Types.EXT:
+ case Types.NAMED_EXT:
+ case Types.TYPED_UNION:
+ case Types.NAMED_UNION:
+ return true;
+ default:
+ return false;
+ }
+ }
+
+ private boolean usesNonStructTypeDef(Class> cls) {
+ return !cls.isEnum()
+ && (cls.isArray()
+ || isCollection(cls)
+ || isMap(cls)
+ || Externalizable.class.isAssignableFrom(cls)
+ || requireJavaSerialization(cls)
+ || useReplaceResolveSerializer(cls)
+ || Functions.isLambda(cls)
+ || (config.isScalaOptimizationEnabled() && ReflectionUtils.isScalaSingletonObject(cls))
+ || Calendar.class.isAssignableFrom(cls)
+ || ZoneId.class.isAssignableFrom(cls)
+ || TimeZone.class.isAssignableFrom(cls)
+ || ByteBuffer.class.isAssignableFrom(cls));
+ }
+
/**
* Compute the user type id used in TypeDef without forcing serializer creation. Returns -1 when
* the class isn't registered by numeric id.
@@ -1079,8 +1156,7 @@ private void registerSerializerImpl(Class> type, Serializer> serializer) {
typeInfo = sharedTypeInfo;
updateTypeInfo(type, typeInfo);
if (typeInfo.namespace != null && typeInfo.typeName != null) {
- TypeNameBytes typeNameBytes =
- new TypeNameBytes(typeInfo.namespace.hash, typeInfo.typeName.hash);
+ TypeNameBytes typeNameBytes = new TypeNameBytes(typeInfo.namespace, typeInfo.typeName);
compositeNameBytes2TypeInfo.put(typeNameBytes, typeInfo);
}
if (typeInfoCache.type == type) {
@@ -1226,8 +1302,7 @@ public void addSerializer(Class> type, Serializer> serializer) {
// readTypeInfo can find the TypeInfo by name bytes during deserialization.
// This is important for dynamically created classes that can't be loaded by name.
if (typeInfo.namespace != null && typeInfo.typeName != null) {
- TypeNameBytes typeNameBytes =
- new TypeNameBytes(typeInfo.namespace.hash, typeInfo.typeName.hash);
+ TypeNameBytes typeNameBytes = new TypeNameBytes(typeInfo.namespace, typeInfo.typeName);
compositeNameBytes2TypeInfo.put(typeNameBytes, typeInfo);
}
}
@@ -1708,7 +1783,6 @@ private void registerGraalvmSerializerClass(Class> cls) {
getGraalvmClassRegistry()
.putDeserializerClass(
typeDef.getId(), getMetaSharedDeserializerClassForGraalvmBuild(cls, typeDef));
- extRegistry.typeInfoByTypeDefId.remove(typeDef.getId());
}
typeInfoCache = NIL_TYPE_INFO;
if (RecordUtils.isRecord(cls)) {
@@ -1809,15 +1883,15 @@ private TypeDef buildTypeDef(TypeInfo typeInfo, Class extends Serializer> seri
Preconditions.checkArgument(
serializerClass != UnknownClassSerializers.UnknownStructSerializer.class);
if (needToWriteTypeDef(serializerClass)) {
- typeDef =
- cacheTypeDef(
- typeDefMap.computeIfAbsent(typeInfo.type, cls -> TypeDef.buildTypeDef(this, cls)));
+ typeDef = typeDefMap.computeIfAbsent(typeInfo.type, cls -> TypeDef.buildTypeDef(this, cls));
} else {
// Some type will use other serializers such MapSerializer and so on.
typeDef =
- cacheTypeDef(
- typeDefMap.computeIfAbsent(
- typeInfo.type, cls -> TypeDef.buildTypeDef(this, cls, new ArrayList<>(), false)));
+ typeDefMap.computeIfAbsent(
+ typeInfo.type,
+ cls ->
+ NativeTypeDefEncoder.buildTypeDefWithFieldInfos(
+ this, cls, Collections.emptyList()));
}
typeInfo.typeDef = typeDef;
return typeDef;
@@ -1924,7 +1998,7 @@ private TypeInfo getTypeInfoByTypeIdForReadClassInternal(int typeId, int userTyp
@Override
protected TypeInfo loadBytesToTypeInfo(
EncodedMetaString packageBytes, EncodedMetaString simpleClassNameBytes) {
- TypeNameBytes typeNameBytes = new TypeNameBytes(packageBytes.hash, simpleClassNameBytes.hash);
+ TypeNameBytes typeNameBytes = new TypeNameBytes(packageBytes, simpleClassNameBytes);
TypeInfo typeInfo = compositeNameBytes2TypeInfo.get(typeNameBytes);
if (typeInfo == null) {
typeInfo = populateBytesToTypeInfo(typeNameBytes, packageBytes, simpleClassNameBytes);
@@ -1946,8 +2020,7 @@ protected TypeInfo ensureSerializerForTypeInfo(TypeInfo typeInfo) {
TypeInfo newTypeInfo = getTypeInfo(typeInfo.type);
// Update the cache with the correct TypeInfo that has a serializer
if (typeInfo.typeName != null) {
- TypeNameBytes typeNameBytes =
- new TypeNameBytes(typeInfo.namespace.hash, typeInfo.typeName.hash);
+ TypeNameBytes typeNameBytes = new TypeNameBytes(typeInfo.namespace, typeInfo.typeName);
compositeNameBytes2TypeInfo.put(typeNameBytes, newTypeInfo);
}
return newTypeInfo;
@@ -1985,8 +2058,7 @@ public Class> loadClassForMeta(String className, boolean isEnum, int arrayDims
String typeName = ReflectionUtils.getClassNameWithoutPackage(className);
EncodedMetaString pkgBytes = sharedRegistry.getPackageEncodedMetaString(pkg);
EncodedMetaString typeBytes = sharedRegistry.getTypeNameEncodedMetaString(typeName);
- TypeInfo cachedInfo =
- compositeNameBytes2TypeInfo.get(new TypeNameBytes(pkgBytes.hash, typeBytes.hash));
+ TypeInfo cachedInfo = compositeNameBytes2TypeInfo.get(new TypeNameBytes(pkgBytes, typeBytes));
if (cachedInfo != null) {
return cachedInfo.type;
}
diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/SharedRegistry.java b/java/fory-core/src/main/java/org/apache/fory/resolver/SharedRegistry.java
index 8e6b50484f..5dec2d5ff6 100644
--- a/java/fory-core/src/main/java/org/apache/fory/resolver/SharedRegistry.java
+++ b/java/fory-core/src/main/java/org/apache/fory/resolver/SharedRegistry.java
@@ -27,6 +27,7 @@
import java.util.Objects;
import java.util.SortedMap;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicInteger;
import org.apache.fory.annotation.Internal;
import org.apache.fory.codegen.CodeGenerator;
import org.apache.fory.collection.BiMap;
@@ -54,6 +55,7 @@
public final class SharedRegistry {
private static final int MAX_CACHED_ENCODED_META_STRINGS = 32768;
private static final int MAX_CACHED_ENCODED_META_STRING_LENGTH = 2048;
+ private static final int MAX_CACHED_TYPE_DEFS = 8192;
final ConcurrentIdentityMap