|
| 1 | +/* |
| 2 | + vq_binary_switch.c |
| 3 | + David Rowe Dec 2021 |
| 4 | +
|
| 5 | + C implementation of [1], that re-arranges VQ indexes so they are robust to single |
| 6 | + bit errors. |
| 7 | +
|
| 8 | + [1] Psuedo Gray Coding, Zeger & Gersho 1990 |
| 9 | +*/ |
| 10 | + |
| 11 | +#include <assert.h> |
| 12 | +#include <getopt.h> |
| 13 | +#include <math.h> |
| 14 | +#include <stdlib.h> |
| 15 | +#include <stdio.h> |
| 16 | +#include <string.h> |
| 17 | +#include <limits.h> |
| 18 | +#include "mbest.h" |
| 19 | + |
| 20 | +#define MAX_DIM 20 |
| 21 | +#define MAX_ENTRIES 4096 |
| 22 | + |
| 23 | +// equation (33) of [1], total cost of all hamming distance 1 vectors of vq index k |
| 24 | +float cost_of_distance_one(float *vq, int n, int dim, float *prob, int k, int st, int en, int verbose) { |
| 25 | + int log2N = log2(n); |
| 26 | + float c = 0.0; |
| 27 | + for (int b=0; b<log2N; b++) { |
| 28 | + unsigned int index_neighbour = k ^ (1<<b); |
| 29 | + float dist = 0.0; |
| 30 | + for(int i=st; i<=en; i++) |
| 31 | + dist += pow(vq[k*dim+i] - vq[index_neighbour*dim+i], 2.0); |
| 32 | + c += prob[k]*dist; |
| 33 | + if (verbose) |
| 34 | + printf("k: %d b: %d index_neighbour: %d dist: %f prob: %f c: %f \n", k, b, index_neighbour, dist, prob[k], c); |
| 35 | + } |
| 36 | + return c; |
| 37 | +} |
| 38 | + |
| 39 | +// equation (39) of [1] |
| 40 | +float distortion_of_current_mapping(float *vq, int n, int dim, float *prob, int st, int en) { |
| 41 | + float d = 0.0; |
| 42 | + for(int k=0; k<n; k++) |
| 43 | + d += cost_of_distance_one(vq, n, dim, prob, k, st, en, 0); |
| 44 | + return d; |
| 45 | +} |
| 46 | + |
| 47 | +// we sort the cost array c[], returning the indexes of sorted elements |
| 48 | +float c[MAX_ENTRIES]; |
| 49 | + |
| 50 | +/* Note how the compare function compares the values of the |
| 51 | + * array to be sorted. The passed value to this function |
| 52 | + * by `qsort' are actually the `idx' array elements. |
| 53 | + */ |
| 54 | +int compare_increase (const void * a, const void * b) { |
| 55 | + int aa = *((int *) a), bb = *((int *) b); |
| 56 | + if (c[aa] < c[bb]) { |
| 57 | + return 1; |
| 58 | + } else if (c[aa] == c[bb]) { |
| 59 | + return 0; |
| 60 | + } else { |
| 61 | + return -1; |
| 62 | + } |
| 63 | +} |
| 64 | + |
| 65 | +void sort_c(int *idx, const size_t n) { |
| 66 | + for (size_t i=0; i<n; i++) idx[i] = i; |
| 67 | + qsort(idx, n, sizeof(int), compare_increase); |
| 68 | +} |
| 69 | + |
| 70 | +void swap(float *vq, int dim, float *prob, int index1, int index2) { |
| 71 | + float tmp[dim]; |
| 72 | + for(int i=0; i<dim; i++) tmp[i] = vq[index1*dim+i]; |
| 73 | + for(int i=0; i<dim; i++) vq[index1*dim+i] = vq[index2*dim+i]; |
| 74 | + for(int i=0; i<dim; i++) vq[index2*dim+i] = tmp[i]; |
| 75 | + |
| 76 | + tmp[0] = prob[index1]; |
| 77 | + prob[index1] = prob[index2]; |
| 78 | + prob[index2] = tmp[0]; |
| 79 | +} |
| 80 | + |
| 81 | +int main(int argc, char *argv[]) { |
| 82 | + float vq[MAX_DIM*MAX_ENTRIES]; |
| 83 | + int dim = MAX_DIM; |
| 84 | + int max_iter = INT_MAX; |
| 85 | + int st = -1; |
| 86 | + int en = -1; |
| 87 | + int verbose = 0; |
| 88 | + int n = 0; |
| 89 | + int fast_en = 0; |
| 90 | + char prob_fn[80]=""; |
| 91 | + |
| 92 | + int o = 0; int opt_idx = 0; |
| 93 | + while (o != -1) { |
| 94 | + static struct option long_opts[] = { |
| 95 | + {"prob", required_argument, 0, 'p'}, |
| 96 | + {"st", required_argument, 0, 't'}, |
| 97 | + {"en", required_argument, 0, 'e'}, |
| 98 | + {0, 0, 0, 0} |
| 99 | + }; |
| 100 | + o = getopt_long(argc,argv,"hd:m:vt:e:n:fp:",long_opts,&opt_idx); |
| 101 | + switch (o) { |
| 102 | + case 'd': |
| 103 | + dim = atoi(optarg); |
| 104 | + assert(dim <= MAX_DIM); |
| 105 | + break; |
| 106 | + case 'm': |
| 107 | + max_iter = atoi(optarg); |
| 108 | + break; |
| 109 | + case 't': |
| 110 | + st = atoi(optarg); |
| 111 | + break; |
| 112 | + case 'e': |
| 113 | + en = atoi(optarg); |
| 114 | + break; |
| 115 | + case 'f': |
| 116 | + fast_en = 1; |
| 117 | + break; |
| 118 | + case 'n': |
| 119 | + n = atoi(optarg); |
| 120 | + break; |
| 121 | + case 'p': |
| 122 | + strcpy(prob_fn,optarg); |
| 123 | + break; |
| 124 | + case 'v': |
| 125 | + verbose = 1; |
| 126 | + break; |
| 127 | + help: |
| 128 | + fprintf(stderr, "\n"); |
| 129 | + fprintf(stderr, "usage: %s -d dimension [-m max_iterations -v --st Kst --en Ken -n nVQ] vq_in.f32 vq_out.f32\n", argv[0]); |
| 130 | + fprintf(stderr, "\n"); |
| 131 | + fprintf(stderr, "-n nVQ Run with just the first nVQ entries of the VQ\n"); |
| 132 | + fprintf(stderr, "--st Kst Start vector element for error calculation (default 0)\n"); |
| 133 | + fprintf(stderr, "--en Ken End vector element for error calculation (default K-1)\n"); |
| 134 | + fprintf(stderr, "--prob probFile f32 file of probabilities for each VQ element (default 1.0)\n"); |
| 135 | + fprintf(stderr, "-v verbose\n"); |
| 136 | + exit(1); |
| 137 | + } |
| 138 | + } |
| 139 | + |
| 140 | + int dx = optind; |
| 141 | + if ((argc - dx) < 2) { |
| 142 | + fprintf(stderr, "Too few arguments\n"); |
| 143 | + goto help; |
| 144 | + } |
| 145 | + if (dim == 0) goto help; |
| 146 | + |
| 147 | + /* default to measuring error on entire vector */ |
| 148 | + if (st == -1) st = 0; |
| 149 | + if (en == -1) en = dim-1; |
| 150 | + |
| 151 | + /* load VQ quantiser file --------------------*/ |
| 152 | + |
| 153 | + fprintf(stderr, "loading %s ... ", argv[dx]); |
| 154 | + FILE *fq=fopen(argv[dx], "rb"); |
| 155 | + if (fq == NULL) { |
| 156 | + fprintf(stderr, "Couldn't open: %s\n", argv[dx]); |
| 157 | + exit(1); |
| 158 | + } |
| 159 | + |
| 160 | + if (n==0) { |
| 161 | + /* count how many entries m of dimension k are in this VQ file */ |
| 162 | + float dummy[dim]; |
| 163 | + while (fread(dummy, sizeof(float), dim, fq) == (size_t)dim) |
| 164 | + n++; |
| 165 | + assert(n <= MAX_ENTRIES); |
| 166 | + fprintf(stderr, "%d entries of vectors width %d\n", n, dim); |
| 167 | + |
| 168 | + rewind(fq); |
| 169 | + } |
| 170 | + |
| 171 | + /* load VQ into memory */ |
| 172 | + int nrd = fread(vq, sizeof(float), n*dim, fq); |
| 173 | + assert(nrd == n*dim); |
| 174 | + fclose(fq); |
| 175 | + |
| 176 | + /* set probability of each vector to 1.0 as default */ |
| 177 | + float prob[n]; |
| 178 | + for(int l=0; l<n; l++) prob[l] = 1.0; |
| 179 | + if (strlen(prob_fn)) { |
| 180 | + fprintf(stderr, "Reading probability file: %s\n", prob_fn); |
| 181 | + FILE *fp = fopen(prob_fn,"rb"); |
| 182 | + assert(fp != NULL); |
| 183 | + int nrd = fread(prob, sizeof(float), n, fp); |
| 184 | + assert(nrd == n); |
| 185 | + fclose(fp); |
| 186 | + float sum = 0.0; |
| 187 | + for(int l=0; l<n; l++) sum += prob[l]; |
| 188 | + fprintf(stderr, "sum = %f\n", sum); |
| 189 | + } |
| 190 | + |
| 191 | + int iteration = 0; |
| 192 | + int i = 0; |
| 193 | + int finished = 0; |
| 194 | + int switches = 0; |
| 195 | + int log2N = log2(n); |
| 196 | + float distortion0 = distortion_of_current_mapping(vq, n, dim, prob, st, en); |
| 197 | + fprintf(stderr, "distortion0: %f\n", distortion0); |
| 198 | + |
| 199 | + while(!finished) { |
| 200 | + |
| 201 | + // generate a list A(i) of which vectors have the largest cost of bit errors |
| 202 | + for(int k=0; k<n; k++) { |
| 203 | + c[k] = cost_of_distance_one(vq, n, dim, prob, k, st, en, verbose); |
| 204 | + } |
| 205 | + int A[n]; |
| 206 | + sort_c(A, n); |
| 207 | + |
| 208 | + // Try switching each vector with A(i) |
| 209 | + float best_delta = 0; int best_j = 0; |
| 210 | + for(int j=1; j<n; j++) { |
| 211 | + float distortion1, distortion2, delta = 0.0; |
| 212 | + |
| 213 | + // we can't switch with ourself |
| 214 | + if (j != A[i]) { |
| 215 | + if (fast_en) { |
| 216 | + // subtract just those contributions to delta that will change |
| 217 | + delta -= cost_of_distance_one(vq, n, dim, prob, A[i], st, en, verbose); |
| 218 | + delta -= cost_of_distance_one(vq, n, dim, prob, j, st, en, verbose); |
| 219 | + for (int b=0; b<log2N; b++) { |
| 220 | + unsigned int index_neighbour; |
| 221 | + index_neighbour = A[i] ^ (1<<b); |
| 222 | + if ((index_neighbour != j) && (index_neighbour != A[i])) |
| 223 | + delta -= cost_of_distance_one(vq, n, dim, prob, index_neighbour, st, en, verbose); |
| 224 | + index_neighbour = j ^ (1<<b); |
| 225 | + if ((index_neighbour != j) && (index_neighbour != A[i])) |
| 226 | + delta -= cost_of_distance_one(vq, n, dim, prob, index_neighbour, st, en, verbose); |
| 227 | + } |
| 228 | + } |
| 229 | + else |
| 230 | + distortion1 = distortion_of_current_mapping(vq, n, dim, prob, st, en); |
| 231 | + |
| 232 | + // switch vq entries A(i) and j |
| 233 | + swap(vq, dim, prob, A[i], j); |
| 234 | + |
| 235 | + if (fast_en) { |
| 236 | + // add just those contributions to delta that will change |
| 237 | + delta += cost_of_distance_one(vq, n, dim, prob, A[i], st, en, verbose); |
| 238 | + delta += cost_of_distance_one(vq, n, dim, prob, j, st, en, verbose); |
| 239 | + for (int b=0; b<log2N; b++) { |
| 240 | + unsigned int index_neighbour; |
| 241 | + index_neighbour = A[i] ^ (1<<b); |
| 242 | + if ((index_neighbour != j) && (index_neighbour != A[i])) |
| 243 | + delta += cost_of_distance_one(vq, n, dim, prob, index_neighbour, st, en, verbose); |
| 244 | + index_neighbour = j ^ (1<<b); |
| 245 | + if ((index_neighbour != j) && (index_neighbour != A[i])) |
| 246 | + delta += cost_of_distance_one(vq, n, dim, prob, index_neighbour, st, en, verbose); |
| 247 | + } |
| 248 | + } |
| 249 | + else { |
| 250 | + distortion2 = distortion_of_current_mapping(vq, n, dim, prob, st, en); |
| 251 | + delta = distortion2 - distortion1; |
| 252 | + } |
| 253 | + |
| 254 | + if (delta < 0.0) { |
| 255 | + if (fabs(delta) > best_delta) { |
| 256 | + best_delta = fabs(delta); |
| 257 | + best_j = j; |
| 258 | + } |
| 259 | + } |
| 260 | + // unswitch |
| 261 | + swap(vq, dim, prob, A[i], j); |
| 262 | + } |
| 263 | + } //next j |
| 264 | + |
| 265 | + // printf("best_delta: %f best_j: %d\n", best_delta, best_j); |
| 266 | + if (best_delta == 0.0) { |
| 267 | + // Hmm, no improvement, lets try the next vector in the sorted cost list |
| 268 | + if (i == n-1) finished = 1; else i++; |
| 269 | + } else { |
| 270 | + // OK keep the switch that minimised the distortion |
| 271 | + swap(vq, dim, prob, A[i], best_j); |
| 272 | + switches++; |
| 273 | + |
| 274 | + // save results |
| 275 | + FILE *fq=fopen(argv[dx+1], "wb"); |
| 276 | + if (fq == NULL) { |
| 277 | + fprintf(stderr, "Couldn't open: %s\n", argv[dx+1]); |
| 278 | + exit(1); |
| 279 | + } |
| 280 | + int nwr = fwrite(vq, sizeof(float), n*dim, fq); |
| 281 | + assert(nwr == n*dim); |
| 282 | + fclose(fq); |
| 283 | + |
| 284 | + // set up for next iteration |
| 285 | + iteration++; |
| 286 | + float distortion = distortion_of_current_mapping(vq, n, dim, prob, st, en); |
| 287 | + fprintf(stderr, "it: %3d dist: %f %3.2f i: %3d sw: %3d\n", iteration, distortion, |
| 288 | + distortion/distortion0, i, switches); |
| 289 | + if (iteration >= max_iter) finished = 1; |
| 290 | + i = 0; |
| 291 | + } |
| 292 | + } |
| 293 | + |
| 294 | + return 0; |
| 295 | +} |
| 296 | + |
0 commit comments