Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 55 additions & 55 deletions amber/src/main/python/core/architecture/packaging/output_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,47 +133,31 @@ def set_up_port_storage_writer(self, port_id: PortIdentity, storage_uri_base: st
state materialization on the same port. `storage_uri_base` is the
port's base URI; the result and state URIs are derived from it.
"""
document, _ = DocumentFactory.open_document(
VFSURIFactory.result_uri(storage_uri_base)
)
buffered_item_writer = document.writer(str(get_worker_index(self.worker_id)))
writer_queue = Queue()
port_storage_writer = PortStorageWriter(
buffered_item_writer=buffered_item_writer, queue=writer_queue
)
writer_thread = threading.Thread(
target=port_storage_writer.run,
daemon=True,
name=f"port_storage_writer_thread_{port_id}",
)
writer_thread.start()
self._port_storage_writers[port_id] = (
writer_queue,
port_storage_writer,
writer_thread,
)

state_document, _ = DocumentFactory.open_document(
VFSURIFactory.state_uri(storage_uri_base)
)
state_buffered_item_writer = state_document.writer(
str(get_worker_index(self.worker_id))
)
state_writer_queue = Queue()
state_port_writer = PortStorageWriter(
buffered_item_writer=state_buffered_item_writer,
queue=state_writer_queue,
)
state_writer_thread = threading.Thread(
target=state_port_writer.run,
daemon=True,
name=f"port_state_writer_thread_{port_id}",
def start_writer(uri: str, name_prefix: str, registry: dict) -> None:
document, _ = DocumentFactory.open_document(uri)
writer_queue = Queue()
writer = PortStorageWriter(
buffered_item_writer=document.writer(
str(get_worker_index(self.worker_id))
),
queue=writer_queue,
)
thread = threading.Thread(
target=writer.run, daemon=True, name=f"{name_prefix}_{port_id}"
)
thread.start()
registry[port_id] = (writer_queue, writer, thread)

start_writer(
VFSURIFactory.result_uri(storage_uri_base),
"port_storage_writer_thread",
self._port_storage_writers,
)
state_writer_thread.start()
self._port_state_writers[port_id] = (
state_writer_queue,
state_port_writer,
state_writer_thread,
start_writer(
VFSURIFactory.state_uri(storage_uri_base),
"port_state_writer_thread",
self._port_state_writers,
)

def get_port(self, port_id=None) -> WorkerPort:
Expand Down Expand Up @@ -203,14 +187,23 @@ def save_tuple_to_storage_if_needed(self, tuple_: Tuple, port_id=None) -> None:
PortStorageWriterElement(data_tuple=tuple_)
)

def save_state_to_storage_if_needed(self, state: State, port_id=None) -> None:
def save_state_to_storage_if_needed(
self,
state: State,
loop_counter: int = 0,
loop_start_id: str = "",
loop_start_state_uri: str = "",
port_id=None,
) -> None:
# When port_id is omitted the same state row is fanned out to
# every output port's state table. This mirrors the
# broadcast-to-all-workers behavior on the emit side: state is
# shared context, not per-key data, so every downstream operator
# (and every worker reading the materialization) needs the full
# set.
element = PortStorageWriterElement(data_tuple=state.to_tuple())
element = PortStorageWriterElement(
data_tuple=state.to_tuple(loop_counter, loop_start_id, loop_start_state_uri)
)
if port_id is None:
for writer_queue, _, _ in self._port_state_writers.values():
writer_queue.put(element)
Expand All @@ -223,18 +216,16 @@ def close_port_storage_writers(self) -> None:
writer threads to finish, which indicates the port storage writing
are finished.
"""
for _, writer, _ in self._port_storage_writers.values():
# This non-blocking stop call will let the storage writers
# flush the remaining buffer
writer.stop()
for _, _, writer_thread in self._port_storage_writers.values():
# This blocking call will wait for all the writer to finish commit
writer_thread.join()
for _, state_writer, _ in self._port_state_writers.values():
state_writer.stop()
for _, _, state_writer_thread in self._port_state_writers.values():
state_writer_thread.join()
self._port_state_writers.clear()
for registry in (self._port_storage_writers, self._port_state_writers):
# Non-blocking stop lets each writer flush its remaining buffer;
# the join then waits for the commit to finish.
for _, writer, _ in registry.values():
writer.stop()
for _, _, thread in registry.values():
thread.join()
# Drop the stopped writers so a later close doesn't act on
# stale entries.
registry.clear()

def add_partitioning(self, tag: PhysicalLink, partitioning: Partitioning) -> None:
"""
Expand Down Expand Up @@ -290,15 +281,24 @@ def emit_ecm(
)

def emit_state(
self, state: State
self,
state: State,
loop_counter: int = 0,
loop_start_id: str = "",
loop_start_state_uri: str = "",
) -> Iterable[typing.Tuple[ActorVirtualIdentity, DataPayload]]:
return chain(
*(
(
(
receiver,
(
StateFrame(payload)
StateFrame(
payload,
loop_counter=loop_counter,
loop_start_id=loop_start_id,
loop_start_state_uri=loop_start_state_uri,
)
if isinstance(payload, State)
else self.tuple_to_frame(payload)
),
Expand Down
9 changes: 9 additions & 0 deletions amber/src/main/python/core/models/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,12 @@ class DataFrame(DataPayload):
@dataclass
class StateFrame(DataPayload):
frame: State
# Loop-control bookkeeping owned by the worker runtime, carried alongside
# the State payload (not inside it) so it never collides with user state.
# Defaults are the "no loop" values for all non-loop state.
loop_counter: int = 0
# Which LoopStart to jump back to, and the iceberg URI its input is read
# from. Set by the runtime on a LoopStart's output, consumed by the
# matching LoopEnd. Empty for non-loop / not-yet-stamped state.
loop_start_id: str = ""
loop_start_state_uri: str = ""
34 changes: 31 additions & 3 deletions amber/src/main/python/core/models/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,41 @@

class State(dict):
CONTENT = "content"
SCHEMA = Schema(raw_schema={CONTENT: "STRING"})
# Loop-control bookkeeping owned by the worker runtime, NOT user state -- it
# never appears in the content JSON. In memory it rides on the StateFrame
# envelope; it is materialized/serialized as its own column (parallel to
# content) by to_tuple(...). from_tuple() returns the bare State; callers
# that need these values read the corresponding columns off the tuple.
LOOP_COUNTER = "loop_counter"
LOOP_START_ID = "loop_start_id"
LOOP_START_STATE_URI = "loop_start_state_uri"
SCHEMA = Schema(
raw_schema={
CONTENT: "STRING",
LOOP_COUNTER: "LONG",
LOOP_START_ID: "STRING",
LOOP_START_STATE_URI: "STRING",
}
)

def to_json(self) -> str:
return json.dumps(_to_json_value(self), separators=(",", ":"))

def to_tuple(self) -> Tuple:
return Tuple({State.CONTENT: self.to_json()}, schema=State.SCHEMA)
def to_tuple(
self,
loop_counter: int = 0,
loop_start_id: str = "",
loop_start_state_uri: str = "",
) -> Tuple:
return Tuple(
{
State.CONTENT: self.to_json(),
State.LOOP_COUNTER: int(loop_counter),
State.LOOP_START_ID: loop_start_id,
State.LOOP_START_STATE_URI: loop_start_state_uri,
},
schema=State.SCHEMA,
)

@classmethod
def from_json(cls, payload: str) -> "State":
Expand Down
7 changes: 6 additions & 1 deletion amber/src/main/python/core/runnables/network_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,12 @@ def data_handler(command: bytes, table: Table) -> int:
"Data",
lambda _: DataFrame(table),
"State",
lambda _: StateFrame(State.from_json(table[State.CONTENT][0].as_py())),
lambda _: StateFrame(
State.from_json(table[State.CONTENT][0].as_py()),
loop_counter=int(table[State.LOOP_COUNTER][0].as_py()),
loop_start_id=table[State.LOOP_START_ID][0].as_py(),
loop_start_state_uri=table[State.LOOP_START_STATE_URI][0].as_py(),
Comment thread
aglinxinyuan marked this conversation as resolved.
),
"ECM",
lambda _: EmbeddedControlMessage().parse(table["payload"][0].as_py()),
)
Expand Down
15 changes: 13 additions & 2 deletions amber/src/main/python/core/runnables/network_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
from overrides import overrides
from typing import Optional

from core.models import DataPayload, InternalQueue, DataFrame, State, StateFrame
from core.models import (
DataPayload,
InternalQueue,
DataFrame,
State,
StateFrame,
)
from core.models.internal_queue import (
InternalQueueElement,
DataElement,
Expand Down Expand Up @@ -100,7 +106,12 @@ def _send_data(self, to: ChannelIdentity, data_payload: DataPayload) -> None:
elif isinstance(data_payload, StateFrame):
data_header = PythonDataHeader(tag=to, payload_type="State")
table = pa.Table.from_pydict(
{State.CONTENT: [data_payload.frame.to_json()]},
{
State.CONTENT: [data_payload.frame.to_json()],
Comment thread
aglinxinyuan marked this conversation as resolved.
State.LOOP_COUNTER: [int(data_payload.loop_counter)],
State.LOOP_START_ID: [data_payload.loop_start_id],
State.LOOP_START_STATE_URI: [data_payload.loop_start_state_uri],
},
schema=State.SCHEMA.as_arrow_schema(),
)
self._proxy_client.send_data(bytes(data_header), table)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@
from core.architecture.sendsemantics.round_robin_partitioner import (
RoundRobinPartitioner,
)
from core.models import Tuple, InternalQueue, DataFrame, DataPayload, State, StateFrame
from core.models import (
Tuple,
InternalQueue,
DataFrame,
DataPayload,
State,
StateFrame,
)
from core.models.internal_queue import DataElement, ECMElement
from core.storage.document_factory import DocumentFactory
from core.storage.vfs_uri_factory import VFSURIFactory
Expand Down Expand Up @@ -152,7 +159,14 @@ def run(self) -> None:
VFSURIFactory.state_uri(self.uri)
)
for state_row in state_document.get():
self.emit_payload(StateFrame(State.from_tuple(state_row)))
self.emit_payload(
StateFrame(
State.from_tuple(state_row),
loop_counter=state_row[State.LOOP_COUNTER],
loop_start_id=state_row[State.LOOP_START_ID],
loop_start_state_uri=state_row[State.LOOP_START_STATE_URI],
)
)

storage_iterator = self.materialization.get()
# Iterate and process tuples.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ class OutputManager(
// emit side: state is shared context, not per-key data, so every
// downstream operator (and every worker reading the materialization)
// needs the full set.
stateWriterThreads.values.foreach(_.queue.put(Left(state.toTuple)))
stateWriterThreads.values.foreach(_.queue.put(Left(state.toTuple())))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class PythonProxyClient(portNumberPromise: Promise[Int], val actorId: ActorVirtu
case DataFrame(frame) =>
writeArrowStream(mutable.Queue(ArraySeq.unsafeWrapArray(frame): _*), from, "Data")
case StateFrame(state) =>
writeArrowStream(mutable.Queue(state.toTuple), from, "State")
writeArrowStream(mutable.Queue(state.toTuple()), from, "State")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@ def port_b(self):

@pytest.fixture
def state(self):
return State({"loop_counter": 1, "i": 2})
return State({"i": 2})

def test_no_state_writers_is_a_noop(self, output_manager, state):
# With no port set up, save_state_to_storage_if_needed must not
# touch any writer.
output_manager.save_state_to_storage_if_needed(state) # no-op
output_manager.save_state_to_storage_if_needed(state, 0) # no-op

def test_unknown_port_id_is_a_noop(self, output_manager, state, port_a):
output_manager.save_state_to_storage_if_needed(state, port_id=port_a)
output_manager.save_state_to_storage_if_needed(state, 0, port_id=port_a)
# No assertion needed -- the absence of any writer means nothing
# was attempted.

Expand All @@ -67,7 +67,7 @@ def test_enqueues_to_every_port_when_port_id_omitted(
queue_a, _, _ = _stub_state_writer(output_manager, port_a)
queue_b, _, _ = _stub_state_writer(output_manager, port_b)

output_manager.save_state_to_storage_if_needed(state)
output_manager.save_state_to_storage_if_needed(state, 0)

# Each port's writer queue receives one PortStorageWriterElement.
# Critically, save is non-blocking -- the call must not invoke
Expand All @@ -84,7 +84,7 @@ def test_enqueues_only_to_selected_port_when_port_id_specified(
queue_a, _, _ = _stub_state_writer(output_manager, port_a)
queue_b, _, _ = _stub_state_writer(output_manager, port_b)

output_manager.save_state_to_storage_if_needed(state, port_id=port_a)
output_manager.save_state_to_storage_if_needed(state, 0, port_id=port_a)

assert queue_a.put.call_count == 1
queue_b.put.assert_not_called()
Expand All @@ -105,3 +105,16 @@ def test_close_port_storage_writers_stops_state_threads(
thread_a.join.assert_called_once()
thread_b.join.assert_called_once()
assert output_manager._port_state_writers == {}

def test_defaults_loop_columns_when_omitted(self, output_manager, state, port_a):
# Dormancy: callers that pass no loop bookkeeping (every non-loop
# caller, e.g. MainLoop.process_input_state) still produce a valid
# 4-column state tuple with the loop columns at their no-loop defaults.
queue_a, _, _ = _stub_state_writer(output_manager, port_a)

output_manager.save_state_to_storage_if_needed(state) # no loop_counter

data_tuple = queue_a.put.call_args.args[0].data_tuple
assert data_tuple[State.LOOP_COUNTER] == 0
assert data_tuple[State.LOOP_START_ID] == ""
assert data_tuple[State.LOOP_START_STATE_URI] == ""
Loading
Loading