Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions providers/microsoft/azure/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ dependencies = [
"apache-airflow>=2.11.0",
"apache-airflow-providers-common-compat>=1.13.0",
"adlfs>=2023.10.0",
# azure-batch 15.x is a full rewrite of the Azure SDK (track 2) that removes BatchServiceClient, batch_auth,
# and the other references in AzureBatchHook. Lifting the upper bound cap needs a full hook rewrite.
"azure-batch>=8.0.0,<15.0.0",
"azure-batch>=15.0.0",
"azure-cosmos>=4.6.0",
"azure-mgmt-cosmosdb>=3.0.0",
"azure-datalake-store>=0.0.45",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from functools import cached_property
from typing import TYPE_CHECKING, Any

from azure.batch import BatchServiceClient, batch_auth, models as batch_models
from azure.batch import BatchClient, models as batch_models
from azure.core.credentials import AzureNamedKeyCredential
from azure.core.exceptions import ResourceExistsError

from airflow.providers.common.compat.sdk import AirflowException, BaseHook
from airflow.providers.microsoft.azure.utils import (
Expand All @@ -33,7 +35,7 @@
from airflow.utils import timezone

if TYPE_CHECKING:
from azure.batch.models import JobAddParameter, PoolAddParameter, TaskAddParameter
from azure.batch.models import BatchJobCreateOptions, BatchPoolCreateOptions, BatchTaskCreateOptions


class AzureBatchHook(BaseHook):
Expand Down Expand Up @@ -85,11 +87,11 @@ def _get_field(self, extras, name):
)

@cached_property
def connection(self) -> BatchServiceClient:
def connection(self) -> BatchClient:
"""Get the Batch client connection (cached)."""
return self.get_conn()

def get_conn(self) -> BatchServiceClient:
def get_conn(self) -> BatchClient:
"""
Get the Batch client connection.

Expand All @@ -101,20 +103,20 @@ def get_conn(self) -> BatchServiceClient:
if not batch_account_url:
raise AirflowException("Batch Account URL parameter is missing.")

credentials: batch_auth.SharedKeyCredentials | AzureIdentityCredentialAdapter
credential: AzureNamedKeyCredential | AzureIdentityCredentialAdapter
if all([conn.login, conn.password]):
credentials = batch_auth.SharedKeyCredentials(conn.login, conn.password)
credential = AzureNamedKeyCredential(conn.login, conn.password)
else:
managed_identity_client_id = conn.extra_dejson.get("managed_identity_client_id")
workload_identity_tenant_id = conn.extra_dejson.get("workload_identity_tenant_id")
credentials = AzureIdentityCredentialAdapter(
credential = AzureIdentityCredentialAdapter(
None,
resource_id="https://batch.core.windows.net/.default",
managed_identity_client_id=managed_identity_client_id,
workload_identity_tenant_id=workload_identity_tenant_id,
)

batch_client = BatchServiceClient(credentials, batch_url=batch_account_url)
batch_client = BatchClient(endpoint=batch_account_url, credential=credential)
return batch_client

def configure_pool(
Expand All @@ -127,13 +129,11 @@ def configure_pool(
sku_starts_with: str | None = None,
vm_sku: str | None = None,
vm_version: str | None = None,
os_family: str | None = None,
os_version: str | None = None,
display_name: str | None = None,
target_dedicated_nodes: int | None = None,
use_latest_image_and_sku: bool = False,
**kwargs,
) -> PoolAddParameter:
) -> BatchPoolCreateOptions:
"""
Configure a pool.

Expand Down Expand Up @@ -162,17 +162,13 @@ def configure_pool(

:param vm_node_agent_sku_id: The node agent sku id of the virtual machine

:param os_family: The Azure Guest OS family to be installed on the virtual machines in the Pool.

:param os_version: The OS family version

"""
if use_latest_image_and_sku:
self.log.info("Using latest verified virtual machine image with node agent sku")
sku_to_use, image_ref_to_use = self._get_latest_verified_image_vm_and_sku(
publisher=vm_publisher, offer=vm_offer, sku_starts_with=sku_starts_with
)
pool = batch_models.PoolAddParameter(
pool = batch_models.BatchPoolCreateOptions(
id=pool_id,
vm_size=vm_size,
display_name=display_name,
Expand All @@ -183,29 +179,14 @@ def configure_pool(
**kwargs,
)

elif os_family:
self.log.info(
"Using cloud service configuration to create pool, virtual machine configuration ignored"
)
pool = batch_models.PoolAddParameter(
id=pool_id,
vm_size=vm_size,
display_name=display_name,
cloud_service_configuration=batch_models.CloudServiceConfiguration(
os_family=os_family, os_version=os_version
),
target_dedicated_nodes=target_dedicated_nodes,
**kwargs,
)

else:
self.log.info("Using virtual machine configuration to create a pool")
pool = batch_models.PoolAddParameter(
pool = batch_models.BatchPoolCreateOptions(
id=pool_id,
vm_size=vm_size,
display_name=display_name,
virtual_machine_configuration=batch_models.VirtualMachineConfiguration(
image_reference=batch_models.ImageReference(
image_reference=batch_models.BatchVmImageReference(
publisher=vm_publisher,
offer=vm_offer,
sku=vm_sku,
Expand All @@ -218,7 +199,7 @@ def configure_pool(
)
return pool

def create_pool(self, pool: PoolAddParameter) -> None:
def create_pool(self, pool: BatchPoolCreateOptions) -> None:
"""
Create a pool if not already existing.

Expand All @@ -227,11 +208,9 @@ def create_pool(self, pool: PoolAddParameter) -> None:
"""
try:
self.log.info("Attempting to create a pool: %s", pool.id)
self.connection.pool.add(pool)
self.connection.create_pool(pool)
self.log.info("Created pool: %s", pool.id)
except batch_models.BatchErrorException as err:
if not err.error or err.error.code != "PoolExists":
raise
except ResourceExistsError:
self.log.info("Pool %s already exists", pool.id)

def _get_latest_verified_image_vm_and_sku(
Expand All @@ -249,8 +228,7 @@ def _get_latest_verified_image_vm_and_sku(
For example, UbuntuServer or WindowsServer.
:param sku_starts_with: The start name of the sku to search
"""
options = batch_models.AccountListSupportedImagesOptions(filter="verificationType eq 'verified'")
images = self.connection.account.list_supported_images(account_list_supported_images_options=options)
images = self.connection.list_supported_images(filter="verificationType eq 'verified'")
# pick the latest supported sku
skus_to_use = [
(image.node_agent_sku_id, image.image_reference)
Expand All @@ -269,16 +247,16 @@ def wait_for_all_node_state(self, pool_id: str, node_state: set) -> list:
Wait for all nodes in a pool to reach given states.

:param pool_id: A string that identifies the pool
:param node_state: A set of batch_models.ComputeNodeState
:param node_state: A set of batch_models.BatchNodeState
"""
self.log.info("waiting for all nodes in pool %s to reach one of: %s", pool_id, node_state)
while True:
# refresh pool to ensure that there is no resize error
pool = self.connection.pool.get(pool_id)
pool = self.connection.get_pool(pool_id)
if pool.resize_errors is not None:
resize_errors = "\n".join(repr(e) for e in pool.resize_errors)
raise RuntimeError(f"resize error encountered for pool {pool.id}:\n{resize_errors}")
nodes = list(self.connection.compute_node.list(pool.id))
nodes = list(self.connection.list_nodes(pool.id))
if len(nodes) >= pool.target_dedicated_nodes and all(node.state in node_state for node in nodes):
return nodes
# Allow the timeout to be controlled by the AzureBatchOperator
Expand All @@ -292,34 +270,32 @@ def configure_job(
pool_id: str,
display_name: str | None = None,
**kwargs,
) -> JobAddParameter:
) -> BatchJobCreateOptions:
"""
Configure a job for use in the pool.

:param job_id: A string that uniquely identifies the job within the account
:param pool_id: A string that identifies the pool
:param display_name: The display name for the job
"""
job = batch_models.JobAddParameter(
job = batch_models.BatchJobCreateOptions(
id=job_id,
pool_info=batch_models.PoolInformation(pool_id=pool_id),
pool_info=batch_models.BatchPoolInfo(pool_id=pool_id),
display_name=display_name,
**kwargs,
)
return job

def create_job(self, job: JobAddParameter) -> None:
def create_job(self, job: BatchJobCreateOptions) -> None:
"""
Create a job in the pool.

:param job: The job object to create
"""
try:
self.connection.job.add(job)
self.connection.create_job(job)
self.log.info("Job %s created", job.id)
except batch_models.BatchErrorException as err:
if not err.error or err.error.code != "JobExists":
raise
except ResourceExistsError:
self.log.info("Job %s already exists", job.id)

def configure_task(
Expand All @@ -329,7 +305,7 @@ def configure_task(
display_name: str | None = None,
container_settings=None,
**kwargs,
) -> TaskAddParameter:
) -> BatchTaskCreateOptions:
"""
Create a task.

Expand All @@ -341,7 +317,7 @@ def configure_task(
this must be set as well. If the Pool that will run this Task doesn't have
containerConfiguration set, this must not be set.
"""
task = batch_models.TaskAddParameter(
task = batch_models.BatchTaskCreateOptions(
id=task_id,
command_line=command_line,
display_name=display_name,
Expand All @@ -351,21 +327,19 @@ def configure_task(
self.log.info("Task created: %s", task_id)
return task

def add_single_task_to_job(self, job_id: str, task: TaskAddParameter) -> None:
def add_single_task_to_job(self, job_id: str, task: BatchTaskCreateOptions) -> None:
"""
Add a single task to given job if it doesn't exist.

:param job_id: A string that identifies the given job
:param task: The task to add
"""
try:
self.connection.task.add(job_id=job_id, task=task)
except batch_models.BatchErrorException as err:
if not err.error or err.error.code != "TaskExists":
raise
self.connection.create_task(job_id, task)
except ResourceExistsError:
self.log.info("Task %s already exists", task.id)

def wait_for_job_tasks_to_complete(self, job_id: str, timeout: int) -> list[batch_models.CloudTask]:
def wait_for_job_tasks_to_complete(self, job_id: str, timeout: int) -> list[batch_models.BatchTask]:
"""
Wait for tasks in a particular job to complete.

Expand All @@ -374,15 +348,17 @@ def wait_for_job_tasks_to_complete(self, job_id: str, timeout: int) -> list[batc
"""
timeout_time = timezone.utcnow() + timedelta(minutes=timeout)
while timezone.utcnow() < timeout_time:
tasks = list(self.connection.task.list(job_id))
tasks = list(self.connection.list_tasks(job_id))

incomplete_tasks = [task for task in tasks if task.state != batch_models.TaskState.completed]
incomplete_tasks = [
task for task in tasks if task.state != batch_models.BatchTaskState.COMPLETED
]
if not incomplete_tasks:
# detect if any task in job has failed
fail_tasks = [
task
for task in tasks
if task.execution_info.result == batch_models.TaskExecutionResult.failure
if task.execution_info.result == batch_models.BatchTaskExecutionResult.FAILURE
]
return fail_tasks
for task in incomplete_tasks:
Expand All @@ -393,12 +369,12 @@ def wait_for_job_tasks_to_complete(self, job_id: str, timeout: int) -> list[batc
def test_connection(self):
"""Test a configured Azure Batch connection."""
try:
# Attempt to list existing jobs under the configured Batch account and retrieve
# Attempt to list existing jobs under the configured Batch account and retrieve
# the first in the returned iterator. The Azure Batch API does allow for creation of a
# BatchServiceClient with incorrect values but then will fail properly once items are
# BatchClient with incorrect values but then will fail properly once items are
# retrieved using the client. We need to _actually_ try to retrieve an object to properly
# test the connection.
next(self.get_conn().job.list(), None)
next(self.get_conn().list_jobs(), None)
except Exception as e:
return False, str(e)
return True, "Successfully connected to Azure Batch."
Loading
Loading