Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions amber/src/main/python/texera_run_python_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.

import base64
import json
import sys
from loguru import logger
Expand Down Expand Up @@ -79,15 +80,18 @@ def init_loguru_logger(stream_log_level) -> None:


def parse_startup_config(raw_config: str) -> dict:
"""Parse and validate the JSON startup configuration.
"""Parse and validate the startup configuration.

The configuration is passed by name (see PythonWorkflowWorker on the JVM
side), so the two sides must agree on an exact key set. Key order is
irrelevant since it is a JSON object. Any drift fails loudly:
side) as a Base64-encoded JSON object. Base64 is used so the argument carries
no quotes or spaces and survives command-line argv quoting on every platform
(a raw JSON string loses its quotes on Windows). The two sides must agree on
an exact key set; key order is irrelevant since it is a JSON object. Any drift
fails loudly:
- a missing or unexpected key raises ValueError;
- a non-string value raises TypeError.
"""
config = json.loads(raw_config)
config = json.loads(base64.b64decode(raw_config).decode("utf-8"))
if not isinstance(config, dict):
raise TypeError(
f"startup config must be a JSON object, got {type(config).__name__}"
Expand All @@ -112,7 +116,7 @@ def parse_startup_config(raw_config: str) -> dict:


def main(raw_config: str) -> None:
"""Start a Python worker from its validated JSON startup configuration."""
"""Start a Python worker from its validated Base64-encoded JSON startup config."""
config = parse_startup_config(raw_config)

init_loguru_logger(config["loggerLevel"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,36 @@ import org.apache.texera.amber.engine.common.ambermessage._
import org.apache.texera.amber.engine.common.{CheckpointState, Utils}
import org.apache.texera.amber.util.JSONUtils.objectMapper

import java.nio.charset.StandardCharsets
import java.nio.file.Path
import org.apache.texera.web.resource.pythonvirtualenvironment.PveManager
import java.util.Base64
import java.util.concurrent.{ExecutorService, Executors}
import scala.sys.process.{BasicIO, Process}

object PythonWorkflowWorker {
def props(workerConfig: WorkerConfig): Props = Props(new PythonWorkflowWorker(workerConfig))

/**
* Serialize the Python worker startup configuration to a JSON object, keyed by
* name. Built from a sequence of (key, value) pairs so a duplicate key fails
* loudly here instead of being silently dropped by Map construction.
* Serialize the Python worker startup configuration to a JSON object keyed by
* name, then Base64-encode it for safe passing as a command-line argument. Built
* from a sequence of (key, value) pairs so a duplicate key fails loudly here
* instead of being silently dropped by Map construction.
*
* The Base64 step matters on Windows: a raw JSON string passed as argv loses its
* quotes there (the JVM assembles argv into a single command line and the inner
* double quotes are stripped before Python receives it), so `json.loads` fails.
* Base64 uses only `[A-Za-z0-9+/=]`, which survives argv quoting on every
* platform. The Python side Base64-decodes before parsing the JSON.
*/
def encodeStartupConfig(entries: Seq[(String, String)]): String = {
val duplicateKeys = entries.groupBy(_._1).collect { case (key, group) if group.size > 1 => key }
require(
duplicateKeys.isEmpty,
s"duplicate Python worker startup config keys: ${duplicateKeys.mkString(", ")}"
)
objectMapper.writeValueAsString(entries.toMap)
val json = objectMapper.writeValueAsString(entries.toMap)
Base64.getEncoder.encodeToString(json.getBytes(StandardCharsets.UTF_8))
}

/**
Expand Down
30 changes: 23 additions & 7 deletions amber/src/test/python/test_run_python_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.

import base64
import json
from unittest import mock

Expand All @@ -23,6 +24,15 @@
import texera_run_python_worker as entry


def _encode(config) -> str:
"""Encode a config the way PythonWorkflowWorker does: Base64-encoded JSON.

