diff --git a/.github/workflows/typedb.yml b/.github/workflows/typedb.yml new file mode 100644 index 0000000..b4fa7e0 --- /dev/null +++ b/.github/workflows/typedb.yml @@ -0,0 +1,44 @@ +name: LocalStack TypeDB Extension Tests + +on: + pull_request: + workflow_dispatch: + +env: + LOCALSTACK_DISABLE_EVENTS: "1" + LOCALSTACK_AUTH_TOKEN: ${{ secrets.LOCALSTACK_AUTH_TOKEN }} + +jobs: + integration-tests: + name: Run Integration Tests + runs-on: ubuntu-latest + timeout-minutes: 10 + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup LocalStack and extension + run: | + cd typedb + + docker pull localstack/localstack-pro & + docker pull typedb/typedb & + pip install localstack + + make install + make dist + localstack extensions -v install file://$(ls ./dist/localstack_extension_typedb-*.tar.gz) + + DEBUG=1 localstack start -d + localstack wait + + - name: Run integration tests + run: | + cd typedb + make test + + - name: Print logs + if: always() + run: | + localstack logs + localstack stop diff --git a/README.md b/README.md index f3c2a49..1ffe6fc 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,7 @@ $ localstack extensions install "git+https://github.com/localstack/localstack-ex ## Official LocalStack Extensions Here is the current list of extensions developed by the LocalStack team and their support status. -You can install the respective extension by calling `localstack install `. +You can install the respective extension by calling `localstack extensions install `. | Extension | Install name | Version | Support status | |----------------------------------------------------------------------------------------------------| ------------ |---------| -------------- | @@ -75,6 +75,7 @@ You can install the respective extension by calling `localstack install list[str]: + result = DOCKER_CLIENT.inspect_image(self.DOCKER_IMAGE) + image_command = result["Config"]["Cmd"] + return image_command + + def should_proxy_request(self, headers: Headers) -> bool: + # determine if this is a gRPC request targeting TypeDB + content_type = headers.get("content-type") or "" + req_path = headers.get(":path") or "" + is_typedb_grpc_request = ( + "grpc" in content_type and "/typedb.protocol.TypeDB" in req_path + ) + return is_typedb_grpc_request + + def request_to_port_router(self, request: Request) -> int: + # TODO add REST API / gRPC routing based on request + return 1729 diff --git a/typedb/localstack_typedb/utils/__init__.py b/typedb/localstack_typedb/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/typedb/localstack_typedb/utils/docker.py b/typedb/localstack_typedb/utils/docker.py new file mode 100644 index 0000000..bac92f6 --- /dev/null +++ b/typedb/localstack_typedb/utils/docker.py @@ -0,0 +1,197 @@ +import re +import logging +from functools import cache +from typing import Callable +import requests + +from localstack import config +from localstack.config import is_env_true +from localstack_typedb.utils.h2_proxy import ( + apply_http2_patches_for_grpc_support, + ProxyRequestMatcher, +) +from localstack.utils.docker_utils import DOCKER_CLIENT +from localstack.extensions.api import Extension, http +from localstack.http import Request +from localstack.utils.container_utils.container_client import PortMappings +from localstack.utils.net import get_addressable_container_host +from localstack.utils.sync import retry +from rolo import route +from rolo.proxy import Proxy +from rolo.routing import RuleAdapter, WithHost + +LOG = logging.getLogger(__name__) +logging.getLogger("localstack_typedb").setLevel( + logging.DEBUG if config.DEBUG else logging.INFO +) +logging.basicConfig() + + +class ProxiedDockerContainerExtension(Extension, ProxyRequestMatcher): + """ + Utility class to create a LocalStack Extension backed by a Docker container that exposes a service + on a network port (or several ports), with requests being proxied through the LocalStack gateway. + + Requests may potentially use HTTP2 with binary content as the protocol (e.g., gRPC over HTTP2). + To ensure proper routing of requests, subclasses can define the `http2_ports`. + """ + + name: str + """Name of this extension""" + image_name: str + """Docker image name""" + container_name: str | None + """Name of the Docker container spun up by the extension""" + container_ports: list[int] + """List of network ports of the Docker container spun up by the extension""" + host: str | None + """ + Optional host on which to expose the container endpoints. + Can be either a static hostname, or a pattern like `myext.` + """ + path: str | None + """Optional path on which to expose the container endpoints.""" + command: list[str] | None + """Optional command (and flags) to execute in the container.""" + + request_to_port_router: Callable[[Request], int] | None + """Callable that returns the target port for a given request, for routing purposes""" + http2_ports: list[int] | None + """List of ports for which HTTP2 proxy forwarding into the container should be enabled.""" + + def __init__( + self, + image_name: str, + container_ports: list[int], + host: str | None = None, + path: str | None = None, + container_name: str | None = None, + command: list[str] | None = None, + request_to_port_router: Callable[[Request], int] | None = None, + http2_ports: list[int] | None = None, + ): + self.image_name = image_name + self.container_ports = container_ports + self.host = host + self.path = path + self.container_name = container_name + self.command = command + self.request_to_port_router = request_to_port_router + self.http2_ports = http2_ports + + def update_gateway_routes(self, router: http.Router[http.RouteHandler]): + if self.path: + raise NotImplementedError( + "Path-based routing not yet implemented for this extension" + ) + # note: for simplicity, starting the external container at startup - could be optimized over time ... + self.start_container() + # add resource for HTTP/1.1 requests + resource = RuleAdapter(ProxyResource(self)) + if self.host: + resource = WithHost(self.host, [resource]) + router.add(resource) + + # apply patches to serve HTTP/2 requests + for port in self.http2_ports or []: + apply_http2_patches_for_grpc_support( + get_addressable_container_host(), port, self + ) + + def on_platform_shutdown(self): + self._remove_container() + + def _get_container_name(self) -> str: + if self.container_name: + return self.container_name + name = f"ls-ext-{self.name}" + name = re.sub(r"\W", "-", name) + return name + + @cache + def start_container(self) -> None: + container_name = self._get_container_name() + LOG.debug("Starting extension container %s", container_name) + + ports = PortMappings() + for port in self.container_ports: + ports.add(port) + + kwargs = {} + if self.command: + kwargs["command"] = self.command + + try: + DOCKER_CLIENT.run_container( + self.image_name, + detach=True, + remove=True, + name=container_name, + ports=ports, + **kwargs, + ) + except Exception as e: + LOG.debug("Failed to start container %s: %s", container_name, e) + # allow running TypeDB in a local server in dev mode, if TYPEDB_DEV_MODE is enabled + if not is_env_true("TYPEDB_DEV_MODE"): + raise + + main_port = self.container_ports[0] + container_host = get_addressable_container_host() + + def _ping_endpoint(): + # TODO: allow defining a custom healthcheck endpoint ... + response = requests.get(f"http://{container_host}:{main_port}/") + assert response.ok + + try: + retry(_ping_endpoint, retries=40, sleep=1) + except Exception as e: + LOG.info("Failed to connect to container %s: %s", container_name, e) + self._remove_container() + raise + + LOG.debug("Successfully started extension container %s", container_name) + + def _remove_container(self): + container_name = self._get_container_name() + LOG.debug("Stopping extension container %s", container_name) + DOCKER_CLIENT.remove_container( + container_name, force=True, check_existence=False + ) + + +class ProxyResource: + """ + Simple proxy resource that forwards incoming requests from the + LocalStack Gateway to the target Docker container. + """ + + extension: ProxiedDockerContainerExtension + + def __init__(self, extension: ProxiedDockerContainerExtension): + self.extension = extension + + @route("/") + def index(self, request: Request, path: str, *args, **kwargs): + return self._proxy_request(request, forward_path=f"/{path}") + + def _proxy_request(self, request: Request, forward_path: str, *args, **kwargs): + self.extension.start_container() + + port = self.extension.container_ports[0] + container_host = get_addressable_container_host() + base_url = f"http://{container_host}:{port}" + proxy = Proxy(forward_base_url=base_url) + + # update content length (may have changed due to content compression) + if request.method not in ("GET", "OPTIONS"): + request.headers["Content-Length"] = str(len(request.data)) + + # make sure we're forwarding the correct Host header + request.headers["Host"] = f"localhost:{port}" + + # forward the request to the target + result = proxy.forward(request, forward_path=forward_path) + + return result diff --git a/typedb/localstack_typedb/utils/h2_proxy.py b/typedb/localstack_typedb/utils/h2_proxy.py new file mode 100644 index 0000000..2beccca --- /dev/null +++ b/typedb/localstack_typedb/utils/h2_proxy.py @@ -0,0 +1,166 @@ +import logging +import socket +from abc import abstractmethod + +from h2.frame_buffer import FrameBuffer +from hpack import Decoder +from hyperframe.frame import HeadersFrame, Frame +from twisted.internet import reactor + +from localstack.utils.patch import patch +from twisted.web._http2 import H2Connection +from werkzeug.datastructures import Headers + +LOG = logging.getLogger(__name__) + + +class ProxyRequestMatcher: + """ + Abstract base class that defines a request matcher, for an extension to define which incoming + request messages should be proxied to an upstream target (and which ones shouldn't). + """ + + @abstractmethod + def should_proxy_request(self, headers: Headers) -> bool: + """Define whether a request should be proxied, based on request headers.""" + + +class TcpForwarder: + """Simple helper class for bidirectional forwarding of TPC traffic.""" + + buffer_size: int = 1024 + """Data buffer size for receiving data from upstream socket.""" + + def __init__(self, port: int, host: str = "localhost"): + self.port = port + self.host = host + self._socket = None + self.connect() + + def connect(self): + if not self._socket: + self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._socket.connect((self.host, self.port)) + + def receive_loop(self, callback): + while True: + data = self._socket.recv(self.buffer_size) + callback(data) + if not data: + break + + def send(self, data): + self._socket.sendall(data) + + def close(self): + LOG.debug("Closing connection to upstream HTTP2 server on port %s", self.port) + try: + self._socket.shutdown(socket.SHUT_RDWR) + self._socket.close() + except Exception: + # swallow exceptions here (e.g., "bad file descriptor") + pass + + +def apply_http2_patches_for_grpc_support( + target_host: str, target_port: int, request_matcher: ProxyRequestMatcher +): + """ + Apply some patches to proxy incoming gRPC requests and forward them to a target port. + Note: this is a very brute-force approach and needs to be fixed/enhanced over time! + """ + + @patch(H2Connection.connectionMade) + def _connectionMade(fn, self, *args, **kwargs): + def _process(data): + LOG.debug("Received data (%s bytes) from upstream HTTP2 server", len(data)) + self.transport.write(data) + + # TODO: make port configurable + self._ls_forwarder = TcpForwarder(target_port, host=target_host) + LOG.debug( + "Starting TCP forwarder to port %s for new HTTP2 connection", target_port + ) + reactor.getThreadPool().callInThread(self._ls_forwarder.receive_loop, _process) + + @patch(H2Connection.dataReceived) + def _dataReceived(fn, self, data, *args, **kwargs): + forwarder = getattr(self, "_ls_forwarder", None) + should_proxy_request = getattr(self, "_ls_should_proxy_request", None) + if not forwarder or should_proxy_request is False: + return fn(self, data, *args, **kwargs) + + if should_proxy_request: + forwarder.send(data) + return + + setattr(self, "_data_received", getattr(self, "_data_received", [])) + self._data_received.append(data) + + # parse headers from request frames received so far + headers = get_headers_from_data_stream(self._data_received) + if not headers: + # if no headers received yet, then return (method will be called again for next chunk of data) + return + + # check if the incoming request should be proxies, based on the request headers + self._ls_should_proxy_request = request_matcher.should_proxy_request(headers) + + if not self._ls_should_proxy_request: + # if this is not a target request, then call the upstream function + result = None + for chunk in self._data_received: + result = fn(self, chunk, *args, **kwargs) + self._data_received = [] + return result + + # forward data chunks to the target + for chunk in self._data_received: + LOG.debug( + "Forwarding data (%s bytes) from HTTP2 client to server", len(chunk) + ) + forwarder.send(chunk) + self._data_received = [] + + @patch(H2Connection.connectionLost) + def connectionLost(fn, self, *args, **kwargs): + forwarder = getattr(self, "_ls_forwarder", None) + if not forwarder: + return fn(self, *args, **kwargs) + forwarder.close() + + +def get_headers_from_data_stream(data_list: list[bytes]) -> Headers: + """Get headers from a data stream (list of bytes data), if any headers are contained.""" + data_combined = b"".join(data_list) + frames = parse_http2_stream(data_combined) + headers = get_headers_from_frames(frames) + return headers + + +def get_headers_from_frames(frames: list[Frame]) -> Headers: + """Parse the given list of HTTP2 frames and return a dict of headers, if any""" + result = {} + decoder = Decoder() + for frame in frames: + if isinstance(frame, HeadersFrame): + try: + headers = decoder.decode(frame.data) + result.update(dict(headers)) + except Exception: + pass + return Headers(result) + + +def parse_http2_stream(data: bytes) -> list[Frame]: + """Parse the data from an HTTP2 stream into a list of frames""" + frames = [] + buffer = FrameBuffer(server=True) + buffer.max_frame_size = 16384 + buffer.add_data(data) + try: + for frame in buffer: + frames.append(frame) + except Exception: + pass + return frames diff --git a/typedb/pyproject.toml b/typedb/pyproject.toml new file mode 100644 index 0000000..71e348f --- /dev/null +++ b/typedb/pyproject.toml @@ -0,0 +1,39 @@ +[build-system] +requires = ["setuptools", "wheel", "plux>=1.3.1"] +build-backend = "setuptools.build_meta" + +[project] +name = "localstack-extension-typedb" +version = "0.1.0" +description = "LocalStack Extension: TypeDB on LocalStack" +readme = {file = "README.md", content-type = "text/markdown; charset=UTF-8"} +requires-python = ">=3.9" +authors = [ + { name = "LocalStack + TypeDB team"} +] +keywords = ["LocalStack", "TypeDB"] +classifiers = [] +dependencies = [ + "httpx", + "h2", + "priority", +] + +[project.urls] +Homepage = "https://github.com/whummer/localstack-utils" + +[project.optional-dependencies] +dev = [ + "boto3", + "build", + "jsonpatch", + "localstack", + "pytest", + "rolo", + "ruff", + "twisted", + "typedb-driver", +] + +[project.entry-points."localstack.extensions"] +localstack_typedb = "localstack_typedb.extension:TypeDbExtension" diff --git a/typedb/tests/test_extension.py b/typedb/tests/test_extension.py new file mode 100644 index 0000000..4d8c3bc --- /dev/null +++ b/typedb/tests/test_extension.py @@ -0,0 +1,84 @@ +import requests +from localstack.utils.strings import short_uid +from localstack_typedb.utils.h2_proxy import parse_http2_stream, get_headers_from_frames +from typedb.driver import TypeDB, Credentials, DriverOptions, TransactionType + + +def test_connect_to_db_via_http_api(): + host = "typedb.localhost.localstack.cloud:4566" + + # get auth token + response = requests.post( + f"http://{host}/v1/signin", json={"username": "admin", "password": "password"} + ) + assert response.ok + token = response.json()["token"] + + # create database + db_name = f"db{short_uid()}" + response = requests.post( + f"http://{host}/v1/databases/{db_name}", + json={}, + headers={"Authorization": f"bearer {token}"}, + ) + assert response.ok + + # list databases + response = requests.get( + f"http://{host}/v1/databases", headers={"Authorization": f"bearer {token}"} + ) + assert response.ok + databases = [db["name"] for db in response.json()["databases"]] + assert db_name in databases + + # clean up + response = requests.delete( + f"http://{host}/v1/databases/{db_name}", + headers={"Authorization": f"bearer {token}"}, + ) + assert response.ok + + +def test_connect_to_db_via_grpc_endpoint(): + db_name = "access-management-db" + server_host = "typedb.localhost.localstack.cloud:4566" + + driver_cfg = TypeDB.driver( + server_host, + Credentials("admin", "password"), + DriverOptions(is_tls_enabled=False), + ) + with driver_cfg as driver: + if driver.databases.contains(db_name): + driver.databases.get(db_name).delete() + driver.databases.create(db_name) + + with driver.transaction(db_name, TransactionType.SCHEMA) as tx: + tx.query("define entity person;").resolve() + tx.query("define attribute name, value string; person owns name;").resolve() + tx.commit() + + with driver.transaction(db_name, TransactionType.WRITE) as tx: + tx.query("insert $p isa person, has name 'Alice';").resolve() + tx.query("insert $p isa person, has name 'Bob';").resolve() + tx.commit() + with driver.transaction(db_name, TransactionType.READ) as tx: + results = tx.query( + 'match $p isa person; fetch {"name": $p.name};' + ).resolve() + for json in results: + print(json) + + +def test_parse_http2_frames(): + # note: the data below is a dump taken from a browser request made against the emulator + data = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n\x00\x00\x18\x04\x00\x00\x00\x00\x00\x00\x01\x00\x01\x00\x00\x00\x02\x00\x00\x00\x00\x00\x04\x00\x02\x00\x00\x00\x05\x00\x00@\x00\x00\x00\x04\x08\x00\x00\x00\x00\x00\x00\xbf\x00\x01" + data += b"\x00\x01V\x01%\x00\x00\x00\x03\x00\x00\x00\x00\x15C\x87\xd5\xaf~MZw\x7f\x05\x8eb*\x0eA\xd0\x84\x8c\x9dX\x9c\xa3\xa13\xffA\x96\xa0\xe4\x1d\x13\x9d\t^\x83\x90t!#'U\xc9A\xed\x92\xe3M\xb8\xe7\x87z\xbe\xd0\x7ff\xa2\x81\xb0\xda\xe0S\xfa\xd02\x1a\xa4\x9d\x13\xfd\xa9\x92\xa4\x96\x854\x0c\x8aj\xdc\xa7\xe2\x81\x02\xe1o\xedK;\xdc\x0bM.\x0f\xedLE'S\xb0 \x04\x00\x08\x02\xa6\x13XYO\xe5\x80\xb4\xd2\xe0S\x83\xf9c\xe7Q\x8b-Kp\xdd\xf4Z\xbe\xfb@\x05\xdbP\x92\x9b\xd9\xab\xfaRB\xcb@\xd2_\xa5#\xb3\xe9OhL\x9f@\x94\x19\x08T!b\x1e\xa4\xd8z\x16\xb0\xbd\xad*\x12\xb5%L\xe7\x93\x83\xc5\x83\x7f@\x95\x19\x08T!b\x1e\xa4\xd8z\x16\xb0\xbd\xad*\x12\xb4\xe5\x1c\x85\xb1\x1f\x89\x1d\xa9\x9c\xf6\x1b\xd8\xd2c\xd5s\x95\x9d)\xad\x17\x18`u\xd6\xbd\x07 \xe8BFN\xab\x92\x83\xdb#\x1f@\x85=\x86\x98\xd5\x7f\x94\x9d)\xad\x17\x18`u\xd6\xbd\x07 \xe8BFN\xab\x92\x83\xdb'@\x8aAH\xb4\xa5I'ZB\xa1?\x84-5\xa7\xd7@\x8aAH\xb4\xa5I'Z\x93\xc8_\x83!\xecG@\x8aAH\xb4\xa5I'Y\x06I\x7f\x86@\xe9*\xc82K@\x86\xae\xc3\x1e\xc3'\xd7\x83\xb6\x06\xbf@\x82I\x7f\x86M\x835\x05\xb1\x1f\x00\x00\x04\x08\x00\x00\x00\x00\x03\x00\xbe\x00\x00" + + frames = parse_http2_stream(data) + assert frames + headers = get_headers_from_frames(frames) + assert headers + assert headers[":scheme"] == "https" + assert headers[":method"] == "OPTIONS" + assert headers[":path"] == "/_localstack/health"