11package main
22
33import (
4- "bufio"
54 "context"
65 "fmt"
76 "image"
8- "os "
7+ "io/ioutil "
98 "path/filepath"
109 "sort"
10+ "strings"
1111
1212 "github.com/Unknwon/com"
1313 "github.com/anthonynsimon/bild/imgio"
@@ -28,19 +28,19 @@ import (
2828)
2929
3030var (
31- batchSize = 1
32- model = "resnet50"
33- shape = []int {1 , 3 , 224 , 224 }
34- mean = []float32 {123.68 , 116.779 , 103.939 }
35- scale = []float32 {1.0 , 1.0 , 1.0 }
36- imgDir , _ = filepath .Abs ("../_fixtures" )
37- imgPath = filepath .Join (imgDir , "platypus.jpg" )
38- graph_url = "http://s3.amazonaws.com/store.carml.org/models/caffe/resnet50/ResNet-50-deploy.prototxt"
39- weights_url = "http://s3.amazonaws.com/store.carml.org/models/caffe/resnet50/ResNet-50-model.caffemodel"
40- synset_url = "http://s3.amazonaws.com/store.carml.org/synsets/imagenet/synset1001 .txt"
31+ batchSize = 1
32+ model = "resnet50"
33+ shape = []int {1 , 3 , 224 , 224 }
34+ mean = []float32 {123.68 , 116.779 , 103.939 }
35+ scale = []float32 {1.0 , 1.0 , 1.0 }
36+ baseDir , _ = filepath .Abs ("../ ../_fixtures" )
37+ imgPath = filepath .Join (baseDir , "platypus.jpg" )
38+ graphURL = "http://s3.amazonaws.com/store.carml.org/models/caffe/resnet50/ResNet-50-deploy.prototxt"
39+ weightsURL = "http://s3.amazonaws.com/store.carml.org/models/caffe/resnet50/ResNet-50-model.caffemodel"
40+ synsetURL = "http://s3.amazonaws.com/store.carml.org/synsets/imagenet/synset .txt"
4141)
4242
43- // convert go Image to 1-dim array
43+ // convert go RGB Image to 1D normalized RGB array
4444func cvtRGBImageToNCHW1DArray (src image.Image , mean []float32 , scale []float32 ) ([]float32 , error ) {
4545 if src == nil {
4646 return nil , fmt .Errorf ("src image nil" )
@@ -49,41 +49,41 @@ func cvtRGBImageToNCHW1DArray(src image.Image, mean []float32, scale []float32)
4949 in := src .Bounds ()
5050 height := in .Max .Y - in .Min .Y // image height
5151 width := in .Max .X - in .Min .X // image width
52+ stride := width * height // image size per channel
5253
5354 out := make ([]float32 , 3 * height * width )
5455 for y := 0 ; y < height ; y ++ {
5556 for x := 0 ; x < width ; x ++ {
5657 r , g , b , _ := src .At (x + in .Min .X , y + in .Min .Y ).RGBA ()
57- out [y * width + x ] = (float32 (b ) - mean [2 ]) / scale [2 ]
58- out [width * height + y * width + x ] = (float32 (g ) - mean [1 ]) / scale [1 ]
59- out [2 * width * height + y * width + x ] = (float32 (r ) - mean [0 ]) / scale [0 ]
58+ out [0 * stride + y * width + x ] = (float32 (r >> 8 ) - mean [0 ]) / scale [0 ]
59+ out [1 * stride + y * width + x ] = (float32 (g >> 8 ) - mean [1 ]) / scale [1 ]
60+ out [2 * stride + y * width + x ] = (float32 (b >> 8 ) - mean [2 ]) / scale [2 ]
6061 }
6162 }
6263
6364 return out , nil
6465}
6566
6667func main () {
67- defer tracer .Close ()
68-
69- baseDir , _ := filepath .Abs ("../tmp" )
68+ defer tracer .Close ()
69+
7070 dir := filepath .Join (baseDir , model )
7171 graph := filepath .Join (dir , model + ".prototxt" )
7272 weights := filepath .Join (dir , model + ".caffemodel" )
7373 synset := filepath .Join (dir , "synset.txt" )
7474
7575 if ! com .IsFile (graph ) {
76- if _ , _ , err := downloadmanager .DownloadFile (graph_url , graph ); err != nil {
76+ if _ , _ , err := downloadmanager .DownloadFile (graphURL , graph ); err != nil {
7777 panic (err )
7878 }
7979 }
8080 if ! com .IsFile (weights ) {
81- if _ , _ , err := downloadmanager .DownloadFile (weights_url , weights ); err != nil {
81+ if _ , _ , err := downloadmanager .DownloadFile (weightsURL , weights ); err != nil {
8282 panic (err )
8383 }
8484 }
8585 if ! com .IsFile (synset ) {
86- if _ , _ , err := downloadmanager .DownloadFile (synset_url , synset ); err != nil {
86+ if _ , _ , err := downloadmanager .DownloadFile (synsetURL , synset ); err != nil {
8787 panic (err )
8888 }
8989 }
@@ -123,6 +123,10 @@ func main() {
123123 Shape : shape ,
124124 Dtype : gotensor .Float32 ,
125125 }
126+ out := options.Node {
127+ Key : "prob" ,
128+ Dtype : gotensor .Float32 ,
129+ }
126130
127131 predictor , err := tensorrt .New (
128132 ctx ,
@@ -132,26 +136,21 @@ func main() {
132136 options .Weights ([]byte (weights )),
133137 options .BatchSize (batchSize ),
134138 options .InputNodes ([]options.Node {in }),
135- options .OutputNodes ([]options.Node {
136- options.Node {
137- Key : "prob" ,
138- Dtype : gotensor .Float32 ,
139- },
140- }),
139+ options .OutputNodes ([]options.Node {out }),
141140 )
142141 if err != nil {
143142 panic (fmt .Sprintf ("%v" , err ))
144143 }
145144 defer predictor .Close ()
146145
147- for ii := 0 ; ii < 3 ; ii ++ {
148- err = predictor .Predict (ctx , input )
149- if err != nil {
150- panic (err )
146+ for ii := 0 ; ii < 3 ; ii ++ {
147+ err = predictor .Predict (ctx , input )
148+ if err != nil {
149+ panic (err )
150+ }
151151 }
152- }
153152
154- enableCupti := true
153+ enableCupti := true
155154 var cu * cupti.CUPTI
156155 if enableCupti {
157156 cu , err = cupti .New (cupti .Context (ctx ))
@@ -174,64 +173,49 @@ func main() {
174173 cu .Close ()
175174 }
176175
177- if true {
178- profBuffer , err := predictor .ReadProfile ()
179- if err != nil {
180- panic (err )
181- }
176+ profBuffer , err := predictor .ReadProfile ()
177+ if err != nil {
178+ panic (err )
179+ }
182180
183- t , err := ctimer .New (profBuffer )
184- if err != nil {
185- panic (err )
186- }
187- t .Publish (ctx , tracer .FRAMEWORK_TRACE )
181+ t , err := ctimer .New (profBuffer )
182+ if err != nil {
183+ panic (err )
184+ }
185+ t .Publish (ctx , tracer .FRAMEWORK_TRACE )
188186
189- outputs , err := predictor .ReadPredictionOutputs (ctx )
190- if err != nil {
191- panic (err )
192- }
187+ outputs , err := predictor .ReadPredictionOutputs (ctx )
188+ if err != nil {
189+ panic (err )
190+ }
193191
194- output := outputs [0 ]
192+ output := outputs [0 ]
193+ labelsFileContent , err := ioutil .ReadFile (synset )
194+ if err != nil {
195+ panic (err )
196+ }
197+ labels := strings .Split (string (labelsFileContent ), "\n " )
195198
196- var labels []string
197- f , err := os .Open (synset )
198- if err != nil {
199- panic (err )
200- }
201- defer f .Close ()
202- scanner := bufio .NewScanner (f )
203- for scanner .Scan () {
204- line := scanner .Text ()
205- labels = append (labels , line )
206- }
199+ features := make ([]dlframework.Features , batchSize )
200+ featuresLen := len (output ) / batchSize
207201
208- features := make ([]dlframework.Features , batchSize )
209- featuresLen := len (output ) / batchSize
210-
211- for ii := 0 ; ii < batchSize ; ii ++ {
212- rprobs := make ([]* dlframework.Feature , featuresLen )
213- for jj := 0 ; jj < featuresLen ; jj ++ {
214- rprobs [jj ] = feature .New (
215- feature .ClassificationIndex (int32 (jj )),
216- feature .ClassificationLabel (labels [jj ]),
217- feature .Probability (output [ii * featuresLen + jj ]),
218- )
219- }
220- sort .Sort (dlframework .Features (rprobs ))
221- features [ii ] = rprobs
202+ for ii := 0 ; ii < batchSize ; ii ++ {
203+ rprobs := make ([]* dlframework.Feature , featuresLen )
204+ for jj := 0 ; jj < featuresLen ; jj ++ {
205+ rprobs [jj ] = feature .New (
206+ feature .ClassificationIndex (int32 (jj )),
207+ feature .ClassificationLabel (labels [jj ]),
208+ feature .Probability (output [ii * featuresLen + jj ]),
209+ )
222210 }
211+ sort .Sort (dlframework .Features (rprobs ))
212+ features [ii ] = rprobs
213+ }
223214
224- if true {
225- results := features [0 ]
226- for i := 0 ; i < 3 ; i ++ {
227- prediction := results [i ]
228- pp .Println (prediction .Probability )
229- pp .Println (prediction .GetClassification ().GetIndex ())
230- pp .Println (prediction .GetClassification ().GetLabel ())
231- }
232- } else {
233- _ = features
234- }
215+ results := features [0 ]
216+ for i := 0 ; i < 3 ; i ++ {
217+ prediction := results [i ]
218+ pp .Println (prediction .Probability , prediction .GetClassification ().GetIndex (), prediction .GetClassification ().GetLabel ())
235219 }
236220}
237221
0 commit comments