Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/) and this project adheres to [Semantic Versioning](https://semver.org/)

## [Unreleased]

### Added

- `forward_entities` parameter
- allow to forward incoming entities to the output port
- `parallel_execution` parameter
- allow to execute X workflows in parallel

## [0.5.1] 2024-05-08

### Added
Expand All @@ -16,4 +25,3 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/) and this p
### Added

- initial version

255 changes: 212 additions & 43 deletions cmem_plugin_loopwf/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import json
from collections.abc import Sequence
from pathlib import Path
from tempfile import NamedTemporaryFile
from dataclasses import dataclass
from http import HTTPStatus
from time import sleep

from cmem.cmempy.workflow.workflow import execute_workflow_io, get_workflows_io
from cmem.cmempy.api import config, get_json
from cmem.cmempy.workflow.workflow import get_workflows_io
from cmem_plugin_base.dataintegration.context import ExecutionContext, ExecutionReport
from cmem_plugin_base.dataintegration.description import Icon, Plugin, PluginParameter
from cmem_plugin_base.dataintegration.entity import (
Expand All @@ -14,9 +16,11 @@
EntityPath,
EntitySchema,
)
from cmem_plugin_base.dataintegration.plugins import WorkflowPlugin
from cmem_plugin_base.dataintegration.plugins import PluginLogger, WorkflowPlugin
from cmem_plugin_base.dataintegration.ports import FixedNumberOfInputs, FlexibleSchemaPort
from cmem_plugin_base.dataintegration.types import BoolParameterType, IntParameterType
from cmem_plugin_base.dataintegration.utils import setup_cmempy_user_access
from requests import HTTPError

