1010from pyspark .sql .datasource import DataSource , DataSourceReader , InputPartition
1111
1212
13-
14-
15-
1613@dataclass
1714class CryptoPair (InputPartition ):
1815 """Represents a single crypto trading pair partition for parallel processing."""
16+
1917 symbol : str
2018
2119
@@ -25,21 +23,22 @@ class RobinhoodDataReader(DataSourceReader):
2523 def __init__ (self , schema : StructType , options : Dict [str , str ]) -> None :
2624 self .schema = schema
2725 self .options = options
28-
26+
2927 # Required API authentication
3028 self .api_key = options .get ("api_key" )
3129 self .private_key_base64 = options .get ("private_key" )
32-
30+
3331 if not self .api_key or not self .private_key_base64 :
3432 raise ValueError (
3533 "Robinhood Crypto API requires both 'api_key' and 'private_key' options. "
3634 "The private_key should be base64-encoded. "
3735 "Get your API credentials from https://docs.robinhood.com/crypto/trading/"
3836 )
39-
37+
4038 # Initialize NaCl signing key
4139 try :
4240 from nacl .signing import SigningKey
41+
4342 private_key_seed = base64 .b64decode (self .private_key_base64 )
4443 self .signing_key = SigningKey (private_key_seed )
4544 except ImportError :
@@ -49,17 +48,14 @@ def __init__(self, schema: StructType, options: Dict[str, str]) -> None:
4948 )
5049 except Exception as e :
5150 raise ValueError (f"Invalid private key format: { str (e )} " )
52-
53-
5451
55-
5652 # Crypto API base URL (configurable for testing)
5753 self .base_url = options .get ("base_url" , "https://trading.robinhood.com" )
5854
5955 def _get_current_timestamp (self ) -> int :
6056 """Get current UTC timestamp."""
6157 return int (datetime .datetime .now (tz = datetime .timezone .utc ).timestamp ())
62-
58+
6359 def _generate_signature (self , timestamp : int , method : str , path : str , body : str = "" ) -> str :
6460 """Generate NaCl signature for API authentication following Robinhood's specification."""
6561 # Official Robinhood signature format: f"{api_key}{current_timestamp}{path}{method}{body}"
@@ -68,41 +64,49 @@ def _generate_signature(self, timestamp: int, method: str, path: str, body: str
6864 message_to_sign = f"{ self .api_key } { timestamp } { path } { method .upper ()} "
6965 else :
7066 message_to_sign = f"{ self .api_key } { timestamp } { path } { method .upper ()} { body } "
71-
67+
7268 signed = self .signing_key .sign (message_to_sign .encode ("utf-8" ))
7369 signature = base64 .b64encode (signed .signature ).decode ("utf-8" )
7470 return signature
7571
76- def _make_authenticated_request (self , method : str , path : str , params : Optional [Dict [str , str ]] = None , json_data : Optional [Dict ] = None ) -> Optional [Dict ]:
72+ def _make_authenticated_request (
73+ self ,
74+ method : str ,
75+ path : str ,
76+ params : Optional [Dict [str , str ]] = None ,
77+ json_data : Optional [Dict ] = None ,
78+ ) -> Optional [Dict ]:
7779 """Make an authenticated request to the Robinhood Crypto API."""
7880 timestamp = self ._get_current_timestamp ()
7981 url = self .base_url + path
80-
82+
8183 # Prepare request body for signature (only for non-GET requests)
8284 body = ""
8385 if method .upper () != "GET" and json_data :
84- body = json .dumps (json_data , separators = (',' , ':' )) # Compact JSON format
85-
86+ body = json .dumps (json_data , separators = ("," , ":" )) # Compact JSON format
87+
8688 # Generate signature
8789 signature = self ._generate_signature (timestamp , method , path , body )
88-
90+
8991 # Set authentication headers
9092 headers = {
91- ' x-api-key' : self .api_key ,
92- ' x-signature' : signature ,
93- ' x-timestamp' : str (timestamp )
93+ " x-api-key" : self .api_key ,
94+ " x-signature" : signature ,
95+ " x-timestamp" : str (timestamp ),
9496 }
95-
97+
9698 try :
9799 # Make request
98100 if method .upper () == "GET" :
99101 response = requests .get (url , headers = headers , params = params , timeout = 10 )
100102 elif method .upper () == "POST" :
101- headers [' Content-Type' ] = ' application/json'
103+ headers [" Content-Type" ] = " application/json"
102104 response = requests .post (url , headers = headers , json = json_data , timeout = 10 )
103105 else :
104- response = requests .request (method , url , headers = headers , params = params , json = json_data , timeout = 10 )
105-
106+ response = requests .request (
107+ method , url , headers = headers , params = params , json = json_data , timeout = 10
108+ )
109+
106110 response .raise_for_status ()
107111 return response .json ()
108112 except requests .RequestException as e :
@@ -116,32 +120,30 @@ def _get_query_params(key: str, *args: str) -> str:
116120 return ""
117121 params = [f"{ key } ={ arg } " for arg in args if arg ]
118122 return "?" + "&" .join (params )
119-
123+
120124 def partitions (self ) -> List [CryptoPair ]:
121125 """Create partitions for parallel processing of crypto pairs."""
122126 # Use specified symbols from path
123127 symbols_str = self .options .get ("path" , "" )
124128 if not symbols_str :
125- raise ValueError (
126- "Must specify crypto pairs to load using .load('BTC-USD,ETH-USD')"
127- )
128-
129+ raise ValueError ("Must specify crypto pairs to load using .load('BTC-USD,ETH-USD')" )
130+
129131 # Split symbols by comma and create partitions
130132 symbols = [symbol .strip ().upper () for symbol in symbols_str .split ("," )]
131133 # Ensure proper format (e.g., BTC-USD)
132134 formatted_symbols = []
133135 for symbol in symbols :
134- if symbol and '-' not in symbol :
136+ if symbol and "-" not in symbol :
135137 symbol = f"{ symbol } -USD" # Default to USD pair
136138 if symbol :
137139 formatted_symbols .append (symbol )
138-
140+
139141 return [CryptoPair (symbol = symbol ) for symbol in formatted_symbols ]
140142
141143 def read (self , partition : CryptoPair ) -> Generator [Row , None , None ]:
142144 """Read crypto data for a single trading pair partition."""
143145 symbol = partition .symbol
144-
146+
145147 try :
146148 yield from self ._read_crypto_pair_data (symbol )
147149 except Exception as e :
@@ -154,34 +156,36 @@ def _read_crypto_pair_data(self, symbol: str) -> Generator[Row, None, None]:
154156 # Get best bid/ask data for the trading pair using query parameters
155157 path = f"/api/v1/crypto/marketdata/best_bid_ask/?symbol={ symbol } "
156158 market_data = self ._make_authenticated_request ("GET" , path )
157-
158- if market_data and ' results' in market_data :
159- for quote in market_data [' results' ]:
159+
160+ if market_data and " results" in market_data :
161+ for quote in market_data [" results" ]:
160162 # Parse numeric values safely
161- def safe_float (value : Union [str , int , float , None ], default : float = 0.0 ) -> float :
163+ def safe_float (
164+ value : Union [str , int , float , None ], default : float = 0.0
165+ ) -> float :
162166 if value is None or value == "" :
163167 return default
164168 try :
165169 return float (value )
166170 except (ValueError , TypeError ):
167171 return default
168-
172+
169173 # Extract market data fields from best bid/ask response
170174 # Use the correct field names from the API response
171- price = safe_float (quote .get (' price' ))
172- bid_price = safe_float (quote .get (' bid_inclusive_of_sell_spread' ))
173- ask_price = safe_float (quote .get (' ask_inclusive_of_buy_spread' ))
174-
175+ price = safe_float (quote .get (" price" ))
176+ bid_price = safe_float (quote .get (" bid_inclusive_of_sell_spread" ))
177+ ask_price = safe_float (quote .get (" ask_inclusive_of_buy_spread" ))
178+
175179 yield Row (
176180 symbol = symbol ,
177181 price = price ,
178182 bid_price = bid_price ,
179183 ask_price = ask_price ,
180- updated_at = quote .get (' timestamp' , "" )
184+ updated_at = quote .get (" timestamp" , "" ),
181185 )
182186 else :
183187 print (f"Warning: No market data found for { symbol } " )
184-
188+
185189 except requests .exceptions .RequestException as e :
186190 print (f"Network error fetching data for { symbol } : { str (e )} " )
187191 except (ValueError , KeyError ) as e :
@@ -256,10 +260,7 @@ def name(cls) -> str:
256260 return "robinhood"
257261
258262 def schema (self ) -> str :
259- return (
260- "symbol string, price double, bid_price double, ask_price double, "
261- "updated_at string"
262- )
263+ return "symbol string, price double, bid_price double, ask_price double, updated_at string"
263264
264265 def reader (self , schema : StructType ) -> RobinhoodDataReader :
265- return RobinhoodDataReader (schema , self .options )
266+ return RobinhoodDataReader (schema , self .options )
0 commit comments