Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2007,6 +2007,8 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil:
CMemoryPool* pool)

CStatus GetRecordBatchSize(const CRecordBatch& batch, int64_t* size)
CStatus GetRecordBatchSize(const CRecordBatch& batch,
const CIpcWriteOptions& options, int64_t* size)
CStatus GetTensorSize(const CTensor& tensor, int64_t* size)

CStatus WriteTensor(const CTensor& tensor, COutputStream* dst,
Expand All @@ -2026,6 +2028,10 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil:
CResult[shared_ptr[CBuffer]] SerializeRecordBatch(
const CRecordBatch& schema, const CIpcWriteOptions& options)

CStatus SerializeRecordBatch(const CRecordBatch& batch,
const CIpcWriteOptions& options,
COutputStream* out)

CResult[shared_ptr[CSchema]] ReadSchema(const CMessage& message,
CDictionaryMemo* dictionary_memo)

Expand Down
44 changes: 38 additions & 6 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -3082,7 +3082,7 @@ cdef class RecordBatch(_Tabular):

return pyarrow_wrap_batch(c_batch)

def serialize(self, memory_pool=None):
def serialize(self, memory_pool=None, Buffer buffer=None):
"""
Write RecordBatch to Buffer as encapsulated IPC message, which does not
include a Schema.
Expand All @@ -3095,6 +3095,13 @@ cdef class RecordBatch(_Tabular):
----------
memory_pool : MemoryPool, default None
Uses default memory pool if not specified
buffer : Buffer, default None
If provided, serialize into this pre-allocated buffer instead of
allocating a new one. The buffer must be mutable and large enough
to hold the serialized data. Use
:func:`pyarrow.ipc.get_record_batch_size` to determine the
required size. A slice of the buffer with the exact serialized
size is returned.

Returns
-------
Expand Down Expand Up @@ -3122,14 +3129,39 @@ cdef class RecordBatch(_Tabular):
animals: ["Flamingo","Parrot","Dog","Horse","Brittle stars","Centipede"]
"""
self._assert_cpu()
cdef shared_ptr[CBuffer] buffer
cdef shared_ptr[CBuffer] c_buffer
cdef CIpcWriteOptions options = CIpcWriteOptions.Defaults()
cdef int64_t size
cdef CFixedSizeBufferWriter* stream
options.memory_pool = maybe_unbox_memory_pool(memory_pool)

with nogil:
buffer = GetResultValue(
SerializeRecordBatch(deref(self.batch), options))
return pyarrow_wrap_buffer(buffer)
if buffer is not None:
if not buffer.is_mutable:
raise ValueError("buffer is not mutable")

with nogil:
check_status(GetRecordBatchSize(
deref(self.batch), options, &size))

if buffer.size < size:
raise ValueError(
f"buffer is too small: {buffer.size} < {size}")

stream = new CFixedSizeBufferWriter(buffer.buffer)
try:
with nogil:
check_status(SerializeRecordBatch(
deref(self.batch), options,
<COutputStream*>stream))
finally:
del stream

return buffer.slice(0, size)
else:
with nogil:
c_buffer = GetResultValue(
SerializeRecordBatch(deref(self.batch), options))
return pyarrow_wrap_buffer(c_buffer)

def slice(self, offset=0, length=None):
"""
Expand Down
41 changes: 41 additions & 0 deletions python/pyarrow/tests/test_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,47 @@ def test_schema_batch_serialize_methods():
assert recons_batch.equals(batch)


def test_serialize_record_batch_to_buffer():
batch = pa.RecordBatch.from_pydict({
'ints': [1, 2, 3],
'strs': ['a', 'b', 'c'],
})
schema = batch.schema

# Round-trip with externally allocated buffer
size = pa.ipc.get_record_batch_size(batch)
buf = pa.allocate_buffer(size * 2)
result = batch.serialize(buffer=buf)
assert result.size == size
recons = pa.ipc.read_record_batch(result, schema)
assert recons.equals(batch)

# Round-trip with oversized buffer
big_buf = pa.allocate_buffer(size * 10)
result = batch.serialize(buffer=big_buf)
assert result.size == size
recons = pa.ipc.read_record_batch(result, schema)
assert recons.equals(batch)

# Exact size buffer
exact_buf = pa.allocate_buffer(size)
result = batch.serialize(buffer=exact_buf)
assert result.size == size
recons = pa.ipc.read_record_batch(result, schema)
assert recons.equals(batch)

# Buffer too small
small_buf = pa.allocate_buffer(8)
with pytest.raises(ValueError, match="buffer is too small"):
batch.serialize(buffer=small_buf)

# Immutable buffer
immutable_buf = pa.py_buffer(b'\x00' * size)
assert not immutable_buf.is_mutable
with pytest.raises(ValueError, match="buffer is not mutable"):
batch.serialize(buffer=immutable_buf)


def test_schema_serialization_with_metadata():
field_metadata = {b'foo': b'bar', b'kind': b'field'}
schema_metadata = {b'foo': b'bar', b'kind': b'schema'}
Expand Down
Loading