Skip to content

Commit 930f219

Browse files
committed
tls: Add end-to-end test with TLSWrappedBuffer
Closes #55.
1 parent 385fd86 commit 930f219

File tree

1 file changed

+112
-1
lines changed

1 file changed

+112
-1
lines changed

tests/test_tls.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from mbedtls._tls import _PSKSToreProxy as PSKStoreProxy
1010
from mbedtls.pk import RSA
1111
from mbedtls.tls import *
12-
from mbedtls.tls import TLSSession
12+
from mbedtls.tls import HandshakeStep, TLSSession
1313
from mbedtls.x509 import CRT, CSR, BasicConstraints
1414

1515

@@ -487,3 +487,114 @@ def test_context(self, context):
487487

488488
def test_wrap_buffers(self, context):
489489
assert isinstance(context.wrap_buffers(), TLSWrappedBuffer)
490+
491+
492+
class TestTLSWrappedBuffer:
493+
@pytest.fixture
494+
def server_hostname(self):
495+
return "hostname"
496+
497+
@pytest.fixture
498+
def psk(self):
499+
return ("cli", b"secret")
500+
501+
@pytest.fixture
502+
def cli_conf(self, psk):
503+
return TLSConfiguration(
504+
pre_shared_key=psk, validate_certificates=False
505+
)
506+
507+
@pytest.fixture
508+
def srv_conf(self, psk):
509+
return TLSConfiguration(
510+
pre_shared_key_store=dict((psk,)), validate_certificates=False
511+
)
512+
513+
@pytest.fixture
514+
def client(Self, cli_conf, server_hostname):
515+
ctx = ClientContext(cli_conf).wrap_buffers(server_hostname)
516+
yield ctx
517+
ctx.shutdown()
518+
519+
@pytest.fixture
520+
def server(self, srv_conf, server_hostname):
521+
ctx = ServerContext(srv_conf).wrap_buffers()
522+
yield ctx
523+
ctx.shutdown()
524+
525+
def test_e2e_handshake_and_communicate(self, client, server):
526+
def do_io(*, src, dst, amt=1024):
527+
in_transit = src.peek_outgoing(amt)
528+
src.consume_outgoing(len(in_transit))
529+
dst.receive_from_network(in_transit)
530+
531+
def do_handshake(*end_state_pair):
532+
for end, state in end_state_pair:
533+
end.do_handshake()
534+
assert end.context._state is state
535+
536+
do_handshake(
537+
(server, HandshakeStep.CLIENT_HELLO),
538+
(client, HandshakeStep.CLIENT_HELLO),
539+
(client, HandshakeStep.SERVER_HELLO),
540+
)
541+
542+
do_io(src=client, dst=server)
543+
do_handshake(
544+
(server, HandshakeStep.SERVER_HELLO),
545+
(server, HandshakeStep.SERVER_CERTIFICATE),
546+
(server, HandshakeStep.SERVER_KEY_EXCHANGE),
547+
(server, HandshakeStep.CERTIFICATE_REQUEST),
548+
(server, HandshakeStep.SERVER_HELLO_DONE),
549+
(server, HandshakeStep.CLIENT_CERTIFICATE),
550+
(server, HandshakeStep.CLIENT_KEY_EXCHANGE),
551+
)
552+
assert client.negotiated_protocol() == server.negotiated_protocol()
553+
554+
do_io(src=server, dst=client)
555+
do_handshake(
556+
(client, HandshakeStep.SERVER_CERTIFICATE),
557+
(client, HandshakeStep.SERVER_KEY_EXCHANGE),
558+
(client, HandshakeStep.CERTIFICATE_REQUEST),
559+
(client, HandshakeStep.SERVER_HELLO_DONE),
560+
(client, HandshakeStep.CLIENT_CERTIFICATE),
561+
(client, HandshakeStep.CLIENT_KEY_EXCHANGE),
562+
(client, HandshakeStep.CERTIFICATE_VERIFY),
563+
(client, HandshakeStep.CLIENT_CHANGE_CIPHER_SPEC),
564+
(client, HandshakeStep.CLIENT_FINISHED),
565+
(client, HandshakeStep.SERVER_CHANGE_CIPHER_SPEC),
566+
)
567+
assert (
568+
client.negotiated_tls_version() == server.negotiated_tls_version()
569+
)
570+
571+
do_io(src=client, dst=server)
572+
do_handshake(
573+
(server, HandshakeStep.CERTIFICATE_VERIFY),
574+
(server, HandshakeStep.CLIENT_CHANGE_CIPHER_SPEC),
575+
(server, HandshakeStep.CLIENT_FINISHED),
576+
(server, HandshakeStep.SERVER_CHANGE_CIPHER_SPEC),
577+
(server, HandshakeStep.SERVER_FINISHED),
578+
(server, HandshakeStep.FLUSH_BUFFERS),
579+
(server, HandshakeStep.HANDSHAKE_WRAPUP),
580+
(server, HandshakeStep.HANDSHAKE_OVER),
581+
)
582+
583+
do_io(src=server, dst=client)
584+
do_handshake(
585+
(client, HandshakeStep.SERVER_FINISHED),
586+
(client, HandshakeStep.FLUSH_BUFFERS),
587+
(client, HandshakeStep.HANDSHAKE_WRAPUP),
588+
(client, HandshakeStep.HANDSHAKE_OVER),
589+
)
590+
assert client.cipher() == server.cipher()
591+
592+
secret = b"a very secret message"
593+
594+
amt = client.write(secret)
595+
do_io(src=client, dst=server)
596+
assert server.read(amt) == secret
597+
598+
amt = server.write(secret)
599+
do_io(src=server, dst=client)
600+
assert client.read(amt) == secret

0 commit comments

Comments
 (0)