Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions pygeoapi/provider/opensearch_.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
import logging
import uuid

from opensearchpy import OpenSearch, helpers
from opensearchpy import OpenSearch, RequestsHttpConnection, helpers
from opensearch_dsl import Search

from requests_auth_aws_sigv4 import AWSSigV4
from pygeofilter.backends.opensearch import to_filter

from pygeoapi.crs import crs_transform
Expand Down Expand Up @@ -65,14 +65,31 @@ def __init__(self, provider_def):
self.select_properties = []

self.os_host, self.index_name = self.data.rsplit('/', 1)

self.aws_role = provider_def.get("aws_role")

LOGGER.debug('Setting OpenSearch properties')

LOGGER.debug(f'host: {self.os_host}')
LOGGER.debug(f'index: {self.index_name}')

LOGGER.debug(f'index: {self.aws_role}')

LOGGER.debug('Connecting to OpenSearch')
self.os_ = OpenSearch(self.os_host, verify_certs=0)

token_url = 'http://169.254.169.254/latest/api/token'
token_headers = {'X-aws-ec2-metadata-token-ttl-seconds': '21600'}
token = requests.put(token_url, headers=token_headers).text
creds_headers = {'X-aws-ec2-metadata-token': token}
creds_url = f'http://169.254.169.254/latest/meta-data/iam/security-credentials/{self.aws_role}'
creds_json = requests.get(creds_url, headers=creds_headers).json()

aws_auth = AWSSigV4('es',
aws_access_key_id=creds_json['AccessKeyId'],
aws_secret_access_key=creds_json['SecretAccessKey'],
aws_session_token=creds_json['Token'],
region='us-east-1')

self.os_ = OpenSearch(hosts=self.os_host, http_auth=aws_auth, use_ssl=True, verify_certs=True, connection_class=RequestsHttpConnection)

if not self.os_.ping():
msg = f'Cannot connect to OpenSearch: {self.os_host}'
LOGGER.error(msg)
Expand Down Expand Up @@ -373,7 +390,7 @@ def get(self, identifier, **kwargs):

LOGGER.debug(f'Query: {query}')
try:
result = self.os_search(index=self.index_name, **query)
result = self.os_.search(index=self.index_name, body=query)
if len(result['hits']['hits']) == 0:
LOGGER.error(err)
raise ProviderItemNotFoundError(err)
Expand Down Expand Up @@ -425,7 +442,7 @@ def update(self, identifier, item):
identifier, json_data = self._load_and_prepare_item(
item, identifier, raise_if_exists=False)

_ = self.os_index(index=self.index_name, id=identifier, body=json_data)
_ = self.os_.index(index=self.index_name, id=identifier, body=json_data)

return True

Expand All @@ -439,7 +456,7 @@ def delete(self, identifier):
"""

LOGGER.debug(f'Deleting item {identifier}')
_ = self.os_delete(index=self.index_name, id=identifier)
_ = self.os_.delete(index=self.index_name, id=identifier)

return True

Expand Down