From 5a719c6e0c791cc9d3dd7371ebf40f97268a8c7c Mon Sep 17 00:00:00 2001 From: Martin Habedank Date: Wed, 15 Oct 2025 10:53:58 +0100 Subject: [PATCH 1/5] let find() take format as argumen to allow returning as list --- hepdata_cli/api.py | 16 ++++++++++++---- tests/test_search.py | 18 +++++++++--------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/hepdata_cli/api.py b/hepdata_cli/api.py index 60e7ec6..f604217 100644 --- a/hepdata_cli/api.py +++ b/hepdata_cli/api.py @@ -32,15 +32,18 @@ def __init__(self, verbose=False): # check service availability resilient_requests('get', SITE_URL + '/ping') - def find(self, query, keyword=None, ids=None, max_matches=MAX_MATCHES, matches_per_page=MATCHES_PER_PAGE): + def find(self, query, keyword=None, ids=None, max_matches=MAX_MATCHES, matches_per_page=MATCHES_PER_PAGE, format=str): """ Search function for the hepdata database. Calls hepdata.net search function. :param query: string passed to hepdata.net search function. See advanced search tips at hepdata.net. :param keyword: filters return dictionary for given keyword. Exact match is first attempted, otherwise partial match is accepted. :param ids: accepts one of ("arxiv", "inspire", "hepdata"). + :param max_matches: maximum number of matches to return. Default is 10,000. + :param matches_per_page: number of matches per page. Default is 10. + :param format: specifies the return format if 'ids' is specified. Allowed formats are: str, list. Default is str. - :return: returns a list of (filtered if 'keyword' is specified) dictionaries for the search matches. If 'ids' is specified it instead returns a list of ids as a string. + :return: returns a list of (filtered if 'keyword' is specified) dictionaries for the search matches. If 'ids' is specified it instead returns a list of ids in the format 'format'. """ find_results = [] for counter in range(int(max_matches / matches_per_page)): @@ -53,7 +56,7 @@ def find(self, query, keyword=None, ids=None, max_matches=MAX_MATCHES, matches_p # return full list of dictionary find_results += data['results'] else: - assert ids in [None, "arxiv", "inspire", "hepdata", "id"], "allowd ids are: arxiv, inspire and hepdata" + assert ids in [None, "arxiv", "inspire", "hepdata", "id"], "allowed ids are: arxiv, inspire and hepdata" if ids is not None: if ids == "hepdata": ids = "id" @@ -76,7 +79,12 @@ def find(self, query, keyword=None, ids=None, max_matches=MAX_MATCHES, matches_p if ids is None: return find_results else: - return ' '.join(find_results) + if format==str: + return ' '.join(find_results) + elif format==list: + return find_results + else: + raise TypeError(f"Cannot return results in specfied format: {format}. Allowed formats are: {str}, {list}.") def download(self, id_list, file_format=None, ids=None, table_name='', download_dir='./hepdata-downloads'): """ diff --git a/tests/test_search.py b/tests/test_search.py index ccf5633..18c5682 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -4,16 +4,17 @@ from click.testing import CliRunner -from hepdata_cli.api import Client +from hepdata_cli.api import Client, MAX_MATCHES, MATCHES_PER_PAGE from hepdata_cli.cli import cli # arguments for testing test_api_find_arguments = [ - ('reactions:"P P --> LQ LQ X"', None, None), - ('reactions:"P P --> LQ LQ"', 'year', None), - ('phrases:"(diffractive AND elastic)"', None, 'arxiv'), + ('reactions:"P P --> LQ LQ X"', None, None, None), + ('reactions:"P P --> LQ LQ"', 'year', None, None), + ('phrases:"(diffractive AND elastic)"', None, 'arxiv', str), + ('phrases:"(diffractive AND elastic)"', None, 'arxiv', list), ] test_cli_find_arguments = [ @@ -24,17 +25,16 @@ # api test -@pytest.mark.parametrize("query, keyword, ids", test_api_find_arguments) -def test_api_find(query, keyword, ids): +@pytest.mark.parametrize("query, keyword, ids, format", test_api_find_arguments) +def test_api_find(query, keyword, ids, format): client = Client(verbose=True) - search_result = client.find(query, keyword, ids) + search_result = client.find(query, keyword, ids, format=format) if ids is None: assert type(search_result) is list if len(search_result) > 0: assert all([type(entry) is dict for entry in search_result]) else: - assert type(search_result) is str - + assert type(search_result) is format # cli testing From eb26509fc66609bd2e84fbe0510653c0a90a38ae Mon Sep 17 00:00:00 2001 From: Martin Habedank Date: Wed, 15 Oct 2025 11:05:23 +0100 Subject: [PATCH 2/5] add set as format --- hepdata_cli/api.py | 4 +++- tests/test_search.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/hepdata_cli/api.py b/hepdata_cli/api.py index f604217..155d502 100644 --- a/hepdata_cli/api.py +++ b/hepdata_cli/api.py @@ -41,7 +41,7 @@ def find(self, query, keyword=None, ids=None, max_matches=MAX_MATCHES, matches_p :param ids: accepts one of ("arxiv", "inspire", "hepdata"). :param max_matches: maximum number of matches to return. Default is 10,000. :param matches_per_page: number of matches per page. Default is 10. - :param format: specifies the return format if 'ids' is specified. Allowed formats are: str, list. Default is str. + :param format: specifies the return format if 'ids' is specified. Allowed formats are: str, list, set. Default is str. :return: returns a list of (filtered if 'keyword' is specified) dictionaries for the search matches. If 'ids' is specified it instead returns a list of ids in the format 'format'. """ @@ -83,6 +83,8 @@ def find(self, query, keyword=None, ids=None, max_matches=MAX_MATCHES, matches_p return ' '.join(find_results) elif format==list: return find_results + elif format==set: + return set(find_results) else: raise TypeError(f"Cannot return results in specfied format: {format}. Allowed formats are: {str}, {list}.") diff --git a/tests/test_search.py b/tests/test_search.py index 18c5682..2e99234 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -15,6 +15,7 @@ ('reactions:"P P --> LQ LQ"', 'year', None, None), ('phrases:"(diffractive AND elastic)"', None, 'arxiv', str), ('phrases:"(diffractive AND elastic)"', None, 'arxiv', list), + ('reactions:"P P --> LQ LQ X"', None, 'arxiv', set), ] test_cli_find_arguments = [ From ca7c3a177b4aad261d8d21def805540510b568b1 Mon Sep 17 00:00:00 2001 From: Martin Habedank Date: Wed, 15 Oct 2025 11:55:25 +0100 Subject: [PATCH 3/5] update readme, add tuple case, extend tests to download --- README.md | 5 +++-- hepdata_cli/api.py | 10 +++++----- hepdata_cli/version.py | 2 +- tests/test_download.py | 33 +++++++++++++++++++++++++++------ tests/test_search.py | 3 ++- 5 files changed, 38 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 340ef59..e504c33 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,8 @@ client.fetch_names(id_list, ids) client.upload(path_to_file, email, recid, invitation_cookie, sandbox, password) ``` +`client.find()` takes the keyword argument `format` to specify which format from `str`, `list`, `set`, or `tuple` shall be returned. +Default is `str`. ## Examples @@ -188,8 +190,7 @@ Then, ```python import hepdata_cli hepdata_client = hepdata_cli.Client() -id_list = hepdata_client.find('reactions:"P P --> LQ LQ X"', ids='arxiv') -id_list = id_list.split() +id_list = hepdata_client.find('reactions:"P P --> LQ LQ X"', ids='arxiv', format=list) print(id_list) # ['1605.06035', '2101.11582', ...] import arxiv diff --git a/hepdata_cli/api.py b/hepdata_cli/api.py index 155d502..a4be3bc 100644 --- a/hepdata_cli/api.py +++ b/hepdata_cli/api.py @@ -41,7 +41,7 @@ def find(self, query, keyword=None, ids=None, max_matches=MAX_MATCHES, matches_p :param ids: accepts one of ("arxiv", "inspire", "hepdata"). :param max_matches: maximum number of matches to return. Default is 10,000. :param matches_per_page: number of matches per page. Default is 10. - :param format: specifies the return format if 'ids' is specified. Allowed formats are: str, list, set. Default is str. + :param format: specifies the return format if 'ids' is specified. Allowed formats are: str, list, set, tuple. Default is str. :return: returns a list of (filtered if 'keyword' is specified) dictionaries for the search matches. If 'ids' is specified it instead returns a list of ids in the format 'format'. """ @@ -83,8 +83,8 @@ def find(self, query, keyword=None, ids=None, max_matches=MAX_MATCHES, matches_p return ' '.join(find_results) elif format==list: return find_results - elif format==set: - return set(find_results) + elif format in (set, tuple): + return format(find_results) else: raise TypeError(f"Cannot return results in specfied format: {format}. Allowed formats are: {str}, {list}.") @@ -155,14 +155,14 @@ def _build_urls(self, id_list, file_format, ids, table_name): """ Builds urls for download and fetch_names, given the specified parameters. - :param id_list: list of ids to download. + :param id_list: list of ids to download. Format is tuple, list, set or space-separated string. :param file_format: accepts one of ('csv', 'root', 'yaml', 'yoda', 'yoda1', 'yoda.h5', 'json'). :param ids: accepts one of ('inspire', 'hepdata'). :param table_name: restricts download to specific tables. :return: dictionary mapping id to url. """ - if type(id_list) not in (tuple, list): + if type(id_list) not in (tuple, list, set): id_list = id_list.split() assert len(id_list) > 0, 'Ids are required.' assert file_format in ALLOWED_FORMATS, f"allowed formats are: {ALLOWED_FORMATS}" diff --git a/hepdata_cli/version.py b/hepdata_cli/version.py index 493f741..260c070 100644 --- a/hepdata_cli/version.py +++ b/hepdata_cli/version.py @@ -1 +1 @@ -__version__ = "0.3.0" +__version__ = "0.3.1" diff --git a/tests/test_download.py b/tests/test_download.py index 50ceb3f..4e834f8 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -38,13 +38,23 @@ def cleanup(directory): test_api_download_arguments = [ (["73322"], "json", "hepdata", ''), - (["1222326", "1694381", "1462258", "1309874"], "csv", "inspire", ''), + ("1222326 1694381 1462258 1309874", "csv", "inspire", ''), # str + (["1222326", "1694381", "1462258", "1309874"], "csv", "inspire", ''), # list + ({"1222326", "1694381", "1462258", "1309874"}, "csv", "inspire", ''), # set + (("1222326", "1694381", "1462258", "1309874"), "csv", "inspire", ''), # tuple (["61434"], "yaml", "hepdata", "Table1"), (["1762350"], "yoda", "inspire", "Number density and Sum p_T pT>0.15 GeV/c"), (["2862529"], "yoda.h5", "inspire", "95% CL upper limit on XSEC times BF"), (["2862529"], "yoda.h5", "inspire", '') ] +test_api_find_download_arguments = [ + ("json", "hepdata", str), + ("csv", "inspire", list), + ("json", "inspire", set), + ("csv", "hepdata", tuple), +] + test_cli_download_arguments = [ (["2862529"], "json", "inspire", ''), (["1222326", "1694381", "1462258", "1309874"], "root", "inspire", ''), @@ -54,18 +64,29 @@ def cleanup(directory): # api testing +def download_and_test(client, id_list, file_format, ids, table, test_download_dir): + path_map = client.download(id_list, file_format, ids, table, test_download_dir) + file_paths = [fp for fps in path_map.values() for fp in fps] + assert len(os.listdir(test_download_dir)) > 0 + assert all(os.path.exists(fp) for fp in file_paths) + cleanup(test_download_dir) + @pytest.mark.parametrize("id_list, file_format, ids, table", test_api_download_arguments) def test_api_download(id_list, file_format, ids, table): test_download_dir = './.pytest_downloads/' mkdir(test_download_dir) assert len(os.listdir(test_download_dir)) == 0 client = Client(verbose=True) - path_map = client.download(id_list, file_format, ids, table, test_download_dir) - file_paths = [fp for fps in path_map.values() for fp in fps] - assert len(os.listdir(test_download_dir)) > 0 - assert all(os.path.exists(fp) for fp in file_paths) - cleanup(test_download_dir) + download_and_test(client, id_list, file_format, ids, table, test_download_dir) +@pytest.mark.parametrize("file_format, ids, format", test_api_find_download_arguments) +def test_api_find_download(file_format, ids, format): + test_download_dir = './.pytest_downloads/' + mkdir(test_download_dir) + assert len(os.listdir(test_download_dir)) == 0 + client = Client(verbose=True) + id_list = client.find('reactions:"P P --> LQ LQ"', ids=ids, format=format) + download_and_test(client, id_list, file_format, ids, '', test_download_dir) # cli testing diff --git a/tests/test_search.py b/tests/test_search.py index 2e99234..82f89d7 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -14,8 +14,9 @@ ('reactions:"P P --> LQ LQ X"', None, None, None), ('reactions:"P P --> LQ LQ"', 'year', None, None), ('phrases:"(diffractive AND elastic)"', None, 'arxiv', str), - ('phrases:"(diffractive AND elastic)"', None, 'arxiv', list), + ('phrases:"(diffractive AND elastic)"', None, 'hepdata', list), ('reactions:"P P --> LQ LQ X"', None, 'arxiv', set), + ('reactions:"P P --> LQ LQ X"', None, 'inspire', tuple), ] test_cli_find_arguments = [ From 5c6b8dcdbf723b59498afe0718154be1c6f6741b Mon Sep 17 00:00:00 2001 From: Martin Habedank Date: Wed, 15 Oct 2025 14:17:40 +0100 Subject: [PATCH 4/5] check whether id_list is instance of str instead of negating other possibilities --- hepdata_cli/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hepdata_cli/api.py b/hepdata_cli/api.py index a4be3bc..a0d7c5e 100644 --- a/hepdata_cli/api.py +++ b/hepdata_cli/api.py @@ -162,7 +162,7 @@ def _build_urls(self, id_list, file_format, ids, table_name): :return: dictionary mapping id to url. """ - if type(id_list) not in (tuple, list, set): + if isinstance(id_list, str): id_list = id_list.split() assert len(id_list) > 0, 'Ids are required.' assert file_format in ALLOWED_FORMATS, f"allowed formats are: {ALLOWED_FORMATS}" From aec4772190c3071512cecb75facdb29683e8ecfd Mon Sep 17 00:00:00 2001 From: Martin Habedank Date: Wed, 15 Oct 2025 16:16:43 +0100 Subject: [PATCH 5/5] correct error message --- hepdata_cli/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hepdata_cli/api.py b/hepdata_cli/api.py index a0d7c5e..91252ad 100644 --- a/hepdata_cli/api.py +++ b/hepdata_cli/api.py @@ -86,7 +86,7 @@ def find(self, query, keyword=None, ids=None, max_matches=MAX_MATCHES, matches_p elif format in (set, tuple): return format(find_results) else: - raise TypeError(f"Cannot return results in specfied format: {format}. Allowed formats are: {str}, {list}.") + raise TypeError(f"Cannot return results in specfied format: {format}. Allowed formats are: {str}, {list}, {set}, {tuple}.") def download(self, id_list, file_format=None, ids=None, table_name='', download_dir='./hepdata-downloads'): """