Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 127 additions & 21 deletions Lib/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
from ssl import Purpose, TLSVersion, _TLSContentType, _TLSMessageType, _TLSAlertType

Py_DEBUG_WIN32 = support.Py_DEBUG and sys.platform == 'win32'
HAS_KEYLOG = hasattr(ssl.SSLContext, 'keylog_filename')
requires_keylog = unittest.skipUnless(
HAS_KEYLOG, 'test requires OpenSSL 1.1.1 with keylog callback')

PROTOCOLS = sorted(ssl._PROTOCOL_NAMES)
HOST = socket_helper.HOST
Expand Down Expand Up @@ -576,6 +579,53 @@ def test_refcycle(self):
del ss
self.assertEqual(wr(), None)

@support.cpython_only
def test_sslsocket_ctx_refcycle(self):
# SSLSocket doesn't leak when it has a reference cycle with its context
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.check_hostname = False
s = socket.socket(socket.AF_INET)
ss = ctx.wrap_socket(s)
# Create a cycle: ctx -> callback -> ss -> ctx
def msg_cb(conn, direction, version, content_type, msg_type, data):
pass
msg_cb.ss = ss
ctx._msg_callback = msg_cb

ctx_wr = weakref.ref(ctx)
ss_wr = weakref.ref(ss)
ss.close()
del ctx, s, ss, msg_cb
gc.collect()
self.assertIs(ctx_wr(), None)
self.assertIs(ss_wr(), None)

@support.cpython_only
def test_sslsocket_owner_refcycle(self):
# SSLSocket doesn't leak when it has a reference cycle with its owner
class Owner:
pass
owner = Owner()
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.check_hostname = False
s = socket.socket(socket.AF_INET)
# owner is only available in SSLObject.wrap_bio or _ssl._SSLSocket directly
# but SSLSocket doesn't expose owner in wrap_socket.
# We can use _sslobj.owner if we want to test the C-level leak.
ss = ctx.wrap_socket(s)
# SSLSocket._sslobj is None if wrap_socket failed or was not called correctly
# but here it should be a _ssl._SSLSocket
ss.owner = owner
owner.ss = ss

owner_wr = weakref.ref(owner)
ss_wr = weakref.ref(ss)
ss.close()
del owner, ctx, s, ss
gc.collect()
self.assertIs(owner_wr(), None)
self.assertIs(ss_wr(), None)

def test_wrapped_unconnected(self):
# Methods on an unconnected SSLSocket propagate the original
# OSError raise by the underlying socket object.
Expand Down Expand Up @@ -1488,6 +1538,49 @@ def dummycallback(sock, servername, ctx, cycle=ctx):
gc.collect()
self.assertIs(wr(), None)

