Skip to content

Commit 9718baa

Browse files
committed
remove load all pairs
1 parent 2b483d2 commit 9718baa

File tree

2 files changed

+21
-53
lines changed

2 files changed

+21
-53
lines changed

pyspark_datasources/robinhood.py

Lines changed: 19 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -41,25 +41,10 @@ class RobinhoodDataSource(DataSource):
4141
+--------+--------+---------+---------+--------------------+
4242
|BTC-USD |45000.50|45000.25 |45000.75 |2024-01-15T16:00:...|
4343
|ETH-USD | 2650.75| 2650.50 | 2651.00 |2024-01-15T16:00:...|
44-
|DOGE-USD| 0.085| 0.084| 0.086|2024-01-15T16:00:...|
44+
|DOGE-USD| 0.085| 0.084| 0.086|2024-01-15T16:00:...|
4545
+--------+--------+---------+---------+--------------------+
4646
47-
Load data for specific trading pairs:
4847
49-
>>> df = spark.read.format("robinhood") \\
50-
... .option("api_key", "your-api-key") \\
51-
... .option("private_key", "your-base64-private-key") \\
52-
... .load("BTC-USD,ETH-USD")
53-
>>> df.show()
54-
55-
Load all available trading pairs:
56-
57-
>>> df = spark.read.format("robinhood") \\
58-
... .option("api_key", "your-api-key") \\
59-
... .option("private_key", "your-base64-private-key") \\
60-
... .option("load_all_pairs", "true") \\
61-
... .load()
62-
>>> df.show()
6348
6449
Notes
6550
-----
@@ -123,8 +108,7 @@ def __init__(self, schema: StructType, options: Dict):
123108
except Exception as e:
124109
raise ValueError(f"Invalid private key format: {str(e)}")
125110

126-
# Option to load all available pairs
127-
self.load_all_pairs = options.get("load_all_pairs", "false").lower() == "true"
111+
128112

129113
# Initialize session for connection pooling
130114
self.session = requests.Session()
@@ -198,39 +182,24 @@ def _get_query_params(key: str, *args) -> str:
198182

199183
def partitions(self):
200184
"""Create partitions for parallel processing of crypto pairs."""
201-
if self.load_all_pairs:
202-
# Get all available trading pairs
203-
try:
204-
path = "/api/v1/crypto/trading/trading_pairs/"
205-
pairs_data = self._make_authenticated_request("GET", path)
206-
207-
if pairs_data and 'results' in pairs_data:
208-
symbols = [pair['symbol'] for pair in pairs_data['results']]
209-
return [CryptoPair(symbol=symbol) for symbol in symbols]
210-
else:
211-
raise ValueError("No trading pairs data returned from API")
212-
except Exception as e:
213-
raise ValueError(f"Failed to fetch available trading pairs: {str(e)}")
214-
else:
215-
# Use specified symbols from path
216-
symbols_str = self.options.get("path", "")
217-
if not symbols_str:
218-
raise ValueError(
219-
"Must specify crypto pairs to load using .load('BTC-USD,ETH-USD') "
220-
"or use .option('load_all_pairs', 'true') to load all available pairs"
221-
)
222-
223-
# Split symbols by comma and create partitions
224-
symbols = [symbol.strip().upper() for symbol in symbols_str.split(",")]
225-
# Ensure proper format (e.g., BTC-USD)
226-
formatted_symbols = []
227-
for symbol in symbols:
228-
if symbol and '-' not in symbol:
229-
symbol = f"{symbol}-USD" # Default to USD pair
230-
if symbol:
231-
formatted_symbols.append(symbol)
185+
# Use specified symbols from path
186+
symbols_str = self.options.get("path", "")
187+
if not symbols_str:
188+
raise ValueError(
189+
"Must specify crypto pairs to load using .load('BTC-USD,ETH-USD')"
190+
)
232191

233-
return [CryptoPair(symbol=symbol) for symbol in formatted_symbols]
192+
# Split symbols by comma and create partitions
193+
symbols = [symbol.strip().upper() for symbol in symbols_str.split(",")]
194+
# Ensure proper format (e.g., BTC-USD)
195+
formatted_symbols = []
196+
for symbol in symbols:
197+
if symbol and '-' not in symbol:
198+
symbol = f"{symbol}-USD" # Default to USD pair
199+
if symbol:
200+
formatted_symbols.append(symbol)
201+
202+
return [CryptoPair(symbol=symbol) for symbol in formatted_symbols]
234203

235204
def read(self, partition: CryptoPair):
236205
"""Read crypto data for a single trading pair partition."""

tests/test_robinhood.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ def test_robinhood_missing_credentials(spark):
3838
assert "ValueError" in str(excinfo.value) and ("api_key" in str(excinfo.value) or "private_key" in str(excinfo.value))
3939

4040

41-
def test_robinhood_missing_symbols_without_load_all(spark):
42-
"""Test that missing symbols without load_all_pairs raises an error."""
41+
def test_robinhood_missing_symbols(spark):
42+
"""Test that missing symbols raises an error."""
4343
with pytest.raises(AnalysisException) as excinfo:
4444
df = spark.read.format("robinhood") \
4545
.option("api_key", "test-key") \
@@ -136,4 +136,3 @@ def test_robinhood_multiple_crypto_pairs(spark):
136136

137137
# Test passes only if we have real data for the requested pairs
138138
assert len(symbols_found) >= 3, f"Expected at least 3 different symbols, got {len(symbols_found)}: {symbols_found}"
139-

0 commit comments

Comments
 (0)