Skip to content

Commit c0c93dc

Browse files
committed
WIP efficient VQ search
1 parent e8748cb commit c0c93dc

File tree

4 files changed

+61
-22
lines changed

4 files changed

+61
-22
lines changed

misc/vq_mbest.c

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,19 @@
1919
#define MAX_ENTRIES 4096
2020
#define MAX_STAGES 5
2121

22-
void quant_pred_mbest(float vec_out[],
23-
int indexes[],
24-
float vec_in[],
25-
int num_stages,
26-
float vq[],
27-
int m[], int k,
28-
int mbest_survivors, int st, int en);
22+
void quant_mbest(float vec_out[],
23+
int indexes[],
24+
float vec_in[],
25+
int num_stages,
26+
float vq[], float vqsq[],
27+
int m[], int k,
28+
int mbest_survivors, int st, int en);
2929

3030
int verbose = 0;
3131

3232
int main(int argc, char *argv[]) {
3333
float vq[MAX_STAGES*MAX_K*MAX_ENTRIES];
34+
float vqsq[MAX_STAGES*MAX_ENTRIES];
3435
int m[MAX_STAGES];
3536
int k=0, mbest_survivors=1, num_stages=0;
3637
char fnames[256], fn[256], *comma, *p;
@@ -143,6 +144,10 @@ int main(int argc, char *argv[]) {
143144
if (st == -1) st = 0;
144145
if (en == -1) en = k-1;
145146

147+
/* precompute vqsq table for efficient search */
148+
for(int s=0; s<num_stages; s++)
149+
mbest_precompute_cbsq(&vqsq[s*MAX_ENTRIES], &vq[s*k*MAX_ENTRIES], k, m[s]);
150+
146151
int indexes[num_stages], nvecs = 0; int vec_usage[m[0]];
147152
for(int i=0; i<m[0]; i++) vec_usage[i] = 0;
148153
float target[k], quantised[k];
@@ -161,7 +166,7 @@ int main(int argc, char *argv[]) {
161166
target[i] += -difference;
162167
dont_count = 1;
163168
}
164-
quant_pred_mbest(quantised, indexes, target, num_stages, vq, m, k, mbest_survivors, st, en);
169+
quant_mbest(quantised, indexes, target, num_stages, vq, vqsq, m, k, mbest_survivors, st, en);
165170
if (dont_count == 0) {
166171
for(int i=st; i<=en; i++)
167172
sqe += pow(target[i]-quantised[i], 2.0);
@@ -196,13 +201,14 @@ void pv(char s[], float v[], int k) {
196201

197202
// mbest algorithm version, backported from LPCNet/src
198203

199-
void quant_pred_mbest(float vec_out[],
200-
int indexes[],
201-
float vec_in[],
202-
int num_stages,
203-
float vq[],
204-
int m[], int k,
205-
int mbest_survivors, int st, int en)
204+
void quant_mbest(float vec_out[],
205+
int indexes[],
206+
float vec_in[],
207+
int num_stages,
208+
float vq[],
209+
float vqsq[],
210+
int m[], int k,
211+
int mbest_survivors, int st, int en)
206212
{
207213
float err[k], w[k], se1;
208214
int i,j,s,s1,ind;
@@ -233,7 +239,7 @@ void quant_pred_mbest(float vec_out[],
233239
/* now quantise err[] using multi-stage mbest search, preserving
234240
mbest_survivors at each stage */
235241

236-
mbest_search(vq, err, w, k, m[0], mbest_stage[0], index);
242+
mbest_search(vq, vqsq, err, w, k, m[0], mbest_stage[0], index);
237243
if (verbose) mbest_print("Stage 1:", mbest_stage[0]);
238244

239245
for(s=1; s<num_stages; s++) {
@@ -255,7 +261,7 @@ void quant_pred_mbest(float vec_out[],
255261
}
256262
}
257263
pv(" target: ", target, k);
258-
mbest_search(&vq[s*k*MAX_ENTRIES], target, w, k, m[s], mbest_stage[s], index);
264+
mbest_search(&vq[s*k*MAX_ENTRIES], &vqsq[s*MAX_ENTRIES], target, w, k, m[s], mbest_stage[s], index);
259265
}
260266
char str[80]; sprintf(str,"Stage %d:", s+1);
261267
if (verbose) mbest_print(str, mbest_stage[s]);

src/mbest.c

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,16 @@ void mbest_destroy(struct MBEST *mbest) {
6464
}
6565

6666

67+
/* precompyte table for efficient VQ search */
68+
69+
void mbest_precompute_cbsq(float cbsq[], float cb[], int k, int m) {
70+
for (int j=0; j<m; j++) {
71+
cbsq[j] = 0.0;
72+
for(int i=0; i<k; i++)
73+
cbsq[j] += cb[j*k+i]*cb[j*k+i];
74+
}
75+
}
76+
6777
/*---------------------------------------------------------------------------*\
6878
6979
mbest_insert
@@ -113,6 +123,7 @@ void mbest_print(char title[], struct MBEST *mbest) {
113123

114124
void mbest_search(
115125
const float *cb, /* VQ codebook to search */
126+
const float *cbsq, /* sum sq of each VQ entry */
116127
float vec[], /* target vector */
117128
float w[], /* weighting vector */
118129
int k, /* dimension of vector */
@@ -128,12 +139,23 @@ void mbest_search(
128139
float diff;
129140
int i;
130141

142+
/*
143+
float cbsq = 0.0;
144+
for(int i = 0; i < k; i++) {
145+
cbsq += cb[j*k+i]*cb[j*k+i];
146+
}
147+
*/
148+
/*
131149
e = 0.0;
132150
for(i=0; i<k; i++) {
133151
diff = cb[j*k+i]-vec[i];
134152
e += diff*w[i]*diff*w[i];
135153
}
136-
154+
*/
155+
float corr = 0.0;
156+
for(i=0; i<k; i++)
157+
corr += cb[j*k+i]*vec[i];
158+
float e = cbsq[j] - 2*corr;
137159
index[0] = j;
138160
if (e < mbest->list[mbest->entries - 1].error)
139161
mbest_insert(mbest, index, e);

src/mbest.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ struct MBEST {
4444

4545
struct MBEST *mbest_create(int entries);
4646
void mbest_destroy(struct MBEST *mbest);
47+
void mbest_precompute_cbsq(float cbsq[], float cb[], int k, int m);
4748
void mbest_insert(struct MBEST *mbest, int index[], float error);
48-
void mbest_search(const float *cb, float vec[], float w[], int k, int m, struct MBEST *mbest, int index[]);
49+
void mbest_search(const float *cb, const float *cbsq, float vec[], float w[], int k, int m, struct MBEST *mbest, int index[]);
4950
void mbest_search_equalweight(const float *cb, float vec[], int k, int m, struct MBEST *mbest, int index[]);
5051
void mbest_search450(const float *cb, float vec[], float w[], int k,int shorterK, int m, struct MBEST *mbest, int index[]);
5152

src/newamp1.c

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,14 +166,24 @@ float rate_K_mbest_encode(int *indexes, float *x, float *xq, int ndim, int mbest
166166
const float *codebook2 = newamp1vq_cb[1].cb;
167167
struct MBEST *mbest_stage1, *mbest_stage2;
168168
float target[ndim];
169+
float w[ndim];
169170
int index[MBEST_STAGES];
170171
float mse, tmp;
172+
float codebook1sq[newamp1vq_cb[0].m];
173+
float codebook2sq[newamp1vq_cb[1].m];
171174

175+
/* precompute tables for efficient search */
176+
mbest_precompute_cbsq(codebook1sq, codebook1, newamp1vq_cb[0].k, newamp1vq_cb[0].m);
177+
mbest_precompute_cbsq(codebook2sq, codebook2, newamp1vq_cb[0].k, newamp1vq_cb[0].m);
178+
172179
/* codebook is compiled for a fixed K */
173180

174181
assert(ndim == newamp1vq_cb[0].k);
175182

176-
/* note: using equal weights, could be argued mel freq axis gives freq dep weighting */
183+
/* equal weights, could be argued mel freq axis gives freq dep weighting */
184+
185+
for(i=0; i<ndim; i++)
186+
w[i] = 1.0;
177187

178188
mbest_stage1 = mbest_create(mbest_entries);
179189
mbest_stage2 = mbest_create(mbest_entries);
@@ -182,15 +192,15 @@ float rate_K_mbest_encode(int *indexes, float *x, float *xq, int ndim, int mbest
182192

183193
/* Stage 1 */
184194

185-
mbest_search_equalweight(codebook1, x, ndim, newamp1vq_cb[0].m, mbest_stage1, index);
195+
mbest_search(codebook1, codebook1sq, x, w, ndim, newamp1vq_cb[0].m, mbest_stage1, index);
186196

187197
/* Stage 2 */
188198

189199
for (j=0; j<mbest_entries; j++) {
190200
index[1] = n1 = mbest_stage1->list[j].index[0];
191201
for(i=0; i<ndim; i++)
192202
target[i] = x[i] - codebook1[ndim*n1+i];
193-
mbest_search_equalweight(codebook2, target, ndim, newamp1vq_cb[1].m, mbest_stage2, index);
203+
mbest_search(codebook2, codebook2sq, target, w, ndim, newamp1vq_cb[1].m, mbest_stage2, index);
194204
}
195205

196206
n1 = mbest_stage2->list[0].index[1];

0 commit comments

Comments
 (0)