Skip to content

Commit 87a12cd

Browse files
added nil safety assertions
1 parent ec36741 commit 87a12cd

File tree

3 files changed

+32
-20
lines changed

3 files changed

+32
-20
lines changed

command_policy_resolver.go

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,20 @@ var defaultPolicies = map[module]map[commandName]*routing.CommandPolicy{
117117
},
118118
}
119119

120-
type CommandInfoResolver struct {
121-
resolve func(ctx context.Context, cmd Cmder) *routing.CommandPolicy
122-
fallBackResolver *CommandInfoResolver
120+
type CommandInfoResolveFunc func(ctx context.Context, cmd Cmder) *routing.CommandPolicy
121+
122+
type commandInfoResolver struct {
123+
resolveFunc CommandInfoResolveFunc
124+
fallBackResolver *commandInfoResolver
123125
}
124126

125-
func NewCommandInfoResolver(resolver func(ctx context.Context, cmd Cmder) *routing.CommandPolicy) *CommandInfoResolver {
126-
return &CommandInfoResolver{
127-
resolve: resolver,
127+
func NewCommandInfoResolver(resolveFunc CommandInfoResolveFunc) *commandInfoResolver {
128+
return &commandInfoResolver{
129+
resolveFunc: resolveFunc,
128130
}
129131
}
130132

131-
func NewDefaultCommandPolicyResolver() *CommandInfoResolver {
133+
func NewDefaultCommandPolicyResolver() *commandInfoResolver {
132134
return NewCommandInfoResolver(func(ctx context.Context, cmd Cmder) *routing.CommandPolicy {
133135
module := "core"
134136
command := cmd.Name()
@@ -146,12 +148,12 @@ func NewDefaultCommandPolicyResolver() *CommandInfoResolver {
146148
})
147149
}
148150

149-
func (r *CommandInfoResolver) GetCommandPolicy(ctx context.Context, cmd Cmder) *routing.CommandPolicy {
150-
if r.resolve == nil {
151+
func (r *commandInfoResolver) GetCommandPolicy(ctx context.Context, cmd Cmder) *routing.CommandPolicy {
152+
if r.resolveFunc == nil {
151153
return nil
152154
}
153155

154-
policy := r.resolve(ctx, cmd)
156+
policy := r.resolveFunc(ctx, cmd)
155157
if policy != nil {
156158
return policy
157159
}
@@ -163,6 +165,6 @@ func (r *CommandInfoResolver) GetCommandPolicy(ctx context.Context, cmd Cmder) *
163165
return nil
164166
}
165167

166-
func (r *CommandInfoResolver) SetFallbackResolver(fallbackResolver *CommandInfoResolver) {
168+
func (r *commandInfoResolver) SetFallbackResolver(fallbackResolver *commandInfoResolver) {
167169
r.fallBackResolver = fallbackResolver
168170
}

osscluster.go

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,7 @@ type ClusterClient struct {
925925
nodes *clusterNodes
926926
state *clusterStateHolder
927927
cmdsInfoCache *cmdsInfoCache
928-
cmdInfoResolver *CommandInfoResolver
928+
cmdInfoResolver *commandInfoResolver
929929
cmdable
930930
hooksMixin
931931
}
@@ -1340,7 +1340,10 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd
13401340

13411341
if c.opt.ReadOnly && c.cmdsAreReadOnly(ctx, cmds) {
13421342
for _, cmd := range cmds {
1343-
policy := c.extractCommandInfo(ctx, cmd)
1343+
var policy *routing.CommandPolicy
1344+
if c.cmdInfoResolver != nil {
1345+
policy = c.cmdInfoResolver.GetCommandPolicy(ctx, cmd)
1346+
}
13441347
if policy != nil && !policy.CanBeUsedInPipeline() {
13451348
return fmt.Errorf(
13461349
"redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(),
@@ -1357,7 +1360,10 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd
13571360
}
13581361

13591362
for _, cmd := range cmds {
1360-
policy := c.extractCommandInfo(ctx, cmd)
1363+
var policy *routing.CommandPolicy
1364+
if c.cmdInfoResolver != nil {
1365+
policy = c.cmdInfoResolver.GetCommandPolicy(ctx, cmd)
1366+
}
13611367
if policy != nil && !policy.CanBeUsedInPipeline() {
13621368
return fmt.Errorf(
13631369
"redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(),
@@ -1982,11 +1988,11 @@ func (c *ClusterClient) context(ctx context.Context) context.Context {
19821988
return context.Background()
19831989
}
19841990

1985-
func (c *ClusterClient) GetResolver() *CommandInfoResolver {
1991+
func (c *ClusterClient) GetResolver() *commandInfoResolver {
19861992
return c.cmdInfoResolver
19871993
}
19881994

1989-
func (c *ClusterClient) SetCommandInfoResolver(cmdInfoResolver *CommandInfoResolver) {
1995+
func (c *ClusterClient) SetCommandInfoResolver(cmdInfoResolver *commandInfoResolver) {
19901996
c.cmdInfoResolver = cmdInfoResolver
19911997
}
19921998

@@ -2001,9 +2007,9 @@ func (c *ClusterClient) extractCommandInfo(ctx context.Context, cmd Cmder) *rout
20012007

20022008
// NewDynamicResolver returns a CommandInfoResolver
20032009
// that uses the underlying cmdInfo cache to resolve the policies
2004-
func (c *ClusterClient) NewDynamicResolver() *CommandInfoResolver {
2005-
return &CommandInfoResolver{
2006-
resolve: c.extractCommandInfo,
2010+
func (c *ClusterClient) NewDynamicResolver() *commandInfoResolver {
2011+
return &commandInfoResolver{
2012+
resolveFunc: c.extractCommandInfo,
20072013
}
20082014
}
20092015

osscluster_router.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@ type slotResult struct {
2121

2222
// routeAndRun routes a command to the appropriate cluster nodes and executes it
2323
func (c *ClusterClient) routeAndRun(ctx context.Context, cmd Cmder, node *clusterNode) error {
24-
policy := c.cmdInfoResolver.GetCommandPolicy(ctx, cmd)
24+
var policy *routing.CommandPolicy
25+
if c.cmdInfoResolver != nil {
26+
policy = c.cmdInfoResolver.GetCommandPolicy(ctx, cmd)
27+
}
28+
2529
if policy == nil {
2630
return c.executeDefault(ctx, cmd, node)
2731
}

0 commit comments

Comments
 (0)