From fc2af1f615f76068d1955dc4e34a5aa1526fbc00 Mon Sep 17 00:00:00 2001 From: Clhikari Date: Sun, 1 Mar 2026 16:50:55 +0800 Subject: [PATCH 01/16] fix: harden backup import for duplicate platform stats MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修复 replace 模式下主库清空失败仍继续导入的问题。 - 导入前对 platform_stats 重复键做聚合(count 累加),并统一时间戳判重格式。 - 非法 count 按 0 处理并告警(限流),补充对应测试。 --- astrbot/core/backup/importer.py | 95 +++++++++++++++++++++++++++++++-- tests/test_backup.py | 91 ++++++++++++++++++++++++++++++- 2 files changed, 182 insertions(+), 4 deletions(-) diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index 2e67f85e5c..6102dbf92a 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -12,7 +12,7 @@ import shutil import zipfile from dataclasses import dataclass, field -from datetime import datetime +from datetime import datetime, timezone from pathlib import Path from typing import TYPE_CHECKING, Any @@ -452,7 +452,7 @@ async def _clear_main_db(self) -> None: await session.execute(delete(model_class)) logger.debug(f"已清空表 {table_name}") except Exception as e: - logger.warning(f"清空表 {table_name} 失败: {e}") + raise RuntimeError(f"清空表 {table_name} 失败: {e}") from e async def _clear_kb_data(self) -> None: """清空知识库数据""" @@ -494,9 +494,18 @@ async def _import_main_database( if not model_class: logger.warning(f"未知的表: {table_name}") continue + normalized_rows = rows + if table_name == "platform_stats": + normalized_rows, duplicate_count = ( + self._merge_platform_stats_rows(rows) + ) + if duplicate_count > 0: + logger.warning( + f"检测到 platform_stats 重复键 {duplicate_count} 条,已在导入前聚合" + ) count = 0 - for row in rows: + for row in normalized_rows: try: # 转换 datetime 字符串为 datetime 对象 row = self._convert_datetime_fields(row, model_class) @@ -511,6 +520,86 @@ async def _import_main_database( return imported + def _merge_platform_stats_rows( + self, rows: list[dict[str, Any]] + ) -> tuple[list[dict[str, Any]], int]: + merged: dict[tuple[str, str, str], dict[str, Any]] = {} + timestamp_cache: dict[str, str] = {} + invalid_count_warned = 0 + invalid_count_warn_limit = 5 + duplicate_count = 0 + for row in rows: + raw_timestamp = row.get("timestamp") + if isinstance(raw_timestamp, str): + normalized_timestamp = timestamp_cache.get(raw_timestamp) + if normalized_timestamp is None: + normalized_timestamp = self._normalize_platform_stats_timestamp( + raw_timestamp + ) + timestamp_cache[raw_timestamp] = normalized_timestamp + else: + normalized_timestamp = self._normalize_platform_stats_timestamp( + raw_timestamp + ) + key = ( + normalized_timestamp, + str(row.get("platform_id")), + str(row.get("platform_type")), + ) + existing = merged.get(key) + if existing is None: + merged[key] = dict(row) + continue + duplicate_count += 1 + existing_raw_count = existing.get("count", 0) + try: + existing_count = int(existing_raw_count) + except (TypeError, ValueError): + existing_count = 0 + if invalid_count_warned < invalid_count_warn_limit: + logger.warning( + "platform_stats count 非法,已按 0 处理: " + f"value={existing_raw_count!r}, key={key}" + ) + invalid_count_warned += 1 + + incoming_raw_count = row.get("count", 0) + try: + incoming_count = int(incoming_raw_count) + except (TypeError, ValueError): + incoming_count = 0 + if invalid_count_warned < invalid_count_warn_limit: + logger.warning( + "platform_stats count 非法,已按 0 处理: " + f"value={incoming_raw_count!r}, key={key}" + ) + invalid_count_warned += 1 + existing["count"] = existing_count + incoming_count + return list(merged.values()), duplicate_count + + def _normalize_platform_stats_timestamp(self, value: Any) -> str: + if isinstance(value, datetime): + dt = value + if dt.tzinfo is not None: + dt = dt.astimezone(timezone.utc) + return dt.isoformat() + if isinstance(value, str): + timestamp = value.strip() + if not timestamp: + return "" + if timestamp.endswith("Z"): + timestamp = f"{timestamp[:-1]}+00:00" + try: + dt = datetime.fromisoformat(timestamp) + if dt.tzinfo is not None: + dt = dt.astimezone(timezone.utc) + return dt.isoformat() + except ValueError: + return value.strip() + if value is None: + return "" + return str(value) + async def _import_knowledge_bases( self, zf: zipfile.ZipFile, diff --git a/tests/test_backup.py b/tests/test_backup.py index 91db470098..ec77af8514 100644 --- a/tests/test_backup.py +++ b/tests/test_backup.py @@ -5,7 +5,7 @@ import re import zipfile from datetime import datetime -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -308,6 +308,69 @@ def test_convert_datetime_fields(self): assert isinstance(result["created_at"], datetime) assert isinstance(result["updated_at"], datetime) + def test_merge_platform_stats_rows(self): + """测试 platform_stats 重复键会在导入前聚合""" + importer = AstrBotImporter(main_db=MagicMock()) + rows = [ + { + "id": 1, + "timestamp": "2025-12-13T20:00:00Z", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 14, + }, + { + "id": 80, + "timestamp": "2025-12-13T20:00:00+00:00", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 3, + }, + { + "id": 2, + "timestamp": "2025-12-13T21:00:00", + "platform_id": "aiocqhttp", + "platform_type": "unknown", + "count": 1, + }, + ] + + merged_rows, duplicate_count = importer._merge_platform_stats_rows(rows) + + assert duplicate_count == 1 + assert len(merged_rows) == 2 + first = merged_rows[0] + assert first["timestamp"] == "2025-12-13T20:00:00Z" + assert first["platform_id"] == "webchat" + assert first["platform_type"] == "unknown" + assert first["count"] == 17 + + def test_merge_platform_stats_rows_warns_on_invalid_count(self): + """测试 platform_stats count 非法时会告警并按 0 处理""" + importer = AstrBotImporter(main_db=MagicMock()) + rows = [ + { + "timestamp": "2025-12-13T20:00:00+00:00", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 5, + }, + { + "timestamp": "2025-12-13T20:00:00Z", + "platform_id": "webchat", + "platform_type": "unknown", + "count": "bad-count", + }, + ] + + with patch("astrbot.core.backup.importer.logger.warning") as warning_mock: + merged_rows, duplicate_count = importer._merge_platform_stats_rows(rows) + + assert duplicate_count == 1 + assert len(merged_rows) == 1 + assert merged_rows[0]["count"] == 5 + assert warning_mock.called + @pytest.mark.asyncio async def test_import_file_not_exists(self, mock_main_db, tmp_path): """测试导入不存在的文件""" @@ -365,6 +428,32 @@ async def test_import_major_version_mismatch(self, mock_main_db, tmp_path): assert result.success is False assert any("主版本不兼容" in err for err in result.errors) + @pytest.mark.asyncio + async def test_import_replace_fails_when_clear_main_db_fails( + self, mock_main_db, tmp_path + ): + """测试 replace 模式下主库清空失败会直接终止导入""" + zip_path = tmp_path / "valid_backup.zip" + manifest = { + "version": "1.1", + "astrbot_version": VERSION, + "tables": {"platform_stats": 0}, + } + main_data = {"platform_stats": []} + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("manifest.json", json.dumps(manifest)) + zf.writestr("databases/main_db.json", json.dumps(main_data)) + + importer = AstrBotImporter(main_db=mock_main_db) + importer._clear_main_db = AsyncMock( + side_effect=RuntimeError("清空表 platform_stats 失败: db locked") + ) + + result = await importer.import_all(str(zip_path), mode="replace") + + assert result.success is False + assert any("清空表 platform_stats 失败" in err for err in result.errors) + class TestSecureFilename: """安全文件名函数测试""" From f71990adae08df5d007f4671f2886407962045c3 Mon Sep 17 00:00:00 2001 From: Clhikari Date: Sun, 1 Mar 2026 18:03:32 +0800 Subject: [PATCH 02/16] refactor: improve robustness and readability of platform stats import MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 告警上限魔法数字提取为模块常量 PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT - 抽取 parse_count 内联函数,消除重复的 try/except 分支 - 存储行的 timestamp 同步写入规范化值,避免落库格式混用 - 补充测试:已有行 count 非法、告警限流、replace 模式中断断言 --- astrbot/core/backup/importer.py | 45 ++++++------- tests/test_backup.py | 109 ++++++++++++++++++++++++-------- 2 files changed, 103 insertions(+), 51 deletions(-) diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index 6102dbf92a..0f8291f49d 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -61,6 +61,7 @@ def _get_major_version(version_str: str) -> str: CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json") KB_PATH = get_astrbot_knowledge_base_path() +PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = 5 @dataclass @@ -526,8 +527,21 @@ def _merge_platform_stats_rows( merged: dict[tuple[str, str, str], dict[str, Any]] = {} timestamp_cache: dict[str, str] = {} invalid_count_warned = 0 - invalid_count_warn_limit = 5 duplicate_count = 0 + + def parse_count(raw_count: Any, key: tuple[str, str, str]) -> int: + nonlocal invalid_count_warned + try: + return int(raw_count) + except (TypeError, ValueError): + if invalid_count_warned < PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT: + logger.warning( + "platform_stats count 非法,已按 0 处理: " + f"value={raw_count!r}, key={key}" + ) + invalid_count_warned += 1 + return 0 + for row in rows: raw_timestamp = row.get("timestamp") if isinstance(raw_timestamp, str): @@ -548,32 +562,13 @@ def _merge_platform_stats_rows( ) existing = merged.get(key) if existing is None: - merged[key] = dict(row) + normalized_row = dict(row) + normalized_row["timestamp"] = normalized_timestamp + merged[key] = normalized_row continue duplicate_count += 1 - existing_raw_count = existing.get("count", 0) - try: - existing_count = int(existing_raw_count) - except (TypeError, ValueError): - existing_count = 0 - if invalid_count_warned < invalid_count_warn_limit: - logger.warning( - "platform_stats count 非法,已按 0 处理: " - f"value={existing_raw_count!r}, key={key}" - ) - invalid_count_warned += 1 - - incoming_raw_count = row.get("count", 0) - try: - incoming_count = int(incoming_raw_count) - except (TypeError, ValueError): - incoming_count = 0 - if invalid_count_warned < invalid_count_warn_limit: - logger.warning( - "platform_stats count 非法,已按 0 处理: " - f"value={incoming_raw_count!r}, key={key}" - ) - invalid_count_warned += 1 + existing_count = parse_count(existing.get("count", 0), key) + incoming_count = parse_count(row.get("count", 0), key) existing["count"] = existing_count + incoming_count return list(merged.values()), duplicate_count diff --git a/tests/test_backup.py b/tests/test_backup.py index ec77af8514..4ee6641c8c 100644 --- a/tests/test_backup.py +++ b/tests/test_backup.py @@ -17,6 +17,7 @@ ) from astrbot.core.backup.exporter import AstrBotExporter from astrbot.core.backup.importer import ( + PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, AstrBotImporter, ImportResult, _get_major_version, @@ -339,37 +340,91 @@ def test_merge_platform_stats_rows(self): assert duplicate_count == 1 assert len(merged_rows) == 2 - first = merged_rows[0] - assert first["timestamp"] == "2025-12-13T20:00:00Z" - assert first["platform_id"] == "webchat" - assert first["platform_type"] == "unknown" - assert first["count"] == 17 + webchat_row = next( + ( + r + for r in merged_rows + if r.get("timestamp") == "2025-12-13T20:00:00+00:00" + and r.get("platform_id") == "webchat" + and r.get("platform_type") == "unknown" + ), + None, + ) + assert webchat_row is not None + assert webchat_row["timestamp"] == "2025-12-13T20:00:00+00:00" + assert webchat_row["platform_id"] == "webchat" + assert webchat_row["platform_type"] == "unknown" + assert webchat_row["count"] == 17 def test_merge_platform_stats_rows_warns_on_invalid_count(self): - """测试 platform_stats count 非法时会告警并按 0 处理""" + """测试 platform_stats count 非法时会告警并按 0 处理(含上限)""" importer = AstrBotImporter(main_db=MagicMock()) - rows = [ - { - "timestamp": "2025-12-13T20:00:00+00:00", - "platform_id": "webchat", - "platform_type": "unknown", - "count": 5, - }, - { - "timestamp": "2025-12-13T20:00:00Z", - "platform_id": "webchat", - "platform_type": "unknown", - "count": "bad-count", - }, - ] - with patch("astrbot.core.backup.importer.logger.warning") as warning_mock: + rows = [ + { + "timestamp": "2025-12-13T20:00:00+00:00", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 5, + }, + { + "timestamp": "2025-12-13T20:00:00Z", + "platform_id": "webchat", + "platform_type": "unknown", + "count": "bad-count", + }, + ] merged_rows, duplicate_count = importer._merge_platform_stats_rows(rows) - - assert duplicate_count == 1 - assert len(merged_rows) == 1 - assert merged_rows[0]["count"] == 5 - assert warning_mock.called + assert duplicate_count == 1 + assert len(merged_rows) == 1 + assert merged_rows[0]["count"] == 5 + assert warning_mock.call_count == 1 + + warning_mock.reset_mock() + + rows_existing_invalid = [ + { + "timestamp": "2025-12-13T21:00:00+00:00", + "platform_id": "webchat", + "platform_type": "unknown", + "count": "bad-count", + }, + { + "timestamp": "2025-12-13T21:00:00Z", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 7, + }, + ] + merged_rows, duplicate_count = importer._merge_platform_stats_rows( + rows_existing_invalid + ) + assert duplicate_count == 1 + assert len(merged_rows) == 1 + assert merged_rows[0]["count"] == 7 + assert warning_mock.call_count == 1 + + warning_mock.reset_mock() + + many_invalid_rows = [ + { + "timestamp": "2025-12-13T22:00:00+00:00", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 1, + }, + *[ + { + "timestamp": "2025-12-13T22:00:00Z", + "platform_id": "webchat", + "platform_type": "unknown", + "count": "bad-count", + } + for _ in range(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + 5) + ], + ] + importer._merge_platform_stats_rows(many_invalid_rows) + assert warning_mock.call_count == PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT @pytest.mark.asyncio async def test_import_file_not_exists(self, mock_main_db, tmp_path): @@ -448,11 +503,13 @@ async def test_import_replace_fails_when_clear_main_db_fails( importer._clear_main_db = AsyncMock( side_effect=RuntimeError("清空表 platform_stats 失败: db locked") ) + importer._import_main_database = AsyncMock(return_value={}) result = await importer.import_all(str(zip_path), mode="replace") assert result.success is False assert any("清空表 platform_stats 失败" in err for err in result.errors) + importer._import_main_database.assert_not_awaited() class TestSecureFilename: From a50727da185c7495d179ce7171ec7f528c887b7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 19:25:29 +0900 Subject: [PATCH 03/16] fix: normalize invalid platform_stats count for non-duplicate rows --- astrbot/core/backup/importer.py | 3 +++ tests/test_backup.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index 0f8291f49d..ec5ab0f37f 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -564,6 +564,9 @@ def parse_count(raw_count: Any, key: tuple[str, str, str]) -> int: if existing is None: normalized_row = dict(row) normalized_row["timestamp"] = normalized_timestamp + normalized_row["count"] = parse_count( + normalized_row.get("count", 0), key + ) merged[key] = normalized_row continue duplicate_count += 1 diff --git a/tests/test_backup.py b/tests/test_backup.py index 4ee6641c8c..9dd23fb8fd 100644 --- a/tests/test_backup.py +++ b/tests/test_backup.py @@ -426,6 +426,24 @@ def test_merge_platform_stats_rows_warns_on_invalid_count(self): importer._merge_platform_stats_rows(many_invalid_rows) assert warning_mock.call_count == PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + warning_mock.reset_mock() + + single_invalid_row = [ + { + "timestamp": "2025-12-13T23:00:00+00:00", + "platform_id": "telegram", + "platform_type": "unknown", + "count": "still-bad", + }, + ] + merged_rows, duplicate_count = importer._merge_platform_stats_rows( + single_invalid_row + ) + assert duplicate_count == 0 + assert len(merged_rows) == 1 + assert merged_rows[0]["count"] == 0 + assert warning_mock.call_count == 1 + @pytest.mark.asyncio async def test_import_file_not_exists(self, mock_main_db, tmp_path): """测试导入不存在的文件""" From f091b85af64423cc551dab3a43d5340dac4dc48e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 19:31:08 +0900 Subject: [PATCH 04/16] fix: avoid merging invalid platform_stats timestamps --- astrbot/core/backup/importer.py | 43 +++++++++++++++++++++------------ tests/test_backup.py | 30 +++++++++++++++++++++++ 2 files changed, 58 insertions(+), 15 deletions(-) diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index ec5ab0f37f..aed608a98b 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -61,6 +61,7 @@ def _get_major_version(version_str: str) -> str: CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json") KB_PATH = get_astrbot_knowledge_base_path() +# Warning limit per _merge_platform_stats_rows invocation. PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = 5 @@ -524,8 +525,14 @@ async def _import_main_database( def _merge_platform_stats_rows( self, rows: list[dict[str, Any]] ) -> tuple[list[dict[str, Any]], int]: + """Merge duplicate platform_stats rows by normalized timestamp/platform key. + + Note: + - Invalid/empty timestamps are kept as distinct rows to avoid accidental merging. + - Invalid count warnings are rate-limited per function invocation. + """ merged: dict[tuple[str, str, str], dict[str, Any]] = {} - timestamp_cache: dict[str, str] = {} + timestamp_cache: dict[str, tuple[str, bool]] = {} invalid_count_warned = 0 duplicate_count = 0 @@ -542,21 +549,27 @@ def parse_count(raw_count: Any, key: tuple[str, str, str]) -> int: invalid_count_warned += 1 return 0 - for row in rows: + for row_index, row in enumerate(rows): raw_timestamp = row.get("timestamp") if isinstance(raw_timestamp, str): - normalized_timestamp = timestamp_cache.get(raw_timestamp) - if normalized_timestamp is None: - normalized_timestamp = self._normalize_platform_stats_timestamp( + timestamp_result = timestamp_cache.get(raw_timestamp) + if timestamp_result is None: + timestamp_result = self._normalize_platform_stats_timestamp( raw_timestamp ) - timestamp_cache[raw_timestamp] = normalized_timestamp + timestamp_cache[raw_timestamp] = timestamp_result else: - normalized_timestamp = self._normalize_platform_stats_timestamp( + timestamp_result = self._normalize_platform_stats_timestamp( raw_timestamp ) + normalized_timestamp, is_timestamp_valid = timestamp_result + timestamp_for_key = ( + normalized_timestamp + if is_timestamp_valid + else f"__invalid_timestamp_row_{row_index}" + ) key = ( - normalized_timestamp, + timestamp_for_key, str(row.get("platform_id")), str(row.get("platform_type")), ) @@ -575,28 +588,28 @@ def parse_count(raw_count: Any, key: tuple[str, str, str]) -> int: existing["count"] = existing_count + incoming_count return list(merged.values()), duplicate_count - def _normalize_platform_stats_timestamp(self, value: Any) -> str: + def _normalize_platform_stats_timestamp(self, value: Any) -> tuple[str, bool]: if isinstance(value, datetime): dt = value if dt.tzinfo is not None: dt = dt.astimezone(timezone.utc) - return dt.isoformat() + return dt.isoformat(), True if isinstance(value, str): timestamp = value.strip() if not timestamp: - return "" + return "", False if timestamp.endswith("Z"): timestamp = f"{timestamp[:-1]}+00:00" try: dt = datetime.fromisoformat(timestamp) if dt.tzinfo is not None: dt = dt.astimezone(timezone.utc) - return dt.isoformat() + return dt.isoformat(), True except ValueError: - return value.strip() + return value.strip(), False if value is None: - return "" - return str(value) + return "", False + return str(value), False async def _import_knowledge_bases( self, diff --git a/tests/test_backup.py b/tests/test_backup.py index 9dd23fb8fd..7bf2c16d4e 100644 --- a/tests/test_backup.py +++ b/tests/test_backup.py @@ -444,6 +444,36 @@ def test_merge_platform_stats_rows_warns_on_invalid_count(self): assert merged_rows[0]["count"] == 0 assert warning_mock.call_count == 1 + def test_merge_platform_stats_rows_keeps_invalid_timestamps_distinct(self): + """测试空/非法 timestamp 不参与聚合,避免误合并""" + importer = AstrBotImporter(main_db=MagicMock()) + rows = [ + { + "timestamp": "", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 2, + }, + { + "timestamp": "not-a-datetime", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 3, + }, + { + "timestamp": "not-a-datetime", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 4, + }, + ] + + merged_rows, duplicate_count = importer._merge_platform_stats_rows(rows) + + assert duplicate_count == 0 + assert len(merged_rows) == 3 + assert [row["count"] for row in merged_rows] == [2, 3, 4] + @pytest.mark.asyncio async def test_import_file_not_exists(self, mock_main_db, tmp_path): """测试导入不存在的文件""" From 7a590f2f1381f821141496e41098f49b7511c8df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 19:36:04 +0900 Subject: [PATCH 05/16] refactor: simplify platform stats merge and normalize naive UTC --- astrbot/core/backup/importer.py | 100 +++++++++++++++++++++----------- tests/test_backup.py | 39 ++++++++++++- 2 files changed, 102 insertions(+), 37 deletions(-) diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index aed608a98b..2dc98be1fd 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -536,37 +536,13 @@ def _merge_platform_stats_rows( invalid_count_warned = 0 duplicate_count = 0 - def parse_count(raw_count: Any, key: tuple[str, str, str]) -> int: - nonlocal invalid_count_warned - try: - return int(raw_count) - except (TypeError, ValueError): - if invalid_count_warned < PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT: - logger.warning( - "platform_stats count 非法,已按 0 处理: " - f"value={raw_count!r}, key={key}" - ) - invalid_count_warned += 1 - return 0 - for row_index, row in enumerate(rows): - raw_timestamp = row.get("timestamp") - if isinstance(raw_timestamp, str): - timestamp_result = timestamp_cache.get(raw_timestamp) - if timestamp_result is None: - timestamp_result = self._normalize_platform_stats_timestamp( - raw_timestamp - ) - timestamp_cache[raw_timestamp] = timestamp_result - else: - timestamp_result = self._normalize_platform_stats_timestamp( - raw_timestamp - ) + timestamp_result = self._normalize_platform_stats_timestamp_cached( + row.get("timestamp"), timestamp_cache + ) normalized_timestamp, is_timestamp_valid = timestamp_result - timestamp_for_key = ( - normalized_timestamp - if is_timestamp_valid - else f"__invalid_timestamp_row_{row_index}" + timestamp_for_key = self._platform_stats_key_timestamp( + normalized_timestamp, is_timestamp_valid, row_index ) key = ( timestamp_for_key, @@ -577,21 +553,73 @@ def parse_count(raw_count: Any, key: tuple[str, str, str]) -> int: if existing is None: normalized_row = dict(row) normalized_row["timestamp"] = normalized_timestamp - normalized_row["count"] = parse_count( - normalized_row.get("count", 0), key + normalized_row["count"], invalid_count_warned = ( + self._parse_platform_stats_count( + normalized_row.get("count", 0), key, invalid_count_warned + ) ) merged[key] = normalized_row continue duplicate_count += 1 - existing_count = parse_count(existing.get("count", 0), key) - incoming_count = parse_count(row.get("count", 0), key) + existing_count, invalid_count_warned = self._parse_platform_stats_count( + existing.get("count", 0), key, invalid_count_warned + ) + incoming_count, invalid_count_warned = self._parse_platform_stats_count( + row.get("count", 0), key, invalid_count_warned + ) existing["count"] = existing_count + incoming_count return list(merged.values()), duplicate_count + def _parse_platform_stats_count( + self, + raw_count: Any, + key: tuple[str, str, str], + warned_count: int, + ) -> tuple[int, int]: + """Parse count and rate-limit invalid-value warnings per merge invocation.""" + try: + return int(raw_count), warned_count + except (TypeError, ValueError): + if warned_count < PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT: + logger.warning( + "platform_stats count 非法,已按 0 处理: " + f"value={raw_count!r}, key={key}" + ) + warned_count += 1 + return 0, warned_count + + def _normalize_platform_stats_timestamp_cached( + self, + raw_timestamp: Any, + cache: dict[str, tuple[str, bool]], + ) -> tuple[str, bool]: + """Normalize timestamp with a cache for repeated string values.""" + if isinstance(raw_timestamp, str): + cached = cache.get(raw_timestamp) + if cached is not None: + return cached + result = self._normalize_platform_stats_timestamp(raw_timestamp) + cache[raw_timestamp] = result + return result + return self._normalize_platform_stats_timestamp(raw_timestamp) + + def _platform_stats_key_timestamp( + self, + normalized_timestamp: str, + is_valid: bool, + row_index: int, + ) -> str: + """Build key timestamp value; keep invalid timestamps distinct by row index.""" + if is_valid: + return normalized_timestamp + return f"__invalid_timestamp_row_{row_index}" + def _normalize_platform_stats_timestamp(self, value: Any) -> tuple[str, bool]: if isinstance(value, datetime): dt = value - if dt.tzinfo is not None: + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + else: dt = dt.astimezone(timezone.utc) return dt.isoformat(), True if isinstance(value, str): @@ -602,7 +630,9 @@ def _normalize_platform_stats_timestamp(self, value: Any) -> tuple[str, bool]: timestamp = f"{timestamp[:-1]}+00:00" try: dt = datetime.fromisoformat(timestamp) - if dt.tzinfo is not None: + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + else: dt = dt.astimezone(timezone.utc) return dt.isoformat(), True except ValueError: diff --git a/tests/test_backup.py b/tests/test_backup.py index 7bf2c16d4e..5a7910af39 100644 --- a/tests/test_backup.py +++ b/tests/test_backup.py @@ -327,6 +327,13 @@ def test_merge_platform_stats_rows(self): "platform_type": "unknown", "count": 3, }, + { + "id": 81, + "timestamp": "2025-12-13T20:00:00", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 2, + }, { "id": 2, "timestamp": "2025-12-13T21:00:00", @@ -338,7 +345,7 @@ def test_merge_platform_stats_rows(self): merged_rows, duplicate_count = importer._merge_platform_stats_rows(rows) - assert duplicate_count == 1 + assert duplicate_count == 2 assert len(merged_rows) == 2 webchat_row = next( ( @@ -354,7 +361,35 @@ def test_merge_platform_stats_rows(self): assert webchat_row["timestamp"] == "2025-12-13T20:00:00+00:00" assert webchat_row["platform_id"] == "webchat" assert webchat_row["platform_type"] == "unknown" - assert webchat_row["count"] == 17 + assert webchat_row["count"] == 19 + + aiocq_row = next( + ( + r + for r in merged_rows + if r.get("platform_id") == "aiocqhttp" + and r.get("platform_type") == "unknown" + ), + None, + ) + assert aiocq_row is not None + assert aiocq_row["timestamp"] == "2025-12-13T21:00:00+00:00" + + def test_normalize_platform_stats_timestamp_treats_naive_as_utc(self): + """测试 naive timestamp 会统一转为显式 UTC 偏移""" + importer = AstrBotImporter(main_db=MagicMock()) + + normalized, is_valid = importer._normalize_platform_stats_timestamp( + "2025-12-13T21:00:00" + ) + assert is_valid is True + assert normalized == "2025-12-13T21:00:00+00:00" + + normalized_dt, is_valid_dt = importer._normalize_platform_stats_timestamp( + datetime(2025, 12, 13, 21, 0, 0) + ) + assert is_valid_dt is True + assert normalized_dt == "2025-12-13T21:00:00+00:00" def test_merge_platform_stats_rows_warns_on_invalid_count(self): """测试 platform_stats count 非法时会告警并按 0 处理(含上限)""" From 9df6a66184e6181df98954efe9168ec276082a59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 19:40:33 +0900 Subject: [PATCH 06/16] refactor: inline platform stats merge helpers --- astrbot/core/backup/importer.py | 93 ++++++++++++--------------------- 1 file changed, 33 insertions(+), 60 deletions(-) diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index 2dc98be1fd..47fdc19e89 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -536,14 +536,37 @@ def _merge_platform_stats_rows( invalid_count_warned = 0 duplicate_count = 0 + def normalize_ts(raw_timestamp: Any) -> tuple[str, bool]: + if isinstance(raw_timestamp, str): + cached = timestamp_cache.get(raw_timestamp) + if cached is not None: + return cached + result = self._normalize_platform_stats_timestamp(raw_timestamp) + if isinstance(raw_timestamp, str): + timestamp_cache[raw_timestamp] = result + return result + + def parse_count(raw_count: Any, key: tuple[str, str, str]) -> int: + nonlocal invalid_count_warned + try: + return int(raw_count) + except (TypeError, ValueError): + if invalid_count_warned < PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT: + logger.warning( + "platform_stats count 非法,已按 0 处理: " + f"value={raw_count!r}, key={key}" + ) + invalid_count_warned += 1 + return 0 + for row_index, row in enumerate(rows): - timestamp_result = self._normalize_platform_stats_timestamp_cached( - row.get("timestamp"), timestamp_cache - ) - normalized_timestamp, is_timestamp_valid = timestamp_result - timestamp_for_key = self._platform_stats_key_timestamp( - normalized_timestamp, is_timestamp_valid, row_index + normalized_timestamp, is_timestamp_valid = normalize_ts( + row.get("timestamp") ) + if is_timestamp_valid: + timestamp_for_key = normalized_timestamp + else: + timestamp_for_key = f"__invalid_timestamp_row_{row_index}" key = ( timestamp_for_key, str(row.get("platform_id")), @@ -553,67 +576,17 @@ def _merge_platform_stats_rows( if existing is None: normalized_row = dict(row) normalized_row["timestamp"] = normalized_timestamp - normalized_row["count"], invalid_count_warned = ( - self._parse_platform_stats_count( - normalized_row.get("count", 0), key, invalid_count_warned - ) + normalized_row["count"] = parse_count( + normalized_row.get("count", 0), key ) merged[key] = normalized_row continue duplicate_count += 1 - existing_count, invalid_count_warned = self._parse_platform_stats_count( - existing.get("count", 0), key, invalid_count_warned - ) - incoming_count, invalid_count_warned = self._parse_platform_stats_count( - row.get("count", 0), key, invalid_count_warned - ) + existing_count = parse_count(existing.get("count", 0), key) + incoming_count = parse_count(row.get("count", 0), key) existing["count"] = existing_count + incoming_count return list(merged.values()), duplicate_count - def _parse_platform_stats_count( - self, - raw_count: Any, - key: tuple[str, str, str], - warned_count: int, - ) -> tuple[int, int]: - """Parse count and rate-limit invalid-value warnings per merge invocation.""" - try: - return int(raw_count), warned_count - except (TypeError, ValueError): - if warned_count < PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT: - logger.warning( - "platform_stats count 非法,已按 0 处理: " - f"value={raw_count!r}, key={key}" - ) - warned_count += 1 - return 0, warned_count - - def _normalize_platform_stats_timestamp_cached( - self, - raw_timestamp: Any, - cache: dict[str, tuple[str, bool]], - ) -> tuple[str, bool]: - """Normalize timestamp with a cache for repeated string values.""" - if isinstance(raw_timestamp, str): - cached = cache.get(raw_timestamp) - if cached is not None: - return cached - result = self._normalize_platform_stats_timestamp(raw_timestamp) - cache[raw_timestamp] = result - return result - return self._normalize_platform_stats_timestamp(raw_timestamp) - - def _platform_stats_key_timestamp( - self, - normalized_timestamp: str, - is_valid: bool, - row_index: int, - ) -> str: - """Build key timestamp value; keep invalid timestamps distinct by row index.""" - if is_valid: - return normalized_timestamp - return f"__invalid_timestamp_row_{row_index}" - def _normalize_platform_stats_timestamp(self, value: Any) -> tuple[str, bool]: if isinstance(value, datetime): dt = value From 5bd71e4f685b24fedd5f8fee0f8ad68b96a3312b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 19:44:46 +0900 Subject: [PATCH 07/16] refactor: flatten platform stats merge flow --- astrbot/core/backup/importer.py | 104 ++++++++++++++++---------------- tests/test_backup.py | 19 +++--- 2 files changed, 62 insertions(+), 61 deletions(-) diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index 47fdc19e89..e34fa1dfc9 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -498,9 +498,8 @@ async def _import_main_database( continue normalized_rows = rows if table_name == "platform_stats": - normalized_rows, duplicate_count = ( - self._merge_platform_stats_rows(rows) - ) + normalized_rows = self._merge_platform_stats_rows(rows) + duplicate_count = len(rows) - len(normalized_rows) if duplicate_count > 0: logger.warning( f"检测到 platform_stats 重复键 {duplicate_count} 条,已在导入前聚合" @@ -524,7 +523,7 @@ async def _import_main_database( def _merge_platform_stats_rows( self, rows: list[dict[str, Any]] - ) -> tuple[list[dict[str, Any]], int]: + ) -> list[dict[str, Any]]: """Merge duplicate platform_stats rows by normalized timestamp/platform key. Note: @@ -532,60 +531,61 @@ def _merge_platform_stats_rows( - Invalid count warnings are rate-limited per function invocation. """ merged: dict[tuple[str, str, str], dict[str, Any]] = {} - timestamp_cache: dict[str, tuple[str, bool]] = {} + result: list[dict[str, Any]] = [] invalid_count_warned = 0 - duplicate_count = 0 - - def normalize_ts(raw_timestamp: Any) -> tuple[str, bool]: - if isinstance(raw_timestamp, str): - cached = timestamp_cache.get(raw_timestamp) - if cached is not None: - return cached - result = self._normalize_platform_stats_timestamp(raw_timestamp) - if isinstance(raw_timestamp, str): - timestamp_cache[raw_timestamp] = result - return result - - def parse_count(raw_count: Any, key: tuple[str, str, str]) -> int: - nonlocal invalid_count_warned - try: - return int(raw_count) - except (TypeError, ValueError): - if invalid_count_warned < PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT: - logger.warning( - "platform_stats count 非法,已按 0 处理: " - f"value={raw_count!r}, key={key}" - ) - invalid_count_warned += 1 - return 0 - - for row_index, row in enumerate(rows): - normalized_timestamp, is_timestamp_valid = normalize_ts( - row.get("timestamp") + for row in rows: + normalized_row = dict(row) + normalized_timestamp, is_timestamp_valid = ( + self._normalize_platform_stats_timestamp( + normalized_row.get("timestamp") + ) ) - if is_timestamp_valid: - timestamp_for_key = normalized_timestamp - else: - timestamp_for_key = f"__invalid_timestamp_row_{row_index}" - key = ( - timestamp_for_key, - str(row.get("platform_id")), - str(row.get("platform_type")), + normalized_row["timestamp"] = normalized_timestamp + + platform_id = str(normalized_row.get("platform_id")) + platform_type = str(normalized_row.get("platform_type")) + key_for_log = ( + normalized_timestamp if is_timestamp_valid else "", + platform_id, + platform_type, + ) + count, invalid_count_warned = self._parse_platform_stats_count( + normalized_row.get("count", 0), invalid_count_warned, key_for_log ) + normalized_row["count"] = count + + # Invalid timestamps should never be merged. + if not is_timestamp_valid: + result.append(normalized_row) + continue + + key = (normalized_timestamp, platform_id, platform_type) existing = merged.get(key) if existing is None: - normalized_row = dict(row) - normalized_row["timestamp"] = normalized_timestamp - normalized_row["count"] = parse_count( - normalized_row.get("count", 0), key - ) merged[key] = normalized_row - continue - duplicate_count += 1 - existing_count = parse_count(existing.get("count", 0), key) - incoming_count = parse_count(row.get("count", 0), key) - existing["count"] = existing_count + incoming_count - return list(merged.values()), duplicate_count + result.append(normalized_row) + else: + existing["count"] += count + + return result + + def _parse_platform_stats_count( + self, + raw_count: Any, + invalid_count_warned: int, + key: tuple[str, str, str], + ) -> tuple[int, int]: + """Safe int parse with per-call rate-limited warning.""" + try: + return int(raw_count), invalid_count_warned + except (TypeError, ValueError): + if invalid_count_warned < PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT: + logger.warning( + "platform_stats count 非法,已按 0 处理: " + f"value={raw_count!r}, key={key}" + ) + invalid_count_warned += 1 + return 0, invalid_count_warned def _normalize_platform_stats_timestamp(self, value: Any) -> tuple[str, bool]: if isinstance(value, datetime): diff --git a/tests/test_backup.py b/tests/test_backup.py index 5a7910af39..1072ec0160 100644 --- a/tests/test_backup.py +++ b/tests/test_backup.py @@ -343,7 +343,8 @@ def test_merge_platform_stats_rows(self): }, ] - merged_rows, duplicate_count = importer._merge_platform_stats_rows(rows) + merged_rows = importer._merge_platform_stats_rows(rows) + duplicate_count = len(rows) - len(merged_rows) assert duplicate_count == 2 assert len(merged_rows) == 2 @@ -409,7 +410,8 @@ def test_merge_platform_stats_rows_warns_on_invalid_count(self): "count": "bad-count", }, ] - merged_rows, duplicate_count = importer._merge_platform_stats_rows(rows) + merged_rows = importer._merge_platform_stats_rows(rows) + duplicate_count = len(rows) - len(merged_rows) assert duplicate_count == 1 assert len(merged_rows) == 1 assert merged_rows[0]["count"] == 5 @@ -431,9 +433,8 @@ def test_merge_platform_stats_rows_warns_on_invalid_count(self): "count": 7, }, ] - merged_rows, duplicate_count = importer._merge_platform_stats_rows( - rows_existing_invalid - ) + merged_rows = importer._merge_platform_stats_rows(rows_existing_invalid) + duplicate_count = len(rows_existing_invalid) - len(merged_rows) assert duplicate_count == 1 assert len(merged_rows) == 1 assert merged_rows[0]["count"] == 7 @@ -471,9 +472,8 @@ def test_merge_platform_stats_rows_warns_on_invalid_count(self): "count": "still-bad", }, ] - merged_rows, duplicate_count = importer._merge_platform_stats_rows( - single_invalid_row - ) + merged_rows = importer._merge_platform_stats_rows(single_invalid_row) + duplicate_count = len(single_invalid_row) - len(merged_rows) assert duplicate_count == 0 assert len(merged_rows) == 1 assert merged_rows[0]["count"] == 0 @@ -503,7 +503,8 @@ def test_merge_platform_stats_rows_keeps_invalid_timestamps_distinct(self): }, ] - merged_rows, duplicate_count = importer._merge_platform_stats_rows(rows) + merged_rows = importer._merge_platform_stats_rows(rows) + duplicate_count = len(rows) - len(merged_rows) assert duplicate_count == 0 assert len(merged_rows) == 3 From 817286f13f0d52dc7d3081cd0c40773487171443 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 19:51:34 +0900 Subject: [PATCH 08/16] refactor: harden platform stats merge key handling --- astrbot/core/backup/importer.py | 66 +++++++++++++++++++++++++++++---- tests/test_backup.py | 60 +++++++++++++++++++++++++++++- 2 files changed, 118 insertions(+), 8 deletions(-) diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index e34fa1dfc9..689d154d76 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -61,8 +61,46 @@ def _get_major_version(version_str: str) -> str: CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json") KB_PATH = get_astrbot_knowledge_base_path() -# Warning limit per _merge_platform_stats_rows invocation. -PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = 5 +DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = 5 +PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV = ( + "ASTRBOT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT" +) + + +def _resolve_platform_stats_invalid_count_warn_limit( + raw_value: str | None, +) -> tuple[int, bool]: + """Resolve warn limit value and return whether the input was valid.""" + if raw_value is None: + return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, True + try: + value = int(raw_value) + except (TypeError, ValueError): + return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, False + if value < 0: + return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, False + return value, True + + +def _load_platform_stats_invalid_count_warn_limit() -> int: + raw_value = os.getenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV) + resolved_value, is_valid = _resolve_platform_stats_invalid_count_warn_limit( + raw_value + ) + if raw_value is not None and not is_valid: + logger.warning( + "Invalid env %s=%r, fallback to default %d", + PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, + raw_value, + DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, + ) + return resolved_value + + +# Warning limit per _merge_platform_stats_rows invocation; configurable by env. +PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = ( + _load_platform_stats_invalid_count_warn_limit() +) @dataclass @@ -140,6 +178,10 @@ def to_dict(self) -> dict: } +class DatabaseClearError(RuntimeError): + """Raised when clearing the main database in replace mode fails.""" + + class AstrBotImporter: """AstrBot 数据导入器 @@ -344,6 +386,9 @@ async def import_all( imported = await self._import_main_database(main_data) result.imported_tables.update(imported) + except DatabaseClearError as e: + result.add_error(f"清空主数据库失败: {e}") + return result except Exception as e: result.add_error(f"导入主数据库失败: {e}") return result @@ -454,7 +499,9 @@ async def _clear_main_db(self) -> None: await session.execute(delete(model_class)) logger.debug(f"已清空表 {table_name}") except Exception as e: - raise RuntimeError(f"清空表 {table_name} 失败: {e}") from e + raise DatabaseClearError( + f"清空表 {table_name} 失败: {e}" + ) from e async def _clear_kb_data(self) -> None: """清空知识库数据""" @@ -528,6 +575,7 @@ def _merge_platform_stats_rows( Note: - Invalid/empty timestamps are kept as distinct rows to avoid accidental merging. + - Non-string platform_id/platform_type are kept as distinct rows. - Invalid count warnings are rate-limited per function invocation. """ merged: dict[tuple[str, str, str], dict[str, Any]] = {} @@ -542,12 +590,12 @@ def _merge_platform_stats_rows( ) normalized_row["timestamp"] = normalized_timestamp - platform_id = str(normalized_row.get("platform_id")) - platform_type = str(normalized_row.get("platform_type")) + platform_id = normalized_row.get("platform_id") + platform_type = normalized_row.get("platform_type") key_for_log = ( normalized_timestamp if is_timestamp_valid else "", - platform_id, - platform_type, + repr(platform_id), + repr(platform_type), ) count, invalid_count_warned = self._parse_platform_stats_count( normalized_row.get("count", 0), invalid_count_warned, key_for_log @@ -559,6 +607,10 @@ def _merge_platform_stats_rows( result.append(normalized_row) continue + if not isinstance(platform_id, str) or not isinstance(platform_type, str): + result.append(normalized_row) + continue + key = (normalized_timestamp, platform_id, platform_type) existing = merged.get(key) if existing is None: diff --git a/tests/test_backup.py b/tests/test_backup.py index 1072ec0160..b27cb2c269 100644 --- a/tests/test_backup.py +++ b/tests/test_backup.py @@ -17,10 +17,13 @@ ) from astrbot.core.backup.exporter import AstrBotExporter from astrbot.core.backup.importer import ( + DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, + DatabaseClearError, PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, AstrBotImporter, ImportResult, _get_major_version, + _resolve_platform_stats_invalid_count_warn_limit, ) from astrbot.core.config.default import VERSION from astrbot.core.db.po import ( @@ -392,6 +395,24 @@ def test_normalize_platform_stats_timestamp_treats_naive_as_utc(self): assert is_valid_dt is True assert normalized_dt == "2025-12-13T21:00:00+00:00" + def test_resolve_platform_stats_invalid_count_warn_limit(self): + """测试非法/合法告警阈值配置解析""" + value, valid = _resolve_platform_stats_invalid_count_warn_limit(None) + assert valid is True + assert value == DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + + value, valid = _resolve_platform_stats_invalid_count_warn_limit("10") + assert valid is True + assert value == 10 + + value, valid = _resolve_platform_stats_invalid_count_warn_limit("-1") + assert valid is False + assert value == DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + + value, valid = _resolve_platform_stats_invalid_count_warn_limit("bad") + assert valid is False + assert value == DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + def test_merge_platform_stats_rows_warns_on_invalid_count(self): """测试 platform_stats count 非法时会告警并按 0 处理(含上限)""" importer = AstrBotImporter(main_db=MagicMock()) @@ -510,6 +531,42 @@ def test_merge_platform_stats_rows_keeps_invalid_timestamps_distinct(self): assert len(merged_rows) == 3 assert [row["count"] for row in merged_rows] == [2, 3, 4] + def test_merge_platform_stats_rows_keeps_non_string_platform_keys_distinct(self): + """测试非字符串 platform_id/platform_type 不参与聚合""" + importer = AstrBotImporter(main_db=MagicMock()) + rows = [ + { + "timestamp": "2025-12-13T20:00:00+00:00", + "platform_id": None, + "platform_type": "unknown", + "count": 2, + }, + { + "timestamp": "2025-12-13T20:00:00Z", + "platform_id": None, + "platform_type": "unknown", + "count": 3, + }, + { + "timestamp": "2025-12-13T20:00:00+00:00", + "platform_id": "webchat", + "platform_type": 1, + "count": 4, + }, + { + "timestamp": "2025-12-13T20:00:00Z", + "platform_id": "webchat", + "platform_type": 1, + "count": 5, + }, + ] + + merged_rows = importer._merge_platform_stats_rows(rows) + duplicate_count = len(rows) - len(merged_rows) + + assert duplicate_count == 0 + assert len(merged_rows) == 4 + @pytest.mark.asyncio async def test_import_file_not_exists(self, mock_main_db, tmp_path): """测试导入不存在的文件""" @@ -585,13 +642,14 @@ async def test_import_replace_fails_when_clear_main_db_fails( importer = AstrBotImporter(main_db=mock_main_db) importer._clear_main_db = AsyncMock( - side_effect=RuntimeError("清空表 platform_stats 失败: db locked") + side_effect=DatabaseClearError("清空表 platform_stats 失败: db locked") ) importer._import_main_database = AsyncMock(return_value={}) result = await importer.import_all(str(zip_path), mode="replace") assert result.success is False + assert any("清空主数据库失败" in err for err in result.errors) assert any("清空表 platform_stats 失败" in err for err in result.errors) importer._import_main_database.assert_not_awaited() From 7de31644c4dbccd78f6f0b3fccaffd6360adf406 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 19:58:28 +0900 Subject: [PATCH 09/16] refactor: streamline platform stats preprocessing --- astrbot/core/backup/importer.py | 125 +++++++++++++++----------------- tests/test_backup.py | 40 ++++++---- 2 files changed, 85 insertions(+), 80 deletions(-) diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index 689d154d76..2a668fe8d9 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -11,6 +11,7 @@ import os import shutil import zipfile +from collections.abc import Callable from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path @@ -67,34 +68,23 @@ def _get_major_version(version_str: str) -> str: ) -def _resolve_platform_stats_invalid_count_warn_limit( - raw_value: str | None, -) -> tuple[int, bool]: - """Resolve warn limit value and return whether the input was valid.""" +def _load_platform_stats_invalid_count_warn_limit() -> int: + raw_value = os.getenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV) if raw_value is None: - return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, True + return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT try: value = int(raw_value) + if value < 0: + raise ValueError("negative") + return value except (TypeError, ValueError): - return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, False - if value < 0: - return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, False - return value, True - - -def _load_platform_stats_invalid_count_warn_limit() -> int: - raw_value = os.getenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV) - resolved_value, is_valid = _resolve_platform_stats_invalid_count_warn_limit( - raw_value - ) - if raw_value is not None and not is_valid: logger.warning( "Invalid env %s=%r, fallback to default %d", PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, raw_value, DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, ) - return resolved_value + return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT # Warning limit per _merge_platform_stats_rows invocation; configurable by env. @@ -210,6 +200,11 @@ def __init__( self.kb_manager = kb_manager self.config_path = config_path self.kb_root_dir = kb_root_dir + self._main_table_preprocessors: dict[ + str, Callable[[list[dict[str, Any]]], list[dict[str, Any]]] + ] = { + "platform_stats": self._merge_platform_stats_rows, + } def pre_check(self, zip_path: str) -> ImportPreCheckResult: """预检查备份文件 @@ -543,14 +538,7 @@ async def _import_main_database( if not model_class: logger.warning(f"未知的表: {table_name}") continue - normalized_rows = rows - if table_name == "platform_stats": - normalized_rows = self._merge_platform_stats_rows(rows) - duplicate_count = len(rows) - len(normalized_rows) - if duplicate_count > 0: - logger.warning( - f"检测到 platform_stats 重复键 {duplicate_count} 条,已在导入前聚合" - ) + normalized_rows = self._preprocess_main_table_rows(table_name, rows) count = 0 for row in normalized_rows: @@ -568,6 +556,20 @@ async def _import_main_database( return imported + def _preprocess_main_table_rows( + self, table_name: str, rows: list[dict[str, Any]] + ) -> list[dict[str, Any]]: + preprocessor = self._main_table_preprocessors.get(table_name) + if preprocessor is None: + return rows + normalized_rows = preprocessor(rows) + duplicate_count = len(rows) - len(normalized_rows) + if duplicate_count > 0: + logger.warning( + f"检测到 {table_name} 重复键 {duplicate_count} 条,已在导入前聚合" + ) + return normalized_rows + def _merge_platform_stats_rows( self, rows: list[dict[str, Any]] ) -> list[dict[str, Any]]: @@ -579,8 +581,23 @@ def _merge_platform_stats_rows( - Invalid count warnings are rate-limited per function invocation. """ merged: dict[tuple[str, str, str], dict[str, Any]] = {} - result: list[dict[str, Any]] = [] + non_mergeable: list[dict[str, Any]] = [] invalid_count_warned = 0 + + def parse_count(raw_count: Any, key: tuple[str, str, str]) -> int: + nonlocal invalid_count_warned + try: + return int(raw_count) + except (TypeError, ValueError): + if invalid_count_warned < PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT: + logger.warning( + "platform_stats count 非法,已按 0 处理: value=%r, key=%s", + raw_count, + key, + ) + invalid_count_warned += 1 + return 0 + for row in rows: normalized_row = dict(row) normalized_timestamp, is_timestamp_valid = ( @@ -597,56 +614,37 @@ def _merge_platform_stats_rows( repr(platform_id), repr(platform_type), ) - count, invalid_count_warned = self._parse_platform_stats_count( - normalized_row.get("count", 0), invalid_count_warned, key_for_log - ) + count = parse_count(normalized_row.get("count", 0), key_for_log) normalized_row["count"] = count - # Invalid timestamps should never be merged. if not is_timestamp_valid: - result.append(normalized_row) + non_mergeable.append(normalized_row) continue if not isinstance(platform_id, str) or not isinstance(platform_type, str): - result.append(normalized_row) + non_mergeable.append(normalized_row) continue key = (normalized_timestamp, platform_id, platform_type) existing = merged.get(key) if existing is None: merged[key] = normalized_row - result.append(normalized_row) else: existing["count"] += count - return result + return [*non_mergeable, *merged.values()] - def _parse_platform_stats_count( - self, - raw_count: Any, - invalid_count_warned: int, - key: tuple[str, str, str], - ) -> tuple[int, int]: - """Safe int parse with per-call rate-limited warning.""" - try: - return int(raw_count), invalid_count_warned - except (TypeError, ValueError): - if invalid_count_warned < PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT: - logger.warning( - "platform_stats count 非法,已按 0 处理: " - f"value={raw_count!r}, key={key}" - ) - invalid_count_warned += 1 - return 0, invalid_count_warned + def _to_utc_iso(self, dt: datetime) -> str: + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + else: + dt = dt.astimezone(timezone.utc) + return dt.isoformat() def _normalize_platform_stats_timestamp(self, value: Any) -> tuple[str, bool]: if isinstance(value, datetime): - dt = value - if dt.tzinfo is None: - dt = dt.replace(tzinfo=timezone.utc) - else: - dt = dt.astimezone(timezone.utc) - return dt.isoformat(), True + return self._to_utc_iso(value), True + if isinstance(value, str): timestamp = value.strip() if not timestamp: @@ -654,16 +652,13 @@ def _normalize_platform_stats_timestamp(self, value: Any) -> tuple[str, bool]: if timestamp.endswith("Z"): timestamp = f"{timestamp[:-1]}+00:00" try: - dt = datetime.fromisoformat(timestamp) - if dt.tzinfo is None: - dt = dt.replace(tzinfo=timezone.utc) - else: - dt = dt.astimezone(timezone.utc) - return dt.isoformat(), True + return self._to_utc_iso(datetime.fromisoformat(timestamp)), True except ValueError: - return value.strip(), False + return timestamp, False + if value is None: return "", False + return str(value), False async def _import_knowledge_bases( diff --git a/tests/test_backup.py b/tests/test_backup.py index b27cb2c269..dc4e65fe94 100644 --- a/tests/test_backup.py +++ b/tests/test_backup.py @@ -19,11 +19,12 @@ from astrbot.core.backup.importer import ( DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, DatabaseClearError, + PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, AstrBotImporter, ImportResult, + _load_platform_stats_invalid_count_warn_limit, _get_major_version, - _resolve_platform_stats_invalid_count_warn_limit, ) from astrbot.core.config.default import VERSION from astrbot.core.db.po import ( @@ -395,23 +396,32 @@ def test_normalize_platform_stats_timestamp_treats_naive_as_utc(self): assert is_valid_dt is True assert normalized_dt == "2025-12-13T21:00:00+00:00" - def test_resolve_platform_stats_invalid_count_warn_limit(self): - """测试非法/合法告警阈值配置解析""" - value, valid = _resolve_platform_stats_invalid_count_warn_limit(None) - assert valid is True - assert value == DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + def test_load_platform_stats_invalid_count_warn_limit(self, monkeypatch): + """测试告警阈值环境变量解析""" + monkeypatch.delenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, raising=False) + assert ( + _load_platform_stats_invalid_count_warn_limit() + == DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + ) - value, valid = _resolve_platform_stats_invalid_count_warn_limit("10") - assert valid is True - assert value == 10 + monkeypatch.setenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, "10") + assert _load_platform_stats_invalid_count_warn_limit() == 10 - value, valid = _resolve_platform_stats_invalid_count_warn_limit("-1") - assert valid is False - assert value == DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + with patch("astrbot.core.backup.importer.logger.warning") as warning_mock: + monkeypatch.setenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, "-1") + assert ( + _load_platform_stats_invalid_count_warn_limit() + == DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + ) + assert warning_mock.call_count == 1 - value, valid = _resolve_platform_stats_invalid_count_warn_limit("bad") - assert valid is False - assert value == DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + with patch("astrbot.core.backup.importer.logger.warning") as warning_mock: + monkeypatch.setenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, "bad") + assert ( + _load_platform_stats_invalid_count_warn_limit() + == DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + ) + assert warning_mock.call_count == 1 def test_merge_platform_stats_rows_warns_on_invalid_count(self): """测试 platform_stats count 非法时会告警并按 0 处理(含上限)""" From ec57b98a6a6897c384c5e6a98f3d20a4d629fd2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 20:03:59 +0900 Subject: [PATCH 10/16] refactor: simplify platform stats merge helpers --- astrbot/core/backup/importer.py | 135 +++++++++++++++----------------- tests/test_backup.py | 38 +-------- 2 files changed, 64 insertions(+), 109 deletions(-) diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index 2a668fe8d9..705c55a937 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -11,7 +11,6 @@ import os import shutil import zipfile -from collections.abc import Callable from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path @@ -62,35 +61,7 @@ def _get_major_version(version_str: str) -> str: CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json") KB_PATH = get_astrbot_knowledge_base_path() -DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = 5 -PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV = ( - "ASTRBOT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT" -) - - -def _load_platform_stats_invalid_count_warn_limit() -> int: - raw_value = os.getenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV) - if raw_value is None: - return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT - try: - value = int(raw_value) - if value < 0: - raise ValueError("negative") - return value - except (TypeError, ValueError): - logger.warning( - "Invalid env %s=%r, fallback to default %d", - PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, - raw_value, - DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, - ) - return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT - - -# Warning limit per _merge_platform_stats_rows invocation; configurable by env. -PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = ( - _load_platform_stats_invalid_count_warn_limit() -) +PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = 5 @dataclass @@ -200,11 +171,6 @@ def __init__( self.kb_manager = kb_manager self.config_path = config_path self.kb_root_dir = kb_root_dir - self._main_table_preprocessors: dict[ - str, Callable[[list[dict[str, Any]]], list[dict[str, Any]]] - ] = { - "platform_stats": self._merge_platform_stats_rows, - } def pre_check(self, zip_path: str) -> ImportPreCheckResult: """预检查备份文件 @@ -559,16 +525,15 @@ async def _import_main_database( def _preprocess_main_table_rows( self, table_name: str, rows: list[dict[str, Any]] ) -> list[dict[str, Any]]: - preprocessor = self._main_table_preprocessors.get(table_name) - if preprocessor is None: - return rows - normalized_rows = preprocessor(rows) - duplicate_count = len(rows) - len(normalized_rows) - if duplicate_count > 0: - logger.warning( - f"检测到 {table_name} 重复键 {duplicate_count} 条,已在导入前聚合" - ) - return normalized_rows + if table_name == "platform_stats": + normalized_rows = self._merge_platform_stats_rows(rows) + duplicate_count = len(rows) - len(normalized_rows) + if duplicate_count > 0: + logger.warning( + f"检测到 {table_name} 重复键 {duplicate_count} 条,已在导入前聚合" + ) + return normalized_rows + return rows def _merge_platform_stats_rows( self, rows: list[dict[str, Any]] @@ -584,28 +549,10 @@ def _merge_platform_stats_rows( non_mergeable: list[dict[str, Any]] = [] invalid_count_warned = 0 - def parse_count(raw_count: Any, key: tuple[str, str, str]) -> int: - nonlocal invalid_count_warned - try: - return int(raw_count) - except (TypeError, ValueError): - if invalid_count_warned < PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT: - logger.warning( - "platform_stats count 非法,已按 0 处理: value=%r, key=%s", - raw_count, - key, - ) - invalid_count_warned += 1 - return 0 - for row in rows: - normalized_row = dict(row) - normalized_timestamp, is_timestamp_valid = ( - self._normalize_platform_stats_timestamp( - normalized_row.get("timestamp") - ) + normalized_row, normalized_timestamp, is_timestamp_valid = ( + self._normalize_platform_stats_row(row) ) - normalized_row["timestamp"] = normalized_timestamp platform_id = normalized_row.get("platform_id") platform_type = normalized_row.get("platform_type") @@ -614,7 +561,11 @@ def parse_count(raw_count: Any, key: tuple[str, str, str]) -> int: repr(platform_id), repr(platform_type), ) - count = parse_count(normalized_row.get("count", 0), key_for_log) + count, invalid_count_warned = self._parse_platform_stats_count( + normalized_row.get("count", 0), + key_for_log, + invalid_count_warned, + ) normalized_row["count"] = count if not is_timestamp_valid: @@ -634,6 +585,44 @@ def parse_count(raw_count: Any, key: tuple[str, str, str]) -> int: return [*non_mergeable, *merged.values()] + def _parse_platform_stats_count( + self, + raw_count: Any, + key_for_log: tuple[str, str, str], + warned_count: int, + ) -> tuple[int, int]: + if warned_count >= PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT: + try: + return int(raw_count), warned_count + except (TypeError, ValueError): + return 0, warned_count + try: + return int(raw_count), warned_count + except (TypeError, ValueError): + logger.warning( + "platform_stats count 非法,已按 0 处理: value=%r, key=%s", + raw_count, + key_for_log, + ) + return 0, warned_count + 1 + + def _normalize_platform_stats_row( + self, row: dict[str, Any] + ) -> tuple[dict[str, Any], str, bool]: + normalized_row = dict(row) + raw_timestamp = normalized_row.get("timestamp") + normalized_timestamp = self._normalize_platform_stats_timestamp(raw_timestamp) + if normalized_timestamp is None: + if isinstance(raw_timestamp, str): + normalized_row["timestamp"] = raw_timestamp.strip() + elif raw_timestamp is None: + normalized_row["timestamp"] = "" + else: + normalized_row["timestamp"] = str(raw_timestamp) + return normalized_row, normalized_row["timestamp"], False + normalized_row["timestamp"] = normalized_timestamp + return normalized_row, normalized_timestamp, True + def _to_utc_iso(self, dt: datetime) -> str: if dt.tzinfo is None: dt = dt.replace(tzinfo=timezone.utc) @@ -641,25 +630,25 @@ def _to_utc_iso(self, dt: datetime) -> str: dt = dt.astimezone(timezone.utc) return dt.isoformat() - def _normalize_platform_stats_timestamp(self, value: Any) -> tuple[str, bool]: + def _normalize_platform_stats_timestamp(self, value: Any) -> str | None: if isinstance(value, datetime): - return self._to_utc_iso(value), True + return self._to_utc_iso(value) if isinstance(value, str): timestamp = value.strip() if not timestamp: - return "", False + return None if timestamp.endswith("Z"): timestamp = f"{timestamp[:-1]}+00:00" try: - return self._to_utc_iso(datetime.fromisoformat(timestamp)), True + return self._to_utc_iso(datetime.fromisoformat(timestamp)) except ValueError: - return timestamp, False + return None if value is None: - return "", False + return None - return str(value), False + return None async def _import_knowledge_bases( self, diff --git a/tests/test_backup.py b/tests/test_backup.py index dc4e65fe94..d41008f93e 100644 --- a/tests/test_backup.py +++ b/tests/test_backup.py @@ -17,13 +17,10 @@ ) from astrbot.core.backup.exporter import AstrBotExporter from astrbot.core.backup.importer import ( - DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, DatabaseClearError, - PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, AstrBotImporter, ImportResult, - _load_platform_stats_invalid_count_warn_limit, _get_major_version, ) from astrbot.core.config.default import VERSION @@ -384,45 +381,14 @@ def test_normalize_platform_stats_timestamp_treats_naive_as_utc(self): """测试 naive timestamp 会统一转为显式 UTC 偏移""" importer = AstrBotImporter(main_db=MagicMock()) - normalized, is_valid = importer._normalize_platform_stats_timestamp( - "2025-12-13T21:00:00" - ) - assert is_valid is True + normalized = importer._normalize_platform_stats_timestamp("2025-12-13T21:00:00") assert normalized == "2025-12-13T21:00:00+00:00" - normalized_dt, is_valid_dt = importer._normalize_platform_stats_timestamp( + normalized_dt = importer._normalize_platform_stats_timestamp( datetime(2025, 12, 13, 21, 0, 0) ) - assert is_valid_dt is True assert normalized_dt == "2025-12-13T21:00:00+00:00" - def test_load_platform_stats_invalid_count_warn_limit(self, monkeypatch): - """测试告警阈值环境变量解析""" - monkeypatch.delenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, raising=False) - assert ( - _load_platform_stats_invalid_count_warn_limit() - == DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT - ) - - monkeypatch.setenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, "10") - assert _load_platform_stats_invalid_count_warn_limit() == 10 - - with patch("astrbot.core.backup.importer.logger.warning") as warning_mock: - monkeypatch.setenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, "-1") - assert ( - _load_platform_stats_invalid_count_warn_limit() - == DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT - ) - assert warning_mock.call_count == 1 - - with patch("astrbot.core.backup.importer.logger.warning") as warning_mock: - monkeypatch.setenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, "bad") - assert ( - _load_platform_stats_invalid_count_warn_limit() - == DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT - ) - assert warning_mock.call_count == 1 - def test_merge_platform_stats_rows_warns_on_invalid_count(self): """测试 platform_stats count 非法时会告警并按 0 处理(含上限)""" importer = AstrBotImporter(main_db=MagicMock()) From 628defa9a5a712767ccbc5215ee7c73708e3a5a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 20:11:11 +0900 Subject: [PATCH 11/16] refactor: inline platform stats merge normalization --- astrbot/core/backup/importer.py | 134 ++++++++++++++------------------ tests/test_backup.py | 29 +++++-- 2 files changed, 78 insertions(+), 85 deletions(-) diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index 705c55a937..8db096764d 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -549,34 +549,72 @@ def _merge_platform_stats_rows( non_mergeable: list[dict[str, Any]] = [] invalid_count_warned = 0 - for row in rows: - normalized_row, normalized_timestamp, is_timestamp_valid = ( - self._normalize_platform_stats_row(row) - ) + def normalize_timestamp(value: Any) -> str | None: + if isinstance(value, datetime): + return self._to_utc_iso(value) + if isinstance(value, str): + timestamp = value.strip() + if not timestamp: + return None + if timestamp.endswith("Z"): + timestamp = f"{timestamp[:-1]}+00:00" + try: + return self._to_utc_iso(datetime.fromisoformat(timestamp)) + except ValueError: + return None + return None + + def build_key(row: dict[str, Any]) -> tuple[str, str, str] | None: + normalized_timestamp = normalize_timestamp(row.get("timestamp")) + if normalized_timestamp is not None: + row["timestamp"] = normalized_timestamp + platform_id = row.get("platform_id") + platform_type = row.get("platform_type") + if ( + normalized_timestamp is None + or not isinstance(platform_id, str) + or not isinstance(platform_type, str) + ): + return None + return (normalized_timestamp, platform_id, platform_type) + + def parse_count( + raw_count: Any, + key_for_log: tuple[Any, Any, Any], + warned_count: int, + ) -> tuple[int, int]: + if warned_count >= PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT: + try: + return int(raw_count), warned_count + except (TypeError, ValueError): + return 0, warned_count + try: + return int(raw_count), warned_count + except (TypeError, ValueError): + logger.warning( + "platform_stats count 非法,已按 0 处理: value=%r, key=%s", + raw_count, + key_for_log, + ) + return 0, warned_count + 1 - platform_id = normalized_row.get("platform_id") - platform_type = normalized_row.get("platform_type") + for row in rows: + normalized_row = dict(row) + key = build_key(normalized_row) key_for_log = ( - normalized_timestamp if is_timestamp_valid else "", - repr(platform_id), - repr(platform_type), + normalized_row.get("timestamp"), + repr(normalized_row.get("platform_id")), + repr(normalized_row.get("platform_type")), ) - count, invalid_count_warned = self._parse_platform_stats_count( - normalized_row.get("count", 0), - key_for_log, - invalid_count_warned, + count, invalid_count_warned = parse_count( + normalized_row.get("count", 0), key_for_log, invalid_count_warned ) normalized_row["count"] = count - if not is_timestamp_valid: - non_mergeable.append(normalized_row) - continue - - if not isinstance(platform_id, str) or not isinstance(platform_type, str): + if key is None: non_mergeable.append(normalized_row) continue - key = (normalized_timestamp, platform_id, platform_type) existing = merged.get(key) if existing is None: merged[key] = normalized_row @@ -585,44 +623,6 @@ def _merge_platform_stats_rows( return [*non_mergeable, *merged.values()] - def _parse_platform_stats_count( - self, - raw_count: Any, - key_for_log: tuple[str, str, str], - warned_count: int, - ) -> tuple[int, int]: - if warned_count >= PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT: - try: - return int(raw_count), warned_count - except (TypeError, ValueError): - return 0, warned_count - try: - return int(raw_count), warned_count - except (TypeError, ValueError): - logger.warning( - "platform_stats count 非法,已按 0 处理: value=%r, key=%s", - raw_count, - key_for_log, - ) - return 0, warned_count + 1 - - def _normalize_platform_stats_row( - self, row: dict[str, Any] - ) -> tuple[dict[str, Any], str, bool]: - normalized_row = dict(row) - raw_timestamp = normalized_row.get("timestamp") - normalized_timestamp = self._normalize_platform_stats_timestamp(raw_timestamp) - if normalized_timestamp is None: - if isinstance(raw_timestamp, str): - normalized_row["timestamp"] = raw_timestamp.strip() - elif raw_timestamp is None: - normalized_row["timestamp"] = "" - else: - normalized_row["timestamp"] = str(raw_timestamp) - return normalized_row, normalized_row["timestamp"], False - normalized_row["timestamp"] = normalized_timestamp - return normalized_row, normalized_timestamp, True - def _to_utc_iso(self, dt: datetime) -> str: if dt.tzinfo is None: dt = dt.replace(tzinfo=timezone.utc) @@ -630,26 +630,6 @@ def _to_utc_iso(self, dt: datetime) -> str: dt = dt.astimezone(timezone.utc) return dt.isoformat() - def _normalize_platform_stats_timestamp(self, value: Any) -> str | None: - if isinstance(value, datetime): - return self._to_utc_iso(value) - - if isinstance(value, str): - timestamp = value.strip() - if not timestamp: - return None - if timestamp.endswith("Z"): - timestamp = f"{timestamp[:-1]}+00:00" - try: - return self._to_utc_iso(datetime.fromisoformat(timestamp)) - except ValueError: - return None - - if value is None: - return None - - return None - async def _import_knowledge_bases( self, zf: zipfile.ZipFile, diff --git a/tests/test_backup.py b/tests/test_backup.py index d41008f93e..c9b65eabaa 100644 --- a/tests/test_backup.py +++ b/tests/test_backup.py @@ -377,17 +377,30 @@ def test_merge_platform_stats_rows(self): assert aiocq_row is not None assert aiocq_row["timestamp"] == "2025-12-13T21:00:00+00:00" - def test_normalize_platform_stats_timestamp_treats_naive_as_utc(self): - """测试 naive timestamp 会统一转为显式 UTC 偏移""" + def test_merge_platform_stats_rows_normalizes_naive_timestamp_to_utc(self): + """测试 platform_stats 合并前会将 naive timestamp 标准化为 UTC 偏移""" importer = AstrBotImporter(main_db=MagicMock()) - normalized = importer._normalize_platform_stats_timestamp("2025-12-13T21:00:00") - assert normalized == "2025-12-13T21:00:00+00:00" + rows = [ + { + "timestamp": "2025-12-13T21:00:00", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 1, + }, + { + "timestamp": datetime(2025, 12, 13, 22, 0, 0), + "platform_id": "telegram", + "platform_type": "unknown", + "count": 1, + }, + ] - normalized_dt = importer._normalize_platform_stats_timestamp( - datetime(2025, 12, 13, 21, 0, 0) - ) - assert normalized_dt == "2025-12-13T21:00:00+00:00" + merged_rows = importer._merge_platform_stats_rows(rows) + assert len(merged_rows) == 2 + by_platform = {row["platform_id"]: row for row in merged_rows} + assert by_platform["webchat"]["timestamp"] == "2025-12-13T21:00:00+00:00" + assert by_platform["telegram"]["timestamp"] == "2025-12-13T22:00:00+00:00" def test_merge_platform_stats_rows_warns_on_invalid_count(self): """测试 platform_stats count 非法时会告警并按 0 处理(含上限)""" From 960c4dba347e3b3d13ffb2cc9f198ad5a7e6d331 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 20:16:24 +0900 Subject: [PATCH 12/16] refactor: extract platform stats merge helpers --- astrbot/core/backup/importer.py | 152 +++++++++++++++++++++----------- tests/test_backup.py | 8 +- 2 files changed, 106 insertions(+), 54 deletions(-) diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index 8db096764d..b71bddfc1c 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -548,67 +548,46 @@ def _merge_platform_stats_rows( merged: dict[tuple[str, str, str], dict[str, Any]] = {} non_mergeable: list[dict[str, Any]] = [] invalid_count_warned = 0 + suppression_warned = False - def normalize_timestamp(value: Any) -> str | None: - if isinstance(value, datetime): - return self._to_utc_iso(value) - if isinstance(value, str): - timestamp = value.strip() - if not timestamp: - return None - if timestamp.endswith("Z"): - timestamp = f"{timestamp[:-1]}+00:00" - try: - return self._to_utc_iso(datetime.fromisoformat(timestamp)) - except ValueError: - return None - return None + for row in rows: + normalized_row = dict(row) + raw_timestamp = normalized_row.get("timestamp") + normalized_timestamp = self._normalize_platform_stats_timestamp( + raw_timestamp + ) + platform_id = normalized_row.get("platform_id") + platform_type = normalized_row.get("platform_type") - def build_key(row: dict[str, Any]) -> tuple[str, str, str] | None: - normalized_timestamp = normalize_timestamp(row.get("timestamp")) if normalized_timestamp is not None: - row["timestamp"] = normalized_timestamp - platform_id = row.get("platform_id") - platform_type = row.get("platform_type") - if ( - normalized_timestamp is None - or not isinstance(platform_id, str) - or not isinstance(platform_type, str) - ): - return None - return (normalized_timestamp, platform_id, platform_type) - - def parse_count( - raw_count: Any, - key_for_log: tuple[Any, Any, Any], - warned_count: int, - ) -> tuple[int, int]: - if warned_count >= PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT: - try: - return int(raw_count), warned_count - except (TypeError, ValueError): - return 0, warned_count - try: - return int(raw_count), warned_count - except (TypeError, ValueError): - logger.warning( - "platform_stats count 非法,已按 0 处理: value=%r, key=%s", - raw_count, - key_for_log, - ) - return 0, warned_count + 1 + normalized_row["timestamp"] = normalized_timestamp + elif isinstance(raw_timestamp, str): + normalized_row["timestamp"] = raw_timestamp.strip() + elif raw_timestamp is None: + normalized_row["timestamp"] = "" + else: + normalized_row["timestamp"] = str(raw_timestamp) - for row in rows: - normalized_row = dict(row) - key = build_key(normalized_row) + key = self._build_platform_stats_key( + normalized_timestamp, platform_id, platform_type + ) key_for_log = ( normalized_row.get("timestamp"), - repr(normalized_row.get("platform_id")), - repr(normalized_row.get("platform_type")), + repr(platform_id), + repr(platform_type), ) - count, invalid_count_warned = parse_count( - normalized_row.get("count", 0), key_for_log, invalid_count_warned + count, is_valid_count = self._parse_platform_stats_count_value( + normalized_row.get("count", 0) ) + if not is_valid_count: + invalid_count_warned, suppression_warned = ( + self._log_invalid_platform_stats_count( + normalized_row.get("count", 0), + key_for_log, + invalid_count_warned, + suppression_warned, + ) + ) normalized_row["count"] = count if key is None: @@ -623,6 +602,73 @@ def parse_count( return [*non_mergeable, *merged.values()] + def _normalize_platform_stats_timestamp(self, value: Any) -> str | None: + if isinstance(value, datetime): + return self._to_utc_iso(value) + if isinstance(value, str): + timestamp = value.strip() + if not timestamp: + return None + if timestamp.endswith("Z"): + timestamp = f"{timestamp[:-1]}+00:00" + try: + return self._to_utc_iso(datetime.fromisoformat(timestamp)) + except ValueError: + return None + return None + + def _build_platform_stats_key( + self, + normalized_timestamp: str | None, + platform_id: Any, + platform_type: Any, + ) -> tuple[str, str, str] | None: + if ( + normalized_timestamp is None + or not isinstance(platform_id, str) + or not isinstance(platform_type, str) + ): + return None + return (normalized_timestamp, platform_id, platform_type) + + def _parse_platform_stats_count_value(self, raw_count: Any) -> tuple[int, bool]: + try: + return int(raw_count), True + except (TypeError, ValueError): + return 0, False + + def _log_invalid_platform_stats_count( + self, + raw_count: Any, + key_for_log: tuple[Any, Any, Any], + warned_count: int, + suppression_warned: bool, + ) -> tuple[int, bool]: + if warned_count < PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT: + logger.warning( + "platform_stats count 非法,已按 0 处理: value=%r, key=%s", + raw_count, + key_for_log, + ) + warned_count += 1 + if ( + warned_count >= PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + and not suppression_warned + ): + logger.warning( + "platform_stats 非法 count 告警已达到上限 (%d),后续将抑制", + PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, + ) + suppression_warned = True + return warned_count, suppression_warned + if not suppression_warned: + logger.warning( + "platform_stats 非法 count 告警已达到上限 (%d),后续将抑制", + PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, + ) + suppression_warned = True + return warned_count, suppression_warned + def _to_utc_iso(self, dt: datetime) -> str: if dt.tzinfo is None: dt = dt.replace(tzinfo=timezone.utc) diff --git a/tests/test_backup.py b/tests/test_backup.py index c9b65eabaa..c2f2a578f0 100644 --- a/tests/test_backup.py +++ b/tests/test_backup.py @@ -470,7 +470,13 @@ def test_merge_platform_stats_rows_warns_on_invalid_count(self): ], ] importer._merge_platform_stats_rows(many_invalid_rows) - assert warning_mock.call_count == PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + assert ( + warning_mock.call_count == PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + 1 + ) + assert any( + "告警已达到上限" in str(call.args[0]) + for call in warning_mock.call_args_list + ) warning_mock.reset_mock() From d21cf5824b6b46ed295dc64c034733859dc11f46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 20:30:21 +0900 Subject: [PATCH 13/16] refactor: simplify platform stats preprocessing flow --- astrbot/core/backup/importer.py | 127 +++++++++++++------------------- tests/test_backup.py | 40 ++++++++++ 2 files changed, 92 insertions(+), 75 deletions(-) diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index b71bddfc1c..861f27e772 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -525,8 +525,10 @@ async def _import_main_database( def _preprocess_main_table_rows( self, table_name: str, rows: list[dict[str, Any]] ) -> list[dict[str, Any]]: - if table_name == "platform_stats": - normalized_rows = self._merge_platform_stats_rows(rows) + preprocessors = {"platform_stats": self._merge_platform_stats_rows} + preprocessor = preprocessors.get(table_name) + if preprocessor is not None: + normalized_rows = preprocessor(rows) duplicate_count = len(rows) - len(normalized_rows) if duplicate_count > 0: logger.warning( @@ -546,9 +548,44 @@ def _merge_platform_stats_rows( - Invalid count warnings are rate-limited per function invocation. """ merged: dict[tuple[str, str, str], dict[str, Any]] = {} - non_mergeable: list[dict[str, Any]] = [] - invalid_count_warned = 0 - suppression_warned = False + result: list[dict[str, Any]] = [] + invalid_count_warnings = 0 + + def log_invalid_count( + raw_count: Any, key_for_log: tuple[Any, Any, Any] + ) -> None: + nonlocal invalid_count_warnings + + limit = PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + if limit <= 0: + if invalid_count_warnings == 0: + logger.warning( + "platform_stats 非法 count 告警已达到上限 (%d),后续将抑制", + limit, + ) + invalid_count_warnings = 1 + return + + if invalid_count_warnings < limit: + logger.warning( + "platform_stats count 非法,已按 0 处理: value=%r, key=%s", + raw_count, + key_for_log, + ) + invalid_count_warnings += 1 + if invalid_count_warnings == limit: + logger.warning( + "platform_stats 非法 count 告警已达到上限 (%d),后续将抑制", + limit, + ) + invalid_count_warnings += 1 + + def parse_count(raw_count: Any, key_for_log: tuple[Any, Any, Any]) -> int: + try: + return int(raw_count) + except (TypeError, ValueError): + log_invalid_count(raw_count, key_for_log) + return 0 for row in rows: normalized_row = dict(row) @@ -568,39 +605,31 @@ def _merge_platform_stats_rows( else: normalized_row["timestamp"] = str(raw_timestamp) - key = self._build_platform_stats_key( - normalized_timestamp, platform_id, platform_type - ) key_for_log = ( normalized_row.get("timestamp"), repr(platform_id), repr(platform_type), ) - count, is_valid_count = self._parse_platform_stats_count_value( - normalized_row.get("count", 0) - ) - if not is_valid_count: - invalid_count_warned, suppression_warned = ( - self._log_invalid_platform_stats_count( - normalized_row.get("count", 0), - key_for_log, - invalid_count_warned, - suppression_warned, - ) - ) + count = parse_count(normalized_row.get("count", 0), key_for_log) normalized_row["count"] = count - if key is None: - non_mergeable.append(normalized_row) + if ( + normalized_timestamp is None + or not isinstance(platform_id, str) + or not isinstance(platform_type, str) + ): + result.append(normalized_row) continue + key = (normalized_timestamp, platform_id, platform_type) existing = merged.get(key) if existing is None: merged[key] = normalized_row + result.append(normalized_row) else: existing["count"] += count - return [*non_mergeable, *merged.values()] + return result def _normalize_platform_stats_timestamp(self, value: Any) -> str | None: if isinstance(value, datetime): @@ -617,58 +646,6 @@ def _normalize_platform_stats_timestamp(self, value: Any) -> str | None: return None return None - def _build_platform_stats_key( - self, - normalized_timestamp: str | None, - platform_id: Any, - platform_type: Any, - ) -> tuple[str, str, str] | None: - if ( - normalized_timestamp is None - or not isinstance(platform_id, str) - or not isinstance(platform_type, str) - ): - return None - return (normalized_timestamp, platform_id, platform_type) - - def _parse_platform_stats_count_value(self, raw_count: Any) -> tuple[int, bool]: - try: - return int(raw_count), True - except (TypeError, ValueError): - return 0, False - - def _log_invalid_platform_stats_count( - self, - raw_count: Any, - key_for_log: tuple[Any, Any, Any], - warned_count: int, - suppression_warned: bool, - ) -> tuple[int, bool]: - if warned_count < PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT: - logger.warning( - "platform_stats count 非法,已按 0 处理: value=%r, key=%s", - raw_count, - key_for_log, - ) - warned_count += 1 - if ( - warned_count >= PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT - and not suppression_warned - ): - logger.warning( - "platform_stats 非法 count 告警已达到上限 (%d),后续将抑制", - PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, - ) - suppression_warned = True - return warned_count, suppression_warned - if not suppression_warned: - logger.warning( - "platform_stats 非法 count 告警已达到上限 (%d),后续将抑制", - PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, - ) - suppression_warned = True - return warned_count, suppression_warned - def _to_utc_iso(self, dt: datetime) -> str: if dt.tzinfo is None: dt = dt.replace(tzinfo=timezone.utc) diff --git a/tests/test_backup.py b/tests/test_backup.py index c2f2a578f0..cf3c4d9494 100644 --- a/tests/test_backup.py +++ b/tests/test_backup.py @@ -562,6 +562,46 @@ def test_merge_platform_stats_rows_keeps_non_string_platform_keys_distinct(self) assert duplicate_count == 0 assert len(merged_rows) == 4 + def test_merge_platform_stats_rows_preserves_input_order(self): + """测试 platform_stats 聚合后仍保持输入顺序(按首次出现位置)""" + importer = AstrBotImporter(main_db=MagicMock()) + rows = [ + { + "id": 1, + "timestamp": "2025-12-13T20:00:00Z", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 2, + }, + { + "id": 2, + "timestamp": "", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 3, + }, + { + "id": 3, + "timestamp": "2025-12-13T20:00:00+00:00", + "platform_id": "webchat", + "platform_type": "unknown", + "count": 5, + }, + { + "id": 4, + "timestamp": "2025-12-13T21:00:00+00:00", + "platform_id": "telegram", + "platform_type": "unknown", + "count": 7, + }, + ] + + merged_rows = importer._merge_platform_stats_rows(rows) + + assert len(merged_rows) == 3 + assert [row["id"] for row in merged_rows] == [1, 2, 4] + assert merged_rows[0]["count"] == 7 + @pytest.mark.asyncio async def test_import_file_not_exists(self, mock_main_db, tmp_path): """测试导入不存在的文件""" From 60cedc7804c7c22638bc43c25966adb23ff6b28d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 20:34:16 +0900 Subject: [PATCH 14/16] refactor: flatten platform stats preprocess helpers --- astrbot/core/backup/importer.py | 177 +++++++++++++++++++------------- 1 file changed, 106 insertions(+), 71 deletions(-) diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index 861f27e772..45e892b151 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -525,14 +525,14 @@ async def _import_main_database( def _preprocess_main_table_rows( self, table_name: str, rows: list[dict[str, Any]] ) -> list[dict[str, Any]]: - preprocessors = {"platform_stats": self._merge_platform_stats_rows} - preprocessor = preprocessors.get(table_name) - if preprocessor is not None: - normalized_rows = preprocessor(rows) + if table_name == "platform_stats": + normalized_rows = self._merge_platform_stats_rows(rows) duplicate_count = len(rows) - len(normalized_rows) if duplicate_count > 0: - logger.warning( - f"检测到 {table_name} 重复键 {duplicate_count} 条,已在导入前聚合" + logger.info( + "检测到 %s 重复键 %d 条,已在导入前聚合", + table_name, + duplicate_count, ) return normalized_rows return rows @@ -551,86 +551,121 @@ def _merge_platform_stats_rows( result: list[dict[str, Any]] = [] invalid_count_warnings = 0 - def log_invalid_count( - raw_count: Any, key_for_log: tuple[Any, Any, Any] - ) -> None: - nonlocal invalid_count_warnings - - limit = PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT - if limit <= 0: - if invalid_count_warnings == 0: - logger.warning( - "platform_stats 非法 count 告警已达到上限 (%d),后续将抑制", - limit, - ) - invalid_count_warnings = 1 - return - - if invalid_count_warnings < limit: - logger.warning( - "platform_stats count 非法,已按 0 处理: value=%r, key=%s", - raw_count, - key_for_log, - ) - invalid_count_warnings += 1 - if invalid_count_warnings == limit: - logger.warning( - "platform_stats 非法 count 告警已达到上限 (%d),后续将抑制", - limit, - ) - invalid_count_warnings += 1 - - def parse_count(raw_count: Any, key_for_log: tuple[Any, Any, Any]) -> int: - try: - return int(raw_count) - except (TypeError, ValueError): - log_invalid_count(raw_count, key_for_log) - return 0 - for row in rows: - normalized_row = dict(row) - raw_timestamp = normalized_row.get("timestamp") - normalized_timestamp = self._normalize_platform_stats_timestamp( - raw_timestamp + merge_key, normalized_row, key_for_log = self._normalize_platform_stats_row( + row ) - platform_id = normalized_row.get("platform_id") - platform_type = normalized_row.get("platform_type") - - if normalized_timestamp is not None: - normalized_row["timestamp"] = normalized_timestamp - elif isinstance(raw_timestamp, str): - normalized_row["timestamp"] = raw_timestamp.strip() - elif raw_timestamp is None: - normalized_row["timestamp"] = "" - else: - normalized_row["timestamp"] = str(raw_timestamp) - - key_for_log = ( - normalized_row.get("timestamp"), - repr(platform_id), - repr(platform_type), + count, invalid_count_warnings = self._parse_platform_stats_count( + normalized_row.get("count", 0), + key_for_log, + invalid_count_warnings, ) - count = parse_count(normalized_row.get("count", 0), key_for_log) normalized_row["count"] = count - if ( - normalized_timestamp is None - or not isinstance(platform_id, str) - or not isinstance(platform_type, str) - ): + if merge_key is None: result.append(normalized_row) continue - key = (normalized_timestamp, platform_id, platform_type) - existing = merged.get(key) + existing = merged.get(merge_key) if existing is None: - merged[key] = normalized_row + merged[merge_key] = normalized_row result.append(normalized_row) else: existing["count"] += count return result + def _normalize_platform_stats_row( + self, row: dict[str, Any] + ) -> tuple[tuple[str, str, str] | None, dict[str, Any], tuple[Any, Any, Any]]: + normalized_row = dict(row) + raw_timestamp = normalized_row.get("timestamp") + normalized_timestamp = self._normalize_platform_stats_timestamp(raw_timestamp) + platform_id = normalized_row.get("platform_id") + platform_type = normalized_row.get("platform_type") + + if normalized_timestamp is not None: + normalized_row["timestamp"] = normalized_timestamp + elif isinstance(raw_timestamp, str): + normalized_row["timestamp"] = raw_timestamp.strip() + elif raw_timestamp is None: + normalized_row["timestamp"] = "" + else: + normalized_row["timestamp"] = str(raw_timestamp) + + key_for_log = ( + normalized_row.get("timestamp"), + repr(platform_id), + repr(platform_type), + ) + + if ( + normalized_timestamp is None + or not isinstance(platform_id, str) + or not isinstance(platform_type, str) + ): + return None, normalized_row, key_for_log + return ( + (normalized_timestamp, platform_id, platform_type), + normalized_row, + key_for_log, + ) + + def _parse_platform_stats_count( + self, + raw_count: Any, + key_for_log: tuple[Any, Any, Any], + invalid_count_warnings: int, + ) -> tuple[int, int]: + try: + return int(raw_count), invalid_count_warnings + except (TypeError, ValueError): + next_warnings = self._log_platform_stats_invalid_count( + raw_count, + key_for_log, + invalid_count_warnings, + ) + return 0, next_warnings + + def _log_platform_stats_invalid_count( + self, + raw_count: Any, + key_for_log: tuple[Any, Any, Any], + invalid_count_warnings: int, + ) -> int: + """Rate-limit invalid count warnings. + + Behavior: + - limit <= 0: log the suppression message once. + - limit > 0: log invalid values up to the limit, then one suppression message. + """ + limit = PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + + if limit <= 0: + if invalid_count_warnings == 0: + logger.warning( + "platform_stats 非法 count 告警已达到上限 (%d),后续将抑制", + limit, + ) + return 1 + return invalid_count_warnings + + if invalid_count_warnings < limit: + logger.warning( + "platform_stats count 非法,已按 0 处理: value=%r, key=%s", + raw_count, + key_for_log, + ) + invalid_count_warnings += 1 + if invalid_count_warnings == limit: + logger.warning( + "platform_stats 非法 count 告警已达到上限 (%d),后续将抑制", + limit, + ) + return invalid_count_warnings + 1 + + return invalid_count_warnings + def _normalize_platform_stats_timestamp(self, value: Any) -> str | None: if isinstance(value, datetime): return self._to_utc_iso(value) From ae446edebd24c24513d8dfefb05f21e1cd2ac0f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 20:40:27 +0900 Subject: [PATCH 15/16] refactor: streamline platform stats merge helpers --- astrbot/core/backup/importer.py | 180 ++++++++++++++++---------------- 1 file changed, 90 insertions(+), 90 deletions(-) diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index 45e892b151..d9996b9d36 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -61,7 +61,35 @@ def _get_major_version(version_str: str) -> str: CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json") KB_PATH = get_astrbot_knowledge_base_path() -PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = 5 +DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = 5 +PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV = ( + "ASTRBOT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT" +) + + +def _load_platform_stats_invalid_count_warn_limit() -> int: + raw_value = os.getenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV) + if raw_value is None: + return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + + try: + value = int(raw_value) + if value < 0: + raise ValueError("negative") + return value + except (TypeError, ValueError): + logger.warning( + "Invalid env %s=%r, fallback to default %d", + PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV, + raw_value, + DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT, + ) + return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + + +PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = ( + _load_platform_stats_invalid_count_warn_limit() +) @dataclass @@ -529,7 +557,7 @@ def _preprocess_main_table_rows( normalized_rows = self._merge_platform_stats_rows(rows) duplicate_count = len(rows) - len(normalized_rows) if duplicate_count > 0: - logger.info( + logger.warning( "检测到 %s 重复键 %d 条,已在导入前聚合", table_name, duplicate_count, @@ -550,22 +578,57 @@ def _merge_platform_stats_rows( merged: dict[tuple[str, str, str], dict[str, Any]] = {} result: list[dict[str, Any]] = [] invalid_count_warnings = 0 + suppression_logged = False + limit = PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT for row in rows: - merge_key, normalized_row, key_for_log = self._normalize_platform_stats_row( - row - ) - count, invalid_count_warnings = self._parse_platform_stats_count( - normalized_row.get("count", 0), - key_for_log, - invalid_count_warnings, + normalized_row = self._normalize_platform_stats_row(row) + timestamp = normalized_row.get("timestamp") + platform_id = normalized_row.get("platform_id") + platform_type = normalized_row.get("platform_type") + key_for_log = (timestamp, repr(platform_id), repr(platform_type)) + + parsed_count = self._parse_platform_stats_count( + normalized_row.get("count", 0) ) + if parsed_count is None: + if limit > 0: + if invalid_count_warnings < limit: + logger.warning( + "platform_stats count 非法,已按 0 处理: value=%r, key=%s", + normalized_row.get("count", 0), + key_for_log, + ) + invalid_count_warnings += 1 + if invalid_count_warnings == limit and not suppression_logged: + logger.warning( + "platform_stats 非法 count 告警已达到上限 (%d),后续将抑制", + limit, + ) + suppression_logged = True + elif not suppression_logged: + # limit <= 0: emit only one suppression warning. + logger.warning( + "platform_stats 非法 count 告警已达到上限 (%d),后续将抑制", + limit, + ) + suppression_logged = True + count = 0 + else: + count = parsed_count + normalized_row["count"] = count - if merge_key is None: + normalized_timestamp = self._normalize_platform_stats_timestamp(timestamp) + if ( + normalized_timestamp is None + or not isinstance(platform_id, str) + or not isinstance(platform_type, str) + ): result.append(normalized_row) continue + merge_key = (normalized_timestamp, platform_id, platform_type) existing = merged.get(merge_key) if existing is None: merged[merge_key] = normalized_row @@ -575,14 +638,10 @@ def _merge_platform_stats_rows( return result - def _normalize_platform_stats_row( - self, row: dict[str, Any] - ) -> tuple[tuple[str, str, str] | None, dict[str, Any], tuple[Any, Any, Any]]: + def _normalize_platform_stats_row(self, row: dict[str, Any]) -> dict[str, Any]: normalized_row = dict(row) raw_timestamp = normalized_row.get("timestamp") normalized_timestamp = self._normalize_platform_stats_timestamp(raw_timestamp) - platform_id = normalized_row.get("platform_id") - platform_type = normalized_row.get("platform_type") if normalized_timestamp is not None: normalized_row["timestamp"] = normalized_timestamp @@ -593,82 +652,25 @@ def _normalize_platform_stats_row( else: normalized_row["timestamp"] = str(raw_timestamp) - key_for_log = ( - normalized_row.get("timestamp"), - repr(platform_id), - repr(platform_type), - ) - - if ( - normalized_timestamp is None - or not isinstance(platform_id, str) - or not isinstance(platform_type, str) - ): - return None, normalized_row, key_for_log - return ( - (normalized_timestamp, platform_id, platform_type), - normalized_row, - key_for_log, - ) + return normalized_row def _parse_platform_stats_count( self, raw_count: Any, - key_for_log: tuple[Any, Any, Any], - invalid_count_warnings: int, - ) -> tuple[int, int]: + ) -> int | None: try: - return int(raw_count), invalid_count_warnings + return int(raw_count) except (TypeError, ValueError): - next_warnings = self._log_platform_stats_invalid_count( - raw_count, - key_for_log, - invalid_count_warnings, - ) - return 0, next_warnings - - def _log_platform_stats_invalid_count( - self, - raw_count: Any, - key_for_log: tuple[Any, Any, Any], - invalid_count_warnings: int, - ) -> int: - """Rate-limit invalid count warnings. - - Behavior: - - limit <= 0: log the suppression message once. - - limit > 0: log invalid values up to the limit, then one suppression message. - """ - limit = PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT - - if limit <= 0: - if invalid_count_warnings == 0: - logger.warning( - "platform_stats 非法 count 告警已达到上限 (%d),后续将抑制", - limit, - ) - return 1 - return invalid_count_warnings - - if invalid_count_warnings < limit: - logger.warning( - "platform_stats count 非法,已按 0 处理: value=%r, key=%s", - raw_count, - key_for_log, - ) - invalid_count_warnings += 1 - if invalid_count_warnings == limit: - logger.warning( - "platform_stats 非法 count 告警已达到上限 (%d),后续将抑制", - limit, - ) - return invalid_count_warnings + 1 - - return invalid_count_warnings + return None def _normalize_platform_stats_timestamp(self, value: Any) -> str | None: if isinstance(value, datetime): - return self._to_utc_iso(value) + dt = value + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + else: + dt = dt.astimezone(timezone.utc) + return dt.isoformat() if isinstance(value, str): timestamp = value.strip() if not timestamp: @@ -676,18 +678,16 @@ def _normalize_platform_stats_timestamp(self, value: Any) -> str | None: if timestamp.endswith("Z"): timestamp = f"{timestamp[:-1]}+00:00" try: - return self._to_utc_iso(datetime.fromisoformat(timestamp)) + dt = datetime.fromisoformat(timestamp) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + else: + dt = dt.astimezone(timezone.utc) + return dt.isoformat() except ValueError: return None return None - def _to_utc_iso(self, dt: datetime) -> str: - if dt.tzinfo is None: - dt = dt.replace(tzinfo=timezone.utc) - else: - dt = dt.astimezone(timezone.utc) - return dt.isoformat() - async def _import_knowledge_bases( self, zf: zipfile.ZipFile, From 1cb1348675496d6e546177f9a68d888bdae17962 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 20:44:31 +0900 Subject: [PATCH 16/16] refactor: isolate platform stats warning limiter --- astrbot/core/backup/importer.py | 102 +++++++++++++++++--------------- 1 file changed, 55 insertions(+), 47 deletions(-) diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index d9996b9d36..b51c7d9560 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -92,6 +92,40 @@ def _load_platform_stats_invalid_count_warn_limit() -> int: ) +class _InvalidCountWarnLimiter: + """Rate-limit warnings for invalid platform_stats count values.""" + + def __init__(self, limit: int) -> None: + self.limit = limit + self._count = 0 + self._suppression_logged = False + + def warn_invalid_count(self, value: Any, key_for_log: tuple[Any, ...]) -> None: + if self.limit > 0: + if self._count < self.limit: + logger.warning( + "platform_stats count 非法,已按 0 处理: value=%r, key=%s", + value, + key_for_log, + ) + self._count += 1 + if self._count == self.limit and not self._suppression_logged: + logger.warning( + "platform_stats 非法 count 告警已达到上限 (%d),后续将抑制", + self.limit, + ) + self._suppression_logged = True + return + + if not self._suppression_logged: + # limit <= 0: emit only one suppression warning. + logger.warning( + "platform_stats 非法 count 告警已达到上限 (%d),后续将抑制", + self.limit, + ) + self._suppression_logged = True + + @dataclass class ImportPreCheckResult: """导入预检查结果 @@ -577,49 +611,15 @@ def _merge_platform_stats_rows( """ merged: dict[tuple[str, str, str], dict[str, Any]] = {} result: list[dict[str, Any]] = [] - invalid_count_warnings = 0 - suppression_logged = False - limit = PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT + warn_limiter = _InvalidCountWarnLimiter(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT) for row in rows: - normalized_row = self._normalize_platform_stats_row(row) - timestamp = normalized_row.get("timestamp") + normalized_row, normalized_timestamp, count = ( + self._normalize_platform_stats_entry(row, warn_limiter) + ) platform_id = normalized_row.get("platform_id") platform_type = normalized_row.get("platform_type") - key_for_log = (timestamp, repr(platform_id), repr(platform_type)) - parsed_count = self._parse_platform_stats_count( - normalized_row.get("count", 0) - ) - if parsed_count is None: - if limit > 0: - if invalid_count_warnings < limit: - logger.warning( - "platform_stats count 非法,已按 0 处理: value=%r, key=%s", - normalized_row.get("count", 0), - key_for_log, - ) - invalid_count_warnings += 1 - if invalid_count_warnings == limit and not suppression_logged: - logger.warning( - "platform_stats 非法 count 告警已达到上限 (%d),后续将抑制", - limit, - ) - suppression_logged = True - elif not suppression_logged: - # limit <= 0: emit only one suppression warning. - logger.warning( - "platform_stats 非法 count 告警已达到上限 (%d),后续将抑制", - limit, - ) - suppression_logged = True - count = 0 - else: - count = parsed_count - - normalized_row["count"] = count - - normalized_timestamp = self._normalize_platform_stats_timestamp(timestamp) if ( normalized_timestamp is None or not isinstance(platform_id, str) @@ -638,7 +638,11 @@ def _merge_platform_stats_rows( return result - def _normalize_platform_stats_row(self, row: dict[str, Any]) -> dict[str, Any]: + def _normalize_platform_stats_entry( + self, + row: dict[str, Any], + warn_limiter: _InvalidCountWarnLimiter, + ) -> tuple[dict[str, Any], str | None, int]: normalized_row = dict(row) raw_timestamp = normalized_row.get("timestamp") normalized_timestamp = self._normalize_platform_stats_timestamp(raw_timestamp) @@ -652,16 +656,20 @@ def _normalize_platform_stats_row(self, row: dict[str, Any]) -> dict[str, Any]: else: normalized_row["timestamp"] = str(raw_timestamp) - return normalized_row - - def _parse_platform_stats_count( - self, - raw_count: Any, - ) -> int | None: + raw_count = normalized_row.get("count", 0) try: - return int(raw_count) + count = int(raw_count) except (TypeError, ValueError): - return None + key_for_log = ( + normalized_row.get("timestamp"), + repr(normalized_row.get("platform_id")), + repr(normalized_row.get("platform_type")), + ) + warn_limiter.warn_invalid_count(raw_count, key_for_log) + count = 0 + + normalized_row["count"] = count + return normalized_row, normalized_timestamp, count def _normalize_platform_stats_timestamp(self, value: Any) -> str | None: if isinstance(value, datetime):