The JVM side passes the startup config as Base64 so it survives command-line
argv quoting on every platform (a raw JSON string loses its quotes on Windows).
"""
return base64.b64encode(json.dumps(config).encode("utf-8")).decode("ascii")


def _full_config() -> dict:
"""A complete startup config matching the keys PythonWorkflowWorker emits."""
return {
Expand Down Expand Up @@ -69,7 +79,7 @@ def test_main_maps_named_config_to_storage_and_worker():
config = _full_config()
storage_patch, worker_patch, _logger_patch = _patched_collaborators()
with storage_patch as storage_config, worker_patch as python_worker, _logger_patch:
entry.main(json.dumps(config))
entry.main(_encode(config))

storage_config.initialize.assert_called_once_with(
"postgres",
Expand Down Expand Up @@ -99,7 +109,7 @@ def test_main_mapping_is_independent_of_key_order():
reordered = dict(reversed(list(_full_config().items())))
storage_patch, worker_patch, _logger_patch = _patched_collaborators()
with storage_patch as storage_config, worker_patch as python_worker, _logger_patch:
entry.main(json.dumps(reordered))
entry.main(_encode(reordered))

storage_config.initialize.assert_called_once_with(
"postgres",
Expand Down Expand Up @@ -131,7 +141,7 @@ def test_main_sets_r_home_when_r_path_present(monkeypatch):
with storage_patch, worker_patch, _logger_patch:
import os

entry.main(json.dumps(config))
entry.main(_encode(config))
assert os.environ["R_HOME"] == "/opt/R"


Expand All @@ -141,25 +151,31 @@ def test_parse_rejects_a_missing_key(missing_key):
config = _full_config()
del config[missing_key]
with pytest.raises(ValueError, match="key mismatch"):
entry.parse_startup_config(json.dumps(config))
entry.parse_startup_config(_encode(config))


def test_parse_rejects_an_unexpected_key():
"""An extra key (e.g. the JVM side added a field) fails instead of being ignored."""
config = _full_config()
config["someNewField"] = "value"
with pytest.raises(ValueError, match="key mismatch"):
entry.parse_startup_config(json.dumps(config))
entry.parse_startup_config(_encode(config))


def test_parse_rejects_a_non_string_value():
"""A wrongly-typed value (e.g. a number instead of a string) fails."""
config = _full_config()
config["outputPort"] = 5005 # number instead of the expected string
with pytest.raises(TypeError, match="must be strings"):
entry.parse_startup_config(json.dumps(config))
entry.parse_startup_config(_encode(config))


def test_parse_rejects_a_non_object_payload():
with pytest.raises(TypeError, match="must be a JSON object"):
entry.parse_startup_config(json.dumps(["not", "an", "object"]))
entry.parse_startup_config(_encode(["not", "an", "object"]))


def test_parse_round_trips_a_base64_encoded_config():
"""The config is passed as Base64-encoded JSON; parsing decodes it back."""
config = _full_config()
assert entry.parse_startup_config(_encode(config)) == config
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,32 @@ package org.apache.texera.amber.engine.architecture.pythonworker
import org.apache.texera.amber.util.JSONUtils.objectMapper
import org.scalatest.flatspec.AnyFlatSpec

import java.nio.charset.StandardCharsets
import java.util.Base64

class PythonWorkflowWorkerStartupConfigSpec extends AnyFlatSpec {

"encodeStartupConfig" should "serialize entries to a JSON object keyed by name" in {
val json = PythonWorkflowWorker.encodeStartupConfig(
private def decode(encoded: String): String =
new String(Base64.getDecoder.decode(encoded), StandardCharsets.UTF_8)

"encodeStartupConfig" should "serialize entries to a Base64-encoded JSON object keyed by name" in {
val encoded = PythonWorkflowWorker.encodeStartupConfig(
Seq("workerId" -> "w-1", "outputPort" -> "5005", "s3Region" -> "us-west-2")
)
val parsed = objectMapper.readValue(json, classOf[java.util.Map[String, String]])
val parsed = objectMapper.readValue(decode(encoded), classOf[java.util.Map[String, String]])
assert(parsed.get("workerId") == "w-1")
assert(parsed.get("outputPort") == "5005")
assert(parsed.get("s3Region") == "us-west-2")
assert(parsed.size() == 3)
}

it should "produce output free of quotes and whitespace so it survives argv quoting on Windows" in {
val encoded = PythonWorkflowWorker.encodeStartupConfig(
Seq("workerId" -> "w-1", "s3Region" -> "us-west-2")
)
assert(!encoded.exists(c => c == '"' || c.isWhitespace))
}

it should "fail loudly when the same key appears more than once" in {
val exception = intercept[IllegalArgumentException] {
PythonWorkflowWorker.encodeStartupConfig(
Expand Down Expand Up @@ -80,10 +93,10 @@ class PythonWorkflowWorkerStartupConfigSpec extends AnyFlatSpec {
}

it should "produce a config that round-trips through encodeStartupConfig" in {
val json = PythonWorkflowWorker.encodeStartupConfig(
val encoded = PythonWorkflowWorker.encodeStartupConfig(
PythonWorkflowWorker.buildStartupConfig("w", "1", "", "uri")
)
val parsed = objectMapper.readValue(json, classOf[java.util.Map[String, String]])
val parsed = objectMapper.readValue(decode(encoded), classOf[java.util.Map[String, String]])
assert(parsed.get("workerId") == "w")
assert(parsed.get("s3LargeBinariesBaseUri") == "uri")
assert(parsed.size() == expectedKeys.size)
Expand Down
Loading