|
| 1 | +package state |
| 2 | + |
| 3 | +import ( |
| 4 | + "github.com/pkg/errors" |
| 5 | + "go.uber.org/zap" |
| 6 | + "k8s.io/apimachinery/pkg/types" |
| 7 | + "sigs.k8s.io/controller-runtime/pkg/reconcile" |
| 8 | +) |
| 9 | + |
| 10 | +// State should provide a unique name, and a Reconcile function. |
| 11 | +// This function gets called by the Machine. The first two returned values |
| 12 | +// are returned to the caller, while the 3rd value is used to indicate if the |
| 13 | +// State completed successfully. A value of true will move onto the next State, |
| 14 | +// a value of false will repeat this State until true is returned. |
| 15 | +type State struct { |
| 16 | + // Name should be a unique identifier of the State |
| 17 | + Name string |
| 18 | + |
| 19 | + // Reconcile should perform the actual reconciliation of the State. |
| 20 | + // The reconcile.Result and error should be returned from the controller. |
| 21 | + // the boolean value indicates that the State has been successfully completed. |
| 22 | + Reconcile func() (reconcile.Result, error, bool) |
| 23 | + |
| 24 | + // OnEnter executes before the Reconcile function is called. |
| 25 | + OnEnter func() error |
| 26 | +} |
| 27 | + |
| 28 | +// transition represents a transition between two states. |
| 29 | +type transition struct { |
| 30 | + from, to State |
| 31 | + predicate TransitionPredicate |
| 32 | +} |
| 33 | + |
| 34 | +// Saver saves the next state name that should be reconciled. |
| 35 | +// If a transition is A -> B, after A finishes reconciling `SaveNextState("B")` will be called. |
| 36 | +type Saver interface { |
| 37 | + SaveNextState(nsName types.NamespacedName, stateName string) error |
| 38 | +} |
| 39 | + |
| 40 | +// Loader should return the value saved by Saver. |
| 41 | +type Loader interface { |
| 42 | + LoadNextState(nsName types.NamespacedName) (string, error) |
| 43 | +} |
| 44 | + |
| 45 | +// SaveLoader can both load and save the name of a state. |
| 46 | +type SaveLoader interface { |
| 47 | + Saver |
| 48 | + Loader |
| 49 | +} |
| 50 | + |
| 51 | +// TransitionPredicate is used to indicate if two States should be connected. |
| 52 | +type TransitionPredicate func() bool |
| 53 | + |
| 54 | +var FromBool = func(b bool) TransitionPredicate { |
| 55 | + return func() bool { |
| 56 | + return b |
| 57 | + } |
| 58 | +} |
| 59 | + |
| 60 | +// directTransition can be used to ensure two states are directly linked. |
| 61 | +var directTransition = FromBool(true) |
| 62 | + |
| 63 | +// Machine allows for several States to be registered via "AddTransition" |
| 64 | +// When calling Reconcile, the corresponding State will be used based on the values |
| 65 | +// stored/loaded from the SaveLoader. A Machine corresponds to a single Kubernetes resource. |
| 66 | +type Machine struct { |
| 67 | + allTransitions map[string][]transition |
| 68 | + currentState *State |
| 69 | + logger *zap.SugaredLogger |
| 70 | + saveLoader SaveLoader |
| 71 | + states map[string]State |
| 72 | + nsName types.NamespacedName |
| 73 | +} |
| 74 | + |
| 75 | +// NewStateMachine returns a Machine, it must be set up with calls to "AddTransition(s1, s2, predicate)" |
| 76 | +// before Reconcile is called. |
| 77 | +func NewStateMachine(saver SaveLoader, nsName types.NamespacedName, logger *zap.SugaredLogger) *Machine { |
| 78 | + return &Machine{ |
| 79 | + allTransitions: map[string][]transition{}, |
| 80 | + logger: logger, |
| 81 | + saveLoader: saver, |
| 82 | + states: map[string]State{}, |
| 83 | + nsName: nsName, |
| 84 | + } |
| 85 | +} |
| 86 | + |
| 87 | +// Reconcile will reconcile the currently active State. This method should be called |
| 88 | +// from the controllers. |
| 89 | +func (m *Machine) Reconcile() (reconcile.Result, error) { |
| 90 | + |
| 91 | + if err := m.determineState(); err != nil { |
| 92 | + m.logger.Errorf("error initializing starting state: %s", err) |
| 93 | + return reconcile.Result{}, err |
| 94 | + } |
| 95 | + |
| 96 | + m.logger.Infof("Reconciling state: [%s]", m.currentState.Name) |
| 97 | + |
| 98 | + if m.currentState.OnEnter != nil { |
| 99 | + if err := m.currentState.OnEnter(); err != nil { |
| 100 | + m.logger.Debugf("Error reconciling state [%s]: %s", m.currentState.Name, err) |
| 101 | + return reconcile.Result{}, err |
| 102 | + } |
| 103 | + } |
| 104 | + |
| 105 | + res, err, isComplete := m.currentState.Reconcile() |
| 106 | + |
| 107 | + if err != nil { |
| 108 | + m.logger.Debugf("Error reconciling state [%s]: %s", m.currentState.Name, err) |
| 109 | + return res, err |
| 110 | + } |
| 111 | + |
| 112 | + if isComplete { |
| 113 | + m.logger.Debugf("Completed state: [%s]", m.currentState.Name) |
| 114 | + |
| 115 | + transition := m.getTransitionForState(*m.currentState) |
| 116 | + nextState := "" |
| 117 | + if transition != nil { |
| 118 | + nextState = transition.to.Name |
| 119 | + } |
| 120 | + |
| 121 | + if nextState != "" { |
| 122 | + m.logger.Debugf("preparing transition [%s] -> [%s]", m.currentState.Name, nextState) |
| 123 | + } |
| 124 | + |
| 125 | + if err := m.saveLoader.SaveNextState(m.nsName, nextState); err != nil { |
| 126 | + m.logger.Debugf("Error marking state: [%s] as complete: %s", m.currentState.Name, err) |
| 127 | + return reconcile.Result{}, err |
| 128 | + } |
| 129 | + return res, err |
| 130 | + } |
| 131 | + |
| 132 | + m.logger.Debugf("State [%s] is not yet complete", m.currentState.Name) |
| 133 | + |
| 134 | + return res, err |
| 135 | +} |
| 136 | + |
| 137 | +// determineState ensures that "currentState" has a valid value. |
| 138 | +// the state that is loaded comes from the Loader. |
| 139 | +func (m *Machine) determineState() error { |
| 140 | + currentStateName, err := m.saveLoader.LoadNextState(m.nsName) |
| 141 | + if err != nil { |
| 142 | + return errors.Errorf("could not load starting state: %s", err) |
| 143 | + } |
| 144 | + nextState, ok := m.states[currentStateName] |
| 145 | + if !ok { |
| 146 | + return errors.Errorf("could not determine state %s as it was not added to the State Machine", currentStateName) |
| 147 | + } |
| 148 | + m.currentState = &nextState |
| 149 | + return nil |
| 150 | +} |
| 151 | + |
| 152 | +// AddDirectTransition creates a transition between the two |
| 153 | +// provided states which will always be valid. |
| 154 | +func (m *Machine) AddDirectTransition(from, to State) { |
| 155 | + m.AddTransition(from, to, directTransition) |
| 156 | +} |
| 157 | + |
| 158 | +// AddTransition creates a transition between the two states if the given |
| 159 | +// predicate returns true. |
| 160 | +func (m *Machine) AddTransition(from, to State, predicate TransitionPredicate) { |
| 161 | + _, ok := m.allTransitions[from.Name] |
| 162 | + if !ok { |
| 163 | + m.allTransitions[from.Name] = []transition{} |
| 164 | + } |
| 165 | + m.allTransitions[from.Name] = append(m.allTransitions[from.Name], transition{ |
| 166 | + from: from, |
| 167 | + to: to, |
| 168 | + predicate: predicate, |
| 169 | + }) |
| 170 | + |
| 171 | + m.states[from.Name] = from |
| 172 | + m.states[to.Name] = to |
| 173 | +} |
| 174 | + |
| 175 | +// getTransitionForState returns the first transition it finds that is available |
| 176 | +// from the current state. |
| 177 | +func (m *Machine) getTransitionForState(s State) *transition { |
| 178 | + transitions := m.allTransitions[s.Name] |
| 179 | + for _, t := range transitions { |
| 180 | + if t.predicate() { |
| 181 | + return &t |
| 182 | + } |
| 183 | + } |
| 184 | + return nil |
| 185 | +} |
0 commit comments