From ca0d11053b45e82c1d72abe6040ad8c13cb0d84b Mon Sep 17 00:00:00 2001 From: Tom Runting Date: Fri, 29 Aug 2025 16:19:37 +0100 Subject: [PATCH 1/6] Return a list of paths to files we download for each ID --- hepdata_cli/api.py | 41 +++++++++++++++++++++++++++++++++-------- tests/test_download.py | 4 +++- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/hepdata_cli/api.py b/hepdata_cli/api.py index bf263c2..489054f 100644 --- a/hepdata_cli/api.py +++ b/hepdata_cli/api.py @@ -87,13 +87,18 @@ def download(self, id_list, file_format=None, ids=None, table_name='', download_ :param ids: accepts one of ('inspire', 'hepdata'). It specifies what type of ids have been passed. :param table_name: restricts download to specific tables. :param download_dir: defaults to ./hepdata-downloads. Specifies where to download the files. + + :return: dictionary mapping id to list of downloaded files. """ - urls = self._build_urls(id_list, file_format, ids, table_name) - for url in urls: + url_map = self._build_urls(id_list, file_format, ids, table_name) + file_map = {} + for id, url in url_map.items(): if self.verbose is True: print("Downloading: " + url) - download_url(url, download_dir) + files_downloaded = download_url(url, download_dir) + file_map[id] = files_downloaded + return file_map def fetch_names(self, id_list, ids=None): """ @@ -102,9 +107,9 @@ def fetch_names(self, id_list, ids=None): :param id_list: list of id of records of which to return table names. :param ids: accepts one of ('inspire', 'hepdata'). It specifies what type of ids have been passed. """ - urls = self._build_urls(id_list, 'json', ids, '') + url_map = self._build_urls(id_list, 'json', ids, '') table_names = [] - for url in urls: + for url in url_map.values(): response = resilient_requests('get', url) json_dict = response.json() table_names += [[data_table['name'] for data_table in json_dict['data_tables']]] @@ -136,7 +141,16 @@ def upload(self, path_to_file, email, recid=None, invitation_cookie=None, sandbo print('Uploaded ' + path_to_file + ' to ' + SITE_URL + '/record/' + str(recid)) def _build_urls(self, id_list, file_format, ids, table_name): - """Builds urls for download and fetch_names, given the specified parameters.""" + """ + Builds urls for download and fetch_names, given the specified parameters. + + :param id_list: list of ids to download. + :param file_format: accepts one of ('csv', 'root', 'yaml', 'yoda', 'yoda1', 'yoda.h5', 'json'). + :param ids: accepts one of ('inspire', 'hepdata'). + :param table_name: restricts download to specific tables. + + :return: dictionary mapping id to url. + """ if type(id_list) not in (tuple, list): id_list = id_list.split() assert len(id_list) > 0, 'Ids are required.' @@ -146,9 +160,12 @@ def _build_urls(self, id_list, file_format, ids, table_name): params = {'format': file_format} else: params = {'format': file_format, 'table': table_name} - urls = [resilient_requests('get', SITE_URL + '/record/' + ('ins' if ids == 'inspire' else '') + id_entry, params=params).url.replace('%2525', '%25') for id_entry in id_list] + url_mapping = {} + for id_entry in id_list: + url = resilient_requests('get', SITE_URL + '/record/' + ('ins' if ids == 'inspire' else '') + id_entry, params=params).url.replace('%2525', '%25') + url_mapping[id_entry] = url # TODO: Investigate root cause of double URL encoding (https://github.com/HEPData/hepdata-cli/issues/8). - return urls + return url_mapping def _query(self, query, page, size): """Builds the search query passed to hepdata.net.""" @@ -170,6 +187,7 @@ def mkdir(directory): def download_url(url, download_dir): """Download file and if necessary extract it.""" + files_downloaded = [] assert is_downloadable(url), "Given url is not downloadable: {}".format(url) response = resilient_requests('get', url, allow_redirects=True) if url[-4:] == 'json': @@ -184,8 +202,15 @@ def download_url(url, download_dir): if filepath.endswith("tar.gz") or filepath.endswith("tar"): tar = tarfile.open(filepath, "r:gz" if filepath.endswith("tar.gz") else "r:") tar.extractall(path=os.path.dirname(filepath)) + for member in tar.getmembers(): + if member.isfile(): + extracted_path = os.path.join(os.path.dirname(filepath), member.name) + files_downloaded.append(extracted_path) tar.close() os.remove(filepath) + else: + files_downloaded.append(filepath) + return files_downloaded def getFilename_fromCd(cd): diff --git a/tests/test_download.py b/tests/test_download.py index 28a727b..5f87260 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -57,8 +57,10 @@ def test_api_download(id_list, file_format, ids, table): mkdir(test_download_dir) assert len(os.listdir(test_download_dir)) == 0 client = Client(verbose=True) - client.download(id_list, file_format, ids, table, test_download_dir) + path_map = client.download(id_list, file_format, ids, table, test_download_dir) + file_paths = [fp for fps in path_map.values() for fp in fps] assert len(os.listdir(test_download_dir)) > 0 + assert all(os.path.exists(fp) for fp in file_paths) cleanup(test_download_dir) From 5578fe2d65719a414c134527d1c1138a5ff64aa4 Mon Sep 17 00:00:00 2001 From: Tom Runting Date: Mon, 1 Sep 2025 09:35:25 +0100 Subject: [PATCH 2/6] Better exception handling in tar unpacking --- hepdata_cli/api.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/hepdata_cli/api.py b/hepdata_cli/api.py index 489054f..ff2ab24 100644 --- a/hepdata_cli/api.py +++ b/hepdata_cli/api.py @@ -200,14 +200,28 @@ def download_url(url, download_dir): mkdir(os.path.dirname(filepath)) open(filepath, 'wb').write(response.content) if filepath.endswith("tar.gz") or filepath.endswith("tar"): - tar = tarfile.open(filepath, "r:gz" if filepath.endswith("tar.gz") else "r:") - tar.extractall(path=os.path.dirname(filepath)) - for member in tar.getmembers(): - if member.isfile(): - extracted_path = os.path.join(os.path.dirname(filepath), member.name) - files_downloaded.append(extracted_path) - tar.close() - os.remove(filepath) + tar = None + try: + tar = tarfile.open(filepath, "r:gz" if filepath.endswith("tar.gz") else "r:") + extract_dir = os.path.abspath(os.path.dirname(filepath)) + tar.extractall(path=os.path.dirname(filepath)) + for member in tar.getmembers(): + if member.isfile(): + extracted_path = os.path.join(os.path.dirname(filepath), member.name) + abs_extracted_path = os.path.abspath(extracted_path) + if abs_extracted_path.startswith(extract_dir + os.sep) and os.path.exists(abs_extracted_path): + files_downloaded.append(abs_extracted_path) + elif not abs_extracted_path.startswith(extract_dir + os.sep): + raise ValueError(f"Attempted path traversal for file {member.name}") + else: + raise FileNotFoundError(f"Extracted file {member.name} not found") + except Exception as e: + raise Exception(f"Failed to extract {filepath}: {str(e)}") + finally: + if tar: + tar.close() + if os.path.exists(filepath): + os.remove(filepath) else: files_downloaded.append(filepath) return files_downloaded From c8626d9fa085b188f8711ab0f597c44146aea887 Mon Sep 17 00:00:00 2001 From: Tom Runting Date: Mon, 1 Sep 2025 09:36:30 +0100 Subject: [PATCH 3/6] Update `download` documentation --- hepdata_cli/api.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/hepdata_cli/api.py b/hepdata_cli/api.py index ff2ab24..60e7ec6 100644 --- a/hepdata_cli/api.py +++ b/hepdata_cli/api.py @@ -89,15 +89,16 @@ def download(self, id_list, file_format=None, ids=None, table_name='', download_ :param download_dir: defaults to ./hepdata-downloads. Specifies where to download the files. :return: dictionary mapping id to list of downloaded files. + :rtype: dict[int, list[str]] """ url_map = self._build_urls(id_list, file_format, ids, table_name) file_map = {} - for id, url in url_map.items(): + for record_id, url in url_map.items(): if self.verbose is True: print("Downloading: " + url) files_downloaded = download_url(url, download_dir) - file_map[id] = files_downloaded + file_map[record_id] = files_downloaded return file_map def fetch_names(self, id_list, ids=None): From 5e3e70daeab5296f08890e12663cf26febb56425 Mon Sep 17 00:00:00 2001 From: Tom Runting Date: Mon, 1 Sep 2025 09:36:38 +0100 Subject: [PATCH 4/6] Bump version --- hepdata_cli/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hepdata_cli/version.py b/hepdata_cli/version.py index d31c31e..493f741 100644 --- a/hepdata_cli/version.py +++ b/hepdata_cli/version.py @@ -1 +1 @@ -__version__ = "0.2.3" +__version__ = "0.3.0" From 5af1d4f8e30da232251acda756fc0f78fe6d8760 Mon Sep 17 00:00:00 2001 From: Tom Runting Date: Mon, 1 Sep 2025 09:44:42 +0100 Subject: [PATCH 5/6] Update example 4 --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6317015..0b9bbd0 100644 --- a/README.md +++ b/README.md @@ -155,10 +155,11 @@ or equivalently ```python id_list = client.find('reactions:"P P--> LQ LQ"', ids='inspire') -client.download(id_list, ids='inspire', file_format='csv') +downloads = client.download(id_list, ids='inspire', file_format='csv') +print(downloads) # {'1222326': ['./hepdata-downloads/HEPData-ins1222326-v1-csv/Table1.csv', ...], ...} ``` -downloads four .tar.gz archives containing csv files and unpacks them in the default ```./hepdata-downloads``` directory. +downloads four .tar.gz archives containing csv files and unpacks them in the default ```./hepdata-downloads``` directory. Using the API, a dictionary mapping ids to the downloaded files is returned. ### Example 5 - find table names in records: From 89ddc2ee2ecf9b49cff0e99a7033586f5d160ee1 Mon Sep 17 00:00:00 2001 From: Tom Runting Date: Mon, 1 Sep 2025 11:56:52 +0100 Subject: [PATCH 6/6] Add tests for tar unpacking --- tests/test_download.py | 73 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) diff --git a/tests/test_download.py b/tests/test_download.py index 5f87260..50ceb3f 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -3,10 +3,13 @@ import pytest import os import shutil +import tarfile +import tempfile +from unittest.mock import patch from click.testing import CliRunner -from hepdata_cli.api import Client, mkdir +from hepdata_cli.api import Client, download_url, mkdir from hepdata_cli.cli import cli @@ -76,3 +79,71 @@ def test_cli_download(id_list, file_format, ids, table): assert result.exit_code == 0 assert len(os.listdir(test_download_dir)) > 0 cleanup(test_download_dir) + + +# utility function testing + +@pytest.mark.parametrize("files_raises", [{"file": "test.txt", "raises": False}, + {"file": "../test.txt", "raises": True}, + {"file": None, "raises": True}]) +def test_tar_unpack(files_raises): + """ + Test the unpacking of a tarfile + """ + filename = files_raises["file"] + raises = files_raises["raises"] + if filename is None: # To hit FileNotFoundError branch + filename = 'test.txt' + real_exists = os.path.exists + def mock_exists(path): + if path.endswith(filename): + return False + return real_exists(path) + exists_patcher = patch('os.path.exists', mock_exists) + exists_patcher.start() + + # Create a some tarfile with known content + with tempfile.NamedTemporaryFile(delete=False, suffix='.tar.gz') as tmp: + tar_path = tmp.name + with tarfile.open(tar_path, "w:gz") as tar: + info = tarfile.TarInfo(name=filename) + content = b"Hello, World!" + info.size = len(content) + temp_content_file = tempfile.NamedTemporaryFile(delete=False) + try: + temp_content_file.write(content) + temp_content_file.close() + tar.add(temp_content_file.name, arcname=filename) + finally: + os.remove(temp_content_file.name) + + test_download_dir = './.pytest_downloads/' + mkdir(test_download_dir) + assert len(os.listdir(test_download_dir)) == 0 + + # Mock the requests part to return our tarfile + with patch('hepdata_cli.api.is_downloadable', return_value=True), \ + patch('hepdata_cli.api.resilient_requests') as mock_requests, \ + patch('hepdata_cli.api.getFilename_fromCd', return_value='test.tar.gz'): + + mock_response = mock_requests.return_value + mock_response.content = open(tar_path, 'rb').read() + mock_response.headers = {'content-disposition': 'filename=test.tar.gz'} + + # Test the download_url function + try: + if raises: + with pytest.raises(Exception): + files = download_url('http://example.com/test.tar.gz', test_download_dir) + else: + files = download_url('http://example.com/test.tar.gz', test_download_dir) + assert len(files) == 1 + for f in files: + assert os.path.exists(f) + with open(f, 'rb') as fr: + assert fr.read() == b"Hello, World!" + finally: + exists_patcher.stop() if filename is None else None + if os.path.exists(tar_path): + os.remove(tar_path) + cleanup(test_download_dir)