Skip to content

Commit 6d6c2f9

Browse files
committed
add type hints
1 parent bdf2072 commit 6d6c2f9

File tree

1 file changed

+69
-66
lines changed

1 file changed

+69
-66
lines changed

pyspark_datasources/robinhood.py

Lines changed: 69 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Dict
2+
from typing import Dict, List, Optional, Generator, Union
33
import requests
44
import json
55
import base64
@@ -10,67 +10,9 @@
1010
from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
1111

1212

13-
class RobinhoodDataSource(DataSource):
14-
"""
15-
A data source for reading cryptocurrency data from Robinhood Crypto API.
16-
17-
This data source allows you to fetch real-time cryptocurrency market data,
18-
trading pairs, and price information using Robinhood's official Crypto API.
19-
It implements proper API key authentication and signature-based security.
20-
21-
Name: `robinhood`
22-
23-
Schema: `symbol string, price double, bid_price double, ask_price double, updated_at string`
24-
25-
Examples
26-
--------
27-
Register the data source:
28-
29-
>>> from pyspark_datasources import RobinhoodDataSource
30-
>>> spark.dataSource.register(RobinhoodDataSource)
31-
32-
Load cryptocurrency market data with API authentication:
33-
34-
>>> df = spark.read.format("robinhood") \\
35-
... .option("api_key", "your-api-key") \\
36-
... .option("private_key", "your-base64-private-key") \\
37-
... .load("BTC-USD,ETH-USD,DOGE-USD")
38-
>>> df.show()
39-
+--------+--------+---------+---------+--------------------+
40-
| symbol| price|bid_price|ask_price| updated_at|
41-
+--------+--------+---------+---------+--------------------+
42-
|BTC-USD |45000.50|45000.25 |45000.75 |2024-01-15T16:00:...|
43-
|ETH-USD | 2650.75| 2650.50 | 2651.00 |2024-01-15T16:00:...|
44-
|DOGE-USD| 0.085| 0.084| 0.086|2024-01-15T16:00:...|
45-
+--------+--------+---------+---------+--------------------+
4613

4714

4815

49-
Notes
50-
-----
51-
- Requires valid Robinhood Crypto API credentials (API key and base64-encoded private key)
52-
- Supports all major cryptocurrencies available on Robinhood
53-
- Implements proper API authentication with NaCl (Sodium) signing
54-
- Rate limiting is handled automatically
55-
- Based on official Robinhood Crypto Trading API documentation
56-
- Requires 'pynacl' library for cryptographic signing: pip install pynacl
57-
- Reference: https://docs.robinhood.com/crypto/trading/
58-
"""
59-
60-
@classmethod
61-
def name(cls) -> str:
62-
return "robinhood"
63-
64-
def schema(self) -> str:
65-
return (
66-
"symbol string, price double, bid_price double, ask_price double, "
67-
"updated_at string"
68-
)
69-
70-
def reader(self, schema: StructType):
71-
return RobinhoodDataReader(schema, self.options)
72-
73-
7416
@dataclass
7517
class CryptoPair(InputPartition):
7618
"""Represents a single crypto trading pair partition for parallel processing."""
@@ -80,7 +22,7 @@ class CryptoPair(InputPartition):
8022
class RobinhoodDataReader(DataSourceReader):
8123
"""Reader implementation for Robinhood Crypto API data source."""
8224

83-
def __init__(self, schema: StructType, options: Dict):
25+
def __init__(self, schema: StructType, options: Dict[str, str]) -> None:
8426
self.schema = schema
8527
self.options = options
8628