@unittest.skipUnless(ssl.HAS_PSK, 'TLS-PSK disabled on this OpenSSL build')
def test_psk_client_callback_refcycle(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
def psk_cb(hint, cycle=ctx):
return (None, b"psk")
ctx.set_psk_client_callback(psk_cb)
wr = weakref.ref(ctx)
del ctx, psk_cb
gc.collect()
self.assertIs(wr(), None)

@unittest.skipUnless(ssl.HAS_PSK, 'TLS-PSK disabled on this OpenSSL build')
def test_psk_server_callback_refcycle(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
def psk_cb(identity, cycle=ctx):
return b"psk"
ctx.set_psk_server_callback(psk_cb)
wr = weakref.ref(ctx)
del ctx, psk_cb
gc.collect()
self.assertIs(wr(), None)

def test_msg_callback_refcycle(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
def msg_cb(conn, direction, version, content_type, msg_type, data, cycle=ctx):
pass
ctx._msg_callback = msg_cb
wr = weakref.ref(ctx)
del ctx, msg_cb
gc.collect()
self.assertIs(wr(), None)

@requires_keylog
def test_keylog_filename_refcycle(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.keylog_filename = os_helper.TESTFN
# keylog_filename is a string, so it can't create a cycle itself,
# but we check that SSLContext still clears it.
ctx_wr = weakref.ref(ctx)
del ctx
gc.collect()
self.assertIs(ctx_wr(), None)

def test_cert_store_stats(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self.assertEqual(ctx.cert_store_stats(),
Expand Down Expand Up @@ -4709,6 +4802,36 @@ def test_session_handling(self):
self.assertEqual(str(e.exception),
'Session refers to a different SSLContext.')

@support.cpython_only
def test_session_refcycle(self):
# SSLSession doesn't leak when it has a reference cycle with its context
client_context, server_context, hostname = testing_context()
client_context.maximum_version = ssl.TLSVersion.TLSv1_2
server = ThreadedEchoServer(context=server_context, chatty=False)
with server:
with client_context.wrap_socket(socket.socket(),
server_hostname=hostname) as s:
s.connect((HOST, server.port))
session = s.session

# Create a cycle: session -> ctx -> callback -> session
def msg_cb(conn, direction, version, content_type, msg_type, data):
pass
msg_cb.session = session
client_context._msg_callback = msg_cb

# _ssl.SSLSession doesn't support weakrefs, so we use gc.get_referrers
# to check if it's still alive.
import gc
del session, client_context, server_context, s, msg_cb
gc.collect()
# If SSLSession is still alive, it should be in gc.get_objects()
# but that's a bit unreliable. Better check that there are no
# SSLSession objects left.
sessions = [obj for obj in gc.get_objects()
if type(obj).__name__ == 'SSLSession']
self.assertEqual(sessions, [])

@requires_tls_version('TLSv1_2')
@unittest.skipUnless(ssl.HAS_PSK, 'TLS-PSK disabled on this OpenSSL build')
def test_psk(self):
Expand Down Expand Up @@ -5163,9 +5286,6 @@ def test_internal_chain_server(self):
self.assertEqual(res, b'\x02\n')


HAS_KEYLOG = hasattr(ssl.SSLContext, 'keylog_filename')
requires_keylog = unittest.skipUnless(
HAS_KEYLOG, 'test requires OpenSSL 1.1.1 with keylog callback')

class TestSSLDebug(unittest.TestCase):

Expand All @@ -5175,17 +5295,13 @@ def keylog_lines(self, fname=os_helper.TESTFN):

@requires_keylog
def test_keylog_defaults(self):
os_helper.unlink(os_helper.TESTFN)
self.addCleanup(os_helper.unlink, os_helper.TESTFN)
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self.assertEqual(ctx.keylog_filename, None)

self.assertFalse(os.path.isfile(os_helper.TESTFN))
try:
ctx.keylog_filename = os_helper.TESTFN
except RuntimeError:
if Py_DEBUG_WIN32:
self.skipTest("not supported on Win32 debug build")
raise
ctx.keylog_filename = os_helper.TESTFN
self.assertEqual(ctx.keylog_filename, os_helper.TESTFN)
self.assertTrue(os.path.isfile(os_helper.TESTFN))
self.assertEqual(self.keylog_lines(), 1)
Expand All @@ -5206,12 +5322,7 @@ def test_keylog_filename(self):
self.addCleanup(os_helper.unlink, os_helper.TESTFN)
client_context, server_context, hostname = testing_context()

try:
client_context.keylog_filename = os_helper.TESTFN
except RuntimeError:
if Py_DEBUG_WIN32:
self.skipTest("not supported on Win32 debug build")
raise
client_context.keylog_filename = os_helper.TESTFN

server = ThreadedEchoServer(context=server_context, chatty=False)
with server:
Expand Down Expand Up @@ -5254,12 +5365,7 @@ def test_keylog_env(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self.assertEqual(ctx.keylog_filename, None)

try:
ctx = ssl.create_default_context()
except RuntimeError:
if Py_DEBUG_WIN32:
self.skipTest("not supported on Win32 debug build")
raise
ctx = ssl.create_default_context()
self.assertEqual(ctx.keylog_filename, os_helper.TESTFN)

ctx = ssl._create_stdlib_context()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix memory leak in SSLContext
Loading
Loading