|
9 | 9 | from mbedtls._tls import _PSKSToreProxy as PSKStoreProxy |
10 | 10 | from mbedtls.pk import RSA |
11 | 11 | from mbedtls.tls import * |
12 | | -from mbedtls.tls import TLSSession |
| 12 | +from mbedtls.tls import HandshakeStep, TLSSession |
13 | 13 | from mbedtls.x509 import CRT, CSR, BasicConstraints |
14 | 14 |
|
15 | 15 |
|
@@ -487,3 +487,114 @@ def test_context(self, context): |
487 | 487 |
|
488 | 488 | def test_wrap_buffers(self, context): |
489 | 489 | assert isinstance(context.wrap_buffers(), TLSWrappedBuffer) |
| 490 | + |
| 491 | + |
| 492 | +class TestTLSWrappedBuffer: |
| 493 | + @pytest.fixture |
| 494 | + def server_hostname(self): |
| 495 | + return "hostname" |
| 496 | + |
| 497 | + @pytest.fixture |
| 498 | + def psk(self): |
| 499 | + return ("cli", b"secret") |
| 500 | + |
| 501 | + @pytest.fixture |
| 502 | + def cli_conf(self, psk): |
| 503 | + return TLSConfiguration( |
| 504 | + pre_shared_key=psk, validate_certificates=False |
| 505 | + ) |
| 506 | + |
| 507 | + @pytest.fixture |
| 508 | + def srv_conf(self, psk): |
| 509 | + return TLSConfiguration( |
| 510 | + pre_shared_key_store=dict((psk,)), validate_certificates=False |
| 511 | + ) |
| 512 | + |
| 513 | + @pytest.fixture |
| 514 | + def client(Self, cli_conf, server_hostname): |
| 515 | + ctx = ClientContext(cli_conf).wrap_buffers(server_hostname) |
| 516 | + yield ctx |
| 517 | + ctx.shutdown() |
| 518 | + |
| 519 | + @pytest.fixture |
| 520 | + def server(self, srv_conf, server_hostname): |
| 521 | + ctx = ServerContext(srv_conf).wrap_buffers() |
| 522 | + yield ctx |
| 523 | + ctx.shutdown() |
| 524 | + |
| 525 | + def test_e2e_handshake_and_communicate(self, client, server): |
| 526 | + def do_io(*, src, dst, amt=1024): |
| 527 | + in_transit = src.peek_outgoing(amt) |
| 528 | + src.consume_outgoing(len(in_transit)) |
| 529 | + dst.receive_from_network(in_transit) |
| 530 | + |
| 531 | + def do_handshake(*end_state_pair): |
| 532 | + for end, state in end_state_pair: |
| 533 | + end.do_handshake() |
| 534 | + assert end.context._state is state |
| 535 | + |
| 536 | + do_handshake( |
| 537 | + (server, HandshakeStep.CLIENT_HELLO), |
| 538 | + (client, HandshakeStep.CLIENT_HELLO), |
| 539 | + (client, HandshakeStep.SERVER_HELLO), |
| 540 | + ) |
| 541 | + |
| 542 | + do_io(src=client, dst=server) |
| 543 | + do_handshake( |
| 544 | + (server, HandshakeStep.SERVER_HELLO), |
| 545 | + (server, HandshakeStep.SERVER_CERTIFICATE), |
| 546 | + (server, HandshakeStep.SERVER_KEY_EXCHANGE), |
| 547 | + (server, HandshakeStep.CERTIFICATE_REQUEST), |
| 548 | + (server, HandshakeStep.SERVER_HELLO_DONE), |
| 549 | + (server, HandshakeStep.CLIENT_CERTIFICATE), |
| 550 | + (server, HandshakeStep.CLIENT_KEY_EXCHANGE), |
| 551 | + ) |
| 552 | + assert client.negotiated_protocol() == server.negotiated_protocol() |
| 553 | + |
| 554 | + do_io(src=server, dst=client) |
| 555 | + do_handshake( |
| 556 | + (client, HandshakeStep.SERVER_CERTIFICATE), |
| 557 | + (client, HandshakeStep.SERVER_KEY_EXCHANGE), |
| 558 | + (client, HandshakeStep.CERTIFICATE_REQUEST), |
| 559 | + (client, HandshakeStep.SERVER_HELLO_DONE), |
| 560 | + (client, HandshakeStep.CLIENT_CERTIFICATE), |
| 561 | + (client, HandshakeStep.CLIENT_KEY_EXCHANGE), |
| 562 | + (client, HandshakeStep.CERTIFICATE_VERIFY), |
| 563 | + (client, HandshakeStep.CLIENT_CHANGE_CIPHER_SPEC), |
| 564 | + (client, HandshakeStep.CLIENT_FINISHED), |
| 565 | + (client, HandshakeStep.SERVER_CHANGE_CIPHER_SPEC), |
| 566 | + ) |
| 567 | + assert ( |
| 568 | + client.negotiated_tls_version() == server.negotiated_tls_version() |
| 569 | + ) |
| 570 | + |
| 571 | + do_io(src=client, dst=server) |
| 572 | + do_handshake( |
| 573 | + (server, HandshakeStep.CERTIFICATE_VERIFY), |
| 574 | + (server, HandshakeStep.CLIENT_CHANGE_CIPHER_SPEC), |
| 575 | + (server, HandshakeStep.CLIENT_FINISHED), |
| 576 | + (server, HandshakeStep.SERVER_CHANGE_CIPHER_SPEC), |
| 577 | + (server, HandshakeStep.SERVER_FINISHED), |
| 578 | + (server, HandshakeStep.FLUSH_BUFFERS), |
| 579 | + (server, HandshakeStep.HANDSHAKE_WRAPUP), |
| 580 | + (server, HandshakeStep.HANDSHAKE_OVER), |
| 581 | + ) |
| 582 | + |
| 583 | + do_io(src=server, dst=client) |
| 584 | + do_handshake( |
| 585 | + (client, HandshakeStep.SERVER_FINISHED), |
| 586 | + (client, HandshakeStep.FLUSH_BUFFERS), |
| 587 | + (client, HandshakeStep.HANDSHAKE_WRAPUP), |
| 588 | + (client, HandshakeStep.HANDSHAKE_OVER), |
| 589 | + ) |
| 590 | + assert client.cipher() == server.cipher() |
| 591 | + |
| 592 | + secret = b"a very secret message" |
| 593 | + |
| 594 | + amt = client.write(secret) |
| 595 | + do_io(src=client, dst=server) |
| 596 | + assert server.read(amt) == secret |
| 597 | + |
| 598 | + amt = server.write(secret) |
| 599 | + do_io(src=server, dst=client) |
| 600 | + assert client.read(amt) == secret |
0 commit comments