Skip to content

Commit c64d9a3

Browse files
committed
fixed run and onnx and uff support
Former-commit-id: 32b162a970d986c06d7d1abb459d5c1d04fdc045 [formerly 16421fb3ed67f59ed008915f362acaa8c85f6ef1] [formerly 92d5a72 [formerly 31e8790]] Former-commit-id: c93d688c24bcb5d7dd029aa8b99d55bd69155877 [formerly e323b474c51a518df534a54fa2914f249d10194d] Former-commit-id: 53f79d47710d97a4818e3f7a43fe9b2c1fcd738c
1 parent 2dd0c14 commit c64d9a3

File tree

12 files changed

+275
-199
lines changed

12 files changed

+275
-199
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ build/*
22
networks/*
33
*.o
44
*.so
5+
_fixtures
56

67
tmp
78
tensorrt

.vscode/settings.json

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,20 @@
4444
"utility": "cpp",
4545
"valarray": "cpp",
4646
"optional": "cpp",
47-
"string_view": "cpp"
47+
"string_view": "cpp",
48+
"cstdarg": "cpp",
49+
"atomic": "cpp",
50+
"strstream": "cpp",
51+
"bitset": "cpp",
52+
"complex": "cpp",
53+
"algorithm": "cpp",
54+
"iterator": "cpp",
55+
"map": "cpp",
56+
"memory_resource": "cpp",
57+
"set": "cpp",
58+
"string": "cpp",
59+
"iomanip": "cpp",
60+
"cfenv": "cpp",
61+
"cinttypes": "cpp"
4862
}
4963
}

cbits/predictor.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ extern "C"
2424
{
2525
TensorRT_CaffeFormat = 1,
2626
TensorRT_OnnxFormat = 2,
27-
TensorRT_TensorFlowFormat = 3
27+
TensorRT_UffFormat = 3,
2828
} TensorRT_ModelFormat;
2929

3030
typedef enum TensorRT_DType
@@ -63,8 +63,9 @@ extern "C"
6363
X(TensorRT_Double, double)
6464

6565
PredictorHandle
66-
NewTensorRTPredictor(TensorRT_ModelFormat model_format, char *deploy_file,
67-
char *weights_file, TensorRT_DType model_datatype,
66+
NewTensorRTPredictor(TensorRT_ModelFormat model_format,
67+
char **model_files,
68+
TensorRT_DType model_datatype,
6869
char **input_layer_names, int32_t num_input_layer_names,
6970
char **output_layer_names, int32_t num_output_layer_names,
7071
int32_t batch_size);
51.1 MB
Binary file not shown.
Binary file not shown.

examples/_fixtures/networks/bvlc_googlenet.caffemodel.2.tensorcache.REMOVED.git-id

Lines changed: 0 additions & 1 deletion
This file was deleted.

examples/_fixtures/networks/bvlc_googlenet.caffemodel.REMOVED.git-id

Lines changed: 0 additions & 1 deletion
This file was deleted.

examples/batch_mlmodelscope/main.go

Lines changed: 69 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
package main
22

33
import (
4-
"bufio"
54
"context"
65
"fmt"
76
"image"
8-
"os"
7+
"io/ioutil"
98
"path/filepath"
109
"sort"
10+
"strings"
1111

1212
"github.com/Unknwon/com"
1313
"github.com/anthonynsimon/bild/imgio"
@@ -28,19 +28,19 @@ import (
2828
)
2929

3030
var (
31-
batchSize = 1
32-
model = "resnet50"
33-
shape = []int{1, 3, 224, 224}
34-
mean = []float32{123.68, 116.779, 103.939}
35-
scale = []float32{1.0, 1.0, 1.0}
36-
imgDir, _ = filepath.Abs("../_fixtures")
37-
imgPath = filepath.Join(imgDir, "platypus.jpg")
38-
graph_url = "http://s3.amazonaws.com/store.carml.org/models/caffe/resnet50/ResNet-50-deploy.prototxt"
39-
weights_url = "http://s3.amazonaws.com/store.carml.org/models/caffe/resnet50/ResNet-50-model.caffemodel"
40-
synset_url = "http://s3.amazonaws.com/store.carml.org/synsets/imagenet/synset1001.txt"
31+
batchSize = 1
32+
model = "resnet50"
33+
shape = []int{1, 3, 224, 224}
34+
mean = []float32{123.68, 116.779, 103.939}
35+
scale = []float32{1.0, 1.0, 1.0}
36+
baseDir, _ = filepath.Abs("../../_fixtures")
37+
imgPath = filepath.Join(baseDir, "platypus.jpg")
38+
graphURL = "http://s3.amazonaws.com/store.carml.org/models/caffe/resnet50/ResNet-50-deploy.prototxt"
39+
weightsURL = "http://s3.amazonaws.com/store.carml.org/models/caffe/resnet50/ResNet-50-model.caffemodel"
40+
synsetURL = "http://s3.amazonaws.com/store.carml.org/synsets/imagenet/synset.txt"
4141
)
4242

43-
// convert go Image to 1-dim array
43+
// convert go RGB Image to 1D normalized RGB array
4444
func cvtRGBImageToNCHW1DArray(src image.Image, mean []float32, scale []float32) ([]float32, error) {
4545
if src == nil {
4646
return nil, fmt.Errorf("src image nil")
@@ -49,41 +49,41 @@ func cvtRGBImageToNCHW1DArray(src image.Image, mean []float32, scale []float32)
4949
in := src.Bounds()
5050
height := in.Max.Y - in.Min.Y // image height
5151
width := in.Max.X - in.Min.X // image width
52+
stride := width * height // image size per channel
5253

5354
out := make([]float32, 3*height*width)
5455
for y := 0; y < height; y++ {
5556
for x := 0; x < width; x++ {
5657
r, g, b, _ := src.At(x+in.Min.X, y+in.Min.Y).RGBA()
57-
out[y*width+x] = (float32(b) - mean[2]) / scale[2]
58-
out[width*height+y*width+x] = (float32(g) - mean[1]) / scale[1]
59-
out[2*width*height+y*width+x] = (float32(r) - mean[0]) / scale[0]
58+
out[0*stride+y*width+x] = (float32(r>>8) - mean[0]) / scale[0]
59+
out[1*stride+y*width+x] = (float32(g>>8) - mean[1]) / scale[1]
60+
out[2*stride+y*width+x] = (float32(b>>8) - mean[2]) / scale[2]
6061
}
6162
}
6263

6364
return out, nil
6465
}
6566

6667
func main() {
67-
defer tracer.Close()
68-
69-
baseDir, _ := filepath.Abs("../tmp")
68+
defer tracer.Close()
69+
7070
dir := filepath.Join(baseDir, model)
7171
graph := filepath.Join(dir, model+".prototxt")
7272
weights := filepath.Join(dir, model+".caffemodel")
7373
synset := filepath.Join(dir, "synset.txt")
7474

7575
if !com.IsFile(graph) {
76-
if _, _, err := downloadmanager.DownloadFile(graph_url, graph); err != nil {
76+
if _, _, err := downloadmanager.DownloadFile(graphURL, graph); err != nil {
7777
panic(err)
7878
}
7979
}
8080
if !com.IsFile(weights) {
81-
if _, _, err := downloadmanager.DownloadFile(weights_url, weights); err != nil {
81+
if _, _, err := downloadmanager.DownloadFile(weightsURL, weights); err != nil {
8282
panic(err)
8383
}
8484
}
8585
if !com.IsFile(synset) {
86-
if _, _, err := downloadmanager.DownloadFile(synset_url, synset); err != nil {
86+
if _, _, err := downloadmanager.DownloadFile(synsetURL, synset); err != nil {
8787
panic(err)
8888
}
8989
}
@@ -123,6 +123,10 @@ func main() {
123123
Shape: shape,
124124
Dtype: gotensor.Float32,
125125
}
126+
out := options.Node{
127+
Key: "prob",
128+
Dtype: gotensor.Float32,
129+
}
126130

127131
predictor, err := tensorrt.New(
128132
ctx,
@@ -132,26 +136,21 @@ func main() {
132136
options.Weights([]byte(weights)),
133137
options.BatchSize(batchSize),
134138
options.InputNodes([]options.Node{in}),
135-
options.OutputNodes([]options.Node{
136-
options.Node{
137-
Key: "prob",
138-
Dtype: gotensor.Float32,
139-
},
140-
}),
139+
options.OutputNodes([]options.Node{out}),
141140
)
142141
if err != nil {
143142
panic(fmt.Sprintf("%v", err))
144143
}
145144
defer predictor.Close()
146145

147-
for ii:=0; ii < 3; ii++ {
148-
err = predictor.Predict(ctx, input)
149-
if err != nil {
150-
panic(err)
146+
for ii := 0; ii < 3; ii++ {
147+
err = predictor.Predict(ctx, input)
148+
if err != nil {
149+
panic(err)
150+
}
151151
}
152-
}
153152

154-
enableCupti := true
153+
enableCupti := true
155154
var cu *cupti.CUPTI
156155
if enableCupti {
157156
cu, err = cupti.New(cupti.Context(ctx))
@@ -174,64 +173,49 @@ func main() {
174173
cu.Close()
175174
}
176175

177-
if true {
178-
profBuffer, err := predictor.ReadProfile()
179-
if err != nil {
180-
panic(err)
181-
}
176+
profBuffer, err := predictor.ReadProfile()
177+
if err != nil {
178+
panic(err)
179+
}
182180

183-
t, err := ctimer.New(profBuffer)
184-
if err != nil {
185-
panic(err)
186-
}
187-
t.Publish(ctx, tracer.FRAMEWORK_TRACE)
181+
t, err := ctimer.New(profBuffer)
182+
if err != nil {
183+
panic(err)
184+
}
185+
t.Publish(ctx, tracer.FRAMEWORK_TRACE)
188186

189-
outputs, err := predictor.ReadPredictionOutputs(ctx)
190-
if err != nil {
191-
panic(err)
192-
}
187+
outputs, err := predictor.ReadPredictionOutputs(ctx)
188+
if err != nil {
189+
panic(err)
190+
}
193191

194-
output := outputs[0]
192+
output := outputs[0]
193+
labelsFileContent, err := ioutil.ReadFile(synset)
194+
if err != nil {
195+
panic(err)
196+
}
197+
labels := strings.Split(string(labelsFileContent), "\n")
195198

196-
var labels []string
197-
f, err := os.Open(synset)
198-
if err != nil {
199-
panic(err)
200-
}
201-
defer f.Close()
202-
scanner := bufio.NewScanner(f)
203-
for scanner.Scan() {
204-
line := scanner.Text()
205-
labels = append(labels, line)
206-
}
199+
features := make([]dlframework.Features, batchSize)
200+
featuresLen := len(output) / batchSize
207201

208-
features := make([]dlframework.Features, batchSize)
209-
featuresLen := len(output) / batchSize
210-
211-
for ii := 0; ii < batchSize; ii++ {
212-
rprobs := make([]*dlframework.Feature, featuresLen)
213-
for jj := 0; jj < featuresLen; jj++ {
214-
rprobs[jj] = feature.New(
215-
feature.ClassificationIndex(int32(jj)),
216-
feature.ClassificationLabel(labels[jj]),
217-
feature.Probability(output[ii*featuresLen+jj]),
218-
)
219-
}
220-
sort.Sort(dlframework.Features(rprobs))
221-
features[ii] = rprobs
202+
for ii := 0; ii < batchSize; ii++ {
203+
rprobs := make([]*dlframework.Feature, featuresLen)
204+
for jj := 0; jj < featuresLen; jj++ {
205+
rprobs[jj] = feature.New(
206+
feature.ClassificationIndex(int32(jj)),
207+
feature.ClassificationLabel(labels[jj]),
208+
feature.Probability(output[ii*featuresLen+jj]),
209+
)
222210
}
211+
sort.Sort(dlframework.Features(rprobs))
212+
features[ii] = rprobs
213+
}
223214

224-
if true {
225-
results := features[0]
226-
for i := 0; i < 3; i++ {
227-
prediction := results[i]
228-
pp.Println(prediction.Probability)
229-
pp.Println(prediction.GetClassification().GetIndex())
230-
pp.Println(prediction.GetClassification().GetLabel())
231-
}
232-
} else {
233-
_ = features
234-
}
215+
results := features[0]
216+
for i := 0; i < 3; i++ {
217+
prediction := results[i]
218+
pp.Println(prediction.Probability, prediction.GetClassification().GetIndex(), prediction.GetClassification().GetLabel())
235219
}
236220
}
237221

model_format.go

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,30 @@
11
package tensorrt
22

3+
import (
4+
"strings"
5+
)
6+
37
// ModelFormat ...
48
type ModelFormat int
59

610
const (
7-
ModelFormatCaffe ModelFormat = 1
8-
ModelFormatOnnx ModelFormat = 2
9-
ModelFormatTensorFlow ModelFormat = 3
10-
ModelFormatUnknown ModelFormat = 999
11+
ModelFormatCaffe ModelFormat = 1
12+
ModelFormatOnnx ModelFormat = 2
13+
ModelFormatUff ModelFormat = 3
14+
ModelFormatUnknown ModelFormat = 999
1115
)
1216

13-
func ClassifyModelFormat(paths ...string) ModelFormat {
14-
return ModelFormatCaffe
17+
func ClassifyModelFormat(path string) ModelFormat {
18+
var format ModelFormat
19+
if strings.HasSuffix(path, "prototxt") {
20+
format = ModelFormatCaffe
21+
} else if strings.HasSuffix(path, "onnx") {
22+
format = ModelFormatOnnx
23+
} else if strings.HasSuffix(path, "uff") {
24+
format = ModelFormatUff
25+
} else {
26+
format = ModelFormatUnknown
27+
}
28+
29+
return format
1530
}

0 commit comments

Comments
 (0)