From 176df9304c796a878f07d3b000cdfb1da66c0497 Mon Sep 17 00:00:00 2001 From: Siddhant Rao Date: Mon, 9 Mar 2026 23:10:49 +0000 Subject: [PATCH 1/5] feat: add SparkMonitor integration for VS Code progress display Intercepts gRPC ExecutePlanResponse stream to extract SparkMonitor progress messages (field 24) and forward them to the VS Code SparkMonitor extension via IPython display outputs. Changes: - Add SparkMonitor interception pipeline to DataprocSparkSession: - _setup_cell_execution_tracking: hooks IPython pre_run_cell events to assign a unique run ID per cell for display routing - _setup_sparkmonitor_interception: wraps the gRPC response stream with a background consumer thread + queue to extract SparkMonitor data without blocking PySpark - _extract_and_send_sparkmonitor: manually locates and slices field 24 from raw response bytes (field is unknown to PySpark's proto class), parses with sparkmonitor_pb2, and forwards to VS Code - _proto_to_scala_json_format: converts SparkMonitorProgress proto to the Scala-compatible JSON format expected by the VS Code extension - _convert_string_numbers_to_int: recursively converts string number fields (e.g. jobId, numTasks) to ints - _send_to_vscode: emits display output with application/vnd.sparkmonitor+json mime type - Add protobuf>=3.20.0 to install_requires and include proto/pb2 files in package_data - Add SparkMonitorTests unit test class (16 tests) covering all new methods including binary extraction, proto conversion, IPython display routing, and graceful fallback when IPython is absent --- .../dataproc_spark_connect/proto/__init__.py | 0 .../proto/sparkmonitor.proto | 187 +++++++++++ .../proto/sparkmonitor_pb2.py | 56 ++++ .../cloud/dataproc_spark_connect/session.py | 305 ++++++++++++++++++ setup.py | 10 +- tests/unit/test_session.py | 254 +++++++++++++++ 6 files changed, 811 insertions(+), 1 deletion(-) create mode 100644 google/cloud/dataproc_spark_connect/proto/__init__.py create mode 100644 google/cloud/dataproc_spark_connect/proto/sparkmonitor.proto create mode 100644 google/cloud/dataproc_spark_connect/proto/sparkmonitor_pb2.py diff --git a/google/cloud/dataproc_spark_connect/proto/__init__.py b/google/cloud/dataproc_spark_connect/proto/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/google/cloud/dataproc_spark_connect/proto/sparkmonitor.proto b/google/cloud/dataproc_spark_connect/proto/sparkmonitor.proto new file mode 100644 index 00000000..36e03c7f --- /dev/null +++ b/google/cloud/dataproc_spark_connect/proto/sparkmonitor.proto @@ -0,0 +1,187 @@ +syntax = "proto3"; + +package sparkmonitor; + +option java_multiple_files = true; +option java_package = "org.apache.spark.connect.proto"; + +message SparkMonitorProgress { + // Message type indicating which event this is + string msg_type = 1; + + // Application events + optional ApplicationStartData application_start = 2; + optional ApplicationEndData application_end = 3; + + // Job events + optional JobStartData job_start = 4; + optional JobEndData job_end = 5; + + // Stage events + optional StageSubmittedData stage_submitted = 6; + optional StageCompletedData stage_completed = 7; + optional StageActiveData stage_active = 8; + + // Task events + optional TaskStartData task_start = 9; + optional TaskEndData task_end = 10; + + // Executor events + optional ExecutorAddedData executor_added = 11; + optional ExecutorRemovedData executor_removed = 12; + + // Application Start/End + message ApplicationStartData { + int64 start_time = 1; + string app_id = 2; + string app_attempt_id = 3; + string app_name = 4; + string spark_user = 5; + } + + message ApplicationEndData { + int64 end_time = 1; + } + + // Job Start/End + message JobStartData { + string job_group = 1; + int64 job_id = 2; + string status = 3; + int64 submission_time = 4; + repeated int32 stage_ids = 5; + map stage_infos = 6; + int32 num_tasks = 7; + int32 total_cores = 8; + string app_id = 9; + int32 num_executors = 10; + string name = 11; + } + + message StageInfoForJob { + int32 attempt_id = 1; + string name = 2; + int32 num_tasks = 3; + int64 completion_time = 4; + int64 submission_time = 5; + } + + message JobEndData { + int64 job_id = 1; + string status = 2; + int64 completion_time = 3; + } + + // Stage Submitted/Completed/Active + message StageSubmittedData { + int64 stage_id = 1; + int32 stage_attempt_id = 2; + string name = 3; + int32 num_tasks = 4; + repeated int32 parent_ids = 5; + int64 submission_time = 6; + repeated int64 job_ids = 7; + int32 num_active_tasks = 8; + int32 num_failed_tasks = 9; + int32 num_completed_tasks = 10; + } + + message StageCompletedData { + int64 stage_id = 1; + int32 stage_attempt_id = 2; + int64 completion_time = 3; + int64 submission_time = 4; + int32 num_tasks = 5; + int32 num_failed_tasks = 6; + int32 num_completed_tasks = 7; + string status = 8; + repeated int64 job_ids = 9; + } + + message StageActiveData { + int64 stage_id = 1; + int32 stage_attempt_id = 2; + string name = 3; + repeated int32 parent_ids = 4; + int32 num_tasks = 5; + int32 num_active_tasks = 6; + int32 num_failed_tasks = 7; + int32 num_completed_tasks = 8; + repeated int64 job_ids = 9; + } + + // Task Start/End + message TaskStartData { + int64 launch_time = 1; + int64 task_id = 2; + int64 stage_id = 3; + int32 stage_attempt_id = 4; + int32 index = 5; + int32 attempt_number = 6; + string executor_id = 7; + string host = 8; + string status = 9; + bool speculative = 10; + } + + message TaskEndData { + int64 launch_time = 1; + int64 finish_time = 2; + int64 task_id = 3; + int64 stage_id = 4; + string task_type = 5; + int32 stage_attempt_id = 6; + int32 index = 7; + int32 attempt_number = 8; + string executor_id = 9; + string host = 10; + string status = 11; + bool speculative = 12; + string error_message = 13; + TaskMetrics metrics = 14; // Also optional since it may not be present + } + + message TaskMetrics { + int64 shuffle_read_time = 1; + int64 shuffle_write_time = 2; + int64 serialization_time = 3; + int64 deserialization_time = 4; + int64 getting_result_time = 5; + int64 executor_computing_time = 6; + int64 scheduler_delay = 7; + double shuffle_read_time_proportion = 8; + double shuffle_write_time_proportion = 9; + double serialization_time_proportion = 10; + double deserialization_time_proportion = 11; + double getting_result_time_proportion = 12; + double executor_computing_time_proportion = 13; + double scheduler_delay_proportion = 14; + double shuffle_read_time_proportion_pos = 15; + double shuffle_write_time_proportion_pos = 16; + double serialization_time_proportion_pos = 17; + double deserialization_time_proportion_pos = 18; + double getting_result_time_proportion_pos = 19; + double executor_computing_time_proportion_pos = 20; + double scheduler_delay_proportion_pos = 21; + int64 result_size = 22; + int64 jvm_gc_time = 23; + int64 memory_bytes_spilled = 24; + int64 disk_bytes_spilled = 25; + int64 peak_execution_memory = 26; + } + + // Executor Added/Removed + message ExecutorAddedData { + string executor_id = 1; + int64 time = 2; + string host = 3; + int32 num_cores = 4; + int32 total_cores = 5; + } + + message ExecutorRemovedData { + string executor_id = 1; + int64 time = 2; + int32 total_cores = 3; + } +} \ No newline at end of file diff --git a/google/cloud/dataproc_spark_connect/proto/sparkmonitor_pb2.py b/google/cloud/dataproc_spark_connect/proto/sparkmonitor_pb2.py new file mode 100644 index 00000000..a609e673 --- /dev/null +++ b/google/cloud/dataproc_spark_connect/proto/sparkmonitor_pb2.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: sparkmonitor.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12sparkmonitor.proto\x12\x0csparkmonitor\"\xf3 \n\x14SparkMonitorProgress\x12\x10\n\x08msg_type\x18\x01 \x01(\t\x12W\n\x11\x61pplication_start\x18\x02 \x01(\x0b\x32\x37.sparkmonitor.SparkMonitorProgress.ApplicationStartDataH\x00\x88\x01\x01\x12S\n\x0f\x61pplication_end\x18\x03 \x01(\x0b\x32\x35.sparkmonitor.SparkMonitorProgress.ApplicationEndDataH\x01\x88\x01\x01\x12G\n\tjob_start\x18\x04 \x01(\x0b\x32/.sparkmonitor.SparkMonitorProgress.JobStartDataH\x02\x88\x01\x01\x12\x43\n\x07job_end\x18\x05 \x01(\x0b\x32-.sparkmonitor.SparkMonitorProgress.JobEndDataH\x03\x88\x01\x01\x12S\n\x0fstage_submitted\x18\x06 \x01(\x0b\x32\x35.sparkmonitor.SparkMonitorProgress.StageSubmittedDataH\x04\x88\x01\x01\x12S\n\x0fstage_completed\x18\x07 \x01(\x0b\x32\x35.sparkmonitor.SparkMonitorProgress.StageCompletedDataH\x05\x88\x01\x01\x12M\n\x0cstage_active\x18\x08 \x01(\x0b\x32\x32.sparkmonitor.SparkMonitorProgress.StageActiveDataH\x06\x88\x01\x01\x12I\n\ntask_start\x18\t \x01(\x0b\x32\x30.sparkmonitor.SparkMonitorProgress.TaskStartDataH\x07\x88\x01\x01\x12\x45\n\x08task_end\x18\n \x01(\x0b\x32..sparkmonitor.SparkMonitorProgress.TaskEndDataH\x08\x88\x01\x01\x12Q\n\x0e\x65xecutor_added\x18\x0b \x01(\x0b\x32\x34.sparkmonitor.SparkMonitorProgress.ExecutorAddedDataH\t\x88\x01\x01\x12U\n\x10\x65xecutor_removed\x18\x0c \x01(\x0b\x32\x36.sparkmonitor.SparkMonitorProgress.ExecutorRemovedDataH\n\x88\x01\x01\x1ax\n\x14\x41pplicationStartData\x12\x12\n\nstart_time\x18\x01 \x01(\x03\x12\x0e\n\x06\x61pp_id\x18\x02 \x01(\t\x12\x16\n\x0e\x61pp_attempt_id\x18\x03 \x01(\t\x12\x10\n\x08\x61pp_name\x18\x04 \x01(\t\x12\x12\n\nspark_user\x18\x05 \x01(\t\x1a&\n\x12\x41pplicationEndData\x12\x10\n\x08\x65nd_time\x18\x01 \x01(\x03\x1a\x87\x03\n\x0cJobStartData\x12\x11\n\tjob_group\x18\x01 \x01(\t\x12\x0e\n\x06job_id\x18\x02 \x01(\x03\x12\x0e\n\x06status\x18\x03 \x01(\t\x12\x17\n\x0fsubmission_time\x18\x04 \x01(\x03\x12\x11\n\tstage_ids\x18\x05 \x03(\x05\x12T\n\x0bstage_infos\x18\x06 \x03(\x0b\x32?.sparkmonitor.SparkMonitorProgress.JobStartData.StageInfosEntry\x12\x11\n\tnum_tasks\x18\x07 \x01(\x05\x12\x13\n\x0btotal_cores\x18\x08 \x01(\x05\x12\x0e\n\x06\x61pp_id\x18\t \x01(\t\x12\x15\n\rnum_executors\x18\n \x01(\x05\x12\x0c\n\x04name\x18\x0b \x01(\t\x1a\x65\n\x0fStageInfosEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32\x32.sparkmonitor.SparkMonitorProgress.StageInfoForJob:\x02\x38\x01\x1ax\n\x0fStageInfoForJob\x12\x12\n\nattempt_id\x18\x01 \x01(\x05\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x11\n\tnum_tasks\x18\x03 \x01(\x05\x12\x17\n\x0f\x63ompletion_time\x18\x04 \x01(\x03\x12\x17\n\x0fsubmission_time\x18\x05 \x01(\x03\x1a\x45\n\nJobEndData\x12\x0e\n\x06job_id\x18\x01 \x01(\x03\x12\x0e\n\x06status\x18\x02 \x01(\t\x12\x17\n\x0f\x63ompletion_time\x18\x03 \x01(\x03\x1a\xf0\x01\n\x12StageSubmittedData\x12\x10\n\x08stage_id\x18\x01 \x01(\x03\x12\x18\n\x10stage_attempt_id\x18\x02 \x01(\x05\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x11\n\tnum_tasks\x18\x04 \x01(\x05\x12\x12\n\nparent_ids\x18\x05 \x03(\x05\x12\x17\n\x0fsubmission_time\x18\x06 \x01(\x03\x12\x0f\n\x07job_ids\x18\x07 \x03(\x03\x12\x18\n\x10num_active_tasks\x18\x08 \x01(\x05\x12\x18\n\x10num_failed_tasks\x18\t \x01(\x05\x12\x1b\n\x13num_completed_tasks\x18\n \x01(\x05\x1a\xdd\x01\n\x12StageCompletedData\x12\x10\n\x08stage_id\x18\x01 \x01(\x03\x12\x18\n\x10stage_attempt_id\x18\x02 \x01(\x05\x12\x17\n\x0f\x63ompletion_time\x18\x03 \x01(\x03\x12\x17\n\x0fsubmission_time\x18\x04 \x01(\x03\x12\x11\n\tnum_tasks\x18\x05 \x01(\x05\x12\x18\n\x10num_failed_tasks\x18\x06 \x01(\x05\x12\x1b\n\x13num_completed_tasks\x18\x07 \x01(\x05\x12\x0e\n\x06status\x18\x08 \x01(\t\x12\x0f\n\x07job_ids\x18\t \x03(\x03\x1a\xd4\x01\n\x0fStageActiveData\x12\x10\n\x08stage_id\x18\x01 \x01(\x03\x12\x18\n\x10stage_attempt_id\x18\x02 \x01(\x05\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x12\n\nparent_ids\x18\x04 \x03(\x05\x12\x11\n\tnum_tasks\x18\x05 \x01(\x05\x12\x18\n\x10num_active_tasks\x18\x06 \x01(\x05\x12\x18\n\x10num_failed_tasks\x18\x07 \x01(\x05\x12\x1b\n\x13num_completed_tasks\x18\x08 \x01(\x05\x12\x0f\n\x07job_ids\x18\t \x03(\x03\x1a\xd0\x01\n\rTaskStartData\x12\x13\n\x0blaunch_time\x18\x01 \x01(\x03\x12\x0f\n\x07task_id\x18\x02 \x01(\x03\x12\x10\n\x08stage_id\x18\x03 \x01(\x03\x12\x18\n\x10stage_attempt_id\x18\x04 \x01(\x05\x12\r\n\x05index\x18\x05 \x01(\x05\x12\x16\n\x0e\x61ttempt_number\x18\x06 \x01(\x05\x12\x13\n\x0b\x65xecutor_id\x18\x07 \x01(\t\x12\x0c\n\x04host\x18\x08 \x01(\t\x12\x0e\n\x06status\x18\t \x01(\t\x12\x13\n\x0bspeculative\x18\n \x01(\x08\x1a\xce\x02\n\x0bTaskEndData\x12\x13\n\x0blaunch_time\x18\x01 \x01(\x03\x12\x13\n\x0b\x66inish_time\x18\x02 \x01(\x03\x12\x0f\n\x07task_id\x18\x03 \x01(\x03\x12\x10\n\x08stage_id\x18\x04 \x01(\x03\x12\x11\n\ttask_type\x18\x05 \x01(\t\x12\x18\n\x10stage_attempt_id\x18\x06 \x01(\x05\x12\r\n\x05index\x18\x07 \x01(\x05\x12\x16\n\x0e\x61ttempt_number\x18\x08 \x01(\x05\x12\x13\n\x0b\x65xecutor_id\x18\t \x01(\t\x12\x0c\n\x04host\x18\n \x01(\t\x12\x0e\n\x06status\x18\x0b \x01(\t\x12\x13\n\x0bspeculative\x18\x0c \x01(\x08\x12\x15\n\rerror_message\x18\r \x01(\t\x12?\n\x07metrics\x18\x0e \x01(\x0b\x32..sparkmonitor.SparkMonitorProgress.TaskMetrics\x1a\x9e\x07\n\x0bTaskMetrics\x12\x19\n\x11shuffle_read_time\x18\x01 \x01(\x03\x12\x1a\n\x12shuffle_write_time\x18\x02 \x01(\x03\x12\x1a\n\x12serialization_time\x18\x03 \x01(\x03\x12\x1c\n\x14\x64\x65serialization_time\x18\x04 \x01(\x03\x12\x1b\n\x13getting_result_time\x18\x05 \x01(\x03\x12\x1f\n\x17\x65xecutor_computing_time\x18\x06 \x01(\x03\x12\x17\n\x0fscheduler_delay\x18\x07 \x01(\x03\x12$\n\x1cshuffle_read_time_proportion\x18\x08 \x01(\x01\x12%\n\x1dshuffle_write_time_proportion\x18\t \x01(\x01\x12%\n\x1dserialization_time_proportion\x18\n \x01(\x01\x12\'\n\x1f\x64\x65serialization_time_proportion\x18\x0b \x01(\x01\x12&\n\x1egetting_result_time_proportion\x18\x0c \x01(\x01\x12*\n\"executor_computing_time_proportion\x18\r \x01(\x01\x12\"\n\x1ascheduler_delay_proportion\x18\x0e \x01(\x01\x12(\n shuffle_read_time_proportion_pos\x18\x0f \x01(\x01\x12)\n!shuffle_write_time_proportion_pos\x18\x10 \x01(\x01\x12)\n!serialization_time_proportion_pos\x18\x11 \x01(\x01\x12+\n#deserialization_time_proportion_pos\x18\x12 \x01(\x01\x12*\n\"getting_result_time_proportion_pos\x18\x13 \x01(\x01\x12.\n&executor_computing_time_proportion_pos\x18\x14 \x01(\x01\x12&\n\x1escheduler_delay_proportion_pos\x18\x15 \x01(\x01\x12\x13\n\x0bresult_size\x18\x16 \x01(\x03\x12\x13\n\x0bjvm_gc_time\x18\x17 \x01(\x03\x12\x1c\n\x14memory_bytes_spilled\x18\x18 \x01(\x03\x12\x1a\n\x12\x64isk_bytes_spilled\x18\x19 \x01(\x03\x12\x1d\n\x15peak_execution_memory\x18\x1a \x01(\x03\x1al\n\x11\x45xecutorAddedData\x12\x13\n\x0b\x65xecutor_id\x18\x01 \x01(\t\x12\x0c\n\x04time\x18\x02 \x01(\x03\x12\x0c\n\x04host\x18\x03 \x01(\t\x12\x11\n\tnum_cores\x18\x04 \x01(\x05\x12\x13\n\x0btotal_cores\x18\x05 \x01(\x05\x1aM\n\x13\x45xecutorRemovedData\x12\x13\n\x0b\x65xecutor_id\x18\x01 \x01(\t\x12\x0c\n\x04time\x18\x02 \x01(\x03\x12\x13\n\x0btotal_cores\x18\x03 \x01(\x05\x42\x14\n\x12_application_startB\x12\n\x10_application_endB\x0c\n\n_job_startB\n\n\x08_job_endB\x12\n\x10_stage_submittedB\x12\n\x10_stage_completedB\x0f\n\r_stage_activeB\r\n\x0b_task_startB\x0b\n\t_task_endB\x11\n\x0f_executor_addedB\x13\n\x11_executor_removedB\"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'sparkmonitor_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n\036org.apache.spark.connect.protoP\001' + _SPARKMONITORPROGRESS_JOBSTARTDATA_STAGEINFOSENTRY._options = None + _SPARKMONITORPROGRESS_JOBSTARTDATA_STAGEINFOSENTRY._serialized_options = b'8\001' + _SPARKMONITORPROGRESS._serialized_start=37 + _SPARKMONITORPROGRESS._serialized_end=4248 + _SPARKMONITORPROGRESS_APPLICATIONSTARTDATA._serialized_start=960 + _SPARKMONITORPROGRESS_APPLICATIONSTARTDATA._serialized_end=1080 + _SPARKMONITORPROGRESS_APPLICATIONENDDATA._serialized_start=1082 + _SPARKMONITORPROGRESS_APPLICATIONENDDATA._serialized_end=1120 + _SPARKMONITORPROGRESS_JOBSTARTDATA._serialized_start=1123 + _SPARKMONITORPROGRESS_JOBSTARTDATA._serialized_end=1514 + _SPARKMONITORPROGRESS_JOBSTARTDATA_STAGEINFOSENTRY._serialized_start=1413 + _SPARKMONITORPROGRESS_JOBSTARTDATA_STAGEINFOSENTRY._serialized_end=1514 + _SPARKMONITORPROGRESS_STAGEINFOFORJOB._serialized_start=1516 + _SPARKMONITORPROGRESS_STAGEINFOFORJOB._serialized_end=1636 + _SPARKMONITORPROGRESS_JOBENDDATA._serialized_start=1638 + _SPARKMONITORPROGRESS_JOBENDDATA._serialized_end=1707 + _SPARKMONITORPROGRESS_STAGESUBMITTEDDATA._serialized_start=1710 + _SPARKMONITORPROGRESS_STAGESUBMITTEDDATA._serialized_end=1950 + _SPARKMONITORPROGRESS_STAGECOMPLETEDDATA._serialized_start=1953 + _SPARKMONITORPROGRESS_STAGECOMPLETEDDATA._serialized_end=2174 + _SPARKMONITORPROGRESS_STAGEACTIVEDATA._serialized_start=2177 + _SPARKMONITORPROGRESS_STAGEACTIVEDATA._serialized_end=2389 + _SPARKMONITORPROGRESS_TASKSTARTDATA._serialized_start=2392 + _SPARKMONITORPROGRESS_TASKSTARTDATA._serialized_end=2600 + _SPARKMONITORPROGRESS_TASKENDDATA._serialized_start=2603 + _SPARKMONITORPROGRESS_TASKENDDATA._serialized_end=2937 + _SPARKMONITORPROGRESS_TASKMETRICS._serialized_start=2940 + _SPARKMONITORPROGRESS_TASKMETRICS._serialized_end=3866 + _SPARKMONITORPROGRESS_EXECUTORADDEDDATA._serialized_start=3868 + _SPARKMONITORPROGRESS_EXECUTORADDEDDATA._serialized_end=3976 + _SPARKMONITORPROGRESS_EXECUTORREMOVEDDATA._serialized_start=3978 + _SPARKMONITORPROGRESS_EXECUTORREMOVEDDATA._serialized_end=4055 +# @@protoc_insertion_point(module_scope) diff --git a/google/cloud/dataproc_spark_connect/session.py b/google/cloud/dataproc_spark_connect/session.py index 7ad24fe8..6b75124d 100644 --- a/google/cloud/dataproc_spark_connect/session.py +++ b/google/cloud/dataproc_spark_connect/session.py @@ -24,6 +24,7 @@ import threading import time import uuid +import queue import tqdm from packaging import version from types import MethodType @@ -42,6 +43,7 @@ from google.auth.exceptions import DefaultCredentialsError from google.cloud.dataproc_spark_connect.client import DataprocChannelBuilder from google.cloud.dataproc_spark_connect.exceptions import DataprocSparkConnectException +from google.cloud.dataproc_spark_connect.proto import sparkmonitor_pb2 from google.cloud.dataproc_spark_connect.pypi_artifacts import PyPiArtifacts from google.cloud.dataproc_v1 import ( AuthenticationConfig, @@ -54,6 +56,7 @@ ) from google.cloud.dataproc_v1.types import sessions from google.cloud.dataproc_spark_connect import environment +from google.protobuf import json_format from pyspark.sql.connect.session import SparkSession from pyspark.sql.utils import to_str @@ -955,6 +958,20 @@ def __init__( super().__init__(connection, user_id) + # Unique ID for the currently executing cell + # This is set by the pre_run_cell hook before each cell executes + self._current_cell_run_id: Optional[str] = None + + # Track if we're in an IPython environment + self._ipython_available = False + + # Setup cell tracking FIRST (sets up the run_id mechanism) + self._setup_cell_execution_tracking() + + # Then setup SparkMonitor interception + self._setup_sparkmonitor_interception() + + # Setup your existing wrappers execute_plan_request_base_method = ( self.client._execute_plan_request_with_metadata ) @@ -1008,6 +1025,294 @@ def clearProgressHandlers_wrapper_method(_, *args, **kwargs): clearProgressHandlers_wrapper_method, self ) + def _setup_cell_execution_tracking(self): + """ + Hook into IPython's cell execution events to generate unique IDs + for each cell execution. This allows VS Code to associate SparkMonitor + messages with the correct cell. + """ + try: + from IPython import get_ipython + from IPython.display import display + + ip = get_ipython() + + if ip is not None: + self._ipython_available = True + + # Set run_id for the current cell (the one creating the session) + self._current_cell_run_id = str(uuid.uuid4()) + + # Bootstrap the session-creation cell: the pre_run_cell hook did not exist + # when this cell started executing, so it never fired for it. This one-time + # call manually injects the initial SparkMonitor payload for the current cell, + # ensuring the widget occupies the top output slot (index 0) before any + # subsequent print statements from session creation execute. + display_data = { + 'application/vnd.sparkmonitor+json': { + 'msgtype': 'fromscala', + 'msg': '{"msgtype": "sparkMonitorInit"}' + } + } + display(display_data, raw=True, display_id=self._current_cell_run_id) + + def pre_run_cell_hook(*args, **kwargs): + """ + Called by IPython BEFORE each cell executes. + Generates a new unique ID for this cell execution. + """ + self._current_cell_run_id = str(uuid.uuid4()) + + # Inject an initial empty payload right when the cell starts. + # This guarantees the SparkMonitor widget occupies the top spot (index 0) + # in the VS Code outputs before any user code `print` statements execute. + display_data = { + 'application/vnd.sparkmonitor+json': { + 'msgtype': 'fromscala', + 'msg': '{"msgtype": "sparkMonitorInit"}' + } + } + display(display_data, raw=True, display_id=self._current_cell_run_id) + + ip.events.register('pre_run_cell', pre_run_cell_hook) + else: + logger.debug("Not in IPython environment - cell tracking disabled") + + except Exception as e: + logger.warning(f"Could not setup cell tracking: {e}") + + def _setup_sparkmonitor_interception(self): + """Intercept gRPC ExecutePlan responses to extract SparkMonitorProgress messages""" + original_execute_plan = self.client._stub.ExecutePlan + + def sparkmonitor_intercepting_execute_plan(request, **kwargs): + """Wrapper that intercepts raw ExecutePlanResponse objects with background consumption""" + # Query-scoped counters (not shared across queries) + msg_type_counts = {} + responses_with_sparkmonitor = [0] + + response_queue = queue.Queue() + background_error = [None] # Mutable container for thread errors + stream_exhausted = threading.Event() + + def background_consumer(): + """Background thread that consumes all messages from gRPC stream""" + try: + count = 0 + for raw_response in original_execute_plan(request, **kwargs): + count += 1 + response_queue.put(raw_response) + self._extract_and_send_sparkmonitor(raw_response, count, msg_type_counts, responses_with_sparkmonitor) + + # Mark stream as exhausted + stream_exhausted.set() + except Exception as e: + background_error[0] = e + stream_exhausted.set() + finally: + # Signal end of stream + response_queue.put(None) + + # Start background consumer thread + consumer_thread = threading.Thread(target=background_consumer, daemon=True) + consumer_thread.start() + + # Yield responses from queue to main consumer (PySpark) + while True: + try: + raw_response = response_queue.get(timeout=0.1) + if raw_response is None: + # End of stream marker + break + yield raw_response + except queue.Empty: + # Check if stream is exhausted and queue is empty + if stream_exhausted.is_set() and response_queue.empty(): + break + continue + + if background_error[0]: + raise background_error[0] + + self.client._stub.ExecutePlan = sparkmonitor_intercepting_execute_plan + + def _extract_and_send_sparkmonitor(self, raw_response, response_num: int, msg_type_counts: dict, responses_with_sparkmonitor: list): + """Extract SparkMonitor data from a raw gRPC response and send it to VS Code. + + The SparkMonitor payload is embedded in field 24 of the ExecutePlanResponse proto. + PySpark's generated proto class has no definition for field 24, so it treats it as + an unknown field — normal attribute access is impossible. We must re-serialize the + response to raw bytes and manually locate and slice out the embedded message. + Once extracted, sparkmonitor_pb2 handles all deserialization into typed fields. + + Args: + raw_response: The gRPC ExecutePlanResponse + response_num: Response number in this query + msg_type_counts: Query-scoped message type counter dict + responses_with_sparkmonitor: Query-scoped counter for responses with SparkMonitor data + """ + try: + serialized = raw_response.SerializeToString() + + if b'\xc2\x01' not in serialized: + return + + responses_with_sparkmonitor[0] += 1 + + # Extract field 24 + idx = serialized.find(b'\xc2\x01') + pos = idx + 2 + length = 0 + shift = 0 + + while pos < len(serialized): + byte = serialized[pos] + pos += 1 + length |= (byte & 0x7F) << shift + if not (byte & 0x80): + break + shift += 7 + + spark_monitor_data = serialized[pos:pos + length] + + # Parse proto + sm = sparkmonitor_pb2.SparkMonitorProgress() + sm.ParseFromString(spark_monitor_data) + + # Track message types (query-scoped) + msg_type = sm.msg_type + msg_type_counts[msg_type] = msg_type_counts.get(msg_type, 0) + 1 + + # Skip stream completion signal (don't forward to VS Code) + if msg_type == "sparkMonitorStreamComplete": + return + + # Convert to Scala-compatible JSON and send to VS Code + json_msg = self._proto_to_scala_json_format(sm) + self._send_to_vscode(json_msg) + + except Exception as e: + logger.debug(f"Error extracting SparkMonitor: {e}") + + def _convert_string_numbers_to_int(self, obj): + """ + Recursively convert string numbers to integers in a dictionary. + + MessageToJson converts int64 fields to strings by default to avoid JavaScript + precision issues, but the VS Code SparkMonitor extension expects numeric values. + """ + if isinstance(obj, dict): + return {k: self._convert_string_numbers_to_int(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [self._convert_string_numbers_to_int(item) for item in obj] + elif isinstance(obj, str): + # Try to convert string to int if it looks like a number + # Negative numbers (like -1 for completionTime) should also be converted + if obj.lstrip('-').isdigit(): + return int(obj) + return obj + else: + return obj + + def _proto_to_scala_json_format(self, sm: sparkmonitor_pb2.SparkMonitorProgress) -> dict: + """ + Convert protobuf message to JSON format matching the Scala listener's output. + + This ensures compatibility with existing VS Code SparkMonitor extension. + The format uses: + - 'msgtype' (lowercase) for the message type field + - camelCase for all other nested fields + - Numeric fields as JSON numbers (not strings) + """ + try: + # Convert proto to JSON with camelCase field names + # Try newer protobuf 5.x+ parameter first, fall back to older parameter + # This ensures fields with default values (like jobId=0, attemptId=0) are included + try: + # Protobuf 5.x+ uses always_print_fields_with_no_presence + json_str = json_format.MessageToJson( + sm, + preserving_proto_field_name=False, + always_print_fields_with_no_presence=True + ) + except TypeError: + # Protobuf <5.x uses including_default_value_fields + json_str = json_format.MessageToJson( + sm, + preserving_proto_field_name=False, + including_default_value_fields=True + ) + except Exception as e: + logger.error(f"Failed to convert proto to JSON: {e}") + # Emergency fallback + return {"msgtype": sm.msg_type or "unknown", "error": "conversion_failed"} + + msg = json.loads(json_str) + + # Convert string numbers to actual numbers for compatibility with VS Code extension + # MessageToJson converts int64 to strings by default to avoid JS precision issues, + # but the SparkMonitor extension expects numeric values + msg = self._convert_string_numbers_to_int(msg) + + # Extract the actual event data (everything except msg_type) + # The proto has msg_type at top level and one of the event fields set + event_data = {} + + # Find which event field is set and extract its data + for field_name in [ + 'applicationStart', 'applicationEnd', + 'jobStart', 'jobEnd', + 'stageSubmitted', 'stageCompleted', 'stageActive', + 'taskStart', 'taskEnd', + 'executorAdded', 'executorRemoved' + ]: + if field_name in msg: + event_data = msg[field_name] + break + + # Get the msgtype from msg_type field (it's already camelCase from MessageToJson) + msgtype_value = msg.get('msgType', sm.msg_type) + + # Build the final message with 'msgtype' (lowercase) and camelCase event data + result = { + 'msgtype': msgtype_value, # lowercase 'msgtype' + **event_data # Spread the event data (already in camelCase) + } + + return result + + def _send_to_vscode(self, msg: dict): + """Send SparkMonitor data to VS Code using IPython display mechanism. + + Matches the remote kernel format exactly: + - Wraps the event in a 'fromscala' envelope + - Converts the msg dict to a JSON string (like the Scala listener does) + """ + if not self._ipython_available: + return + + try: + from IPython.display import display + + display_id = self._current_cell_run_id or str(uuid.uuid4()) + + # Match the remote kernel format exactly: + # 1. Convert dict to JSON string (like Scala's pretty(render(json))) + # 2. Wrap in fromscala envelope (like kernel extension does) + wrapper = { + 'msgtype': 'fromscala', + 'msg': json.dumps(msg) # Convert to JSON string + } + + display_data = { + 'application/vnd.sparkmonitor+json': wrapper, + } + + display(display_data, raw=True, display_id=display_id) + + except Exception as e: + logger.debug(f"Error sending to VS Code: {e}") + @staticmethod @functools.lru_cache(maxsize=1) def get_tqdm_bar(): diff --git a/setup.py b/setup.py index 539e50ed..77a497a3 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,13 @@ url="https://github.com/GoogleCloudDataproc/dataproc-spark-connect-python", license="Apache 2.0", packages=find_namespace_packages(include=["google.*"]), + package_data={ + "google.cloud.dataproc_spark_connect.proto": [ + "*.proto", + "*_pb2.py", + ], + }, + include_package_data=True, install_requires=[ "google-api-core>=2.19", "google-cloud-dataproc>=5.18", @@ -35,5 +42,6 @@ "pyspark[connect]~=4.0.0", "tqdm>=4.67", "websockets>=14.0", + "protobuf>=3.20.0", # Added for proto support ], -) +) \ No newline at end of file diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 2b1a6245..68bec8d3 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2644,5 +2644,259 @@ def test_session_skip_terminated(self, mock_session_controller_client): mock_client.get_session.assert_called_once() +class SparkMonitorTests(unittest.TestCase): + """Tests for the SparkMonitor integration added to DataprocSparkSession.""" + + def setUp(self): + self.original_environment = dict(os.environ) + os.environ.clear() + os.environ["GOOGLE_CLOUD_PROJECT"] = "test-project" + os.environ["GOOGLE_CLOUD_REGION"] = "test-region" + + def tearDown(self): + os.environ.clear() + os.environ.update(self.original_environment) + + @staticmethod + def _make_session_instance(**attrs): + """Create a minimal mock DataprocSparkSession with given attributes.""" + session = mock.MagicMock(spec=DataprocSparkSession) + for key, value in attrs.items(): + setattr(session, key, value) + return session + + @staticmethod + def _encode_varint(value): + """Encode an integer as a protobuf base-128 varint.""" + result = b'' + while value > 127: + result += bytes([(value & 0x7F) | 0x80]) + value >>= 7 + result += bytes([value]) + return result + + def _build_fake_grpc_response(self, sm): + """Build a fake gRPC response with a SparkMonitorProgress message embedded at field 24.""" + from google.cloud.dataproc_spark_connect.proto import sparkmonitor_pb2 + sm_bytes = sm.SerializeToString() + # Field 24, wire type 2 tag = (24 << 3) | 2 = 194 = 0xC2 0x01 as a varint + payload = b'\xc2\x01' + self._encode_varint(len(sm_bytes)) + sm_bytes + mock_response = mock.MagicMock() + mock_response.SerializeToString.return_value = payload + return mock_response + + def test_convert_string_numbers_to_int_positive(self): + session = self._make_session_instance() + result = DataprocSparkSession._convert_string_numbers_to_int(session, "42") + self.assertEqual(result, 42) + self.assertIsInstance(result, int) + + def test_convert_string_numbers_to_int_negative(self): + """Negative string numbers such as completionTime=-1 should be converted.""" + session = self._make_session_instance() + result = DataprocSparkSession._convert_string_numbers_to_int(session, "-1") + self.assertEqual(result, -1) + self.assertIsInstance(result, int) + + def test_convert_string_numbers_to_int_preserves_non_numeric(self): + session = self._make_session_instance() + result = DataprocSparkSession._convert_string_numbers_to_int(session, "sparkJobStart") + self.assertEqual(result, "sparkJobStart") + + def test_convert_string_numbers_to_int_nested_dict_and_list(self): + session = self._make_session_instance() + # Wire up the recursive self-call so nested values are also converted + session._convert_string_numbers_to_int = lambda x: DataprocSparkSession._convert_string_numbers_to_int(session, x) + obj = {"jobId": "5", "status": "SUCCEEDED", "stageIds": ["1", "2"]} + result = DataprocSparkSession._convert_string_numbers_to_int(session, obj) + self.assertEqual(result, {"jobId": 5, "status": "SUCCEEDED", "stageIds": [1, 2]}) + + def test_convert_string_numbers_to_int_passthrough_non_string(self): + session = self._make_session_instance() + self.assertEqual(DataprocSparkSession._convert_string_numbers_to_int(session, 99), 99) + self.assertIsNone(DataprocSparkSession._convert_string_numbers_to_int(session, None)) + + def test_proto_to_scala_json_format_job_start(self): + from google.cloud.dataproc_spark_connect.proto import sparkmonitor_pb2 + session = self._make_session_instance() + # Wire up _convert_string_numbers_to_int so _proto_to_scala_json_format gets real values + session._convert_string_numbers_to_int = lambda x: DataprocSparkSession._convert_string_numbers_to_int(session, x) + + sm = sparkmonitor_pb2.SparkMonitorProgress() + sm.msg_type = "sparkJobStart" + sm.job_start.job_id = 3 + sm.job_start.num_tasks = 10 + sm.job_start.num_executors = 2 + + result = DataprocSparkSession._proto_to_scala_json_format(session, sm) + + self.assertEqual(result["msgtype"], "sparkJobStart") + self.assertEqual(result["jobId"], 3) + self.assertEqual(result["numTasks"], 10) + self.assertNotIn("jobStart", result) # event data should be spread to top level + + def test_proto_to_scala_json_format_job_end(self): + from google.cloud.dataproc_spark_connect.proto import sparkmonitor_pb2 + session = self._make_session_instance() + session._convert_string_numbers_to_int = lambda x: DataprocSparkSession._convert_string_numbers_to_int(session, x) + + sm = sparkmonitor_pb2.SparkMonitorProgress() + sm.msg_type = "sparkJobEnd" + sm.job_end.job_id = 3 + sm.job_end.status = "SUCCEEDED" + + result = DataprocSparkSession._proto_to_scala_json_format(session, sm) + + self.assertEqual(result["msgtype"], "sparkJobEnd") + self.assertEqual(result["jobId"], 3) + self.assertEqual(result["status"], "SUCCEEDED") + + def test_proto_to_scala_json_format_stage_active(self): + from google.cloud.dataproc_spark_connect.proto import sparkmonitor_pb2 + session = self._make_session_instance() + session._convert_string_numbers_to_int = lambda x: DataprocSparkSession._convert_string_numbers_to_int(session, x) + + sm = sparkmonitor_pb2.SparkMonitorProgress() + sm.msg_type = "sparkStageActive" + sm.stage_active.stage_id = 7 + sm.stage_active.num_tasks = 20 + sm.stage_active.num_completed_tasks = 15 + + result = DataprocSparkSession._proto_to_scala_json_format(session, sm) + + self.assertEqual(result["msgtype"], "sparkStageActive") + self.assertEqual(result["stageId"], 7) + self.assertEqual(result["numTasks"], 20) + self.assertEqual(result["numCompletedTasks"], 15) + + def test_send_to_vscode_skips_when_ipython_unavailable(self): + session = self._make_session_instance(_ipython_available=False) + + with mock.patch("IPython.display.display") as mock_display: + DataprocSparkSession._send_to_vscode(session, {"msgtype": "sparkJobStart"}) + mock_display.assert_not_called() + + def test_send_to_vscode_calls_display_when_ipython_available(self): + import json + run_id = "test-run-id-1234" + session = self._make_session_instance( + _ipython_available=True, + _current_cell_run_id=run_id, + ) + msg = {"msgtype": "sparkJobEnd", "jobId": 1} + + with mock.patch("IPython.display.display") as mock_display: + # Patch the import inside the method + with mock.patch.dict("sys.modules", {"IPython.display": mock.MagicMock(display=mock_display)}): + DataprocSparkSession._send_to_vscode(session, msg) + + mock_display.assert_called_once() + call_args = mock_display.call_args + display_data = call_args[0][0] + self.assertIn("application/vnd.sparkmonitor+json", display_data) + wrapper = display_data["application/vnd.sparkmonitor+json"] + self.assertEqual(wrapper["msgtype"], "fromscala") + self.assertEqual(json.loads(wrapper["msg"]), msg) + + def test_extract_and_send_skips_response_without_sparkmonitor_data(self): + session = self._make_session_instance() + + mock_response = mock.MagicMock() + mock_response.SerializeToString.return_value = b'\x0a\x05hello' # No \xc2\x01 field tag + + msg_type_counts = {} + responses_with_sparkmonitor = [0] + + DataprocSparkSession._extract_and_send_sparkmonitor( + session, mock_response, 1, msg_type_counts, responses_with_sparkmonitor + ) + + self.assertEqual(responses_with_sparkmonitor[0], 0) + session._send_to_vscode.assert_not_called() + + def test_extract_and_send_skips_stream_complete_signal(self): + from google.cloud.dataproc_spark_connect.proto import sparkmonitor_pb2 + session = self._make_session_instance() + + sm = sparkmonitor_pb2.SparkMonitorProgress() + sm.msg_type = "sparkMonitorStreamComplete" + mock_response = self._build_fake_grpc_response(sm) + + msg_type_counts = {} + responses_with_sparkmonitor = [0] + + DataprocSparkSession._extract_and_send_sparkmonitor( + session, mock_response, 1, msg_type_counts, responses_with_sparkmonitor + ) + + # Counter incremented but _send_to_vscode NOT called + self.assertEqual(responses_with_sparkmonitor[0], 1) + self.assertEqual(msg_type_counts["sparkMonitorStreamComplete"], 1) + session._send_to_vscode.assert_not_called() + + def test_extract_and_send_processes_valid_job_start_payload(self): + from google.cloud.dataproc_spark_connect.proto import sparkmonitor_pb2 + session = self._make_session_instance() + + sm = sparkmonitor_pb2.SparkMonitorProgress() + sm.msg_type = "sparkJobStart" + sm.job_start.job_id = 1 + sm.job_start.num_tasks = 8 + + mock_response = self._build_fake_grpc_response(sm) + + # Wire up real implementations so the full extraction pipeline runs + session._convert_string_numbers_to_int = lambda x: DataprocSparkSession._convert_string_numbers_to_int(session, x) + session._proto_to_scala_json_format = lambda s: DataprocSparkSession._proto_to_scala_json_format(session, s) + + msg_type_counts = {} + responses_with_sparkmonitor = [0] + + DataprocSparkSession._extract_and_send_sparkmonitor( + session, mock_response, 1, msg_type_counts, responses_with_sparkmonitor + ) + + self.assertEqual(responses_with_sparkmonitor[0], 1) + self.assertEqual(msg_type_counts["sparkJobStart"], 1) + session._send_to_vscode.assert_called_once() + sent_msg = session._send_to_vscode.call_args[0][0] + self.assertEqual(sent_msg["msgtype"], "sparkJobStart") + + def test_setup_cell_tracking_sets_flag_when_ipython_present(self): + """When IPython is available and has a live shell, _ipython_available should be True.""" + session = self._make_session_instance(_ipython_available=False, _current_cell_run_id=None) + + mock_ip = mock.MagicMock() + with mock.patch("IPython.get_ipython", return_value=mock_ip): + with mock.patch("IPython.display.display"): + DataprocSparkSession._setup_cell_execution_tracking(session) + + self.assertTrue(session._ipython_available) + self.assertIsNotNone(session._current_cell_run_id) + mock_ip.events.register.assert_called_once_with( + "pre_run_cell", mock.ANY + ) + + def test_setup_cell_tracking_leaves_flag_false_when_no_ipython_shell(self): + """When get_ipython() returns None, _ipython_available should remain False.""" + session = self._make_session_instance(_ipython_available=False, _current_cell_run_id=None) + + with mock.patch("IPython.get_ipython", return_value=None): + DataprocSparkSession._setup_cell_execution_tracking(session) + + self.assertFalse(session._ipython_available) + self.assertIsNone(session._current_cell_run_id) + + def test_setup_cell_tracking_is_resilient_to_import_error(self): + """If IPython is not installed, the method should not raise.""" + session = self._make_session_instance(_ipython_available=False, _current_cell_run_id=None) + + with mock.patch.dict("sys.modules", {"IPython": None}): + # Should not raise + DataprocSparkSession._setup_cell_execution_tracking(session) + + self.assertFalse(session._ipython_available) + + if __name__ == "__main__": unittest.main() From 878f5f2b6d3364448cde2763b3b0e6486e9b5aa0 Mon Sep 17 00:00:00 2001 From: Siddhant Rao Date: Thu, 19 Mar 2026 14:16:55 -0700 Subject: [PATCH 2/5] Update session.py to now use the protobuf Any extension field from server side sparkmonitor soa s to be immune to upgrades from parent apache/spark repo --- .../proto/sparkmonitor.proto | 238 ++++++------------ .../proto/sparkmonitor_pb2.py | 64 +++-- .../cloud/dataproc_spark_connect/session.py | 177 +++++++------ tests/unit/test_session.py | 57 +++-- 4 files changed, 252 insertions(+), 284 deletions(-) diff --git a/google/cloud/dataproc_spark_connect/proto/sparkmonitor.proto b/google/cloud/dataproc_spark_connect/proto/sparkmonitor.proto index 36e03c7f..8e3d9201 100644 --- a/google/cloud/dataproc_spark_connect/proto/sparkmonitor.proto +++ b/google/cloud/dataproc_spark_connect/proto/sparkmonitor.proto @@ -1,118 +1,89 @@ syntax = "proto3"; -package sparkmonitor; +package spark.connect; option java_multiple_files = true; option java_package = "org.apache.spark.connect.proto"; +// SparkMonitor progress data delivered via the upstream extension slot on ExecutePlanResponse +// (google.protobuf.Any extension = 999). +// type_url: "type.googleapis.com/spark.connect.SparkMonitorProgress" message SparkMonitorProgress { - // Message type indicating which event this is - string msg_type = 1; - - // Application events - optional ApplicationStartData application_start = 2; - optional ApplicationEndData application_end = 3; - - // Job events - optional JobStartData job_start = 4; - optional JobEndData job_end = 5; - - // Stage events - optional StageSubmittedData stage_submitted = 6; - optional StageCompletedData stage_completed = 7; - optional StageActiveData stage_active = 8; - - // Task events - optional TaskStartData task_start = 9; - optional TaskEndData task_end = 10; - - // Executor events - optional ExecutorAddedData executor_added = 11; - optional ExecutorRemovedData executor_removed = 12; - - // Application Start/End - message ApplicationStartData { - int64 start_time = 1; - string app_id = 2; - string app_attempt_id = 3; - string app_name = 4; - string spark_user = 5; - } - - message ApplicationEndData { - int64 end_time = 1; + optional ApplicationInfo application_info = 1; + repeated JobEvent job_events = 2; + repeated DetailedStageEvent stage_events = 3; + repeated TaskEvent task_events = 4; + repeated ExecutorEvent executor_events = 5; + optional bool stream_complete = 6; + + // Application lifecycle info (start_time present = start event, end_time present = end event) + message ApplicationInfo { + optional int64 start_time = 1; + optional int64 end_time = 2; + optional string app_id = 3; + optional string app_attempt_id = 4; + optional string app_name = 5; + optional string spark_user = 6; } - - // Job Start/End - message JobStartData { - string job_group = 1; + + // Job events (JOB_START=0, JOB_END=1) + message JobEvent { + enum JobEventType { + JOB_START = 0; + JOB_END = 1; + } + JobEventType event_type = 1; int64 job_id = 2; string status = 3; - int64 submission_time = 4; - repeated int32 stage_ids = 5; - map stage_infos = 6; - int32 num_tasks = 7; - int32 total_cores = 8; - string app_id = 9; - int32 num_executors = 10; - string name = 11; + optional int64 submission_time = 4; + optional int64 completion_time = 5; + optional string job_group = 6; + optional string name = 7; + repeated int32 stage_ids = 8; + map stage_infos = 9; + optional int32 num_tasks = 10; + optional int32 total_cores = 11; + optional string app_id = 12; + optional int32 num_executors = 13; } - - message StageInfoForJob { + + message JobStageInfo { int32 attempt_id = 1; string name = 2; int32 num_tasks = 3; int64 completion_time = 4; int64 submission_time = 5; } - - message JobEndData { - int64 job_id = 1; - string status = 2; - int64 completion_time = 3; - } - - // Stage Submitted/Completed/Active - message StageSubmittedData { - int64 stage_id = 1; - int32 stage_attempt_id = 2; - string name = 3; - int32 num_tasks = 4; - repeated int32 parent_ids = 5; - int64 submission_time = 6; - repeated int64 job_ids = 7; - int32 num_active_tasks = 8; - int32 num_failed_tasks = 9; - int32 num_completed_tasks = 10; - } - - message StageCompletedData { - int64 stage_id = 1; - int32 stage_attempt_id = 2; - int64 completion_time = 3; - int64 submission_time = 4; - int32 num_tasks = 5; - int32 num_failed_tasks = 6; - int32 num_completed_tasks = 7; - string status = 8; - repeated int64 job_ids = 9; - } - - message StageActiveData { - int64 stage_id = 1; - int32 stage_attempt_id = 2; - string name = 3; - repeated int32 parent_ids = 4; + + // Detailed stage events (STAGE_SUBMITTED=0, STAGE_ACTIVE=1, STAGE_COMPLETED=2) + message DetailedStageEvent { + enum StageEventType { + STAGE_SUBMITTED = 0; + STAGE_ACTIVE = 1; + STAGE_COMPLETED = 2; + } + StageEventType event_type = 1; + int64 stage_id = 2; + int32 stage_attempt_id = 3; + string name = 4; int32 num_tasks = 5; - int32 num_active_tasks = 6; - int32 num_failed_tasks = 7; - int32 num_completed_tasks = 8; + repeated int32 parent_ids = 6; + optional int64 submission_time = 7; + optional int64 completion_time = 8; repeated int64 job_ids = 9; + optional int32 num_active_tasks = 10; + optional int32 num_failed_tasks = 11; + optional int32 num_completed_tasks = 12; + optional string status = 13; } - - // Task Start/End - message TaskStartData { - int64 launch_time = 1; + + // Task events (TASK_START=0, TASK_END=1) + message TaskEvent { + enum TaskEventType { + TASK_START = 0; + TASK_END = 1; + } + TaskEventType event_type = 1; int64 task_id = 2; int64 stage_id = 3; int32 stage_attempt_id = 4; @@ -122,66 +93,23 @@ message SparkMonitorProgress { string host = 8; string status = 9; bool speculative = 10; + optional int64 launch_time = 11; + optional int64 finish_time = 12; + optional string task_type = 13; + optional string error_message = 14; } - - message TaskEndData { - int64 launch_time = 1; - int64 finish_time = 2; - int64 task_id = 3; - int64 stage_id = 4; - string task_type = 5; - int32 stage_attempt_id = 6; - int32 index = 7; - int32 attempt_number = 8; - string executor_id = 9; - string host = 10; - string status = 11; - bool speculative = 12; - string error_message = 13; - TaskMetrics metrics = 14; // Also optional since it may not be present - } - - message TaskMetrics { - int64 shuffle_read_time = 1; - int64 shuffle_write_time = 2; - int64 serialization_time = 3; - int64 deserialization_time = 4; - int64 getting_result_time = 5; - int64 executor_computing_time = 6; - int64 scheduler_delay = 7; - double shuffle_read_time_proportion = 8; - double shuffle_write_time_proportion = 9; - double serialization_time_proportion = 10; - double deserialization_time_proportion = 11; - double getting_result_time_proportion = 12; - double executor_computing_time_proportion = 13; - double scheduler_delay_proportion = 14; - double shuffle_read_time_proportion_pos = 15; - double shuffle_write_time_proportion_pos = 16; - double serialization_time_proportion_pos = 17; - double deserialization_time_proportion_pos = 18; - double getting_result_time_proportion_pos = 19; - double executor_computing_time_proportion_pos = 20; - double scheduler_delay_proportion_pos = 21; - int64 result_size = 22; - int64 jvm_gc_time = 23; - int64 memory_bytes_spilled = 24; - int64 disk_bytes_spilled = 25; - int64 peak_execution_memory = 26; - } - - // Executor Added/Removed - message ExecutorAddedData { - string executor_id = 1; - int64 time = 2; - string host = 3; - int32 num_cores = 4; - int32 total_cores = 5; - } - - message ExecutorRemovedData { - string executor_id = 1; - int64 time = 2; - int32 total_cores = 3; + + // Executor events (EXECUTOR_ADDED=0, EXECUTOR_REMOVED=1) + message ExecutorEvent { + enum ExecutorEventType { + EXECUTOR_ADDED = 0; + EXECUTOR_REMOVED = 1; + } + ExecutorEventType event_type = 1; + string executor_id = 2; + int64 time = 3; + optional string host = 4; + optional int32 num_cores = 5; + optional int32 total_cores = 6; } } \ No newline at end of file diff --git a/google/cloud/dataproc_spark_connect/proto/sparkmonitor_pb2.py b/google/cloud/dataproc_spark_connect/proto/sparkmonitor_pb2.py index a609e673..e8009fb0 100644 --- a/google/cloud/dataproc_spark_connect/proto/sparkmonitor_pb2.py +++ b/google/cloud/dataproc_spark_connect/proto/sparkmonitor_pb2.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! -# source: sparkmonitor.proto +# source: google/cloud/dataproc_spark_connect/proto/sparkmonitor.proto """Generated protocol buffer code.""" from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor @@ -13,44 +13,38 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12sparkmonitor.proto\x12\x0csparkmonitor\"\xf3 \n\x14SparkMonitorProgress\x12\x10\n\x08msg_type\x18\x01 \x01(\t\x12W\n\x11\x61pplication_start\x18\x02 \x01(\x0b\x32\x37.sparkmonitor.SparkMonitorProgress.ApplicationStartDataH\x00\x88\x01\x01\x12S\n\x0f\x61pplication_end\x18\x03 \x01(\x0b\x32\x35.sparkmonitor.SparkMonitorProgress.ApplicationEndDataH\x01\x88\x01\x01\x12G\n\tjob_start\x18\x04 \x01(\x0b\x32/.sparkmonitor.SparkMonitorProgress.JobStartDataH\x02\x88\x01\x01\x12\x43\n\x07job_end\x18\x05 \x01(\x0b\x32-.sparkmonitor.SparkMonitorProgress.JobEndDataH\x03\x88\x01\x01\x12S\n\x0fstage_submitted\x18\x06 \x01(\x0b\x32\x35.sparkmonitor.SparkMonitorProgress.StageSubmittedDataH\x04\x88\x01\x01\x12S\n\x0fstage_completed\x18\x07 \x01(\x0b\x32\x35.sparkmonitor.SparkMonitorProgress.StageCompletedDataH\x05\x88\x01\x01\x12M\n\x0cstage_active\x18\x08 \x01(\x0b\x32\x32.sparkmonitor.SparkMonitorProgress.StageActiveDataH\x06\x88\x01\x01\x12I\n\ntask_start\x18\t \x01(\x0b\x32\x30.sparkmonitor.SparkMonitorProgress.TaskStartDataH\x07\x88\x01\x01\x12\x45\n\x08task_end\x18\n \x01(\x0b\x32..sparkmonitor.SparkMonitorProgress.TaskEndDataH\x08\x88\x01\x01\x12Q\n\x0e\x65xecutor_added\x18\x0b \x01(\x0b\x32\x34.sparkmonitor.SparkMonitorProgress.ExecutorAddedDataH\t\x88\x01\x01\x12U\n\x10\x65xecutor_removed\x18\x0c \x01(\x0b\x32\x36.sparkmonitor.SparkMonitorProgress.ExecutorRemovedDataH\n\x88\x01\x01\x1ax\n\x14\x41pplicationStartData\x12\x12\n\nstart_time\x18\x01 \x01(\x03\x12\x0e\n\x06\x61pp_id\x18\x02 \x01(\t\x12\x16\n\x0e\x61pp_attempt_id\x18\x03 \x01(\t\x12\x10\n\x08\x61pp_name\x18\x04 \x01(\t\x12\x12\n\nspark_user\x18\x05 \x01(\t\x1a&\n\x12\x41pplicationEndData\x12\x10\n\x08\x65nd_time\x18\x01 \x01(\x03\x1a\x87\x03\n\x0cJobStartData\x12\x11\n\tjob_group\x18\x01 \x01(\t\x12\x0e\n\x06job_id\x18\x02 \x01(\x03\x12\x0e\n\x06status\x18\x03 \x01(\t\x12\x17\n\x0fsubmission_time\x18\x04 \x01(\x03\x12\x11\n\tstage_ids\x18\x05 \x03(\x05\x12T\n\x0bstage_infos\x18\x06 \x03(\x0b\x32?.sparkmonitor.SparkMonitorProgress.JobStartData.StageInfosEntry\x12\x11\n\tnum_tasks\x18\x07 \x01(\x05\x12\x13\n\x0btotal_cores\x18\x08 \x01(\x05\x12\x0e\n\x06\x61pp_id\x18\t \x01(\t\x12\x15\n\rnum_executors\x18\n \x01(\x05\x12\x0c\n\x04name\x18\x0b \x01(\t\x1a\x65\n\x0fStageInfosEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32\x32.sparkmonitor.SparkMonitorProgress.StageInfoForJob:\x02\x38\x01\x1ax\n\x0fStageInfoForJob\x12\x12\n\nattempt_id\x18\x01 \x01(\x05\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x11\n\tnum_tasks\x18\x03 \x01(\x05\x12\x17\n\x0f\x63ompletion_time\x18\x04 \x01(\x03\x12\x17\n\x0fsubmission_time\x18\x05 \x01(\x03\x1a\x45\n\nJobEndData\x12\x0e\n\x06job_id\x18\x01 \x01(\x03\x12\x0e\n\x06status\x18\x02 \x01(\t\x12\x17\n\x0f\x63ompletion_time\x18\x03 \x01(\x03\x1a\xf0\x01\n\x12StageSubmittedData\x12\x10\n\x08stage_id\x18\x01 \x01(\x03\x12\x18\n\x10stage_attempt_id\x18\x02 \x01(\x05\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x11\n\tnum_tasks\x18\x04 \x01(\x05\x12\x12\n\nparent_ids\x18\x05 \x03(\x05\x12\x17\n\x0fsubmission_time\x18\x06 \x01(\x03\x12\x0f\n\x07job_ids\x18\x07 \x03(\x03\x12\x18\n\x10num_active_tasks\x18\x08 \x01(\x05\x12\x18\n\x10num_failed_tasks\x18\t \x01(\x05\x12\x1b\n\x13num_completed_tasks\x18\n \x01(\x05\x1a\xdd\x01\n\x12StageCompletedData\x12\x10\n\x08stage_id\x18\x01 \x01(\x03\x12\x18\n\x10stage_attempt_id\x18\x02 \x01(\x05\x12\x17\n\x0f\x63ompletion_time\x18\x03 \x01(\x03\x12\x17\n\x0fsubmission_time\x18\x04 \x01(\x03\x12\x11\n\tnum_tasks\x18\x05 \x01(\x05\x12\x18\n\x10num_failed_tasks\x18\x06 \x01(\x05\x12\x1b\n\x13num_completed_tasks\x18\x07 \x01(\x05\x12\x0e\n\x06status\x18\x08 \x01(\t\x12\x0f\n\x07job_ids\x18\t \x03(\x03\x1a\xd4\x01\n\x0fStageActiveData\x12\x10\n\x08stage_id\x18\x01 \x01(\x03\x12\x18\n\x10stage_attempt_id\x18\x02 \x01(\x05\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x12\n\nparent_ids\x18\x04 \x03(\x05\x12\x11\n\tnum_tasks\x18\x05 \x01(\x05\x12\x18\n\x10num_active_tasks\x18\x06 \x01(\x05\x12\x18\n\x10num_failed_tasks\x18\x07 \x01(\x05\x12\x1b\n\x13num_completed_tasks\x18\x08 \x01(\x05\x12\x0f\n\x07job_ids\x18\t \x03(\x03\x1a\xd0\x01\n\rTaskStartData\x12\x13\n\x0blaunch_time\x18\x01 \x01(\x03\x12\x0f\n\x07task_id\x18\x02 \x01(\x03\x12\x10\n\x08stage_id\x18\x03 \x01(\x03\x12\x18\n\x10stage_attempt_id\x18\x04 \x01(\x05\x12\r\n\x05index\x18\x05 \x01(\x05\x12\x16\n\x0e\x61ttempt_number\x18\x06 \x01(\x05\x12\x13\n\x0b\x65xecutor_id\x18\x07 \x01(\t\x12\x0c\n\x04host\x18\x08 \x01(\t\x12\x0e\n\x06status\x18\t \x01(\t\x12\x13\n\x0bspeculative\x18\n \x01(\x08\x1a\xce\x02\n\x0bTaskEndData\x12\x13\n\x0blaunch_time\x18\x01 \x01(\x03\x12\x13\n\x0b\x66inish_time\x18\x02 \x01(\x03\x12\x0f\n\x07task_id\x18\x03 \x01(\x03\x12\x10\n\x08stage_id\x18\x04 \x01(\x03\x12\x11\n\ttask_type\x18\x05 \x01(\t\x12\x18\n\x10stage_attempt_id\x18\x06 \x01(\x05\x12\r\n\x05index\x18\x07 \x01(\x05\x12\x16\n\x0e\x61ttempt_number\x18\x08 \x01(\x05\x12\x13\n\x0b\x65xecutor_id\x18\t \x01(\t\x12\x0c\n\x04host\x18\n \x01(\t\x12\x0e\n\x06status\x18\x0b \x01(\t\x12\x13\n\x0bspeculative\x18\x0c \x01(\x08\x12\x15\n\rerror_message\x18\r \x01(\t\x12?\n\x07metrics\x18\x0e \x01(\x0b\x32..sparkmonitor.SparkMonitorProgress.TaskMetrics\x1a\x9e\x07\n\x0bTaskMetrics\x12\x19\n\x11shuffle_read_time\x18\x01 \x01(\x03\x12\x1a\n\x12shuffle_write_time\x18\x02 \x01(\x03\x12\x1a\n\x12serialization_time\x18\x03 \x01(\x03\x12\x1c\n\x14\x64\x65serialization_time\x18\x04 \x01(\x03\x12\x1b\n\x13getting_result_time\x18\x05 \x01(\x03\x12\x1f\n\x17\x65xecutor_computing_time\x18\x06 \x01(\x03\x12\x17\n\x0fscheduler_delay\x18\x07 \x01(\x03\x12$\n\x1cshuffle_read_time_proportion\x18\x08 \x01(\x01\x12%\n\x1dshuffle_write_time_proportion\x18\t \x01(\x01\x12%\n\x1dserialization_time_proportion\x18\n \x01(\x01\x12\'\n\x1f\x64\x65serialization_time_proportion\x18\x0b \x01(\x01\x12&\n\x1egetting_result_time_proportion\x18\x0c \x01(\x01\x12*\n\"executor_computing_time_proportion\x18\r \x01(\x01\x12\"\n\x1ascheduler_delay_proportion\x18\x0e \x01(\x01\x12(\n shuffle_read_time_proportion_pos\x18\x0f \x01(\x01\x12)\n!shuffle_write_time_proportion_pos\x18\x10 \x01(\x01\x12)\n!serialization_time_proportion_pos\x18\x11 \x01(\x01\x12+\n#deserialization_time_proportion_pos\x18\x12 \x01(\x01\x12*\n\"getting_result_time_proportion_pos\x18\x13 \x01(\x01\x12.\n&executor_computing_time_proportion_pos\x18\x14 \x01(\x01\x12&\n\x1escheduler_delay_proportion_pos\x18\x15 \x01(\x01\x12\x13\n\x0bresult_size\x18\x16 \x01(\x03\x12\x13\n\x0bjvm_gc_time\x18\x17 \x01(\x03\x12\x1c\n\x14memory_bytes_spilled\x18\x18 \x01(\x03\x12\x1a\n\x12\x64isk_bytes_spilled\x18\x19 \x01(\x03\x12\x1d\n\x15peak_execution_memory\x18\x1a \x01(\x03\x1al\n\x11\x45xecutorAddedData\x12\x13\n\x0b\x65xecutor_id\x18\x01 \x01(\t\x12\x0c\n\x04time\x18\x02 \x01(\x03\x12\x0c\n\x04host\x18\x03 \x01(\t\x12\x11\n\tnum_cores\x18\x04 \x01(\x05\x12\x13\n\x0btotal_cores\x18\x05 \x01(\x05\x1aM\n\x13\x45xecutorRemovedData\x12\x13\n\x0b\x65xecutor_id\x18\x01 \x01(\t\x12\x0c\n\x04time\x18\x02 \x01(\x03\x12\x13\n\x0btotal_cores\x18\x03 \x01(\x05\x42\x14\n\x12_application_startB\x12\n\x10_application_endB\x0c\n\n_job_startB\n\n\x08_job_endB\x12\n\x10_stage_submittedB\x12\n\x10_stage_completedB\x0f\n\r_stage_activeB\r\n\x0b_task_startB\x0b\n\t_task_endB\x11\n\x0f_executor_addedB\x13\n\x11_executor_removedB\"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n 0 or + len(sm.stage_events) > 0 or + len(sm.task_events) > 0 or + len(sm.executor_events) > 0 or + sm.HasField('stream_complete') + ) + if not is_sparkmonitor: + return - # Track message types (query-scoped) - msg_type = sm.msg_type + responses_with_sparkmonitor[0] += 1 + + # Derive msgtype for tracking (mirrors old string-based tracking) + msg_type = self._derive_sparkmonitor_msgtype(sm) msg_type_counts[msg_type] = msg_type_counts.get(msg_type, 0) + 1 # Skip stream completion signal (don't forward to VS Code) - if msg_type == "sparkMonitorStreamComplete": + if sm.HasField('stream_complete') and sm.stream_complete: return # Convert to Scala-compatible JSON and send to VS Code @@ -1194,6 +1201,22 @@ def _extract_and_send_sparkmonitor(self, raw_response, response_num: int, msg_ty except Exception as e: logger.debug(f"Error extracting SparkMonitor: {e}") + def _derive_sparkmonitor_msgtype(self, sm: sparkmonitor_pb2.SparkMonitorProgress) -> str: + """Derive a msgtype string from the new enum-based SparkMonitor proto structure.""" + if sm.HasField('stream_complete'): + return "sparkMonitorStreamComplete" + if sm.HasField('application_info'): + return "sparkApplicationStart" if sm.application_info.HasField('start_time') else "sparkApplicationEnd" + if sm.job_events: + return "sparkJobStart" if sm.job_events[0].event_type == 0 else "sparkJobEnd" + if sm.stage_events: + return ["sparkStageSubmitted", "sparkStageActive", "sparkStageCompleted"][sm.stage_events[0].event_type] + if sm.task_events: + return "sparkTaskStart" if sm.task_events[0].event_type == 0 else "sparkTaskEnd" + if sm.executor_events: + return "sparkExecutorAdded" if sm.executor_events[0].event_type == 0 else "sparkExecutorRemoved" + return "unknown" + def _convert_string_numbers_to_int(self, obj): """ Recursively convert string numbers to integers in a dictionary. @@ -1217,17 +1240,18 @@ def _convert_string_numbers_to_int(self, obj): def _proto_to_scala_json_format(self, sm: sparkmonitor_pb2.SparkMonitorProgress) -> dict: """ Convert protobuf message to JSON format matching the Scala listener's output. - - This ensures compatibility with existing VS Code SparkMonitor extension. - The format uses: - - 'msgtype' (lowercase) for the message type field - - camelCase for all other nested fields + + Handles the new ExecutionProgress-based protocol where events are delivered as + typed sub-messages with enums (JobEvent, DetailedStageEvent, TaskEvent, ExecutorEvent) + rather than the old string msg_type + separate data messages approach. + + The output format is unchanged from before: + - 'msgtype' (lowercase) for the event type string + - camelCase for all other fields - Numeric fields as JSON numbers (not strings) """ try: # Convert proto to JSON with camelCase field names - # Try newer protobuf 5.x+ parameter first, fall back to older parameter - # This ensures fields with default values (like jobId=0, attemptId=0) are included try: # Protobuf 5.x+ uses always_print_fields_with_no_presence json_str = json_format.MessageToJson( @@ -1244,42 +1268,53 @@ def _proto_to_scala_json_format(self, sm: sparkmonitor_pb2.SparkMonitorProgress) ) except Exception as e: logger.error(f"Failed to convert proto to JSON: {e}") - # Emergency fallback - return {"msgtype": sm.msg_type or "unknown", "error": "conversion_failed"} - + return {"msgtype": "unknown", "error": "conversion_failed"} + msg = json.loads(json_str) - + # Convert string numbers to actual numbers for compatibility with VS Code extension # MessageToJson converts int64 to strings by default to avoid JS precision issues, # but the SparkMonitor extension expects numeric values msg = self._convert_string_numbers_to_int(msg) - - # Extract the actual event data (everything except msg_type) - # The proto has msg_type at top level and one of the event fields set - event_data = {} - - # Find which event field is set and extract its data - for field_name in [ - 'applicationStart', 'applicationEnd', - 'jobStart', 'jobEnd', - 'stageSubmitted', 'stageCompleted', 'stageActive', - 'taskStart', 'taskEnd', - 'executorAdded', 'executorRemoved' - ]: - if field_name in msg: - event_data = msg[field_name] - break - - # Get the msgtype from msg_type field (it's already camelCase from MessageToJson) - msgtype_value = msg.get('msgType', sm.msg_type) - + + # Use proto HasField / list length for type detection (more reliable than JSON key checks + # because always_print_fields_with_no_presence makes all keys present in JSON). + # Then pull event data from the corresponding JSON key and strip the enum 'eventType' field. + if sm.HasField('application_info'): + msgtype = ( + "sparkApplicationStart" + if sm.application_info.HasField('start_time') + else "sparkApplicationEnd" + ) + event_data = msg.get('applicationInfo', {}) + elif sm.job_events: + msgtype = "sparkJobStart" if sm.job_events[0].event_type == 0 else "sparkJobEnd" + raw = msg.get('jobEvents', [{}])[0] + event_data = {k: v for k, v in raw.items() if k != 'eventType'} + elif sm.stage_events: + msgtype = ["sparkStageSubmitted", "sparkStageActive", "sparkStageCompleted"][ + sm.stage_events[0].event_type + ] + raw = msg.get('stageEvents', [{}])[0] + event_data = {k: v for k, v in raw.items() if k != 'eventType'} + elif sm.task_events: + msgtype = "sparkTaskStart" if sm.task_events[0].event_type == 0 else "sparkTaskEnd" + raw = msg.get('taskEvents', [{}])[0] + event_data = {k: v for k, v in raw.items() if k != 'eventType'} + elif sm.executor_events: + msgtype = ( + "sparkExecutorAdded" + if sm.executor_events[0].event_type == 0 + else "sparkExecutorRemoved" + ) + raw = msg.get('executorEvents', [{}])[0] + event_data = {k: v for k, v in raw.items() if k != 'eventType'} + else: + return {"msgtype": "unknown"} + # Build the final message with 'msgtype' (lowercase) and camelCase event data - result = { - 'msgtype': msgtype_value, # lowercase 'msgtype' - **event_data # Spread the event data (already in camelCase) - } - - return result + return {'msgtype': msgtype, **event_data} + def _send_to_vscode(self, msg: dict): """Send SparkMonitor data to VS Code using IPython display mechanism. diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 68bec8d3..e03279e1 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -1617,6 +1617,7 @@ def test_execute_plan_request_default_behaviour( try: session = DataprocSparkSession.builder.getOrCreate() + mock_uuid4.reset_mock() # clear calls from session init (e.g. _setup_cell_execution_tracking) client = session.client result_request = client._execute_plan_request_with_metadata() @@ -1710,6 +1711,7 @@ def test_execute_plan_request_with_operation_id_provided( try: session = DataprocSparkSession.builder.getOrCreate() + mock_uuid4.reset_mock() # clear calls from session init (e.g. _setup_cell_execution_tracking) client = session.client result_request = client._execute_plan_request_with_metadata() @@ -2676,13 +2678,13 @@ def _encode_varint(value): return result def _build_fake_grpc_response(self, sm): - """Build a fake gRPC response with a SparkMonitorProgress message embedded at field 24.""" - from google.cloud.dataproc_spark_connect.proto import sparkmonitor_pb2 + """Build a fake gRPC response with SparkMonitorProgress packed in extension (Any, field 999).""" + from google.cloud.dataproc_spark_connect.session import _SPARK_MONITOR_TYPE_URL sm_bytes = sm.SerializeToString() - # Field 24, wire type 2 tag = (24 << 3) | 2 = 194 = 0xC2 0x01 as a varint - payload = b'\xc2\x01' + self._encode_varint(len(sm_bytes)) + sm_bytes mock_response = mock.MagicMock() - mock_response.SerializeToString.return_value = payload + mock_response.HasField.side_effect = lambda field: field == "extension" + mock_response.extension.type_url = _SPARK_MONITOR_TYPE_URL + mock_response.extension.value = sm_bytes return mock_response def test_convert_string_numbers_to_int_positive(self): @@ -2723,17 +2725,18 @@ def test_proto_to_scala_json_format_job_start(self): session._convert_string_numbers_to_int = lambda x: DataprocSparkSession._convert_string_numbers_to_int(session, x) sm = sparkmonitor_pb2.SparkMonitorProgress() - sm.msg_type = "sparkJobStart" - sm.job_start.job_id = 3 - sm.job_start.num_tasks = 10 - sm.job_start.num_executors = 2 + je = sm.job_events.add() + je.event_type = sparkmonitor_pb2.SparkMonitorProgress.JobEvent.JOB_START + je.job_id = 3 + je.num_tasks = 10 + je.num_executors = 2 result = DataprocSparkSession._proto_to_scala_json_format(session, sm) self.assertEqual(result["msgtype"], "sparkJobStart") self.assertEqual(result["jobId"], 3) self.assertEqual(result["numTasks"], 10) - self.assertNotIn("jobStart", result) # event data should be spread to top level + self.assertNotIn("eventType", result) # enum field should be stripped from event data def test_proto_to_scala_json_format_job_end(self): from google.cloud.dataproc_spark_connect.proto import sparkmonitor_pb2 @@ -2741,9 +2744,10 @@ def test_proto_to_scala_json_format_job_end(self): session._convert_string_numbers_to_int = lambda x: DataprocSparkSession._convert_string_numbers_to_int(session, x) sm = sparkmonitor_pb2.SparkMonitorProgress() - sm.msg_type = "sparkJobEnd" - sm.job_end.job_id = 3 - sm.job_end.status = "SUCCEEDED" + je = sm.job_events.add() + je.event_type = sparkmonitor_pb2.SparkMonitorProgress.JobEvent.JOB_END + je.job_id = 3 + je.status = "SUCCEEDED" result = DataprocSparkSession._proto_to_scala_json_format(session, sm) @@ -2757,17 +2761,18 @@ def test_proto_to_scala_json_format_stage_active(self): session._convert_string_numbers_to_int = lambda x: DataprocSparkSession._convert_string_numbers_to_int(session, x) sm = sparkmonitor_pb2.SparkMonitorProgress() - sm.msg_type = "sparkStageActive" - sm.stage_active.stage_id = 7 - sm.stage_active.num_tasks = 20 - sm.stage_active.num_completed_tasks = 15 + se = sm.stage_events.add() + se.event_type = sparkmonitor_pb2.SparkMonitorProgress.DetailedStageEvent.STAGE_ACTIVE + se.stage_id = 7 + se.num_tasks = 20 + se.num_completed_tasks = 20 # optional field result = DataprocSparkSession._proto_to_scala_json_format(session, sm) self.assertEqual(result["msgtype"], "sparkStageActive") self.assertEqual(result["stageId"], 7) self.assertEqual(result["numTasks"], 20) - self.assertEqual(result["numCompletedTasks"], 15) + self.assertNotIn("eventType", result) def test_send_to_vscode_skips_when_ipython_unavailable(self): session = self._make_session_instance(_ipython_available=False) @@ -2801,8 +2806,9 @@ def test_send_to_vscode_calls_display_when_ipython_available(self): def test_extract_and_send_skips_response_without_sparkmonitor_data(self): session = self._make_session_instance() + # Response that has no extension field at all mock_response = mock.MagicMock() - mock_response.SerializeToString.return_value = b'\x0a\x05hello' # No \xc2\x01 field tag + mock_response.HasField.side_effect = lambda field: False msg_type_counts = {} responses_with_sparkmonitor = [0] @@ -2819,9 +2825,12 @@ def test_extract_and_send_skips_stream_complete_signal(self): session = self._make_session_instance() sm = sparkmonitor_pb2.SparkMonitorProgress() - sm.msg_type = "sparkMonitorStreamComplete" + sm.stream_complete = True mock_response = self._build_fake_grpc_response(sm) + # Wire up _derive_sparkmonitor_msgtype + session._derive_sparkmonitor_msgtype = lambda s: DataprocSparkSession._derive_sparkmonitor_msgtype(session, s) + msg_type_counts = {} responses_with_sparkmonitor = [0] @@ -2839,15 +2848,17 @@ def test_extract_and_send_processes_valid_job_start_payload(self): session = self._make_session_instance() sm = sparkmonitor_pb2.SparkMonitorProgress() - sm.msg_type = "sparkJobStart" - sm.job_start.job_id = 1 - sm.job_start.num_tasks = 8 + je = sm.job_events.add() + je.event_type = sparkmonitor_pb2.SparkMonitorProgress.JobEvent.JOB_START + je.job_id = 1 + je.num_tasks = 8 mock_response = self._build_fake_grpc_response(sm) # Wire up real implementations so the full extraction pipeline runs session._convert_string_numbers_to_int = lambda x: DataprocSparkSession._convert_string_numbers_to_int(session, x) session._proto_to_scala_json_format = lambda s: DataprocSparkSession._proto_to_scala_json_format(session, s) + session._derive_sparkmonitor_msgtype = lambda s: DataprocSparkSession._derive_sparkmonitor_msgtype(session, s) msg_type_counts = {} responses_with_sparkmonitor = [0] From 2c3f05f08e703d5a814f73d56403b5776bfda0f7 Mon Sep 17 00:00:00 2001 From: Siddhant Rao Date: Thu, 19 Mar 2026 14:28:16 -0700 Subject: [PATCH 3/5] Clean up unnecessary comments --- google/cloud/dataproc_spark_connect/session.py | 15 +++------------ tests/unit/test_session.py | 4 +--- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/google/cloud/dataproc_spark_connect/session.py b/google/cloud/dataproc_spark_connect/session.py index 37e1a614..6ae36d18 100644 --- a/google/cloud/dataproc_spark_connect/session.py +++ b/google/cloud/dataproc_spark_connect/session.py @@ -971,10 +971,9 @@ def __init__( # Setup cell tracking FIRST (sets up the run_id mechanism) self._setup_cell_execution_tracking() - # Then setup SparkMonitor interception + # Setup SparkMonitor interception self._setup_sparkmonitor_interception() - # Setup your existing wrappers execute_plan_request_base_method = ( self.client._execute_plan_request_with_metadata ) @@ -1116,13 +1115,11 @@ def background_consumer(): else: response_queue.put(raw_response) - # Mark stream as exhausted stream_exhausted.set() except Exception as e: background_error[0] = e stream_exhausted.set() finally: - # Signal end of stream response_queue.put(None) # Start background consumer thread @@ -1186,7 +1183,6 @@ def _extract_and_send_sparkmonitor(self, raw_response, response_num: int, msg_ty responses_with_sparkmonitor[0] += 1 - # Derive msgtype for tracking (mirrors old string-based tracking) msg_type = self._derive_sparkmonitor_msgtype(sm) msg_type_counts[msg_type] = msg_type_counts.get(msg_type, 0) + 1 @@ -1277,8 +1273,7 @@ def _proto_to_scala_json_format(self, sm: sparkmonitor_pb2.SparkMonitorProgress) # but the SparkMonitor extension expects numeric values msg = self._convert_string_numbers_to_int(msg) - # Use proto HasField / list length for type detection (more reliable than JSON key checks - # because always_print_fields_with_no_presence makes all keys present in JSON). + # Use proto HasField / list length for type detection. # Then pull event data from the corresponding JSON key and strip the enum 'eventType' field. if sm.HasField('application_info'): msgtype = ( @@ -1312,7 +1307,6 @@ def _proto_to_scala_json_format(self, sm: sparkmonitor_pb2.SparkMonitorProgress) else: return {"msgtype": "unknown"} - # Build the final message with 'msgtype' (lowercase) and camelCase event data return {'msgtype': msgtype, **event_data} @@ -1331,12 +1325,9 @@ def _send_to_vscode(self, msg: dict): display_id = self._current_cell_run_id or str(uuid.uuid4()) - # Match the remote kernel format exactly: - # 1. Convert dict to JSON string (like Scala's pretty(render(json))) - # 2. Wrap in fromscala envelope (like kernel extension does) wrapper = { 'msgtype': 'fromscala', - 'msg': json.dumps(msg) # Convert to JSON string + 'msg': json.dumps(msg) } display_data = { diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index e03279e1..08a012d0 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2721,7 +2721,6 @@ def test_convert_string_numbers_to_int_passthrough_non_string(self): def test_proto_to_scala_json_format_job_start(self): from google.cloud.dataproc_spark_connect.proto import sparkmonitor_pb2 session = self._make_session_instance() - # Wire up _convert_string_numbers_to_int so _proto_to_scala_json_format gets real values session._convert_string_numbers_to_int = lambda x: DataprocSparkSession._convert_string_numbers_to_int(session, x) sm = sparkmonitor_pb2.SparkMonitorProgress() @@ -2736,7 +2735,7 @@ def test_proto_to_scala_json_format_job_start(self): self.assertEqual(result["msgtype"], "sparkJobStart") self.assertEqual(result["jobId"], 3) self.assertEqual(result["numTasks"], 10) - self.assertNotIn("eventType", result) # enum field should be stripped from event data + self.assertNotIn("eventType", result) def test_proto_to_scala_json_format_job_end(self): from google.cloud.dataproc_spark_connect.proto import sparkmonitor_pb2 @@ -2791,7 +2790,6 @@ def test_send_to_vscode_calls_display_when_ipython_available(self): msg = {"msgtype": "sparkJobEnd", "jobId": 1} with mock.patch("IPython.display.display") as mock_display: - # Patch the import inside the method with mock.patch.dict("sys.modules", {"IPython.display": mock.MagicMock(display=mock_display)}): DataprocSparkSession._send_to_vscode(session, msg) From 367c83fe05ccf92dd4972c3f4f8b797ff569996b Mon Sep 17 00:00:00 2001 From: Siddhant Rao Date: Thu, 19 Mar 2026 14:32:44 -0700 Subject: [PATCH 4/5] Clean up unnecessary comments - setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 77a497a3..6eb91abd 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,6 @@ "pyspark[connect]~=4.0.0", "tqdm>=4.67", "websockets>=14.0", - "protobuf>=3.20.0", # Added for proto support + "protobuf>=3.20.0", ], ) \ No newline at end of file From 10ef0d3f0db02c6d9bf19c4fc370774a2b4eb184 Mon Sep 17 00:00:00 2001 From: Siddhant Rao Date: Thu, 19 Mar 2026 14:37:40 -0700 Subject: [PATCH 5/5] Formatted changes using pyink --- .../proto/sparkmonitor_pb2.py | 78 ++++--- .../cloud/dataproc_spark_connect/session.py | 200 ++++++++++++------ setup.py | 2 +- tests/unit/test_session.py | 128 ++++++++--- 4 files changed, 280 insertions(+), 128 deletions(-) diff --git a/google/cloud/dataproc_spark_connect/proto/sparkmonitor_pb2.py b/google/cloud/dataproc_spark_connect/proto/sparkmonitor_pb2.py index e8009fb0..c8840881 100644 --- a/google/cloud/dataproc_spark_connect/proto/sparkmonitor_pb2.py +++ b/google/cloud/dataproc_spark_connect/proto/sparkmonitor_pb2.py @@ -11,40 +11,54 @@ _sym_db = _symbol_database.Default() - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n 0 or - len(sm.stage_events) > 0 or - len(sm.task_events) > 0 or - len(sm.executor_events) > 0 or - sm.HasField('stream_complete') + sm.HasField("application_info") + or len(sm.job_events) > 0 + or len(sm.stage_events) > 0 + or len(sm.task_events) > 0 + or len(sm.executor_events) > 0 + or sm.HasField("stream_complete") ) if not is_sparkmonitor: return @@ -1187,7 +1214,7 @@ def _extract_and_send_sparkmonitor(self, raw_response, response_num: int, msg_ty msg_type_counts[msg_type] = msg_type_counts.get(msg_type, 0) + 1 # Skip stream completion signal (don't forward to VS Code) - if sm.HasField('stream_complete') and sm.stream_complete: + if sm.HasField("stream_complete") and sm.stream_complete: return # Convert to Scala-compatible JSON and send to VS Code @@ -1197,43 +1224,70 @@ def _extract_and_send_sparkmonitor(self, raw_response, response_num: int, msg_ty except Exception as e: logger.debug(f"Error extracting SparkMonitor: {e}") - def _derive_sparkmonitor_msgtype(self, sm: sparkmonitor_pb2.SparkMonitorProgress) -> str: + def _derive_sparkmonitor_msgtype( + self, sm: sparkmonitor_pb2.SparkMonitorProgress + ) -> str: """Derive a msgtype string from the new enum-based SparkMonitor proto structure.""" - if sm.HasField('stream_complete'): + if sm.HasField("stream_complete"): return "sparkMonitorStreamComplete" - if sm.HasField('application_info'): - return "sparkApplicationStart" if sm.application_info.HasField('start_time') else "sparkApplicationEnd" + if sm.HasField("application_info"): + return ( + "sparkApplicationStart" + if sm.application_info.HasField("start_time") + else "sparkApplicationEnd" + ) if sm.job_events: - return "sparkJobStart" if sm.job_events[0].event_type == 0 else "sparkJobEnd" + return ( + "sparkJobStart" + if sm.job_events[0].event_type == 0 + else "sparkJobEnd" + ) if sm.stage_events: - return ["sparkStageSubmitted", "sparkStageActive", "sparkStageCompleted"][sm.stage_events[0].event_type] + return [ + "sparkStageSubmitted", + "sparkStageActive", + "sparkStageCompleted", + ][sm.stage_events[0].event_type] if sm.task_events: - return "sparkTaskStart" if sm.task_events[0].event_type == 0 else "sparkTaskEnd" + return ( + "sparkTaskStart" + if sm.task_events[0].event_type == 0 + else "sparkTaskEnd" + ) if sm.executor_events: - return "sparkExecutorAdded" if sm.executor_events[0].event_type == 0 else "sparkExecutorRemoved" + return ( + "sparkExecutorAdded" + if sm.executor_events[0].event_type == 0 + else "sparkExecutorRemoved" + ) return "unknown" def _convert_string_numbers_to_int(self, obj): """ Recursively convert string numbers to integers in a dictionary. - + MessageToJson converts int64 fields to strings by default to avoid JavaScript precision issues, but the VS Code SparkMonitor extension expects numeric values. """ if isinstance(obj, dict): - return {k: self._convert_string_numbers_to_int(v) for k, v in obj.items()} + return { + k: self._convert_string_numbers_to_int(v) + for k, v in obj.items() + } elif isinstance(obj, list): return [self._convert_string_numbers_to_int(item) for item in obj] elif isinstance(obj, str): # Try to convert string to int if it looks like a number # Negative numbers (like -1 for completionTime) should also be converted - if obj.lstrip('-').isdigit(): + if obj.lstrip("-").isdigit(): return int(obj) return obj else: return obj - def _proto_to_scala_json_format(self, sm: sparkmonitor_pb2.SparkMonitorProgress) -> dict: + def _proto_to_scala_json_format( + self, sm: sparkmonitor_pb2.SparkMonitorProgress + ) -> dict: """ Convert protobuf message to JSON format matching the Scala listener's output. @@ -1253,14 +1307,14 @@ def _proto_to_scala_json_format(self, sm: sparkmonitor_pb2.SparkMonitorProgress) json_str = json_format.MessageToJson( sm, preserving_proto_field_name=False, - always_print_fields_with_no_presence=True + always_print_fields_with_no_presence=True, ) except TypeError: # Protobuf <5.x uses including_default_value_fields json_str = json_format.MessageToJson( sm, preserving_proto_field_name=False, - including_default_value_fields=True + including_default_value_fields=True, ) except Exception as e: logger.error(f"Failed to convert proto to JSON: {e}") @@ -1275,40 +1329,49 @@ def _proto_to_scala_json_format(self, sm: sparkmonitor_pb2.SparkMonitorProgress) # Use proto HasField / list length for type detection. # Then pull event data from the corresponding JSON key and strip the enum 'eventType' field. - if sm.HasField('application_info'): + if sm.HasField("application_info"): msgtype = ( "sparkApplicationStart" - if sm.application_info.HasField('start_time') + if sm.application_info.HasField("start_time") else "sparkApplicationEnd" ) - event_data = msg.get('applicationInfo', {}) + event_data = msg.get("applicationInfo", {}) elif sm.job_events: - msgtype = "sparkJobStart" if sm.job_events[0].event_type == 0 else "sparkJobEnd" - raw = msg.get('jobEvents', [{}])[0] - event_data = {k: v for k, v in raw.items() if k != 'eventType'} + msgtype = ( + "sparkJobStart" + if sm.job_events[0].event_type == 0 + else "sparkJobEnd" + ) + raw = msg.get("jobEvents", [{}])[0] + event_data = {k: v for k, v in raw.items() if k != "eventType"} elif sm.stage_events: - msgtype = ["sparkStageSubmitted", "sparkStageActive", "sparkStageCompleted"][ - sm.stage_events[0].event_type - ] - raw = msg.get('stageEvents', [{}])[0] - event_data = {k: v for k, v in raw.items() if k != 'eventType'} + msgtype = [ + "sparkStageSubmitted", + "sparkStageActive", + "sparkStageCompleted", + ][sm.stage_events[0].event_type] + raw = msg.get("stageEvents", [{}])[0] + event_data = {k: v for k, v in raw.items() if k != "eventType"} elif sm.task_events: - msgtype = "sparkTaskStart" if sm.task_events[0].event_type == 0 else "sparkTaskEnd" - raw = msg.get('taskEvents', [{}])[0] - event_data = {k: v for k, v in raw.items() if k != 'eventType'} + msgtype = ( + "sparkTaskStart" + if sm.task_events[0].event_type == 0 + else "sparkTaskEnd" + ) + raw = msg.get("taskEvents", [{}])[0] + event_data = {k: v for k, v in raw.items() if k != "eventType"} elif sm.executor_events: msgtype = ( "sparkExecutorAdded" if sm.executor_events[0].event_type == 0 else "sparkExecutorRemoved" ) - raw = msg.get('executorEvents', [{}])[0] - event_data = {k: v for k, v in raw.items() if k != 'eventType'} + raw = msg.get("executorEvents", [{}])[0] + event_data = {k: v for k, v in raw.items() if k != "eventType"} else: return {"msgtype": "unknown"} - return {'msgtype': msgtype, **event_data} - + return {"msgtype": msgtype, **event_data} def _send_to_vscode(self, msg: dict): """Send SparkMonitor data to VS Code using IPython display mechanism. @@ -1325,13 +1388,10 @@ def _send_to_vscode(self, msg: dict): display_id = self._current_cell_run_id or str(uuid.uuid4()) - wrapper = { - 'msgtype': 'fromscala', - 'msg': json.dumps(msg) - } + wrapper = {"msgtype": "fromscala", "msg": json.dumps(msg)} display_data = { - 'application/vnd.sparkmonitor+json': wrapper, + "application/vnd.sparkmonitor+json": wrapper, } display(display_data, raw=True, display_id=display_id) diff --git a/setup.py b/setup.py index 6eb91abd..2163f51e 100644 --- a/setup.py +++ b/setup.py @@ -44,4 +44,4 @@ "websockets>=14.0", "protobuf>=3.20.0", ], -) \ No newline at end of file +) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 08a012d0..3e0ddb39 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2670,7 +2670,7 @@ def _make_session_instance(**attrs): @staticmethod def _encode_varint(value): """Encode an integer as a protobuf base-128 varint.""" - result = b'' + result = b"" while value > 127: result += bytes([(value & 0x7F) | 0x80]) value >>= 7 @@ -2680,6 +2680,7 @@ def _encode_varint(value): def _build_fake_grpc_response(self, sm): """Build a fake gRPC response with SparkMonitorProgress packed in extension (Any, field 999).""" from google.cloud.dataproc_spark_connect.session import _SPARK_MONITOR_TYPE_URL + sm_bytes = sm.SerializeToString() mock_response = mock.MagicMock() mock_response.HasField.side_effect = lambda field: field == "extension" @@ -2689,39 +2690,62 @@ def _build_fake_grpc_response(self, sm): def test_convert_string_numbers_to_int_positive(self): session = self._make_session_instance() - result = DataprocSparkSession._convert_string_numbers_to_int(session, "42") + result = DataprocSparkSession._convert_string_numbers_to_int( + session, "42" + ) self.assertEqual(result, 42) self.assertIsInstance(result, int) def test_convert_string_numbers_to_int_negative(self): """Negative string numbers such as completionTime=-1 should be converted.""" session = self._make_session_instance() - result = DataprocSparkSession._convert_string_numbers_to_int(session, "-1") + result = DataprocSparkSession._convert_string_numbers_to_int( + session, "-1" + ) self.assertEqual(result, -1) self.assertIsInstance(result, int) def test_convert_string_numbers_to_int_preserves_non_numeric(self): session = self._make_session_instance() - result = DataprocSparkSession._convert_string_numbers_to_int(session, "sparkJobStart") + result = DataprocSparkSession._convert_string_numbers_to_int( + session, "sparkJobStart" + ) self.assertEqual(result, "sparkJobStart") def test_convert_string_numbers_to_int_nested_dict_and_list(self): session = self._make_session_instance() # Wire up the recursive self-call so nested values are also converted - session._convert_string_numbers_to_int = lambda x: DataprocSparkSession._convert_string_numbers_to_int(session, x) + session._convert_string_numbers_to_int = ( + lambda x: DataprocSparkSession._convert_string_numbers_to_int( + session, x + ) + ) obj = {"jobId": "5", "status": "SUCCEEDED", "stageIds": ["1", "2"]} - result = DataprocSparkSession._convert_string_numbers_to_int(session, obj) - self.assertEqual(result, {"jobId": 5, "status": "SUCCEEDED", "stageIds": [1, 2]}) + result = DataprocSparkSession._convert_string_numbers_to_int( + session, obj + ) + self.assertEqual( + result, {"jobId": 5, "status": "SUCCEEDED", "stageIds": [1, 2]} + ) def test_convert_string_numbers_to_int_passthrough_non_string(self): session = self._make_session_instance() - self.assertEqual(DataprocSparkSession._convert_string_numbers_to_int(session, 99), 99) - self.assertIsNone(DataprocSparkSession._convert_string_numbers_to_int(session, None)) + self.assertEqual( + DataprocSparkSession._convert_string_numbers_to_int(session, 99), 99 + ) + self.assertIsNone( + DataprocSparkSession._convert_string_numbers_to_int(session, None) + ) def test_proto_to_scala_json_format_job_start(self): from google.cloud.dataproc_spark_connect.proto import sparkmonitor_pb2 + session = self._make_session_instance() - session._convert_string_numbers_to_int = lambda x: DataprocSparkSession._convert_string_numbers_to_int(session, x) + session._convert_string_numbers_to_int = ( + lambda x: DataprocSparkSession._convert_string_numbers_to_int( + session, x + ) + ) sm = sparkmonitor_pb2.SparkMonitorProgress() je = sm.job_events.add() @@ -2739,8 +2763,13 @@ def test_proto_to_scala_json_format_job_start(self): def test_proto_to_scala_json_format_job_end(self): from google.cloud.dataproc_spark_connect.proto import sparkmonitor_pb2 + session = self._make_session_instance() - session._convert_string_numbers_to_int = lambda x: DataprocSparkSession._convert_string_numbers_to_int(session, x) + session._convert_string_numbers_to_int = ( + lambda x: DataprocSparkSession._convert_string_numbers_to_int( + session, x + ) + ) sm = sparkmonitor_pb2.SparkMonitorProgress() je = sm.job_events.add() @@ -2756,12 +2785,19 @@ def test_proto_to_scala_json_format_job_end(self): def test_proto_to_scala_json_format_stage_active(self): from google.cloud.dataproc_spark_connect.proto import sparkmonitor_pb2 + session = self._make_session_instance() - session._convert_string_numbers_to_int = lambda x: DataprocSparkSession._convert_string_numbers_to_int(session, x) + session._convert_string_numbers_to_int = ( + lambda x: DataprocSparkSession._convert_string_numbers_to_int( + session, x + ) + ) sm = sparkmonitor_pb2.SparkMonitorProgress() se = sm.stage_events.add() - se.event_type = sparkmonitor_pb2.SparkMonitorProgress.DetailedStageEvent.STAGE_ACTIVE + se.event_type = ( + sparkmonitor_pb2.SparkMonitorProgress.DetailedStageEvent.STAGE_ACTIVE + ) se.stage_id = 7 se.num_tasks = 20 se.num_completed_tasks = 20 # optional field @@ -2777,11 +2813,14 @@ def test_send_to_vscode_skips_when_ipython_unavailable(self): session = self._make_session_instance(_ipython_available=False) with mock.patch("IPython.display.display") as mock_display: - DataprocSparkSession._send_to_vscode(session, {"msgtype": "sparkJobStart"}) + DataprocSparkSession._send_to_vscode( + session, {"msgtype": "sparkJobStart"} + ) mock_display.assert_not_called() def test_send_to_vscode_calls_display_when_ipython_available(self): import json + run_id = "test-run-id-1234" session = self._make_session_instance( _ipython_available=True, @@ -2790,7 +2829,10 @@ def test_send_to_vscode_calls_display_when_ipython_available(self): msg = {"msgtype": "sparkJobEnd", "jobId": 1} with mock.patch("IPython.display.display") as mock_display: - with mock.patch.dict("sys.modules", {"IPython.display": mock.MagicMock(display=mock_display)}): + with mock.patch.dict( + "sys.modules", + {"IPython.display": mock.MagicMock(display=mock_display)}, + ): DataprocSparkSession._send_to_vscode(session, msg) mock_display.assert_called_once() @@ -2812,7 +2854,11 @@ def test_extract_and_send_skips_response_without_sparkmonitor_data(self): responses_with_sparkmonitor = [0] DataprocSparkSession._extract_and_send_sparkmonitor( - session, mock_response, 1, msg_type_counts, responses_with_sparkmonitor + session, + mock_response, + 1, + msg_type_counts, + responses_with_sparkmonitor, ) self.assertEqual(responses_with_sparkmonitor[0], 0) @@ -2820,6 +2866,7 @@ def test_extract_and_send_skips_response_without_sparkmonitor_data(self): def test_extract_and_send_skips_stream_complete_signal(self): from google.cloud.dataproc_spark_connect.proto import sparkmonitor_pb2 + session = self._make_session_instance() sm = sparkmonitor_pb2.SparkMonitorProgress() @@ -2827,13 +2874,21 @@ def test_extract_and_send_skips_stream_complete_signal(self): mock_response = self._build_fake_grpc_response(sm) # Wire up _derive_sparkmonitor_msgtype - session._derive_sparkmonitor_msgtype = lambda s: DataprocSparkSession._derive_sparkmonitor_msgtype(session, s) + session._derive_sparkmonitor_msgtype = ( + lambda s: DataprocSparkSession._derive_sparkmonitor_msgtype( + session, s + ) + ) msg_type_counts = {} responses_with_sparkmonitor = [0] DataprocSparkSession._extract_and_send_sparkmonitor( - session, mock_response, 1, msg_type_counts, responses_with_sparkmonitor + session, + mock_response, + 1, + msg_type_counts, + responses_with_sparkmonitor, ) # Counter incremented but _send_to_vscode NOT called @@ -2843,6 +2898,7 @@ def test_extract_and_send_skips_stream_complete_signal(self): def test_extract_and_send_processes_valid_job_start_payload(self): from google.cloud.dataproc_spark_connect.proto import sparkmonitor_pb2 + session = self._make_session_instance() sm = sparkmonitor_pb2.SparkMonitorProgress() @@ -2854,15 +2910,31 @@ def test_extract_and_send_processes_valid_job_start_payload(self): mock_response = self._build_fake_grpc_response(sm) # Wire up real implementations so the full extraction pipeline runs - session._convert_string_numbers_to_int = lambda x: DataprocSparkSession._convert_string_numbers_to_int(session, x) - session._proto_to_scala_json_format = lambda s: DataprocSparkSession._proto_to_scala_json_format(session, s) - session._derive_sparkmonitor_msgtype = lambda s: DataprocSparkSession._derive_sparkmonitor_msgtype(session, s) + session._convert_string_numbers_to_int = ( + lambda x: DataprocSparkSession._convert_string_numbers_to_int( + session, x + ) + ) + session._proto_to_scala_json_format = ( + lambda s: DataprocSparkSession._proto_to_scala_json_format( + session, s + ) + ) + session._derive_sparkmonitor_msgtype = ( + lambda s: DataprocSparkSession._derive_sparkmonitor_msgtype( + session, s + ) + ) msg_type_counts = {} responses_with_sparkmonitor = [0] DataprocSparkSession._extract_and_send_sparkmonitor( - session, mock_response, 1, msg_type_counts, responses_with_sparkmonitor + session, + mock_response, + 1, + msg_type_counts, + responses_with_sparkmonitor, ) self.assertEqual(responses_with_sparkmonitor[0], 1) @@ -2873,7 +2945,9 @@ def test_extract_and_send_processes_valid_job_start_payload(self): def test_setup_cell_tracking_sets_flag_when_ipython_present(self): """When IPython is available and has a live shell, _ipython_available should be True.""" - session = self._make_session_instance(_ipython_available=False, _current_cell_run_id=None) + session = self._make_session_instance( + _ipython_available=False, _current_cell_run_id=None + ) mock_ip = mock.MagicMock() with mock.patch("IPython.get_ipython", return_value=mock_ip): @@ -2888,7 +2962,9 @@ def test_setup_cell_tracking_sets_flag_when_ipython_present(self): def test_setup_cell_tracking_leaves_flag_false_when_no_ipython_shell(self): """When get_ipython() returns None, _ipython_available should remain False.""" - session = self._make_session_instance(_ipython_available=False, _current_cell_run_id=None) + session = self._make_session_instance( + _ipython_available=False, _current_cell_run_id=None + ) with mock.patch("IPython.get_ipython", return_value=None): DataprocSparkSession._setup_cell_execution_tracking(session) @@ -2898,7 +2974,9 @@ def test_setup_cell_tracking_leaves_flag_false_when_no_ipython_shell(self): def test_setup_cell_tracking_is_resilient_to_import_error(self): """If IPython is not installed, the method should not raise.""" - session = self._make_session_instance(_ipython_available=False, _current_cell_run_id=None) + session = self._make_session_instance( + _ipython_available=False, _current_cell_run_id=None + ) with mock.patch.dict("sys.modules", {"IPython": None}): # Should not raise