diff --git a/gigl/common/logger.py b/gigl/common/logger.py index 52102f3e9..5e46a1fe9 100644 --- a/gigl/common/logger.py +++ b/gigl/common/logger.py @@ -1,22 +1,88 @@ +import json import logging import os import pathlib -from datetime import datetime +import sys +from datetime import datetime, timezone from typing import Any, MutableMapping, Optional -from google.cloud import logging as google_cloud_logging - _BASE_LOG_FILE_PATH = "/tmp/research/gbml/logs" +_PYTHON_LEVEL_TO_GCP_SEVERITY: dict[str, str] = { + "DEBUG": "DEBUG", + "INFO": "INFO", + "WARNING": "WARNING", + "ERROR": "ERROR", + "CRITICAL": "CRITICAL", +} -class Logger(logging.LoggerAdapter): +# Key used by Logger.process() to pass user-supplied extras to the formatter +# without mixing them into the LogRecord's built-in attributes. +_GCP_LABELS_RECORD_ATTR: str = "_gcp_labels" + + +class _GcpJsonFormatter(logging.Formatter): + """A ``logging.Formatter`` that outputs one JSON object per line with + `GCP-recognized structured logging fields + `_. + + Fields emitted: + + - ``severity`` -- mapped from the Python log level. + - ``message`` -- the formatted log message (with traceback appended when present). + - ``time`` -- ISO 8601 UTC timestamp. + - ``logging.googleapis.com/sourceLocation`` -- ``{file, line, function}``. + - ``logging.googleapis.com/labels`` -- any extra fields supplied via the + ``extra`` dict on the ``Logger`` adapter. Omitted when there are no extras. """ - GiGL's custom logger class used for local and cloud logging (VertexAI, Dataflow, etc.) + + def format(self, record: logging.LogRecord) -> str: + """Format *record* as a single-line JSON string. + + Args: + record: The ``LogRecord`` to format. + + Returns: + A JSON string (no trailing newline) suitable for writing to + ``sys.stderr`` on GCP-managed environments. + """ + message = record.getMessage() + + if record.exc_info and not record.exc_text: + record.exc_text = self.formatException(record.exc_info) + if record.exc_text: + message = f"{message}\n{record.exc_text}" + if record.stack_info: + message = f"{message}\n{record.stack_info}" + + payload: dict[str, object] = { + "severity": _PYTHON_LEVEL_TO_GCP_SEVERITY.get( + record.levelname, record.levelname + ), + "message": message, + "time": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(), + "logging.googleapis.com/sourceLocation": { + "file": record.pathname, + "line": record.lineno, + "function": record.funcName, + }, + } + + labels: dict[str, object] = getattr(record, _GCP_LABELS_RECORD_ATTR, {}) + if labels: + payload["logging.googleapis.com/labels"] = labels + + return json.dumps(payload, ensure_ascii=False, default=str) + + +class Logger(logging.LoggerAdapter): + """GiGL's custom logger class used for local and cloud logging (VertexAI, Dataflow, etc.). + Args: - logger (Optional[logging.Logger]): A custom logger to use. If not provided, the default logger will be created. - name (Optional[str]): The name to be used for the logger. By default uses "root". - log_to_file (bool): If True, logs will be written to a file. If False, logs will be written to the console. - extra (Optional[dict[str, Any]]): Extra information to be added to the log message. + logger: A custom logger to use. If not provided, the default logger will be created. + name: The name to be used for the logger. By default uses "root". + log_to_file: If True, logs will be written to a file. If False, logs will be written to the console. + extra: Extra information to be added to the log message. """ def __init__( @@ -37,12 +103,11 @@ def _setup_logger( ) -> None: handler: logging.Handler if not logger.handlers: - if os.getenv("GAE_APPLICATION") or os.environ.get( - "KUBERNETES_SERVICE_HOST" - ): - # Google Cloud Logging - client = google_cloud_logging.Client() - client.setup_logging(log_level=logging.INFO) + # Check if running on GCP. + if os.getenv("GAE_APPLICATION") or os.getenv("KUBERNETES_SERVICE_HOST"): + handler = logging.StreamHandler(stream=sys.stderr) + handler.setFormatter(_GcpJsonFormatter()) + logger.addHandler(handler) else: # Logging locally. Set up logging to console or file if log_to_file: @@ -64,10 +129,10 @@ def _setup_logger( logger.setLevel(logging.INFO) def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> Any: + merged: dict[str, Any] = dict(self.extra if self.extra else {}) if "extra" in kwargs: - kwargs["extra"].update(self.extra) - else: - kwargs["extra"] = self.extra + merged.update(kwargs["extra"]) + kwargs["extra"] = {**merged, _GCP_LABELS_RECORD_ATTR: merged} return msg, kwargs def __getattr__(self, name: str): diff --git a/tests/unit/common/logger_test.py b/tests/unit/common/logger_test.py new file mode 100644 index 000000000..56f201726 --- /dev/null +++ b/tests/unit/common/logger_test.py @@ -0,0 +1,101 @@ +import json +import logging +import sys + +from gigl.common.logger import _GCP_LABELS_RECORD_ATTR, _GcpJsonFormatter +from tests.test_assets.test_case import TestCase + + +class GcpJsonFormatterTest(TestCase): + def setUp(self) -> None: + self.formatter = _GcpJsonFormatter() + + def _make_record( + self, + message: str = "test message", + level: int = logging.INFO, + exc_info: object = None, + ) -> logging.LogRecord: + """Create a ``LogRecord`` with deterministic source location.""" + record = logging.LogRecord( + name="test", + level=level, + pathname="test_file.py", + lineno=42, + msg=message, + args=None, + exc_info=exc_info, # type: ignore[arg-type] + ) + record.funcName = "test_func" + return record + + def test_basic_info_message_produces_valid_json(self) -> None: + record = self._make_record() + output = self.formatter.format(record) + parsed = json.loads(output) + + self.assertEqual(parsed["severity"], "INFO") + self.assertEqual(parsed["message"], "test message") + self.assertIn("time", parsed) + self.assertIn("logging.googleapis.com/sourceLocation", parsed) + + def test_severity_mapping_for_all_levels(self) -> None: + levels = { + logging.DEBUG: "DEBUG", + logging.INFO: "INFO", + logging.WARNING: "WARNING", + logging.ERROR: "ERROR", + logging.CRITICAL: "CRITICAL", + } + for python_level, expected_severity in levels.items(): + with self.subTest(level=python_level): + record = self._make_record(level=python_level) + parsed = json.loads(self.formatter.format(record)) + self.assertEqual(parsed["severity"], expected_severity) + + def test_output_is_single_line(self) -> None: + record = self._make_record() + output = self.formatter.format(record) + self.assertEqual(output.count("\n"), 0) + + def test_time_field_is_iso_8601(self) -> None: + record = self._make_record() + parsed = json.loads(self.formatter.format(record)) + time_str = parsed["time"] + # ISO 8601 with timezone: contains 'T' separator and '+' offset + self.assertIn("T", time_str) + self.assertIn("+", time_str) + + def test_extra_fields_appear_under_labels(self) -> None: + record = self._make_record() + setattr(record, _GCP_LABELS_RECORD_ATTR, {"custom_key": "custom_value"}) + parsed = json.loads(self.formatter.format(record)) + + labels = parsed["logging.googleapis.com/labels"] + self.assertEqual(labels["custom_key"], "custom_value") + + def test_no_labels_key_when_no_extras(self) -> None: + record = self._make_record() + parsed = json.loads(self.formatter.format(record)) + self.assertNotIn("logging.googleapis.com/labels", parsed) + + def test_exception_traceback_in_message(self) -> None: + try: + raise ValueError("boom") + except ValueError: + exc_info = sys.exc_info() + + record = self._make_record(exc_info=exc_info) + parsed = json.loads(self.formatter.format(record)) + + self.assertIn("ValueError: boom", parsed["message"]) + self.assertIn("Traceback", parsed["message"]) + + def test_source_location_fields(self) -> None: + record = self._make_record() + parsed = json.loads(self.formatter.format(record)) + source = parsed["logging.googleapis.com/sourceLocation"] + + self.assertEqual(source["file"], "test_file.py") + self.assertEqual(source["line"], 42) + self.assertEqual(source["function"], "test_func")