diff --git a/docs/datasources/robinhood.md b/docs/datasources/robinhood.md new file mode 100644 index 0000000..ccfb33e --- /dev/null +++ b/docs/datasources/robinhood.md @@ -0,0 +1,6 @@ +# RobinhoodDataSource + +> Requires the [`pynacl`](https://github.com/pyca/pynacl) library for cryptographic signing. You can install it manually: `pip install pynacl` +> or use `pip install pyspark-data-sources[robinhood]`. + +::: pyspark_datasources.robinhood.RobinhoodDataSource diff --git a/docs/index.md b/docs/index.md index 9aa6888..ee7983b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -42,4 +42,5 @@ spark.readStream.format("fake").load().writeStream.format("console").start() | [GoogleSheetsDataSource](./datasources/googlesheets.md) | `googlesheets` | Read table from public Google Sheets document | None | | [KaggleDataSource](./datasources/kaggle.md) | `kaggle` | Read datasets from Kaggle | `kagglehub`, `pandas` | | [JSONPlaceHolder](./datasources/jsonplaceholder.md) | `jsonplaceholder` | Read JSON data for testing and prototyping | None | -| [SalesforceDataSource](./datasources/salesforce.md) | `salesforce` | Write streaming data to Salesforce objects |`simple-salesforce` | +| [RobinhoodDataSource](./datasources/robinhood.md) | `robinhood` | Read cryptocurrency market data from Robinhood API | `pynacl` | +| [SalesforceDataSource](./datasources/salesforce.md) | `salesforce` | Write streaming data to Salesforce objects |`simple-salesforce` | \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 16b3ceb..6d0d8c0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -27,6 +27,7 @@ nav: - datasources/googlesheets.md - datasources/kaggle.md - datasources/jsonplaceholder.md + - datasources/robinhood.md markdown_extensions: - pymdownx.highlight: diff --git a/poetry.lock b/poetry.lock index 56b452f..61693c2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -234,7 +234,7 @@ description = "Foreign Function Interface for Python calling C code." optional = true python-versions = ">=3.8" groups = ["main"] -markers = "(extra == \"salesforce\" or extra == \"all\") and platform_python_implementation != \"PyPy\"" +markers = "(extra == \"salesforce\" or extra == \"all\") and platform_python_implementation != \"PyPy\" or extra == \"robinhood\" or extra == \"all\"" files = [ {file = "cffi-1.17.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14"}, {file = "cffi-1.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67"}, @@ -2162,7 +2162,7 @@ description = "C parser in Python" optional = true python-versions = ">=3.8" groups = ["main"] -markers = "(extra == \"salesforce\" or extra == \"all\") and platform_python_implementation != \"PyPy\"" +markers = "(extra == \"salesforce\" or extra == \"all\") and platform_python_implementation != \"PyPy\" or extra == \"robinhood\" or extra == \"all\"" files = [ {file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"}, {file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"}, @@ -2224,6 +2224,34 @@ pyyaml = "*" [package.extras] extra = ["pygments (>=2.19.1)"] +[[package]] +name = "pynacl" +version = "1.5.0" +description = "Python binding to the Networking and Cryptography (NaCl) library" +optional = true +python-versions = ">=3.6" +groups = ["main"] +markers = "extra == \"robinhood\" or extra == \"all\"" +files = [ + {file = "PyNaCl-1.5.0-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:401002a4aaa07c9414132aaed7f6836ff98f59277a234704ff66878c2ee4a0d1"}, + {file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:52cb72a79269189d4e0dc537556f4740f7f0a9ec41c1322598799b0bdad4ef92"}, + {file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a36d4a9dda1f19ce6e03c9a784a2921a4b726b02e1c736600ca9c22029474394"}, + {file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:0c84947a22519e013607c9be43706dd42513f9e6ae5d39d3613ca1e142fba44d"}, + {file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06b8f6fa7f5de8d5d2f7573fe8c863c051225a27b61e6860fd047b1775807858"}, + {file = "PyNaCl-1.5.0-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:a422368fc821589c228f4c49438a368831cb5bbc0eab5ebe1d7fac9dded6567b"}, + {file = "PyNaCl-1.5.0-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:61f642bf2378713e2c2e1de73444a3778e5f0a38be6fee0fe532fe30060282ff"}, + {file = "PyNaCl-1.5.0-cp36-abi3-win32.whl", hash = "sha256:e46dae94e34b085175f8abb3b0aaa7da40767865ac82c928eeb9e57e1ea8a543"}, + {file = "PyNaCl-1.5.0-cp36-abi3-win_amd64.whl", hash = "sha256:20f42270d27e1b6a29f54032090b972d97f0a1b0948cc52392041ef7831fee93"}, + {file = "PyNaCl-1.5.0.tar.gz", hash = "sha256:8ac7448f09ab85811607bdd21ec2464495ac8b7c66d146bf545b0f08fb9220ba"}, +] + +[package.dependencies] +cffi = ">=1.4.1" + +[package.extras] +docs = ["sphinx (>=1.6.5)", "sphinx-rtd-theme"] +tests = ["hypothesis (>=3.27.0)", "pytest (>=3.2.1,!=3.3.0)"] + [[package]] name = "pyspark" version = "4.0.0" @@ -2957,15 +2985,16 @@ test = ["big-O", "importlib-resources ; python_version < \"3.9\"", "jaraco.funct type = ["pytest-mypy"] [extras] -all = ["databricks-sdk", "datasets", "faker", "kagglehub", "simple-salesforce"] +all = ["databricks-sdk", "datasets", "faker", "kagglehub", "pynacl", "simple-salesforce"] databricks = ["databricks-sdk"] datasets = ["datasets"] faker = ["faker"] kaggle = ["kagglehub"] lance = [] +robinhood = ["pynacl"] salesforce = ["simple-salesforce"] [metadata] lock-version = "2.1" python-versions = ">=3.9,<3.13" -content-hash = "fe9bd32c2d9b9cb23070941a474fb6f409c16ebf1b2aa305f7dfdc08b7b4d290" +content-hash = "c1d4d9a66408c793a3f2b1b55185856de39e332a4315fdcffc1a55c60e647231" diff --git a/pyproject.toml b/pyproject.toml index 4d16967..81bcdc1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ datasets = {version = "^2.17.0", optional = true} databricks-sdk = {version = "^0.28.0", optional = true} kagglehub = {extras = ["pandas-datasets"], version = "^0.3.10", optional = true} simple-salesforce = {version = "^1.12.0", optional = true} +pynacl = {version = "^1.5.0", optional = true} [tool.poetry.extras] faker = ["faker"] @@ -26,8 +27,9 @@ datasets = ["datasets"] databricks = ["databricks-sdk"] kaggle = ["kagglehub"] lance = ["pylance"] +robinhood = ["pynacl"] salesforce = ["simple-salesforce"] -all = ["faker", "datasets", "databricks-sdk", "kagglehub", "simple-salesforce"] +all = ["faker", "datasets", "databricks-sdk", "kagglehub", "pynacl", "simple-salesforce"] [tool.poetry.group.dev.dependencies] pytest = "^8.0.0" diff --git a/pyspark_datasources/__init__.py b/pyspark_datasources/__init__.py index e1d1f18..75016e5 100644 --- a/pyspark_datasources/__init__.py +++ b/pyspark_datasources/__init__.py @@ -5,6 +5,7 @@ from .huggingface import HuggingFaceDatasets from .kaggle import KaggleDataSource from .opensky import OpenSkyDataSource +from .robinhood import RobinhoodDataSource from .salesforce import SalesforceDataSource from .simplejson import SimpleJsonDataSource from .stock import StockDataSource diff --git a/pyspark_datasources/robinhood.py b/pyspark_datasources/robinhood.py new file mode 100644 index 0000000..0f08842 --- /dev/null +++ b/pyspark_datasources/robinhood.py @@ -0,0 +1,266 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, Generator, Union +import requests +import json +import base64 +import datetime + +from pyspark.sql import Row +from pyspark.sql.types import StructType +from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition + + +@dataclass +class CryptoPair(InputPartition): + """Represents a single crypto trading pair partition for parallel processing.""" + + symbol: str + + +class RobinhoodDataReader(DataSourceReader): + """Reader implementation for Robinhood Crypto API data source.""" + + def __init__(self, schema: StructType, options: Dict[str, str]) -> None: + self.schema = schema + self.options = options + + # Required API authentication + self.api_key = options.get("api_key") + self.private_key_base64 = options.get("private_key") + + if not self.api_key or not self.private_key_base64: + raise ValueError( + "Robinhood Crypto API requires both 'api_key' and 'private_key' options. " + "The private_key should be base64-encoded. " + "Get your API credentials from https://docs.robinhood.com/crypto/trading/" + ) + + # Initialize NaCl signing key + try: + from nacl.signing import SigningKey + + private_key_seed = base64.b64decode(self.private_key_base64) + self.signing_key = SigningKey(private_key_seed) + except ImportError: + raise ImportError( + "PyNaCl library is required for Robinhood Crypto API authentication. " + "Install it with: pip install pynacl" + ) + except Exception as e: + raise ValueError(f"Invalid private key format: {str(e)}") + + # Crypto API base URL (configurable for testing) + self.base_url = options.get("base_url", "https://trading.robinhood.com") + + def _get_current_timestamp(self) -> int: + """Get current UTC timestamp.""" + return int(datetime.datetime.now(tz=datetime.timezone.utc).timestamp()) + + def _generate_signature(self, timestamp: int, method: str, path: str, body: str = "") -> str: + """Generate NaCl signature for API authentication following Robinhood's specification.""" + # Official Robinhood signature format: f"{api_key}{current_timestamp}{path}{method}{body}" + # For GET requests with no body, omit the body parameter + if method.upper() == "GET" and not body: + message_to_sign = f"{self.api_key}{timestamp}{path}{method.upper()}" + else: + message_to_sign = f"{self.api_key}{timestamp}{path}{method.upper()}{body}" + + signed = self.signing_key.sign(message_to_sign.encode("utf-8")) + signature = base64.b64encode(signed.signature).decode("utf-8") + return signature + + def _make_authenticated_request( + self, + method: str, + path: str, + params: Optional[Dict[str, str]] = None, + json_data: Optional[Dict] = None, + ) -> Optional[Dict]: + """Make an authenticated request to the Robinhood Crypto API.""" + timestamp = self._get_current_timestamp() + url = self.base_url + path + + # Prepare request body for signature (only for non-GET requests) + body = "" + if method.upper() != "GET" and json_data: + body = json.dumps(json_data, separators=(",", ":")) # Compact JSON format + + # Generate signature + signature = self._generate_signature(timestamp, method, path, body) + + # Set authentication headers + headers = { + "x-api-key": self.api_key, + "x-signature": signature, + "x-timestamp": str(timestamp), + } + + try: + # Make request + if method.upper() == "GET": + response = requests.get(url, headers=headers, params=params, timeout=10) + elif method.upper() == "POST": + headers["Content-Type"] = "application/json" + response = requests.post(url, headers=headers, json=json_data, timeout=10) + else: + response = requests.request( + method, url, headers=headers, params=params, json=json_data, timeout=10 + ) + + response.raise_for_status() + return response.json() + except requests.RequestException as e: + print(f"Error making API request to {path}: {e}") + return None + + @staticmethod + def _get_query_params(key: str, *args: str) -> str: + """Build query parameters for API requests.""" + if not args: + return "" + params = [f"{key}={arg}" for arg in args if arg] + return "?" + "&".join(params) + + def partitions(self) -> List[CryptoPair]: + """Create partitions for parallel processing of crypto pairs.""" + # Use specified symbols from path + symbols_str = self.options.get("path", "") + if not symbols_str: + raise ValueError("Must specify crypto pairs to load using .load('BTC-USD,ETH-USD')") + + # Split symbols by comma and create partitions + symbols = [symbol.strip().upper() for symbol in symbols_str.split(",")] + # Ensure proper format (e.g., BTC-USD) + formatted_symbols = [] + for symbol in symbols: + if symbol and "-" not in symbol: + symbol = f"{symbol}-USD" # Default to USD pair + if symbol: + formatted_symbols.append(symbol) + + return [CryptoPair(symbol=symbol) for symbol in formatted_symbols] + + def read(self, partition: CryptoPair) -> Generator[Row, None, None]: + """Read crypto data for a single trading pair partition.""" + symbol = partition.symbol + + try: + yield from self._read_crypto_pair_data(symbol) + except Exception as e: + # Log error but don't fail the entire job + print(f"Warning: Failed to fetch data for {symbol}: {str(e)}") + + def _read_crypto_pair_data(self, symbol: str) -> Generator[Row, None, None]: + """Fetch cryptocurrency market data for a given trading pair.""" + try: + # Get best bid/ask data for the trading pair using query parameters + path = f"/api/v1/crypto/marketdata/best_bid_ask/?symbol={symbol}" + market_data = self._make_authenticated_request("GET", path) + + if market_data and "results" in market_data: + for quote in market_data["results"]: + # Parse numeric values safely + def safe_float( + value: Union[str, int, float, None], default: float = 0.0 + ) -> float: + if value is None or value == "": + return default + try: + return float(value) + except (ValueError, TypeError): + return default + + # Extract market data fields from best bid/ask response + # Use the correct field names from the API response + price = safe_float(quote.get("price")) + bid_price = safe_float(quote.get("bid_inclusive_of_sell_spread")) + ask_price = safe_float(quote.get("ask_inclusive_of_buy_spread")) + + yield Row( + symbol=symbol, + price=price, + bid_price=bid_price, + ask_price=ask_price, + updated_at=quote.get("timestamp", ""), + ) + else: + print(f"Warning: No market data found for {symbol}") + + except requests.exceptions.RequestException as e: + print(f"Network error fetching data for {symbol}: {str(e)}") + except (ValueError, KeyError) as e: + print(f"Data parsing error for {symbol}: {str(e)}") + except Exception as e: + print(f"Unexpected error fetching data for {symbol}: {str(e)}") + + +class RobinhoodDataSource(DataSource): + """ + A data source for reading cryptocurrency data from Robinhood Crypto API. + + This data source allows you to fetch real-time cryptocurrency market data, + trading pairs, and price information using Robinhood's official Crypto API. + It implements proper API key authentication and signature-based security. + + Name: `robinhood` + + Schema: `symbol string, price double, bid_price double, ask_price double, updated_at string` + + Examples + -------- + Register the data source: + + >>> from pyspark_datasources import RobinhoodDataSource + >>> spark.dataSource.register(RobinhoodDataSource) + + Load cryptocurrency market data with API authentication: + + >>> df = spark.read.format("robinhood") \\ + ... .option("api_key", "your-api-key") \\ + ... .option("private_key", "your-base64-private-key") \\ + ... .load("BTC-USD,ETH-USD,DOGE-USD") + >>> df.show() + +--------+--------+---------+---------+--------------------+ + | symbol| price|bid_price|ask_price| updated_at| + +--------+--------+---------+---------+--------------------+ + |BTC-USD |45000.50|45000.25 |45000.75 |2024-01-15T16:00:...| + |ETH-USD | 2650.75| 2650.50 | 2651.00 |2024-01-15T16:00:...| + |DOGE-USD| 0.085| 0.084| 0.086|2024-01-15T16:00:...| + +--------+--------+---------+---------+--------------------+ + + + + Options + ------- + - api_key: string (required) — Robinhood Crypto API key. + - private_key: string (required) — Base64-encoded Ed25519 private key seed. + - base_url: string (optional, default "https://trading.robinhood.com") — Override for sandbox/testing. + + Errors + ------ + - Raises ValueError when required options are missing or private_key is invalid. + - Network/API errors are logged and skipped per symbol; no rows are emitted for failed symbols. + + Partitioning + ------------ + - One partition per requested trading pair (e.g., "BTC-USD,ETH-USD"). Symbols are uppercased and auto-appended with "-USD" if missing pair format. + + Arrow + ----- + - Rows are yielded directly; Arrow-based batches can be added in future for improved performance. + + Notes + ----- + - Requires 'pynacl' for Ed25519 signing: pip install pynacl + - Refer to official Robinhood documentation for authentication details. + """ + + @classmethod + def name(cls) -> str: + return "robinhood" + + def schema(self) -> str: + return "symbol string, price double, bid_price double, ask_price double, updated_at string" + + def reader(self, schema: StructType) -> RobinhoodDataReader: + return RobinhoodDataReader(schema, self.options) diff --git a/tests/test_robinhood.py b/tests/test_robinhood.py new file mode 100644 index 0000000..fdb0a21 --- /dev/null +++ b/tests/test_robinhood.py @@ -0,0 +1,168 @@ +import os +import pytest +from unittest.mock import Mock, patch +from pyspark.sql import SparkSession, Row +from pyspark.errors.exceptions.captured import AnalysisException + +from pyspark_datasources import RobinhoodDataSource + + +@pytest.fixture +def spark(): + """Create SparkSession for testing.""" + spark = SparkSession.builder.getOrCreate() + spark.dataSource.register(RobinhoodDataSource) + yield spark + + +def test_robinhood_datasource_registration(spark): + """Test that RobinhoodDataSource can be registered.""" + # Test registration + assert RobinhoodDataSource.name() == "robinhood" + + # Test schema + expected_schema = ( + "symbol string, price double, bid_price double, ask_price double, updated_at string" + ) + datasource = RobinhoodDataSource({}) + assert datasource.schema() == expected_schema + + +def test_robinhood_missing_credentials(spark): + """Test that missing API credentials raises an error.""" + with pytest.raises(AnalysisException) as excinfo: + df = spark.read.format("robinhood").load("BTC-USD") + df.collect() # Trigger execution + + assert "ValueError" in str(excinfo.value) and ( + "api_key" in str(excinfo.value) or "private_key" in str(excinfo.value) + ) + + +def test_robinhood_missing_symbols(spark): + """Test that missing symbols raises an error.""" + with pytest.raises(AnalysisException) as excinfo: + df = ( + spark.read.format("robinhood") + .option("api_key", "test-key") + .option("private_key", "FAPmPMsQqDFOFiRvpUMJ6BC5eFOh/tPx7qcTYGKc8nE=") + .load("") + ) + df.collect() # Trigger execution + + assert "ValueError" in str(excinfo.value) and "crypto pairs" in str(excinfo.value) + + +def test_robinhood_invalid_private_key_format(spark): + """Test that invalid private key format raises proper error.""" + with pytest.raises(AnalysisException) as excinfo: + df = ( + spark.read.format("robinhood") + .option("api_key", "test-key") + .option("private_key", "invalid-key-format") + .load("BTC-USD") + ) + df.collect() # Trigger execution + + assert "Invalid private key format" in str(excinfo.value) + + +def test_robinhood_btc_data(spark): + """Test BTC-USD data retrieval with registered API key - REQUIRES API CREDENTIALS.""" + # Get credentials from environment variables + api_key = os.environ.get("ROBINHOOD_API_KEY") + private_key = os.environ.get("ROBINHOOD_PRIVATE_KEY") + + if not api_key or not private_key: + pytest.skip( + "ROBINHOOD_API_KEY and ROBINHOOD_PRIVATE_KEY environment variables required for real API tests" + ) + + # Test loading BTC-USD data + df = ( + spark.read.format("robinhood") + .option("api_key", api_key) + .option("private_key", private_key) + .load("BTC-USD") + ) + + rows = df.collect() + print(f"Retrieved {len(rows)} rows") + + # CRITICAL: Test MUST fail if no data is returned + assert len(rows) > 0, "TEST FAILED: No data returned! Expected at least 1 BTC-USD record." + + for i, row in enumerate(rows): + print(f"Row {i + 1}: {row}") + + # Validate data structure + assert row.symbol == "BTC-USD", f"Expected BTC-USD, got {row.symbol}" + assert isinstance(row.price, (int, float)), ( + f"Price should be numeric, got {type(row.price)}" + ) + assert row.price > 0, f"Price should be > 0, got {row.price}" + assert isinstance(row.bid_price, (int, float)), ( + f"Bid price should be numeric, got {type(row.bid_price)}" + ) + assert isinstance(row.ask_price, (int, float)), ( + f"Ask price should be numeric, got {type(row.ask_price)}" + ) + assert isinstance(row.updated_at, str), ( + f"Updated timestamp should be string, got {type(row.updated_at)}" + ) + + +def test_robinhood_multiple_crypto_pairs(spark): + """Test multi-crypto data retrieval with registered API key - REQUIRES API CREDENTIALS.""" + # Get credentials from environment variables + api_key = os.environ.get("ROBINHOOD_API_KEY") + private_key = os.environ.get("ROBINHOOD_PRIVATE_KEY") + + if not api_key or not private_key: + pytest.skip( + "ROBINHOOD_API_KEY and ROBINHOOD_PRIVATE_KEY environment variables required for real API tests" + ) + + # Test loading multiple crypto pairs + df = ( + spark.read.format("robinhood") + .option("api_key", api_key) + .option("private_key", private_key) + .load("BTC-USD,ETH-USD,DOGE-USD") + ) + + rows = df.collect() + print(f"Retrieved {len(rows)} rows") + + # CRITICAL: Test MUST fail if no data is returned + assert len(rows) > 0, "TEST FAILED: No data returned! Expected at least 1 crypto record." + + # CRITICAL: Should get data for all 3 requested pairs + assert len(rows) >= 3, f"TEST FAILED: Expected 3 crypto pairs, got {len(rows)} records." + + symbols_found = set() + + for i, row in enumerate(rows): + symbols_found.add(row.symbol) + print(f"Row {i + 1}: {row}") + + # Validate each record + assert isinstance(row.symbol, str), f"Symbol should be string, got {type(row.symbol)}" + assert isinstance(row.price, (int, float)), ( + f"Price should be numeric, got {type(row.price)}" + ) + assert row.price > 0, f"Price should be > 0, got {row.price}" + assert isinstance(row.bid_price, (int, float)), ( + f"Bid price should be numeric, got {type(row.bid_price)}" + ) + assert isinstance(row.ask_price, (int, float)), ( + f"Ask price should be numeric, got {type(row.ask_price)}" + ) + assert isinstance(row.updated_at, str), ( + f"Updated timestamp should be string, got {type(row.updated_at)}" + ) + + # Test passes only if we have real data for the requested pairs + assert len(symbols_found) >= 3, ( + f"Expected at least 3 different symbols, got {len(symbols_found)}: {symbols_found}" + )