1919#define MAX_ENTRIES 4096
2020#define MAX_STAGES 5
2121
22- void quant_pred_mbest (float vec_out [],
23- int indexes [],
24- float vec_in [],
25- int num_stages ,
26- float vq [],
27- int m [], int k ,
28- int mbest_survivors , int st , int en );
22+ void quant_mbest (float vec_out [],
23+ int indexes [],
24+ float vec_in [],
25+ int num_stages ,
26+ float vqw [], float vq [],
27+ int m [], int k ,
28+ int mbest_survivors );
2929
3030int verbose = 0 ;
3131
3232int main (int argc , char * argv []) {
3333 float vq [MAX_STAGES * MAX_K * MAX_ENTRIES ];
34+ float vqw [MAX_STAGES * MAX_K * MAX_ENTRIES ];
3435 int m [MAX_STAGES ];
3536 int k = 0 , mbest_survivors = 1 , num_stages = 0 ;
3637 char fnames [256 ], fn [256 ], * comma , * p ;
@@ -143,11 +144,27 @@ int main(int argc, char *argv[]) {
143144 if (st == -1 ) st = 0 ;
144145 if (en == -1 ) en = k - 1 ;
145146
146- int indexes [num_stages ], nvecs = 0 ; int vec_usage [m [0 ]];
147+ float w [k ];
148+ for (int i = 0 ; i < st ; i ++ )
149+ w [i ] = 0.0 ;
150+ for (int i = st ; i <=en ; i ++ )
151+ w [i ] = 1.0 ;
152+ for (int i = en + 1 ; i < k ; i ++ )
153+ w [i ] = 0.0 ;
154+
155+ /* apply weighting to codebook (rather than in search) */
156+ memcpy (vqw , vq , sizeof (vq ));
157+ for (int s = 0 ; s < num_stages ; s ++ ) {
158+ mbest_precompute_weight (& vqw [s * k * MAX_ENTRIES ], w , k , m [s ]);
159+ }
160+
161+ int indexes [num_stages ], nvecs = 0 ; int vec_usage [m [0 ]];
147162 for (int i = 0 ; i < m [0 ]; i ++ ) vec_usage [i ] = 0 ;
148163 float target [k ], quantised [k ];
149164 float sqe = 0.0 ;
150165 while (fread (& target , sizeof (float ), k , stdin ) && (nvecs < num )) {
166+ for (int i = 0 ; i < k ; i ++ )
167+ target [i ] *= w [i ];
151168 int dont_count = 0 ;
152169 /* optional clamping to lower limit or mean */
153170 float mean = 0.0 ;
@@ -161,7 +178,7 @@ int main(int argc, char *argv[]) {
161178 target [i ] += - difference ;
162179 dont_count = 1 ;
163180 }
164- quant_pred_mbest (quantised , indexes , target , num_stages , vq , m , k , mbest_survivors , st , en );
181+ quant_mbest (quantised , indexes , target , num_stages , vqw , vq , m , k , mbest_survivors );
165182 if (dont_count == 0 ) {
166183 for (int i = st ; i <=en ; i ++ )
167184 sqe += pow (target [i ]- quantised [i ], 2.0 );
@@ -172,7 +189,7 @@ int main(int argc, char *argv[]) {
172189 vec_usage [indexes [0 ]]++ ;
173190 }
174191
175- fprintf (stderr , "%4.2f\n" , sqe /(nvecs * (en - st + 1 )));
192+ fprintf (stderr , "MSE: %4.2f\n" , sqe /(nvecs * (en - st + 1 )));
176193
177194 if (output_vec_usage ) {
178195 for (int i = 0 ; i < m [0 ]; i ++ )
@@ -196,15 +213,15 @@ void pv(char s[], float v[], int k) {
196213
197214// mbest algorithm version, backported from LPCNet/src
198215
199- void quant_pred_mbest (float vec_out [],
200- int indexes [],
201- float vec_in [],
202- int num_stages ,
203- float vq [],
204- int m [], int k ,
205- int mbest_survivors , int st , int en )
216+ void quant_mbest (float vec_out [],
217+ int indexes [],
218+ float vec_in [],
219+ int num_stages ,
220+ float vqw [], float vq [],
221+ int m [], int k ,
222+ int mbest_survivors )
206223{
207- float err [k ], w [ k ], se1 ;
224+ float err [k ], se1 ;
208225 int i ,j ,s ,s1 ,ind ;
209226
210227 struct MBEST * mbest_stage [num_stages ];
@@ -216,24 +233,17 @@ void quant_pred_mbest(float vec_out[],
216233 index [i ] = 0 ;
217234 }
218235
219- for (i = 0 ; i < st ; i ++ )
220- w [i ] = 0.0 ;
221- for (i = st ; i <=en ; i ++ )
222- w [i ] = 1.0 ;
223- for (i = en + 1 ; i < k ; i ++ )
224- w [i ] = 0.0 ;
225-
226236 se1 = 0.0 ;
227237 for (i = 0 ; i < k ; i ++ ) {
228238 err [i ] = vec_in [i ];
229- se1 += err [i ]* err [i ]* w [ i ] * w [ i ] ;
239+ se1 += err [i ]* err [i ];
230240 }
231241 se1 /= k ;
232242
233243 /* now quantise err[] using multi-stage mbest search, preserving
234244 mbest_survivors at each stage */
235245
236- mbest_search (vq , err , w , k , m [0 ], mbest_stage [0 ], index );
246+ mbest_search (vqw , err , k , m [0 ], mbest_stage [0 ], index );
237247 if (verbose ) mbest_print ("Stage 1:" , mbest_stage [0 ]);
238248
239249 for (s = 1 ; s < num_stages ; s ++ ) {
@@ -251,11 +261,11 @@ void quant_pred_mbest(float vec_out[],
251261 ind = index [s - s1 ];
252262 if (verbose ) fprintf (stderr , " s: %d s1: %d s-s1: %d ind: %d\n" , s ,s1 ,s - s1 ,ind );
253263 for (i = 0 ; i < k ; i ++ ) {
254- target [i ] -= vq [s1 * k * MAX_ENTRIES + ind * k + i ];
264+ target [i ] -= vqw [s1 * k * MAX_ENTRIES + ind * k + i ];
255265 }
256266 }
257267 pv (" target: " , target , k );
258- mbest_search (& vq [s * k * MAX_ENTRIES ], target , w , k , m [s ], mbest_stage [s ], index );
268+ mbest_search (& vqw [s * k * MAX_ENTRIES ], target , k , m [s ], mbest_stage [s ], index );
259269 }
260270 char str [80 ]; sprintf (str ,"Stage %d:" , s + 1 );
261271 if (verbose ) mbest_print (str , mbest_stage [s ]);
@@ -272,9 +282,9 @@ void quant_pred_mbest(float vec_out[],
272282 int ind = indexes [s ];
273283 float se2 = 0.0 ;
274284 for (i = 0 ; i < k ; i ++ ) {
275- err [i ] -= vq [s * k * MAX_ENTRIES + ind * k + i ];
285+ err [i ] -= vqw [s * k * MAX_ENTRIES + ind * k + i ];
276286 vec_out [i ] += vq [s * k * MAX_ENTRIES + ind * k + i ];
277- se2 += err [i ]* err [i ]* w [ i ] * w [ i ] ;
287+ se2 += err [i ]* err [i ];
278288 }
279289 se2 /= k ;
280290 pv (" err: " , err , k );
0 commit comments