diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index 2e67f85e5c..b51c7d9560 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -12,7 +12,7 @@ import shutil import zipfile from dataclasses import dataclass, field -from datetime import datetime +from datetime import datetime, timezone from pathlib import Path from typing import TYPE_CHECKING, Any @@ -61,6 +61,69 @@ def _get_major_version(version_str: str) -> str: CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json") KB_PATH = get_astrbot_knowledge_base_path() +DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = 5 +PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV = ( + "ASTRBOT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT" +) + + +def _load_platform_stats_invalid_count_warn_limit() -> int: + raw_value = os.getenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV) + if raw_value is None: + return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + + try: + value = int(raw_value) + if value < 0: + raise ValueError("negative") + return value + except (TypeError, ValueError): + logger.warning( + "Invalid env %s=%r, fallback to default %d", + PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, + raw_value, + DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, + ) + return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + + +PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = ( + _load_platform_stats_invalid_count_warn_limit() +) + + +class _InvalidCountWarnLimiter: + """Rate-limit warnings for invalid platform_stats count values.""" + + def __init__(self, limit: int) -> None: + self.limit = limit + self._count = 0 + self._suppression_logged = False + + def warn_invalid_count(self, value: Any, key_for_log: tuple[Any, ...]) -> None: + if self.limit > 0: + if self._count < self.limit: + logger.warning( + "platform_stats count 非法,已按 0 处理: value=%r, key=%s", + value, + key_for_log, + ) + self._count += 1 + if self._count == self.limit and not self._suppression_logged: + logger.warning( + "platform_stats 非法 count 告警已达到上限 (%d),后续将抑制", + self.limit, + ) + self._suppression_logged = True + return + + if not self._suppression_logged: + # limit <= 0: emit only one suppression warning. + logger.warning( + "platform_stats 非法 count 告警已达到上限 (%d),后续将抑制", + self.limit, + ) + self._suppression_logged = True @dataclass @@ -138,6 +201,10 @@ def to_dict(self) -> dict: } +class DatabaseClearError(RuntimeError): + """Raised when clearing the main database in replace mode fails.""" + + class AstrBotImporter: """AstrBot 数据导入器 @@ -342,6 +409,9 @@ async def import_all( imported = await self._import_main_database(main_data) result.imported_tables.update(imported) + except DatabaseClearError as e: + result.add_error(f"清空主数据库失败: {e}") + return result except Exception as e: result.add_error(f"导入主数据库失败: {e}") return result @@ -452,7 +522,9 @@ async def _clear_main_db(self) -> None: await session.execute(delete(model_class)) logger.debug(f"已清空表 {table_name}") except Exception as e: - logger.warning(f"清空表 {table_name} 失败: {e}") + raise DatabaseClearError( + f"清空表 {table_name} 失败: {e}" + ) from e async def _clear_kb_data(self) -> None: """清空知识库数据""" @@ -494,9 +566,10 @@ async def _import_main_database( if not model_class: logger.warning(f"未知的表: {table_name}") continue + normalized_rows = self._preprocess_main_table_rows(table_name, rows) count = 0 - for row in rows: + for row in normalized_rows: try: # 转换 datetime 字符串为 datetime 对象 row = self._convert_datetime_fields(row, model_class) @@ -511,6 +584,118 @@ async def _import_main_database( return imported + def _preprocess_main_table_rows( + self, table_name: str, rows: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + if table_name == "platform_stats": + normalized_rows = self._merge_platform_stats_rows(rows) + duplicate_count = len(rows) - len(normalized_rows) + if duplicate_count > 0: + logger.warning( + "检测到 %s 重复键 %d 条,已在导入前聚合", + table_name, + duplicate_count, + ) + return normalized_rows + return rows + + def _merge_platform_stats_rows( + self, rows: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + """Merge duplicate platform_stats rows by normalized timestamp/platform key. + + Note: + - Invalid/empty timestamps are kept as distinct rows to avoid accidental merging. + - Non-string platform_id/platform_type are kept as distinct rows. + - Invalid count warnings are rate-limited per function invocation. + """ + merged: dict[tuple[str, str, str], dict[str, Any]] = {} + result: list[dict[str, Any]] = [] + warn_limiter = _InvalidCountWarnLimiter(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT) + + for row in rows: + normalized_row, normalized_timestamp, count = ( + self._normalize_platform_stats_entry(row, warn_limiter) + ) + platform_id = normalized_row.get("platform_id") + platform_type = normalized_row.get("platform_type") + + if ( + normalized_timestamp is None + or not isinstance(platform_id, str) + or not isinstance(platform_type, str) + ): + result.append(normalized_row) + continue + + merge_key = (normalized_timestamp, platform_id, platform_type) + existing = merged.get(merge_key) + if existing is None: + merged[merge_key] = normalized_row + result.append(normalized_row) + else: + existing["count"] += count + + return result + + def _normalize_platform_stats_entry( + self, + row: dict[str, Any], + warn_limiter: _InvalidCountWarnLimiter, + ) -> tuple[dict[str, Any], str | None, int]: + normalized_row = dict(row) + raw_timestamp = normalized_row.get("timestamp") + normalized_timestamp = self._normalize_platform_stats_timestamp(raw_timestamp) + + if normalized_timestamp is not None: + normalized_row["timestamp"] = normalized_timestamp + elif isinstance(raw_timestamp, str): + normalized_row["timestamp"] = raw_timestamp.strip() + elif raw_timestamp is None: + normalized_row["timestamp"] = "" + else: + normalized_row["timestamp"] = str(raw_timestamp) + + raw_count = normalized_row.get("count", 0) + try: + count = int(raw_count) + except (TypeError, ValueError): + key_for_log = ( + normalized_row.get("timestamp"), + repr(normalized_row.get("platform_id")), + repr(normalized_row.get("platform_type")), + ) + warn_limiter.warn_invalid_count(raw_count, key_for_log) + count = 0 + + normalized_row["count"] = count + return normalized_row, normalized_timestamp, count + + def _normalize_platform_stats_timestamp(self, value: Any) -> str | None: + if isinstance(value, datetime): + dt = value + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + else: + dt = dt.astimezone(timezone.utc) + return dt.isoformat() + if isinstance(value, str): + timestamp = value.strip() + if not timestamp: + return None + if timestamp.endswith("Z"): + timestamp = f"{timestamp[:-1]}+00:00" + try: + dt = datetime.fromisoformat(timestamp) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + else: + dt = dt.astimezone(timezone.utc) + return dt.isoformat() + except ValueError: + return None + return None + async def _import_knowledge_bases( self, zf: zipfile.ZipFile, diff --git a/tests/test_backup.py b/tests/test_backup.py index 91db470098..cf3c4d9494 100644 --- a/tests/test_backup.py +++ b/tests/test_backup.py @@ -5,7 +5,7 @@ import re import zipfile from datetime import datetime -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -17,6 +17,8 @@ ) from astrbot.core.backup.exporter import AstrBotExporter from astrbot.core.backup.importer import ( + DatabaseClearError, + PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, AstrBotImporter, ImportResult, _get_major_version, @@ -308,6 +310,298 @@ def test_convert_datetime_fields(self): assert isinstance(result["created_at"], datetime) assert isinstance(result["updated_at"], datetime) + def test_merge_platform_stats_rows(self): + """测试 platform_stats 重复键会在导入前聚合""" + importer = AstrBotImporter(main_db=MagicMock()) + rows = [ + { + "id": 1, + "timestamp": "2025-12-13T20:00:00Z", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 14, + }, + { + "id": 80, + "timestamp": "2025-12-13T20:00:00+00:00", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 3, + }, + { + "id": 81, + "timestamp": "2025-12-13T20:00:00", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 2, + }, + { + "id": 2, + "timestamp": "2025-12-13T21:00:00", + "platform_id": "aiocqhttp", + "platform_type": "unknown", + "count": 1, + }, + ] + + merged_rows = importer._merge_platform_stats_rows(rows) + duplicate_count = len(rows) - len(merged_rows) + + assert duplicate_count == 2 + assert len(merged_rows) == 2 + webchat_row = next( + ( + r + for r in merged_rows + if r.get("timestamp") == "2025-12-13T20:00:00+00:00" + and r.get("platform_id") == "webchat" + and r.get("platform_type") == "unknown" + ), + None, + ) + assert webchat_row is not None + assert webchat_row["timestamp"] == "2025-12-13T20:00:00+00:00" + assert webchat_row["platform_id"] == "webchat" + assert webchat_row["platform_type"] == "unknown" + assert webchat_row["count"] == 19 + + aiocq_row = next( + ( + r + for r in merged_rows + if r.get("platform_id") == "aiocqhttp" + and r.get("platform_type") == "unknown" + ), + None, + ) + assert aiocq_row is not None + assert aiocq_row["timestamp"] == "2025-12-13T21:00:00+00:00" + + def test_merge_platform_stats_rows_normalizes_naive_timestamp_to_utc(self): + """测试 platform_stats 合并前会将 naive timestamp 标准化为 UTC 偏移""" + importer = AstrBotImporter(main_db=MagicMock()) + + rows = [ + { + "timestamp": "2025-12-13T21:00:00", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 1, + }, + { + "timestamp": datetime(2025, 12, 13, 22, 0, 0), + "platform_id": "telegram", + "platform_type": "unknown", + "count": 1, + }, + ] + + merged_rows = importer._merge_platform_stats_rows(rows) + assert len(merged_rows) == 2 + by_platform = {row["platform_id"]: row for row in merged_rows} + assert by_platform["webchat"]["timestamp"] == "2025-12-13T21:00:00+00:00" + assert by_platform["telegram"]["timestamp"] == "2025-12-13T22:00:00+00:00" + + def test_merge_platform_stats_rows_warns_on_invalid_count(self): + """测试 platform_stats count 非法时会告警并按 0 处理(含上限)""" + importer = AstrBotImporter(main_db=MagicMock()) + with patch("astrbot.core.backup.importer.logger.warning") as warning_mock: + rows = [ + { + "timestamp": "2025-12-13T20:00:00+00:00", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 5, + }, + { + "timestamp": "2025-12-13T20:00:00Z", + "platform_id": "webchat", + "platform_type": "unknown", + "count": "bad-count", + }, + ] + merged_rows = importer._merge_platform_stats_rows(rows) + duplicate_count = len(rows) - len(merged_rows) + assert duplicate_count == 1 + assert len(merged_rows) == 1 + assert merged_rows[0]["count"] == 5 + assert warning_mock.call_count == 1 + + warning_mock.reset_mock() + + rows_existing_invalid = [ + { + "timestamp": "2025-12-13T21:00:00+00:00", + "platform_id": "webchat", + "platform_type": "unknown", + "count": "bad-count", + }, + { + "timestamp": "2025-12-13T21:00:00Z", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 7, + }, + ] + merged_rows = importer._merge_platform_stats_rows(rows_existing_invalid) + duplicate_count = len(rows_existing_invalid) - len(merged_rows) + assert duplicate_count == 1 + assert len(merged_rows) == 1 + assert merged_rows[0]["count"] == 7 + assert warning_mock.call_count == 1 + + warning_mock.reset_mock() + + many_invalid_rows = [ + { + "timestamp": "2025-12-13T22:00:00+00:00", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 1, + }, + *[ + { + "timestamp": "2025-12-13T22:00:00Z", + "platform_id": "webchat", + "platform_type": "unknown", + "count": "bad-count", + } + for _ in range(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + 5) + ], + ] + importer._merge_platform_stats_rows(many_invalid_rows) + assert ( + warning_mock.call_count == PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + 1 + ) + assert any( + "告警已达到上限" in str(call.args[0]) + for call in warning_mock.call_args_list + ) + + warning_mock.reset_mock() + + single_invalid_row = [ + { + "timestamp": "2025-12-13T23:00:00+00:00", + "platform_id": "telegram", + "platform_type": "unknown", + "count": "still-bad", + }, + ] + merged_rows = importer._merge_platform_stats_rows(single_invalid_row) + duplicate_count = len(single_invalid_row) - len(merged_rows) + assert duplicate_count == 0 + assert len(merged_rows) == 1 + assert merged_rows[0]["count"] == 0 + assert warning_mock.call_count == 1 + + def test_merge_platform_stats_rows_keeps_invalid_timestamps_distinct(self): + """测试空/非法 timestamp 不参与聚合,避免误合并""" + importer = AstrBotImporter(main_db=MagicMock()) + rows = [ + { + "timestamp": "", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 2, + }, + { + "timestamp": "not-a-datetime", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 3, + }, + { + "timestamp": "not-a-datetime", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 4, + }, + ] + + merged_rows = importer._merge_platform_stats_rows(rows) + duplicate_count = len(rows) - len(merged_rows) + + assert duplicate_count == 0 + assert len(merged_rows) == 3 + assert [row["count"] for row in merged_rows] == [2, 3, 4] + + def test_merge_platform_stats_rows_keeps_non_string_platform_keys_distinct(self): + """测试非字符串 platform_id/platform_type 不参与聚合""" + importer = AstrBotImporter(main_db=MagicMock()) + rows = [ + { + "timestamp": "2025-12-13T20:00:00+00:00", + "platform_id": None, + "platform_type": "unknown", + "count": 2, + }, + { + "timestamp": "2025-12-13T20:00:00Z", + "platform_id": None, + "platform_type": "unknown", + "count": 3, + }, + { + "timestamp": "2025-12-13T20:00:00+00:00", + "platform_id": "webchat", + "platform_type": 1, + "count": 4, + }, + { + "timestamp": "2025-12-13T20:00:00Z", + "platform_id": "webchat", + "platform_type": 1, + "count": 5, + }, + ] + + merged_rows = importer._merge_platform_stats_rows(rows) + duplicate_count = len(rows) - len(merged_rows) + + assert duplicate_count == 0 + assert len(merged_rows) == 4 + + def test_merge_platform_stats_rows_preserves_input_order(self): + """测试 platform_stats 聚合后仍保持输入顺序(按首次出现位置)""" + importer = AstrBotImporter(main_db=MagicMock()) + rows = [ + { + "id": 1, + "timestamp": "2025-12-13T20:00:00Z", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 2, + }, + { + "id": 2, + "timestamp": "", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 3, + }, + { + "id": 3, + "timestamp": "2025-12-13T20:00:00+00:00", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 5, + }, + { + "id": 4, + "timestamp": "2025-12-13T21:00:00+00:00", + "platform_id": "telegram", + "platform_type": "unknown", + "count": 7, + }, + ] + + merged_rows = importer._merge_platform_stats_rows(rows) + + assert len(merged_rows) == 3 + assert [row["id"] for row in merged_rows] == [1, 2, 4] + assert merged_rows[0]["count"] == 7 + @pytest.mark.asyncio async def test_import_file_not_exists(self, mock_main_db, tmp_path): """测试导入不存在的文件""" @@ -365,6 +659,35 @@ async def test_import_major_version_mismatch(self, mock_main_db, tmp_path): assert result.success is False assert any("主版本不兼容" in err for err in result.errors) + @pytest.mark.asyncio + async def test_import_replace_fails_when_clear_main_db_fails( + self, mock_main_db, tmp_path + ): + """测试 replace 模式下主库清空失败会直接终止导入""" + zip_path = tmp_path / "valid_backup.zip" + manifest = { + "version": "1.1", + "astrbot_version": VERSION, + "tables": {"platform_stats": 0}, + } + main_data = {"platform_stats": []} + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("manifest.json", json.dumps(manifest)) + zf.writestr("databases/main_db.json", json.dumps(main_data)) + + importer = AstrBotImporter(main_db=mock_main_db) + importer._clear_main_db = AsyncMock( + side_effect=DatabaseClearError("清空表 platform_stats 失败: db locked") + ) + importer._import_main_database = AsyncMock(return_value={}) + + result = await importer.import_all(str(zip_path), mode="replace") + + assert result.success is False + assert any("清空主数据库失败" in err for err in result.errors) + assert any("清空表 platform_stats 失败" in err for err in result.errors) + importer._import_main_database.assert_not_awaited() + class TestSecureFilename: """安全文件名函数测试"""