From aa1b7b2088a162fc8e19f2520003fa791d72387d Mon Sep 17 00:00:00 2001 From: Daniel Miranda Date: Thu, 14 Mar 2019 21:42:19 -0300 Subject: [PATCH 01/13] hooks: overhaul hook config and running mechanism - Prepare for running hooks individually instead of as a group, to be used to integrate them hooks into the stack execution graph - Add extra config keys for hook execution: profile, region, requires, required_by - Add stack-level key for integrated hooks: build_hooks and destroy_hooks - Add methods to context to retrieve hooks in internal format based on the action to be run --- stacker/config/__init__.py | 14 +++ stacker/context.py | 11 +- stacker/exceptions.py | 17 +++ stacker/hooks/__init__.py | 134 +++++++++++++++++++++++ stacker/tests/test_context.py | 57 +++++++++- stacker/tests/test_util.py | 200 ++++++++++++++-------------------- stacker/util.py | 71 +----------- 7 files changed, 312 insertions(+), 192 deletions(-) diff --git a/stacker/config/__init__.py b/stacker/config/__init__.py index 5fdde4162..f3666b0a6 100644 --- a/stacker/config/__init__.py +++ b/stacker/config/__init__.py @@ -280,6 +280,8 @@ class PackageSources(Model): class Hook(Model): + name = StringType(serialize_when_none=None) + path = StringType(required=True) required = BooleanType(default=True) @@ -290,6 +292,14 @@ class Hook(Model): args = DictType(AnyType) + required_by = ListType(StringType, serialize_when_none=False) + + requires = ListType(StringType, serialize_when_none=False) + + region = StringType(serialize_when_none=False) + + profile = StringType(serialize_when_none=False) + class Target(Model): name = StringType(required=True) @@ -414,10 +424,14 @@ class Config(Model): post_build = ListType(ModelType(Hook), serialize_when_none=False) + build_hooks = ListType(ModelType(Hook), serialize_when_none=False) + pre_destroy = ListType(ModelType(Hook), serialize_when_none=False) post_destroy = ListType(ModelType(Hook), serialize_when_none=False) + destroy_hooks = ListType(ModelType(Hook), serialize_when_none=False) + tags = DictType(StringType, serialize_when_none=False) template_indent = StringType(serialize_when_none=False) diff --git a/stacker/context.py b/stacker/context.py index 0eac9236f..96af24ae2 100644 --- a/stacker/context.py +++ b/stacker/context.py @@ -4,10 +4,12 @@ from builtins import object import collections import logging +import threading from stacker.config import Config from .stack import Stack from .target import Target +from .hooks import ActionHooks logger = logging.getLogger(__name__) @@ -57,6 +59,8 @@ def __init__(self, environment=None, self.force_stacks = force_stacks or [] self.hook_data = {} + self._hook_lock = threading.RLock() + @property def namespace(self): return self.config.namespace @@ -136,6 +140,7 @@ def get_targets(self): for target_def in self.config.targets or []: target = Target(target_def) targets.append(target) + self._targets = targets return self._targets @@ -183,6 +188,9 @@ def get_fqn(self, name=None): """ return get_fqn(self._base_fqn, self.namespace_delimiter, name) + def get_hooks_for_action(self, action_name): + return ActionHooks.from_config(self.config, action_name) + def set_hook_data(self, key, data): """Set hook data for the given key. @@ -201,4 +209,5 @@ def set_hook_data(self, key, data): raise KeyError("Hook data for key %s already exists, each hook " "must have a unique data_key.", key) - self.hook_data[key] = data + with self._hook_lock: + self.hook_data[key] = data diff --git a/stacker/exceptions.py b/stacker/exceptions.py index e1ae8339f..e2f463336 100644 --- a/stacker/exceptions.py +++ b/stacker/exceptions.py @@ -273,3 +273,20 @@ def __init__(self, exception, stack, dependency): "as a dependency of '%s': %s" ) % (dependency, stack, str(exception)) super(GraphError, self).__init__(message) + + +class HookExecutionFailed(Exception): + """Raised when running a required hook fails""" + + def __init__(self, hook, result=None, exception=None): + self.hook = hook + self.result = result + self.exception = exception + + if self.exception: + message = ("Hook '{}' threw exception: {}".format( + hook.name, exception)) + else: + message = ("Hook '{}' failed (result: {})".format( + hook.name, result)) + super(HookExecutionFailed, self).__init__(message) diff --git a/stacker/hooks/__init__.py b/stacker/hooks/__init__.py index e69de29bb..39912d3ab 100644 --- a/stacker/hooks/__init__.py +++ b/stacker/hooks/__init__.py @@ -0,0 +1,134 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +from collections import Mapping, namedtuple + +from stacker.exceptions import HookExecutionFailed +from stacker.util import load_object_from_string + + +logger = logging.getLogger(__name__) + + +def no_op(*args, **kwargs): + logger.info("No-op hook called with arguments: {}".format(kwargs)) + return True + + +class Hook(object): + @classmethod + def from_definition(cls, definition, name_fallback=None): + """Create a hook instance from a config definition""" + name = definition.name or name_fallback + if not name: + raise ValueError('Hook definition does not include name and no ' + 'fallback provided') + + data_key = definition.data_key or name + return cls( + name=name, + path=definition.path, + required=definition.required, + enabled=definition.enabled, + data_key=data_key, + args=definition.args, + required_by=definition.required_by, + requires=definition.requires, + profile=definition.profile, + region=definition.region) + + def __init__(self, name, path, required=True, enabled=True, + data_key=None, args=None, required_by=None, requires=None, + profile=None, region=None): + self.path = path + self.name = name + self.required = required + self.enabled = enabled + self.data_key = data_key + self.args = args or {} + self.required_by = set(required_by or []) + self.requires = set(requires or []) + self.profile = profile + self.region = region + + def run(self, provider, context): + """Run a Hook and capture its result + + These are pieces of external code that we want to run in addition to + CloudFormation deployments, to perform actions that are not easily + handled in a template. + + Args: + provider (:class:`stacker.provider.base.BaseProvider`): + Provider to pass to the hook + context (:class:`stacker.context.Context`): The current stacker + context + Raises: + :class:`stacker.exceptions.HookExecutionFailed`: + if the hook failed + Returns: the result of the hook if it was run, ``None`` if it was + skipped. + """ + + logger.info("Executing hook %s", self) + + data_key = self.data_key + required = self.required + kwargs = self.args or {} + enabled = self.enabled + + if not enabled: + logger.debug("Hook %s is disabled, skipping", self.name) + return + + try: + method = load_object_from_string(self.path) + except (AttributeError, ImportError) as e: + logger.exception("Unable to load method at %s for hook %s:", + self.path, self.name) + if required: + raise HookExecutionFailed(self, exception=e) + + return + + try: + result = method(context=context, provider=provider, **kwargs) + except Exception as e: + if required: + raise HookExecutionFailed(self, exception=e) + + return + + if not result: + if required: + raise HookExecutionFailed(self, result=result) + + logger.warning("Non-required hook %s failed. Return value: %s", + self.name, result) + return result + + if isinstance(result, Mapping): + if data_key: + logger.debug("Adding result for hook %s to context in " + "data_key %s.", self.name, data_key) + context.set_hook_data(data_key, result) + + return result + + +class ActionHooks(namedtuple('ActionHooks', 'action_name pre post custom')): + @classmethod + def from_config(cls, config, action_name): + def from_key(key): + for i, hook_def in enumerate(config.get(key) or [], 1): + name_fallback = '{}_{}_{}'.format(key, i, hook_def.path) + yield Hook.from_definition(hook_def, + name_fallback=name_fallback) + + return ActionHooks( + action_name=action_name, + pre=list(from_key('pre_{}'.format(action_name))), + post=list(from_key('post_{}'.format(action_name))), + custom=list(from_key('{}_hooks'.format(action_name)))) diff --git a/stacker/tests/test_context.py b/stacker/tests/test_context.py index 088fed5f0..629509710 100644 --- a/stacker/tests/test_context.py +++ b/stacker/tests/test_context.py @@ -1,11 +1,12 @@ from __future__ import print_function from __future__ import division from __future__ import absolute_import + import unittest +import mock from stacker.context import Context, get_fqn from stacker.config import load, Config -from stacker.util import handle_hooks class TestContext(unittest.TestCase): @@ -123,11 +124,61 @@ def test_hook_with_sys_path(self): "args": { "value": "mockResult"}}]}) load(config) + context = Context(config=config) - stage = "pre_build" - handle_hooks(stage, context.config[stage], "mock-region-1", context) + provider = mock.Mock() + hooks = context.get_hooks_for_action('build') + hook = hooks.pre[0] + + hook.run(provider, context) self.assertEqual("mockResult", context.hook_data["myHook"]["result"]) + def test_get_hooks_for_action(self): + config = Config({ + "pre_build": [ + {"path": "fake.hook"}, + {"name": "pre_build_test", "path": "fake.hook"}, + {"path": "fake.hook"} + ], + "post_build": [ + {"path": "fake.hook"}, + {"name": "post_build_test", "path": "fake.hook"}, + {"path": "fake.hook"} + ], + "build_hooks": [ + {"path": "fake.hook"}, + {"name": "build_test", "path": "fake.hook"}, + {"path": "fake.hook"} + ] + }) + + context = Context(config=config) + hooks = context.get_hooks_for_action('build') + + assert hooks.pre[0].name == "pre_build_1_fake.hook" + assert hooks.pre[1].name == "pre_build_test" + assert hooks.pre[2].name == "pre_build_3_fake.hook" + + assert hooks.post[0].name == "post_build_1_fake.hook" + assert hooks.post[1].name == "post_build_test" + assert hooks.post[2].name == "post_build_3_fake.hook" + + assert hooks.custom[0].name == "build_hooks_1_fake.hook" + assert hooks.custom[1].name == "build_test" + assert hooks.custom[2].name == "build_hooks_3_fake.hook" + + def test_hook_data_key_fallback(self): + config = Config({ + "build_hooks": [ + {"name": "my-hook", "path": "fake.hook"} + ] + }) + context = Context(config=config) + hooks = context.get_hooks_for_action("build") + hook = hooks.custom[0] + + assert hook.data_key == "my-hook" + class TestFunctions(unittest.TestCase): """ Test the module level functions """ diff --git a/stacker/tests/test_util.py b/stacker/tests/test_util.py index 9c4fa7635..cc6a8deab 100644 --- a/stacker/tests/test_util.py +++ b/stacker/tests/test_util.py @@ -3,23 +3,21 @@ from __future__ import absolute_import from future import standard_library standard_library.install_aliases() - -import unittest - -import string import os -import queue +import string +import unittest import mock import boto3 -from stacker.config import Hook, GitPackageSource +from stacker.config import GitPackageSource +from stacker.exceptions import HookExecutionFailed +from stacker.hooks import Hook from stacker.util import ( cf_safe_name, load_object_from_string, camel_to_snake, - handle_hooks, merge_map, yaml_to_ordered_dict, get_client_region, @@ -274,28 +272,7 @@ def test_SourceProcessor_helpers(self): ) -hook_queue = queue.Queue() - - -def mock_hook(*args, **kwargs): - hook_queue.put(kwargs) - return True - - -def fail_hook(*args, **kwargs): - return None - - -def exception_hook(*args, **kwargs): - raise Exception - - -def context_hook(*args, **kwargs): - return "context" in kwargs - - -def result_hook(*args, **kwargs): - return {"foo": "bar"} +mock_hook = mock.Mock() class TestHooks(unittest.TestCase): @@ -304,116 +281,103 @@ def setUp(self): self.context = mock_context(namespace="namespace") self.provider = mock_provider(region="us-east-1") - def test_empty_hook_stage(self): - hooks = [] - handle_hooks("fake", hooks, self.provider, self.context) - self.assertTrue(hook_queue.empty()) + global mock_hook + mock_hook = mock.Mock() def test_missing_required_hook(self): - hooks = [Hook({"path": "not.a.real.path", "required": True})] - with self.assertRaises(ImportError): - handle_hooks("missing", hooks, self.provider, self.context) + hook = Hook("test", path="not.a.real.path", required=True) + + with self.assertRaises(HookExecutionFailed) as raised: + hook.run(self.provider, self.context) + self.assertIsInstance(ImportError, raised.exception.exception) def test_missing_required_hook_method(self): - hooks = [{"path": "stacker.hooks.blah", "required": True}] - with self.assertRaises(AttributeError): - handle_hooks("missing", hooks, self.provider, self.context) + hook = Hook("test", path="stacker.hooks.blah", required=True) + + with self.assertRaises(HookExecutionFailed) as raised: + hook.run(self.provider, self.context) + self.assertIsInstance(AttributeError, raised.exception.exception) def test_missing_non_required_hook_method(self): - hooks = [Hook({"path": "stacker.hooks.blah", "required": False})] - handle_hooks("missing", hooks, self.provider, self.context) - self.assertTrue(hook_queue.empty()) + hook = Hook("test", path="stacker.hooks.blah", required=False) + + result = hook.run(self.provider, self.context) + self.assertIsNone(result) def test_default_required_hook(self): - hooks = [Hook({"path": "stacker.hooks.blah"})] - with self.assertRaises(AttributeError): - handle_hooks("missing", hooks, self.provider, self.context) - - def test_valid_hook(self): - hooks = [ - Hook({"path": "stacker.tests.test_util.mock_hook", - "required": True})] - handle_hooks("missing", hooks, self.provider, self.context) - good = hook_queue.get_nowait() - self.assertEqual(good["provider"].region, "us-east-1") - with self.assertRaises(queue.Empty): - hook_queue.get_nowait() + hook = Hook("test", path="stacker.hooks.blah") + + with self.assertRaises(HookExecutionFailed) as raised: + hook.run(self.provider, self.context) + self.assertIsInstance(AttributeError, raised.exception.exception) def test_valid_enabled_hook(self): - hooks = [ - Hook({"path": "stacker.tests.test_util.mock_hook", - "required": True, "enabled": True})] - handle_hooks("missing", hooks, self.provider, self.context) - good = hook_queue.get_nowait() - self.assertEqual(good["provider"].region, "us-east-1") - with self.assertRaises(queue.Empty): - hook_queue.get_nowait() - - def test_valid_enabled_false_hook(self): - hooks = [ - Hook({"path": "stacker.tests.test_util.mock_hook", - "required": True, "enabled": False})] - handle_hooks("missing", hooks, self.provider, self.context) - self.assertTrue(hook_queue.empty()) + hook = Hook("test", path="stacker.tests.test_util.mock_hook", + required=True, enabled=True) + + result = mock_hook.return_value = mock.Mock() + self.assertIs(result, hook.run(self.provider, self.context)) + mock_hook.assert_called_once() + + def test_valid_disabled_hook(self): + hook = Hook("test", path="stacker.tests.test_util.mock_hook", + required=True, enabled=False) + + self.assertIsNone(hook.run(self.provider, self.context)) + mock_hook.assert_not_called() def test_context_provided_to_hook(self): - hooks = [ - Hook({"path": "stacker.tests.test_util.context_hook", - "required": True})] - handle_hooks("missing", hooks, "us-east-1", self.context) + hook = Hook("test", path="stacker.tests.test_util.mock_hook", + required=True) + + def return_context(*args, **kwargs): + return kwargs['context'] + + mock_hook.side_effect = return_context + result = hook.run(self.provider, self.context) + self.assertIs(result, self.context) def test_hook_failure(self): - hooks = [ - Hook({"path": "stacker.tests.test_util.fail_hook", - "required": True})] - with self.assertRaises(SystemExit): - handle_hooks("fail", hooks, self.provider, self.context) - hooks = [{"path": "stacker.tests.test_util.exception_hook", - "required": True}] - with self.assertRaises(Exception): - handle_hooks("fail", hooks, self.provider, self.context) - hooks = [ - Hook({"path": "stacker.tests.test_util.exception_hook", - "required": False})] - # Should pass - handle_hooks("ignore_exception", hooks, self.provider, self.context) + hook = Hook("test", path="stacker.tests.test_util.mock_hook", + required=True) + + err = Exception() + mock_hook.side_effect = err + + with self.assertRaises(HookExecutionFailed) as raised: + hook.run(self.provider, self.context) + self.assertIs(hook, raised.exception.hook) + self.assertIs(err, raised.exception.exception) + + def test_hook_failure_skip(self): + hook = Hook("test", path="stacker.tests.test_util.mock_hook", + required=False) + + mock_hook.side_effect = Exception() + result = hook.run(self.provider, self.context) + self.assertIsNone(result) def test_return_data_hook(self): - hooks = [ - Hook({ - "path": "stacker.tests.test_util.result_hook", - "data_key": "my_hook_results" - }), - # Shouldn't return data - Hook({ - "path": "stacker.tests.test_util.context_hook" - }) - ] - handle_hooks("result", hooks, "us-east-1", self.context) + hook = Hook("test", path="stacker.tests.test_util.mock_hook", + data_key='test') + hook_data = {'hello': 'world'} + mock_hook.return_value = hook_data - self.assertEqual( - self.context.hook_data["my_hook_results"]["foo"], - "bar" - ) - # Verify only the first hook resulted in stored data - self.assertEqual( - list(self.context.hook_data.keys()), ["my_hook_results"] - ) + result = hook.run(self.provider, self.context) + self.assertEqual(hook_data, result) + self.assertEqual(hook_data, self.context.hook_data.get('test')) def test_return_data_hook_duplicate_key(self): - hooks = [ - Hook({ - "path": "stacker.tests.test_util.result_hook", - "data_key": "my_hook_results" - }), - Hook({ - "path": "stacker.tests.test_util.result_hook", - "data_key": "my_hook_results" - }) - ] + hook = Hook("test", path="stacker.tests.test_util.mock_hook", + data_key='test') + mock_hook.return_value = {'foo': 'bar'} + hook_data = {'hello': 'world'} + self.context.set_hook_data('test', hook_data) with self.assertRaises(KeyError): - handle_hooks("result", hooks, "us-east-1", self.context) + hook.run(self.provider, self.context) + + self.assertEqual(hook_data, self.context.hook_data['test']) class TestException1(Exception): diff --git a/stacker/util.py b/stacker/util.py index 4f95a52f6..41595dbec 100644 --- a/stacker/util.py +++ b/stacker/util.py @@ -16,7 +16,6 @@ import tempfile import zipfile -import collections from collections import OrderedDict import botocore.client @@ -26,7 +25,7 @@ from yaml.constructor import ConstructorError from yaml.nodes import MappingNode -from .awscli_yamlhelper import yaml_parse +from stacker.awscli_yamlhelper import yaml_parse from stacker.session_cache import get_session logger = logging.getLogger(__name__) @@ -337,74 +336,6 @@ def cf_safe_name(name): return "".join([uppercase_first_letter(part) for part in parts]) -def handle_hooks(stage, hooks, provider, context): - """ Used to handle pre/post_build hooks. - - These are pieces of code that we want to run before/after the builder - builds the stacks. - - Args: - stage (string): The current stage (pre_run, post_run, etc). - hooks (list): A list of :class:`stacker.config.Hook` containing the - hooks to execute. - provider (:class:`stacker.provider.base.BaseProvider`): The provider - the current stack is using. - context (:class:`stacker.context.Context`): The current stacker - context. - """ - if not hooks: - logger.debug("No %s hooks defined.", stage) - return - - hook_paths = [] - for i, h in enumerate(hooks): - try: - hook_paths.append(h.path) - except KeyError: - raise ValueError("%s hook #%d missing path." % (stage, i)) - - logger.info("Executing %s hooks: %s", stage, ", ".join(hook_paths)) - for hook in hooks: - data_key = hook.data_key - required = hook.required - kwargs = hook.args or {} - enabled = hook.enabled - if not enabled: - logger.debug("hook with method %s is disabled, skipping", - hook.path) - continue - try: - method = load_object_from_string(hook.path) - except (AttributeError, ImportError): - logger.exception("Unable to load method at %s:", hook.path) - if required: - raise - continue - try: - result = method(context=context, provider=provider, **kwargs) - except Exception: - logger.exception("Method %s threw an exception:", hook.path) - if required: - raise - continue - if not result: - if required: - logger.error("Required hook %s failed. Return value: %s", - hook.path, result) - sys.exit(1) - logger.warning("Non-required hook %s failed. Return value: %s", - hook.path, result) - else: - if isinstance(result, collections.Mapping): - if data_key: - logger.debug("Adding result for hook %s to context in " - "data_key %s.", hook.path, data_key) - context.set_hook_data(data_key, result) - else: - logger.debug("Hook %s returned result data, but no data " - "key set, so ignoring.", hook.path) - - def get_config_directory(): """Return the directory the config file is located in. From b821ea2fe4680dc6fdf899751ec52f9d2a7400f3 Mon Sep 17 00:00:00 2001 From: Daniel Miranda Date: Fri, 15 Mar 2019 02:05:25 -0300 Subject: [PATCH 02/13] target: allow creation without config definition We need to define some syntethic targets to add more features to the execution graph, so make it possible to create instances without a config def. --- stacker/context.py | 2 +- stacker/target.py | 17 ++++++++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/stacker/context.py b/stacker/context.py index 96af24ae2..8fba2731c 100644 --- a/stacker/context.py +++ b/stacker/context.py @@ -138,7 +138,7 @@ def get_targets(self): if not hasattr(self, "_targets"): targets = [] for target_def in self.config.targets or []: - target = Target(target_def) + target = Target.from_definition(target_def) targets.append(target) self._targets = targets diff --git a/stacker/target.py b/stacker/target.py index b57b3e672..a2171fd9d 100644 --- a/stacker/target.py +++ b/stacker/target.py @@ -9,8 +9,15 @@ class Target(object): a set of stacks together that can be targeted with the `--targets` flag. """ - def __init__(self, definition): - self.name = definition.name - self.requires = definition.requires or [] - self.required_by = definition.required_by or [] - self.logging = False + @classmethod + def from_definition(cls, definition): + return cls(name=definition.name, + requires=definition.requires, + required_by=definition.required_by, + logging=False) + + def __init__(self, name, requires=None, required_by=None, logging=False): + self.name = name + self.requires = list(requires or []) + self.required_by = list(required_by or []) + self.logging = logging From 2680a11b0e1364e26682edf4718a207e36daf8bc Mon Sep 17 00:00:00 2001 From: Daniel Miranda Date: Fri, 15 Mar 2019 14:00:56 -0300 Subject: [PATCH 03/13] plan: generalize step to accept hooks and targets --- stacker/actions/base.py | 9 ++-- stacker/plan.py | 71 ++++++++++++++++++----------- stacker/tests/actions/test_build.py | 2 +- stacker/tests/test_plan.py | 49 +++++++++++++------- 4 files changed, 81 insertions(+), 50 deletions(-) diff --git a/stacker/actions/base.py b/stacker/actions/base.py index 0763e5245..1d8f8a3dc 100644 --- a/stacker/actions/base.py +++ b/stacker/actions/base.py @@ -81,12 +81,11 @@ def plan(description, stack_action, context, def target_fn(*args, **kwargs): return COMPLETE - steps = [ - Step(stack, fn=stack_action, watch_func=tail) - for stack in context.get_stacks()] + steps = [Step.from_stack(stack, fn=stack_action, watch_func=tail) + for stack in context.get_stacks()] - steps += [ - Step(target, fn=target_fn) for target in context.get_targets()] + steps += [Step.from_target(target, fn=target_fn) + for target in context.get_targets()] graph = build_graph(steps) diff --git a/stacker/plan.py b/stacker/plan.py index 24b415e04..00760078e 100644 --- a/stacker/plan.py +++ b/stacker/plan.py @@ -8,6 +8,7 @@ import uuid import threading +from .stack import Stack from .util import stack_template_key_name from .exceptions import ( GraphError, @@ -44,26 +45,51 @@ def log_step(step): class Step(object): """State machine for executing generic actions related to stacks. Args: - stack (:class:`stacker.stack.Stack`): the stack associated - with this step + subject (Union[Stack, Target, Hook]): the subject associated with this + step. Usually a Stack, Target or Hook. fn (func): the function to run to execute the step. This function will be ran multiple times until the step is "done". watch_func (func): an optional function that will be called to "tail" the step action. """ - def __init__(self, stack, fn, watch_func=None): - self.stack = stack - self.status = PENDING - self.last_updated = time.time() + @classmethod + def from_stack(cls, stack, fn, **kwargs): + kwargs.setdefault('logging', stack.logging) + return cls(stack.name, subject=stack, fn=fn, **kwargs) + + @classmethod + def from_target(cls, target, fn, **kwargs): + kwargs.setdefault('logging', True) + return cls(target.name, subject=target, fn=fn, **kwargs) + + @classmethod + def from_hook(cls, hook, fn, **kwargs): + kwargs.setdefault('logging', True) + return cls(hook.name, subject=hook, fn=fn, **kwargs) + + def __init__(self, name, fn, subject=None, watch_func=None, requires=None, + required_by=None, logging=False): + self.name = name + self.subject = subject self.fn = fn + self.watch_func = watch_func + self.requires = set(requires or []) + self.required_by = set(required_by or []) + if subject is not None: + self.requires.update(subject.requires or []) + self.required_by.update(subject.required_by or []) + self.logging = logging + + self.status = PENDING + self.last_updated = time.time() def __repr__(self): - return "" % (self.stack.name,) + return "" % (self.name,) def __str__(self): - return self.stack.name + return self.name def run(self): """Runs this step until it has completed successfully, or been @@ -75,7 +101,7 @@ def run(self): if self.watch_func: watcher = threading.Thread( target=self.watch_func, - args=(self.stack, stop_watcher) + args=(self.subject, stop_watcher) ) watcher.start() @@ -90,25 +116,13 @@ def run(self): def _run_once(self): try: - status = self.fn(self.stack, status=self.status) + status = self.fn(self.subject, status=self.status) except Exception as e: logger.exception(e) status = FailedStatus(reason=str(e)) self.set_status(status) return status - @property - def name(self): - return self.stack.name - - @property - def requires(self): - return self.stack.requires - - @property - def required_by(self): - return self.stack.required_by - @property def completed(self): """Returns True if the step is in a COMPLETE state.""" @@ -147,11 +161,10 @@ def set_status(self, status): step to. """ if status is not self.status: - logger.debug("Setting %s state to %s.", self.stack.name, - status.name) + logger.debug("Setting %s state to %s.", self.name, status.name) self.status = status self.last_updated = time.time() - if self.stack.logging: + if self.logging: log_step(self) def complete(self): @@ -231,6 +244,7 @@ class Graph(object): Example: >>> dag = DAG() + >>> def build(*args, **kwargs): return COMPLETE >>> a = Step("a", fn=build) >>> b = Step("b", fn=build) >>> dag.add_step(a) @@ -335,11 +349,14 @@ def dump(self, directory, context, provider=None): os.makedirs(directory) def walk_func(step): - step.stack.resolve( + if not isinstance(step.subject, Stack): + return True + + step.subject.resolve( context=context, provider=provider, ) - blueprint = step.stack.blueprint + blueprint = step.subject.blueprint filename = stack_template_key_name(blueprint) path = os.path.join(directory, filename) diff --git a/stacker/tests/actions/test_build.py b/stacker/tests/actions/test_build.py index 018101401..a3ee8b2c5 100644 --- a/stacker/tests/actions/test_build.py +++ b/stacker/tests/actions/test_build.py @@ -228,7 +228,7 @@ def setUp(self): plan = self.build_action._generate_plan() self.step = plan.steps[0] - self.step.stack = self.stack + self.step.subject = self.stack def patch_object(*args, **kwargs): m = mock.patch.object(*args, **kwargs) diff --git a/stacker/tests/test_plan.py b/stacker/tests/test_plan.py index a88c5e460..81f5c1710 100644 --- a/stacker/tests/test_plan.py +++ b/stacker/tests/test_plan.py @@ -45,7 +45,7 @@ def setUp(self): stack = mock.MagicMock() stack.name = "stack" stack.fqn = "namespace-stack" - self.step = Step(stack=stack, fn=None) + self.step = Step.from_stack(stack=stack, fn=None) def test_status(self): self.assertFalse(self.step.submitted) @@ -87,7 +87,9 @@ def test_plan(self): context=self.context) graph = build_graph([ - Step(vpc, fn=None), Step(bastion, fn=None)]) + Step.from_stack(vpc, fn=None), + Step.from_stack(bastion, fn=None) + ]) plan = build_plan(description="Test", graph=graph) self.assertEqual(plan.graph.to_dict(), { @@ -108,7 +110,10 @@ def fn(stack, status=None): calls.append(stack.fqn) return COMPLETE - graph = build_graph([Step(vpc, fn), Step(bastion, fn)]) + graph = build_graph([ + Step.from_stack(vpc, fn), + Step.from_stack(bastion, fn) + ]) plan = build_plan( description="Test", graph=graph) plan.execute(walk) @@ -133,7 +138,10 @@ def fn(stack, status=None): return COMPLETE graph = build_graph([ - Step(vpc, fn), Step(db, fn), Step(app, fn)]) + Step.from_stack(vpc, fn), + Step.from_stack(db, fn), + Step.from_stack(app, fn) + ]) plan = build_plan( description="Test", graph=graph, @@ -159,8 +167,8 @@ def fn(stack, status=None): raise ValueError('Boom') return COMPLETE - vpc_step = Step(vpc, fn) - bastion_step = Step(bastion, fn) + vpc_step = Step.from_stack(vpc, fn) + bastion_step = Step.from_stack(bastion, fn) graph = build_graph([vpc_step, bastion_step]) plan = build_plan(description="Test", graph=graph) @@ -187,8 +195,8 @@ def fn(stack, status=None): return SKIPPED return COMPLETE - vpc_step = Step(vpc, fn) - bastion_step = Step(bastion, fn) + vpc_step = Step.from_stack(vpc, fn) + bastion_step = Step.from_stack(bastion, fn) graph = build_graph([vpc_step, bastion_step]) plan = build_plan(description="Test", graph=graph) @@ -215,9 +223,9 @@ def fn(stack, status=None): return FAILED return COMPLETE - vpc_step = Step(vpc, fn) - bastion_step = Step(bastion, fn) - db_step = Step(db, fn) + vpc_step = Step.from_stack(vpc, fn) + bastion_step = Step.from_stack(bastion, fn) + db_step = Step.from_stack(db, fn) graph = build_graph([ vpc_step, bastion_step, db_step]) @@ -245,8 +253,8 @@ def fn(stack, status=None): raise CancelExecution return COMPLETE - vpc_step = Step(vpc, fn) - bastion_step = Step(bastion, fn) + vpc_step = Step.from_stack(vpc, fn) + bastion_step = Step.from_stack(bastion, fn) graph = build_graph([vpc_step, bastion_step]) plan = build_plan(description="Test", graph=graph) @@ -261,7 +269,7 @@ def test_build_graph_missing_dependency(self): context=self.context) with self.assertRaises(GraphError) as expected: - build_graph([Step(bastion, None)]) + build_graph([Step.from_stack(bastion, None)]) message_starts = ( "Error detected when adding 'vpc.1' " "as a dependency of 'bastion.1':" @@ -285,7 +293,11 @@ def test_build_graph_cyclic_dependencies(self): context=self.context) with self.assertRaises(GraphError) as expected: - build_graph([Step(vpc, None), Step(db, None), Step(app, None)]) + build_graph([ + Step.from_stack(vpc, None), + Step.from_stack(db, None), + Step.from_stack(app, None) + ]) message = ("Error detected when adding 'db.1' " "as a dependency of 'app.1': graph is " "not acyclic") @@ -311,7 +323,7 @@ def test_dump(self, *args): context=self.context) requires = [stack.name] - steps += [Step(stack, None)] + steps += [Step.from_stack(stack, None)] graph = build_graph(steps) plan = build_plan(description="Test", graph=graph) @@ -321,9 +333,12 @@ def test_dump(self, *args): plan.dump(tmp_dir, context=self.context) for step in plan.steps: + if not isinstance(step.subject, Stack): + continue + template_path = os.path.join( tmp_dir, - stack_template_key_name(step.stack.blueprint)) + stack_template_key_name(step.subject.blueprint)) self.assertTrue(os.path.isfile(template_path)) finally: shutil.rmtree(tmp_dir) From 4a799792a5ee777429b7f1ca2a105581e3b13d3b Mon Sep 17 00:00:00 2001 From: Daniel Miranda Date: Sat, 16 Mar 2019 02:15:09 -0300 Subject: [PATCH 04/13] hooks: integrate into stack execution graph --- stacker/actions/base.py | 169 +++++++++++++++++++------- stacker/actions/build.py | 59 ++------- stacker/actions/destroy.py | 31 +---- stacker/actions/diff.py | 7 +- stacker/actions/graph.py | 7 +- stacker/plan.py | 136 +++++++++++---------- stacker/tests/actions/test_build.py | 49 ++++++-- stacker/tests/actions/test_destroy.py | 51 +++++--- stacker/tests/test_plan.py | 42 +++---- 9 files changed, 316 insertions(+), 235 deletions(-) diff --git a/stacker/actions/base.py b/stacker/actions/base.py index 1d8f8a3dc..30096be37 100644 --- a/stacker/actions/base.py +++ b/stacker/actions/base.py @@ -8,20 +8,14 @@ import threading from ..dag import walk, ThreadedWalker, UnlimitedSemaphore -from ..plan import Step, build_plan, build_graph +from ..plan import Graph, Plan, Step +from ..target import Target import botocore.exceptions from stacker.session_cache import get_session -from stacker.exceptions import PlanFailed - -from ..status import ( - COMPLETE -) - -from stacker.util import ( - ensure_s3_bucket, - get_s3_endpoint, -) +from stacker.exceptions import HookExecutionFailed, PlanFailed +from stacker.status import COMPLETE, SKIPPED, FailedStatus +from stacker.util import ensure_s3_bucket, get_s3_endpoint logger = logging.getLogger(__name__) @@ -61,41 +55,6 @@ def build_walker(concurrency): return ThreadedWalker(semaphore).walk -def plan(description, stack_action, context, - tail=None, reverse=False): - """A simple helper that builds a graph based plan from a set of stacks. - - Args: - description (str): a description of the plan. - action (func): a function to call for each stack. - context (:class:`stacker.context.Context`): a - :class:`stacker.context.Context` to build the plan from. - tail (func): an optional function to call to tail the stack progress. - reverse (bool): if True, execute the graph in reverse (useful for - destroy actions). - - Returns: - :class:`plan.Plan`: The resulting plan object - """ - - def target_fn(*args, **kwargs): - return COMPLETE - - steps = [Step.from_stack(stack, fn=stack_action, watch_func=tail) - for stack in context.get_stacks()] - - steps += [Step.from_target(target, fn=target_fn) - for target in context.get_targets()] - - graph = build_graph(steps) - - return build_plan( - description=description, - graph=graph, - targets=context.stack_names, - reverse=reverse) - - def stack_template_key_name(blueprint): """Given a blueprint, produce an appropriate key name. @@ -155,6 +114,124 @@ def __init__(self, context, provider_builder=None, cancel=None): self.bucket_region = provider_builder.region self.s3_conn = get_session(self.bucket_region).client('s3') + def plan(self, description, action_name, action, context, tail=None, + reverse=False, run_hooks=True): + """A simple helper that builds a graph based plan from a set of stacks. + + Args: + description (str): a description of the plan. + action_name (str): name of the action being run. Used to generate + target names and filter out which hooks to run. + action (func): a function to call for each stack. + context (stacker.context.Context): a context to build the plan + from. + tail (func): an optional function to call to tail the stack + progress. + reverse (bool): whether to flip the direction of stack and target + dependencies. Use it when planning an action destroying + resources, which usually must happen in the reverse order + of creation. + Note: this does not change the order of execution of hooks, or + their dependencies, as the build and destroy hooks are + currently configured in separate. + run_hooks (bool): whether to run hooks configured for this action + + Returns: stacker.plan.Plan: the resulting plan for this action + """ + + def target_fn(*args, **kwargs): + return COMPLETE + + def hook_fn(hook, *args, **kwargs): + provider = self.provider_builder.build(profile=hook.profile, + region=hook.region) + + try: + result = hook.run(provider, self.context) + except HookExecutionFailed as e: + return FailedStatus(reason=str(e)) + + if result is None: + return SKIPPED + + return COMPLETE + + pre_hooks_target = Target( + name="pre_{}_run_hooks".format(action_name)) + pre_action_target = Target( + name="pre_{}".format(action_name), + requires=[pre_hooks_target.name]) + action_target = Target( + name=action_name, + requires=[pre_action_target.name]) + post_action_target = Target( + name="post_{}".format(action_name), + requires=[action_target.name]) + post_hooks_target = Target( + name="post_{}_run_hooks".format(action_name), + requires=[post_action_target.name]) + + def steps(): + yield Step.from_target(pre_hooks_target, fn=target_fn) + yield Step.from_target(pre_action_target, fn=target_fn) + yield Step.from_target(action_target, fn=target_fn) + yield Step.from_target(post_action_target, fn=target_fn) + yield Step.from_target(post_hooks_target, fn=target_fn) + + if run_hooks: + # Since we need to maintain compatibility with legacy hooks, + # we separate them completely from the new hooks. + # The legacy hooks will run in two separate phases, completely + # isolated from regular stacks and targets, and any of the new + # hooks. + # Hence, all legacy pre-hooks will finish before any of the + # new hooks, and all legacy post-hooks will only start after + # the new hooks. + + hooks = self.context.get_hooks_for_action(action_name) + + for hook in hooks.pre: + yield Step.from_hook( + hook, fn=hook_fn, + required_by=[pre_hooks_target.name]) + + for hook in hooks.custom: + yield Step.from_hook( + hook, fn=hook_fn, + requires=[pre_action_target.name], + required_by=[post_action_target.name]) + + for hook in hooks.post: + yield Step.from_hook( + hook, fn=hook_fn, + requires=[post_hooks_target.name]) + + for target in context.get_targets(): + step = Step.from_target(target, fn=target_fn) + if reverse: + step.reverse_requirements() + + yield step + + for stack in context.get_stacks(): + step = Step.from_stack(stack, fn=action, watch_func=tail) + if reverse: + step.reverse_requirements() + + # Contain stack execution in the boundaries of the pre_action + # and post_action targets. + step.requires.add(pre_action_target.name) + step.required_by.add(action_target.name) + + yield step + + graph = Graph.from_steps(list(steps())) + + return Plan.from_graph( + description=description, + graph=graph, + targets=context.stack_names) + def ensure_cfn_bucket(self): """The CloudFormation bucket where templates will be stored.""" if self.bucket_name: diff --git a/stacker/actions/build.py b/stacker/actions/build.py index bd2b91714..ca7bf6f62 100644 --- a/stacker/actions/build.py +++ b/stacker/actions/build.py @@ -3,11 +3,10 @@ from __future__ import absolute_import import logging -from .base import BaseAction, plan, build_walker +from .base import BaseAction, build_walker from .base import STACK_POLL_TIME from ..providers.base import Template -from .. import util from ..exceptions import ( MissingParameterException, StackDidNotChange, @@ -181,29 +180,6 @@ def _handle_missing_parameters(parameter_values, all_params, required_params, return list(parameter_values.items()) -def handle_hooks(stage, hooks, provider, context, dump, outline): - """Handle pre/post hooks. - - Args: - stage (str): The name of the hook stage - pre_build/post_build. - hooks (list): A list of dictionaries containing the hooks to execute. - provider (:class:`stacker.provider.base.BaseProvider`): The provider - the current stack is using. - context (:class:`stacker.context.Context`): The current stacker - context. - dump (bool): Whether running with dump set or not. - outline (bool): Whether running with outline set or not. - - """ - if not outline and not dump and hooks: - util.handle_hooks( - stage=stage, - hooks=hooks, - provider=provider, - context=context - ) - - class Action(BaseAction): """Responsible for building & coordinating CloudFormation stacks. @@ -391,26 +367,19 @@ def _stack_policy(self, stack): if stack.stack_policy: return Template(body=stack.stack_policy) - def _generate_plan(self, tail=False): - return plan( + def _generate_plan(self, tail=False, outline=False, dump=False): + return self.plan( description="Create/Update stacks", - stack_action=self._launch_stack, + action_name="build", + action=self._launch_stack, tail=self._tail_stack if tail else None, - context=self.context) + context=self.context, + run_hooks=not outline and not dump) def pre_run(self, outline=False, dump=False, *args, **kwargs): """Any steps that need to be taken prior to running the action.""" if should_ensure_cfn_bucket(outline, dump): self.ensure_cfn_bucket() - hooks = self.context.config.pre_build - handle_hooks( - "pre_build", - hooks, - self.provider, - self.context, - dump, - outline - ) def run(self, concurrency=0, outline=False, tail=False, dump=False, *args, **kwargs): @@ -419,7 +388,7 @@ def run(self, concurrency=0, outline=False, This is the main entry point for the Builder. """ - plan = self._generate_plan(tail=tail) + plan = self._generate_plan(tail=tail, outline=outline, dump=dump) if not plan.keys(): logger.warn('WARNING: No stacks detected (error in config?)') if not outline and not dump: @@ -433,15 +402,3 @@ def run(self, concurrency=0, outline=False, if dump: plan.dump(directory=dump, context=self.context, provider=self.provider) - - def post_run(self, outline=False, dump=False, *args, **kwargs): - """Any steps that need to be taken after running the action.""" - hooks = self.context.config.post_build - handle_hooks( - "post_build", - hooks, - self.provider, - self.context, - dump, - outline - ) diff --git a/stacker/actions/destroy.py b/stacker/actions/destroy.py index 4f26692ad..e5245cca7 100644 --- a/stacker/actions/destroy.py +++ b/stacker/actions/destroy.py @@ -3,10 +3,9 @@ from __future__ import absolute_import import logging -from .base import BaseAction, plan, build_walker +from .base import BaseAction, build_walker from .base import STACK_POLL_TIME from ..exceptions import StackDoesNotExist -from .. import util from ..status import ( CompleteStatus, SubmittedStatus, @@ -37,12 +36,14 @@ class Action(BaseAction): """ def _generate_plan(self, tail=False): - return plan( + return self.plan( description="Destroy stacks", - stack_action=self._destroy_stack, + action_name='destroy', + action=self._destroy_stack, tail=self._tail_stack if tail else None, context=self.context, - reverse=True) + reverse=True, + run_hooks=True) def _destroy_stack(self, stack, **kwargs): old_status = kwargs.get("status") @@ -78,16 +79,6 @@ def _destroy_stack(self, stack, **kwargs): provider.destroy_stack(provider_stack) return DestroyingStatus - def pre_run(self, outline=False, *args, **kwargs): - """Any steps that need to be taken prior to running the action.""" - pre_destroy = self.context.config.pre_destroy - if not outline and pre_destroy: - util.handle_hooks( - stage="pre_destroy", - hooks=pre_destroy, - provider=self.provider, - context=self.context) - def run(self, force, concurrency=0, tail=False, *args, **kwargs): plan = self._generate_plan(tail=tail) if not plan.keys(): @@ -101,13 +92,3 @@ def run(self, force, concurrency=0, tail=False, *args, **kwargs): else: plan.outline(message="To execute this plan, run with \"--force\" " "flag.") - - def post_run(self, outline=False, *args, **kwargs): - """Any steps that need to be taken after running the action.""" - post_destroy = self.context.config.post_destroy - if not outline and post_destroy: - util.handle_hooks( - stage="post_destroy", - hooks=post_destroy, - provider=self.provider, - context=self.context) diff --git a/stacker/actions/diff.py b/stacker/actions/diff.py index 97801ae7d..7d9cc3f98 100644 --- a/stacker/actions/diff.py +++ b/stacker/actions/diff.py @@ -8,7 +8,7 @@ import logging from operator import attrgetter -from .base import plan, build_walker +from .base import build_walker from . import build from ..ui import ui from .. import exceptions @@ -278,9 +278,10 @@ def _diff_stack(self, stack, **kwargs): return COMPLETE def _generate_plan(self): - return plan( + return self.plan( description="Diff stacks", - stack_action=self._diff_stack, + action_name="diff", + action=self._diff_stack, context=self.context) def run(self, concurrency=0, *args, **kwargs): diff --git a/stacker/actions/graph.py b/stacker/actions/graph.py index 1f069a68d..f7cdffb50 100644 --- a/stacker/actions/graph.py +++ b/stacker/actions/graph.py @@ -5,7 +5,7 @@ import sys import json -from .base import BaseAction, plan +from .base import BaseAction logger = logging.getLogger(__name__) @@ -55,9 +55,10 @@ def json_format(out, graph): class Action(BaseAction): def _generate_plan(self): - return plan( + return self.plan( description="Print graph", - stack_action=None, + action_name='graph', + action=None, context=self.context) def run(self, format=None, reduce=False, *args, **kwargs): diff --git a/stacker/plan.py b/stacker/plan.py index 00760078e..81f968852 100644 --- a/stacker/plan.py +++ b/stacker/plan.py @@ -44,13 +44,15 @@ def log_step(step): class Step(object): """State machine for executing generic actions related to stacks. + Args: - subject (Union[Stack, Target, Hook]): the subject associated with this - step. Usually a Stack, Target or Hook. - fn (func): the function to run to execute the step. This function will - be ran multiple times until the step is "done". - watch_func (func): an optional function that will be called to "tail" - the step action. + subject: the subject associated with this + step. Usually a :class:`stacker.stack.Stack`, + :class:`stacker.target.Target` or :class:`stacker.hooks.Hook` + fn (funcb): the function to run to execute the step. This function + will be ran multiple times until the step is "done". + watch_func (func): an optional function that will be called to + monitor the step action. """ @classmethod @@ -140,18 +142,20 @@ def failed(self): @property def done(self): - """Returns True if the step is finished (either COMPLETE, SKIPPED or FAILED) + """Whether this step is finished (either COMPLETE, SKIPPED or FAILED) """ return self.completed or self.skipped or self.failed @property def ok(self): - """Returns True if the step is finished (either COMPLETE or SKIPPED)""" + """Whether this step is finished (either COMPLETE or SKIPPED)""" return self.completed or self.skipped @property def submitted(self): - """Returns True if the step is SUBMITTED, COMPLETE, or SKIPPED.""" + """Whether this step is has been submitted (SUBMITTED, COMPLETE, or + SKIPPED). + """ return self.status >= SUBMITTED def set_status(self, status): @@ -179,58 +183,15 @@ def submit(self): """A shortcut for set_status(SUBMITTED)""" self.set_status(SUBMITTED) + def reverse_requirements(self): + """ + Change this step so it is suitable for use in operations in reverse + dependency order. -def build_plan(description, graph, - targets=None, reverse=False): - """Builds a plan from a list of steps. - Args: - description (str): an arbitrary string to - describe the plan. - graph (:class:`Graph`): a list of :class:`Graph` to execute. - targets (list): an optional list of step names to filter the graph to. - If provided, only these steps, and their transitive dependencies - will be executed. If no targets are specified, every node in the - graph will be executed. - reverse (bool): If provided, the graph will be walked in reverse order - (dependencies last). - """ - - # If we want to execute the plan in reverse (e.g. Destroy), transpose the - # graph. - if reverse: - graph = graph.transposed() - - # If we only want to build a specific target, filter the graph. - if targets: - nodes = [] - for target in targets: - for k, step in graph.steps.items(): - if step.name == target: - nodes.append(step.name) - graph = graph.filtered(nodes) - - return Plan(description=description, graph=graph) - - -def build_graph(steps): - """Builds a graph of steps. - Args: - steps (list): a list of :class:`Step` objects to execute. - """ - - graph = Graph() - - for step in steps: - graph.add_step(step) - - for step in steps: - for dep in step.requires: - graph.connect(step.name, dep) - - for parent in step.required_by: - graph.connect(parent, step.name) - - return graph + This can be used to correctly generate an action graph when destroying + stacks. + """ + self.required_by, self.requires = self.requires, self.required_by class Graph(object): @@ -252,11 +213,34 @@ class Graph(object): >>> dag.connect(a, b) Args: - steps (list): an optional list of :class:`Step` objects to execute. + steps (dict): an optional list of :class:`Step` objects to execute. dag (:class:`stacker.dag.DAG`): an optional :class:`stacker.dag.DAG` object. If one is not provided, a new one will be initialized. """ + @classmethod + def from_steps(cls, steps): + """Builds a graph of steps respecting dependencies + + Args: + steps (List[Step]): steps to include in the graph + Returns: :class:`Graph`: the resulting graph + """ + + graph = Graph() + + for step in steps: + graph.add_step(step) + + for step in steps: + for dep in step.requires: + graph.connect(step.name, dep) + + for parent in step.required_by: + graph.connect(parent, step.name) + + return graph + def __init__(self, steps=None, dag=None): self.steps = steps or {} self.dag = dag or DAG() @@ -301,6 +285,9 @@ def topological_sort(self): nodes = self.dag.topological_sort() return [self.steps[step_name] for step_name in nodes] + def get(self, name, default=None): + return self.steps.get(name, default) + def to_dict(self): return self.dag.graph @@ -312,6 +299,26 @@ class Plan(object): graph (:class:`Graph`): a graph of steps. """ + @classmethod + def from_graph(cls, description, graph, targets=None): + """Builds a plan from a list of steps. + + Args: + description (str): an arbitrary string to describe the plan. + graph (Graph): a :class:`Graph` to base the plan on + targets (list, optional): names of steps to include in the graph. + If provided, only these steps, and their transitive + dependencies will be executed. Otherwise, every node in the + graph will be executed. + Returns: Plan: the resulting plan + """ + + # If we only want to build a specific target, filter the graph. + if targets: + graph = graph.filtered(targets) + + return Plan(description=description, graph=graph) + def __init__(self, description, graph): self.id = uuid.uuid4() self.description = description @@ -418,3 +425,10 @@ def step_names(self): def keys(self): return self.step_names + + def get(self, name, default=None): + for step in self.steps: + if step.name == name: + return step + + return default diff --git a/stacker/tests/actions/test_build.py b/stacker/tests/actions/test_build.py index a3ee8b2c5..501228b20 100644 --- a/stacker/tests/actions/test_build.py +++ b/stacker/tests/actions/test_build.py @@ -80,6 +80,22 @@ def _get_context(self, **kwargs): "else": "${output bastion::something}"}}, {"name": "other", "variables": {}} ], + "build_hooks": [ + {"name": "before-db-hook", + "path": "stacker.hooks.no_op", + "required_by": ["db"]}, + {"name": "after-db-hook", + "path": "stacker.hooks.no_op", + "requires": ["db"]} + ], + "pre_build": [ + {"name": "pre-build-hook", + "path": "stacker.hooks.no_op"} + ], + "post_build": [ + {"name": "post-build-hook", + "path": "stacker.hooks.no_op"} + ] }) return Context(config=config, **kwargs) @@ -130,14 +146,28 @@ def test_existing_stack_params_dont_override_given_params(self): def test_generate_plan(self): context = self._get_context() build_action = build.Action(context, cancel=MockThreadingEvent()) + plan = build_action._generate_plan() + plan.graph.transitive_reduction() + self.assertEqual( - { - 'db': set(['bastion', 'vpc']), - 'bastion': set(['vpc']), - 'other': set([]), - 'vpc': set([])}, - plan.graph.to_dict() + sorted({ + 'pre-build-hook': set(), + 'pre_build_run_hooks': {'pre-build-hook'}, + 'pre_build': {'pre_build_run_hooks'}, + 'build': {'other', 'db'}, + 'post_build': {'build', 'after-db-hook'}, + 'post_build_run_hooks': {'post_build'}, + 'post-build-hook': {'post_build_run_hooks'}, + + 'other': {'pre_build'}, + 'vpc': {'pre_build'}, + 'bastion': {'vpc'}, + 'before-db-hook': {'pre_build'}, + 'db': {'before-db-hook', 'bastion'}, + 'after-db-hook': {'db'}, + }.items()), + sorted(plan.graph.to_dict().items()) ) def test_dont_execute_plan_when_outline_specified(self): @@ -227,7 +257,8 @@ def setUp(self): self.stack_status = None plan = self.build_action._generate_plan() - self.step = plan.steps[0] + self.step = next(step for step in plan.steps + if step.name == self.stack.name) self.step.subject = self.stack def patch_object(*args, **kwargs): @@ -244,9 +275,9 @@ def get_stack(name, *args, **kwargs): 'Outputs': [], 'Tags': []} - def get_events(name, *args, **kwargs): + def get_events(*args, **kwargs): return [{'ResourceStatus': 'ROLLBACK_IN_PROGRESS', - 'ResourceStatusReason': 'CFN fail'}] + 'ResourceStatusReason': 'CFN fail'}] patch_object(self.provider, 'get_stack', side_effect=get_stack) patch_object(self.provider, 'update_stack') diff --git a/stacker/tests/actions/test_destroy.py b/stacker/tests/actions/test_destroy.py index 697afd660..e4141ea81 100644 --- a/stacker/tests/actions/test_destroy.py +++ b/stacker/tests/actions/test_destroy.py @@ -38,10 +38,26 @@ def setUp(self): "stacks": [ {"name": "vpc"}, {"name": "bastion", "requires": ["vpc"]}, - {"name": "instance", "requires": ["vpc", "bastion"]}, - {"name": "db", "requires": ["instance", "vpc", "bastion"]}, - {"name": "other", "requires": ["db"]}, + {"name": "db", "requires": ["vpc", "bastion"]}, + {"name": "instance", "requires": ["db", "vpc", "bastion"]}, + {"name": "other", "requires": []}, ], + "hooks": [ + {"name": "before-db-hook", + "path": "stacker.hooks.no_op", + "required_by": ["db"]}, + {"name": "after-db-hook", + "path": "stacker.hooks.no_op", + "requires": ["db"]} + ], + "pre_destroy": [ + {"name": "pre-destroy-hook", + "path": "stacker.hooks.no_op"} + ], + "post_destroy": [ + {"name": "post-destroy-hook", + "path": "stacker.hooks.no_op"} + ] }) self.context = Context(config=config) self.action = destroy.Action(self.context, @@ -49,18 +65,25 @@ def setUp(self): def test_generate_plan(self): plan = self.action._generate_plan() + plan.graph.transitive_reduction() + self.assertEqual( { - 'vpc': set( - ['db', 'instance', 'bastion']), - 'other': set([]), - 'bastion': set( - ['instance', 'db']), - 'instance': set( - ['db']), - 'db': set( - ['other'])}, - plan.graph.to_dict() + 'pre-destroy-hook': set(), + 'pre_destroy_run_hooks': {'pre-destroy-hook'}, + 'pre_destroy': {'pre_destroy_run_hooks'}, + 'destroy': {'vpc', 'other'}, + 'post_destroy': {'destroy'}, + 'post_destroy_run_hooks': {'post_destroy'}, + 'post-destroy-hook': {'post_destroy_run_hooks'}, + + 'instance': {'pre_destroy'}, + 'db': {'instance'}, + 'bastion': {'db'}, + 'vpc': {'bastion'}, + 'other': {'pre_destroy'}, + }, + dict(plan.graph.to_dict()) ) def test_only_execute_plan_when_forced(self): @@ -98,7 +121,7 @@ def get_stack(stack_name): return stacks_dict.get(stack_name) plan = self.action._generate_plan() - step = plan.steps[0] + step = plan.get("vpc") # we need the AWS provider to generate the plan, but swap it for # the mock one to make the test easier self.action.provider_builder = MockProviderBuilder(mock_provider) diff --git a/stacker/tests/test_plan.py b/stacker/tests/test_plan.py index 81f5c1710..e71b6cd95 100644 --- a/stacker/tests/test_plan.py +++ b/stacker/tests/test_plan.py @@ -16,11 +16,7 @@ register_lookup_handler, unregister_lookup_handler, ) -from stacker.plan import ( - Step, - build_plan, - build_graph, -) +from stacker.plan import Graph, Step, Plan from stacker.exceptions import ( CancelExecution, GraphError, @@ -86,11 +82,11 @@ def test_plan(self): definition=generate_definition('bastion', 1, requires=[vpc.name]), context=self.context) - graph = build_graph([ + graph = Graph.from_steps([ Step.from_stack(vpc, fn=None), Step.from_stack(bastion, fn=None) ]) - plan = build_plan(description="Test", graph=graph) + plan = Plan.from_graph(description="Test", graph=graph) self.assertEqual(plan.graph.to_dict(), { 'bastion.1': set(['vpc.1']), @@ -110,11 +106,11 @@ def fn(stack, status=None): calls.append(stack.fqn) return COMPLETE - graph = build_graph([ + graph = Graph.from_steps([ Step.from_stack(vpc, fn), Step.from_stack(bastion, fn) ]) - plan = build_plan( + plan = Plan.from_graph( description="Test", graph=graph) plan.execute(walk) @@ -137,12 +133,12 @@ def fn(stack, status=None): calls.append(stack.fqn) return COMPLETE - graph = build_graph([ + graph = Graph.from_steps([ Step.from_stack(vpc, fn), Step.from_stack(db, fn), Step.from_stack(app, fn) ]) - plan = build_plan( + plan = Plan.from_graph( description="Test", graph=graph, targets=['db.1']) @@ -170,8 +166,8 @@ def fn(stack, status=None): vpc_step = Step.from_stack(vpc, fn) bastion_step = Step.from_stack(bastion, fn) - graph = build_graph([vpc_step, bastion_step]) - plan = build_plan(description="Test", graph=graph) + graph = Graph.from_steps([vpc_step, bastion_step]) + plan = Plan.from_graph(description="Test", graph=graph) with self.assertRaises(PlanFailed): plan.execute(walk) @@ -198,8 +194,8 @@ def fn(stack, status=None): vpc_step = Step.from_stack(vpc, fn) bastion_step = Step.from_stack(bastion, fn) - graph = build_graph([vpc_step, bastion_step]) - plan = build_plan(description="Test", graph=graph) + graph = Graph.from_steps([vpc_step, bastion_step]) + plan = Plan.from_graph(description="Test", graph=graph) plan.execute(walk) self.assertEquals(calls, ['namespace-vpc.1', 'namespace-bastion.1']) @@ -227,9 +223,9 @@ def fn(stack, status=None): bastion_step = Step.from_stack(bastion, fn) db_step = Step.from_stack(db, fn) - graph = build_graph([ + graph = Graph.from_steps([ vpc_step, bastion_step, db_step]) - plan = build_plan(description="Test", graph=graph) + plan = Plan.from_graph(description="Test", graph=graph) with self.assertRaises(PlanFailed): plan.execute(walk) @@ -256,8 +252,8 @@ def fn(stack, status=None): vpc_step = Step.from_stack(vpc, fn) bastion_step = Step.from_stack(bastion, fn) - graph = build_graph([vpc_step, bastion_step]) - plan = build_plan(description="Test", graph=graph) + graph = Graph.from_steps([vpc_step, bastion_step]) + plan = Plan.from_graph(description="Test", graph=graph) plan.execute(walk) self.assertEquals(calls, ['namespace-vpc.1', 'namespace-bastion.1']) @@ -269,7 +265,7 @@ def test_build_graph_missing_dependency(self): context=self.context) with self.assertRaises(GraphError) as expected: - build_graph([Step.from_stack(bastion, None)]) + Graph.from_steps([Step.from_stack(bastion, None)]) message_starts = ( "Error detected when adding 'vpc.1' " "as a dependency of 'bastion.1':" @@ -293,7 +289,7 @@ def test_build_graph_cyclic_dependencies(self): context=self.context) with self.assertRaises(GraphError) as expected: - build_graph([ + Graph.from_steps([ Step.from_stack(vpc, None), Step.from_stack(db, None), Step.from_stack(app, None) @@ -325,8 +321,8 @@ def test_dump(self, *args): steps += [Step.from_stack(stack, None)] - graph = build_graph(steps) - plan = build_plan(description="Test", graph=graph) + graph = Graph.from_steps(steps) + plan = Plan.from_graph(description="Test", graph=graph) tmp_dir = tempfile.mkdtemp() try: From 4cc2c7d1865a5ae91b36b7cee604c0b9c2ed3608 Mon Sep 17 00:00:00 2001 From: Daniel Miranda Date: Sun, 17 Mar 2019 16:40:18 -0300 Subject: [PATCH 05/13] actions: base: rename hook running steps to be less verbose --- stacker/actions/base.py | 4 ++-- stacker/tests/actions/test_build.py | 8 ++++---- stacker/tests/actions/test_destroy.py | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/stacker/actions/base.py b/stacker/actions/base.py index 30096be37..a05711f6d 100644 --- a/stacker/actions/base.py +++ b/stacker/actions/base.py @@ -157,7 +157,7 @@ def hook_fn(hook, *args, **kwargs): return COMPLETE pre_hooks_target = Target( - name="pre_{}_run_hooks".format(action_name)) + name="pre_{}_hooks".format(action_name)) pre_action_target = Target( name="pre_{}".format(action_name), requires=[pre_hooks_target.name]) @@ -168,7 +168,7 @@ def hook_fn(hook, *args, **kwargs): name="post_{}".format(action_name), requires=[action_target.name]) post_hooks_target = Target( - name="post_{}_run_hooks".format(action_name), + name="post_{}_hooks".format(action_name), requires=[post_action_target.name]) def steps(): diff --git a/stacker/tests/actions/test_build.py b/stacker/tests/actions/test_build.py index 501228b20..586233437 100644 --- a/stacker/tests/actions/test_build.py +++ b/stacker/tests/actions/test_build.py @@ -153,12 +153,12 @@ def test_generate_plan(self): self.assertEqual( sorted({ 'pre-build-hook': set(), - 'pre_build_run_hooks': {'pre-build-hook'}, - 'pre_build': {'pre_build_run_hooks'}, + 'pre_build_hooks': {'pre-build-hook'}, + 'pre_build': {'pre_build_hooks'}, 'build': {'other', 'db'}, 'post_build': {'build', 'after-db-hook'}, - 'post_build_run_hooks': {'post_build'}, - 'post-build-hook': {'post_build_run_hooks'}, + 'post_build_hooks': {'post_build'}, + 'post-build-hook': {'post_build_hooks'}, 'other': {'pre_build'}, 'vpc': {'pre_build'}, diff --git a/stacker/tests/actions/test_destroy.py b/stacker/tests/actions/test_destroy.py index e4141ea81..cd6b1e20a 100644 --- a/stacker/tests/actions/test_destroy.py +++ b/stacker/tests/actions/test_destroy.py @@ -70,12 +70,12 @@ def test_generate_plan(self): self.assertEqual( { 'pre-destroy-hook': set(), - 'pre_destroy_run_hooks': {'pre-destroy-hook'}, - 'pre_destroy': {'pre_destroy_run_hooks'}, + 'pre_destroy_hooks': {'pre-destroy-hook'}, + 'pre_destroy': {'pre_destroy_hooks'}, 'destroy': {'vpc', 'other'}, 'post_destroy': {'destroy'}, - 'post_destroy_run_hooks': {'post_destroy'}, - 'post-destroy-hook': {'post_destroy_run_hooks'}, + 'post_destroy_hooks': {'post_destroy'}, + 'post-destroy-hook': {'post_destroy_hooks'}, 'instance': {'pre_destroy'}, 'db': {'instance'}, From ca767da264af616bdf59b179c7d4e7ffa25fe988 Mon Sep 17 00:00:00 2001 From: Daniel Miranda Date: Sun, 17 Mar 2019 16:40:55 -0300 Subject: [PATCH 06/13] actions: base: properly reverse hook dependencies for use in destroy --- stacker/actions/base.py | 29 ++++++++++++++++----------- stacker/tests/actions/test_destroy.py | 19 ++++++++++++------ 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/stacker/actions/base.py b/stacker/actions/base.py index a05711f6d..1ce8ab516 100644 --- a/stacker/actions/base.py +++ b/stacker/actions/base.py @@ -116,7 +116,7 @@ def __init__(self, context, provider_builder=None, cancel=None): def plan(self, description, action_name, action, context, tail=None, reverse=False, run_hooks=True): - """A simple helper that builds a graph based plan from a set of stacks. + """A helper that builds a graph based plan from a set of stacks. Args: description (str): a description of the plan. @@ -127,13 +127,12 @@ def plan(self, description, action_name, action, context, tail=None, from. tail (func): an optional function to call to tail the stack progress. - reverse (bool): whether to flip the direction of stack and target - dependencies. Use it when planning an action destroying - resources, which usually must happen in the reverse order - of creation. - Note: this does not change the order of execution of hooks, or - their dependencies, as the build and destroy hooks are - currently configured in separate. + reverse (bool): whether to flip the direction of dependencies. + Use it when planning an action for destroying resources, + which usually must happen in the reverse order of creation. + Note: this does not change the order of execution of pre/post + action hooks, as the build and destroy hooks are currently + configured in separate. run_hooks (bool): whether to run hooks configured for this action Returns: stacker.plan.Plan: the resulting plan for this action @@ -189,6 +188,8 @@ def steps(): # the new hooks. hooks = self.context.get_hooks_for_action(action_name) + logger.debug("Found hooks for action {}: {}".format( + action_name, hooks)) for hook in hooks.pre: yield Step.from_hook( @@ -196,10 +197,14 @@ def steps(): required_by=[pre_hooks_target.name]) for hook in hooks.custom: - yield Step.from_hook( - hook, fn=hook_fn, - requires=[pre_action_target.name], - required_by=[post_action_target.name]) + step = Step.from_hook( + hook, fn=hook_fn) + if reverse: + step.reverse_requirements() + + step.requires.add(pre_action_target.name) + step.required_by.add(post_action_target.name) + yield step for hook in hooks.post: yield Step.from_hook( diff --git a/stacker/tests/actions/test_destroy.py b/stacker/tests/actions/test_destroy.py index cd6b1e20a..7eb3f0dd6 100644 --- a/stacker/tests/actions/test_destroy.py +++ b/stacker/tests/actions/test_destroy.py @@ -42,13 +42,16 @@ def setUp(self): {"name": "instance", "requires": ["db", "vpc", "bastion"]}, {"name": "other", "requires": []}, ], - "hooks": [ - {"name": "before-db-hook", + "destroy_hooks": [ + {"name": "before-db-hook-1", "path": "stacker.hooks.no_op", - "required_by": ["db"]}, + "args": {"x": "${output db::whatever}"}}, + {"name": "before-db-hook-2", + "path": "stacker.hooks.no_op", + "requires": ["db"]}, {"name": "after-db-hook", "path": "stacker.hooks.no_op", - "requires": ["db"]} + "required_by": ["db"]} ], "pre_destroy": [ {"name": "pre-destroy-hook", @@ -73,12 +76,16 @@ def test_generate_plan(self): 'pre_destroy_hooks': {'pre-destroy-hook'}, 'pre_destroy': {'pre_destroy_hooks'}, 'destroy': {'vpc', 'other'}, - 'post_destroy': {'destroy'}, + 'post_destroy': {'destroy', 'after-db-hook'}, 'post_destroy_hooks': {'post_destroy'}, 'post-destroy-hook': {'post_destroy_hooks'}, + 'before-db-hook-1': {'pre_destroy'}, + 'before-db-hook-2': {'pre_destroy'}, + 'after-db-hook': {'db'}, + 'instance': {'pre_destroy'}, - 'db': {'instance'}, + 'db': {'instance', 'before-db-hook-1', 'before-db-hook-2'}, 'bastion': {'db'}, 'vpc': {'bastion'}, 'other': {'pre_destroy'}, From d83378bd869292188a93f7c130b9e6bf60702cc3 Mon Sep 17 00:00:00 2001 From: Daniel Miranda Date: Sun, 17 Mar 2019 16:43:03 -0300 Subject: [PATCH 07/13] stack: remove manual output updating, and always rely on provider instead --- stacker/actions/build.py | 6 ----- stacker/actions/diff.py | 3 --- stacker/lookups/handlers/output.py | 40 +++++++++++++++++------------- stacker/stack.py | 4 --- 4 files changed, 23 insertions(+), 30 deletions(-) diff --git a/stacker/actions/build.py b/stacker/actions/build.py index ca7bf6f62..64c042e03 100644 --- a/stacker/actions/build.py +++ b/stacker/actions/build.py @@ -249,8 +249,6 @@ def _launch_stack(self, stack, **kwargs): provider_stack = None if provider_stack and not should_update(stack): - stack.set_outputs( - self.provider.get_output_dict(provider_stack)) return NotUpdatedStatus() recreate = False @@ -292,8 +290,6 @@ def _launch_stack(self, stack, **kwargs): return FailedStatus(reason) elif provider.is_stack_completed(provider_stack): - stack.set_outputs( - provider.get_output_dict(provider_stack)) return CompleteStatus(old_status.reason) else: return old_status @@ -342,10 +338,8 @@ def _launch_stack(self, stack, **kwargs): else: return SubmittedStatus("destroying stack for re-creation") except CancelExecution: - stack.set_outputs(provider.get_output_dict(provider_stack)) return SkippedStatus(reason="canceled execution") except StackDidNotChange: - stack.set_outputs(provider.get_output_dict(provider_stack)) return DidNotChangeStatus() def _template(self, blueprint): diff --git a/stacker/actions/diff.py b/stacker/actions/diff.py index 7d9cc3f98..84d067f37 100644 --- a/stacker/actions/diff.py +++ b/stacker/actions/diff.py @@ -272,9 +272,6 @@ def _diff_stack(self, stack, **kwargs): new_params, old_params)) ui.info('\n' + '\n'.join(output)) - stack.set_outputs( - provider.get_output_dict(provider_stack)) - return COMPLETE def _generate_plan(self): diff --git a/stacker/lookups/handlers/output.py b/stacker/lookups/handlers/output.py index a40ba0fb3..66c4c0818 100644 --- a/stacker/lookups/handlers/output.py +++ b/stacker/lookups/handlers/output.py @@ -5,6 +5,9 @@ import re from collections import namedtuple +import yaml + +from stacker.exceptions import StackDoesNotExist from . import LookupHandler TYPE_NAME = "output" @@ -14,25 +17,28 @@ class OutputLookup(LookupHandler): @classmethod - def handle(cls, value, context=None, **kwargs): - """Fetch an output from the designated stack. - - Args: - value (str): string with the following format: - ::, ie. some-stack::SomeOutput - context (:class:`stacker.context.Context`): stacker context - - Returns: - str: output from the specified stack - - """ - - if context is None: - raise ValueError('Context is required') + def handle(cls, value, context, provider, **kwargs): + """Fetch an output from the designated stack.""" d = deconstruct(value) - stack = context.get_stack(d.stack_name) - return stack.outputs[d.output_name] + try: + stack = context.get_stack(d.stack_name) + if not stack: + raise StackDoesNotExist(d.stack_name) + outputs = provider.get_outputs(stack.fqn) + except StackDoesNotExist: + raise LookupError("Stack is missing from configuration or not " + "deployed: {}".format(d.stack_name)) + + try: + return outputs[d.output_name] + except KeyError: + available_lookups = yaml.safe_dump( + list(outputs.keys()), default_flow_style=False) + msg = ("Lookup missing from stack: {}::{}. " + "Available lookups:\n{}") + raise LookupError(msg.format( + d.stack_name, d.output_name, available_lookups)) @classmethod def dependencies(cls, lookup_data): diff --git a/stacker/stack.py b/stacker/stack.py index aa5ab81b4..60fb2f564 100644 --- a/stacker/stack.py +++ b/stacker/stack.py @@ -73,7 +73,6 @@ def __init__(self, definition, context, variables=None, mappings=None, self.enabled = enabled self.protected = protected self.context = context - self.outputs = None self.in_progress_behavior = definition.in_progress_behavior def __repr__(self): @@ -192,6 +191,3 @@ def resolve(self, context, provider): """ resolve_variables(self.variables, context, provider) self.blueprint.resolve_variables(self.variables) - - def set_outputs(self, outputs): - self.outputs = outputs From 41a4fe0eb8c955249b01df55efa49bf6494cc925 Mon Sep 17 00:00:00 2001 From: Daniel Miranda Date: Sun, 17 Mar 2019 16:45:36 -0300 Subject: [PATCH 08/13] hooks: resolve lookups in arguments --- stacker/hooks/__init__.py | 40 +++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/stacker/hooks/__init__.py b/stacker/hooks/__init__.py index 39912d3ab..987b2b167 100644 --- a/stacker/hooks/__init__.py +++ b/stacker/hooks/__init__.py @@ -7,7 +7,7 @@ from stacker.exceptions import HookExecutionFailed from stacker.util import load_object_from_string - +from stacker.variables import Variable logger = logging.getLogger(__name__) @@ -47,12 +47,24 @@ def __init__(self, name, path, required=True, enabled=True, self.required = required self.enabled = enabled self.data_key = data_key - self.args = args or {} + self.args = args self.required_by = set(required_by or []) self.requires = set(requires or []) self.profile = profile self.region = region + self._args = {} + if args: + for key, value in args.items(): + var = self._args[key] = \ + Variable('{}.args.{}'.format(self.name, key), value) + self.requires.update(var.dependencies()) + + def resolve_args(self, provider, context): + for key, value in self._args.items(): + value.resolve(context, provider) + yield key, value.value + def run(self, provider, context): """Run a Hook and capture its result @@ -74,12 +86,7 @@ def run(self, provider, context): logger.info("Executing hook %s", self) - data_key = self.data_key - required = self.required - kwargs = self.args or {} - enabled = self.enabled - - if not enabled: + if not self.enabled: logger.debug("Hook %s is disabled, skipping", self.name) return @@ -88,21 +95,22 @@ def run(self, provider, context): except (AttributeError, ImportError) as e: logger.exception("Unable to load method at %s for hook %s:", self.path, self.name) - if required: + if self.required: raise HookExecutionFailed(self, exception=e) return + kwargs = dict(self.resolve_args(provider, context)) try: result = method(context=context, provider=provider, **kwargs) except Exception as e: - if required: + if self.required: raise HookExecutionFailed(self, exception=e) return if not result: - if required: + if self.required: raise HookExecutionFailed(self, result=result) logger.warning("Non-required hook %s failed. Return value: %s", @@ -110,13 +118,17 @@ def run(self, provider, context): return result if isinstance(result, Mapping): - if data_key: + if self.data_key: logger.debug("Adding result for hook %s to context in " - "data_key %s.", self.name, data_key) - context.set_hook_data(data_key, result) + "data_key %s.", self.name, self.data_key) + context.set_hook_data(self.data_key, result) return result + def __str__(self): + return 'Hook(name={}, path={}, profile={}, region={})'.format( + self.name, self.path, self.profile, self.region) + class ActionHooks(namedtuple('ActionHooks', 'action_name pre post custom')): @classmethod From 020b0e2533fe8ebdb4a06e6069c5bec7bff814a3 Mon Sep 17 00:00:00 2001 From: Daniel Miranda Date: Sun, 17 Mar 2019 18:39:51 -0300 Subject: [PATCH 09/13] test: use an uniform pattern for mocking Provider --- stacker/tests/actions/test_build.py | 25 +--- stacker/tests/factories.py | 83 ++++++++--- .../tests/lookups/handlers/test_default.py | 14 +- stacker/tests/lookups/handlers/test_output.py | 25 ++-- stacker/tests/lookups/handlers/test_rxref.py | 22 +-- stacker/tests/lookups/handlers/test_xref.py | 16 +- stacker/tests/test_variables.py | 137 +++++------------- 7 files changed, 136 insertions(+), 186 deletions(-) diff --git a/stacker/tests/actions/test_build.py b/stacker/tests/actions/test_build.py index 586233437..1c0983ceb 100644 --- a/stacker/tests/actions/test_build.py +++ b/stacker/tests/actions/test_build.py @@ -2,9 +2,8 @@ from __future__ import division from __future__ import absolute_import from builtins import str -import unittest from collections import namedtuple - +import unittest import mock from stacker import exceptions @@ -18,7 +17,6 @@ from stacker.blueprints.variables.types import CFNString from stacker.context import Context, Config from stacker.exceptions import StackDidNotChange, StackDoesNotExist -from stacker.providers.base import BaseProvider from stacker.providers.aws.default import Provider from stacker.status import ( NotSubmittedStatus, @@ -29,7 +27,7 @@ FAILED ) -from ..factories import MockThreadingEvent, MockProviderBuilder +from ..factories import MockThreadingEvent, MockProviderBuilder, mock_provider def mock_stack_parameters(parameters): @@ -41,27 +39,10 @@ def mock_stack_parameters(parameters): } -class TestProvider(BaseProvider): - def __init__(self, outputs=None, *args, **kwargs): - self._outputs = outputs or {} - - def set_outputs(self, outputs): - self._outputs = outputs - - def get_stack(self, stack_name, **kwargs): - if stack_name not in self._outputs: - raise exceptions.StackDoesNotExist(stack_name) - return {"name": stack_name, "outputs": self._outputs[stack_name]} - - def get_outputs(self, stack_name, *args, **kwargs): - stack = self.get_stack(stack_name) - return stack["outputs"] - - class TestBuildAction(unittest.TestCase): def setUp(self): self.context = Context(config=Config({"namespace": "namespace"})) - self.provider = TestProvider() + self.provider = mock_provider() self.build_action = build.Action( self.context, provider_builder=MockProviderBuilder(self.provider)) diff --git a/stacker/tests/factories.py b/stacker/tests/factories.py index f930c5177..e818460a9 100644 --- a/stacker/tests/factories.py +++ b/stacker/tests/factories.py @@ -2,11 +2,13 @@ from __future__ import division from __future__ import absolute_import from builtins import object -from mock import MagicMock + +import mock from stacker.context import Context from stacker.config import Config, Stack -from stacker.lookups import Lookup +from stacker.exceptions import StackDoesNotExist, StackUpdateBadStatus +from stacker.providers.base import BaseProvider class MockThreadingEvent(object): @@ -23,23 +25,72 @@ def build(self, region=None, profile=None): return self.provider -def mock_provider(**kwargs): - return MagicMock(**kwargs) +class MockProvider(BaseProvider): + def __init__(self, outputs=None): + self._stacks = {} + for stack_name, stack_outputs in (outputs or {}).items(): + self._stacks[stack_name] = { + "StackName": stack_name, + "Outputs": stack_outputs, + "StackStatus": "CREATED" + } + + def get_stack(self, stack_name, **kwargs): + try: + return self._stacks[stack_name] + except KeyError: + raise StackDoesNotExist(stack_name) + + def get_outputs(self, stack_name, *args, **kwargs): + return self.get_stack(stack_name)["Outputs"] + + def get_stack_status(self, stack_name, *args, **kwargs): + return self.get_stack(stack_name)["StackStatus"] + + def create_stack(self, stack_name, *args, **kwargs): + try: + stack = self.get_stack(stack_name) + status = self.get_stack_status(stack) + if status != "DELETED": + raise StackUpdateBadStatus(stack_name, status, "can't create") + except StackDoesNotExist: + pass + + return None + + def update_stack(self, stack_name, *args, **kwargs): + stack = self.get_stack(stack_name) + status = self.get_stack_status(stack) + if status == "DELETED": + raise StackUpdateBadStatus(stack_name, status, "can't update") + + stack["StackStatus"] = "UPDATED" + return None + def destroy_stack(self, stack_name, *args, **kwargs): + stack = self.get_stack(stack_name) + status = self.get_stack_status(stack) + if status == "DELETED": + raise StackUpdateBadStatus(stack_name, status, "can't destroy") -def mock_context(namespace="default", extra_config_args=None, **kwargs): + stack["StackStatus"] = "DELETED" + return None + + +def mock_provider(outputs=None, **kwargs): + provider = mock.MagicMock(wraps=MockProvider(outputs), **kwargs) + return provider + + +def mock_context(namespace="default", extra_config_args=None, + environment=None, **kwargs): config_args = {"namespace": namespace} if extra_config_args: config_args.update(extra_config_args) + config = Config(config_args) - if kwargs.get("environment"): - return Context( - config=config, - **kwargs) - return Context( - config=config, - environment={}, - **kwargs) + environment = environment or {} + return Context(config=config, environment=environment, **kwargs) def generate_definition(base_name, stack_id, **overrides): @@ -53,12 +104,6 @@ def generate_definition(base_name, stack_id, **overrides): return Stack(definition) -def mock_lookup(lookup_input, lookup_type, raw=None): - if raw is None: - raw = "%s %s" % (lookup_type, lookup_input) - return Lookup(type=lookup_type, input=lookup_input, raw=raw) - - class SessionStub(object): """Stubber class for boto3 sessions made with session_cache.get_session() diff --git a/stacker/tests/lookups/handlers/test_default.py b/stacker/tests/lookups/handlers/test_default.py index a59ccd6d8..990e510fe 100644 --- a/stacker/tests/lookups/handlers/test_default.py +++ b/stacker/tests/lookups/handlers/test_default.py @@ -1,22 +1,18 @@ from __future__ import print_function from __future__ import division from __future__ import absolute_import -from mock import MagicMock import unittest -from stacker.context import Context from stacker.lookups.handlers.default import DefaultLookup +from ...factories import mock_context, mock_provider -class TestDefaultLookup(unittest.TestCase): +class TestDefaultLookup(unittest.TestCase): def setUp(self): - self.provider = MagicMock() - self.context = Context( - environment={ - 'namespace': 'test', - 'env_var': 'val_in_env'} - ) + self.provider = mock_provider() + self.context = mock_context( + namespace='test', environment={'env_var': 'val_in_env'}) def test_env_var_present(self): lookup_val = "env_var::fallback" diff --git a/stacker/tests/lookups/handlers/test_output.py b/stacker/tests/lookups/handlers/test_output.py index 3891dfe25..691f8814d 100644 --- a/stacker/tests/lookups/handlers/test_output.py +++ b/stacker/tests/lookups/handlers/test_output.py @@ -5,25 +5,22 @@ import unittest from stacker.stack import Stack -from ...factories import generate_definition from stacker.lookups.handlers.output import OutputLookup +from ...factories import generate_definition, mock_context, mock_provider -class TestOutputHandler(unittest.TestCase): +class TestOutputHandler(unittest.TestCase): def setUp(self): - self.context = MagicMock() + stack_def = generate_definition("vpc", 1) + self.context = mock_context() + self.stack = Stack(definition=stack_def, context=self.context) + self.context.get_stacks = MagicMock(return_value=[self.stack]) + self.provider = mock_provider( + outputs={self.stack.fqn: {"SomeOutput": "Test Output"}}) def test_output_handler(self): - stack = Stack( - definition=generate_definition("vpc", 1), - context=self.context) - stack.set_outputs({ - "SomeOutput": "Test Output"}) - self.context.get_stack.return_value = stack - value = OutputLookup.handle("stack-name::SomeOutput", - context=self.context) + value = OutputLookup.handle("{}::SomeOutput".format(self.stack.name), + context=self.context, + provider=self.provider) self.assertEqual(value, "Test Output") - self.assertEqual(self.context.get_stack.call_count, 1) - args = self.context.get_stack.call_args - self.assertEqual(args[0][0], "stack-name") diff --git a/stacker/tests/lookups/handlers/test_rxref.py b/stacker/tests/lookups/handlers/test_rxref.py index b5e7cb828..c6480b916 100644 --- a/stacker/tests/lookups/handlers/test_rxref.py +++ b/stacker/tests/lookups/handlers/test_rxref.py @@ -1,30 +1,24 @@ from __future__ import print_function from __future__ import division from __future__ import absolute_import -from mock import MagicMock import unittest from stacker.lookups.handlers.rxref import RxrefLookup -from ....context import Context -from ....config import Config + +from ...factories import mock_context, mock_provider class TestRxrefHandler(unittest.TestCase): def setUp(self): - self.provider = MagicMock() - self.context = Context( - config=Config({"namespace": "ns"}) - ) + self.context = mock_context() + self.stack_name = "stack-name" + self.stack_fqn = self.context.get_fqn(self.stack_name) + self.provider = mock_provider( + outputs={self.stack_fqn: {"SomeOutput": "Test Output"}}) def test_rxref_handler(self): - self.provider.get_output.return_value = "Test Output" - - value = RxrefLookup.handle("fully-qualified-stack-name::SomeOutput", + value = RxrefLookup.handle("{}::SomeOutput".format(self.stack_name), provider=self.provider, context=self.context) self.assertEqual(value, "Test Output") - - args = self.provider.get_output.call_args - self.assertEqual(args[0][0], "ns-fully-qualified-stack-name") - self.assertEqual(args[0][1], "SomeOutput") diff --git a/stacker/tests/lookups/handlers/test_xref.py b/stacker/tests/lookups/handlers/test_xref.py index cb611ed65..c2b1d1b46 100644 --- a/stacker/tests/lookups/handlers/test_xref.py +++ b/stacker/tests/lookups/handlers/test_xref.py @@ -1,25 +1,23 @@ from __future__ import print_function from __future__ import division from __future__ import absolute_import -from mock import MagicMock import unittest from stacker.lookups.handlers.xref import XrefLookup +from ...factories import mock_context, mock_provider + class TestXrefHandler(unittest.TestCase): def setUp(self): - self.provider = MagicMock() - self.context = MagicMock() + self.stack_fqn = "fully-qualified-stack-name" + self.context = mock_context() + self.provider = mock_provider( + outputs={self.stack_fqn: {"SomeOutput": "Test Output"}}) def test_xref_handler(self): - self.provider.get_output.return_value = "Test Output" - value = XrefLookup.handle("fully-qualified-stack-name::SomeOutput", + value = XrefLookup.handle("{}::SomeOutput".format(self.stack_fqn), provider=self.provider, context=self.context) self.assertEqual(value, "Test Output") - self.assertEqual(self.context.get_fqn.call_count, 0) - args = self.provider.get_output.call_args - self.assertEqual(args[0][0], "fully-qualified-stack-name") - self.assertEqual(args[0][1], "SomeOutput") diff --git a/stacker/tests/test_variables.py b/stacker/tests/test_variables.py index 2b1acbc55..74bc1d6ec 100644 --- a/stacker/tests/test_variables.py +++ b/stacker/tests/test_variables.py @@ -4,85 +4,49 @@ import unittest -from mock import MagicMock - from troposphere import s3 + from stacker.blueprints.variables.types import TroposphereType +from stacker.lookups.handlers import LookupHandler from stacker.variables import Variable from stacker.lookups import register_lookup_handler -from stacker.stack import Stack +from .factories import mock_context, mock_provider -from .factories import generate_definition + +class MockLookup(LookupHandler): + @classmethod + def handle(cls, value, context, provider): + return str(value) class TestVariables(unittest.TestCase): def setUp(self): - self.provider = MagicMock() - self.context = MagicMock() + self.provider = mock_provider() + self.context = mock_context() + + register_lookup_handler("test", MockLookup) def test_variable_replace_no_lookups(self): var = Variable("Param1", "2") self.assertEqual(var.value, "2") - def test_variable_replace_simple_lookup(self): - var = Variable("Param1", "${output fakeStack::FakeOutput}") - var._value._resolve("resolved") - self.assertEqual(var.value, "resolved") - def test_variable_resolve_simple_lookup(self): - stack = Stack( - definition=generate_definition("vpc", 1), - context=self.context) - stack.set_outputs({ - "FakeOutput": "resolved", - "FakeOutput2": "resolved2", - }) - - self.context.get_stack.return_value = stack - - var = Variable("Param1", "${output fakeStack::FakeOutput}") - var.resolve(self.context, self.provider) - self.assertTrue(var.resolved) - self.assertEqual(var.value, "resolved") - - def test_variable_resolve_default_lookup_empty(self): - var = Variable("Param1", "${default fakeStack::}") + var = Variable("Param1", "${noop test}") var.resolve(self.context, self.provider) self.assertTrue(var.resolved) - self.assertEqual(var.value, "") + self.assertEqual(var.value, "test") def test_variable_replace_multiple_lookups_string(self): var = Variable( "Param1", "url://" # 0 - "${output fakeStack::FakeOutput}" # 1 + "${test resolved}" # 1 "@" # 2 - "${output fakeStack::FakeOutput2}", # 3 - ) - var._value[1]._resolve("resolved") - var._value[3]._resolve("resolved2") - self.assertEqual(var.value, "url://resolved@resolved2") - - def test_variable_resolve_multiple_lookups_string(self): - var = Variable( - "Param1", - "url://${output fakeStack::FakeOutput}@" - "${output fakeStack::FakeOutput2}", + "${test resolved2}", # 3 ) - - stack = Stack( - definition=generate_definition("vpc", 1), - context=self.context) - stack.set_outputs({ - "FakeOutput": "resolved", - "FakeOutput2": "resolved2", - }) - - self.context.get_stack.return_value = stack var.resolve(self.context, self.provider) - self.assertTrue(var.resolved) self.assertEqual(var.value, "url://resolved@resolved2") def test_variable_replace_no_lookups_list(self): @@ -90,77 +54,52 @@ def test_variable_replace_no_lookups_list(self): self.assertEqual(var.value, ["something", "here"]) def test_variable_replace_lookups_list(self): - value = ["something", # 0 - "${output fakeStack::FakeOutput}", # 1 - "${output fakeStack::FakeOutput2}" # 2 - ] + value = ["something", "${test resolved}", "${test resolved2}"] var = Variable("Param1", value) - - var._value[1]._resolve("resolved") - var._value[2]._resolve("resolved2") + var.resolve(self.context, self.provider) self.assertEqual(var.value, ["something", "resolved", "resolved2"]) def test_variable_replace_lookups_dict(self): value = { - "something": "${output fakeStack::FakeOutput}", - "other": "${output fakeStack::FakeOutput2}", + "something": "${test resolved}", + "other": "${test resolved2}", } var = Variable("Param1", value) - var._value["something"]._resolve("resolved") - var._value["other"]._resolve("resolved2") - self.assertEqual(var.value, {"something": "resolved", "other": - "resolved2"}) + var.resolve(self.context, self.provider) + self.assertEqual(var.value, {"something": "resolved", + "other": "resolved2"}) def test_variable_replace_lookups_mixed(self): value = { - "something": [ - "${output fakeStack::FakeOutput}", - "other", + "list": [ + "${test 1}", + "2", ], - "here": { - "other": "${output fakeStack::FakeOutput2}", - "same": "${output fakeStack::FakeOutput}", - "mixed": "something:${output fakeStack::FakeOutput3}", + "dict": { + "1": "${test a}", + "2": "${test b}", + "3": "c:${test d}", }, } var = Variable("Param1", value) - var._value["something"][0]._resolve("resolved") - var._value["here"]["other"]._resolve("resolved2") - var._value["here"]["same"]._resolve("resolved") - var._value["here"]["mixed"][1]._resolve("resolved3") + var.resolve(self.context, self.provider) self.assertEqual(var.value, { - "something": [ - "resolved", - "other", - ], - "here": { - "other": "resolved2", - "same": "resolved", - "mixed": "something:resolved3", + "list": ["1", "2"], + "dict": { + "1": "a", + "2": "b", + "3": "c:d", }, }) def test_variable_resolve_nested_lookup(self): - stack = Stack( - definition=generate_definition("vpc", 1), - context=self.context) - stack.set_outputs({ - "FakeOutput": "resolved", - "FakeOutput2": "resolved2", - }) - - def mock_handler(value, context, provider, **kwargs): - return "looked up: {}".format(value) - - register_lookup_handler("lookup", mock_handler) - self.context.get_stack.return_value = stack var = Variable( "Param1", - "${lookup ${lookup ${output fakeStack::FakeOutput}}}", + "${test a:${test b:${test c}}}", ) var.resolve(self.context, self.provider) self.assertTrue(var.resolved) - self.assertEqual(var.value, "looked up: looked up: resolved") + self.assertEqual(var.value, "a:b:c") def test_troposphere_type_no_from_dict(self): with self.assertRaises(ValueError): From 395685ea148c975b207c2c89b2926092f2f48891 Mon Sep 17 00:00:00 2001 From: Daniel Miranda Date: Sun, 17 Mar 2019 18:56:01 -0300 Subject: [PATCH 10/13] tests: disable boto3 session caching to fix moto issues If instances of clients are created outside of the moto patching, they will be cached by boto3 and return in subsequent runs, breaking some tests. While the best fix would be to avoid that situation, disabling the caching altogether is easier and ensures the issue won't return. --- stacker/tests/conftest.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/stacker/tests/conftest.py b/stacker/tests/conftest.py index 6597ebc81..81b40aa86 100644 --- a/stacker/tests/conftest.py +++ b/stacker/tests/conftest.py @@ -3,8 +3,11 @@ import logging import os +import mock import pytest import py.path +from boto3 import Session + logger = logging.getLogger(__name__) @@ -42,3 +45,13 @@ def stacker_fixture_dir(): path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'fixtures') return py.path.local(path) + + +@pytest.fixture(scope='session', autouse=True) +def boto3_disable_session_caching(): + def get_session(**kwargs): + return Session(**kwargs) + + with mock.patch('boto3._get_default_session', + side_effect=get_session): + yield From 37422fd976329ee5ca228f51f8cf277cc669bf32 Mon Sep 17 00:00:00 2001 From: Daniel Miranda Date: Sun, 17 Mar 2019 19:25:02 -0300 Subject: [PATCH 11/13] provider: add get_session method and convert all hooks and lookups This makes hooks respect the universal region and profile settings add in the config. --- stacker/hooks/aws_lambda.py | 4 +- stacker/hooks/ecs.py | 3 +- stacker/hooks/iam.py | 5 +- stacker/hooks/keypair.py | 5 +- stacker/hooks/route53.py | 4 +- stacker/lookups/handlers/ami.py | 33 +- stacker/lookups/handlers/dynamodb.py | 5 +- stacker/lookups/handlers/kms.py | 5 +- stacker/lookups/handlers/output.py | 2 +- stacker/lookups/handlers/rxref.py | 2 +- stacker/lookups/handlers/ssmstore.py | 6 +- stacker/lookups/handlers/xref.py | 5 +- stacker/providers/aws/default.py | 58 +-- stacker/providers/base.py | 4 + stacker/tests/factories.py | 40 +- stacker/tests/hooks/test_aws_lambda.py | 49 +-- stacker/tests/hooks/test_ecs.py | 124 +++--- stacker/tests/hooks/test_iam.py | 86 ++-- stacker/tests/lookups/handlers/test_ami.py | 381 +++++++++--------- .../tests/lookups/handlers/test_dynamodb.py | 150 ++++--- stacker/tests/lookups/handlers/test_kms.py | 44 +- .../tests/lookups/handlers/test_ssmstore.py | 155 +++---- stacker/tests/providers/aws/test_default.py | 4 +- 23 files changed, 608 insertions(+), 566 deletions(-) diff --git a/stacker/hooks/aws_lambda.py b/stacker/hooks/aws_lambda.py index 4b388f40c..9e8c201b7 100644 --- a/stacker/hooks/aws_lambda.py +++ b/stacker/hooks/aws_lambda.py @@ -11,10 +11,10 @@ import hashlib from io import BytesIO as StringIO from zipfile import ZipFile, ZIP_DEFLATED + import botocore import formic from troposphere.awslambda import Code -from stacker.session_cache import get_session from stacker.util import ( get_config_directory, @@ -508,7 +508,7 @@ def create_template(self): payload_acl = kwargs.get('payload_acl', 'private') # Always use the global client for s3 - session = get_session(bucket_region) + session = provider.get_session(region=bucket_region) s3_client = session.client('s3') ensure_s3_bucket(s3_client, bucket_name, bucket_region) diff --git a/stacker/hooks/ecs.py b/stacker/hooks/ecs.py index 308c2eccc..daad432d3 100644 --- a/stacker/hooks/ecs.py +++ b/stacker/hooks/ecs.py @@ -7,7 +7,6 @@ from past.builtins import basestring import logging -from stacker.session_cache import get_session logger = logging.getLogger(__name__) @@ -26,7 +25,7 @@ def create_clusters(provider, context, **kwargs): Returns: boolean for whether or not the hook succeeded. """ - conn = get_session(provider.region).client('ecs') + conn = provider.get_session().client('ecs') try: clusters = kwargs["clusters"] diff --git a/stacker/hooks/iam.py b/stacker/hooks/iam.py index 009888157..2fe9c345a 100644 --- a/stacker/hooks/iam.py +++ b/stacker/hooks/iam.py @@ -5,7 +5,6 @@ import copy import logging -from stacker.session_cache import get_session from botocore.exceptions import ClientError from awacs.aws import Statement, Allow, Policy @@ -32,7 +31,7 @@ def create_ecs_service_role(provider, context, **kwargs): """ role_name = kwargs.get("role_name", "ecsServiceRole") - client = get_session(provider.region).client('iam') + client = provider.get_session().client('iam') try: client.create_role( @@ -125,7 +124,7 @@ def get_cert_contents(kwargs): def ensure_server_cert_exists(provider, context, **kwargs): - client = get_session(provider.region).client('iam') + client = provider.get_session().client('iam') cert_name = kwargs["cert_name"] status = "unknown" try: diff --git a/stacker/hooks/keypair.py b/stacker/hooks/keypair.py index 3114729cd..100d3ef72 100644 --- a/stacker/hooks/keypair.py +++ b/stacker/hooks/keypair.py @@ -8,7 +8,6 @@ from botocore.exceptions import ClientError -from stacker.session_cache import get_session from stacker.hooks import utils from stacker.ui import get_raw_input @@ -220,8 +219,8 @@ def ensure_keypair_exists(provider, context, **kwargs): "specified at the same time") return False - session = get_session(region=provider.region, - profile=kwargs.get("profile")) + session = provider.get_session( + profile=kwargs.get("profile")) ec2 = session.client("ec2") keypair = get_existing_key_pair(ec2, keypair_name) diff --git a/stacker/hooks/route53.py b/stacker/hooks/route53.py index c163e091d..01bc04b41 100644 --- a/stacker/hooks/route53.py +++ b/stacker/hooks/route53.py @@ -3,8 +3,6 @@ from __future__ import absolute_import import logging -from stacker.session_cache import get_session - from stacker.util import create_route53_zone logger = logging.getLogger(__name__) @@ -21,7 +19,7 @@ def create_domain(provider, context, **kwargs): Returns: boolean for whether or not the hook succeeded. """ - session = get_session(provider.region) + session = provider.get_session() client = session.client("route53") domain = kwargs.get("domain") if not domain: diff --git a/stacker/lookups/handlers/ami.py b/stacker/lookups/handlers/ami.py index 8d51c0619..634dce500 100644 --- a/stacker/lookups/handlers/ami.py +++ b/stacker/lookups/handlers/ami.py @@ -1,7 +1,7 @@ from __future__ import print_function from __future__ import division from __future__ import absolute_import -from stacker.session_cache import get_session + import re import operator @@ -22,31 +22,31 @@ def __init__(self, search_string): class AmiLookup(LookupHandler): @classmethod - def handle(cls, value, provider, **kwargs): + def handle(cls, value, context, provider): """Fetch the most recent AMI Id using a filter - + For example: - + ${ami [@]owners:self,account,amazon name_regex:serverX-[0-9]+ architecture:x64,i386} - + The above fetches the most recent AMI where owner is self account or amazon and the ami name matches the regex described, the architecture will be either x64 or i386 - + You can also optionally specify the region in which to perform the AMI lookup. - + Valid arguments: - + owners (comma delimited) REQUIRED ONCE: aws_account_id | amazon | self - + name_regex (a regex) REQUIRED ONCE: e.g. my-ubuntu-server-[0-9]+ - + executable_users (comma delimited) OPTIONAL ONCE: aws_account_id | amazon | self - + Any other arguments specified are sent as filters to the aws api For example, "architecture:x86_64" will add a filter """ # noqa @@ -57,13 +57,13 @@ def handle(cls, value, provider, **kwargs): else: region = provider.region - ec2 = get_session(region).client('ec2') + ec2 = provider.get_session(region=region).client('ec2') values = {} describe_args = {} # now find any other arguments that can be filters - matches = re.findall('([0-9a-zA-z_-]+:[^\s$]+)', value) + matches = re.findall(r'([0-9a-zA-z_-]+:[^\s$]+)', value) for match in matches: k, v = match.split(':', 1) values[k] = v @@ -77,10 +77,9 @@ def handle(cls, value, provider, **kwargs): raise Exception("'name_regex' value required when using ami") name_regex = values.pop('name_regex') - executable_users = None - if values.get('executable_users'): - executable_users = values.pop('executable_users').split(',') - describe_args["ExecutableUsers"] = executable_users + executable_users = values.get('executable_users') + if executable_users: + describe_args["ExecutableUsers"] = executable_users.split(',') filters = [] for k, v in values.items(): diff --git a/stacker/lookups/handlers/dynamodb.py b/stacker/lookups/handlers/dynamodb.py index 9dcd97ce8..44df1b4b0 100644 --- a/stacker/lookups/handlers/dynamodb.py +++ b/stacker/lookups/handlers/dynamodb.py @@ -4,7 +4,6 @@ from builtins import str from botocore.exceptions import ClientError import re -from stacker.session_cache import get_session from . import LookupHandler from ...util import read_value_from_path @@ -14,7 +13,7 @@ class DynamodbLookup(LookupHandler): @classmethod - def handle(cls, value, **kwargs): + def handle(cls, value, context, provider): """Get a value from a dynamodb table dynamodb field types should be in the following format: @@ -53,7 +52,7 @@ def handle(cls, value, **kwargs): projection_expression = _build_projection_expression(clean_table_keys) # lookup the data from dynamodb - dynamodb = get_session(region).client('dynamodb') + dynamodb = provider.get_session(region=region).client('dynamodb') try: response = dynamodb.get_item( TableName=table_name, diff --git a/stacker/lookups/handlers/kms.py b/stacker/lookups/handlers/kms.py index ba80d2779..921925768 100644 --- a/stacker/lookups/handlers/kms.py +++ b/stacker/lookups/handlers/kms.py @@ -2,7 +2,6 @@ from __future__ import division from __future__ import absolute_import import codecs -from stacker.session_cache import get_session from . import LookupHandler from ...util import read_value_from_path @@ -12,7 +11,7 @@ class KmsLookup(LookupHandler): @classmethod - def handle(cls, value, **kwargs): + def handle(cls, value, context, provider): """Decrypt the specified value with a master key in KMS. kmssimple field types should be in the following format: @@ -55,7 +54,7 @@ def handle(cls, value, **kwargs): if "@" in value: region, value = value.split("@", 1) - kms = get_session(region).client('kms') + kms = provider.get_session(region=region).client('kms') # encode str value as an utf-8 bytestring for use with codecs.decode. value = value.encode('utf-8') diff --git a/stacker/lookups/handlers/output.py b/stacker/lookups/handlers/output.py index 66c4c0818..7d67162cd 100644 --- a/stacker/lookups/handlers/output.py +++ b/stacker/lookups/handlers/output.py @@ -17,7 +17,7 @@ class OutputLookup(LookupHandler): @classmethod - def handle(cls, value, context, provider, **kwargs): + def handle(cls, value, context, provider): """Fetch an output from the designated stack.""" d = deconstruct(value) diff --git a/stacker/lookups/handlers/rxref.py b/stacker/lookups/handlers/rxref.py index 858a13a3d..cfb327d04 100644 --- a/stacker/lookups/handlers/rxref.py +++ b/stacker/lookups/handlers/rxref.py @@ -22,7 +22,7 @@ class RxrefLookup(LookupHandler): @classmethod - def handle(cls, value, provider=None, context=None, **kwargs): + def handle(cls, value, context, provider): """Fetch an output from the designated stack. Args: diff --git a/stacker/lookups/handlers/ssmstore.py b/stacker/lookups/handlers/ssmstore.py index 2da724d30..8f6d8eaae 100644 --- a/stacker/lookups/handlers/ssmstore.py +++ b/stacker/lookups/handlers/ssmstore.py @@ -3,8 +3,6 @@ from __future__ import absolute_import from builtins import str -from stacker.session_cache import get_session - from . import LookupHandler from ...util import read_value_from_path @@ -13,7 +11,7 @@ class SsmstoreLookup(LookupHandler): @classmethod - def handle(cls, value, **kwargs): + def handle(cls, value, context, provider): """Retrieve (and decrypt if applicable) a parameter from AWS SSM Parameter Store. @@ -48,7 +46,7 @@ def handle(cls, value, **kwargs): if "@" in value: region, value = value.split("@", 1) - client = get_session(region).client("ssm") + client = provider.get_session(region=region).client("ssm") response = client.get_parameters( Names=[ value, diff --git a/stacker/lookups/handlers/xref.py b/stacker/lookups/handlers/xref.py index a318d252b..6e591fa72 100644 --- a/stacker/lookups/handlers/xref.py +++ b/stacker/lookups/handlers/xref.py @@ -21,7 +21,7 @@ class XrefLookup(LookupHandler): @classmethod - def handle(cls, value, provider=None, **kwargs): + def handle(cls, value, context, provider): """Fetch an output from the designated stack. Args: @@ -34,9 +34,6 @@ def handle(cls, value, provider=None, **kwargs): str: output from the specified stack """ - if provider is None: - raise ValueError('Provider is required') - d = deconstruct(value) stack_fqn = d.stack_name output = provider.get_output(stack_fqn, d.output_name) diff --git a/stacker/providers/aws/default.py b/stacker/providers/aws/default.py index 808531346..1805c70c2 100644 --- a/stacker/providers/aws/default.py +++ b/stacker/providers/aws/default.py @@ -11,19 +11,16 @@ import time import urllib.parse import sys - -# thread safe, memoized, provider builder. from threading import Lock import botocore.exceptions from botocore.config import Config -from ..base import BaseProvider -from ... import exceptions -from ...ui import ui +from stacker import exceptions +from stacker.ui import ui +from stacker.providers.base import BaseProvider from stacker.session_cache import get_session - -from ...actions.diff import ( +from stacker.actions.diff import ( DictValue, diff_parameters, format_params_diff as format_diff @@ -550,17 +547,19 @@ def __init__(self, session, region=None, interactive=False, replacements_only=False, recreate_failed=False, service_role=None, **kwargs): self._outputs = {} - self.region = region - self.cloudformation = get_cloudformation_client(session) + self.region = region or session.region_name self.interactive = interactive # replacements only is only used in interactive mode self.replacements_only = interactive and replacements_only self.recreate_failed = interactive or recreate_failed self.service_role = service_role + self._session = session + self._cloudformation = get_cloudformation_client(session) + def get_stack(self, stack_name, **kwargs): try: - return self.cloudformation.describe_stacks( + return self._cloudformation.describe_stacks( StackName=stack_name)['Stacks'][0] except botocore.exceptions.ClientError as e: if "does not exist" not in str(e): @@ -630,11 +629,11 @@ def get_events(self, stack_name, chronological=True): event_list = [] while True: if next_token is not None: - events = self.cloudformation.describe_stack_events( + events = self._cloudformation.describe_stack_events( StackName=stack_name, NextToken=next_token ) else: - events = self.cloudformation.describe_stack_events( + events = self._cloudformation.describe_stack_events( StackName=stack_name ) event_list.append(events['StackEvents']) @@ -690,7 +689,7 @@ def destroy_stack(self, stack, **kwargs): if self.service_role: args["RoleARN"] = self.service_role - self.cloudformation.delete_stack(**args) + self._cloudformation.delete_stack(**args) return True def create_stack(self, fqn, template, parameters, tags, @@ -723,11 +722,11 @@ def create_stack(self, fqn, template, parameters, tags, logger.debug("force_change_set set to True, creating stack with " "changeset.") _changes, change_set_id = create_change_set( - self.cloudformation, fqn, template, parameters, tags, + self._cloudformation, fqn, template, parameters, tags, 'CREATE', service_role=self.service_role, **kwargs ) - self.cloudformation.execute_change_set( + self._cloudformation.execute_change_set( ChangeSetName=change_set_id, ) else: @@ -738,14 +737,14 @@ def create_stack(self, fqn, template, parameters, tags, ) try: - self.cloudformation.create_stack(**args) + self._cloudformation.create_stack(**args) except botocore.exceptions.ClientError as e: if e.response['Error']['Message'] == ('TemplateURL must ' 'reference a valid S3 ' 'object to which you ' 'have access.'): s3_fallback(fqn, template, parameters, tags, - self.cloudformation.create_stack, + self._cloudformation.create_stack, self.service_role) else: raise @@ -887,7 +886,7 @@ def deal_with_changeset_stack_policy(self, fqn, stack_policy): kwargs = generate_stack_policy_args(stack_policy) kwargs["StackName"] = fqn logger.debug("Setting stack policy on %s.", fqn) - self.cloudformation.set_stack_policy(**kwargs) + self._cloudformation.set_stack_policy(**kwargs) def interactive_update_stack(self, fqn, template, old_parameters, parameters, stack_policy, tags, @@ -909,7 +908,7 @@ def interactive_update_stack(self, fqn, template, old_parameters, """ logger.debug("Using interactive provider mode for %s.", fqn) changes, change_set_id = create_change_set( - self.cloudformation, fqn, template, parameters, tags, + self._cloudformation, fqn, template, parameters, tags, 'UPDATE', service_role=self.service_role, **kwargs ) old_parameters_as_dict = self.params_as_dict(old_parameters) @@ -944,7 +943,7 @@ def interactive_update_stack(self, fqn, template, old_parameters, self.deal_with_changeset_stack_policy(fqn, stack_policy) - self.cloudformation.execute_change_set( + self._cloudformation.execute_change_set( ChangeSetName=change_set_id, ) @@ -972,13 +971,13 @@ def noninteractive_changeset_update(self, fqn, template, old_parameters, logger.debug("Using noninterative changeset provider mode " "for %s.", fqn) _changes, change_set_id = create_change_set( - self.cloudformation, fqn, template, parameters, tags, + self._cloudformation, fqn, template, parameters, tags, 'UPDATE', service_role=self.service_role, **kwargs ) self.deal_with_changeset_stack_policy(fqn, stack_policy) - self.cloudformation.execute_change_set( + self._cloudformation.execute_change_set( ChangeSetName=change_set_id, ) @@ -1008,7 +1007,7 @@ def default_update_stack(self, fqn, template, old_parameters, parameters, ) try: - self.cloudformation.update_stack(**args) + self._cloudformation.update_stack(**args) except botocore.exceptions.ClientError as e: if "No updates are to be performed." in str(e): logger.debug( @@ -1021,7 +1020,7 @@ def default_update_stack(self, fqn, template, old_parameters, parameters, 'S3 object to which ' 'you have access.'): s3_fallback(fqn, template, parameters, tags, - self.cloudformation.update_stack, + self._cloudformation.update_stack, self.service_role) else: raise @@ -1038,9 +1037,6 @@ def get_outputs(self, stack_name, *args, **kwargs): self._outputs[stack_name] = get_output_dict(stack) return self._outputs[stack_name] - def get_output_dict(self, stack): - return get_output_dict(stack) - def get_stack_info(self, stack): """ Get the template and parameters of the stack currently in AWS @@ -1049,7 +1045,7 @@ def get_stack_info(self, stack): stack_name = stack['StackId'] try: - template = self.cloudformation.get_template( + template = self._cloudformation.get_template( StackName=stack_name)['TemplateBody'] except botocore.exceptions.ClientError as e: if "does not exist" not in str(e): @@ -1066,3 +1062,9 @@ def params_as_dict(parameters_list): for p in parameters_list: parameters[p['ParameterKey']] = p['ParameterValue'] return parameters + + def get_session(self, **kwargs): + kwargs.setdefault('region', self._session.region_name) + kwargs.setdefault('profile', self._session.profile_name) + + return get_session(**kwargs) diff --git a/stacker/providers/base.py b/stacker/providers/base.py index c48291f13..36208d1f0 100644 --- a/stacker/providers/base.py +++ b/stacker/providers/base.py @@ -43,6 +43,10 @@ def get_output(self, stack_name, output): # pylint: disable=unused-argument return self.get_outputs(stack_name)[output] + def get_session(self, region=None, profile=None): + # pylint: disable=unused-argument + not_implemented("get_session") + class Template(object): """A value object that represents a CloudFormation stack template, which diff --git a/stacker/tests/factories.py b/stacker/tests/factories.py index e818460a9..ca5c64337 100644 --- a/stacker/tests/factories.py +++ b/stacker/tests/factories.py @@ -5,6 +5,8 @@ import mock +import boto3 + from stacker.context import Context from stacker.config import Config, Stack from stacker.exceptions import StackDoesNotExist, StackUpdateBadStatus @@ -26,7 +28,10 @@ def build(self, region=None, profile=None): class MockProvider(BaseProvider): - def __init__(self, outputs=None): + def __init__(self, outputs=None, region=None, profile=None): + self.region = region + self.profile = profile + self._stacks = {} for stack_name, stack_outputs in (outputs or {}).items(): self._stacks[stack_name] = { @@ -34,6 +39,7 @@ def __init__(self, outputs=None): "Outputs": stack_outputs, "StackStatus": "CREATED" } + self._sessions = {} def get_stack(self, stack_name, **kwargs): try: @@ -76,9 +82,13 @@ def destroy_stack(self, stack_name, *args, **kwargs): stack["StackStatus"] = "DELETED" return None + def get_session(self, region=None, profile=None): + return boto3.Session(region_name=region or self.region, + profile_name=profile or self.profile) -def mock_provider(outputs=None, **kwargs): - provider = mock.MagicMock(wraps=MockProvider(outputs), **kwargs) + +def mock_provider(outputs=None, region=None, profile=None, **kwargs): + provider = MockProvider(outputs, region=region, profile=profile) return provider @@ -93,6 +103,30 @@ def mock_context(namespace="default", extra_config_args=None, return Context(config=config, environment=environment, **kwargs) +def mock_boto3_client(service_name, region=None, profile=None): + client = boto3.client(service_name, region_name=region) + default_session = boto3._get_default_session() + + region = region or default_session.region_name + profile = profile or default_session.profile_name + svc_name = service_name + + def create_client(self, service_name, region_name=None, **kwargs): + region_name = region_name or self.region_name + profile_name = self.profile_name + if (svc_name, region, profile) == \ + (service_name, region_name, profile_name): + return client + + raise AssertionError( + "Attempted to create non-mocked AWS client: service={} region={} " + "profile={}".format(service_name, region_name, profile_name)) + + mock_ = mock.patch('boto3.Session.client', autospec=True, + side_effect=create_client) + return client, mock_ + + def generate_definition(base_name, stack_id, **overrides): definition = { "name": "%s.%d" % (base_name, stack_id), diff --git a/stacker/tests/hooks/test_aws_lambda.py b/stacker/tests/hooks/test_aws_lambda.py index 67acc934d..d184d82f2 100644 --- a/stacker/tests/hooks/test_aws_lambda.py +++ b/stacker/tests/hooks/test_aws_lambda.py @@ -12,22 +12,18 @@ from io import BytesIO as StringIO from zipfile import ZipFile -import boto3 import botocore from troposphere.awslambda import Code from moto import mock_s3 from testfixtures import TempDirectory, ShouldRaise, compare -from stacker.context import Context -from stacker.config import Config from stacker.hooks.aws_lambda import ( upload_lambda_functions, ZIP_PERMS_MASK, _calculate_hash, select_bucket_region, ) -from ..factories import mock_provider - +from ..factories import mock_provider, mock_context, mock_boto3_client REGION = "us-east-1" ALL_FILES = ( @@ -52,12 +48,6 @@ def temp_directory_with_files(cls, files=ALL_FILES): d.write(f, b'') return d - @property - def s3(self): - if not hasattr(self, '_s3'): - self._s3 = boto3.client('s3', region_name=REGION) - return self._s3 - def assert_s3_zip_file_list(self, bucket, key, files): object_info = self.s3.get_object(Bucket=bucket, Key=key) zip_data = StringIO(object_info['Body'].read()) @@ -82,11 +72,6 @@ def assert_s3_bucket(self, bucket, present=True): if present: self.fail('s3: bucket {} does not exist'.format(bucket)) - def setUp(self): - self.context = Context( - config=Config({'namespace': 'test', 'stacker_bucket': 'test'})) - self.provider = mock_provider(region="us-east-1") - def run_hook(self, **kwargs): real_kwargs = { 'context': self.context, @@ -96,14 +81,26 @@ def run_hook(self, **kwargs): return upload_lambda_functions(**real_kwargs) - @mock_s3 + def setUp(self): + self.context = mock_context( + extra_config_args={'stacker_bucket': 'test'}) + self.provider = mock_provider(region="us-east-1") + + self.mock_s3 = mock_s3() + self.mock_s3.start() + self.s3, self.client_mock = mock_boto3_client('s3', 'us-east-1') + self.client_mock.start() + + def tearDown(self): + self.client_mock.stop() + self.mock_s3.stop() + def test_bucket_default(self): self.assertIsNotNone( self.run_hook(functions={})) self.assert_s3_bucket('test') - @mock_s3 def test_bucket_custom(self): self.assertIsNotNone( self.run_hook(bucket='custom', functions={})) @@ -111,7 +108,6 @@ def test_bucket_custom(self): self.assert_s3_bucket('test', present=False) self.assert_s3_bucket('custom') - @mock_s3 def test_prefix(self): with self.temp_directory_with_files() as d: results = self.run_hook(prefix='cloudformation-custom-resources/', @@ -129,7 +125,6 @@ def test_prefix(self): self.assertTrue(code.S3Key.startswith( 'cloudformation-custom-resources/lambda-MyFunction-')) - @mock_s3 def test_prefix_missing(self): with self.temp_directory_with_files() as d: results = self.run_hook(functions={ @@ -145,7 +140,6 @@ def test_prefix_missing(self): self.assert_s3_zip_file_list(code.S3Bucket, code.S3Key, F1_FILES) self.assertTrue(code.S3Key.startswith('lambda-MyFunction-')) - @mock_s3 def test_path_missing(self): msg = "missing required property 'path' in function 'MyFunction'" with ShouldRaise(ValueError(msg)): @@ -154,7 +148,6 @@ def test_path_missing(self): } }) - @mock_s3 def test_path_relative(self): get_config_directory = 'stacker.hooks.aws_lambda.get_config_directory' with self.temp_directory_with_files(['test/test.py']) as d, \ @@ -173,7 +166,6 @@ def test_path_relative(self): self.assertIsInstance(code, Code) self.assert_s3_zip_file_list(code.S3Bucket, code.S3Key, ['test.py']) - @mock_s3 def test_path_home_relative(self): test_path = '~/test' @@ -195,7 +187,6 @@ def test_path_home_relative(self): self.assertIsInstance(code, Code) self.assert_s3_zip_file_list(code.S3Bucket, code.S3Key, ['test.py']) - @mock_s3 def test_multiple_functions(self): with self.temp_directory_with_files() as d: results = self.run_hook(functions={ @@ -217,7 +208,6 @@ def test_multiple_functions(self): self.assertIsInstance(f2_code, Code) self.assert_s3_zip_file_list(f2_code.S3Bucket, f2_code.S3Key, F2_FILES) - @mock_s3 def test_patterns_invalid(self): msg = ("Invalid file patterns in key 'include': must be a string or " 'list of strings') @@ -230,7 +220,6 @@ def test_patterns_invalid(self): } }) - @mock_s3 def test_patterns_include(self): with self.temp_directory_with_files() as d: results = self.run_hook(functions={ @@ -252,7 +241,6 @@ def test_patterns_include(self): 'test2/test.txt' ]) - @mock_s3 def test_patterns_exclude(self): with self.temp_directory_with_files() as d: results = self.run_hook(functions={ @@ -272,7 +260,6 @@ def test_patterns_exclude(self): 'test2/test.txt' ]) - @mock_s3 def test_patterns_include_exclude(self): with self.temp_directory_with_files() as d: results = self.run_hook(functions={ @@ -292,7 +279,6 @@ def test_patterns_include_exclude(self): '__init__.py' ]) - @mock_s3 def test_patterns_exclude_all(self): msg = ('Empty list of files for Lambda payload. Check your ' 'include/exclude options for errors.') @@ -309,7 +295,6 @@ def test_patterns_exclude_all(self): self.assertIsNone(results) - @mock_s3 def test_idempotence(self): bucket_name = 'test' @@ -396,7 +381,6 @@ def test_select_bucket_region(self): for args, result in tests: self.assertEqual(select_bucket_region(*args), result) - @mock_s3 def test_follow_symlink_nonbool(self): msg = "follow_symlinks option must be a boolean" with ShouldRaise(ValueError(msg)): @@ -405,7 +389,6 @@ def test_follow_symlink_nonbool(self): } }) - @mock_s3 def test_follow_symlink_true(self): # Testing if symlinks are followed with self.temp_directory_with_files() as d1: @@ -439,7 +422,6 @@ def test_follow_symlink_true(self): 'f3/test2/test.txt' ]) - @mock_s3 def test_follow_symlink_false(self): # testing if syminks are present and not folllowed with self.temp_directory_with_files() as d1: @@ -466,7 +448,6 @@ def test_follow_symlink_false(self): 'f2/f2.js', ]) - @mock_s3 def test_follow_symlink_omitted(self): # same as test_follow_symlink_false, but default behaivor with self.temp_directory_with_files() as d1: diff --git a/stacker/tests/hooks/test_ecs.py b/stacker/tests/hooks/test_ecs.py index 12998590f..0b1980dcc 100644 --- a/stacker/tests/hooks/test_ecs.py +++ b/stacker/tests/hooks/test_ecs.py @@ -3,15 +3,12 @@ from __future__ import absolute_import import unittest -import boto3 from moto import mock_ecs from testfixtures import LogCapture from stacker.hooks.ecs import create_clusters -from ..factories import ( - mock_context, - mock_provider, -) +from ..factories import mock_boto3_client, mock_context, mock_provider + REGION = "us-east-1" @@ -22,14 +19,48 @@ def setUp(self): self.provider = mock_provider(region=REGION) self.context = mock_context(namespace="fake") + self.mock_ecs = mock_ecs() + self.mock_ecs.start() + self.ecs, self.ecs_mock = mock_boto3_client("ecs", region=REGION) + self.ecs_mock.start() + + def tearDown(self): + self.ecs_mock.stop() + self.mock_ecs.stop() + def test_create_single_cluster(self): - with mock_ecs(): - cluster = "test-cluster" - logger = "stacker.hooks.ecs" - client = boto3.client("ecs", region_name=REGION) - response = client.list_clusters() + cluster = "test-cluster" + logger = "stacker.hooks.ecs" + response = self.ecs.list_clusters() + + self.assertEqual(len(response["clusterArns"]), 0) + with LogCapture(logger) as logs: + self.assertTrue( + create_clusters( + provider=self.provider, + context=self.context, + clusters=cluster, + ) + ) - self.assertEqual(len(response["clusterArns"]), 0) + logs.check( + ( + logger, + "DEBUG", + "Creating ECS cluster: %s" % cluster + ) + ) + + response = self.ecs.list_clusters() + self.assertEqual(len(response["clusterArns"]), 1) + + def test_create_multiple_clusters(self): + clusters = ("test-cluster0", "test-cluster1") + logger = "stacker.hooks.ecs" + response = self.ecs.list_clusters() + + self.assertEqual(len(response["clusterArns"]), 0) + for cluster in clusters: with LogCapture(logger) as logs: self.assertTrue( create_clusters( @@ -47,58 +78,27 @@ def test_create_single_cluster(self): ) ) - response = client.list_clusters() - self.assertEqual(len(response["clusterArns"]), 1) - - def test_create_multiple_clusters(self): - with mock_ecs(): - clusters = ("test-cluster0", "test-cluster1") - logger = "stacker.hooks.ecs" - client = boto3.client("ecs", region_name=REGION) - response = client.list_clusters() - - self.assertEqual(len(response["clusterArns"]), 0) - for cluster in clusters: - with LogCapture(logger) as logs: - self.assertTrue( - create_clusters( - provider=self.provider, - context=self.context, - clusters=cluster, - ) - ) - - logs.check( - ( - logger, - "DEBUG", - "Creating ECS cluster: %s" % cluster - ) - ) - - response = client.list_clusters() - self.assertEqual(len(response["clusterArns"]), 2) + response = self.ecs.list_clusters() + self.assertEqual(len(response["clusterArns"]), 2) def test_fail_create_cluster(self): - with mock_ecs(): - logger = "stacker.hooks.ecs" - client = boto3.client("ecs", region_name=REGION) - response = client.list_clusters() - - self.assertEqual(len(response["clusterArns"]), 0) - with LogCapture(logger) as logs: - create_clusters( - provider=self.provider, - context=self.context - ) - - logs.check( - ( - logger, - "ERROR", - "setup_clusters hook missing \"clusters\" argument" - ) + logger = "stacker.hooks.ecs" + response = self.ecs.list_clusters() + + self.assertEqual(len(response["clusterArns"]), 0) + with LogCapture(logger) as logs: + create_clusters( + provider=self.provider, + context=self.context + ) + + logs.check( + ( + logger, + "ERROR", + "setup_clusters hook missing \"clusters\" argument" ) + ) - response = client.list_clusters() - self.assertEqual(len(response["clusterArns"]), 0) + response = self.ecs.list_clusters() + self.assertEqual(len(response["clusterArns"]), 0) diff --git a/stacker/tests/hooks/test_iam.py b/stacker/tests/hooks/test_iam.py index d194f4f06..55d799ec0 100644 --- a/stacker/tests/hooks/test_iam.py +++ b/stacker/tests/hooks/test_iam.py @@ -3,22 +3,15 @@ from __future__ import absolute_import import unittest -import boto3 +from awacs.helpers.trust import get_ecs_assumerole_policy from botocore.exceptions import ClientError - from moto import mock_iam from stacker.hooks.iam import ( create_ecs_service_role, _get_cert_arn_from_response, ) - -from awacs.helpers.trust import get_ecs_assumerole_policy - -from ..factories import ( - mock_context, - mock_provider, -) +from ..factories import mock_boto3_client, mock_context, mock_provider REGION = "us-east-1" @@ -34,6 +27,15 @@ def setUp(self): self.context = mock_context(namespace="fake") self.provider = mock_provider(region=REGION) + self.mock_iam = mock_iam() + self.mock_iam.start() + self.iam, self.client_mock = mock_boto3_client("iam", region=REGION) + self.client_mock.start() + + def tearDown(self): + self.client_mock.stop() + self.mock_iam.stop() + def test_get_cert_arn_from_response(self): arn = "fake-arn" # Creation response @@ -52,50 +54,48 @@ def test_get_cert_arn_from_response(self): def test_create_service_role(self): role_name = "ecsServiceRole" policy_name = "AmazonEC2ContainerServiceRolePolicy" - with mock_iam(): - client = boto3.client("iam", region_name=REGION) - with self.assertRaises(ClientError): - client.get_role(RoleName=role_name) + with self.assertRaises(ClientError): + self.iam.get_role(RoleName=role_name) - self.assertTrue( - create_ecs_service_role( - context=self.context, - provider=self.provider, - ) + self.assertTrue( + create_ecs_service_role( + context=self.context, + provider=self.provider, ) + ) - role = client.get_role(RoleName=role_name) + role = self.iam.get_role(RoleName=role_name) - self.assertIn("Role", role) - self.assertEqual(role_name, role["Role"]["RoleName"]) - client.get_role_policy( - RoleName=role_name, - PolicyName=policy_name - ) + self.assertIn("Role", role) + self.assertEqual(role_name, role["Role"]["RoleName"]) + + self.iam.get_role_policy( + RoleName=role_name, + PolicyName=policy_name + ) def test_create_service_role_already_exists(self): role_name = "ecsServiceRole" policy_name = "AmazonEC2ContainerServiceRolePolicy" - with mock_iam(): - client = boto3.client("iam", region_name=REGION) - client.create_role( - RoleName=role_name, - AssumeRolePolicyDocument=get_ecs_assumerole_policy().to_json() - ) - self.assertTrue( - create_ecs_service_role( - context=self.context, - provider=self.provider, - ) + self.iam.create_role( + RoleName=role_name, + AssumeRolePolicyDocument=get_ecs_assumerole_policy().to_json() + ) + + self.assertTrue( + create_ecs_service_role( + context=self.context, + provider=self.provider, ) + ) - role = client.get_role(RoleName=role_name) + role = self.iam.get_role(RoleName=role_name) - self.assertIn("Role", role) - self.assertEqual(role_name, role["Role"]["RoleName"]) - client.get_role_policy( - RoleName=role_name, - PolicyName=policy_name - ) + self.assertIn("Role", role) + self.assertEqual(role_name, role["Role"]["RoleName"]) + self.iam.get_role_policy( + RoleName=role_name, + PolicyName=policy_name + ) diff --git a/stacker/tests/lookups/handlers/test_ami.py b/stacker/tests/lookups/handlers/test_ami.py index 0e34b7b47..0b4f46423 100644 --- a/stacker/tests/lookups/handlers/test_ami.py +++ b/stacker/tests/lookups/handlers/test_ami.py @@ -1,194 +1,209 @@ from __future__ import print_function from __future__ import division from __future__ import absolute_import -import unittest -import mock -from botocore.stub import Stubber -from stacker.lookups.handlers.ami import AmiLookup, ImageNotFound -import boto3 -from stacker.tests.factories import SessionStub, mock_provider - -REGION = "us-east-1" +from botocore.stub import Stubber +import pytest -class TestAMILookup(unittest.TestCase): - client = boto3.client("ec2", region_name=REGION) - - def setUp(self): - self.stubber = Stubber(self.client) - self.provider = mock_provider(region=REGION) - - @mock.patch("stacker.lookups.handlers.ami.get_session", - return_value=SessionStub(client)) - def test_basic_lookup_single_image(self, mock_client): - image_id = "ami-fffccc111" - self.stubber.add_response( - "describe_images", - { - "Images": [ - { - "OwnerId": "897883143566", - "Architecture": "x86_64", - "CreationDate": "2011-02-13T01:17:44.000Z", - "State": "available", - "ImageId": image_id, - "Name": "Fake Image 1", - "VirtualizationType": "hvm", - } - ] - } - ) - - with self.stubber: - value = AmiLookup.handle( - value="owners:self name_regex:Fake\sImage\s\d", - provider=self.provider - ) - self.assertEqual(value, image_id) - - @mock.patch("stacker.lookups.handlers.ami.get_session", - return_value=SessionStub(client)) - def test_basic_lookup_with_region(self, mock_client): - image_id = "ami-fffccc111" - self.stubber.add_response( - "describe_images", - { - "Images": [ - { - "OwnerId": "897883143566", - "Architecture": "x86_64", - "CreationDate": "2011-02-13T01:17:44.000Z", - "State": "available", - "ImageId": image_id, - "Name": "Fake Image 1", - "VirtualizationType": "hvm", - } - ] - } - ) +from stacker.lookups.handlers.ami import AmiLookup, ImageNotFound +from ...factories import mock_boto3_client, mock_context, mock_provider - with self.stubber: - value = AmiLookup.handle( - value="us-west-1@owners:self name_regex:Fake\sImage\s\d", - provider=self.provider - ) - self.assertEqual(value, image_id) - - @mock.patch("stacker.lookups.handlers.ami.get_session", - return_value=SessionStub(client)) - def test_basic_lookup_multiple_images(self, mock_client): - image_id = "ami-fffccc111" - self.stubber.add_response( - "describe_images", - { - "Images": [ - { - "OwnerId": "897883143566", - "Architecture": "x86_64", - "CreationDate": "2011-02-13T01:17:44.000Z", - "State": "available", - "ImageId": "ami-fffccc110", - "Name": "Fake Image 1", - "VirtualizationType": "hvm", - }, - { - "OwnerId": "897883143566", - "Architecture": "x86_64", - "CreationDate": "2011-02-14T01:17:44.000Z", - "State": "available", - "ImageId": image_id, - "Name": "Fake Image 2", - "VirtualizationType": "hvm", - }, - ] - } - ) - with self.stubber: - value = AmiLookup.handle( - value="owners:self name_regex:Fake\sImage\s\d", - provider=self.provider - ) - self.assertEqual(value, image_id) - - @mock.patch("stacker.lookups.handlers.ami.get_session", - return_value=SessionStub(client)) - def test_basic_lookup_multiple_images_name_match(self, mock_client): - image_id = "ami-fffccc111" - self.stubber.add_response( - "describe_images", - { - "Images": [ - { - "OwnerId": "897883143566", - "Architecture": "x86_64", - "CreationDate": "2011-02-13T01:17:44.000Z", - "State": "available", - "ImageId": "ami-fffccc110", - "Name": "Fa---ke Image 1", - "VirtualizationType": "hvm", - }, - { - "OwnerId": "897883143566", - "Architecture": "x86_64", - "CreationDate": "2011-02-14T01:17:44.000Z", - "State": "available", - "ImageId": image_id, - "Name": "Fake Image 2", - "VirtualizationType": "hvm", - }, - ] - } +REGION = "us-east-1" +ALT_REGION = "us-east-2" + + +@pytest.fixture +def context(): + return mock_context() + + +@pytest.fixture(params=[dict(region=REGION)]) +def provider(request): + return mock_provider(**request.param) + + +@pytest.fixture(params=[dict(region=REGION)]) +def ec2(request): + client, mock = mock_boto3_client("ec2", **request.param) + with mock: + yield client + + +@pytest.fixture +def ec2_stubber(ec2): + with Stubber(ec2) as stubber: + yield stubber + + +def test_basic_lookup_single_image(ec2_stubber, context, provider): + image_id = "ami-fffccc111" + ec2_stubber.add_response( + "describe_images", + { + "Images": [ + { + "OwnerId": "897883143566", + "Architecture": "x86_64", + "CreationDate": "2011-02-13T01:17:44.000Z", + "State": "available", + "ImageId": image_id, + "Name": "Fake Image 1", + "VirtualizationType": "hvm", + } + ] + } + ) + + value = AmiLookup.handle( + value=r"owners:self name_regex:Fake\sImage\s\d", + context=context, + provider=provider + ) + assert value == image_id + + +@pytest.mark.parametrize("ec2", [dict(region=ALT_REGION)], indirect=True) +def test_basic_lookup_with_region(ec2_stubber, context, provider): + image_id = "ami-fffccc111" + ec2_stubber.add_response( + "describe_images", + { + "Images": [ + { + "OwnerId": "897883143566", + "Architecture": "x86_64", + "CreationDate": "2011-02-13T01:17:44.000Z", + "State": "available", + "ImageId": image_id, + "Name": "Fake Image 1", + "VirtualizationType": "hvm", + } + ] + } + ) + + key = r"{}@owners:self name_regex:Fake\sImage\s\d".format(ALT_REGION) + value = AmiLookup.handle( + value=key, + context=context, + provider=provider + ) + assert value == image_id + + +def test_basic_lookup_multiple_images(ec2_stubber, context, provider): + image_id = "ami-fffccc111" + ec2_stubber.add_response( + "describe_images", + { + "Images": [ + { + "OwnerId": "897883143566", + "Architecture": "x86_64", + "CreationDate": "2011-02-13T01:17:44.000Z", + "State": "available", + "ImageId": "ami-fffccc110", + "Name": "Fake Image 1", + "VirtualizationType": "hvm", + }, + { + "OwnerId": "897883143566", + "Architecture": "x86_64", + "CreationDate": "2011-02-14T01:17:44.000Z", + "State": "available", + "ImageId": image_id, + "Name": "Fake Image 2", + "VirtualizationType": "hvm", + }, + ] + } + ) + + value = AmiLookup.handle( + value=r"owners:self name_regex:Fake\sImage\s\d", + context=context, + provider=provider + ) + assert value == image_id + + +def test_basic_lookup_multiple_images_name_match(ec2_stubber, context, + provider): + image_id = "ami-fffccc111" + ec2_stubber.add_response( + "describe_images", + { + "Images": [ + { + "OwnerId": "897883143566", + "Architecture": "x86_64", + "CreationDate": "2011-02-13T01:17:44.000Z", + "State": "available", + "ImageId": "ami-fffccc110", + "Name": "Fa---ke Image 1", + "VirtualizationType": "hvm", + }, + { + "OwnerId": "897883143566", + "Architecture": "x86_64", + "CreationDate": "2011-02-14T01:17:44.000Z", + "State": "available", + "ImageId": image_id, + "Name": "Fake Image 2", + "VirtualizationType": "hvm", + }, + ] + } + ) + + value = AmiLookup.handle( + value=r"owners:self name_regex:Fake\sImage\s\d", + context=context, + provider=provider + ) + assert value == image_id + + +def test_basic_lookup_no_matching_images(ec2_stubber, context, provider): + ec2_stubber.add_response( + "describe_images", + { + "Images": [] + } + ) + + with pytest.raises(ImageNotFound): + AmiLookup.handle( + value=r"owners:self name_regex:Fake\sImage\s\d", + context=context, + provider=provider ) - with self.stubber: - value = AmiLookup.handle( - value="owners:self name_regex:Fake\sImage\s\d", - provider=self.provider - ) - self.assertEqual(value, image_id) - - @mock.patch("stacker.lookups.handlers.ami.get_session", - return_value=SessionStub(client)) - def test_basic_lookup_no_matching_images(self, mock_client): - self.stubber.add_response( - "describe_images", - { - "Images": [] - } - ) - with self.stubber: - with self.assertRaises(ImageNotFound): - AmiLookup.handle( - value="owners:self name_regex:Fake\sImage\s\d", - provider=self.provider - ) - - @mock.patch("stacker.lookups.handlers.ami.get_session", - return_value=SessionStub(client)) - def test_basic_lookup_no_matching_images_from_name(self, mock_client): - image_id = "ami-fffccc111" - self.stubber.add_response( - "describe_images", - { - "Images": [ - { - "OwnerId": "897883143566", - "Architecture": "x86_64", - "CreationDate": "2011-02-13T01:17:44.000Z", - "State": "available", - "ImageId": image_id, - "Name": "Fake Image 1", - "VirtualizationType": "hvm", - } - ] - } +def test_basic_lookup_no_matching_images_from_name(ec2_stubber, context, + provider): + image_id = "ami-fffccc111" + ec2_stubber.add_response( + "describe_images", + { + "Images": [ + { + "OwnerId": "897883143566", + "Architecture": "x86_64", + "CreationDate": "2011-02-13T01:17:44.000Z", + "State": "available", + "ImageId": image_id, + "Name": "Fake Image 1", + "VirtualizationType": "hvm", + } + ] + } + ) + + with pytest.raises(ImageNotFound): + AmiLookup.handle( + value=r"owners:self name_regex:MyImage\s\d", + context=context, + provider=provider ) - - with self.stubber: - with self.assertRaises(ImageNotFound): - AmiLookup.handle( - value="owners:self name_regex:MyImage\s\d", - provider=self.provider - ) diff --git a/stacker/tests/lookups/handlers/test_dynamodb.py b/stacker/tests/lookups/handlers/test_dynamodb.py index 44b6cc693..a23dcf3cf 100644 --- a/stacker/tests/lookups/handlers/test_dynamodb.py +++ b/stacker/tests/lookups/handlers/test_dynamodb.py @@ -2,18 +2,26 @@ from __future__ import division from __future__ import absolute_import import unittest -import mock + from botocore.stub import Stubber + from stacker.lookups.handlers.dynamodb import DynamodbLookup -import boto3 -from stacker.tests.factories import SessionStub +from ...factories import mock_context, mock_provider, mock_boto3_client +REGION = 'us-east-1' -class TestDynamoDBHandler(unittest.TestCase): - client = boto3.client('dynamodb', region_name='us-east-1') +class TestDynamoDBHandler(unittest.TestCase): def setUp(self): - self.stubber = Stubber(self.client) + self.context = mock_context() + self.provider = mock_provider(region=REGION) + + self.dynamodb, self.client_mock = \ + mock_boto3_client("dynamodb", region=REGION) + self.client_mock.start() + self.stubber = Stubber(self.dynamodb) + self.stubber.activate() + self.get_parameters_response = {'Item': {'TestMap': {'M': { 'String1': {'S': 'StringVal1'}, 'List1': {'L': [ @@ -21,9 +29,11 @@ def setUp(self): {'S': 'ListVal2'}]}, 'Number1': {'N': '12345'}, }}}} - @mock.patch('stacker.lookups.handlers.dynamodb.get_session', - return_value=SessionStub(client)) - def test_dynamodb_handler(self, mock_client): + def tearDown(self): + self.client_mock.stop() + self.stubber.deactivate() + + def test_dynamodb_handler(self): expected_params = { 'TableName': 'TestTable', 'Key': { @@ -36,13 +46,11 @@ def test_dynamodb_handler(self, mock_client): self.stubber.add_response('get_item', self.get_parameters_response, expected_params) - with self.stubber: - value = DynamodbLookup.handle(base_lookup_key) - self.assertEqual(value, base_lookup_key_valid) + value = DynamodbLookup.handle( + base_lookup_key, self.context, self.provider) + self.assertEqual(value, base_lookup_key_valid) - @mock.patch('stacker.lookups.handlers.dynamodb.get_session', - return_value=SessionStub(client)) - def test_dynamodb_number_handler(self, mock_client): + def test_dynamodb_number_handler(self): expected_params = { 'TableName': 'TestTable', 'Key': { @@ -56,13 +64,12 @@ def test_dynamodb_number_handler(self, mock_client): self.stubber.add_response('get_item', self.get_parameters_response, expected_params) - with self.stubber: - value = DynamodbLookup.handle(base_lookup_key) - self.assertEqual(value, base_lookup_key_valid) - @mock.patch('stacker.lookups.handlers.dynamodb.get_session', - return_value=SessionStub(client)) - def test_dynamodb_list_handler(self, mock_client): + value = DynamodbLookup.handle( + base_lookup_key, self.context, self.provider) + self.assertEqual(value, base_lookup_key_valid) + + def test_dynamodb_list_handler(self): expected_params = { 'TableName': 'TestTable', 'Key': { @@ -76,13 +83,12 @@ def test_dynamodb_list_handler(self, mock_client): self.stubber.add_response('get_item', self.get_parameters_response, expected_params) - with self.stubber: - value = DynamodbLookup.handle(base_lookup_key) - self.assertEqual(value, base_lookup_key_valid) - @mock.patch('stacker.lookups.handlers.dynamodb.get_session', - return_value=SessionStub(client)) - def test_dynamodb_empty_table_handler(self, mock_client): + value = DynamodbLookup.handle( + base_lookup_key, self.context, self.provider) + self.assertEqual(value, base_lookup_key_valid) + + def test_dynamodb_empty_table_handler(self): expected_params = { 'TableName': '', 'Key': { @@ -94,17 +100,14 @@ def test_dynamodb_empty_table_handler(self, mock_client): self.stubber.add_response('get_item', self.get_parameters_response, expected_params) - with self.stubber: - try: - DynamodbLookup.handle(base_lookup_key) - except ValueError as e: - self.assertEqual( - 'Please make sure to include a dynamodb table name', - str(e)) - - @mock.patch('stacker.lookups.handlers.dynamodb.get_session', - return_value=SessionStub(client)) - def test_dynamodb_missing_table_handler(self, mock_client): + + msg = 'Please make sure to include a dynamodb table name' + with self.assertRaises(ValueError) as raised: + DynamodbLookup.handle( + base_lookup_key, self.context, self.provider) + self.assertEquals(raised.exception.message, msg) + + def test_dynamodb_missing_table_handler(self): expected_params = { 'Key': { 'TestKey': {'S': 'TestVal'} @@ -115,17 +118,14 @@ def test_dynamodb_missing_table_handler(self, mock_client): self.stubber.add_response('get_item', self.get_parameters_response, expected_params) - with self.stubber: - try: - DynamodbLookup.handle(base_lookup_key) - except ValueError as e: - self.assertEqual( - 'Please make sure to include a tablename', - str(e)) - - @mock.patch('stacker.lookups.handlers.dynamodb.get_session', - return_value=SessionStub(client)) - def test_dynamodb_invalid_table_handler(self, mock_client): + + msg = 'Please make sure to include a tablename' + with self.assertRaises(ValueError) as raised: + DynamodbLookup.handle( + base_lookup_key, self.context, self.provider) + self.assertEquals(raised.exception.message, msg) + + def test_dynamodb_invalid_table_handler(self): expected_params = { 'TableName': 'FakeTable', 'Key': { @@ -138,17 +138,14 @@ def test_dynamodb_invalid_table_handler(self, mock_client): self.stubber.add_client_error('get_item', service_error_code=service_error_code, expected_params=expected_params) - with self.stubber: - try: - DynamodbLookup.handle(base_lookup_key) - except ValueError as e: - self.assertEqual( - 'Cannot find the dynamodb table: FakeTable', - str(e)) - - @mock.patch('stacker.lookups.handlers.dynamodb.get_session', - return_value=SessionStub(client)) - def test_dynamodb_invalid_partition_key_handler(self, mock_client): + + msg = 'Cannot find the dynamodb table: FakeTable' + with self.assertRaises(ValueError) as raised: + DynamodbLookup.handle( + base_lookup_key, self.context, self.provider) + self.assertEquals(raised.exception.message, msg) + + def test_dynamodb_invalid_partition_key_handler(self): expected_params = { 'TableName': 'TestTable', 'Key': { @@ -162,17 +159,13 @@ def test_dynamodb_invalid_partition_key_handler(self, mock_client): service_error_code=service_error_code, expected_params=expected_params) - with self.stubber: - try: - DynamodbLookup.handle(base_lookup_key) - except ValueError as e: - self.assertEqual( - 'No dynamodb record matched the partition key: FakeKey', - str(e)) - - @mock.patch('stacker.lookups.handlers.dynamodb.get_session', - return_value=SessionStub(client)) - def test_dynamodb_invalid_partition_val_handler(self, mock_client): + msg = 'No dynamodb record matched the partition key: FakeKey' + with self.assertRaises(ValueError) as raised: + DynamodbLookup.handle( + base_lookup_key, self.context, self.provider) + self.assertEquals(raised.exception.message, msg) + + def test_dynamodb_invalid_partition_val_handler(self): expected_params = { 'TableName': 'TestTable', 'Key': { @@ -185,11 +178,10 @@ def test_dynamodb_invalid_partition_val_handler(self, mock_client): self.stubber.add_response('get_item', empty_response, expected_params) - with self.stubber: - try: - DynamodbLookup.handle(base_lookup_key) - except ValueError as e: - self.assertEqual( - 'The dynamodb record could not be found using ' - 'the following key: {\'S\': \'FakeVal\'}', - str(e)) + + msg = ('The dynamodb record could not be found using the following ' + 'key: {\'S\': \'FakeVal\'}') + with self.assertRaises(ValueError) as raised: + DynamodbLookup.handle( + base_lookup_key, self.context, self.provider) + self.assertEquals(raised.exception.message, msg) diff --git a/stacker/tests/lookups/handlers/test_kms.py b/stacker/tests/lookups/handlers/test_kms.py index bb199a639..254a082de 100644 --- a/stacker/tests/lookups/handlers/test_kms.py +++ b/stacker/tests/lookups/handlers/test_kms.py @@ -6,31 +6,39 @@ from moto import mock_kms -import boto3 - from stacker.lookups.handlers.kms import KmsLookup +from ...factories import mock_boto3_client, mock_context, mock_provider + +REGION = 'us-east-1' class TestKMSHandler(unittest.TestCase): def setUp(self): + self.context = mock_context() + self.provider = mock_provider(region=REGION) + + self.mock_kms = mock_kms() + self.mock_kms.start() + self.kms, self.client_mock = mock_boto3_client('kms', region=REGION) + self.client_mock.start() + self.plain = b"my secret" - with mock_kms(): - kms = boto3.client("kms", region_name="us-east-1") - self.secret = kms.encrypt( - KeyId="alias/stacker", - Plaintext=codecs.encode(self.plain, 'base64').decode('utf-8'), - )["CiphertextBlob"] - if isinstance(self.secret, bytes): - self.secret = self.secret.decode() + self.secret = self.kms.encrypt( + KeyId="alias/stacker", + Plaintext=codecs.encode(self.plain, 'base64').decode('utf-8'), + )["CiphertextBlob"] + if isinstance(self.secret, bytes): + self.secret = self.secret.decode() + + def tearDown(self): + self.client_mock.stop() + self.mock_kms.stop() def test_kms_handler(self): - with mock_kms(): - decrypted = KmsLookup.handle(self.secret) - self.assertEqual(decrypted, self.plain) + decrypted = KmsLookup.handle(self.secret, self.context, self.provider) + self.assertEqual(decrypted, self.plain) def test_kms_handler_with_region(self): - region = "us-east-1" - value = "%s@%s" % (region, self.secret) - with mock_kms(): - decrypted = KmsLookup.handle(value) - self.assertEqual(decrypted, self.plain) + value = "%s@%s" % (REGION, self.secret) + decrypted = KmsLookup.handle(value, self.context, self.provider) + self.assertEqual(decrypted, self.plain) diff --git a/stacker/tests/lookups/handlers/test_ssmstore.py b/stacker/tests/lookups/handlers/test_ssmstore.py index daff2444d..d0d89a81a 100644 --- a/stacker/tests/lookups/handlers/test_ssmstore.py +++ b/stacker/tests/lookups/handlers/test_ssmstore.py @@ -2,74 +2,93 @@ from __future__ import division from __future__ import absolute_import from builtins import str -import unittest -import mock + +import pytest from botocore.stub import Stubber + from stacker.lookups.handlers.ssmstore import SsmstoreLookup -import boto3 -from stacker.tests.factories import SessionStub - - -class TestSSMStoreHandler(unittest.TestCase): - client = boto3.client('ssm', region_name='us-east-1') - - def setUp(self): - self.stubber = Stubber(self.client) - self.get_parameters_response = { - 'Parameters': [ - { - 'Name': 'ssmkey', - 'Type': 'String', - 'Value': 'ssmvalue' - } - ], - 'InvalidParameters': [ - 'invalidssmparam' - ] - } - self.invalid_get_parameters_response = { - 'InvalidParameters': [ - 'ssmkey' - ] - } - self.expected_params = { - 'Names': ['ssmkey'], - 'WithDecryption': True +from ...factories import mock_context, mock_provider, mock_boto3_client + +REGION = 'us-east-1' +ALT_REGION = 'us-east-2' + + +@pytest.fixture +def context(): + return mock_context() + + +@pytest.fixture(params=[dict(region=REGION)]) +def provider(request): + return mock_provider(**request.param) + + +@pytest.fixture(params=[dict(region=REGION)]) +def ssm(request): + client, mock = mock_boto3_client("ssm", **request.param) + with mock: + yield client + + +@pytest.fixture +def ssm_stubber(ssm): + with Stubber(ssm) as stubber: + yield stubber + + +get_parameters_response = { + 'Parameters': [ + { + 'Name': 'ssmkey', + 'Type': 'String', + 'Value': 'ssmvalue' } - self.ssmkey = "ssmkey" - self.ssmvalue = "ssmvalue" - - @mock.patch('stacker.lookups.handlers.ssmstore.get_session', - return_value=SessionStub(client)) - def test_ssmstore_handler(self, mock_client): - self.stubber.add_response('get_parameters', - self.get_parameters_response, - self.expected_params) - with self.stubber: - value = SsmstoreLookup.handle(self.ssmkey) - self.assertEqual(value, self.ssmvalue) - self.assertIsInstance(value, str) - - @mock.patch('stacker.lookups.handlers.ssmstore.get_session', - return_value=SessionStub(client)) - def test_ssmstore_invalid_value_handler(self, mock_client): - self.stubber.add_response('get_parameters', - self.invalid_get_parameters_response, - self.expected_params) - with self.stubber: - try: - SsmstoreLookup.handle(self.ssmkey) - except ValueError: - assert True - - @mock.patch('stacker.lookups.handlers.ssmstore.get_session', - return_value=SessionStub(client)) - def test_ssmstore_handler_with_region(self, mock_client): - self.stubber.add_response('get_parameters', - self.get_parameters_response, - self.expected_params) - region = "us-east-1" - temp_value = "%s@%s" % (region, self.ssmkey) - with self.stubber: - value = SsmstoreLookup.handle(temp_value) - self.assertEqual(value, self.ssmvalue) + ], + 'InvalidParameters': [ + 'invalidssmparam' + ] +} + +invalid_get_parameters_response = { + 'InvalidParameters': [ + 'ssmkey' + ] +} + +expected_params = { + 'Names': ['ssmkey'], + 'WithDecryption': True +} + +ssmkey = "ssmkey" +ssmvalue = "ssmvalue" + + +def test_ssmstore_handler(ssm_stubber, context, provider): + ssm_stubber.add_response('get_parameters', + get_parameters_response, + expected_params) + + value = SsmstoreLookup.handle(ssmkey, context, provider) + assert value == ssmvalue + assert isinstance(value, str) + + +def test_ssmstore_invalid_value_handler(ssm_stubber, context, provider): + ssm_stubber.add_response('get_parameters', + invalid_get_parameters_response, + expected_params) + + with pytest.raises(ValueError): + SsmstoreLookup.handle(ssmkey, context, provider) + + +@pytest.mark.parametrize("ssm", [dict(region=ALT_REGION)], indirect=True) +def test_ssmstore_handler_with_region(ssm_stubber, context, provider): + ssm_stubber.add_response('get_parameters', + get_parameters_response, + expected_params) + temp_value = '%s@%s' % (ALT_REGION, ssmkey) + + value = SsmstoreLookup.handle(temp_value, context, provider) + assert value == ssmvalue diff --git a/stacker/tests/providers/aws/test_default.py b/stacker/tests/providers/aws/test_default.py index 10dc5577c..e9a729722 100644 --- a/stacker/tests/providers/aws/test_default.py +++ b/stacker/tests/providers/aws/test_default.py @@ -382,7 +382,7 @@ def setUp(self): self.session = get_session(region=region) self.provider = Provider( self.session, region=region, recreate_failed=False) - self.stubber = Stubber(self.provider.cloudformation) + self.stubber = Stubber(self.provider._cloudformation) def test_get_stack_stack_does_not_exist(self): stack_name = "MockStack" @@ -657,7 +657,7 @@ def setUp(self): self.session = get_session(region=region) self.provider = Provider( self.session, interactive=True, recreate_failed=True) - self.stubber = Stubber(self.provider.cloudformation) + self.stubber = Stubber(self.provider._cloudformation) def test_successful_init(self): replacements = True From e11a905590d26c7057dcef6def2605bf8f559123 Mon Sep 17 00:00:00 2001 From: Daniel Miranda Date: Mon, 18 Mar 2019 22:10:48 -0300 Subject: [PATCH 12/13] tests: add functional tests for integrated hooks --- setup.py | 2 +- stacker/tests/factories.py | 1 - tests/fixtures/blueprints/bucket.yaml.j2 | 10 ++ tests/test_helper.bash | 18 ++- .../34_stacker_build-integrated-hooks.bats | 130 ++++++++++++++++++ 5 files changed, 158 insertions(+), 3 deletions(-) create mode 100644 tests/fixtures/blueprints/bucket.yaml.j2 create mode 100644 tests/test_suite/34_stacker_build-integrated-hooks.bats diff --git a/setup.py b/setup.py index 8aae81862..4e97e09d2 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ "mock~=2.0", "moto~=1.3.7", "testfixtures~=4.10.0", - "flake8-future-import", + "flake8-future-import" ] scripts = [ diff --git a/stacker/tests/factories.py b/stacker/tests/factories.py index ca5c64337..e46e2cf37 100644 --- a/stacker/tests/factories.py +++ b/stacker/tests/factories.py @@ -39,7 +39,6 @@ def __init__(self, outputs=None, region=None, profile=None): "Outputs": stack_outputs, "StackStatus": "CREATED" } - self._sessions = {} def get_stack(self, stack_name, **kwargs): try: diff --git a/tests/fixtures/blueprints/bucket.yaml.j2 b/tests/fixtures/blueprints/bucket.yaml.j2 new file mode 100644 index 000000000..2687f4473 --- /dev/null +++ b/tests/fixtures/blueprints/bucket.yaml.j2 @@ -0,0 +1,10 @@ +AWSTemplateFormatVersion: 2010-09-09 +Resources: + Bucket: + Type: AWS::S3::Bucket + Properties: + BucketName: {{ variables.BucketName }} + AccessControl: Private +Outputs: + BucketName: + Value: !Ref Bucket diff --git a/tests/test_helper.bash b/tests/test_helper.bash index 1d0d52194..1392df715 100644 --- a/tests/test_helper.bash +++ b/tests/test_helper.bash @@ -36,7 +36,23 @@ assert() { # Checks that the given line is in $output. assert_has_line() { - echo "$output" | grep "$@" 1>/dev/null + echo "$output" | grep -q "$@" +} + +assert_has_lines_in_order() { + local search_line + read -r search_line || return $? + + for line in "${lines[@]}"; do + if grep -q "$@" "$search_line" <<< "$line"; then + if ! read -r search_line && [ -z "$search_line" ]; then + return 0 + fi + fi + done + + echo "Error: did not match line in correct order: '$search_line'" >&2 + return 1 } # This helper wraps "stacker" with bats' "run" and also outputs debug diff --git a/tests/test_suite/34_stacker_build-integrated-hooks.bats b/tests/test_suite/34_stacker_build-integrated-hooks.bats new file mode 100644 index 000000000..b1da384f5 --- /dev/null +++ b/tests/test_suite/34_stacker_build-integrated-hooks.bats @@ -0,0 +1,130 @@ +#!/usr/bin/env bats + +# This test will exercise the integration of hooks among the execution of stacks +# making use of the fact that S3 buckets cannot be deleted when not empty. +# The test will create the bucket and populate it during build, and erase the +# objects before destruction. If the hooks are not executed in the proper order, +# the destruction will fail, and so will the tst. + +load ../test_helper + +@test "stacker build - integrated hooks" { + needs_aws + + config() { + echo "namespace: ${STACKER_NAMESPACE}-integrated-hooks" + cat <<'EOF' +stacks: + - name: bucket + profile: stacker + template_path: fixtures/blueprints/bucket.yaml.j2 + variables: + BucketName: "stacker-${envvar STACKER_NAMESPACE}-integrated-hooks-${awsparam AccountId}" + +build_hooks: + - name: write-hello + path: stacker.hooks.command.run_command + args: + command: 'echo "Hello from Stacker!" > /tmp/hello.txt' + shell: true + + - name: send-hello + path: stacker.hooks.command.run_command + requires: + - write-hello + args: + command: 'aws s3 cp /tmp/hello.txt "s3://$BUCKET/hello.txt"' + shell: true + env: + BUCKET: "${output bucket::BucketName}" + AWS_PROFILE: stacker + + - name: send-world + path: stacker.hooks.command.run_command + requires: + - send-hello + args: + command: 'aws s3 cp "s3://$BUCKET/hello.txt" "s3://$BUCKET/world.txt"' + shell: true + env: + BUCKET: "${output bucket::BucketName}" + AWS_PROFILE: stacker + +destroy_hooks: + - name: remove-world + path: stacker.hooks.command.run_command + args: + command: 'aws s3 rm "s3://$BUCKET/world.txt"' + shell: true + env: + BUCKET: "${output bucket::BucketName}" + AWS_PROFILE: stacker + + - name: remove-hello + path: stacker.hooks.command.run_command + required_by: + - remove-world + args: + command: 'aws s3 rm "s3://$BUCKET/hello.txt"' + shell: true + env: + BUCKET: "${output bucket::BucketName}" + AWS_PROFILE: stacker + + - name: clean-hello + path: stacker.hooks.command.run_command + required_by: + - bucket + args: + command: [rm, -f, /tmp/hello.txt] +EOF + } + + teardown() { + stacker destroy --force <(config) + } + + stacker build -t --recreate-failed <(config) + assert "$status" -eq 0 + assert_has_line "Using default AWS provider mode" + assert_has_lines_in_order -E <<'EOF' +pre_build_hooks: complete +write-hello: complete +bucket: submitted \(creating new stack\) +bucket: complete \(creating new stack\) +upload: [^ ]*/hello.txt to s3://[^ ]*/hello.txt +send-hello: complete +copy: s3://[^ ]*/hello.txt to s3://[^ ]*/world.txt +send-world: complete +post_build_hooks: complete +EOF + + stacker destroy --force <(config) + assert "$status" -eq 0 + assert_has_line "Using default AWS provider mode" + assert_has_lines_in_order -E <<'EOF' +pre_destroy_hooks: complete +delete: s3://[^ ]*/world.txt +remove-world: complete +delete: s3://[^ ]*/hello.txt +remove-hello: complete +bucket: submitted \(submitted for destruction\) +bucket: complete \(stack destroyed\) +clean-hello: complete +post_destroy_hooks: complete +EOF + assert ! -e /tmp/hello.txt + + # Check that hooks that use lookups from stacks that do not exist anymore are + # not run + stacker destroy --force <(config) + assert "$status" -eq 0 + assert_has_lines_in_order <<'EOF' +pre_destroy_hooks: complete +remove-world: skipped +remove-hello: skipped +bucket: skipped +clean-hello: complete +post_destroy_hooks: complete +EOF +} From 39a2f82f8408c608adf27fd8ce650fa2e11efe49 Mon Sep 17 00:00:00 2001 From: Daniel Miranda Date: Tue, 19 Mar 2019 17:02:31 -0300 Subject: [PATCH 13/13] hooks: treat lookup dependencies specially to handle destroy When running the destroy action, we reverse the order of execution compared to build, such that steps that required others are now required by them instead. Unfortunately that is not enough to express the execution requirements of hooks: they might be `required_by` a stack to run some cleanup before destroying that stack, but still need it to be deployed to lookup outputs from. Handle that by manually checking for the dependencies of lookups before executing hooks, skipping them if a required stack is not deployed (or has already been destroyed). --- stacker/actions/base.py | 19 ++--- stacker/exceptions.py | 9 ++- stacker/hooks/__init__.py | 98 ++++++++++++++++++----- stacker/tests/test_context.py | 40 +++++---- stacker/tests/test_hooks.py | 147 ++++++++++++++++++++++++++++++++++ stacker/tests/test_util.py | 114 -------------------------- 6 files changed, 257 insertions(+), 170 deletions(-) create mode 100644 stacker/tests/test_hooks.py diff --git a/stacker/actions/base.py b/stacker/actions/base.py index 1ce8ab516..65a09afd5 100644 --- a/stacker/actions/base.py +++ b/stacker/actions/base.py @@ -13,10 +13,11 @@ import botocore.exceptions from stacker.session_cache import get_session -from stacker.exceptions import HookExecutionFailed, PlanFailed -from stacker.status import COMPLETE, SKIPPED, FailedStatus +from stacker.exceptions import PlanFailed +from stacker.status import COMPLETE from stacker.util import ensure_s3_bucket, get_s3_endpoint + logger = logging.getLogger(__name__) # After submitting a stack update/create, this controls how long we'll wait @@ -142,18 +143,8 @@ def target_fn(*args, **kwargs): return COMPLETE def hook_fn(hook, *args, **kwargs): - provider = self.provider_builder.build(profile=hook.profile, - region=hook.region) - - try: - result = hook.run(provider, self.context) - except HookExecutionFailed as e: - return FailedStatus(reason=str(e)) - - if result is None: - return SKIPPED - - return COMPLETE + return hook.run_step(provider_builder=self.provider_builder, + context=self.context) pre_hooks_target = Target( name="pre_{}_hooks".format(action_name)) diff --git a/stacker/exceptions.py b/stacker/exceptions.py index e2f463336..9b6cbd50e 100644 --- a/stacker/exceptions.py +++ b/stacker/exceptions.py @@ -126,6 +126,7 @@ def __init__(self, stack_name, *args, **kwargs): message = ("Stack: \"%s\" does not exist in outputs or the lookup is " "not available in this stacker run") % (stack_name,) super(StackDoesNotExist, self).__init__(message, *args, **kwargs) + self.stack_name = stack_name class MissingParameterException(Exception): @@ -278,14 +279,14 @@ def __init__(self, exception, stack, dependency): class HookExecutionFailed(Exception): """Raised when running a required hook fails""" - def __init__(self, hook, result=None, exception=None): + def __init__(self, hook, result=None, cause=None): self.hook = hook self.result = result - self.exception = exception + self.cause = cause - if self.exception: + if self.cause: message = ("Hook '{}' threw exception: {}".format( - hook.name, exception)) + hook.name, cause)) else: message = ("Hook '{}' failed (result: {})".format( hook.name, result)) diff --git a/stacker/hooks/__init__.py b/stacker/hooks/__init__.py index 987b2b167..7869e0489 100644 --- a/stacker/hooks/__init__.py +++ b/stacker/hooks/__init__.py @@ -5,8 +5,11 @@ import logging from collections import Mapping, namedtuple -from stacker.exceptions import HookExecutionFailed +from stacker.exceptions import HookExecutionFailed, StackDoesNotExist from stacker.util import load_object_from_string +from stacker.status import ( + COMPLETE, SKIPPED, FailedStatus, NotSubmittedStatus, SkippedStatus +) from stacker.variables import Variable logger = logging.getLogger(__name__) @@ -54,11 +57,55 @@ def __init__(self, name, path, required=True, enabled=True, self.region = region self._args = {} + self._args, deps = self.parse_args(args) + self.requires.update(deps) + + self._callable = self.resolve_path() + + def parse_args(self, args): + arg_vars = {} + deps = set() + if args: for key, value in args.items(): - var = self._args[key] = \ + var = arg_vars[key] = \ Variable('{}.args.{}'.format(self.name, key), value) - self.requires.update(var.dependencies()) + deps.update(var.dependencies()) + + return arg_vars, deps + + def resolve_path(self): + try: + return load_object_from_string(self.path) + except (AttributeError, ImportError) as e: + raise ValueError("Unable to load method at %s for hook %s: %s", + self.path, self.name, str(e)) + + def check_args_dependencies(self, provider, context): + # When running hooks for destruction, we might rely on outputs of + # stacks that we assume have been deployed. Unfortunately, since + # destruction must happen in the reverse order of creation, those stack + # dependencies will not be present on `requires`, but in `required_by`, + # meaning the execution engine won't stop the hook from running early. + + # To deal with that, manually find the dependencies coming from + # lookups in the hook arguments, select those that represent stacks, + # and check if they are actually available. + + dependencies = set() + for value in self._args.values(): + dependencies.update(value.dependencies()) + + for dep in dependencies: + # We assume all dependency names are valid here. Hence, if we can't + # find a stack with that same name, it must be a target or a hook, + # and hence we don't need to check it + stack = context.get_stack(dep) + if stack is None: + continue + + # This will raise if the stack is missing + provider.get_stack(stack.fqn) def resolve_args(self, provider, context): for key, value in self._args.items(): @@ -85,29 +132,15 @@ def run(self, provider, context): """ logger.info("Executing hook %s", self) - - if not self.enabled: - logger.debug("Hook %s is disabled, skipping", self.name) - return - - try: - method = load_object_from_string(self.path) - except (AttributeError, ImportError) as e: - logger.exception("Unable to load method at %s for hook %s:", - self.path, self.name) - if self.required: - raise HookExecutionFailed(self, exception=e) - - return - kwargs = dict(self.resolve_args(provider, context)) try: - result = method(context=context, provider=provider, **kwargs) + result = self._callable(context=context, provider=provider, + **kwargs) except Exception as e: if self.required: - raise HookExecutionFailed(self, exception=e) + raise HookExecutionFailed(self, cause=e) - return + return None if not result: if self.required: @@ -125,6 +158,29 @@ def run(self, provider, context): return result + def run_step(self, provider_builder, context): + if not self.enabled: + return NotSubmittedStatus() + + provider = provider_builder.build(profile=self.profile, + region=self.region) + + try: + self.check_args_dependencies(provider, context) + except StackDoesNotExist as e: + reason = "required stack not deployed: {}".format(e.stack_name) + return SkippedStatus(reason=reason) + + try: + result = self.run(provider, context) + except HookExecutionFailed as e: + return FailedStatus(reason=str(e)) + + if not result: + return SKIPPED + + return COMPLETE + def __str__(self): return 'Hook(name={}, path={}, profile={}, region={})'.format( self.name, self.path, self.profile, self.region) diff --git a/stacker/tests/test_context.py b/stacker/tests/test_context.py index 629509710..689e1a36e 100644 --- a/stacker/tests/test_context.py +++ b/stacker/tests/test_context.py @@ -9,6 +9,9 @@ from stacker.config import load, Config +FAKE_HOOK_PATH = "stacker.tests.fixtures.mock_hooks.mock_hook" + + class TestContext(unittest.TestCase): def setUp(self): @@ -119,7 +122,7 @@ def test_hook_with_sys_path(self): "pre_build": [ { "data_key": "myHook", - "path": "fixtures.mock_hooks.mock_hook", + "path": FAKE_HOOK_PATH.replace('stacker.tests.', ''), "required": True, "args": { "value": "mockResult"}}]}) @@ -134,43 +137,46 @@ def test_hook_with_sys_path(self): self.assertEqual("mockResult", context.hook_data["myHook"]["result"]) def test_get_hooks_for_action(self): + config = Config({ "pre_build": [ - {"path": "fake.hook"}, - {"name": "pre_build_test", "path": "fake.hook"}, - {"path": "fake.hook"} + {"path": FAKE_HOOK_PATH}, + {"name": "pre_build_test", "path": FAKE_HOOK_PATH}, + {"path": FAKE_HOOK_PATH} ], "post_build": [ - {"path": "fake.hook"}, - {"name": "post_build_test", "path": "fake.hook"}, - {"path": "fake.hook"} + {"path": FAKE_HOOK_PATH}, + {"name": "post_build_test", "path": FAKE_HOOK_PATH}, + {"path": FAKE_HOOK_PATH} ], "build_hooks": [ - {"path": "fake.hook"}, - {"name": "build_test", "path": "fake.hook"}, - {"path": "fake.hook"} + {"path": FAKE_HOOK_PATH}, + {"name": "build_test", "path": FAKE_HOOK_PATH}, + {"path": FAKE_HOOK_PATH} ] }) context = Context(config=config) hooks = context.get_hooks_for_action('build') - assert hooks.pre[0].name == "pre_build_1_fake.hook" + assert hooks.pre[0].name == "pre_build_1_{}".format(FAKE_HOOK_PATH) assert hooks.pre[1].name == "pre_build_test" - assert hooks.pre[2].name == "pre_build_3_fake.hook" + assert hooks.pre[2].name == "pre_build_3_{}".format(FAKE_HOOK_PATH) - assert hooks.post[0].name == "post_build_1_fake.hook" + assert hooks.post[0].name == "post_build_1_{}".format(FAKE_HOOK_PATH) assert hooks.post[1].name == "post_build_test" - assert hooks.post[2].name == "post_build_3_fake.hook" + assert hooks.post[2].name == "post_build_3_{}".format(FAKE_HOOK_PATH) - assert hooks.custom[0].name == "build_hooks_1_fake.hook" + assert hooks.custom[0].name == \ + "build_hooks_1_{}".format(FAKE_HOOK_PATH) assert hooks.custom[1].name == "build_test" - assert hooks.custom[2].name == "build_hooks_3_fake.hook" + assert hooks.custom[2].name == \ + "build_hooks_3_{}".format(FAKE_HOOK_PATH) def test_hook_data_key_fallback(self): config = Config({ "build_hooks": [ - {"name": "my-hook", "path": "fake.hook"} + {"name": "my-hook", "path": FAKE_HOOK_PATH} ] }) context = Context(config=config) diff --git a/stacker/tests/test_hooks.py b/stacker/tests/test_hooks.py new file mode 100644 index 000000000..275f7f0b2 --- /dev/null +++ b/stacker/tests/test_hooks.py @@ -0,0 +1,147 @@ +from __future__ import print_function +from __future__ import division +from __future__ import absolute_import +import unittest +import mock + + +from stacker.exceptions import HookExecutionFailed +from stacker.hooks import Hook +from stacker.status import ( + COMPLETE, FailedStatus, NotSubmittedStatus, SkippedStatus +) +from .factories import MockProviderBuilder, mock_context, mock_provider + + +mock_hook = mock.Mock() + + +class TestHooks(unittest.TestCase): + mock_hook_path = __name__ + ".mock_hook" + + def setUp(self): + self.context = mock_context(extra_config_args={ + "stacks": [ + {"name": "undeployed-stack", "template_path": "missing"} + ] + }) + self.provider = mock_provider(region="us-east-1") + self.provider_builder = MockProviderBuilder(self.provider) + + global mock_hook + mock_hook = mock.Mock() + + def test_missing_module(self): + with self.assertRaises(ValueError): + Hook("test", path="not.a.real.path") + + def test_missing_method(self): + with self.assertRaises(ValueError): + Hook("test", path=self.mock_hook_path + "garbage") + + def test_valid_enabled_hook(self): + hook = Hook("test", path=self.mock_hook_path, + required=True, enabled=True) + + result = mock_hook.return_value = mock.Mock() + self.assertIs(result, hook.run(self.provider, self.context)) + mock_hook.assert_called_once() + + def test_context_provided_to_hook(self): + hook = Hook("test", path=self.mock_hook_path, + required=True) + + def return_context(*args, **kwargs): + return kwargs['context'] + + mock_hook.side_effect = return_context + result = hook.run(self.provider, self.context) + self.assertIs(result, self.context) + + def test_hook_failure(self): + hook = Hook("test", path=self.mock_hook_path, + required=True) + + err = Exception() + mock_hook.side_effect = err + + with self.assertRaises(HookExecutionFailed) as raised: + hook.run(self.provider, self.context) + + self.assertIs(hook, raised.exception.hook) + self.assertIs(err, raised.exception.cause) + + def test_hook_failure_skip(self): + hook = Hook("test", path=self.mock_hook_path, + required=False) + + mock_hook.side_effect = Exception() + result = hook.run(self.provider, self.context) + self.assertIsNone(result) + + def test_return_data_hook(self): + hook = Hook("test", path=self.mock_hook_path, + data_key='test') + hook_data = {'hello': 'world'} + mock_hook.return_value = hook_data + + result = hook.run(self.provider, self.context) + self.assertEqual(hook_data, result) + self.assertEqual(hook_data, self.context.hook_data.get('test')) + + def test_return_data_hook_duplicate_key(self): + hook = Hook("test", path=self.mock_hook_path, + data_key='test') + mock_hook.return_value = {'foo': 'bar'} + + hook_data = {'hello': 'world'} + self.context.set_hook_data('test', hook_data) + with self.assertRaises(KeyError): + hook.run(self.provider, self.context) + + self.assertEqual(hook_data, self.context.hook_data['test']) + + def test_run_step_disabled(self): + hook = Hook("test", path=self.mock_hook_path, enabled=False) + + status = hook.run_step(provider_builder=self.provider_builder, + context=self.context) + self.assertIsInstance(status, NotSubmittedStatus) + + def test_run_step_stack_dep_missing(self): + hook = Hook("test", path=self.mock_hook_path, + args={"hello": "${output undeployed-stack::Output}"}) + stack_fqn = self.context.get_stack("undeployed-stack").fqn + + status = hook.run_step(provider_builder=self.provider_builder, + context=self.context) + self.assertIsInstance(status, SkippedStatus) + self.assertEqual(status.reason, + "required stack not deployed: {}".format(stack_fqn)) + + def test_run_step_hook_raised(self): + hook = Hook("test", path=self.mock_hook_path) + err = HookExecutionFailed(hook, cause=RuntimeError("canary")) + hook.run = mock.Mock(side_effect=err) + + status = hook.run_step(provider_builder=self.provider_builder, + context=self.context) + self.assertIsInstance(status, FailedStatus) + self.assertIn("canary", status.reason) + self.assertIn("threw exception", status.reason) + + def test_run_step_hook_failed(self): + hook = Hook("test", path=self.mock_hook_path, required=True) + hook.run = mock.Mock(return_value=False) + + status = hook.run_step(provider_builder=self.provider_builder, + context=self.context) + self.assertIsInstance(status, SkippedStatus) + + def test_run_step_hook_succeeded(self): + hook = Hook("test", path=self.mock_hook_path) + hook.run = mock.Mock(return_value=True) + + status = hook.run_step(provider_builder=self.provider_builder, + context=self.context) + self.assertEqual(status, COMPLETE) diff --git a/stacker/tests/test_util.py b/stacker/tests/test_util.py index cc6a8deab..218b594af 100644 --- a/stacker/tests/test_util.py +++ b/stacker/tests/test_util.py @@ -12,8 +12,6 @@ import boto3 from stacker.config import GitPackageSource -from stacker.exceptions import HookExecutionFailed -from stacker.hooks import Hook from stacker.util import ( cf_safe_name, load_object_from_string, @@ -31,10 +29,6 @@ SourceProcessor ) -from .factories import ( - mock_context, - mock_provider, -) regions = ["us-east-1", "cn-north-1", "ap-northeast-1", "eu-west-1", "ap-southeast-1", "ap-southeast-2", "us-west-2", "us-gov-west-1", @@ -272,114 +266,6 @@ def test_SourceProcessor_helpers(self): ) -mock_hook = mock.Mock() - - -class TestHooks(unittest.TestCase): - - def setUp(self): - self.context = mock_context(namespace="namespace") - self.provider = mock_provider(region="us-east-1") - - global mock_hook - mock_hook = mock.Mock() - - def test_missing_required_hook(self): - hook = Hook("test", path="not.a.real.path", required=True) - - with self.assertRaises(HookExecutionFailed) as raised: - hook.run(self.provider, self.context) - self.assertIsInstance(ImportError, raised.exception.exception) - - def test_missing_required_hook_method(self): - hook = Hook("test", path="stacker.hooks.blah", required=True) - - with self.assertRaises(HookExecutionFailed) as raised: - hook.run(self.provider, self.context) - self.assertIsInstance(AttributeError, raised.exception.exception) - - def test_missing_non_required_hook_method(self): - hook = Hook("test", path="stacker.hooks.blah", required=False) - - result = hook.run(self.provider, self.context) - self.assertIsNone(result) - - def test_default_required_hook(self): - hook = Hook("test", path="stacker.hooks.blah") - - with self.assertRaises(HookExecutionFailed) as raised: - hook.run(self.provider, self.context) - self.assertIsInstance(AttributeError, raised.exception.exception) - - def test_valid_enabled_hook(self): - hook = Hook("test", path="stacker.tests.test_util.mock_hook", - required=True, enabled=True) - - result = mock_hook.return_value = mock.Mock() - self.assertIs(result, hook.run(self.provider, self.context)) - mock_hook.assert_called_once() - - def test_valid_disabled_hook(self): - hook = Hook("test", path="stacker.tests.test_util.mock_hook", - required=True, enabled=False) - - self.assertIsNone(hook.run(self.provider, self.context)) - mock_hook.assert_not_called() - - def test_context_provided_to_hook(self): - hook = Hook("test", path="stacker.tests.test_util.mock_hook", - required=True) - - def return_context(*args, **kwargs): - return kwargs['context'] - - mock_hook.side_effect = return_context - result = hook.run(self.provider, self.context) - self.assertIs(result, self.context) - - def test_hook_failure(self): - hook = Hook("test", path="stacker.tests.test_util.mock_hook", - required=True) - - err = Exception() - mock_hook.side_effect = err - - with self.assertRaises(HookExecutionFailed) as raised: - hook.run(self.provider, self.context) - self.assertIs(hook, raised.exception.hook) - self.assertIs(err, raised.exception.exception) - - def test_hook_failure_skip(self): - hook = Hook("test", path="stacker.tests.test_util.mock_hook", - required=False) - - mock_hook.side_effect = Exception() - result = hook.run(self.provider, self.context) - self.assertIsNone(result) - - def test_return_data_hook(self): - hook = Hook("test", path="stacker.tests.test_util.mock_hook", - data_key='test') - hook_data = {'hello': 'world'} - mock_hook.return_value = hook_data - - result = hook.run(self.provider, self.context) - self.assertEqual(hook_data, result) - self.assertEqual(hook_data, self.context.hook_data.get('test')) - - def test_return_data_hook_duplicate_key(self): - hook = Hook("test", path="stacker.tests.test_util.mock_hook", - data_key='test') - mock_hook.return_value = {'foo': 'bar'} - - hook_data = {'hello': 'world'} - self.context.set_hook_data('test', hook_data) - with self.assertRaises(KeyError): - hook.run(self.provider, self.context) - - self.assertEqual(hook_data, self.context.hook_data['test']) - - class TestException1(Exception): pass