Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,30 @@ class ProxyOptions:
use_insecure_credentials: Whether to use insecure gRPC credentials for the
proxy server.
xla_flags: A list of XLA flags to pass to the proxy server.
sidecar_name: The name of the colocated Python sidecar to register with the
proxy. When set (e.g. to "external"), the proxy passes
``--sidecar_name=<value>`` so that ``jax.experimental.colocated_python``
can reach the sidecar containers running on the worker pods. Leave as
``None`` when no sidecar is deployed.
"""
use_insecure_credentials: bool = False
xla_flags: list[str] = dataclasses.field(default_factory=list)
sidecar_name: str | None = None

@classmethod
def from_list(cls, options: Iterable[str] | None) -> "ProxyOptions":
"""Creates a ProxyOptions object from a list of 'key:value' strings."""
use_insecure = False
xla_flags = []
sidecar_name = None
for option in options or []:
if ":" in option:
key, value = option.split(":", 1)
key_strip = key.strip().lower()
if key_strip == "use_insecure_credentials":
use_insecure = value.strip().lower() == "true"
elif key.strip().lower() == "sidecar_name":
sidecar_name = value.strip()
elif key_strip == "xla_flags":
val_strip = value.strip()
if (
Expand All @@ -78,7 +87,10 @@ def from_list(cls, options: Iterable[str] | None) -> "ProxyOptions":
if xla_flags:
validators.validate_xla_flags(xla_flags)

return cls(use_insecure_credentials=use_insecure, xla_flags=xla_flags)
return cls(
use_insecure_credentials=use_insecure, xla_flags=xla_flags,
sidecar_name=sidecar_name,
)


def _deploy_pathways_proxy_server(
Expand Down Expand Up @@ -134,6 +146,12 @@ def _deploy_pathways_proxy_server(
)
proxy_args_str = "\n" + proxy_args_str

sidecar_args_str = ""
if proxy_options.sidecar_name:
sidecar_args_str = (
f"- --sidecar_name={proxy_options.sidecar_name}"
)

template = string.Template(yaml_template)
substituted_yaml = template.substitute(
PROXY_JOB_NAME=proxy_job_name,
Expand All @@ -145,6 +163,7 @@ def _deploy_pathways_proxy_server(
PROXY_SERVER_IMAGE=proxy_server_image,
PROXY_ENV=proxy_env_str,
PROXY_ARGS=proxy_args_str,
SIDECAR_ARGS=sidecar_args_str,
)

_logger.info("Deploying Pathways proxy: %s", proxy_job_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ spec:
- --resource_manager_address=${PATHWAYS_HEAD_HOSTNAME}:${PATHWAYS_HEAD_PORT}
- --gcs_scratch_location=${GCS_SCRATCH_LOCATION}
- --virtual_slices=${EXPECTED_INSTANCES}${PROXY_ARGS}
${SIDECAR_ARGS}
env:
${PROXY_ENV}
ports:
Expand Down