Skip to content

Commit ffe092d

Browse files
committed
tls: Move buffers- and encryption-handling to cython
1 parent 78731b5 commit ffe092d

File tree

3 files changed

+60
-44
lines changed

3 files changed

+60
-44
lines changed

src/mbedtls/_tls.pxd

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,8 @@ cdef class _BaseContext:
456456
cdef class MbedTLSBuffer:
457457
cdef _BaseContext _context
458458
cdef mbedtls_ssl_context _ctx
459+
cdef _rb.RingBuffer _output_buffer
460+
cdef _rb.RingBuffer _input_buffer
459461
cdef _C_Buffers _c_buffers
460-
cpdef set_bio(self, _rb.RingBuffer input, _rb.RingBuffer output)
461462
# DTLS only:
462463
cdef mbedtls_timing_delay_context _timer

src/mbedtls/_tls.pyx

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,12 +1131,31 @@ cdef class _BaseContext:
11311131
raise NotImplementedError
11321132

11331133

1134+
TLS_BUFFER_CAPACITY = 2 << 14
1135+
# 32K (MBEDTLS_SSL_DTLS_MAX_BUFFERING)
1136+
1137+
11341138
cdef class MbedTLSBuffer:
1135-
def __init__(self, _BaseContext context):
1139+
def __init__(self, _BaseContext context, server_hostname=None):
11361140
self._context = context
11371141
_exc.check_error(_tls.mbedtls_ssl_setup(&self._ctx, &self._context._conf._ctx))
1142+
self._output_buffer = _rb.RingBuffer(TLS_BUFFER_CAPACITY)
1143+
self._input_buffer = _rb.RingBuffer(TLS_BUFFER_CAPACITY)
1144+
self._c_buffers = _tls._C_Buffers(
1145+
&self._output_buffer._ctx,
1146+
&self._input_buffer._ctx
1147+
)
1148+
self._reset()
1149+
_tls.mbedtls_ssl_set_bio(
1150+
&self._ctx,
1151+
&self._c_buffers,
1152+
buffer_write,
1153+
buffer_read,
1154+
NULL
1155+
)
1156+
self._set_hostname(server_hostname)
11381157

