22
33import numbers
44import pickle
5- from typing import Any , List , Union , cast
5+ import sys
6+ from typing import Any , Callable , List , Tuple , Type , TypeVar , Union , cast
67
78import pytest
89
1314from mbedtls .pk import (
1415 ECC ,
1516 RSA ,
16- CipherBase ,
1717 Curve ,
1818 DHClient ,
1919 DHServer ,
2626 get_supported_curves ,
2727)
2828
29+ if sys .version_info < (3 , 8 ):
30+ from typing_extensions import Literal
31+ else :
32+ from typing import Literal
33+
34+
35+ T = TypeVar ("T" )
36+
2937
3038def test_supported_curves () -> None :
3139 assert sorted (get_supported_curves ()) == [
@@ -55,14 +63,14 @@ def test_get_supported_ciphers() -> None:
5563 ]
5664
5765
58- def test_rsa_encryp_decrypt (randbytes ) -> None :
66+ def test_rsa_encryp_decrypt (randbytes : Callable [[ int ], bytes ] ) -> None :
5967 rsa = RSA ()
6068 rsa .generate (1024 )
6169 msg = randbytes (rsa .key_size - 11 )
6270 assert rsa .decrypt (rsa .encrypt (msg )) == msg
6371
6472
65- def do_generate (cipher : CipherBase ) :
73+ def do_generate (cipher : T ) -> bytes :
6674 if isinstance (cipher , RSA ):
6775 return cipher .generate (1024 )
6876 if isinstance (cipher , ECC ):
@@ -72,30 +80,38 @@ def do_generate(cipher: CipherBase):
7280
7381class TestECPoint :
7482 @pytest .fixture (params = [(MPI (1 ), MPI (2 ), MPI (3 )), (1 , 2 , 3 )])
75- def xyz (self , request ) :
76- return request .param
83+ def xyz (self , request : Any ) -> Tuple [ int , int , int ] :
84+ return request .param # type: ignore[no-any-return]
7785
7886 @pytest .fixture ()
79- def point (self , xyz ) :
87+ def point (self , xyz : Tuple [ int , int , int ]) -> ECPoint :
8088 return ECPoint (* xyz )
8189
82- @pytest .mark .parametrize ("repr_" , [repr , str ], ids = lambda f : f .__name__ )
83- def test_repr (self , repr_ , point ) -> None :
90+ @pytest .mark .parametrize (
91+ "repr_" ,
92+ [repr , str ],
93+ ids = lambda f : f .__name__ , # type: ignore[no-any-return]
94+ )
95+ def test_repr (
96+ self , repr_ : Callable [[object ], str ], point : ECPoint
97+ ) -> None :
8498 assert isinstance (repr_ (point ), str )
8599
86- def test_pickle (self , point ) -> None :
100+ def test_pickle (self , point : ECPoint ) -> None :
87101 assert point == pickle .loads (pickle .dumps (point ))
88102
89- def test_hash (self , point ) -> None :
103+ def test_hash (self , point : ECPoint ) -> None :
90104 assert isinstance (hash (point ), int )
91105
92- def test_accessors (self , point , xyz ) -> None :
106+ def test_accessors (
107+ self , point : ECPoint , xyz : Tuple [int , int , int ]
108+ ) -> None :
93109 x , y , z = xyz
94110 assert point .x == x
95111 assert point .y == y
96112 assert point .z == z
97113
98- def test_eq_point (self , point , xyz ) -> None :
114+ def test_eq_point (self , point : ECPoint , xyz : Tuple [ int , int , int ] ) -> None :
99115 assert (point == ECPoint (* xyz )) is True
100116 assert (point == ECPoint (0 , 0 , 0 )) is False
101117
@@ -105,7 +121,7 @@ def test_eq_zero(self) -> None:
105121 assert (zero == 0 ) is True
106122 assert (zero == ECPoint (0 , 0 , 0 )) is True
107123
108- def test_bool (self , point ) -> None :
124+ def test_bool (self , point : ECPoint ) -> None :
109125 assert bool (point ) is True
110126 assert bool (ECPoint (0 , 0 , 0 )) is False
111127
@@ -127,13 +143,13 @@ def cipher(self, request: Any) -> Union[ECC, RSA]:
127143 return cast (RSA , request .param ())
128144 return ECC (request .param )
129145
130- def test_pickle (self , cipher ) -> None :
146+ def test_pickle (self , cipher : Union [ ECC , RSA ] ) -> None :
131147 assert cipher == pickle .loads (pickle .dumps (cipher ))
132148
133- def test_hash (self , cipher ) -> None :
149+ def test_hash (self , cipher : Union [ ECC , RSA ] ) -> None :
134150 assert isinstance (hash (cipher ), int )
135151
136- def test_export_private_key (self , cipher ) -> None :
152+ def test_export_private_key (self , cipher : Union [ ECC , RSA ] ) -> None :
137153 assert cipher .export_key ("DER" ) == b""
138154 assert cipher .export_key ("PEM" ) == ""
139155 assert cipher .key_size == 0
@@ -150,7 +166,7 @@ def test_export_private_key(self, cipher) -> None:
150166 assert type (cipher ).from_PEM (cipher .export_key ("PEM" )) == cipher
151167 assert cipher .key_size > 0
152168
153- def test_import_private_key (self , cipher ) -> None :
169+ def test_import_private_key (self , cipher : Union [ ECC , RSA ] ) -> None :
154170 assert not cipher .export_key ()
155171 assert not cipher .export_public_key ()
156172
@@ -164,7 +180,7 @@ def test_import_private_key(self, cipher) -> None:
164180 assert check_pair (other , other ) is True
165181 assert cipher == other
166182
167- def test_export_public_key (self , cipher ) -> None :
183+ def test_export_public_key (self , cipher : Union [ ECC , RSA ] ) -> None :
168184 assert cipher .export_public_key ("DER" ) == b""
169185 assert cipher .export_public_key ("PEM" ) == ""
170186
@@ -178,13 +194,13 @@ def test_export_public_key(self, cipher) -> None:
178194 cipher .export_public_key ("PEM" )
179195 ) == cipher .export_public_key ("PEM" )
180196
181- def test_import_public_key (self , cipher ) -> None :
197+ def test_import_public_key (self , cipher : Union [ ECC , RSA ] ) -> None :
182198 assert not cipher .export_key ()
183199 assert not cipher .export_public_key ()
184200
185201 do_generate (cipher )
186202
187- pub = cipher .export_public_key ()
203+ pub = cipher .export_public_key ("DER" )
188204 other = type (cipher ).from_buffer (pub )
189205 assert not other .export_key ()
190206 assert other .export_public_key ()
@@ -196,9 +212,14 @@ def test_import_public_key(self, cipher) -> None:
196212 @pytest .mark .parametrize (
197213 "digestmod" ,
198214 [_get_md_alg (name ) for name in hashlib .algorithms_guaranteed ],
199- ids = lambda dm : dm ().name ,
215+ ids = lambda dm : dm ().name , # type: ignore[no-any-return]
200216 )
201- def test_sign_verify (self , cipher , digestmod , randbytes ) -> None :
217+ def test_sign_verify (
218+ self ,
219+ cipher : Union [ECC , RSA ],
220+ digestmod : str ,
221+ randbytes : Callable [[int ], bytes ],
222+ ) -> None :
202223 msg = randbytes (4096 )
203224 assert cipher .sign (msg , digestmod ) is None
204225
@@ -212,17 +233,18 @@ def test_sign_verify(self, cipher, digestmod, randbytes) -> None:
212233
213234class TestECCExportKey :
214235 @pytest .fixture (params = get_supported_curves ())
215- def curve (self , request ):
236+ def curve (self , request : Any ) -> Curve :
237+ assert isinstance (request .param , Curve )
216238 return request .param
217239
218- def test_export_private_key (self , curve ) -> None :
240+ def test_export_private_key (self , curve : Curve ) -> None :
219241 ecc = ECC (curve )
220242 assert ecc .export_key ("NUM" ) == 0
221243
222244 ecc .generate ()
223245 assert ecc .export_key ("NUM" ) != 0
224246
225- def test_export_public_key_to_point (self , curve ) -> None :
247+ def test_export_public_key_to_point (self , curve : Curve ) -> None :
226248 ecc = ECC (curve )
227249 assert ecc .export_public_key ("POINT" ) == 0
228250 assert ecc .export_public_key ("POINT" ) == ECPoint (0 , 0 , 0 )
@@ -244,7 +266,7 @@ def test_export_public_key_to_point(self, curve) -> None:
244266
245267class TestECCtoECDH :
246268 @pytest .mark .parametrize ("curve" , get_supported_curves ())
247- def test_exchange (self , curve ) -> None :
269+ def test_exchange (self , curve : Curve ) -> None :
248270 ecp = ECC (curve )
249271 ecp .generate ()
250272
@@ -263,7 +285,9 @@ def test_exchange(self, curve) -> None:
263285
264286class TestDH :
265287 @pytest .mark .parametrize ("dh_cls" , [DHServer , DHClient ])
266- def test_pickle (self , dh_cls ) -> None :
288+ def test_pickle (
289+ self , dh_cls : Union [Type [DHServer ], Type [DHClient ]]
290+ ) -> None :
267291 dhentity = dh_cls (MPI .prime (64 ), MPI .prime (20 ))
268292
269293 with pytest .raises (TypeError ) as excinfo :
@@ -272,7 +296,9 @@ def test_pickle(self, dh_cls) -> None:
272296 assert str (excinfo .value ).startswith ("cannot pickle" )
273297
274298 @pytest .mark .parametrize ("dh_cls" , [DHServer , DHClient ])
275- def test_accessors (self , dh_cls ) -> None :
299+ def test_accessors (
300+ self , dh_cls : Union [Type [DHServer ], Type [DHClient ]]
301+ ) -> None :
276302 modulus = MPI .prime (64 )
277303 generator = MPI .prime (20 )
278304
@@ -301,12 +327,14 @@ def test_exchange(self) -> None:
301327
302328class TestECDH :
303329 @pytest .fixture (params = get_supported_curves ())
304- def key (self , request ) :
330+ def key (self , request : Any ) -> ECC :
305331 key = ECC (request .param )
306332 return key
307333
308334 @pytest .mark .parametrize ("peer_cls" , [ECDHServer , ECDHClient ])
309- def test_pickle (self , key , peer_cls ) -> None :
335+ def test_pickle (
336+ self , key : ECC , peer_cls : Union [Type [ECDHServer ], Type [ECDHClient ]]
337+ ) -> None :
310338 key .generate ()
311339 peer = peer_cls (key )
312340
@@ -316,7 +344,9 @@ def test_pickle(self, key, peer_cls) -> None:
316344 assert str (excinfo .value ).startswith ("cannot pickle" )
317345
318346 @pytest .mark .parametrize ("peer_cls" , [ECDHServer , ECDHClient ])
319- def test_key_accessors_without_key (self , key , peer_cls ) -> None :
347+ def test_key_accessors_without_key (
348+ self , key : ECC , peer_cls : Union [Type [ECDHServer ], Type [ECDHClient ]]
349+ ) -> None :
320350 peer = peer_cls (key )
321351
322352 assert not peer ._has_private ()
@@ -327,9 +357,9 @@ def test_key_accessors_without_key(self, key, peer_cls) -> None:
327357 assert peer .peers_public_key == 0
328358 assert peer .shared_secret == 0
329359
330- def test_client_accessors_with_key (self , key ) -> None :
360+ def test_client_accessors_with_key (self , key : ECC ) -> None :
331361 der = key .generate ()
332- format = (
362+ format : Literal [ "NUM" , "DER" ] = (
333363 "NUM" if key .curve in (Curve .CURVE25519 , Curve .CURVE448 ) else "DER"
334364 )
335365 assert der == key .export_key (format )
@@ -344,9 +374,9 @@ def test_client_accessors_with_key(self, key) -> None:
344374 assert peer .peers_public_key == key .export_public_key ("POINT" )
345375 assert peer .shared_secret == 0
346376
347- def test_server_accessors_with_key (self , key ) -> None :
377+ def test_server_accessors_with_key (self , key : ECC ) -> None :
348378 der = key .generate ()
349- format = (
379+ format : Literal [ "NUM" , "DER" ] = (
350380 "NUM" if key .curve in (Curve .CURVE25519 , Curve .CURVE448 ) else "DER"
351381 )
352382 assert der == key .export_key (format )
@@ -361,7 +391,7 @@ def test_server_accessors_with_key(self, key) -> None:
361391 assert peer .peers_public_key == 0
362392 assert peer .shared_secret == 0
363393
364- def test_exchange (self , key ) -> None :
394+ def test_exchange (self , key : ECC ) -> None :
365395 srv , cli = ECDHServer (key ), ECDHClient (key )
366396
367397 ske = srv .generate ()
@@ -383,7 +413,7 @@ def test_exchange(self, key) -> None:
383413 assert srv_sec == cli_sec
384414 assert srv .shared_secret == cli .shared_secret
385415
386- def test_generate_public (self , key ) -> None :
416+ def test_generate_public (self , key : ECC ) -> None :
387417 srv , cli = ECDHServer (key ), ECDHClient (key )
388418
389419 srv .generate ()
@@ -393,7 +423,7 @@ def test_generate_public(self, key) -> None:
393423 assert cli .public_key == srv .public_key
394424
395425
396- def do_exchange (alice , bob ) -> None :
426+ def do_exchange (alice : ECDHNaive , bob : ECDHNaive ) -> None :
397427 alice_to_bob = alice .generate ()
398428 bob_to_alice = bob .generate ()
399429 alice .import_peers_public (bob_to_alice )
@@ -404,10 +434,11 @@ def do_exchange(alice, bob) -> None:
404434
405435class TestECDHNaive :
406436 @pytest .fixture (params = [Curve .CURVE448 , Curve .CURVE25519 ])
407- def curve (self , request ):
437+ def curve (self , request : Any ) -> Curve :
438+ assert isinstance (request .param , Curve )
408439 return request .param
409440
410- def test_key_accessors_without_key (self , curve ) -> None :
441+ def test_key_accessors_without_key (self , curve : Curve ) -> None :
411442 peer = ECDHNaive (curve )
412443
413444 assert not peer ._has_private ()
@@ -417,7 +448,7 @@ def test_key_accessors_without_key(self, curve) -> None:
417448 assert peer .public_key == 0
418449 assert peer .peers_public_key == 0
419450
420- def test_exchange (self , curve ) -> None :
451+ def test_exchange (self , curve : Curve ) -> None :
421452 alice , bob = ECDHNaive (curve ), ECDHNaive (curve )
422453
423454 alice_to_bob = alice .generate ()
@@ -448,7 +479,7 @@ def test_exchange(self, curve) -> None:
448479 assert alice_secret == alice .shared_secret
449480 assert bob_secret == bob .shared_secret
450481
451- def test_attacker_fails_with_public_keys (self , curve ) -> None :
482+ def test_attacker_fails_with_public_keys (self , curve : Curve ) -> None :
452483 alice , bob = ECDHNaive (curve ), ECDHNaive (curve )
453484 do_exchange (alice , bob )
454485
@@ -458,7 +489,7 @@ def test_attacker_fails_with_public_keys(self, curve) -> None:
458489 with pytest .raises (TLSError ):
459490 eve .generate_secret ()
460491
461- def test_attacker_succeeds_with_private_key (self , curve ) -> None :
492+ def test_attacker_succeeds_with_private_key (self , curve : Curve ) -> None :
462493 alice , bob = ECDHNaive (curve ), ECDHNaive (curve )
463494 do_exchange (alice , bob )
464495
0 commit comments