diff --git a/redis/cluster.py b/redis/cluster.py index 7f3c0d3c82..8f42c1a235 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1141,7 +1141,7 @@ def _get_command_keys(self, *args): redis_conn = self.get_default_node().redis_connection return self.commands_parser.get_keys(redis_conn, *args) - def determine_slot(self, *args) -> int: + def determine_slot(self, *args) -> Optional[int]: """ Figure out what slot to use based on args. @@ -1156,6 +1156,12 @@ def determine_slot(self, *args) -> int: # Get the keys in the command + # CLIENT TRACKING is a special case. + # It doesn't have any keys, it needs to be sent to the provided nodes + # By default it will be sent to all nodes. + if command.upper() == "CLIENT TRACKING": + return None + # EVAL and EVALSHA are common enough that it's wasteful to go to the # redis server to parse the keys. Besides, there is a bug in redis<7.0 # where `self._get_command_keys()` fails anyway. So, we special case diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index 3918611cef..568f4b4914 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -11,6 +11,7 @@ Mapping, NoReturn, Optional, + Sequence, Union, ) @@ -25,6 +26,7 @@ PatternT, ResponseT, ) +from redis.utils import deprecated_function from .core import ( ACLCommands, @@ -755,6 +757,76 @@ def readwrite(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT: self.read_from_replicas = False return self.execute_command("READWRITE", target_nodes=target_nodes) + @deprecated_function( + version="7.2.0", + reason="Use client-side caching feature instead.", + ) + def client_tracking_on( + self, + clientid: Optional[int] = None, + prefix: Sequence[KeyT] = [], + bcast: bool = False, + optin: bool = False, + optout: bool = False, + noloop: bool = False, + target_nodes: Optional["TargetNodesT"] = "all", + ) -> ResponseT: + """ + Enables the tracking feature of the Redis server, that is used + for server assisted client side caching. + + When clientid is provided - in target_nodes only the node that owns the + connection with this id should be provided. + When clientid is not provided - target_nodes can be any node. + + For more information see https://redis.io/commands/client-tracking + """ + return self.client_tracking( + True, + clientid, + prefix, + bcast, + optin, + optout, + noloop, + target_nodes=target_nodes, + ) + + @deprecated_function( + version="7.2.0", + reason="Use client-side caching feature instead.", + ) + def client_tracking_off( + self, + clientid: Optional[int] = None, + prefix: Sequence[KeyT] = [], + bcast: bool = False, + optin: bool = False, + optout: bool = False, + noloop: bool = False, + target_nodes: Optional["TargetNodesT"] = "all", + ) -> ResponseT: + """ + Disables the tracking feature of the Redis server, that is used + for server assisted client side caching. + + When clientid is provided - in target_nodes only the node that owns the + connection with this id should be provided. + When clientid is not provided - target_nodes can be any node. + + For more information see https://redis.io/commands/client-tracking + """ + return self.client_tracking( + False, + clientid, + prefix, + bcast, + optin, + optout, + noloop, + target_nodes=target_nodes, + ) + class AsyncClusterManagementCommands( ClusterManagementCommands, AsyncManagementCommands @@ -782,6 +854,76 @@ async def cluster_delslots(self, *slots: EncodableT) -> List[bool]: ) ) + @deprecated_function( + version="7.2.0", + reason="Use client-side caching feature instead.", + ) + async def client_tracking_on( + self, + clientid: Optional[int] = None, + prefix: Sequence[KeyT] = [], + bcast: bool = False, + optin: bool = False, + optout: bool = False, + noloop: bool = False, + target_nodes: Optional["TargetNodesT"] = "all", + ) -> ResponseT: + """ + Enables the tracking feature of the Redis server, that is used + for server assisted client side caching. + + When clientid is provided - in target_nodes only the node that owns the + connection with this id should be provided. + When clientid is not provided - target_nodes can be any node. + + For more information see https://redis.io/commands/client-tracking + """ + return await self.client_tracking( + True, + clientid, + prefix, + bcast, + optin, + optout, + noloop, + target_nodes=target_nodes, + ) + + @deprecated_function( + version="7.2.0", + reason="Use client-side caching feature instead.", + ) + async def client_tracking_off( + self, + clientid: Optional[int] = None, + prefix: Sequence[KeyT] = [], + bcast: bool = False, + optin: bool = False, + optout: bool = False, + noloop: bool = False, + target_nodes: Optional["TargetNodesT"] = "all", + ) -> ResponseT: + """ + Disables the tracking feature of the Redis server, that is used + for server assisted client side caching. + + When clientid is provided - in target_nodes only the node that owns the + connection with this id should be provided. + When clientid is not provided - target_nodes can be any node. + + For more information see https://redis.io/commands/client-tracking + """ + return await self.client_tracking( + False, + clientid, + prefix, + bcast, + optin, + optout, + noloop, + target_nodes=target_nodes, + ) + class ClusterDataAccessCommands(DataAccessCommands): """ diff --git a/redis/commands/core.py b/redis/commands/core.py index 908895a846..6e1af05635 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -685,7 +685,7 @@ def client_tracking( if noloop: pieces.append("NOLOOP") - return self.execute_command("CLIENT TRACKING", *pieces) + return self.execute_command("CLIENT TRACKING", *pieces, **kwargs) def client_trackinginfo(self, **kwargs) -> ResponseT: """ diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 759c93ffc6..a6cfcd2d94 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -1742,6 +1742,30 @@ def test_client_trackinginfo(self, r): assert len(res) > 2 assert "prefixes" in res or b"prefixes" in res + @skip_if_server_version_lt("6.0.0") + @skip_if_redis_enterprise() + def test_client_tracking(self, r): + # simple case - will execute on all nodes + assert r.client_tracking_on() + assert r.client_tracking_off() + + # id based + node = r.get_default_node() + # when id is provided - the command should be sent to the node that + # owns the connection with this id + client_id = node.redis_connection.client_id() + assert r.client_tracking_on(clientid=client_id, target_nodes=node) + assert r.client_tracking_off(clientid=client_id, target_nodes=node) + + # execute with client id and prefixes and bcast + assert r.client_tracking_on( + clientid=client_id, prefix=["foo", "bar"], bcast=True, target_nodes=node + ) + + # now with some prefixes and without bcast + with pytest.raises(DataError): + assert r.client_tracking_on(prefix=["foo", "bar", "blee"]) + @skip_if_server_version_lt("2.9.50") def test_client_pause(self, r): node = r.get_primaries()[0]