Skip to content

Commit 3ff2f73

Browse files
authored
Merge pull request #4 from go-spring-projects/redesign-bean-init
Redesign the BeanInit interface to get away from Go-Spring dependence
2 parents b719a2b + 52c1517 commit 3ff2f73

File tree

3 files changed

+169
-40
lines changed

3 files changed

+169
-40
lines changed

gs/gs.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,19 @@ type Context interface {
7979
Go(fn func(ctx context.Context))
8080
}
8181

82+
type contextKey struct{}
83+
84+
func WithContext(ctx Context) context.Context {
85+
return context.WithValue(ctx.Context(), contextKey{}, ctx)
86+
}
87+
88+
func FromContext(ctx context.Context) Context {
89+
if val := ctx.Value(contextKey{}); val != nil {
90+
return val.(Context)
91+
}
92+
return nil
93+
}
94+
8295
type tempContainer struct {
8396
props *conf.Properties
8497
beans []*BeanDefinition

gs/gs_bean.go

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package gs
1818

1919
import (
20+
"context"
2021
"errors"
2122
"fmt"
2223
"reflect"
@@ -66,7 +67,7 @@ func BeanID(typ interface{}, name string) string {
6667
}
6768

6869
type BeanInit interface {
69-
OnInit(ctx Context) error
70+
OnInit(ctx context.Context) error
7071
}
7172

7273
type BeanDestroy interface {
@@ -215,9 +216,24 @@ func validLifeCycleFunc(fnType reflect.Type, beanValue reflect.Value) bool {
215216
if !utils.IsFuncType(fnType) {
216217
return false
217218
}
218-
if fnType.NumIn() != 1 || !utils.HasReceiver(fnType, beanValue) {
219+
220+
switch fnType.NumIn() {
221+
case 1:
222+
// func(bean)
223+
// func(bean) error
224+
if !utils.HasReceiver(fnType, beanValue) {
225+
return false
226+
}
227+
case 2:
228+
// func(bean, ctx)
229+
// func(bean, ctx) error
230+
if !utils.HasReceiver(fnType, beanValue) || !utils.IsContextType(fnType.In(1)) {
231+
return false
232+
}
233+
default:
219234
return false
220235
}
236+
221237
return utils.ReturnNothing(fnType) || utils.ReturnOnlyError(fnType)
222238
}
223239

@@ -227,7 +243,7 @@ func (d *BeanDefinition) Init(fn interface{}) *BeanDefinition {
227243
d.init = fn
228244
return d
229245
}
230-
panic(errors.New("init should be func(bean) or func(bean)error"))
246+
panic(errors.New("init should be func(bean,[ctx]) or func(bean,[ctx])error"))
231247
}
232248

233249
// Destroy Set the destruction function for a bean.
@@ -236,7 +252,7 @@ func (d *BeanDefinition) Destroy(fn interface{}) *BeanDefinition {
236252
d.destroy = fn
237253
return d
238254
}
239-
panic(errors.New("destroy should be func(bean) or func(bean)error"))
255+
panic(errors.New("destroy should be func(bean,[ctx]) or func(bean,[ctx])error"))
240256
}
241257

242258
// Export indicates the types of interface to export.
@@ -283,14 +299,19 @@ func (d *BeanDefinition) export(exports ...interface{}) error {
283299
func (d *BeanDefinition) constructor(ctx Context) error {
284300
if d.init != nil {
285301
fnValue := reflect.ValueOf(d.init)
286-
out := fnValue.Call([]reflect.Value{d.Value()})
302+
fnValues := []reflect.Value{d.Value()}
303+
if fnValue.Type().NumIn() > 1 {
304+
fnValues = append(fnValues, reflect.ValueOf(WithContext(ctx)))
305+
}
306+
307+
out := fnValue.Call(fnValues)
287308
if len(out) > 0 && !out[0].IsNil() {
288309
return out[0].Interface().(error)
289310
}
290311
}
291312

292313
if f, ok := d.Interface().(BeanInit); ok {
293-
if err := f.OnInit(ctx); err != nil {
314+
if err := f.OnInit(WithContext(ctx)); err != nil {
294315
return err
295316
}
296317
}
@@ -300,6 +321,10 @@ func (d *BeanDefinition) constructor(ctx Context) error {
300321
func (d *BeanDefinition) destructor() {
301322
if d.destroy != nil {
302323
fnValue := reflect.ValueOf(d.destroy)
324+
fnValues := []reflect.Value{d.Value()}
325+
if fnValue.Type().NumIn() > 1 {
326+
fnValues = append(fnValues, reflect.ValueOf(context.Background()))
327+
}
303328
fnValue.Call([]reflect.Value{d.Value()})
304329
}
305330

gs/gs_test.go

Lines changed: 125 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package gs
1818

1919
import (
2020
"bytes"
21+
"context"
2122
"errors"
2223
"fmt"
2324
"image"
@@ -926,6 +927,8 @@ type destroyable interface {
926927
Init()
927928
Destroy()
928929
InitWithError() error
930+
InitWithCtx(ctx context.Context)
931+
InitWithCtxError(ctx context.Context) error
929932
DestroyWithError() error
930933
}
931934

@@ -951,6 +954,27 @@ func (d *callDestroy) InitWithError() error {
951954
return fmt.Errorf("error")
952955
}
953956

957+
func (d *callDestroy) InitWithCtx(ctx context.Context) {
958+
if d.i == 0 {
959+
d.inited = true
960+
}
961+
if nil == FromContext(ctx) {
962+
panic("invalid context")
963+
}
964+
}
965+
966+
func (d *callDestroy) InitWithCtxError(ctx context.Context) error {
967+
if d.i == 0 {
968+
d.inited = true
969+
970+
if nil == FromContext(ctx) {
971+
return fmt.Errorf("invalid context")
972+
}
973+
return nil
974+
}
975+
return fmt.Errorf("error")
976+
}
977+
954978
func (d *callDestroy) DestroyWithError() error {
955979
if d.i == 0 {
956980
d.destroyed = true
@@ -971,15 +995,29 @@ func TestRegisterBean_InitFunc(t *testing.T) {
971995

972996
t.Run("call init", func(t *testing.T) {
973997

974-
c := New()
975-
c.Object(new(callDestroy)).Init((*callDestroy).Init)
976-
err := runTest(c, func(p Context) {
977-
var d *callDestroy
978-
err := p.Get(&d)
998+
{
999+
c := New()
1000+
c.Object(new(callDestroy)).Init((*callDestroy).Init)
1001+
err := runTest(c, func(p Context) {
1002+
var d *callDestroy
1003+
err := p.Get(&d)
1004+
assert.Nil(t, err)
1005+
assert.True(t, d.inited)
1006+
})
9791007
assert.Nil(t, err)
980-
assert.True(t, d.inited)
981-
})
982-
assert.Nil(t, err)
1008+
}
1009+
1010+
{
1011+
c := New()
1012+
c.Object(new(callDestroy)).Init((*callDestroy).InitWithCtx)
1013+
err := runTest(c, func(p Context) {
1014+
var d *callDestroy
1015+
err := p.Get(&d)
1016+
assert.Nil(t, err)
1017+
assert.True(t, d.inited)
1018+
})
1019+
assert.Nil(t, err)
1020+
}
9831021
})
9841022

9851023
t.Run("call init with error", func(t *testing.T) {
@@ -991,32 +1029,72 @@ func TestRegisterBean_InitFunc(t *testing.T) {
9911029
assert.Error(t, err, "error")
9921030
}
9931031

994-
c := New()
995-
p := conf.New()
996-
p.Set("int", 0)
997-
c.Object(&callDestroy{}).Init((*callDestroy).InitWithError)
1032+
{
1033+
c := New()
1034+
c.Object(&callDestroy{i: 1}).Init((*callDestroy).InitWithCtxError)
1035+
err := c.Refresh()
1036+
assert.Error(t, err, "error")
1037+
}
9981038

999-
err := c.Properties().Refresh(p)
1000-
assert.Nil(t, err)
1001-
err = runTest(c, func(p Context) {
1002-
var d *callDestroy
1003-
err = p.Get(&d)
1039+
{
1040+
c := New()
1041+
p := conf.New()
1042+
p.Set("int", 0)
1043+
c.Object(&callDestroy{}).Init((*callDestroy).InitWithError)
1044+
1045+
err := c.Properties().Refresh(p)
10041046
assert.Nil(t, err)
1005-
assert.True(t, d.inited)
1006-
})
1007-
assert.Nil(t, err)
1047+
err = runTest(c, func(p Context) {
1048+
var d *callDestroy
1049+
err = p.Get(&d)
1050+
assert.Nil(t, err)
1051+
assert.True(t, d.inited)
1052+
})
1053+
assert.Nil(t, err)
1054+
}
1055+
1056+
{
1057+
c := New()
1058+
p := conf.New()
1059+
p.Set("int", 0)
1060+
c.Object(&callDestroy{}).Init((*callDestroy).InitWithCtxError)
1061+
1062+
err := c.Properties().Refresh(p)
1063+
assert.Nil(t, err)
1064+
err = runTest(c, func(p Context) {
1065+
var d *callDestroy
1066+
err = p.Get(&d)
1067+
assert.Nil(t, err)
1068+
assert.True(t, d.inited)
1069+
})
1070+
assert.Nil(t, err)
1071+
}
10081072
})
10091073

10101074
t.Run("call interface init", func(t *testing.T) {
1011-
c := New()
1012-
c.Provide(func() destroyable { return new(callDestroy) }).Init(destroyable.Init)
1013-
err := runTest(c, func(p Context) {
1014-
var d destroyable
1015-
err := p.Get(&d)
1075+
{
1076+
c := New()
1077+
c.Provide(func() destroyable { return new(callDestroy) }).Init(destroyable.Init)
1078+
err := runTest(c, func(p Context) {
1079+
var d destroyable
1080+
err := p.Get(&d)
1081+
assert.Nil(t, err)
1082+
assert.True(t, d.(*callDestroy).inited)
1083+
})
10161084
assert.Nil(t, err)
1017-
assert.True(t, d.(*callDestroy).inited)
1018-
})
1019-
assert.Nil(t, err)
1085+
}
1086+
1087+
{
1088+
c := New()
1089+
c.Provide(func() destroyable { return new(callDestroy) }).Init(destroyable.InitWithCtx)
1090+
err := runTest(c, func(p Context) {
1091+
var d destroyable
1092+
err := p.Get(&d)
1093+
assert.Nil(t, err)
1094+
assert.True(t, d.(*callDestroy).inited)
1095+
})
1096+
assert.Nil(t, err)
1097+
}
10201098
})
10211099

10221100
t.Run("call interface init with error", func(t *testing.T) {
@@ -1028,6 +1106,13 @@ func TestRegisterBean_InitFunc(t *testing.T) {
10281106
assert.Error(t, err, "error")
10291107
}
10301108

1109+
{
1110+
c := New()
1111+
c.Provide(func() destroyable { return &callDestroy{i: 1} }).Init(destroyable.InitWithCtxError)
1112+
err := c.Refresh()
1113+
assert.Error(t, err, "error")
1114+
}
1115+
10311116
c := New()
10321117
p := conf.New()
10331118
p.Set("int", 0)
@@ -1932,22 +2017,22 @@ func TestApplicationContext_Close(t *testing.T) {
19322017
assert.Panic(t, func() {
19332018
c := New()
19342019
c.Object(func() {}).Destroy(func() {})
1935-
}, "destroy should be func\\(bean\\) or func\\(bean\\)error")
2020+
}, "destroy should be func\\(bean,\\[ctx\\]\\) or func\\(bean,\\[ctx\\]\\)error")
19362021

19372022
assert.Panic(t, func() {
19382023
c := New()
19392024
c.Object(func() {}).Destroy(func() int { return 0 })
1940-
}, "destroy should be func\\(bean\\) or func\\(bean\\)error")
2025+
}, "destroy should be func\\(bean,\\[ctx\\]\\) or func\\(bean,\\[ctx\\]\\)error")
19412026

19422027
assert.Panic(t, func() {
19432028
c := New()
19442029
c.Object(func() {}).Destroy(func(int) {})
1945-
}, "destroy should be func\\(bean\\) or func\\(bean\\)error")
2030+
}, "destroy should be func\\(bean,\\[ctx\\]\\) or func\\(bean,\\[ctx\\]\\)error")
19462031

19472032
assert.Panic(t, func() {
19482033
c := New()
19492034
c.Object(func() {}).Destroy(func(int, int) {})
1950-
}, "destroy should be func\\(bean\\) or func\\(bean\\)error")
2035+
}, "destroy should be func\\(bean,\\[ctx\\]\\) or func\\(bean,\\[ctx\\]\\)error")
19512036
})
19522037

19532038
t.Run("call destroy fn", func(t *testing.T) {
@@ -2848,7 +2933,10 @@ func TestLazy(t *testing.T) {
28482933
type memory struct {
28492934
}
28502935

2851-
func (m *memory) OnInit(ctx Context) error {
2936+
func (m *memory) OnInit(ctx context.Context) error {
2937+
if nil == FromContext(ctx) {
2938+
panic("invalid context")
2939+
}
28522940
fmt.Println("memory.OnInit")
28532941
return nil
28542942
}
@@ -2861,7 +2949,10 @@ type table struct {
28612949
_ *memory `autowire:""`
28622950
}
28632951

2864-
func (t *table) OnInit(ctx Context) error {
2952+
func (t *table) OnInit(ctx context.Context) error {
2953+
if nil == FromContext(ctx) {
2954+
panic("invalid context")
2955+
}
28652956
fmt.Println("table.OnInit")
28662957
return nil
28672958
}

0 commit comments

Comments
 (0)