Skip to content

Commit 1b0f267

Browse files
committed
tls: Fix WantReadError/WantWriteError signaling
This lets us further simplify TLSWrappedSocket.do_handshake().
1 parent 69d5c85 commit 1b0f267

File tree

3 files changed

+23
-19
lines changed

3 files changed

+23
-19
lines changed

src/mbedtls/_tls.pyx

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ cdef int buffer_read(void *ctx, unsigned char *buf, const size_t len) nogil:
8484
"""Read from input buffer."""
8585
c_buf = <_tls._C_Buffers *> ctx
8686
if _rb.c_len(c_buf.in_ctx) == 0:
87-
return _tls.MBEDTLS_ERR_SSL_WANT_WRITE
87+
return _tls.MBEDTLS_ERR_SSL_WANT_READ
8888
return _rb.c_readinto(c_buf.in_ctx, buf, len)
8989

9090

@@ -1247,7 +1247,6 @@ cdef class MbedTLSBuffer:
12471247
return 0
12481248
if amt <= 0:
12491249
return 0
1250-
# cdef size_t avail = _tls.mbedtls_ssl_get_bytes_avail(&self._ctx)
12511250
read = _tls.mbedtls_ssl_read(&self._ctx, &buffer[0], amt)
12521251
if read > 0:
12531252
return read
@@ -1342,26 +1341,28 @@ cdef class MbedTLSBuffer:
13421341
def do_handshake(self):
13431342
if self._handshake_state is HandshakeStep.HANDSHAKE_OVER:
13441343
raise ValueError("handshake already over")
1345-
self._handle_handshake_response(_tls.mbedtls_ssl_handshake_step(&self._ctx))
1344+
self._handle_handshake_response(
1345+
_tls.mbedtls_ssl_handshake_step(&self._ctx)
1346+
)
13461347

13471348
def _renegotiate(self):
13481349
"""Initialize an SSL renegotiation on the running connection."""
13491350
self._handle_handshake_response(_tls.mbedtls_ssl_renegotiate(&self._ctx))
13501351

13511352
def _handle_handshake_response(self, ret):
1352-
if ret == 0:
1353-
return
1354-
elif ret == _tls.MBEDTLS_ERR_SSL_WANT_READ:
1353+
if ret == _tls.MBEDTLS_ERR_SSL_WANT_READ:
13551354
raise WantReadError()
1356-
elif ret == _tls.MBEDTLS_ERR_SSL_WANT_WRITE:
1355+
if ret == _tls.MBEDTLS_ERR_SSL_WANT_WRITE:
13571356
raise WantWriteError()
1358-
elif ret == _tls.MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED:
1357+
if ret == _tls.MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED:
13591358
self._reset()
13601359
raise HelloVerifyRequest()
1361-
else:
1362-
assert ret < 0
1360+
if ret < 0:
13631361
self._reset()
13641362
_exc.check_error(ret)
1363+
if ret == 0 and self._output_buffer:
1364+
raise WantWriteError
1365+
assert ret == 0
13651366

13661367
def _get_channel_binding(self, cb_type="tls-unique"):
13671368
return None

src/mbedtls/tls.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,10 @@ def __exit__(self, *exc_info):
186186
def __str__(self):
187187
return str(self._socket)
188188

189+
@property
190+
def _handshake_state(self):
191+
return self._buffer._handshake_state
192+
189193
# PEP 543 requires the full socket API.
190194

191195
@property
@@ -333,19 +337,16 @@ def shutdown(self, how):
333337
# PEP 543 adds the following methods.
334338

335339
def do_handshake(self):
336-
while (
337-
self._buffer._handshake_state is not HandshakeStep.HANDSHAKE_OVER
338-
):
340+
while self._handshake_state is not HandshakeStep.HANDSHAKE_OVER:
339341
try:
340342
self._buffer.do_handshake()
341-
amt = self._socket.send(self._buffer.peek_outgoing(1024))
342-
self._buffer.consume_outgoing(amt)
343343
except WantReadError:
344-
amt = self._socket.send(self._buffer.peek_outgoing(1024))
345-
self._buffer.consume_outgoing(amt)
346-
except WantWriteError:
347344
data = self._socket.recv(1024)
348345
self._buffer.receive_from_network(data)
346+
except WantWriteError:
347+
in_transit = self._buffer.peek_outgoing(1024)
348+
amt = self._socket.send(in_transit)
349+
self._buffer.consume_outgoing(amt)
349350

350351
def setcookieparam(self, param):
351352
self._buffer.setcookieparam(param)

tests/test_tls.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import datetime as dt
22
import pickle
33
import socket
4+
from contextlib import suppress
45

56
import pytest
67

@@ -515,7 +516,8 @@ def do_io(*, src, dst, amt=1024):
515516

516517
def do_handshake(*end_state_pair):
517518
for end, state in end_state_pair:
518-
end.do_handshake()
519+
with suppress(WantReadError, WantWriteError):
520+
end.do_handshake()
519521
assert end._handshake_state is state
520522

521523
do_handshake(

0 commit comments

Comments
 (0)