1139-
def __cinit__(self, _BaseContext context):
1158+
def __cinit__(self, _BaseContext context, server_hostname=None):
11401159
"""Initialize an `ssl_context`."""
11411160
_tls.mbedtls_ssl_init(&self._ctx)
11421161
_tls.mbedtls_ssl_set_timer_cb(
@@ -1149,17 +1168,8 @@ cdef class MbedTLSBuffer:
11491168
"""Free and clear the internal structures of ctx."""
11501169
_tls.mbedtls_ssl_free(&self._ctx)
11511170

1152-
cpdef set_bio(self, _rb.RingBuffer output, _rb.RingBuffer input):
1153-
self._reset()
1154-
self._c_buffers = _tls._C_Buffers(&output._ctx, &input._ctx)
1155-
_tls.mbedtls_ssl_set_bio(
1156-
&self._ctx,
1157-
&self._c_buffers,
1158-
buffer_write,
1159-
buffer_read,
1160-
NULL)
1161-
11621171
def __getstate__(self):
1172+
# We could make this pickable by copying the buffers.
11631173
raise TypeError(f"cannot pickle {self.__class__.__name__!r} object")
11641174

11651175
@property
@@ -1218,6 +1228,17 @@ cdef class MbedTLSBuffer:
12181228
def _close(self):
12191229
self.shutdown()
12201230

1231+
def read(self, amt):
1232+
# PEP 543
1233+
if amt <= 0:
1234+
return b""
1235+
buffer = bytearray(amt)
1236+
view = memoryview(buffer)
1237+
nread = 0
1238+
while nread != amt and not self._input_buffer.empty():
1239+
nread += self.readinto(view[nread:], amt - nread)
1240+
return bytes(buffer[:nread])
1241+
12211242
def readinto(self, unsigned char[:] buffer not None, size_t amt):
12221243
if buffer.size == 0:
12231244
return 0
@@ -1255,7 +1276,25 @@ cdef class MbedTLSBuffer:
12551276
else:
12561277
self._reset()
12571278
_exc.check_error(ret)
1258-
return written
1279+
assert written == len(buffer)
1280+
return len(self._output_buffer)
1281+
1282+
def receive_from_network(self, data):
1283+
# PEP 543
1284+
# Append data to input buffer.
1285+
self._input_buffer.write(data, len(data))
1286+
1287+
def peek_outgoing(self, amt):
1288+
# PEP 543
1289+
# Read from output buffer.
1290+
if amt == 0:
1291+
return b""
1292+
return self._output_buffer.peek(amt)
1293+
1294+
def consume_outgoing(self, amt):
1295+
"""Consume `amt` bytes from the output buffer."""
1296+
# PEP 543
1297+
self._output_buffer.consume(amt)
12591298

12601299
def getpeercert(self, binary_form=False):
12611300
"""Return the peer certificate, or None."""

src/mbedtls/tls.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,6 @@
77
import sys
88
from contextlib import suppress
99

10-
if sys.version_info < (3, 8):
11-
from typing_extensions import Final
12-
else:
13-
from typing import Final
14-
1510
import mbedtls._ringbuf as _rb
1611

1712
from ._tls import (
@@ -54,9 +49,6 @@
5449
"ciphers_available",
5550
)
5651

57-
TLS_BUFFER_CAPACITY: Final = 2 << 14
58-
# 32K (MBEDTLS_SSL_DTLS_MAX_BUFFERING)
59-
6052

6153
class TLSRecordHeader:
6254
"""Encode/decode TLS record protocol format."""
@@ -173,17 +165,12 @@ def wrap_buffers(self):
173165
class TLSWrappedBuffer:
174166
# _pep543.TLSWrappedBuffer
175167
def __init__(self, context, server_hostname=None):
176-
self._output_buffer = _rb.RingBuffer(TLS_BUFFER_CAPACITY)
177-
self._input_buffer = _rb.RingBuffer(TLS_BUFFER_CAPACITY)
178-
self._tlsbuf = MbedTLSBuffer(context)
179-
self._tlsbuf.set_bio(self._output_buffer, self._input_buffer)
180-
self._tlsbuf._set_hostname(server_hostname)
168+
self._tlsbuf = MbedTLSBuffer(context, server_hostname)
181169

182170
def __repr__(self):
183171
return "%s(%r)" % (type(self).__name__, self.context)
184172

185173
def __getstate__(self):
186-
# We could make this pickable by copying the buffers.
187174
raise TypeError(f"cannot pickle {self.__class__.__name__!r} object")
188175

189176
@property
@@ -196,24 +183,15 @@ def _handshake_state(self):
196183

197184
def read(self, amt):
198185
# PEP 543
199-
if amt <= 0:
200-
return b""
201-
buffer = bytearray(amt)
202-
view = memoryview(buffer)
203-
nread = 0
204-
while nread != amt and not self._input_buffer.empty():
205-
nread += self.readinto(view[nread:], amt - nread)
206-
return bytes(buffer[:nread])
186+
return self._tlsbuf.read(amt)
207187

208188
def readinto(self, buffer, amt):
209189
# PEP 543
210190
return self._tlsbuf.readinto(buffer, amt)
211191

212192
def write(self, buffer):
213193
# PEP 543
214-
amt = self._tlsbuf.write(buffer)
215-
assert amt == len(buffer)
216-
return len(self._output_buffer)
194+
return self._tlsbuf.write(buffer)
217195

218196
def do_handshake(self):
219197
# PEP 543
@@ -247,19 +225,17 @@ def shutdown(self):
247225
def receive_from_network(self, data):
248226
# PEP 543
249227
# Append data to input buffer.
250-
self._input_buffer.write(data, len(data))
228+
self._tlsbuf.receive_from_network(data)
251229

252230
def peek_outgoing(self, amt):
253231
# PEP 543
254232
# Read from output buffer.
255-
if amt == 0:
256-
return b""
257-
return self._output_buffer.peek(amt)
233+
return self._tlsbuf.peek_outgoing(amt)
258234

259235
def consume_outgoing(self, amt):
260236
"""Consume `amt` bytes from the output buffer."""
261237
# PEP 543
262-
self._output_buffer.consume(amt)
238+
self._tlsbuf.consume_outgoing(amt)
263239

264240

265241
class TLSWrappedSocket:

0 commit comments

Comments
 (0)