Skip to content

Commit 3e3b077

Browse files
committed
refactor(cargo_provider): cleanup and get rid of potential type errors
1 parent 9f3ec86 commit 3e3b077

File tree

2 files changed

+198
-45
lines changed

2 files changed

+198
-45
lines changed

commitizen/providers/cargo_provider.py

Lines changed: 51 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,15 @@
33
import fnmatch
44
import glob
55
from pathlib import Path
6+
from typing import TYPE_CHECKING
67

7-
import tomlkit
8+
from tomlkit import TOMLDocument, dumps, parse
9+
from tomlkit.exceptions import NonExistentKey
10+
from tomlkit.items import AoT
811

912
from commitizen.providers.base_provider import TomlProvider
1013

1114

12-
def matches_exclude(path: str, exclude_patterns: list[str]) -> bool:
13-
for pattern in exclude_patterns:
14-
if fnmatch.fnmatch(path, pattern):
15-
return True
16-
return False
17-
18-
1915
class CargoProvider(TomlProvider):
2016
"""
2117
Cargo version management
@@ -30,65 +26,75 @@ class CargoProvider(TomlProvider):
3026
def lock_file(self) -> Path:
3127
return Path() / self.lock_filename
3228

33-
def get(self, document: tomlkit.TOMLDocument) -> str:
34-
# If there is a root package, change its version (but not the workspace version)
35-
try:
36-
return document["package"]["version"] # type: ignore[index,return-value]
37-
# Else, bump the workspace version
38-
except tomlkit.exceptions.NonExistentKey:
39-
...
40-
return document["workspace"]["package"]["version"] # type: ignore[index,return-value]
29+
def get(self, document: TOMLDocument) -> str:
30+
out = _try_get_workspace(document)["package"]["version"]
31+
assert isinstance(out, str)
32+
return out
4133

42-
def set(self, document: tomlkit.TOMLDocument, version: str) -> None:
43-
try:
44-
document["workspace"]["package"]["version"] = version # type: ignore[index]
45-
return
46-
except tomlkit.exceptions.NonExistentKey:
47-
...
48-
document["package"]["version"] = version # type: ignore[index]
34+
def set(self, document: TOMLDocument, version: str) -> None:
35+
_try_get_workspace(document)["package"]["version"] = version
4936

5037
def set_version(self, version: str) -> None:
5138
super().set_version(version)
5239
if self.lock_file.exists():
5340
self.set_lock_version(version)
5441

5542
def set_lock_version(self, version: str) -> None:
56-
cargo_toml_content = tomlkit.parse(self.file.read_text())
57-
cargo_lock_content = tomlkit.parse(self.lock_file.read_text())
58-
packages: tomlkit.items.AoT = cargo_lock_content["package"] # type: ignore[assignment]
43+
cargo_toml_content = parse(self.file.read_text())
44+
cargo_lock_content = parse(self.lock_file.read_text())
45+
packages = cargo_lock_content["package"]
46+
47+
assert isinstance(packages, AoT)
48+
5949
try:
60-
package_name = cargo_toml_content["package"]["name"] # type: ignore[index]
50+
cargo_package_name = cargo_toml_content["package"]["name"] # type: ignore[index]
51+
if TYPE_CHECKING:
52+
assert isinstance(cargo_package_name, str)
6153
for i, package in enumerate(packages):
62-
if package["name"] == package_name:
54+
if package["name"] == cargo_package_name:
6355
cargo_lock_content["package"][i]["version"] = version # type: ignore[index]
6456
break
65-
except tomlkit.exceptions.NonExistentKey:
66-
workspace_members = cargo_toml_content.get("workspace", {}).get(
67-
"members", []
68-
)
69-
excluded_workspace_members = cargo_toml_content.get("workspace", {}).get(
70-
"exclude", []
71-
)
72-
members_inheriting = []
57+
except NonExistentKey:
58+
workspace = cargo_toml_content.get("workspace", {})
59+
assert isinstance(workspace, dict)
60+
workspace_members = workspace.get("members", [])
61+
excluded_workspace_members = workspace.get("exclude", [])
62+
members_inheriting: list[str] = []
7363

7464
for member in workspace_members:
7565
for path in glob.glob(member, recursive=True):
76-
if matches_exclude(path, excluded_workspace_members):
66+
if any(
67+
fnmatch.fnmatch(path, pattern)
68+
for pattern in excluded_workspace_members
69+
):
7770
continue
71+
7872
cargo_file = Path(path) / "Cargo.toml"
79-
cargo_toml_content = tomlkit.parse(cargo_file.read_text())
73+
package_content = parse(cargo_file.read_text()).get("package", {})
74+
if TYPE_CHECKING:
75+
assert isinstance(package_content, dict)
8076
try:
81-
version_workspace = cargo_toml_content["package"]["version"][ # type: ignore[index]
82-
"workspace"
83-
]
77+
version_workspace = package_content["version"]["workspace"]
8478
if version_workspace is True:
85-
package_name = cargo_toml_content["package"]["name"] # type: ignore[index]
79+
package_name = package_content["name"]
80+
if TYPE_CHECKING:
81+
assert isinstance(package_name, str)
8682
members_inheriting.append(package_name)
87-
except tomlkit.exceptions.NonExistentKey:
88-
continue
83+
except NonExistentKey:
84+
pass
8985

9086
for i, package in enumerate(packages):
9187
if package["name"] in members_inheriting:
9288
cargo_lock_content["package"][i]["version"] = version # type: ignore[index]
9389

94-
self.lock_file.write_text(tomlkit.dumps(cargo_lock_content))
90+
self.lock_file.write_text(dumps(cargo_lock_content))
91+
92+
93+
def _try_get_workspace(document: TOMLDocument) -> dict:
94+
try:
95+
workspace = document["workspace"]
96+
if TYPE_CHECKING:
97+
assert isinstance(workspace, dict)
98+
return workspace
99+
except NonExistentKey:
100+
return document

tests/providers/test_cargo_provider.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,3 +301,150 @@ def test_cargo_provider_with_lock(
301301
provider.set_version("42.1")
302302
assert file.read_text() == dedent(toml_expected)
303303
assert lock_file.read_text() == dedent(lock_expected)
304+
305+
306+
def test_cargo_provider_workspace_member_without_version_key(
307+
config: BaseConfig,
308+
chdir: Path,
309+
):
310+
"""Test workspace member that has no version key at all (should not crash)."""
311+
workspace_toml = """\
312+
[workspace]
313+
members = ["member_without_version"]
314+
315+
[workspace.package]
316+
version = "0.1.0"
317+
"""
318+
319+
# Create a member that has no version key at all
320+
member_content = """\
321+
[package]
322+
name = "member_without_version"
323+
# No version key - this should trigger NonExistentKey exception
324+
"""
325+
326+
lock_content = """\
327+
[[package]]
328+
name = "member_without_version"
329+
version = "0.1.0"
330+
source = "registry+https://github.com/rust-lang/crates.io-index"
331+
checksum = "123abc"
332+
"""
333+
334+
expected_workspace_toml = """\
335+
[workspace]
336+
members = ["member_without_version"]
337+
338+
[workspace.package]
339+
version = "42.1"
340+
"""
341+
342+
expected_lock_content = """\
343+
[[package]]
344+
name = "member_without_version"
345+
version = "0.1.0"
346+
source = "registry+https://github.com/rust-lang/crates.io-index"
347+
checksum = "123abc"
348+
"""
349+
350+
# Create the workspace file
351+
filename = CargoProvider.filename
352+
file = chdir / filename
353+
file.write_text(dedent(workspace_toml))
354+
355+
# Create the member directory and file
356+
os.mkdir(chdir / "member_without_version")
357+
member_file = chdir / "member_without_version" / "Cargo.toml"
358+
member_file.write_text(dedent(member_content))
359+
360+
# Create the lock file
361+
lock_filename = CargoProvider.lock_filename
362+
lock_file = chdir / lock_filename
363+
lock_file.write_text(dedent(lock_content))
364+
365+
config.settings["version_provider"] = "cargo"
366+
367+
provider = get_provider(config)
368+
assert isinstance(provider, CargoProvider)
369+
assert provider.get_version() == "0.1.0"
370+
371+
# This should not crash even though the member has no version key
372+
provider.set_version("42.1")
373+
assert file.read_text() == dedent(expected_workspace_toml)
374+
# The lock file should remain unchanged since the member doesn't inherit workspace version
375+
assert lock_file.read_text() == dedent(expected_lock_content)
376+
377+
378+
def test_cargo_provider_workspace_member_without_workspace_key(
379+
config: BaseConfig,
380+
chdir: Path,
381+
):
382+
"""Test workspace member that has version key but no workspace subkey."""
383+
workspace_toml = """\
384+
[workspace]
385+
members = ["member_without_workspace"]
386+
387+
[workspace.package]
388+
version = "0.1.0"
389+
"""
390+
391+
# Create a member that has version as a table but no workspace subkey
392+
# This should trigger NonExistentKey when trying to access version["workspace"]
393+
member_content = """\
394+
[package]
395+
name = "member_without_workspace"
396+
397+
[package.version]
398+
# Has version table but no workspace key - should trigger NonExistentKey
399+
"""
400+
401+
lock_content = """\
402+
[[package]]
403+
name = "member_without_workspace"
404+
version = "0.1.0"
405+
source = "registry+https://github.com/rust-lang/crates.io-index"
406+
checksum = "123abc"
407+
"""
408+
409+
expected_workspace_toml = """\
410+
[workspace]
411+
members = ["member_without_workspace"]
412+
413+
[workspace.package]
414+
version = "42.1"
415+
"""
416+
417+
expected_lock_content = """\
418+
[[package]]
419+
name = "member_without_workspace"
420+
version = "0.1.0"
421+
source = "registry+https://github.com/rust-lang/crates.io-index"
422+
checksum = "123abc"
423+
"""
424+
425+
# Create the workspace file
426+
filename = CargoProvider.filename
427+
file = chdir / filename
428+
file.write_text(dedent(workspace_toml))
429+
430+
# Create the member directory and file
431+
os.mkdir(chdir / "member_without_workspace")
432+
member_file = chdir / "member_without_workspace" / "Cargo.toml"
433+
member_file.write_text(dedent(member_content))
434+
435+
# Create the lock file
436+
lock_filename = CargoProvider.lock_filename
437+
lock_file = chdir / lock_filename
438+
lock_file.write_text(dedent(lock_content))
439+
440+
config.settings["version_provider"] = "cargo"
441+
442+
provider = get_provider(config)
443+
assert isinstance(provider, CargoProvider)
444+
assert provider.get_version() == "0.1.0"
445+
446+
# This should not crash even though the member has no version.workspace key
447+
provider.set_version("42.1")
448+
assert file.read_text() == dedent(expected_workspace_toml)
449+
# The lock file should remain unchanged since the member doesn't inherit workspace version
450+
assert lock_file.read_text() == dedent(expected_lock_content)

0 commit comments

Comments
 (0)