Skip to content

Commit 018b16f

Browse files
committed
test_pk: Add type annotations
The file now passes `mypy --strict`.
1 parent 9c55a4f commit 018b16f

File tree

2 files changed

+85
-44
lines changed

2 files changed

+85
-44
lines changed

src/mbedtls/pk.pyi

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,13 @@ class RSA(CipherBase):
107107
@overload
108108
def export_key(self, format: Literal["PEM"]) -> _PEM: ...
109109
@overload
110+
def export_key(self) -> Union[_DER, _PEM]: ...
111+
@overload
110112
def export_public_key(self, format: Literal["DER"]) -> _DER: ...
111113
@overload
112114
def export_public_key(self, format: Literal["PEM"]) -> _PEM: ...
115+
@overload
116+
def export_public_key(self) -> Union[_DER, _PEM]: ...
113117

114118
class ECPoint:
115119
def __init__(self, x: _MPI, y: _MPI, z: _MPI) -> None: ...
@@ -132,6 +136,8 @@ class ECC(CipherBase):
132136
key: Optional[bytes] = ...,
133137
password: Optional[bytes] = ...,
134138
) -> None: ...
139+
@property
140+
def curve(self) -> Curve: ...
135141
def generate(self) -> bytes: ...
136142
@overload
137143
def export_key(self, format: Literal["DER"]) -> _DER: ...
@@ -140,11 +146,15 @@ class ECC(CipherBase):
140146
@overload
141147
def export_key(self, format: Literal["NUM"]) -> _NUM: ...
142148
@overload
149+
def export_key(self) -> Union[_DER, _PEM, _NUM]: ...
150+
@overload
143151
def export_public_key(self, format: Literal["DER"]) -> _DER: ...
144152
@overload
145153
def export_public_key(self, format: Literal["PEM"]) -> _PEM: ...
146154
@overload
147155
def export_public_key(self, format: Literal["POINT"]) -> ECPoint: ...
156+
@overload
157+
def export_public_key(self) -> Union[_DER, _PEM, ECPoint]: ...
148158

149159
class DHServer:
150160
def __init__(self, modulus: _MPI, generator: _MPI) -> None: ...

tests/test_pk.py

Lines changed: 75 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
import numbers
44
import pickle
5-
from typing import Any, List, Union, cast
5+
import sys
6+
from typing import Any, Callable, List, Tuple, Type, TypeVar, Union, cast
67

78
import pytest
89

@@ -13,7 +14,6 @@
1314
from mbedtls.pk import (
1415
ECC,
1516
RSA,
16-
CipherBase,
1717
Curve,
1818
DHClient,
1919
DHServer,
@@ -26,6 +26,14 @@
2626
get_supported_curves,
2727
)
2828

29+
if sys.version_info < (3, 8):
30+
from typing_extensions import Literal
31+
else:
32+
from typing import Literal
33+
34+
35+
T = TypeVar("T")
36+
2937

3038
def test_supported_curves() -> None:
3139
assert sorted(get_supported_curves()) == [
@@ -55,14 +63,14 @@ def test_get_supported_ciphers() -> None:
5563
]
5664

5765

58-
def test_rsa_encryp_decrypt(randbytes) -> None:
66+
def test_rsa_encryp_decrypt(randbytes: Callable[[int], bytes]) -> None:
5967
rsa = RSA()
6068
rsa.generate(1024)
6169
msg = randbytes(rsa.key_size - 11)
6270
assert rsa.decrypt(rsa.encrypt(msg)) == msg
6371

6472

65-
def do_generate(cipher: CipherBase):
73+
def do_generate(cipher: T) -> bytes:
6674
if isinstance(cipher, RSA):
6775
return cipher.generate(1024)
6876
if isinstance(cipher, ECC):
@@ -72,30 +80,38 @@ def do_generate(cipher: CipherBase):
7280

7381
class TestECPoint:
7482
@pytest.fixture(params=[(MPI(1), MPI(2), MPI(3)), (1, 2, 3)])
75-
def xyz(self, request):
76-
return request.param
83+
def xyz(self, request: Any) -> Tuple[int, int, int]:
84+
return request.param # type: ignore[no-any-return]
7785

7886
@pytest.fixture()
79-
def point(self, xyz):
87+
def point(self, xyz: Tuple[int, int, int]) -> ECPoint:
8088
return ECPoint(*xyz)
8189

