Skip to content

Commit c9e7c73

Browse files
committed
RDBC-700 Subscription negotiation
1 parent 1e278bc commit c9e7c73

File tree

4 files changed

+93
-20
lines changed

4 files changed

+93
-20
lines changed

ravendb/documents/commands/subscriptions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,14 @@ def __init__(
6363
certificate: Optional[str] = None,
6464
urls: Optional[List[str]] = None,
6565
node_tag: Optional[str] = None,
66+
server_id: Optional[str] = None,
6667
):
6768
self.port = port
6869
self.url = url
6970
self.certificate = certificate
7071
self.urls = urls
7172
self.node_tag = node_tag
73+
self.server_id = server_id
7274

7375
@classmethod
7476
def from_json(cls, json_dict: Dict) -> TcpConnectionInfo:
@@ -78,6 +80,7 @@ def from_json(cls, json_dict: Dict) -> TcpConnectionInfo:
7880
json_dict.get("Certificate", None),
7981
json_dict.get("Urls", None),
8082
json_dict.get("NodeTag", None),
83+
json_dict.get("ServerId", None),
8184
)
8285

8386

ravendb/documents/subscriptions/worker.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from ravendb import constants
1616
from ravendb.documents.session.entity_to_json import EntityToJson
17-
from ravendb.documents.commands.subscriptions import GetTcpInfoForRemoteTaskCommand
17+
from ravendb.documents.commands.subscriptions import GetTcpInfoForRemoteTaskCommand, TcpConnectionInfo
1818
from ravendb.documents.session.document_session_operations.in_memory_document_session_operations import (
1919
InMemoryDocumentSessionOperations,
2020
)
@@ -319,21 +319,16 @@ def _connect_to_server(self) -> socket:
319319
except ClientVersionMismatchException:
320320
tcp_info = self._legacy_try_get_tcp_info(request_executor)
321321

