diff --git a/docs/datasources/pokemon.md b/docs/datasources/pokemon.md new file mode 100644 index 0000000..9ce8f9f --- /dev/null +++ b/docs/datasources/pokemon.md @@ -0,0 +1,5 @@ +# PokemonDataSource + +> Uses the public [PokeAPI](https://pokeapi.co/) to retrieve Pokemon data. No API key required. + +::: pyspark_datasources.pokemon.PokemonDataSource \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index f907e77..064d831 100644 --- a/docs/index.md +++ b/docs/index.md @@ -36,3 +36,4 @@ spark.read.format("github").load("apache/spark").show() | [StockDataSource](./datasources/stock.md) | `stock` | Read stock data from Alpha Vantage | None | | [SimpleJsonDataSource](./datasources/simplejson.md) | `simplejson` | Read JSON data from a file | `databricks-sdk` | | [GoogleSheetsDataSource](./datasources/googlesheets.md) | `googlesheets` | Read table from public Google Sheets document | None | +| [PokemonDataSource](./datasources/pokemon.md) | `pokemon` | Read Pokemon data from the PokeAPI | None | diff --git a/pyspark_datasources/__init__.py b/pyspark_datasources/__init__.py index 89413b3..47e40ae 100644 --- a/pyspark_datasources/__init__.py +++ b/pyspark_datasources/__init__.py @@ -4,3 +4,4 @@ from .huggingface import HuggingFaceDatasets from .simplejson import SimpleJsonDataSource from .stock import StockDataSource +from .pokemon import PokemonDataSource diff --git a/pyspark_datasources/pokemon.py b/pyspark_datasources/pokemon.py new file mode 100644 index 0000000..ade7bb4 --- /dev/null +++ b/pyspark_datasources/pokemon.py @@ -0,0 +1,150 @@ +import json +import requests + +from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition +from pyspark.sql.types import StructType, StructField, StringType, IntegerType, ArrayType + + +class PokemonDataSource(DataSource): + """ + A data source for reading Pokemon data from the PokeAPI (https://pokeapi.co/). + + Options + ------- + + - endpoint: The API endpoint to query. Default is "pokemon". + Valid options include: "pokemon", "type", "ability", "berry", etc. + - limit: Maximum number of results to return. Default is 20. + - offset: Number of results to skip. Default is 0. + + Examples + -------- + + Register the data source: + + >>> from pyspark_datasources import PokemonDataSource + >>> spark.dataSource.register(PokemonDataSource) + + Load Pokemon data: + + >>> df = spark.read.format("pokemon").option("limit", 10).load() + >>> df.show() + +----+-----------+--------+------+--------------------+ + | id| name| height|weight| abilities| + +----+-----------+--------+------+--------------------+ + | 1| bulbasaur| 7| 69|[overgrow, chloro...| + | 2| ivysaur| 10| 130|[overgrow, chloro...| + | 3| venusaur| 20| 1000|[overgrow, chloro...| + | 4| charmander| 6| 85|[blaze, solar-power]| + | 5| charmeleon| 11| 190|[blaze, solar-power]| + | 6| charizard| 17| 905|[blaze, solar-power]| + | 7| squirtle| 5| 90|[torrent, rain-dish]| + | 8| wartortle| 10| 225|[torrent, rain-dish]| + | 9| blastoise| 16| 855|[torrent, rain-dish]| + | 10| caterpie | 3| 29| [shield-dust]| + +----+-----------+--------+------+--------------------+ + + Load specific Pokemon types: + + >>> df = spark.read.format("pokemon").option("endpoint", "type").load() + >>> df.show() + """ + + @classmethod + def name(cls) -> str: + return "pokemon" + + def schema(self) -> str: + if self.options.get("endpoint", "pokemon") == "pokemon": + return ( + "id integer, name string, height integer, " + "weight integer, abilities array" + ) + elif self.options.get("endpoint") == "type": + return "id integer, name string, pokemon array" + else: + # Generic schema for other endpoints + return "id integer, name string, details string" + + def reader(self, schema): + return PokemonDataReader(schema, self.options) + + +class PokemonPartition(InputPartition): + def __init__(self, endpoint, limit, offset): + self.endpoint = endpoint + self.limit = limit + self.offset = offset + + +class PokemonDataReader(DataSourceReader): + def __init__(self, schema: StructType, options: dict): + self.schema = schema + self.options = options + self.endpoint = options.get("endpoint", "pokemon") + self.limit = int(options.get("limit", 20)) + self.offset = int(options.get("offset", 0)) + self.base_url = "https://pokeapi.co/api/v2" + + def partitions(self): + # Create a single partition for simplicity + return [PokemonPartition(self.endpoint, self.limit, self.offset)] + + def read(self, partition: PokemonPartition): + session = requests.Session() + + # Fetch list of resources + url = f"{self.base_url}/{partition.endpoint}?limit={partition.limit}&offset={partition.offset}" + response = session.get(url) + response.raise_for_status() + results = response.json()["results"] + + # Process based on endpoint type + if partition.endpoint == "pokemon": + for result in results: + pokemon_data = self._fetch_pokemon(result["url"], session) + abilities = [ability["ability"]["name"] for ability in pokemon_data["abilities"]] + + yield ( + pokemon_data["id"], + pokemon_data["name"], + pokemon_data["height"], + pokemon_data["weight"], + abilities + ) + + elif partition.endpoint == "type": + for result in results: + type_data = self._fetch_resource(result["url"], session) + pokemon_names = [pokemon["pokemon"]["name"] for pokemon in type_data["pokemon"]] + + yield ( + type_data["id"], + type_data["name"], + pokemon_names + ) + + else: + # Generic handler for other endpoints + for result in results: + resource_data = self._fetch_resource(result["url"], session) + + yield ( + resource_data.get("id", 0), + resource_data.get("name", ""), + json.dumps(resource_data) + ) + + @staticmethod + def _fetch_pokemon(url, session): + """Fetch detailed Pokemon data""" + response = session.get(url) + response.raise_for_status() + return response.json() + + @staticmethod + def _fetch_resource(url, session): + """Fetch any resource data""" + response = session.get(url) + response.raise_for_status() + return response.json() \ No newline at end of file diff --git a/tests/test_pokemon.py b/tests/test_pokemon.py new file mode 100644 index 0000000..3974c1b --- /dev/null +++ b/tests/test_pokemon.py @@ -0,0 +1,86 @@ +import unittest +from unittest.mock import patch, MagicMock +import json + +from pyspark_datasources.pokemon import PokemonDataSource, PokemonDataReader, PokemonPartition + + +class TestPokemonDataSource(unittest.TestCase): + def test_name(self): + self.assertEqual(PokemonDataSource.name(), "pokemon") + + def test_schema_pokemon(self): + source = PokemonDataSource({"endpoint": "pokemon"}) + expected_schema = ( + "id integer, name string, height integer, " + "weight integer, abilities array" + ) + self.assertEqual(source.schema(), expected_schema) + + def test_schema_type(self): + source = PokemonDataSource({"endpoint": "type"}) + expected_schema = "id integer, name string, pokemon array" + self.assertEqual(source.schema(), expected_schema) + + def test_schema_other(self): + source = PokemonDataSource({"endpoint": "ability"}) + expected_schema = "id integer, name string, details string" + self.assertEqual(source.schema(), expected_schema) + + +class TestPokemonDataReader(unittest.TestCase): + @patch('pyspark_datasources.pokemon.requests.Session') + def test_read_pokemon(self, mock_session): + # Mock response for the list endpoint + mock_response_list = MagicMock() + mock_response_list.json.return_value = { + "results": [ + {"name": "bulbasaur", "url": "https://pokeapi.co/api/v2/pokemon/1/"} + ] + } + + # Mock response for the individual pokemon + mock_response_pokemon = MagicMock() + mock_response_pokemon.json.return_value = { + "id": 1, + "name": "bulbasaur", + "height": 7, + "weight": 69, + "abilities": [ + {"ability": {"name": "overgrow"}}, + {"ability": {"name": "chlorophyll"}} + ] + } + + # Set up the session mock + mock_session_instance = mock_session.return_value + mock_session_instance.get.side_effect = [mock_response_list, mock_response_pokemon] + + # Create reader and partition + schema = None # Not used in the test + reader = PokemonDataReader(schema, {"endpoint": "pokemon", "limit": 1}) + partition = PokemonPartition("pokemon", 1, 0) + + # Get results + results = list(reader.read(partition)) + + # Verify the results + self.assertEqual(len(results), 1) + pokemon = results[0] + self.assertEqual(pokemon[0], 1) # id + self.assertEqual(pokemon[1], "bulbasaur") # name + self.assertEqual(pokemon[2], 7) # height + self.assertEqual(pokemon[3], 69) # weight + self.assertEqual(pokemon[4], ["overgrow", "chlorophyll"]) # abilities + + # Verify the correct URLs were called + mock_session_instance.get.assert_any_call( + "https://pokeapi.co/api/v2/pokemon?limit=1&offset=0" + ) + mock_session_instance.get.assert_any_call( + "https://pokeapi.co/api/v2/pokemon/1/" + ) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file