82-
@pytest.mark.parametrize("repr_", [repr, str], ids=lambda f: f.__name__)
83-
def test_repr(self, repr_, point) -> None:
90+
@pytest.mark.parametrize(
91+
"repr_",
92+
[repr, str],
93+
ids=lambda f: f.__name__, # type: ignore[no-any-return]
94+
)
95+
def test_repr(
96+
self, repr_: Callable[[object], str], point: ECPoint
97+
) -> None:
8498
assert isinstance(repr_(point), str)
8599

86-
def test_pickle(self, point) -> None:
100+
def test_pickle(self, point: ECPoint) -> None:
87101
assert point == pickle.loads(pickle.dumps(point))
88102

89-
def test_hash(self, point) -> None:
103+
def test_hash(self, point: ECPoint) -> None:
90104
assert isinstance(hash(point), int)
91105

92-
def test_accessors(self, point, xyz) -> None:
106+
def test_accessors(
107+
self, point: ECPoint, xyz: Tuple[int, int, int]
108+
) -> None:
93109
x, y, z = xyz
94110
assert point.x == x
95111
assert point.y == y
96112
assert point.z == z
97113

98-
def test_eq_point(self, point, xyz) -> None:
114+
def test_eq_point(self, point: ECPoint, xyz: Tuple[int, int, int]) -> None:
99115
assert (point == ECPoint(*xyz)) is True
100116
assert (point == ECPoint(0, 0, 0)) is False
101117

@@ -105,7 +121,7 @@ def test_eq_zero(self) -> None:
105121
assert (zero == 0) is True
106122
assert (zero == ECPoint(0, 0, 0)) is True
107123

108-
def test_bool(self, point) -> None:
124+
def test_bool(self, point: ECPoint) -> None:
109125
assert bool(point) is True
110126
assert bool(ECPoint(0, 0, 0)) is False
111127

@@ -127,13 +143,13 @@ def cipher(self, request: Any) -> Union[ECC, RSA]:
127143
return cast(RSA, request.param())
128144
return ECC(request.param)
129145

130-
def test_pickle(self, cipher) -> None:
146+
def test_pickle(self, cipher: Union[ECC, RSA]) -> None:
131147
assert cipher == pickle.loads(pickle.dumps(cipher))
132148

133-
def test_hash(self, cipher) -> None:
149+
def test_hash(self, cipher: Union[ECC, RSA]) -> None:
134150
assert isinstance(hash(cipher), int)
135151

136-
def test_export_private_key(self, cipher) -> None:
152+
def test_export_private_key(self, cipher: Union[ECC, RSA]) -> None:
137153
assert cipher.export_key("DER") == b""
138154
assert cipher.export_key("PEM") == ""
139155
assert cipher.key_size == 0
@@ -150,7 +166,7 @@ def test_export_private_key(self, cipher) -> None:
150166
assert type(cipher).from_PEM(cipher.export_key("PEM")) == cipher
151167
assert cipher.key_size > 0
152168

153-
def test_import_private_key(self, cipher) -> None:
169+
def test_import_private_key(self, cipher: Union[ECC, RSA]) -> None:
154170
assert not cipher.export_key()
155171
assert not cipher.export_public_key()
156172

@@ -164,7 +180,7 @@ def test_import_private_key(self, cipher) -> None:
164180
assert check_pair(other, other) is True
165181
assert cipher == other
166182

167-
def test_export_public_key(self, cipher) -> None:
183+
def test_export_public_key(self, cipher: Union[ECC, RSA]) -> None:
168184
assert cipher.export_public_key("DER") == b""
169185
assert cipher.export_public_key("PEM") == ""
170186

@@ -178,13 +194,13 @@ def test_export_public_key(self, cipher) -> None:
178194
cipher.export_public_key("PEM")
179195
) == cipher.export_public_key("PEM")
180196

181-
def test_import_public_key(self, cipher) -> None:
197+
def test_import_public_key(self, cipher: Union[ECC, RSA]) -> None:
182198
assert not cipher.export_key()
183199
assert not cipher.export_public_key()
184200

185201
do_generate(cipher)
186202

