diff --git a/hazelcast/asyncio/__init__.py b/hazelcast/asyncio/__init__.py index 89df6eea46..2781b856bc 100644 --- a/hazelcast/asyncio/__init__.py +++ b/hazelcast/asyncio/__init__.py @@ -17,3 +17,4 @@ from hazelcast.internal.asyncio_proxy.map import Map, EntryEventCallable from hazelcast.internal.asyncio_proxy.replicated_map import ReplicatedMap from hazelcast.internal.asyncio_proxy.vector_collection import VectorCollection +from hazelcast.internal.asyncio_proxy.reliable_topic import ReliableTopic, ReliableMessageListener diff --git a/hazelcast/core.py b/hazelcast/core.py index 3443759e44..7063c66db0 100644 --- a/hazelcast/core.py +++ b/hazelcast/core.py @@ -20,8 +20,8 @@ class MemberInfo: def __init__( self, address: "Address", - member_uuid: uuid.UUID, - attributes: typing.Dict[str, str], + member_uuid: uuid.UUID | None, + attributes: typing.Dict[str, str] | None, lite_member: bool, version: "MemberVersion", _, diff --git a/hazelcast/internal/asyncio_client.py b/hazelcast/internal/asyncio_client.py index 6452b2ca66..2669fa269d 100644 --- a/hazelcast/internal/asyncio_client.py +++ b/hazelcast/internal/asyncio_client.py @@ -27,6 +27,7 @@ MULTI_MAP_SERVICE, ProxyManager, QUEUE_SERVICE, + RELIABLE_TOPIC_SERVICE, REPLICATED_MAP_SERVICE, RINGBUFFER_SERVICE, SET_SERVICE, @@ -36,6 +37,7 @@ from hazelcast.internal.asyncio_proxy.map import Map from hazelcast.internal.asyncio_proxy.multi_map import MultiMap from hazelcast.internal.asyncio_proxy.queue import Queue +from hazelcast.internal.asyncio_proxy.reliable_topic import ReliableTopic from hazelcast.internal.asyncio_proxy.replicated_map import ReplicatedMap from hazelcast.internal.asyncio_proxy.ringbuffer import Ringbuffer from hazelcast.internal.asyncio_proxy.set import Set @@ -325,7 +327,18 @@ async def get_replicated_map(self, name: str) -> ReplicatedMap[KeyType, ValueTyp """ return await self._proxy_manager.get_or_create(REPLICATED_MAP_SERVICE, name) - async def get_ringbuffer(self, name: str) -> Ringbuffer[ItemType]: + async def get_reliable_topic(self, name: str) -> ReliableTopic: + """Returns the ReliableTopic instance with the specified name. + + Args: + name: Name of the ReliableTopic. + + Returns: + Distributed ReliableTopic instance with the specified name. + """ + return await self._proxy_manager.get_or_create(RELIABLE_TOPIC_SERVICE, name) + + async def get_ringbuffer(self, name: str) -> Ringbuffer: """Returns the distributed Ringbuffer instance with the specified name. Args: diff --git a/hazelcast/internal/asyncio_proxy/manager.py b/hazelcast/internal/asyncio_proxy/manager.py index a0d277170e..08bfab6fc3 100644 --- a/hazelcast/internal/asyncio_proxy/manager.py +++ b/hazelcast/internal/asyncio_proxy/manager.py @@ -6,14 +6,15 @@ from hazelcast.internal.asyncio_proxy.queue import create_queue_proxy from hazelcast.internal.asyncio_proxy.set import create_set_proxy from hazelcast.internal.asyncio_proxy.vector_collection import ( - VectorCollection, create_vector_collection_proxy, ) from hazelcast.protocol.codec import client_create_proxy_codec, client_destroy_proxy_codec from hazelcast.internal.asyncio_invocation import Invocation from hazelcast.internal.asyncio_proxy.base import Proxy from hazelcast.internal.asyncio_proxy.map import create_map_proxy +from hazelcast.internal.asyncio_proxy.reliable_topic import ReliableTopic from hazelcast.internal.asyncio_proxy.replicated_map import create_replicated_map_proxy +from hazelcast.proxy.reliable_topic import _RINGBUFFER_PREFIX from hazelcast.internal.asyncio_proxy.ringbuffer import create_ringbuffer_proxy from hazelcast.util import to_list @@ -21,25 +22,12 @@ MAP_SERVICE = "hz:impl:mapService" MULTI_MAP_SERVICE = "hz:impl:multiMapService" QUEUE_SERVICE = "hz:impl:queueService" +RELIABLE_TOPIC_SERVICE = "hz:impl:reliableTopicService" REPLICATED_MAP_SERVICE = "hz:impl:replicatedMapService" RINGBUFFER_SERVICE = "hz:impl:ringbufferService" SET_SERVICE = "hz:impl:setService" VECTOR_SERVICE = "hz:service:vector" -_proxy_init: typing.Dict[ - str, - typing.Callable[[str, str, typing.Any], typing.Coroutine[typing.Any, typing.Any, typing.Any]], -] = { - LIST_SERVICE: create_list_proxy, - MAP_SERVICE: create_map_proxy, - MULTI_MAP_SERVICE: create_multi_map_proxy, - QUEUE_SERVICE: create_queue_proxy, - REPLICATED_MAP_SERVICE: create_replicated_map_proxy, - RINGBUFFER_SERVICE: create_ringbuffer_proxy, - SET_SERVICE: create_set_proxy, - VECTOR_SERVICE: create_vector_collection_proxy, -} - class ProxyManager: def __init__(self, context): @@ -92,3 +80,26 @@ async def destroy_proxy(self, service_name, name, destroy_on_remote=True): def get_distributed_objects(self): return to_list(v for v in self._proxies.values() if not isinstance(v, asyncio.Future)) + + +async def create_reliable_topic_proxy(service_name, name, context): + ringbuffer = await context.proxy_manager.get_or_create( + RINGBUFFER_SERVICE, _RINGBUFFER_PREFIX + name, create_on_remote=False + ) + return ReliableTopic(service_name, name, context, ringbuffer) + + +_proxy_init: typing.Dict[ + str, + typing.Callable[[str, str, typing.Any], typing.Coroutine[typing.Any, typing.Any, typing.Any]], +] = { + LIST_SERVICE: create_list_proxy, + MAP_SERVICE: create_map_proxy, + MULTI_MAP_SERVICE: create_multi_map_proxy, + QUEUE_SERVICE: create_queue_proxy, + REPLICATED_MAP_SERVICE: create_replicated_map_proxy, + RELIABLE_TOPIC_SERVICE: create_reliable_topic_proxy, + RINGBUFFER_SERVICE: create_ringbuffer_proxy, + SET_SERVICE: create_set_proxy, + VECTOR_SERVICE: create_vector_collection_proxy, +} diff --git a/hazelcast/internal/asyncio_proxy/reliable_topic.py b/hazelcast/internal/asyncio_proxy/reliable_topic.py new file mode 100644 index 0000000000..7a5b42cf92 --- /dev/null +++ b/hazelcast/internal/asyncio_proxy/reliable_topic.py @@ -0,0 +1,567 @@ +import asyncio +import logging +import time +import typing +from uuid import uuid4 + +from hazelcast.config import ReliableTopicConfig, TopicOverloadPolicy +from hazelcast.core import MemberInfo, MemberVersion, EndpointQualifier, ProtocolType +from hazelcast.errors import ( + OperationTimeoutError, + IllegalArgumentError, + HazelcastClientNotActiveError, + ClientOfflineError, + HazelcastInstanceNotActiveError, + DistributedObjectDestroyedError, + TopicOverloadError, +) +from hazelcast.internal.asyncio_proxy.base import Proxy +from hazelcast.proxy.base import TopicMessage +from hazelcast.proxy.reliable_topic import ReliableMessageListener, _ReliableMessageListenerAdapter +from hazelcast.proxy.ringbuffer import OVERFLOW_POLICY_FAIL, OVERFLOW_POLICY_OVERWRITE +from hazelcast.serialization.compact import SchemaNotReplicatedError +from hazelcast.serialization.objects import ReliableTopicMessage +from hazelcast.types import MessageType +from hazelcast.util import check_not_none + +_INITIAL_BACKOFF = 0.1 +_MAX_BACKOFF = 2.0 + +_UNKNOWN_MEMBER_VERSION = MemberVersion(0, 0, 0) +_MEMBER_ENDPOINT_QUALIFIER = EndpointQualifier(ProtocolType.MEMBER, None) + +_logger = logging.getLogger(__name__) + + +class _MessageRunner: + def __init__( + self, + registration_id, + listener, + ringbuffer, + topic_name, + read_batch_size, + to_object, + runners, + ): + self._registration_id = registration_id + self._listener = listener + self._ringbuffer = ringbuffer + self._topic_name = topic_name + self._read_batch_size = read_batch_size + self._to_object = to_object + self._runners = runners + self._sequence = listener.retrieve_initial_sequence() + self._cancelled = False + self._task: asyncio.Task | None = None + + async def start(self): + """Starts the message runner by checking the given sequence. + + If the user provided an initial sequence via listener, we will + use it as it is. If not, we will ask server to get the tail + sequence and use it. + """ + if self._sequence != -1: + # User provided a sequence to start from + return + + # We are going to listen to next publication. + # We don't care about what already has been published. + sequence = await self._ringbuffer.tail_sequence() + self._sequence = sequence + 1 + + def next_batch(self): + """Schedules an asyncio task to read the next batch from the + ringbuffer and call the listener on items when it is done. + """ + if self._cancelled: + return + # The task is assigned to an instance variable to keep a reference. + # That ensures it is not garbage collected before done. + # See: https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task + self._task = asyncio.create_task(self._handle_next_batch()) + + def cancel(self): + """Sets the cancelled flag, cancels the running task, and removes + the runner registration. + """ + self._cancelled = True + self._runners.pop(self._registration_id, None) + self._listener.on_cancel() + + async def _handle_next_batch(self): + """Reads the next batch from the ringbuffer and processes the items.""" + if self._cancelled: + return + + try: + result = await self._ringbuffer.read_many(self._sequence, 1, self._read_batch_size) + + # Check if there are any messages lost since the last read + # and whether the listener can tolerate that. + lost_count = (result.next_sequence_to_read_from - result.read_count) - self._sequence + if lost_count != 0 and not self._is_loss_tolerable(lost_count): + self.cancel() + return + + # Call the listener for each item read. + for i in range(result.size): + try: + message = result[i] + self._listener.store_sequence(result.get_sequence(i)) + member = None + if message.publisher_address: + member = MemberInfo( + message.publisher_address, + None, + None, + False, + _UNKNOWN_MEMBER_VERSION, + None, + { + _MEMBER_ENDPOINT_QUALIFIER: message.publisher_address, + }, + ) + + topic_message = TopicMessage( + self._topic_name, + message.payload, + message.publish_time, + member, + ) + self._listener.on_message(topic_message) + except Exception as e: + if self._terminate(e): + self.cancel() + return + + self._sequence = result.next_sequence_to_read_from + self.next_batch() + except asyncio.CancelledError: + pass + except Exception as e: + # read_many request failed. + if not await self._handle_internal_error(e): + self.cancel() + + def _is_loss_tolerable(self, loss_count: int) -> bool: + """Called when message loss is detected. + + Checks if the listener is able to tolerate the loss. + + Args: + loss_count: Number of lost messages. + + Returns: + ``True`` if the listener may continue reading. + """ + if self._listener.is_loss_tolerant(): + _logger.debug( + "MessageListener %s on topic %s lost %s messages.", + self._listener, + self._topic_name, + loss_count, + ) + return True + + _logger.warning( + "Terminating MessageListener %s on topic %s. " + "Reason: The listener was too slow or the retention period of the message has been violated. " + "%s messages lost.", + self._listener, + self._topic_name, + loss_count, + ) + return False + + def _terminate(self, error: Exception) -> bool: + """Checks if we should terminate the listener based on the error + received while calling on_message. + + Args: + error: Error received while calling the listener. + + Returns: + Should terminate the listener or not. + """ + if self._cancelled: + return True + + try: + terminate = self._listener.is_terminal(error) + if terminate: + _logger.warning( + "Terminating MessageListener %s on topic %s. Reason: Unhandled exception.", + self._listener, + self._topic_name, + exc_info=error, + ) + else: + _logger.debug( + "MessageListener %s on topic %s ran into an error.", + self._listener, + self._topic_name, + exc_info=error, + ) + return terminate + except Exception as e: + _logger.warning( + "Terminating MessageListener %s on topic %s. " + "Reason: Unhandled exception while calling is_terminal method", + self._listener, + self._topic_name, + exc_info=e, + ) + return True + + async def _handle_internal_error(self, error: Exception) -> bool: + """Called when the read_many request fails. + + Based on the error we receive, we will act differently. + + If we can tolerate the error, we will call next_batch here. + The reasoning behind is that, on some cases, we do not immediately + call next_batch, but make a request to the server, and based on + that, call next_batch. + + Args: + error: The error we received. + + Returns: + ``True`` if the error is handled internally. ``False`` otherwise. + When ``False`` is returned, listener should be cancelled. + """ + if isinstance(error, HazelcastClientNotActiveError): + return self._handle_client_not_active_error() + elif isinstance(error, ClientOfflineError): + return self._handle_client_offline_error() + elif isinstance(error, OperationTimeoutError): + return self._handle_timeout_error() + elif isinstance(error, IllegalArgumentError): + return await self._handle_illegal_argument_error(error) + elif isinstance(error, HazelcastInstanceNotActiveError): + return self._handle_instance_not_active_error() + elif isinstance(error, DistributedObjectDestroyedError): + return self._handle_distributed_object_destroyed_error() + else: + return self._handle_generic_error(error) + + def _handle_generic_error(self, error): + # Received an error we do not expect. + _logger.warning( + "Terminating MessageListener %s on topic %s. Reason: Unhandled exception.", + self._listener, + self._topic_name, + exc_info=error, + ) + return False + + def _handle_distributed_object_destroyed_error(self): + # Underlying ringbuffer is destroyed. It should only + # happen when the user destroys the reliable topic + # associated with it. + _logger.debug( + "Terminating MessageListener %s on topic %s. Reason: Topic is destroyed.", + self._listener, + self._topic_name, + ) + return False + + def _handle_instance_not_active_error(self): + # This error should be received from the server. + # We do not throw it anywhere on the client. + _logger.debug( + "Terminating MessageListener %s on topic %s. Reason: Server is shutting down.", + self._listener, + self._topic_name, + ) + return False + + def _handle_client_offline_error(self): + # Client is reconnecting to cluster. + _logger.debug( + "MessageListener %s on topic %s got error. " + "Continuing from the last known sequence %s.", + self._listener, + self._topic_name, + self._sequence, + ) + self.next_batch() + return True + + def _handle_client_not_active_error(self): + # Client#shutdown is called. + _logger.debug( + "Terminating MessageListener %s on topic %s. Reason: Client is shutting down.", + self._listener, + self._topic_name, + ) + return False + + def _handle_timeout_error(self): + # read_many invocation to the server timed out. + _logger.debug( + "MessageListener %s on topic %s timed out. " + "Continuing from the last known sequence %s.", + self._listener, + self._topic_name, + self._sequence, + ) + self.next_batch() + return True + + async def _handle_illegal_argument_error(self, error): + # Server sends this when it detects data loss + # on the underlying ringbuffer. + if self._listener.is_loss_tolerant(): + # Listener can tolerate message loss. Try to continue reading + # after getting head sequence, and try to read from there. + try: + head_sequence = await self._ringbuffer.head_sequence() + _logger.debug( + "MessageListener %s on topic %s requested a too large sequence. " + "Jumping from old sequence %s to sequence %s.", + self._listener, + self._topic_name, + self._sequence, + head_sequence, + exc_info=error, + ) + self._sequence = head_sequence + # We call next_batch only after getting the new head + # sequence and updating our state with it. + self.next_batch() + except Exception as e: + _logger.warning( + "Terminating MessageListener %s on topic %s. " + "Reason: After the ring buffer data related " + "to reliable topic is lost, client tried to get the " + "current head sequence to continue since the listener " + "is loss tolerant, but that request failed.", + self._listener, + self._topic_name, + exc_info=e, + ) + # We said that we can handle that error so the listener + # is not cancelled. But, we could not continue since + # our request to the server failed. We should cancel + # the listener. + self.cancel() + return True + + _logger.warning( + "Terminating MessageListener %s on topic %s. " + "Reason: Underlying ring buffer data related to reliable topic is lost.", + self._listener, + self._topic_name, + ) + return False + + +class ReliableTopic(Proxy, typing.Generic[MessageType]): + """Hazelcast provides distribution mechanism for publishing messages that + are delivered to multiple subscribers, which is also known as a + publish/subscribe (pub/sub) messaging model. Publish and subscriptions are + cluster-wide. When a member subscribes for a topic, it is actually + registering for messages published by any member in the cluster, including + the new members joined after you added the listener. + + Messages are ordered, meaning that listeners(subscribers) will process the + messages in the order they are actually published. + + Hazelcast's Reliable Topic uses the same Topic interface as a regular topic. + The main difference is that Reliable Topic is backed up by the Ringbuffer + data structure, a replicated but not partitioned data structure that stores + its data in a ring-like structure. + """ + + def __init__(self, service_name, name, context, ringbuffer): + super(ReliableTopic, self).__init__(service_name, name, context) + + config = context.config.reliable_topics.get(name, None) + if config is None: + config = ReliableTopicConfig() + + self._config = config + self._ringbuffer = ringbuffer + self._runners: typing.Dict[str, _MessageRunner] = {} + + async def publish(self, message: MessageType) -> None: + """Publishes the message to all subscribers of this topic. + + Args: + message: The message. + """ + check_not_none(message, "Message cannot be None") + try: + payload = self._to_data(message) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.publish, message) + + topic_message = ReliableTopicMessage(time.time(), None, payload) + + overload_policy = self._config.overload_policy + if overload_policy == TopicOverloadPolicy.BLOCK: + return await self._add_with_backoff(topic_message) + elif overload_policy == TopicOverloadPolicy.ERROR: + return await self._add_or_fail(topic_message) + elif overload_policy == TopicOverloadPolicy.DISCARD_OLDEST: + return await self._add_or_overwrite(topic_message) + elif overload_policy == TopicOverloadPolicy.DISCARD_NEWEST: + return await self._add_or_discard(topic_message) + else: + raise ValueError(f"Unexpected overload policy is passed {overload_policy}") + + async def publish_all(self, messages: typing.Sequence[MessageType]) -> None: + """Publishes all messages to all subscribers of this topic. + + Args: + messages: Messages to publish. + """ + check_not_none(messages, "Messages cannot be None") + try: + topic_messages = [] + for message in messages: + check_not_none(message, "Message cannot be None") + payload = self._to_data(message) + topic_messages.append(ReliableTopicMessage(time.time(), None, payload)) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.publish_all, messages) + + overload_policy = self._config.overload_policy + if overload_policy == TopicOverloadPolicy.BLOCK: + return await self._add_messages_with_backoff(topic_messages) + elif overload_policy == TopicOverloadPolicy.ERROR: + return await self._add_messages_or_fail(topic_messages) + elif overload_policy == TopicOverloadPolicy.DISCARD_OLDEST: + return await self._add_messages_or_overwrite(topic_messages) + elif overload_policy == TopicOverloadPolicy.DISCARD_NEWEST: + return await self._add_messages_or_discard(topic_messages) + else: + raise ValueError(f"Unexpected overload policy is passed {overload_policy}") + + async def add_listener( + self, + listener: typing.Union[ + ReliableMessageListener, typing.Callable[[TopicMessage[MessageType]], None] + ], + ) -> str: + """Subscribes to this reliable topic. + + It can be either a simple function or an instance of an + :class:`ReliableMessageListener`. When a function is passed, a + :class:`ReliableMessageListener` is created out of that with + sensible default values. + + When a message is published, the + :func:`ReliableMessageListener.on_message` method of the given + listener (or the function passed) is called. + + More than one message listener can be added on one instance. + + Args: + listener: Listener to add. + + Returns: + The registration id. + """ + check_not_none(listener, "None listener is not allowed") + registration_id = str(uuid4()) + reliable_message_listener = self._to_reliable_message_listener(listener) + runner = _MessageRunner( + registration_id, + reliable_message_listener, + self._ringbuffer, + self.name, + self._config.read_batch_size, + self._to_object, + self._runners, + ) + await runner.start() + # If the runner started successfully, register it. + self._runners[registration_id] = runner + runner.next_batch() + # ensure the runner is scheduled + await asyncio.sleep(0) + return registration_id + + async def remove_listener(self, registration_id: str) -> bool: + """Stops receiving messages for the given message listener. + + If the given listener already removed, this method does nothing. + + Args: + registration_id: ID of listener registration. + + Returns: + ``True`` if registration is removed, ``False`` otherwise. + """ + check_not_none(registration_id, "Registration id cannot be None") + runner = self._runners.get(registration_id, None) + if not runner: + return False + + runner.cancel() + return True + + async def destroy(self) -> bool: + """Destroys underlying Proxy and RingBuffer instances.""" + for runner in list(self._runners.values()): + runner.cancel() + + self._runners.clear() + await super(ReliableTopic, self).destroy() + return await self._ringbuffer.destroy() + + async def _add_or_fail(self, message): + sequence_id = await self._ringbuffer.add(message, OVERFLOW_POLICY_FAIL) + if sequence_id == -1: + raise TopicOverloadError( + "Failed to publish message %s on topic %s." % (message, self.name) + ) + + async def _add_messages_or_fail(self, messages): + sequence_id = await self._ringbuffer.add_all(messages, OVERFLOW_POLICY_FAIL) + if sequence_id == -1: + raise TopicOverloadError("Failed to publish messages on topic %s." % self.name) + + async def _add_or_overwrite(self, message): + await self._ringbuffer.add(message, OVERFLOW_POLICY_OVERWRITE) + + async def _add_messages_or_overwrite(self, messages): + await self._ringbuffer.add_all(messages, OVERFLOW_POLICY_OVERWRITE) + + async def _add_or_discard(self, message): + await self._ringbuffer.add(message, OVERFLOW_POLICY_FAIL) + + async def _add_messages_or_discard(self, messages): + await self._ringbuffer.add_all(messages, OVERFLOW_POLICY_FAIL) + + async def _add_with_backoff(self, message): + backoff = _INITIAL_BACKOFF + while True: + sequence_id = await self._ringbuffer.add(message, OVERFLOW_POLICY_FAIL) + if sequence_id != -1: + return + await asyncio.sleep(backoff) + backoff = min(_MAX_BACKOFF, 2 * backoff) + + async def _add_messages_with_backoff(self, messages): + backoff = _INITIAL_BACKOFF + while True: + sequence_id = await self._ringbuffer.add_all(messages, OVERFLOW_POLICY_FAIL) + if sequence_id != -1: + return + await asyncio.sleep(backoff) + backoff = min(_MAX_BACKOFF, 2 * backoff) + + @staticmethod + def _to_reliable_message_listener(listener): + if isinstance(listener, ReliableMessageListener): + return listener + + if not callable(listener): + raise TypeError("Listener must be a callable") + + return _ReliableMessageListenerAdapter(listener) diff --git a/hazelcast/proxy/base.py b/hazelcast/proxy/base.py index 5c961ae9a9..c2a9e6a04c 100644 --- a/hazelcast/proxy/base.py +++ b/hazelcast/proxy/base.py @@ -266,7 +266,9 @@ class TopicMessage(typing.Generic[MessageType]): __slots__ = ("name", "message", "publish_time", "member") - def __init__(self, name: str, message: MessageType, publish_time: int, member: MemberInfo): + def __init__( + self, name: str, message: MessageType, publish_time: int, member: MemberInfo | None + ): self.name = name self.message = message self.publish_time = publish_time diff --git a/tests/integration/asyncio/proxy/reliable_topic_test.py b/tests/integration/asyncio/proxy/reliable_topic_test.py new file mode 100644 index 0000000000..e3dd050c10 --- /dev/null +++ b/tests/integration/asyncio/proxy/reliable_topic_test.py @@ -0,0 +1,605 @@ +import asyncio +import os +import unittest +from asyncio import InvalidStateError + +from hazelcast.util import AtomicInteger + +from tests.hzrc.ttypes import Lang + +try: + from hazelcast.config import TopicOverloadPolicy + from hazelcast.errors import ( + TopicOverloadError, + HazelcastClientNotActiveError, + TargetDisconnectedError, + ) + from hazelcast.proxy.reliable_topic import ReliableMessageListener +except ImportError: + # For backward compatibility. If we cannot import those, we won't + # be even referencing them in tests. + pass + +from tests.integration.asyncio.base import SingleMemberTestCase +from tests.util import ( + compare_client_version, + random_string, + event_collector, + get_current_timestamp, + skip_if_client_version_older_than, +) + +CAPACITY = 10 + + +@unittest.skipIf( + compare_client_version("4.1") < 0, "Tests the features added in 4.1 version of the client" +) +class ReliableTopicTest(SingleMemberTestCase): + @classmethod + def configure_cluster(cls): + path = os.path.abspath(__file__) + dir_path = os.path.dirname(path) + with open( + os.path.join(dir_path, "../../backward_compatible/proxy/hazelcast_topic.xml") + ) as f: + return f.read() + + @classmethod + def configure_client(cls, config): + config["cluster_name"] = cls.cluster.id + if not compare_client_version("4.1") < 0: + # Add these config elements only to the 4.1+ clients + # since the older versions do not know anything + # about them. + config["reliable_topics"] = { + "discard": { + "overload_policy": TopicOverloadPolicy.DISCARD_NEWEST, + }, + "overwrite": { + "overload_policy": TopicOverloadPolicy.DISCARD_OLDEST, + }, + "block": { + "overload_policy": TopicOverloadPolicy.BLOCK, + }, + "error": { + "overload_policy": TopicOverloadPolicy.ERROR, + }, + } + return config + + async def asyncSetUp(self): + await super().asyncSetUp() + self.topics = [] + self.topic = await self.get_topic(random_string()) + + async def asyncTearDown(self): + for topic in self.topics: + await topic.destroy() + await super().asyncTearDown() + + async def test_add_listener_with_function(self): + topic = await self.get_topic(random_string()) + + collector = event_collector() + registration_id = await topic.add_listener(collector) + self.assertIsNotNone(registration_id) + + await topic.publish("a") + await topic.publish("b") + + await self.assertTrueEventually( + lambda: self.assertEqual(["a", "b"], list(map(lambda m: m.message, collector.events))) + ) + pass + + async def test_add_listener(self): + topic = await self.get_topic(random_string()) + + messages = [] + + on_cancel_call_count = AtomicInteger() + + class Listener(ReliableMessageListener): + def on_message(self, message): + messages.append(message.message) + + def retrieve_initial_sequence(self): + return -1 + + def store_sequence(self, sequence): + pass + + def is_loss_tolerant(self): + return False + + def is_terminal(self, error): + return False + + def on_cancel(self): + on_cancel_call_count.add(1) + + registration_id = await topic.add_listener(Listener()) + self.assertIsNotNone(registration_id) + + await topic.publish("a") + await topic.publish("b") + + await self.assertTrueEventually(lambda: self.assertEqual(["a", "b"], messages)) + + self.assertEqual(0, on_cancel_call_count.get()) + + async def test_add_listener_with_retrieve_initial_sequence(self): + topic = await self.get_topic(random_string()) + + messages = [] + + class Listener(ReliableMessageListener): + def on_message(self, message): + messages.append(message.message) + + def retrieve_initial_sequence(self): + return 5 + + def store_sequence(self, sequence): + pass + + def is_loss_tolerant(self): + return False + + def is_terminal(self, error): + return False + + await topic.publish_all(range(10)) + + registration_id = await topic.add_listener(Listener()) + self.assertIsNotNone(registration_id) + + await self.assertTrueEventually(lambda: self.assertEqual(list(range(5, 10)), messages)) + + async def test_add_listener_with_store_sequence(self): + topic = await self.get_topic(random_string()) + + sequences = [] + + class Listener(ReliableMessageListener): + def on_message(self, message): + pass + + def retrieve_initial_sequence(self): + return -1 + + def store_sequence(self, sequence): + sequences.append(sequence) + + def is_loss_tolerant(self): + return False + + def is_terminal(self, error): + return False + + registration_id = await topic.add_listener(Listener()) + self.assertIsNotNone(registration_id) + + await topic.publish_all(["item-%s" % i for i in range(20)]) + + await self.assertTrueEventually(lambda: self.assertEqual(list(range(20)), sequences)) + + async def test_add_listener_with_loss_tolerant_listener_on_message_loss(self): + topic = await self.get_topic("overwrite") # has capacity of 10 + + messages = [] + + class Listener(ReliableMessageListener): + def on_message(self, message): + messages.append(message.message) + + def retrieve_initial_sequence(self): + return -1 + + def store_sequence(self, sequence): + pass + + def is_loss_tolerant(self): + return True + + def is_terminal(self, error): + return False + + registration_id = await topic.add_listener(Listener()) + self.assertIsNotNone(registration_id) + + # will overwrite first 10 messages, hence they will be lost + await topic.publish_all(range(2 * CAPACITY)) + + await self.assertTrueEventually( + lambda: self.assertEqual(list(range(CAPACITY, 2 * CAPACITY)), messages) + ) + + async def test_add_listener_with_non_loss_tolerant_listener_on_message_loss(self): + topic = await self.get_topic("overwrite") # has capacity of 10 + + messages = [] + + class Listener(ReliableMessageListener): + def on_message(self, message): + messages.append(message.message) + + def retrieve_initial_sequence(self): + return -1 + + def store_sequence(self, sequence): + pass + + def is_loss_tolerant(self): + return False + + def is_terminal(self, error): + return False + + registration_id = await topic.add_listener(Listener()) + self.assertIsNotNone(registration_id) + + # will overwrite first 10 messages, hence they will be lost + await topic.publish_all(range(2 * CAPACITY)) + + self.assertEqual(0, len(messages)) + + # Should be cancelled on message loss + await self.assertTrueEventually(lambda: self.assertEqual(0, len(topic._runners))) + + async def test_add_listener_when_on_message_raises_error(self): + topic = await self.get_topic(random_string()) + + messages = [] + + on_cancel_call_count = AtomicInteger() + + class Listener(ReliableMessageListener): + def on_message(self, message): + message = message.message + if message < 5: + messages.append(message) + else: + raise ValueError("expected") + + def retrieve_initial_sequence(self): + return -1 + + def store_sequence(self, sequence): + pass + + def is_loss_tolerant(self): + return False + + def is_terminal(self, error): + return isinstance(error, ValueError) + + def on_cancel(self) -> None: + on_cancel_call_count.add(1) + + registration_id = await topic.add_listener(Listener()) + self.assertIsNotNone(registration_id) + + await topic.publish_all(range(10)) + + await self.assertTrueEventually(lambda: self.assertEqual(list(range(5)), messages)) + + # Should be cancelled since on_message raised error + await self.assertTrueEventually(lambda: self.assertEqual(0, len(topic._runners))) + + if compare_client_version("5.4") >= 0: + self.assertEqual(1, on_cancel_call_count.get()) + + async def test_add_listener_when_on_message_and_is_terminal_raises_error(self): + topic = await self.get_topic(random_string()) + + messages = [] + + on_cancel_call_count = AtomicInteger() + + class Listener(ReliableMessageListener): + def on_message(self, message): + message = message.message + if message < 5: + messages.append(message) + else: + raise ValueError("expected") + + def retrieve_initial_sequence(self): + return -1 + + def store_sequence(self, sequence): + pass + + def is_loss_tolerant(self): + return False + + def is_terminal(self, error): + raise error + + def on_cancel(self) -> None: + on_cancel_call_count.add(1) + + registration_id = await topic.add_listener(Listener()) + self.assertIsNotNone(registration_id) + + await topic.publish_all(range(10)) + + await self.assertTrueEventually(lambda: self.assertEqual(list(range(5)), messages)) + + # Should be cancelled since on_message raised error + await self.assertTrueEventually(lambda: self.assertEqual(0, len(topic._runners))) + + if compare_client_version("5.4") >= 0: + self.assertEqual(1, on_cancel_call_count.get()) + + async def test_add_listener_with_non_callable(self): + topic = await self.get_topic(random_string()) + with self.assertRaises(TypeError): + await topic.add_listener(3) + + async def test_remove_listener(self): + topic = await self.get_topic(random_string()) + + on_cancel_call_count = AtomicInteger() + + class Listener(ReliableMessageListener): + def on_message(self, message) -> None: + pass + + def retrieve_initial_sequence(self) -> int: + return -1 + + def store_sequence(self, sequence: int) -> None: + pass + + def is_loss_tolerant(self) -> bool: + pass + + def is_terminal(self, error: Exception) -> bool: + pass + + def on_cancel(self) -> None: + on_cancel_call_count.add(1) + + registration_id = await topic.add_listener(Listener()) + self.assertTrue(await topic.remove_listener(registration_id)) + if compare_client_version("5.4") >= 0: + self.assertEqual(1, on_cancel_call_count.get()) + + async def test_remove_listener_does_not_receive_messages_after_removal(self): + topic = await self.get_topic(random_string()) + + collector = event_collector() + registration_id = await topic.add_listener(collector) + self.assertTrue(await topic.remove_listener(registration_id)) + + await topic.publish_all(range(10)) + + self.assertEqual(0, len(collector.events)) + + async def test_remove_listener_twice(self): + topic = await self.get_topic(random_string()) + registration_id = await topic.add_listener(lambda m: m) + self.assertTrue(await topic.remove_listener(registration_id)) + self.assertFalse(await topic.remove_listener(registration_id)) + + async def test_publish_with_discard_newest_policy(self): + topic = await self.get_topic("discard") + + collector = event_collector() + await topic.add_listener(collector) + + for i in range(2 * CAPACITY): + await topic.publish(i) + + await self.assertTrueEventually(lambda: self.assertEqual(CAPACITY, len(collector.events))) + self.assertEqual(list(range(CAPACITY)), await self.get_ringbuffer_data(topic)) + + async def test_publish_with_discard_oldest_policy(self): + topic = await self.get_topic("overwrite") + + collector = event_collector() + await topic.add_listener(collector) + + for i in range(2 * CAPACITY): + await topic.publish(i) + + await self.assertTrueEventually( + lambda: self.assertEqual(2 * CAPACITY, len(collector.events)) + ) + self.assertEqual(list(range(CAPACITY, 2 * CAPACITY)), await self.get_ringbuffer_data(topic)) + + async def test_publish_with_block_policy(self): + topic = await self.get_topic("block") + + collector = event_collector() + await topic.add_listener(collector) + + for i in range(CAPACITY): + await topic.publish(i) + + begin_time = get_current_timestamp() + + for i in range(CAPACITY, 2 * CAPACITY): + await topic.publish(i) + + time_passed = get_current_timestamp() - begin_time + + # TTL is set in the XML config + self.assertTrue(time_passed >= 2.0) + + await self.assertTrueEventually( + lambda: self.assertEqual(2 * CAPACITY, len(collector.events)) + ) + self.assertEqual(list(range(CAPACITY, CAPACITY * 2)), await self.get_ringbuffer_data(topic)) + + async def test_publish_with_error_policy(self): + topic = await self.get_topic("error") + + collector = event_collector() + await topic.add_listener(collector) + + for i in range(CAPACITY): + await topic.publish(i) + + for i in range(CAPACITY, 2 * CAPACITY): + with self.assertRaises(TopicOverloadError): + await topic.publish(i) + + await self.assertTrueEventually(lambda: self.assertEqual(CAPACITY, len(collector.events))) + self.assertEqual(list(range(CAPACITY)), await self.get_ringbuffer_data(topic)) + + async def test_publish_all_with_discard_newest_policy(self): + topic = await self.get_topic("discard") + + collector = event_collector() + await topic.add_listener(collector) + + await topic.publish_all(range(CAPACITY)) + await topic.publish_all(range(CAPACITY, 2 * CAPACITY)) + + await self.assertTrueEventually(lambda: self.assertEqual(CAPACITY, len(collector.events))) + self.assertEqual(list(range(CAPACITY)), await self.get_ringbuffer_data(topic)) + + async def test_publish_all_with_discard_oldest_policy(self): + topic = await self.get_topic("overwrite") + collector = event_collector() + await topic.add_listener(collector) + await topic.publish_all(range(CAPACITY)) + await topic.publish_all(range(CAPACITY, 2 * CAPACITY)) + await self.assertTrueEventually( + lambda: self.assertEqual(2 * CAPACITY, len(collector.events)) + ) + self.assertEqual(list(range(CAPACITY, 2 * CAPACITY)), await self.get_ringbuffer_data(topic)) + + async def test_publish_all_with_block_policy(self): + topic = await self.get_topic("block") + + collector = event_collector() + await topic.add_listener(collector) + + await topic.publish_all(range(CAPACITY)) + + begin_time = get_current_timestamp() + await topic.publish_all(range(CAPACITY, 2 * CAPACITY)) + time_passed = get_current_timestamp() - begin_time + + # TTL is set in the XML config + self.assertTrue(time_passed >= 2.0) + + await self.assertTrueEventually( + lambda: self.assertEqual(2 * CAPACITY, len(collector.events)) + ) + self.assertEqual(list(range(CAPACITY, CAPACITY * 2)), await self.get_ringbuffer_data(topic)) + + async def test_publish_all_with_error_policy(self): + topic = await self.get_topic("error") + + collector = event_collector() + await topic.add_listener(collector) + + await topic.publish_all(range(CAPACITY)) + + with self.assertRaises(TopicOverloadError): + await topic.publish_all(range(CAPACITY, 2 * CAPACITY)) + + await self.assertTrueEventually(lambda: self.assertEqual(CAPACITY, len(collector.events))) + self.assertEqual(list(range(CAPACITY)), await self.get_ringbuffer_data(topic)) + + async def test_durable_subscription(self): + topic = await self.get_topic(random_string()) + + class DurableListener(ReliableMessageListener): + def __init__(self): + self.objects = [] + self.sequences = [] + self.sequence = -1 + + def on_message(self, message): + self.objects.append(message.message) + + def retrieve_initial_sequence(self): + if self.sequence == -1: + return self.sequence + + # +1 to read the next item + return self.sequence + 1 + + def store_sequence(self, sequence): + self.sequences.append(sequence) + self.sequence = sequence + + def is_loss_tolerant(self): + return False + + def is_terminal(self, error): + return True + + listener = DurableListener() + + registration_id = await topic.add_listener(listener) + await topic.publish("item1") + + await self.assertTrueEventually(lambda: self.assertEqual(["item1"], listener.objects)) + + self.assertTrue(await topic.remove_listener(registration_id)) + + await topic.publish("item2") + await topic.publish("item3") + + await topic.add_listener(listener) + + def assertion(): + self.assertEqual(["item1", "item2", "item3"], listener.objects) + self.assertEqual([0, 1, 2], listener.sequences) + + await self.assertTrueEventually(assertion) + + async def test_client_receives_when_server_publish_messages(self): + skip_if_client_version_older_than(self, "4.2.1") + + topic_name = random_string() + topic = await self.get_topic(topic_name) + + received_message_count = [0] + + def listener(message): + self.assertIsNotNone(message.member) + received_message_count[0] += 1 + + await topic.add_listener(listener) + + message_count = 10 + + script = """ + var topic = instance_0.getReliableTopic("%s"); + for (var i = 0; i < %d; i++) { + topic.publish(i); + } + """ % ( + topic_name, + message_count, + ) + + self.rc.executeOnController(self.cluster.id, script, Lang.JAVASCRIPT) + await self.assertTrueEventually( + lambda: self.assertEqual(message_count, received_message_count[0]) + ) + + async def get_ringbuffer_data(self, topic): + ringbuffer = topic._ringbuffer + head_sequence = await ringbuffer.head_sequence() + items = await ringbuffer.read_many(head_sequence, CAPACITY, CAPACITY) + return list( + map( + lambda m: topic._to_object(m.payload), + items, + ) + ) + + async def get_topic(self, name): + topic = await self.client.get_reliable_topic(name) + self.topics.append(topic) + return topic