@@ -131,7 +73,7 @@ def _generate_signature(self, timestamp: int, method: str, path: str, body: str
13173
signature = base64.b64encode(signed.signature).decode("utf-8")
13274
return signature
13375

134-
def _make_authenticated_request(self, method: str, path: str, params: Dict = None, json_data: Dict = None):
76+
def _make_authenticated_request(self, method: str, path: str, params: Optional[Dict[str, str]] = None, json_data: Optional[Dict] = None) -> Optional[Dict]:
13577
"""Make an authenticated request to the Robinhood Crypto API."""
13678
timestamp = self._get_current_timestamp()
13779
url = self.base_url + path
@@ -168,14 +110,14 @@ def _make_authenticated_request(self, method: str, path: str, params: Dict = Non
168110
return None
169111

170112
@staticmethod
171-
def _get_query_params(key: str, *args) -> str:
113+
def _get_query_params(key: str, *args: str) -> str:
172114
"""Build query parameters for API requests."""
173115
if not args:
174116
return ""
175117
params = [f"{key}={arg}" for arg in args if arg]
176118
return "?" + "&".join(params)
177119

178-
def partitions(self):
120+
def partitions(self) -> List[CryptoPair]:
179121
"""Create partitions for parallel processing of crypto pairs."""
180122
# Use specified symbols from path
181123
symbols_str = self.options.get("path", "")
@@ -196,7 +138,7 @@ def partitions(self):
196138

197139
return [CryptoPair(symbol=symbol) for symbol in formatted_symbols]
198140

199-
def read(self, partition: CryptoPair):
141+
def read(self, partition: CryptoPair) -> Generator[Row, None, None]:
200142
"""Read crypto data for a single trading pair partition."""
201143
symbol = partition.symbol
202144

@@ -206,7 +148,7 @@ def read(self, partition: CryptoPair):
206148
# Log error but don't fail the entire job
207149
print(f"Warning: Failed to fetch data for {symbol}: {str(e)}")
208150

209-
def _read_crypto_pair_data(self, symbol: str):
151+
def _read_crypto_pair_data(self, symbol: str) -> Generator[Row, None, None]:
210152
"""Fetch cryptocurrency market data for a given trading pair."""
211153
try:
212154
# Get best bid/ask data for the trading pair using query parameters
@@ -216,7 +158,7 @@ def _read_crypto_pair_data(self, symbol: str):
216158
if market_data and 'results' in market_data:
217159
for quote in market_data['results']:
218160
# Parse numeric values safely
219-
def safe_float(value, default=0.0):
161+
def safe_float(value: Union[str, int, float, None], default: float = 0.0) -> float:
220162
if value is None or value == "":
221163
return default
222164
try:
@@ -246,3 +188,64 @@ def safe_float(value, default=0.0):
246188
print(f"Data parsing error for {symbol}: {str(e)}")
247189
except Exception as e:
248190
print(f"Unexpected error fetching data for {symbol}: {str(e)}")
191+
192+
193+
class RobinhoodDataSource(DataSource):
194+
"""
195+
A data source for reading cryptocurrency data from Robinhood Crypto API.
196+
197+
This data source allows you to fetch real-time cryptocurrency market data,
198+
trading pairs, and price information using Robinhood's official Crypto API.
199+
It implements proper API key authentication and signature-based security.
200+
201+
Name: `robinhood`
202+
203+
Schema: `symbol string, price double, bid_price double, ask_price double, updated_at string`
204+
205+
Examples
206+
--------
207+
Register the data source:
208+
209+
>>> from pyspark_datasources import RobinhoodDataSource
210+
>>> spark.dataSource.register(RobinhoodDataSource)
211+
212+
Load cryptocurrency market data with API authentication:
213+
214+
>>> df = spark.read.format("robinhood") \\
215+
... .option("api_key", "your-api-key") \\
216+
... .option("private_key", "your-base64-private-key") \\
217+
... .load("BTC-USD,ETH-USD,DOGE-USD")
218+
>>> df.show()
219+
+--------+--------+---------+---------+--------------------+
220+
| symbol| price|bid_price|ask_price| updated_at|
221+
+--------+--------+---------+---------+--------------------+
222+
|BTC-USD |45000.50|45000.25 |45000.75 |2024-01-15T16:00:...|
223+
|ETH-USD | 2650.75| 2650.50 | 2651.00 |2024-01-15T16:00:...|
224+
|DOGE-USD| 0.085| 0.084| 0.086|2024-01-15T16:00:...|
225+
+--------+--------+---------+---------+--------------------+
226+
227+
228+
229+
Notes
230+
-----
231+
- Requires valid Robinhood Crypto API credentials (API key and base64-encoded private key)
232+
- Supports all major cryptocurrencies available on Robinhood
233+
- Implements proper API authentication with NaCl (Sodium) signing
234+
- Rate limiting is handled automatically
235+
- Based on official Robinhood Crypto Trading API documentation
236+
- Requires 'pynacl' library for cryptographic signing: pip install pynacl
237+
- Reference: https://docs.robinhood.com/crypto/trading/
238+
"""
239+
240+
@classmethod
241+
def name(cls) -> str:
242+
return "robinhood"
243+
244+
def schema(self) -> str:
245+
return (
246+
"symbol string, price double, bid_price double, ask_price double, "
247+
"updated_at string"
248+
)
249+
250+
def reader(self, schema: StructType) -> RobinhoodDataReader:
251+
return RobinhoodDataReader(schema, self.options)

0 commit comments

Comments
 (0)