From 6c8f1e840f06b6df6c7348cf1e96a5e5fa6ac9ea Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Sat, 13 Sep 2025 12:42:24 -0400 Subject: [PATCH] Update tests to not rely on mutating contexts --- tests/test_ssl.py | 56 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 39 insertions(+), 17 deletions(-) diff --git a/tests/test_ssl.py b/tests/test_ssl.py index 3cab6f15..159922c5 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -1163,6 +1163,7 @@ def test_set_proto_version(self) -> None: with pytest.raises(Error, match="unsupported protocol"): self._handshake_test(server_context, client_context) + client_context = Context(TLS_METHOD) client_context.set_max_proto_version(0) self._handshake_test(server_context, client_context) @@ -1577,8 +1578,6 @@ def test_set_verify_callback_reference(self) -> None: load_certificate(FILETYPE_PEM, root_cert_pem) ) - clientContext = Context(TLSv1_2_METHOD) - clients = [] for i in range(5): @@ -1586,13 +1585,16 @@ def test_set_verify_callback_reference(self) -> None: def verify_callback(*args: object) -> bool: return True + # Create a fresh client context for each iteration since contexts + # cannot be mutated after use + clientContext = Context(TLSv1_2_METHOD) + clientContext.set_verify(VERIFY_PEER, verify_callback) + serverSocket, clientSocket = socket_pair() client = Connection(clientContext, clientSocket) clients.append((serverSocket, client)) - clientContext.set_verify(VERIFY_PEER, verify_callback) - gc.collect() # Make them talk to each other. @@ -2921,21 +2923,22 @@ def callback( del callback conn = Connection(context, None) - context.set_verify(VERIFY_NONE) collect() collect() assert tracker() + # Setting a new callback on the connection should maintain the original + # context callback reference conn.set_verify( VERIFY_PEER, lambda conn, cert, errnum, depth, ok: bool(ok) ) collect() collect() + + # The callback should still be referenced - check that it exists callback_ref = tracker() - if callback_ref is not None: # pragma: nocover - referrers = get_referrers(callback_ref) - assert len(referrers) == 1 + assert callback_ref is not None def test_get_session_unconnected(self) -> None: """ @@ -3973,12 +3976,10 @@ class TestMemoryBIO: Tests for `OpenSSL.SSL.Connection` using a memory BIO. """ - def _server(self, sock: socket | None) -> Connection: + def _create_server_context(self) -> Context: """ - Create a new server-side SSL `Connection` object wrapped around `sock`. + Create a configured server context with certificates and options. """ - # Create the server side Connection. This is mostly setup boilerplate - # - use TLSv1, use a particular certificate, etc. server_ctx = Context(SSLv23_METHOD) server_ctx.set_options(OP_NO_SSLv2 | OP_NO_SSLv3 | OP_SINGLE_DH_USE) server_ctx.set_verify( @@ -3995,6 +3996,23 @@ def _server(self, sock: socket | None) -> Connection: ) server_ctx.check_privatekey() server_store.add_cert(load_certificate(FILETYPE_PEM, root_cert_pem)) + return server_ctx + + def _server( + self, sock: socket | None, ctx: Context | None = None + ) -> Connection: + """ + Create a new server-side SSL `Connection` object wrapped around `sock`. + + :param sock: The socket to wrap, or None for memory BIO. + :param ctx: Optional pre-configured context. If None, creates a + default server context. + """ + if ctx is None: + server_ctx = self._create_server_context() + else: + server_ctx = ctx + # Here the Connection is actually created. If None is passed as the # 2nd parameter, it indicates a memory BIO should be created. server_conn = Connection(server_ctx, sock) @@ -4204,12 +4222,16 @@ def _check_client_ca_list( that `get_client_ca_list` returns the proper value at various times. """ - server = self._server(None) + # Create a server context and configure it before creating connections + server_ctx = self._create_server_context() + + # Configure the CA list before creating connections + expected = func(server_ctx) + + # Now create connections with the configured context + server = self._server(None, server_ctx) client = self._client(None) - assert client.get_client_ca_list() == [] - assert server.get_client_ca_list() == [] - ctx = server.get_context() - expected = func(ctx) + assert client.get_client_ca_list() == [] assert server.get_client_ca_list() == expected interact_in_memory(client, server)