@@ -1131,12 +1131,31 @@ cdef class _BaseContext:
11311131 raise NotImplementedError
11321132
11331133
1134+ TLS_BUFFER_CAPACITY = 2 << 14
1135+ # 32K (MBEDTLS_SSL_DTLS_MAX_BUFFERING )
1136+
1137+
11341138cdef class MbedTLSBuffer:
1135- def __init__(self , _BaseContext context ):
1139+ def __init__(self , _BaseContext context , server_hostname = None ):
11361140 self ._context = context
11371141 _exc.check_error(_tls.mbedtls_ssl_setup(& self ._ctx, & self ._context._conf._ctx))
1142+ self ._output_buffer = _rb.RingBuffer(TLS_BUFFER_CAPACITY)
1143+ self ._input_buffer = _rb.RingBuffer(TLS_BUFFER_CAPACITY)
1144+ self ._c_buffers = _tls._C_Buffers(
1145+ & self ._output_buffer._ctx,
1146+ & self ._input_buffer._ctx
1147+ )
1148+ self ._reset()
1149+ _tls.mbedtls_ssl_set_bio(
1150+ & self ._ctx,
1151+ & self ._c_buffers,
1152+ buffer_write,
1153+ buffer_read,
1154+ NULL
1155+ )
1156+ self ._set_hostname(server_hostname)
11381157
1139- def __cinit__ (self , _BaseContext context ):
1158+ def __cinit__ (self , _BaseContext context , server_hostname = None ):
11401159 """ Initialize an `ssl_context`."""
11411160 _tls.mbedtls_ssl_init(& self ._ctx)
11421161 _tls.mbedtls_ssl_set_timer_cb(
@@ -1149,17 +1168,8 @@ cdef class MbedTLSBuffer:
11491168 """ Free and clear the internal structures of ctx."""
11501169 _tls.mbedtls_ssl_free(& self ._ctx)
11511170
1152- cpdef set_bio(self , _rb.RingBuffer output, _rb.RingBuffer input ):
1153- self ._reset()
1154- self ._c_buffers = _tls._C_Buffers(& output._ctx, & input ._ctx)
1155- _tls.mbedtls_ssl_set_bio(
1156- & self ._ctx,
1157- & self ._c_buffers,
1158- buffer_write,
1159- buffer_read,
1160- NULL )
1161-
11621171 def __getstate__ (self ):
1172+ # We could make this pickable by copying the buffers.
11631173 raise TypeError (f" cannot pickle {self.__class__.__name__!r} object" )
11641174
11651175 @property
@@ -1218,6 +1228,17 @@ cdef class MbedTLSBuffer:
12181228 def _close (self ):
12191229 self .shutdown()
12201230
1231+ def read (self , amt ):
1232+ # PEP 543
1233+ if amt <= 0 :
1234+ return b" "
1235+ buffer = bytearray(amt)
1236+ view = memoryview(buffer )
1237+ nread = 0
1238+ while nread != amt and not self ._input_buffer.empty():
1239+ nread += self .readinto(view[nread:], amt - nread)
1240+ return bytes(buffer [:nread])
1241+
12211242 def readinto (self , unsigned char[:] buffer not None , size_t amt ):
12221243 if buffer .size == 0 :
12231244 return 0
@@ -1255,7 +1276,25 @@ cdef class MbedTLSBuffer:
12551276 else :
12561277 self ._reset()
12571278 _exc.check_error(ret)
1258- return written
1279+ assert written == len (buffer )
1280+ return len (self ._output_buffer)
1281+
1282+ def receive_from_network (self , data ):
1283+ # PEP 543
1284+ # Append data to input buffer.
1285+ self ._input_buffer.write(data, len (data))
1286+
1287+ def peek_outgoing (self , amt ):
1288+ # PEP 543
1289+ # Read from output buffer.
1290+ if amt == 0 :
1291+ return b" "
1292+ return self ._output_buffer.peek(amt)
1293+
1294+ def consume_outgoing (self , amt ):
1295+ """ Consume `amt` bytes from the output buffer."""
1296+ # PEP 543
1297+ self ._output_buffer.consume(amt)
12591298
12601299 def getpeercert (self , binary_form = False ):
12611300 """ Return the peer certificate, or None."""
0 commit comments