diff --git a/pygeoapi/provider/opensearch_.py b/pygeoapi/provider/opensearch_.py index 1b46f6af4..bad77496a 100644 --- a/pygeoapi/provider/opensearch_.py +++ b/pygeoapi/provider/opensearch_.py @@ -35,9 +35,9 @@ import logging import uuid -from opensearchpy import OpenSearch, helpers +from opensearchpy import OpenSearch, RequestsHttpConnection, helpers from opensearch_dsl import Search - +from requests_auth_aws_sigv4 import AWSSigV4 from pygeofilter.backends.opensearch import to_filter from pygeoapi.crs import crs_transform @@ -65,14 +65,31 @@ def __init__(self, provider_def): self.select_properties = [] self.os_host, self.index_name = self.data.rsplit('/', 1) - + self.aws_role = provider_def.get("aws_role") + LOGGER.debug('Setting OpenSearch properties') LOGGER.debug(f'host: {self.os_host}') LOGGER.debug(f'index: {self.index_name}') - + LOGGER.debug(f'index: {self.aws_role}') + LOGGER.debug('Connecting to OpenSearch') - self.os_ = OpenSearch(self.os_host, verify_certs=0) + + token_url = 'http://169.254.169.254/latest/api/token' + token_headers = {'X-aws-ec2-metadata-token-ttl-seconds': '21600'} + token = requests.put(token_url, headers=token_headers).text + creds_headers = {'X-aws-ec2-metadata-token': token} + creds_url = f'http://169.254.169.254/latest/meta-data/iam/security-credentials/{self.aws_role}' + creds_json = requests.get(creds_url, headers=creds_headers).json() + + aws_auth = AWSSigV4('es', + aws_access_key_id=creds_json['AccessKeyId'], + aws_secret_access_key=creds_json['SecretAccessKey'], + aws_session_token=creds_json['Token'], + region='us-east-1') + + self.os_ = OpenSearch(hosts=self.os_host, http_auth=aws_auth, use_ssl=True, verify_certs=True, connection_class=RequestsHttpConnection) + if not self.os_.ping(): msg = f'Cannot connect to OpenSearch: {self.os_host}' LOGGER.error(msg) @@ -373,7 +390,7 @@ def get(self, identifier, **kwargs): LOGGER.debug(f'Query: {query}') try: - result = self.os_search(index=self.index_name, **query) + result = self.os_.search(index=self.index_name, body=query) if len(result['hits']['hits']) == 0: LOGGER.error(err) raise ProviderItemNotFoundError(err) @@ -425,7 +442,7 @@ def update(self, identifier, item): identifier, json_data = self._load_and_prepare_item( item, identifier, raise_if_exists=False) - _ = self.os_index(index=self.index_name, id=identifier, body=json_data) + _ = self.os_.index(index=self.index_name, id=identifier, body=json_data) return True @@ -439,7 +456,7 @@ def delete(self, identifier): """ LOGGER.debug(f'Deleting item {identifier}') - _ = self.os_delete(index=self.index_name, id=identifier) + _ = self.os_.delete(index=self.index_name, id=identifier) return True