diff --git a/astrbot/core/star/updator.py b/astrbot/core/star/updator.py index cb462d8b2c..7d08ece4cb 100644 --- a/astrbot/core/star/updator.py +++ b/astrbot/core/star/updator.py @@ -1,6 +1,8 @@ import os import shutil +import tempfile import zipfile +from pathlib import Path, PurePosixPath from astrbot.core import logger from astrbot.core.utils.astrbot_path import get_astrbot_plugin_path @@ -9,6 +11,9 @@ from ..star.star import StarMetadata from ..updator import RepoZipUpdator +ARCHIVE_METADATA_ROOT_DIRS = {"__MACOSX"} +ARCHIVE_METADATA_FILE_NAMES = {".DS_Store"} + class PluginUpdator(RepoZipUpdator): def __init__(self, repo_mirror: str = "", verify: str | bool | None = None) -> None: @@ -71,29 +76,129 @@ async def update( return plugin_path def unzip_file(self, zip_path: str, target_dir: str) -> None: - ensure_dir(target_dir) - update_dir = "" + target_path = Path(target_dir) + ensure_dir(target_path) logger.info(f"Extracting archive: {zip_path}") - with zipfile.ZipFile(zip_path, "r") as z: - update_dir = z.namelist()[0] - z.extractall(target_dir) - - files = os.listdir(os.path.join(target_dir, update_dir)) - for f in files: - if os.path.isdir(os.path.join(target_dir, update_dir, f)): - if os.path.exists(os.path.join(target_dir, f)): - shutil.rmtree(os.path.join(target_dir, f), onerror=on_error) - elif os.path.exists(os.path.join(target_dir, f)): - os.remove(os.path.join(target_dir, f)) - shutil.move(os.path.join(target_dir, update_dir, f), target_dir) + staging_path = self._create_extract_temp_dir(target_path) try: - logger.info( - f"Removing temporary files: {zip_path} and {os.path.join(target_dir, update_dir)}", + archive_root_dir = None + with zipfile.ZipFile(zip_path, "r") as z: + members = [ + member + for member in z.infolist() + if not self._is_archive_metadata_member(member.filename) + ] + archive_root_dir = self._get_archive_root_dir(members) + for member in members: + z.extract(member, staging_path) + + source_path = ( + staging_path / archive_root_dir if archive_root_dir else staging_path ) - shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error) + self._move_extracted_children(source_path, target_path) + self._remove_update_files(zip_path, staging_path) + if not staging_path.exists(): + staging_path = None + finally: + if staging_path: + self._remove_staging_path_safely(staging_path) + + @staticmethod + def _create_extract_temp_dir(target_path: Path) -> Path: + return Path( + tempfile.mkdtemp( + prefix=f".{target_path.name}.", + suffix=".extract", + dir=target_path.parent, + ) + ) + + def _move_extracted_children(self, source_path: Path, target_path: Path) -> None: + for child in source_path.iterdir(): + destination = target_path / child.name + self._remove_existing_path(destination) + shutil.move(str(child), str(target_path)) + + @staticmethod + def _remove_update_files(zip_path: str, staging_path: Path) -> None: + try: + logger.info(f"Removing temporary files: {zip_path} and {staging_path}") + shutil.rmtree(staging_path, onerror=on_error) os.remove(zip_path) - except BaseException: + except Exception: logger.warning( - f"Failed to remove update files; you can manually delete {zip_path} and {os.path.join(target_dir, update_dir)}", + f"Failed to remove update files; you can manually delete {zip_path} " + f"and {staging_path}", ) + + @staticmethod + def _remove_staging_path_safely(staging_path: Path) -> None: + if not staging_path.exists(): + return + try: + shutil.rmtree(staging_path, onerror=on_error) + except Exception: + logger.warning( + f"Failed to remove temporary extract directory; " + f"you can manually delete {staging_path}", + ) + + @staticmethod + def _remove_existing_path(path: Path) -> None: + if path.is_dir() and not path.is_symlink(): + shutil.rmtree(path, onerror=on_error) + elif path.exists() or path.is_symlink(): + path.unlink() + + @staticmethod + def _get_archive_root_dir(members: list[zipfile.ZipInfo]) -> str | None: + root_dir = None + has_file = False + has_root_file = False + has_multiple_roots = False + for member in members: + parts = PluginUpdator._get_safe_member_parts(member.filename) + if not parts: + continue + if not member.is_dir(): + has_file = True + if len(parts) == 1 and not member.is_dir(): + has_root_file = True + continue + if root_dir is None: + root_dir = parts[0] + elif root_dir != parts[0]: + has_multiple_roots = True + if not has_file: + raise ValueError("Empty plugin archive") + if has_root_file or has_multiple_roots: + return None + return root_dir + + @staticmethod + def _is_archive_metadata_member(member_name: str) -> bool: + parts = PluginUpdator._get_safe_member_parts(member_name) + if not parts: + return False + return ( + parts[0] in ARCHIVE_METADATA_ROOT_DIRS + or parts[-1] in ARCHIVE_METADATA_FILE_NAMES + ) + + @staticmethod + def _get_safe_member_parts(member_name: str) -> tuple[str, ...]: + if not member_name: + return () + if "\\" in member_name: + raise ValueError(f"Unsafe path in zip archive: {member_name}") + + member_path = PurePosixPath(member_name) + parts = tuple(part for part in member_path.parts if part) + if ( + member_path.is_absolute() + or any(part in {".", ".."} for part in parts) + or any(":" in part for part in parts) + ): + raise ValueError(f"Unsafe path in zip archive: {member_name}") + return parts diff --git a/tests/test_updator_socks.py b/tests/test_updator_socks.py index 11009ff85c..db1fee7564 100644 --- a/tests/test_updator_socks.py +++ b/tests/test_updator_socks.py @@ -1,3 +1,4 @@ +import zipfile from dataclasses import dataclass, field from pathlib import Path from types import SimpleNamespace @@ -177,6 +178,160 @@ def fake_unzip_file(zip_path: str, target_dir: str): assert calls["unzip"] == (str(expected_path) + ".zip", str(expected_path)) +def test_plugin_updator_unzip_file_accepts_flat_plugin_archive(tmp_path: Path) -> None: + archive_path = tmp_path / "flat_plugin.zip" + target_path = tmp_path / "plugin_upload" + with zipfile.ZipFile(archive_path, "w") as archive: + archive.writestr("main.py", "print('loaded')\n") + archive.writestr("metadata.yaml", "name: flat_plugin\n") + archive.writestr("commands/__init__.py", "") + + PluginUpdator().unzip_file(str(archive_path), str(target_path)) + + assert (target_path / "main.py").read_text(encoding="utf-8") == "print('loaded')\n" + assert (target_path / "metadata.yaml").read_text(encoding="utf-8") == ( + "name: flat_plugin\n" + ) + assert (target_path / "commands" / "__init__.py").exists() + assert not archive_path.exists() + + +def test_plugin_updator_unzip_file_rejects_empty_archive(tmp_path: Path) -> None: + archive_path = tmp_path / "empty_plugin.zip" + target_path = tmp_path / "plugin_upload" + with zipfile.ZipFile(archive_path, "w"): + pass + + with pytest.raises(ValueError, match="Empty plugin archive"): + PluginUpdator().unzip_file(str(archive_path), str(target_path)) + + assert not any(target_path.iterdir()) + + +def test_plugin_updator_unzip_file_flattens_single_root_dir(tmp_path: Path) -> None: + archive_path = tmp_path / "rooted_plugin.zip" + target_path = tmp_path / "plugin_upload" + with zipfile.ZipFile(archive_path, "w") as archive: + archive.writestr("astrbot_plugin_demo-main/main.py", "print('loaded')\n") + archive.writestr("astrbot_plugin_demo-main/metadata.yaml", "name: demo\n") + archive.writestr("astrbot_plugin_demo-main/services/__init__.py", "") + + PluginUpdator().unzip_file(str(archive_path), str(target_path)) + + assert (target_path / "main.py").exists() + assert (target_path / "metadata.yaml").exists() + assert (target_path / "services" / "__init__.py").exists() + assert not (target_path / "astrbot_plugin_demo-main").exists() + assert not archive_path.exists() + + +def test_plugin_updator_unzip_file_ignores_macos_metadata_when_flattening( + tmp_path: Path, +) -> None: + archive_path = tmp_path / "rooted_plugin_with_macos_metadata.zip" + target_path = tmp_path / "plugin_upload" + with zipfile.ZipFile(archive_path, "w") as archive: + archive.writestr("astrbot_plugin_demo-main/main.py", "print('loaded')\n") + archive.writestr("astrbot_plugin_demo-main/metadata.yaml", "name: demo\n") + archive.writestr("astrbot_plugin_demo-main/.DS_Store", "") + archive.writestr("__MACOSX/._astrbot_plugin_demo-main", "") + + PluginUpdator().unzip_file(str(archive_path), str(target_path)) + + assert (target_path / "main.py").exists() + assert (target_path / "metadata.yaml").exists() + assert not (target_path / "astrbot_plugin_demo-main").exists() + assert not (target_path / "__MACOSX").exists() + assert not (target_path / ".DS_Store").exists() + assert not archive_path.exists() + + +def test_plugin_updator_unzip_file_keeps_multiple_root_entries( + tmp_path: Path, +) -> None: + archive_path = tmp_path / "multi_root.zip" + target_path = tmp_path / "plugin_upload" + with zipfile.ZipFile(archive_path, "w") as archive: + archive.writestr("plugin_a/main.py", "print('a')\n") + archive.writestr("plugin_b/main.py", "print('b')\n") + + PluginUpdator().unzip_file(str(archive_path), str(target_path)) + + assert (target_path / "plugin_a" / "main.py").exists() + assert (target_path / "plugin_b" / "main.py").exists() + assert not (target_path / "main.py").exists() + assert not archive_path.exists() + + +def test_plugin_updator_unzip_file_keeps_root_dir_with_extra_empty_root_dir( + tmp_path: Path, +) -> None: + archive_path = tmp_path / "rooted_plugin_with_empty_dir.zip" + target_path = tmp_path / "plugin_upload" + with zipfile.ZipFile(archive_path, "w") as archive: + archive.writestr("plugin/main.py", "print('loaded')\n") + archive.writestr("docs/", "") + + PluginUpdator().unzip_file(str(archive_path), str(target_path)) + + assert (target_path / "plugin" / "main.py").exists() + assert (target_path / "docs").is_dir() + assert not (target_path / "main.py").exists() + assert not archive_path.exists() + + +def test_plugin_updator_unzip_file_flattens_root_dir_with_same_named_child( + tmp_path: Path, +) -> None: + archive_path = tmp_path / "same_named_child.zip" + target_path = tmp_path / "plugin_upload" + with zipfile.ZipFile(archive_path, "w") as archive: + archive.writestr("my_plugin/main.py", "print('loaded')\n") + archive.writestr("my_plugin/my_plugin/__init__.py", "") + + PluginUpdator().unzip_file(str(archive_path), str(target_path)) + + assert (target_path / "main.py").exists() + assert (target_path / "my_plugin" / "__init__.py").exists() + assert not any( + path.name.startswith(".my_plugin.") and path.name.endswith(".tmp") + for path in target_path.iterdir() + ) + assert not archive_path.exists() + + +@pytest.mark.parametrize( + "member_name", + [ + "../escape.py", + "nested/../../escape.py", + "/absolute.py", + "C:/absolute.py", + "nested/colon:name.py", + ], +) +def test_plugin_updator_unzip_file_rejects_unsafe_member_paths( + tmp_path: Path, + member_name: str, +) -> None: + archive_path = tmp_path / "unsafe_plugin.zip" + target_path = tmp_path / "plugin_upload" + with zipfile.ZipFile(archive_path, "w") as archive: + archive.writestr("main.py", "print('safe')\n") + archive.writestr(member_name, "print('escape')\n") + + with pytest.raises(ValueError, match="Unsafe path in zip archive"): + PluginUpdator().unzip_file(str(archive_path), str(target_path)) + + assert not (target_path / "main.py").exists() + assert not (tmp_path / "escape.py").exists() + + +def test_plugin_updator_rejects_backslash_member_path() -> None: + with pytest.raises(ValueError, match="Unsafe path in zip archive"): + PluginUpdator._get_safe_member_parts(r"nested\windows.py") + + @pytest.mark.asyncio async def test_fetch_release_info_uses_httpx_client_with_env_proxy_support( monkeypatch: pytest.MonkeyPatch,