322-
self._tcp_client, chosen_url = TcpUtils.connect_with_priority(
323-
tcp_info, command.result.certificate, self._store.certificate_pem_path
322+
result = TcpUtils.connect_secured_tcp_socket(
323+
tcp_info,
324+
command.result.certificate,
325+
self._store.certificate_pem_path,
326+
None,
327+
TcpConnectionHeaderMessage.OperationTypes.SUBSCRIPTION,
328+
self.__negotiate_protocol_version_for_subscription,
324329
)
325-
326-
database_name = self._store.get_effective_database(self._db_name)
327-
328-
parameters = TcpNegotiateParameters()
329-
parameters.database = database_name
330-
parameters.operation = TcpConnectionHeaderMessage.OperationTypes.SUBSCRIPTION
331-
parameters.version = TcpConnectionHeaderMessage.SUBSCRIPTION_TCP_VERSION
332-
parameters.read_response_and_get_version_callback = self._read_server_response_and_get_version
333-
parameters.destination_node_tag = self.current_node_tag
334-
parameters.destination_url = chosen_url
335-
336-
self._supported_features = TcpNegotiation.negotiate_protocol_version(self._tcp_client, parameters)
330+
self._tcp_client = result.socket
331+
self._supported_features = result.supported_features
337332

338333
if self._supported_features.protocol_version <= 0:
339334
raise RuntimeError(
@@ -363,6 +358,22 @@ def _connect_to_server(self) -> socket:
363358

364359
return self._tcp_client
365360

361+
def __negotiate_protocol_version_for_subscription(
362+
self, chosen_url: str, tcp_info: TcpConnectionInfo, s: socket
363+
) -> TcpConnectionHeaderMessage.SupportedFeatures:
364+
database_name = self._store.get_effective_database(self._db_name)
365+
366+
parameters = TcpNegotiateParameters()
367+
parameters.database = database_name
368+
parameters.operation = TcpConnectionHeaderMessage.OperationTypes.SUBSCRIPTION
369+
parameters.version = TcpConnectionHeaderMessage.SUBSCRIPTION_TCP_VERSION
370+
parameters.read_response_and_get_version_callback = self._read_server_response_and_get_version
371+
parameters.destination_node_tag = self.current_node_tag
372+
parameters.destination_url = chosen_url
373+
parameters.destination_server_id = tcp_info.server_id
374+
375+
return TcpNegotiation.negotiate_protocol_version(s, parameters)
376+
366377
def _legacy_try_get_tcp_info(self, request_executor: RequestExecutor, node: Optional[ServerNode] = None):
367378
tcp_command = GetTcpInfoCommand(f"Subscription/{self._db_name}", self._db_name)
368379

@@ -378,10 +389,10 @@ def _ensure_parser(self) -> None:
378389
# python doesn't use parsers
379390
pass
380391

381-
def _read_server_response_and_get_version(self, url: str) -> int:
392+
def _read_server_response_and_get_version(self, url: str, sock: socket) -> int:
382393
# reading reply from server
383394
self._ensure_parser()
384-
response = self._tcp_client.recv(self._options.receive_buffer_size)
395+
response = sock.recv(self._options.receive_buffer_size)
385396
reply = TcpConnectionHeaderResponse.from_json(json.loads(response.decode("utf-8")))
386397

387398
if reply.status == TcpConnectionStatus.OK:
@@ -897,6 +908,10 @@ def items(self) -> List[Item[_T]]:
897908
def number_of_items_in_batch(self) -> int:
898909
return 0 if self.items is None else len(self.items)
899910

911+
@property
912+
def number_of_includes(self) -> int:
913+
return len(self._includes) if self._includes is not None else 0
914+
900915
def open_session(self, options: Optional[SessionOptions] = None) -> DocumentSession:
901916
if not options:
902917
options = SessionOptions()

ravendb/serverwide/tcp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,8 @@ def __init__(
253253
source_node_tag: Optional[str] = None,
254254
destination_node_tag: Optional[str] = None,
255255
destination_url: Optional[str] = None,
256-
read_response_and_get_version_callback: Optional[Callable[[str], int]] = None,
256+
read_response_and_get_version_callback: Optional[Callable[[str, socket.socket], int]] = None,
257+
destination_server_id: Optional[str] = None,
257258
):
258259
self.operation = operation
259260
self.authorize_info = authorize_info
@@ -263,6 +264,7 @@ def __init__(
263264
self.destination_node_tag = destination_node_tag
264265
self.destination_url = destination_url
265266
self.read_response_and_get_version_callback = read_response_and_get_version_callback
267+
self.destination_server_id = destination_server_id
266268

267269

268270
class TcpNegotiation:
@@ -280,7 +282,7 @@ def negotiate_protocol_version(
280282
current = parameters.version
281283
while True:
282284
cls._send_tcp_version_info(sock, parameters, current)
283-
version = parameters.read_response_and_get_version_callback(parameters.destination_url)
285+
version = parameters.read_response_and_get_version_callback(parameters.destination_url, sock)
284286

285287
cls.logger.info(
286288
f"Read response from {parameters.source_node_tag or parameters.destination_url} "

ravendb/util/tcp_utils.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import base64
22
import socket
33
import ssl
4-
from typing import Tuple, Optional
4+
from typing import Tuple, Optional, Callable
55

66
from ravendb.documents.commands.subscriptions import TcpConnectionInfo
7+
from ravendb.serverwide.tcp import TcpConnectionHeaderMessage
78

89

910
class TcpUtils:
@@ -55,3 +56,55 @@ def connect_with_priority(
5556
)
5657

5758
return s, info.url
59+
60+
@staticmethod
61+
def invoke_negotiation(
62+
info: TcpConnectionInfo,
63+
operation_type: TcpConnectionHeaderMessage.OperationTypes,
64+
negotiation_callback,
65+
url: str,
66+
socket: socket.socket,
67+
) -> TcpConnectionHeaderMessage.SupportedFeatures:
68+
if operation_type == TcpConnectionHeaderMessage.OperationTypes.SUBSCRIPTION:
69+
return negotiation_callback(url, info, socket)
70+
else:
71+
raise NotImplementedError(f"Operation type '{operation_type}' not supported")
72+
73+
class ConnectSecuredTcpSocketResult:
74+
def __init__(
75+
self,
76+
url: str = None,
77+
socket: socket.socket = None,
78+
supported_features: TcpConnectionHeaderMessage.SupportedFeatures = None,
79+
):
80+
self.url = url
81+
self.socket = socket
82+
self.supported_features = supported_features
83+
84+
@staticmethod
85+
def connect_secured_tcp_socket(
86+
info: TcpConnectionInfo,
87+
server_certificate: str,
88+
client_certificate_pem_path: str,
89+
certificate_private_key_password: Optional[str],
90+
operation_type: TcpConnectionHeaderMessage.OperationTypes,
91+
negotiation_callback: Callable,
92+
) -> ConnectSecuredTcpSocketResult:
93+
if info.urls:
94+
for url in info.urls:
95+
try:
96+
s = TcpUtils.connect(
97+
url, server_certificate, client_certificate_pem_path, certificate_private_key_password
98+
)
99+
supported_features = TcpUtils.invoke_negotiation(info, operation_type, negotiation_callback, url, s)
100+
101+
return TcpUtils.ConnectSecuredTcpSocketResult(url, s, supported_features)
102+
except Exception as e:
103+
pass
104+
# ignored
105+
s = TcpUtils.connect(
106+
info.url, server_certificate, client_certificate_pem_path, certificate_private_key_password
107+
)
108+
109+
supported_features = TcpUtils.invoke_negotiation(info, operation_type, negotiation_callback, info.url, s)
110+
return TcpUtils.ConnectSecuredTcpSocketResult(info.url, s, supported_features)

0 commit comments

Comments
 (0)