from cmem_plugin_loopwf import exceptions
from cmem_plugin_loopwf.workflow_type import SuitableWorkflowParameterType
Expand All @@ -37,6 +41,154 @@
"""


@dataclass
class WorkflowExecution:
"""Represents the status of a concrete workflow execution"""

task_id: str
project_id: str
entity: Entity
schema: EntitySchema
instance_id: str | None = None
activity_id: str | None = None
status: str = "QUEUED"
is_running: bool = False
raw: dict[str, str] | None = None
execution_context: ExecutionContext | None = None
logger: PluginLogger | None = None

@property
def is_finished(self) -> bool:
"""True if workflow is finished"""
return self.status.upper() == "FINISHED"

@property
def is_queued(self) -> bool:
"""True if workflow is queued"""
return self.status.upper() == "QUEUED"

def entity_as_json_str(self) -> str:
"""Return the entity as a JSON string"""
entity_as_dict = StartWorkflow.entity_to_dict(entity=self.entity, schema=self.schema)
return json.dumps(entity_as_dict)

def start(self) -> bool:
"""Start the workflow"""
if self.logger:
self.logger.debug(f"Starting workflow execution: {self.entity_as_json_str()}")
try:
response = get_json(
f"{config.get_di_api_endpoint()}/api/workflow/executeAsync/{self.project_id}/{self.task_id}",
headers={"Content-Type": "application/json"},
method="POST",
data=self.entity_as_json_str(),
)
except HTTPError as error:
if error.response.status_code == HTTPStatus.SERVICE_UNAVAILABLE:
# 503 - no more execution capacity > no status change
return False
raise HTTPError from error
self.instance_id = response["instanceId"]
self.activity_id = response["activityId"]
self.update()
return True

def wait_until_finished(self) -> None:
"""Wait until the workflow is finished"""
while self.is_running:
self.update()
sleep(1)

def update(self) -> None:
"""Update the execution status"""
response = get_json(
f"{config.get_di_api_endpoint()}/workspace/activities/status",
params={
"project": self.project_id,
"task": self.task_id,
"activity": self.activity_id,
"instance": self.instance_id,
},
)
self.status = response["statusName"]
self.is_running = response["isRunning"]
self.raw = response
if self.logger:
self.logger.debug(f"Updated Status: {self!s}")


@dataclass
class WorkflowExecutionList:
"""Workflow execution status list / registry"""

statuses: list[WorkflowExecution]
context: ExecutionContext
logger: PluginLogger

def __init__(self):
self.statuses = []

def execute(self, parallel_execution: int) -> None:
"""Execute all workflow executions"""
while self.queued > 0:
while self.running < parallel_execution and self.queued > 0:
self.start_next()
self.report()
self.wait_until_finished()
self.report()

def start_next(self) -> bool:
"""Start next workflow execution in queue"""
all_queued = [_ for _ in self.statuses if _.is_queued]
if not all_queued:
return False
next_in_queue: WorkflowExecution = all_queued[0]
return next_in_queue.start()

def wait_until_finished(self, polling_time: int = 1) -> None:
"""Wait until all running workflows are finished"""
while self.running > 0:
sleep(polling_time)
self.update_running_status()

def update_running_status(self) -> None:
"""Update status of running workflows"""
for _ in self.statuses:
if _.is_running:
_.update()

def append(self, status: WorkflowExecution) -> None:
"""Append a workflow execution to the list"""
self.statuses.append(status)

def report(self) -> None:
"""Report workflow statuses to logger and/or execution report from context"""
line = f"finished ({self.running} running, {self.queued} queued)"
self.context.report.update(
ExecutionReport(
entity_count=self.finished,
operation="start",
operation_desc=line,
)
)
self.logger.info(f"{self.finished} {line}")

@property
def running(self) -> int:
"""Returns the number of running workflows"""
return len([_ for _ in self.statuses if _.is_running])

@property
def finished(self) -> int:
"""Returns the number of finished workflows"""
return len([_ for _ in self.statuses if _.is_finished])

@property
def queued(self) -> int:
"""Returns the number of queued workflows"""
return len([_ for _ in self.statuses if _.is_queued])


@Plugin(
label="Start Workflow per Entity",
description="Loop over the output of a task and start a sub-workflow for each entity.",
Expand All @@ -50,36 +202,82 @@
param_type=SuitableWorkflowParameterType(),
description="Which workflow do you want to start per entity.",
),
PluginParameter(
name="parallel_execution",
label="How many workflow jobs should run in parallel?",
param_type=IntParameterType(),
default_value=1,
),
PluginParameter(
name="forward_entities",
label="Forward incoming entities to the output port?",
param_type=BoolParameterType(),
default_value=False,
),
],
)
class StartWorkflow(WorkflowPlugin):
"""Start Workflow per Entity"""

context: ExecutionContext
schema: EntitySchema
executions: WorkflowExecutionList

def __init__(self, workflow: str) -> None:
def __init__(
self, workflow: str, parallel_execution: int = 1, forward_entities: bool = False
) -> None:
self.workflow = workflow
if parallel_execution < 1:
raise ValueError("parallel_execution must be >= 1")
self.parallel_execution = parallel_execution
self.forward_entities = forward_entities
self.input_ports = FixedNumberOfInputs([FlexibleSchemaPort()])
self.output_port = None
self.output_port = FlexibleSchemaPort() if forward_entities else None
self.workflows_started = 0
self.executions = WorkflowExecutionList()

def start_workflows(self, inputs: Sequence[Entities]) -> Entities:
"""Start the workflows and return output entities"""
input_entities = inputs[0].entities
schema = inputs[0].schema
self.executions.context = self.context
self.executions.logger = self.log
self.executions.report()
for entity in input_entities:
new_execution = WorkflowExecution(
task_id=self.workflow,
project_id=self.context.task.project_id(),
entity=entity,
schema=schema,
execution_context=self.context,
logger=self.log,
)
self.log.info(f"Got new entity: {new_execution.entity_as_json_str()}")
self.executions.append(new_execution)
self.executions.report()
self.executions.execute(parallel_execution=self.parallel_execution)
# remove execution via /workflow/workflows/{project}/{task}/execution/{executionId}

return Entities(
schema=schema,
entities=iter([_.entity for _ in self.executions.statuses]),
)

def execute(
self,
inputs: Sequence[Entities],
context: ExecutionContext,
) -> None:
) -> Entities | None:
"""Run the workflow operator."""
self.log.info("Start execute")
self.context = context
self.validate_inputs(inputs=inputs)
self.schema = inputs[0].schema
self.validate_workflow(workflow=self.workflow)

for entity in inputs[0].entities:
self.start_workflow(entity=entity)

self.log.info("Stop execute")
output_entities = self.start_workflows(inputs=inputs)
if self.forward_entities:
self.log.info("All done ... forward entities")
return output_entities
self.log.info("All done ...")
return None

@staticmethod
def validate_inputs(inputs: Sequence[Entities]) -> None:
Expand All @@ -106,35 +304,6 @@ def validate_workflow(self, workflow: str) -> None:
)
self.log.info(str(suitable_workflows))

def start_workflow(self, entity: Entity) -> None:
"""Start a single workflow."""
entity_as_dict: dict = self.entity_to_dict(entity=entity, schema=self.schema)
entity_as_json: str = json.dumps(entity_as_dict)
self.log.info(f"Processing new entity: {entity_as_json}")
# start workflow here
with NamedTemporaryFile(mode="w+") as temp_file:
self.log.info(f"temp file for entity: {temp_file.name}")
temp_file.write(entity_as_json)
temp_file.flush()
self.log.info(f"temp file content: {Path(temp_file.name).read_text()}")
setup_cmempy_user_access(context=self.context.user)
execute_workflow_io(
project_name=self.context.task.project_id(),
task_name=self.workflow,
input_file=temp_file.name,
input_mime_type="application/x-plugin-json",
output_mime_type="guess",
auto_config=False,
)
self.workflows_started += 1
self.context.report.update(
ExecutionReport(
entity_count=self.workflows_started,
operation="start",
operation_desc="workflows started",
)
)

@staticmethod
def entity_to_dict(entity: Entity, schema: EntitySchema) -> dict:
"""Convert an entity to a dictionary, using the schema"""
Expand Down
Loading