187-
pub = cipher.export_public_key()
203+
pub = cipher.export_public_key("DER")
188204
other = type(cipher).from_buffer(pub)
189205
assert not other.export_key()
190206
assert other.export_public_key()
@@ -196,9 +212,14 @@ def test_import_public_key(self, cipher) -> None:
196212
@pytest.mark.parametrize(
197213
"digestmod",
198214
[_get_md_alg(name) for name in hashlib.algorithms_guaranteed],
199-
ids=lambda dm: dm().name,
215+
ids=lambda dm: dm().name, # type: ignore[no-any-return]
200216
)
201-
def test_sign_verify(self, cipher, digestmod, randbytes) -> None:
217+
def test_sign_verify(
218+
self,
219+
cipher: Union[ECC, RSA],
220+
digestmod: str,
221+
randbytes: Callable[[int], bytes],
222+
) -> None:
202223
msg = randbytes(4096)
203224
assert cipher.sign(msg, digestmod) is None
204225

@@ -212,17 +233,18 @@ def test_sign_verify(self, cipher, digestmod, randbytes) -> None:
212233

213234
class TestECCExportKey:
214235
@pytest.fixture(params=get_supported_curves())
215-
def curve(self, request):
236+
def curve(self, request: Any) -> Curve:
237+
assert isinstance(request.param, Curve)
216238
return request.param
217239

218-
def test_export_private_key(self, curve) -> None:
240+
def test_export_private_key(self, curve: Curve) -> None:
219241
ecc = ECC(curve)
220242
assert ecc.export_key("NUM") == 0
221243

222244
ecc.generate()
223245
assert ecc.export_key("NUM") != 0
224246

225-
def test_export_public_key_to_point(self, curve) -> None:
247+
def test_export_public_key_to_point(self, curve: Curve) -> None:
226248
ecc = ECC(curve)
227249
assert ecc.export_public_key("POINT") == 0
228250
assert ecc.export_public_key("POINT") == ECPoint(0, 0, 0)
@@ -244,7 +266,7 @@ def test_export_public_key_to_point(self, curve) -> None:
244266

245267
class TestECCtoECDH:
246268
@pytest.mark.parametrize("curve", get_supported_curves())
247-
def test_exchange(self, curve) -> None:
269+
def test_exchange(self, curve: Curve) -> None:
248270
ecp = ECC(curve)
249271
ecp.generate()
250272

@@ -263,7 +285,9 @@ def test_exchange(self, curve) -> None:
263285

264286
class TestDH:
265287
@pytest.mark.parametrize("dh_cls", [DHServer, DHClient])
266-
def test_pickle(self, dh_cls) -> None:
288+
def test_pickle(
289+
self, dh_cls: Union[Type[DHServer], Type[DHClient]]
290+
) -> None:
267291
dhentity = dh_cls(MPI.prime(64), MPI.prime(20))
268292

269293
with pytest.raises(TypeError) as excinfo:
@@ -272,7 +296,9 @@ def test_pickle(self, dh_cls) -> None:
272296
assert str(excinfo.value).startswith("cannot pickle")
273297

274298
@pytest.mark.parametrize("dh_cls", [DHServer, DHClient])
275-
def test_accessors(self, dh_cls) -> None:
299+
def test_accessors(
300+
self, dh_cls: Union[Type[DHServer], Type[DHClient]]
301+
) -> None:
276302
modulus = MPI.prime(64)
277303
generator = MPI.prime(20)
278304

@@ -301,12 +327,14 @@ def test_exchange(self) -> None:
301327

302328
class TestECDH:
303329
@pytest.fixture(params=get_supported_curves())
304-
def key(self, request):
330+
def key(self, request: Any) -> ECC:
305331
key = ECC(request.param)
306332
return key
307333

308334
@pytest.mark.parametrize("peer_cls", [ECDHServer, ECDHClient])
309-
def test_pickle(self, key, peer_cls) -> None:
335+
def test_pickle(
336+
self, key: ECC, peer_cls: Union[Type[ECDHServer], Type[ECDHClient]]
337+
) -> None:
310338
key.generate()
311339
peer = peer_cls(key)
312340

@@ -316,7 +344,9 @@ def test_pickle(self, key, peer_cls) -> None:
316344
assert str(excinfo.value).startswith("cannot pickle")
317345

318346
@pytest.mark.parametrize("peer_cls", [ECDHServer, ECDHClient])
319-
def test_key_accessors_without_key(self, key, peer_cls) -> None:
347+
def test_key_accessors_without_key(
348+
self, key: ECC, peer_cls: Union[Type[ECDHServer], Type[ECDHClient]]
349+
) -> None:
320350
peer = peer_cls(key)
321351

