diff --git a/common/src/main/java/com/pedro/common/nal/NalReader.kt b/common/src/main/java/com/pedro/common/nal/NalReader.kt new file mode 100644 index 0000000000..b4cf5ba542 --- /dev/null +++ b/common/src/main/java/com/pedro/common/nal/NalReader.kt @@ -0,0 +1,43 @@ +package com.pedro.common.nal + +import java.nio.ByteBuffer + +object NalReader { + private const val ZERO: Byte = 0x00 + private const val ONE: Byte = 0x01 + + fun extractNals(buffer: ByteBuffer): List { + val units = ArrayList() + + val array = buffer.array() + val offset = buffer.arrayOffset() + val start = offset + buffer.position() + val limit = offset + buffer.limit() + var payloadStart = -1 + + var i = start + while (i < limit - 2) { + if (array[i] == ZERO && array[i + 1] == ZERO && array[i + 2] == ONE) { + val previousPayloadEnd = if (i > start && array[i - 1] == ZERO) i - 1 else i + if (payloadStart != -1 && previousPayloadEnd > payloadStart) { + val duplicate = buffer.duplicate() + duplicate.position(payloadStart - offset) + duplicate.limit(previousPayloadEnd - offset) + units.add(duplicate.slice()) + } + payloadStart = i + 3 + i += 3 + } else { + i++ + } + } + if (payloadStart != -1 && payloadStart < limit) { + val duplicate = buffer.duplicate() + duplicate.position(payloadStart - offset) + duplicate.limit(limit - offset) + units.add(duplicate.slice()) + } + if (units.isEmpty()) units.add(buffer) + return units + } +} \ No newline at end of file diff --git a/common/src/test/java/com/pedro/common/NalReaderTest.kt b/common/src/test/java/com/pedro/common/NalReaderTest.kt new file mode 100644 index 0000000000..fe5fa19fb4 --- /dev/null +++ b/common/src/test/java/com/pedro/common/NalReaderTest.kt @@ -0,0 +1,34 @@ +package com.pedro.common + +import com.pedro.common.nal.NalReader +import junit.framework.TestCase.assertEquals +import org.junit.Test +import java.nio.ByteBuffer + +class NalReaderTest { + + private val header = byteArrayOf( + 0x00, 0x00, 0x01 + ) + private val nal = header.plus(ByteArray(10_000) { 0x08 }) + + @Test + fun testReadMultipleNalsFromBuffer() { + val buffer = ByteBuffer.wrap(nal.plus(nal).plus(nal)) + val nals = NalReader.extractNals(buffer) + assertEquals(3, nals.size) + nals.forEach { + assertEquals(10_000, it.capacity()) + } + } + + @Test + fun testReadSingleNalFromBuffer() { + val buffer = ByteBuffer.wrap(nal) + val nals = NalReader.extractNals(buffer) + assertEquals(1, nals.size) + nals.forEach { + assertEquals(10_000, it.capacity()) + } + } +} \ No newline at end of file diff --git a/rtmp/src/main/java/com/pedro/rtmp/flv/video/packet/H264Packet.kt b/rtmp/src/main/java/com/pedro/rtmp/flv/video/packet/H264Packet.kt index 4d4e18c0cc..cffd86f7be 100644 --- a/rtmp/src/main/java/com/pedro/rtmp/flv/video/packet/H264Packet.kt +++ b/rtmp/src/main/java/com/pedro/rtmp/flv/video/packet/H264Packet.kt @@ -19,6 +19,7 @@ package com.pedro.rtmp.flv.video.packet import android.util.Log import com.pedro.common.frame.MediaFrame import com.pedro.common.getStartCodeSize +import com.pedro.common.nal.NalReader import com.pedro.common.removeInfo import com.pedro.rtmp.flv.BasePacket import com.pedro.rtmp.flv.FlvPacket @@ -102,11 +103,11 @@ class H264Packet: BasePacket() { val headerSize = getHeaderSize(fixedBuffer) if (headerSize == 0) return //invalid buffer or waiting for sps/pps fixedBuffer.rewind() - val validBuffer = removeHeader(fixedBuffer, headerSize) - val size = validBuffer.remaining() - buffer = ByteArray(header.size + size + naluSize) + val nals = NalReader.extractNals(fixedBuffer) + val size = nals.sumOf { it.capacity() } + buffer = ByteArray(header.size + size + naluSize * nals.size) - val type: Int = (validBuffer.get(0) and 0x1F).toInt() + val type: Int = (nals[0].get(0) and 0x1F).toInt() var nalType = VideoDataType.INTER_FRAME.value if (type == VideoNalType.IDR.value || mediaFrame.info.isKeyFrame) { nalType = VideoDataType.KEYFRAME.value @@ -116,9 +117,13 @@ class H264Packet: BasePacket() { } header[0] = ((nalType shl 4) or VideoFormat.AVC.value).toByte() header[1] = Type.NALU.value - writeNaluSize(buffer, header.size, size) - validBuffer.get(buffer, header.size + naluSize, size) - + var offset = header.size + nals.forEach { + val nalSize = it.capacity() + writeNaluSize(buffer, offset, nalSize) + it.get(buffer, offset + naluSize, nalSize) + offset += naluSize + nalSize + } System.arraycopy(header, 0, buffer, 0, header.size) callback(FlvPacket(buffer, ts, buffer.size, FlvType.VIDEO)) } diff --git a/rtmp/src/main/java/com/pedro/rtmp/flv/video/packet/H265Packet.kt b/rtmp/src/main/java/com/pedro/rtmp/flv/video/packet/H265Packet.kt index 2a9cd9febb..d39a7b1dcf 100644 --- a/rtmp/src/main/java/com/pedro/rtmp/flv/video/packet/H265Packet.kt +++ b/rtmp/src/main/java/com/pedro/rtmp/flv/video/packet/H265Packet.kt @@ -19,13 +19,14 @@ package com.pedro.rtmp.flv.video.packet import android.util.Log import com.pedro.common.frame.MediaFrame import com.pedro.common.getStartCodeSize +import com.pedro.common.nal.NalReader import com.pedro.common.removeInfo import com.pedro.rtmp.flv.BasePacket import com.pedro.rtmp.flv.FlvPacket import com.pedro.rtmp.flv.FlvType -import com.pedro.rtmp.flv.video.VideoFourCCPacketType import com.pedro.rtmp.flv.video.VideoDataType import com.pedro.rtmp.flv.video.VideoFormat +import com.pedro.rtmp.flv.video.VideoFourCCPacketType import com.pedro.rtmp.flv.video.VideoNalType import com.pedro.rtmp.flv.video.config.VideoSpecificConfigHEVC import java.nio.ByteBuffer @@ -109,11 +110,11 @@ class H265Packet: BasePacket() { val headerSize = getHeaderSize(fixedBuffer) if (headerSize == 0) return //invalid buffer or waiting for sps/pps fixedBuffer.rewind() - val validBuffer = removeHeader(fixedBuffer, headerSize) - val size = validBuffer.remaining() - buffer = ByteArray(header.size + size + naluSize) + val nals = NalReader.extractNals(fixedBuffer) + val size = nals.sumOf { it.capacity() } + buffer = ByteArray(header.size + size + naluSize * nals.size) - val type: Int = validBuffer.get(0).toInt().shr(1 and 0x3f) + val type: Int = nals[0].get(0).toInt().shr(1 and 0x3f) var nalType = VideoDataType.INTER_FRAME.value if (type == VideoNalType.IDR_N_LP.value || type == VideoNalType.IDR_W_DLP.value || mediaFrame.info.isKeyFrame) { nalType = VideoDataType.KEYFRAME.value @@ -122,9 +123,13 @@ class H265Packet: BasePacket() { return } header[0] = (0b10000000 or (nalType shl 4) or VideoFourCCPacketType.CODED_FRAMES.value).toByte() - writeNaluSize(buffer, header.size, size) - validBuffer.get(buffer, header.size + naluSize, size) - + var offset = header.size + nals.forEach { + val nalSize = it.capacity() + writeNaluSize(buffer, offset, nalSize) + it.get(buffer, offset + naluSize, nalSize) + offset += naluSize + nalSize + } System.arraycopy(header, 0, buffer, 0, header.size) callback(FlvPacket(buffer, ts, buffer.size, FlvType.VIDEO)) } diff --git a/rtmp/src/test/java/com/pedro/rtmp/amf/AmfObjectTest.kt b/rtmp/src/test/java/com/pedro/rtmp/amf/AmfObjectTest.kt index 08d1e06054..b88cd194ab 100644 --- a/rtmp/src/test/java/com/pedro/rtmp/amf/AmfObjectTest.kt +++ b/rtmp/src/test/java/com/pedro/rtmp/amf/AmfObjectTest.kt @@ -70,7 +70,6 @@ class AmfObjectTest { amfObject.writeHeader(output) amfObject.writeBody(output) -println(output.toByteArray().contentToString()) assertArrayEquals(expectedBuffer, output.toByteArray()) } } \ No newline at end of file diff --git a/rtsp/src/main/java/com/pedro/rtsp/rtp/packets/H264Packet.kt b/rtsp/src/main/java/com/pedro/rtsp/rtp/packets/H264Packet.kt index 7203382349..ac136ea6a7 100644 --- a/rtsp/src/main/java/com/pedro/rtsp/rtp/packets/H264Packet.kt +++ b/rtsp/src/main/java/com/pedro/rtsp/rtp/packets/H264Packet.kt @@ -18,6 +18,7 @@ package com.pedro.rtsp.rtp.packets import android.util.Log import com.pedro.common.frame.MediaFrame +import com.pedro.common.nal.NalReader import com.pedro.common.removeInfo import com.pedro.common.toByteArray import com.pedro.rtsp.rtsp.RtpFrame @@ -39,6 +40,7 @@ class H264Packet: BasePacket(RtpConstants.clockVideoFrequency, private var sendKeyFrame = false private var sps: ByteArray? = null private var pps: ByteArray? = null + private val header = ByteArray(2) init { channelIdentifier = RtpConstants.trackVideo @@ -55,13 +57,13 @@ class H264Packet: BasePacket(RtpConstants.clockVideoFrequency, val fixedBuffer = mediaFrame.data.removeInfo(mediaFrame.info) // We read a NAL units from ByteBuffer and we send them // NAL units are preceded with 0x00000001 - val header = ByteArray(getHeaderSize(fixedBuffer) + 1) - if (header.size == 1) return //invalid buffer or waiting for sps/pps + if (getHeaderSize(fixedBuffer) == 0) return //invalid buffer or waiting for sps/pps fixedBuffer.rewind() - fixedBuffer.get(header, 0, header.size) + val nals = NalReader.extractNals(fixedBuffer) val ts = mediaFrame.info.timestamp * 1000L - val naluLength = fixedBuffer.remaining() - val type: Int = (header[header.size - 1] and 0x1F).toInt() + val nalType = nals[0].get() + val naluLength = nals.sumOf { it.remaining() } + val type: Int = (nalType and 0x1F).toInt() val frames = mutableListOf() if (type == RtpConstants.IDR || mediaFrame.info.isKeyFrame) { stapA?.let { @@ -79,10 +81,10 @@ class H264Packet: BasePacket(RtpConstants.clockVideoFrequency, } if (sendKeyFrame) { // Small NAL unit => Single NAL unit - if (naluLength <= maxPacketSize - RtpConstants.RTP_HEADER_LENGTH - 1) { + if (nals.size == 1 && naluLength <= maxPacketSize - RtpConstants.RTP_HEADER_LENGTH - 1) { val buffer = getBuffer(naluLength + RtpConstants.RTP_HEADER_LENGTH + 1) - buffer[RtpConstants.RTP_HEADER_LENGTH] = header[header.size - 1] - fixedBuffer.get(buffer, RtpConstants.RTP_HEADER_LENGTH + 1, naluLength) + buffer[RtpConstants.RTP_HEADER_LENGTH] = nalType + nals[0].get(buffer, RtpConstants.RTP_HEADER_LENGTH + 1, naluLength) val rtpTs = updateTimeStamp(buffer, ts) markPacket(buffer) //mark end frame updateSeq(buffer) @@ -90,35 +92,37 @@ class H264Packet: BasePacket(RtpConstants.clockVideoFrequency, frames.add(rtpFrame) } else { // Set FU-A header - header[1] = header[header.size - 1] and 0x1F // FU header type + header[1] = nalType and 0x1F // FU header type header[1] = header[1].plus(0x80).toByte() // set start bit to 1 // Set FU-A indicator - header[0] = header[header.size - 1] and 0x60 and 0xFF.toByte() // FU indicator NRI + header[0] = nalType and 0x60 and 0xFF.toByte() // FU indicator NRI header[0] = header[0].plus(28).toByte() - var sum = 0 - while (sum < naluLength) { - val length = if (naluLength - sum > maxPacketSize - RtpConstants.RTP_HEADER_LENGTH - 2) { - maxPacketSize - RtpConstants.RTP_HEADER_LENGTH - 2 - } else { - fixedBuffer.remaining() + nals.forEachIndexed { index, data -> + var sum = 0 + val nalSize = data.remaining() + while (sum < nalSize) { + val length = if (nalSize - sum > maxPacketSize - RtpConstants.RTP_HEADER_LENGTH - 2) { + maxPacketSize - RtpConstants.RTP_HEADER_LENGTH - 2 + } else { + data.remaining() + } + val buffer = getBuffer(length + RtpConstants.RTP_HEADER_LENGTH + 2) + buffer[RtpConstants.RTP_HEADER_LENGTH] = header[0] + // Switch start bit + buffer[RtpConstants.RTP_HEADER_LENGTH + 1] = if (sum > 0) header[1] and 0x7F else header[1] + val rtpTs = updateTimeStamp(buffer, ts) + data.get(buffer, RtpConstants.RTP_HEADER_LENGTH + 2, length) + sum += length + // Last packet before next NAL + if (sum >= nalSize) { + // End bit on + buffer[RtpConstants.RTP_HEADER_LENGTH + 1] = buffer[RtpConstants.RTP_HEADER_LENGTH + 1].plus(0x40).toByte() + if (index == nals.size - 1) markPacket(buffer) //mark end frame + } + updateSeq(buffer) + val rtpFrame = RtpFrame(buffer, rtpTs, buffer.size, channelIdentifier) + frames.add(rtpFrame) } - val buffer = getBuffer(length + RtpConstants.RTP_HEADER_LENGTH + 2) - buffer[RtpConstants.RTP_HEADER_LENGTH] = header[0] - buffer[RtpConstants.RTP_HEADER_LENGTH + 1] = header[1] - val rtpTs = updateTimeStamp(buffer, ts) - fixedBuffer.get(buffer, RtpConstants.RTP_HEADER_LENGTH + 2, length) - sum += length - // Last packet before next NAL - if (sum >= naluLength) { - // End bit on - buffer[RtpConstants.RTP_HEADER_LENGTH + 1] = buffer[RtpConstants.RTP_HEADER_LENGTH + 1].plus(0x40).toByte() - markPacket(buffer) //mark end frame - } - updateSeq(buffer) - val rtpFrame = RtpFrame(buffer, rtpTs, buffer.size, channelIdentifier) - frames.add(rtpFrame) - // Switch start bit - header[1] = header[1] and 0x7F } } } else { diff --git a/rtsp/src/main/java/com/pedro/rtsp/rtp/packets/H265Packet.kt b/rtsp/src/main/java/com/pedro/rtsp/rtp/packets/H265Packet.kt index 2e2894dab3..303073d520 100644 --- a/rtsp/src/main/java/com/pedro/rtsp/rtp/packets/H265Packet.kt +++ b/rtsp/src/main/java/com/pedro/rtsp/rtp/packets/H265Packet.kt @@ -17,6 +17,7 @@ package com.pedro.rtsp.rtp.packets import com.pedro.common.frame.MediaFrame +import com.pedro.common.nal.NalReader import com.pedro.common.removeInfo import com.pedro.rtsp.rtsp.RtpFrame import com.pedro.rtsp.utils.RtpConstants @@ -33,6 +34,8 @@ class H265Packet: BasePacket( RtpConstants.payloadType + RtpConstants.trackVideo ) { + private val header = ByteArray(3) + init { channelIdentifier = RtpConstants.trackVideo } @@ -44,20 +47,21 @@ class H265Packet: BasePacket( val fixedBuffer = mediaFrame.data.removeInfo(mediaFrame.info) // We read a NAL units from ByteBuffer and we send them // NAL units are preceded with 0x00000001 - val header = ByteArray(fixedBuffer.getVideoStartCodeSize() + 2) - if (header.size == 2) return //invalid buffer or waiting for sps/pps/vps - fixedBuffer.get(header, 0, header.size) + if (fixedBuffer.getVideoStartCodeSize() == 0) return //invalid buffer or waiting for sps/pps/vps + val nals = NalReader.extractNals(fixedBuffer) val ts = mediaFrame.info.timestamp * 1000L - val naluLength = fixedBuffer.remaining() - val type: Int = header[header.size - 2].toInt().shr(1 and 0x3f) + val nalType = nals[0].get() + val nalType2 = nals[0].get() + val naluLength = nals.sumOf { it.remaining() } + val type: Int = nalType.toInt().shr(1 and 0x3f) val frames = mutableListOf() // Small NAL unit => Single NAL unit - if (naluLength <= maxPacketSize - RtpConstants.RTP_HEADER_LENGTH - 2) { + if (nals.size == 1 && naluLength <= maxPacketSize - RtpConstants.RTP_HEADER_LENGTH - 2) { val buffer = getBuffer(naluLength + RtpConstants.RTP_HEADER_LENGTH + 2) //Set PayloadHdr (exact copy of nal unit header) - buffer[RtpConstants.RTP_HEADER_LENGTH] = header[header.size - 2] - buffer[RtpConstants.RTP_HEADER_LENGTH + 1] = header[header.size - 1] - fixedBuffer.get(buffer, RtpConstants.RTP_HEADER_LENGTH + 2, naluLength) + buffer[RtpConstants.RTP_HEADER_LENGTH] = nalType + buffer[RtpConstants.RTP_HEADER_LENGTH + 1] = nalType2 + nals[0].get(buffer, RtpConstants.RTP_HEADER_LENGTH + 2, naluLength) val rtpTs = updateTimeStamp(buffer, ts) markPacket(buffer) //mark end frame updateSeq(buffer) @@ -75,31 +79,33 @@ class H265Packet: BasePacket( // +---------------+ header[2] = type.toByte() // FU header type header[2] = header[2].plus(0x80).toByte() // Start bit - var sum = 0 - while (sum < naluLength) { - val length = if (naluLength - sum > maxPacketSize - RtpConstants.RTP_HEADER_LENGTH - 3) { - maxPacketSize - RtpConstants.RTP_HEADER_LENGTH - 3 - } else { - fixedBuffer.remaining() - } - val buffer = getBuffer(length + RtpConstants.RTP_HEADER_LENGTH + 3) - buffer[RtpConstants.RTP_HEADER_LENGTH] = header[0] - buffer[RtpConstants.RTP_HEADER_LENGTH + 1] = header[1] - buffer[RtpConstants.RTP_HEADER_LENGTH + 2] = header[2] - val rtpTs = updateTimeStamp(buffer, ts) - fixedBuffer.get(buffer, RtpConstants.RTP_HEADER_LENGTH + 3, length) - sum += length - // Last packet before next NAL - if (sum >= naluLength) { - // End bit on - buffer[RtpConstants.RTP_HEADER_LENGTH + 2] = buffer[RtpConstants.RTP_HEADER_LENGTH + 2].plus(0x40).toByte() - markPacket(buffer) //mark end frame + nals.forEachIndexed { index, data -> + var sum = 0 + val nalSize = data.remaining() + while (sum < nalSize) { + val length = if (nalSize - sum > maxPacketSize - RtpConstants.RTP_HEADER_LENGTH - 3) { + maxPacketSize - RtpConstants.RTP_HEADER_LENGTH - 3 + } else { + data.remaining() + } + val buffer = getBuffer(length + RtpConstants.RTP_HEADER_LENGTH + 3) + buffer[RtpConstants.RTP_HEADER_LENGTH] = header[0] + buffer[RtpConstants.RTP_HEADER_LENGTH + 1] = header[1] + // Switch start bit + buffer[RtpConstants.RTP_HEADER_LENGTH + 2] = if (sum > 0) header[2] and 0x7F else header[2] + val rtpTs = updateTimeStamp(buffer, ts) + data.get(buffer, RtpConstants.RTP_HEADER_LENGTH + 3, length) + sum += length + // Last packet before next NAL + if (sum >= nalSize) { + // End bit on + buffer[RtpConstants.RTP_HEADER_LENGTH + 2] = buffer[RtpConstants.RTP_HEADER_LENGTH + 2].plus(0x40).toByte() + if (index == nals.size - 1) markPacket(buffer) //mark end frame + } + updateSeq(buffer) + val rtpFrame = RtpFrame(buffer, rtpTs, buffer.size, channelIdentifier) + frames.add(rtpFrame) } - updateSeq(buffer) - val rtpFrame = RtpFrame(buffer, rtpTs, buffer.size, channelIdentifier) - frames.add(rtpFrame) - // Switch start bit - header[2] = header[2] and 0x7F } } if (frames.isNotEmpty()) callback(frames)