diff --git a/swift/Arrow/Sources/Arrow/ArrowReader.swift b/swift/Arrow/Sources/Arrow/ArrowReader.swift index ae187e22eef..8515a782afa 100644 --- a/swift/Arrow/Sources/Arrow/ArrowReader.swift +++ b/swift/Arrow/Sources/Arrow/ArrowReader.swift @@ -19,7 +19,7 @@ import FlatBuffers import Foundation let FILEMARKER = "ARROW1" -let CONTINUATIONMARKER = -1 +let CONTINUATIONMARKER = UInt32(0xFFFFFFFF) public class ArrowReader { // swiftlint:disable:this type_body_length private class RecordBatchData { @@ -216,7 +216,77 @@ public class ArrowReader { // swiftlint:disable:this type_body_length return .success(RecordBatch(arrowSchema, columns: columns)) } - public func fromStream( // swiftlint:disable:this function_body_length + /* + This is for reading the Arrow streaming format. The Arrow streaming format + is slightly different from the Arrow File format as it doesn't contain a header + and footer. + */ + public func readStreaming( // swiftlint:disable:this function_body_length + _ input: Data, + useUnalignedBuffers: Bool = false + ) -> Result { + let result = ArrowReaderResult() + var offset: Int = 0 + var length = getUInt32(input, offset: offset) + var streamData = input + var schemaMessage: org_apache_arrow_flatbuf_Schema? + while length != 0 { + if length == CONTINUATIONMARKER { + offset += Int(MemoryLayout.size) + length = getUInt32(input, offset: offset) + if length == 0 { + return .success(result) + } + } + + offset += Int(MemoryLayout.size) + streamData = input[offset...] + let dataBuffer = ByteBuffer( + data: streamData, + allowReadingUnalignedBuffers: true) + let message = org_apache_arrow_flatbuf_Message.getRootAsMessage(bb: dataBuffer) + switch message.headerType { + case .recordbatch: + do { + let rbMessage = message.header(type: org_apache_arrow_flatbuf_RecordBatch.self)! + let recordBatch = try loadRecordBatch( + rbMessage, + schema: schemaMessage!, + arrowSchema: result.schema!, + data: input, + messageEndOffset: (Int64(offset) + Int64(length))).get() + result.batches.append(recordBatch) + offset += Int(message.bodyLength + Int64(length)) + length = getUInt32(input, offset: offset) + } catch let error as ArrowError { + return .failure(error) + } catch { + return .failure(.unknownError("Unexpected error: \(error)")) + } + case .schema: + schemaMessage = message.header(type: org_apache_arrow_flatbuf_Schema.self)! + let schemaResult = loadSchema(schemaMessage!) + switch schemaResult { + case .success(let schema): + result.schema = schema + case .failure(let error): + return .failure(error) + } + offset += Int(message.bodyLength + Int64(length)) + length = getUInt32(input, offset: offset) + default: + return .failure(.unknownError("Unhandled header type: \(message.headerType)")) + } + } + return .success(result) + } + + /* + This is for reading the Arrow file format. The Arrow file format supports + random accessing the data. The Arrow file format contains a header and + footer around the Arrow streaming format. + */ + public func readFile( // swiftlint:disable:this function_body_length _ fileData: Data, useUnalignedBuffers: Bool = false ) -> Result { @@ -242,7 +312,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length for index in 0 ..< footer.recordBatchesCount { let recordBatch = footer.recordBatches(at: index)! var messageLength = fileData.withUnsafeBytes { rawBuffer in - rawBuffer.loadUnaligned(fromByteOffset: Int(recordBatch.offset), as: Int32.self) + rawBuffer.loadUnaligned(fromByteOffset: Int(recordBatch.offset), as: UInt32.self) } var messageOffset: Int64 = 1 @@ -251,7 +321,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length messageLength = fileData.withUnsafeBytes { rawBuffer in rawBuffer.loadUnaligned( fromByteOffset: Int(recordBatch.offset + Int64(MemoryLayout.size)), - as: Int32.self) + as: UInt32.self) } } @@ -296,7 +366,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length let markerLength = FILEMARKER.utf8.count let footerLengthEnd = Int(fileData.count - markerLength) let data = fileData[..<(footerLengthEnd)] - return fromStream(data) + return readFile(data) } catch { return .failure(.unknownError("Error loading file: \(error)")) } @@ -340,10 +410,10 @@ public class ArrowReader { // swiftlint:disable:this type_body_length } catch { return .failure(.unknownError("Unexpected error: \(error)")) } - default: return .failure(.unknownError("Unhandled header type: \(message.headerType)")) } } } +// swiftlint:disable:this file_length diff --git a/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift b/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift index 48c6fd85507..18cf41ad25a 100644 --- a/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift +++ b/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift @@ -289,3 +289,10 @@ func validateFileData(_ data: Data) -> Bool { let endString = String(decoding: data[(data.count - markerLength)...], as: UTF8.self) return startString == FILEMARKER && endString == FILEMARKER } + +func getUInt32(_ data: Data, offset: Int) -> UInt32 { + let token = data.withUnsafeBytes { rawBuffer in + rawBuffer.loadUnaligned(fromByteOffset: offset, as: UInt32.self) + } + return token +} diff --git a/swift/Arrow/Sources/Arrow/ArrowWriter.swift b/swift/Arrow/Sources/Arrow/ArrowWriter.swift index eec4c858276..54581ba396f 100644 --- a/swift/Arrow/Sources/Arrow/ArrowWriter.swift +++ b/swift/Arrow/Sources/Arrow/ArrowWriter.swift @@ -123,6 +123,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length let startIndex = writer.count switch writeRecordBatch(batch: batch) { case .success(let rbResult): + withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) {writer.append(Data($0))} withUnsafeBytes(of: rbResult.1.o.littleEndian) {writer.append(Data($0))} writer.append(rbResult.0) switch writeRecordBatchData(&writer, batch: batch) { @@ -232,7 +233,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length return .success(fbb.data) } - private func writeStream(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result { + private func writeFile(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result { var fbb: FlatBufferBuilder = FlatBufferBuilder() switch writeSchema(&fbb, schema: info.schema) { case .success(let schemaOffset): @@ -264,9 +265,41 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length return .success(true) } - public func toStream(_ info: ArrowWriter.Info) -> Result { + public func writeSteaming(_ info: ArrowWriter.Info) -> Result { + let writer: any DataWriter = InMemDataWriter() + switch toMessage(info.schema) { + case .success(let schemaData): + withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) {writer.append(Data($0))} + withUnsafeBytes(of: UInt32(schemaData.count).littleEndian) {writer.append(Data($0))} + writer.append(schemaData) + case .failure(let error): + return .failure(error) + } + + for batch in info.batches { + switch toMessage(batch) { + case .success(let batchData): + withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) {writer.append(Data($0))} + withUnsafeBytes(of: UInt32(batchData[0].count).littleEndian) {writer.append(Data($0))} + writer.append(batchData[0]) + writer.append(batchData[1]) + case .failure(let error): + return .failure(error) + } + } + + withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) {writer.append(Data($0))} + withUnsafeBytes(of: UInt32(0).littleEndian) {writer.append(Data($0))} + if let memWriter = writer as? InMemDataWriter { + return .success(memWriter.data) + } else { + return .failure(.invalid("Unable to cast writer")) + } + } + + public func writeFile(_ info: ArrowWriter.Info) -> Result { var writer: any DataWriter = InMemDataWriter() - switch writeStream(&writer, info: info) { + switch writeFile(&writer, info: info) { case .success: if let memWriter = writer as? InMemDataWriter { return .success(memWriter.data) @@ -293,7 +326,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length var writer: any DataWriter = FileDataWriter(fileHandle) writer.append(FILEMARKER.data(using: .utf8)!) - switch writeStream(&writer, info: info) { + switch writeFile(&writer, info: info) { case .success: writer.append(FILEMARKER.data(using: .utf8)!) case .failure(let error): diff --git a/swift/Arrow/Tests/ArrowTests/IPCTests.swift b/swift/Arrow/Tests/ArrowTests/IPCTests.swift index 4f56f5fdabb..703490d2b24 100644 --- a/swift/Arrow/Tests/ArrowTests/IPCTests.swift +++ b/swift/Arrow/Tests/ArrowTests/IPCTests.swift @@ -118,6 +118,60 @@ func makeRecordBatch() throws -> RecordBatch { } } +final class IPCStreamReaderTests: XCTestCase { + func testRBInMemoryToFromStream() throws { + let schema = makeSchema() + let recordBatch = try makeRecordBatch() + let arrowWriter = ArrowWriter() + let writerInfo = ArrowWriter.Info(.recordbatch, schema: schema, batches: [recordBatch]) + switch arrowWriter.writeSteaming(writerInfo) { + case .success(let writeData): + let arrowReader = ArrowReader() + switch arrowReader.readStreaming(writeData) { + case .success(let result): + let recordBatches = result.batches + XCTAssertEqual(recordBatches.count, 1) + for recordBatch in recordBatches { + XCTAssertEqual(recordBatch.length, 4) + XCTAssertEqual(recordBatch.columns.count, 5) + XCTAssertEqual(recordBatch.schema.fields.count, 5) + XCTAssertEqual(recordBatch.schema.fields[0].name, "col1") + XCTAssertEqual(recordBatch.schema.fields[0].type.info, ArrowType.ArrowUInt8) + XCTAssertEqual(recordBatch.schema.fields[1].name, "col2") + XCTAssertEqual(recordBatch.schema.fields[1].type.info, ArrowType.ArrowString) + XCTAssertEqual(recordBatch.schema.fields[2].name, "col3") + XCTAssertEqual(recordBatch.schema.fields[2].type.info, ArrowType.ArrowDate32) + XCTAssertEqual(recordBatch.schema.fields[3].name, "col4") + XCTAssertEqual(recordBatch.schema.fields[3].type.info, ArrowType.ArrowInt32) + XCTAssertEqual(recordBatch.schema.fields[4].name, "col5") + XCTAssertEqual(recordBatch.schema.fields[4].type.info, ArrowType.ArrowFloat) + let columns = recordBatch.columns + XCTAssertEqual(columns[0].nullCount, 2) + let dateVal = + "\((columns[2].array as! AsString).asString(0))" // swiftlint:disable:this force_cast + XCTAssertEqual(dateVal, "2014-09-10 00:00:00 +0000") + let stringVal = + "\((columns[1].array as! AsString).asString(1))" // swiftlint:disable:this force_cast + XCTAssertEqual(stringVal, "test22") + let uintVal = + "\((columns[0].array as! AsString).asString(0))" // swiftlint:disable:this force_cast + XCTAssertEqual(uintVal, "10") + let stringVal2 = + "\((columns[1].array as! AsString).asString(3))" // swiftlint:disable:this force_cast + XCTAssertEqual(stringVal2, "test44") + let uintVal2 = + "\((columns[0].array as! AsString).asString(3))" // swiftlint:disable:this force_cast + XCTAssertEqual(uintVal2, "44") + } + case.failure(let error): + throw error + } + case .failure(let error): + throw error + } + } +} + final class IPCFileReaderTests: XCTestCase { // swiftlint:disable:this type_body_length func testFileReader_struct() throws { let fileURL = currentDirectory().appendingPathComponent("../../testdata_struct.arrow") @@ -204,10 +258,10 @@ final class IPCFileReaderTests: XCTestCase { // swiftlint:disable:this type_body let arrowWriter = ArrowWriter() // write data from file to a stream let writerInfo = ArrowWriter.Info(.recordbatch, schema: fileRBs[0].schema, batches: fileRBs) - switch arrowWriter.toStream(writerInfo) { + switch arrowWriter.writeFile(writerInfo) { case .success(let writeData): // read stream back into recordbatches - try checkBoolRecordBatch(arrowReader.fromStream(writeData)) + try checkBoolRecordBatch(arrowReader.readFile(writeData)) case .failure(let error): throw error } @@ -227,10 +281,10 @@ final class IPCFileReaderTests: XCTestCase { // swiftlint:disable:this type_body let recordBatch = try makeRecordBatch() let arrowWriter = ArrowWriter() let writerInfo = ArrowWriter.Info(.recordbatch, schema: schema, batches: [recordBatch]) - switch arrowWriter.toStream(writerInfo) { + switch arrowWriter.writeFile(writerInfo) { case .success(let writeData): let arrowReader = ArrowReader() - switch arrowReader.fromStream(writeData) { + switch arrowReader.readFile(writeData) { case .success(let result): let recordBatches = result.batches XCTAssertEqual(recordBatches.count, 1) @@ -279,10 +333,10 @@ final class IPCFileReaderTests: XCTestCase { // swiftlint:disable:this type_body let schema = makeSchema() let arrowWriter = ArrowWriter() let writerInfo = ArrowWriter.Info(.schema, schema: schema) - switch arrowWriter.toStream(writerInfo) { + switch arrowWriter.writeFile(writerInfo) { case .success(let writeData): let arrowReader = ArrowReader() - switch arrowReader.fromStream(writeData) { + switch arrowReader.readFile(writeData) { case .success(let result): XCTAssertNotNil(result.schema) let schema = result.schema! @@ -362,10 +416,10 @@ final class IPCFileReaderTests: XCTestCase { // swiftlint:disable:this type_body let dataset = try makeBinaryDataset() let writerInfo = ArrowWriter.Info(.recordbatch, schema: dataset.0, batches: [dataset.1]) let arrowWriter = ArrowWriter() - switch arrowWriter.toStream(writerInfo) { + switch arrowWriter.writeFile(writerInfo) { case .success(let writeData): let arrowReader = ArrowReader() - switch arrowReader.fromStream(writeData) { + switch arrowReader.readFile(writeData) { case .success(let result): XCTAssertNotNil(result.schema) let schema = result.schema! @@ -391,10 +445,10 @@ final class IPCFileReaderTests: XCTestCase { // swiftlint:disable:this type_body let dataset = try makeTimeDataset() let writerInfo = ArrowWriter.Info(.recordbatch, schema: dataset.0, batches: [dataset.1]) let arrowWriter = ArrowWriter() - switch arrowWriter.toStream(writerInfo) { + switch arrowWriter.writeFile(writerInfo) { case .success(let writeData): let arrowReader = ArrowReader() - switch arrowReader.fromStream(writeData) { + switch arrowReader.readFile(writeData) { case .success(let result): XCTAssertNotNil(result.schema) let schema = result.schema!