@@ -100,6 +100,11 @@ inline void one_bit_code_with_factor(
100100 // dot product between centroid and xu_cb
101101 T ip_cent_xucb = dot_product<T>(centroid, xu_cb.data (), dim);
102102
103+ // corner case
104+ if (ip_resi_xucb == 0 ) {
105+ ip_resi_xucb = std::numeric_limits<T>::infinity ();
106+ }
107+
103108 // We use unnormalized vector to get error factor. To be more specific,
104109 // sqrt((1 - <o, o_bar>^2) / <o, o_bar>^2) / sqrt(dim - 1) = 3rd item in following
105110 // expression
@@ -465,6 +470,11 @@ inline void ex_bits_code_with_factor(
465470 T ip_resi_xucb = dot_product<T>(residual_arr.data (), xu_cb.data (), dim);
466471 T ip_cent_xucb = dot_product<T>(centroid, xu_cb.data (), dim);
467472
473+ // corner case
474+ if (ip_resi_xucb == 0 ) {
475+ ip_resi_xucb = std::numeric_limits<T>::infinity ();
476+ }
477+
468478 T tmp_error =
469479 l2_norm * kConstEpsilon *
470480 std::sqrt (
@@ -556,7 +566,8 @@ static inline void rabitq_scalar_impl(
556566
557567 float norm_data = std::sqrt (l2norm_sqr (residual_arr.data (), dim));
558568 float norm_quan = std::sqrt (l2norm_sqr (u_cb.data (), dim));
559- float cos_similarity = dot_product<T>(residual_arr.data (), u_cb.data (), dim) / (norm_data * norm_quan);
569+ float cos_similarity =
570+ dot_product<T>(residual_arr.data (), u_cb.data (), dim) / (norm_data * norm_quan);
560571
561572 if (scalar_quantizer_type == ScalarQuantizerType::RECONSTRUCTION) {
562573 delta = norm_data / norm_quan * cos_similarity;
0 commit comments