From 33bb9a298c1f115aa8e604bce14eb9261e4bdc22 Mon Sep 17 00:00:00 2001 From: Stefan Zabka Date: Fri, 9 Apr 2021 12:37:41 +0200 Subject: [PATCH] Removed collect_content from PySparkS3Dataset Downloading files via the SparkContext was much slower than downloading via boto (which is what S3Dataset does. So now both classes use the same method, as PySparkS3Dataset inherits from S3Dataset --- openwpm_utils/s3.py | 102 +++++++++++++++++--------------------------- 1 file changed, 40 insertions(+), 62 deletions(-) diff --git a/openwpm_utils/s3.py b/openwpm_utils/s3.py index b533e03..5654ba8 100644 --- a/openwpm_utils/s3.py +++ b/openwpm_utils/s3.py @@ -8,68 +8,7 @@ from pyarrow.filesystem import S3FSWrapper # noqa from pyspark.sql import SQLContext - -class PySparkS3Dataset(object): - def __init__(self, spark_context, s3_directory, - s3_bucket='openwpm-crawls'): - """Helper class to load OpenWPM datasets from S3 using PySpark - - Parameters - ---------- - spark_context - Spark context. In databricks, this is available via the `sc` - variable. - s3_directory : string - Directory within the S3 bucket in which the dataset is saved. - s3_bucket : string, optional - The bucket name on S3. Defaults to `openwpm-crawls`. - """ - self._s3_bucket = s3_bucket - self._s3_directory = s3_directory - self._spark_context = spark_context - self._sql_context = SQLContext(spark_context) - self._s3_table_loc = "s3a://%s/%s/visits/%%s/" % ( - s3_bucket, s3_directory) - self._s3_content_loc = "s3a://%s/%s/content/%%s.gz" % ( - s3_bucket, s3_directory) - - def read_table(self, table_name, columns=None): - """Read `table_name` from OpenWPM dataset into a pyspark dataframe. - - Parameters - ---------- - table_name : string - OpenWPM table to read - columns : list of strings - The set of columns to filter the parquet dataset by - """ - table = self._sql_context.read.parquet(self._s3_table_loc % table_name) - if columns is not None: - return table.select(columns) - return table - - def read_content(self, content_hash): - """Read the content corresponding to `content_hash`. - - NOTE: This can only be run in the driver process since it requires - access to the spark context - """ - return self._spark_context.textFile( - self._s3_content_loc % content_hash) - - def collect_content(self, content_hash, beautify=False): - """Collect content for `content_hash` to driver - - NOTE: This can only be run in the driver process since it requires - access to the spark context - """ - content = ''.join(self.read_content(content_hash).collect()) - if beautify: - return jsbeautifier.beautify(content) - return content - - -class S3Dataset(object): +class S3Dataset: def __init__(self, s3_directory, s3_bucket='openwpm-crawls'): """Helper class to load OpenWPM datasets from S3 using pandas @@ -134,3 +73,42 @@ def collect_content(self, content_hash, beautify=False): except IndexError: pass return content + +class PySparkS3Dataset(S3Dataset): + def __init__(self, spark_context, s3_directory, + s3_bucket='openwpm-crawls'): + """Helper class to load OpenWPM datasets from S3 using PySpark + + Parameters + ---------- + spark_context + Spark context. In databricks, this is available via the `sc` + variable. + s3_directory : string + Directory within the S3 bucket in which the dataset is saved. + s3_bucket : string, optional + The bucket name on S3. Defaults to `openwpm-crawls`. + """ + self._s3_bucket = s3_bucket + self._s3_directory = s3_directory + self._spark_context = spark_context + self._sql_context = SQLContext(spark_context) + self._s3_table_loc = "s3a://%s/%s/visits/%%s/" % ( + s3_bucket, s3_directory) + self._s3_content_loc = "s3a://%s/%s/content/%%s.gz" % ( + s3_bucket, s3_directory) + + def read_table(self, table_name, columns=None): + """Read `table_name` from OpenWPM dataset into a pyspark dataframe. + + Parameters + ---------- + table_name : string + OpenWPM table to read + columns : list of strings + The set of columns to filter the parquet dataset by + """ + table = self._sql_context.read.parquet(self._s3_table_loc % table_name) + if columns is not None: + return table.select(columns) + return table