322352
assert not peer._has_private()
@@ -327,9 +357,9 @@ def test_key_accessors_without_key(self, key, peer_cls) -> None:
327357
assert peer.peers_public_key == 0
328358
assert peer.shared_secret == 0
329359

330-
def test_client_accessors_with_key(self, key) -> None:
360+
def test_client_accessors_with_key(self, key: ECC) -> None:
331361
der = key.generate()
332-
format = (
362+
format: Literal["NUM", "DER"] = (
333363
"NUM" if key.curve in (Curve.CURVE25519, Curve.CURVE448) else "DER"
334364
)
335365
assert der == key.export_key(format)
@@ -344,9 +374,9 @@ def test_client_accessors_with_key(self, key) -> None:
344374
assert peer.peers_public_key == key.export_public_key("POINT")
345375
assert peer.shared_secret == 0
346376

347-
def test_server_accessors_with_key(self, key) -> None:
377+
def test_server_accessors_with_key(self, key: ECC) -> None:
348378
der = key.generate()
349-
format = (
379+
format: Literal["NUM", "DER"] = (
350380
"NUM" if key.curve in (Curve.CURVE25519, Curve.CURVE448) else "DER"
351381
)
352382
assert der == key.export_key(format)
@@ -361,7 +391,7 @@ def test_server_accessors_with_key(self, key) -> None:
361391
assert peer.peers_public_key == 0
362392
assert peer.shared_secret == 0
363393

364-
def test_exchange(self, key) -> None:
394+
def test_exchange(self, key: ECC) -> None:
365395
srv, cli = ECDHServer(key), ECDHClient(key)
366396

367397
ske = srv.generate()
@@ -383,7 +413,7 @@ def test_exchange(self, key) -> None:
383413
assert srv_sec == cli_sec
384414
assert srv.shared_secret == cli.shared_secret
385415

386-
def test_generate_public(self, key) -> None:
416+
def test_generate_public(self, key: ECC) -> None:
387417
srv, cli = ECDHServer(key), ECDHClient(key)
388418

389419
srv.generate()
@@ -393,7 +423,7 @@ def test_generate_public(self, key) -> None:
393423
assert cli.public_key == srv.public_key
394424

395425

396-
def do_exchange(alice, bob) -> None:
426+
def do_exchange(alice: ECDHNaive, bob: ECDHNaive) -> None:
397427
alice_to_bob = alice.generate()
398428
bob_to_alice = bob.generate()
399429
alice.import_peers_public(bob_to_alice)
@@ -404,10 +434,11 @@ def do_exchange(alice, bob) -> None:
404434

405435
class TestECDHNaive:
406436
@pytest.fixture(params=[Curve.CURVE448, Curve.CURVE25519])
407-
def curve(self, request):
437+
def curve(self, request: Any) -> Curve:
438+
assert isinstance(request.param, Curve)
408439
return request.param
409440

410-
def test_key_accessors_without_key(self, curve) -> None:
441+
def test_key_accessors_without_key(self, curve: Curve) -> None:
411442
peer = ECDHNaive(curve)
412443

413444
assert not peer._has_private()
@@ -417,7 +448,7 @@ def test_key_accessors_without_key(self, curve) -> None:
417448
assert peer.public_key == 0
418449
assert peer.peers_public_key == 0
419450

420-
def test_exchange(self, curve) -> None:
451+
def test_exchange(self, curve: Curve) -> None:
421452
alice, bob = ECDHNaive(curve), ECDHNaive(curve)
422453

423454
alice_to_bob = alice.generate()
@@ -448,7 +479,7 @@ def test_exchange(self, curve) -> None:
448479
assert alice_secret == alice.shared_secret
449480
assert bob_secret == bob.shared_secret
450481

451-
def test_attacker_fails_with_public_keys(self, curve) -> None:
482+
def test_attacker_fails_with_public_keys(self, curve: Curve) -> None:
452483
alice, bob = ECDHNaive(curve), ECDHNaive(curve)
453484
do_exchange(alice, bob)
454485

@@ -458,7 +489,7 @@ def test_attacker_fails_with_public_keys(self, curve) -> None:
458489
with pytest.raises(TLSError):
459490
eve.generate_secret()
460491

461-
def test_attacker_succeeds_with_private_key(self, curve) -> None:
492+
def test_attacker_succeeds_with_private_key(self, curve: Curve) -> None:
462493
alice, bob = ECDHNaive(curve), ECDHNaive(curve)
463494
do_exchange(alice, bob)
464495

0 commit comments

Comments
 (0)