Skip to content

Commit e30fc43

Browse files
committed
feat(core): support modelConfig interface modelObj to enable use local model obj
1 parent 2de71ce commit e30fc43

File tree

4 files changed

+25
-11
lines changed

4 files changed

+25
-11
lines changed

packages/paddlejs-core/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@paddlejs/paddlejs-core",
3-
"version": "2.1.7",
3+
"version": "2.1.9",
44
"description": "",
55
"main": "lib/index",
66
"scripts": {

packages/paddlejs-core/src/commons/interface.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,14 @@ export interface FeedShape {
5959
fw: number;
6060
fh: number;
6161
};
62+
63+
interface ModelObj {
64+
model: Model;
65+
params: Float32Array
66+
}
6267
export interface RunnerConfig {
63-
modelPath: string;
68+
modelPath?: string;
69+
modelObj?: ModelObj;
6470
modelName?: string;
6571
feedShape?: FeedShape;
6672
fill?: string; // 缩放后用什么颜色填充不足方形部分

packages/paddlejs-core/src/loader.ts

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ export default class ModelLoader {
7070
if (this.separateChunk) {
7171
if (this.dataType === 'binary') {
7272
await this.fetchChunks().then(allChunksData =>
73-
this.traverse(modelInfo.vars, allChunksData)
73+
ModelLoader.allocateParamsVar(modelInfo.vars, allChunksData)
7474
);
7575
}
7676
}
@@ -105,9 +105,7 @@ export default class ModelLoader {
105105
this.fetchOneChunk(this.urlConf.dir + this.getFileName(i))
106106
);
107107
}
108-
// console.time('加载时间');
109108
return Promise.all(chunkArray).then(chunks => {
110-
// console.timeEnd('加载时间');
111109
let chunksLength = 0;
112110
const f32Array: any[] = [];
113111
let float32Chunk;
@@ -128,7 +126,7 @@ export default class ModelLoader {
128126
});
129127
}
130128

131-
traverse(vars, allChunksData: Float32Array) {
129+
static allocateParamsVar(vars, allChunksData: Float32Array) {
132130
let marker = 0; // 读到哪个位置了
133131
let len; // 当前op长度
134132
traverseVars(vars, item => {

packages/paddlejs-core/src/runner.ts

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,19 @@ export default class Runner {
7474
}
7575

7676
async load() {
77-
const { modelPath } = this.runnerConfig;
78-
const loader = new Loader(modelPath);
79-
this.model = await loader.load();
77+
const { modelPath, modelObj = null } = this.runnerConfig;
78+
if (modelPath) {
79+
const loader = new Loader(modelPath);
80+
this.model = await loader.load();
81+
}
82+
else if (modelObj?.model && modelObj?.params) {
83+
const {
84+
model,
85+
params
86+
} = modelObj;
87+
Loader.allocateParamsVar(model.vars, params);
88+
this.model = model;
89+
}
8090
}
8191

8292
genGraph() {
@@ -373,8 +383,8 @@ export default class Runner {
373383

374384
if (env.get('debug')
375385
&& op.opData?.outputTensors
376-
&& op.opData.outputTensors[0]
377-
&& op.opData.outputTensors[0].tensorId === this.modelName + '_'
386+
&& op.opData.outputTensors[op.opData.outputTensors.length - 1]
387+
&& op.opData.outputTensors[op.opData.outputTensors.length - 1].tensorId === this.modelName + '_'
378388
+ (env.get('ns').layerName || env.get('layerName'))) {
379389
console.info(op.opData.name + '_' + op.opData.iLayer, 'runner op');
380390
return;

0 commit comments

Comments
 (0)