Skip to content

Commit 8c6644c

Browse files
committed
handle corner case for factors during quantization
1 parent fdeb3a6 commit 8c6644c

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

rabitqlib/quantization/rabitq_impl.hpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ inline void one_bit_code_with_factor(
100100
// dot product between centroid and xu_cb
101101
T ip_cent_xucb = dot_product<T>(centroid, xu_cb.data(), dim);
102102

103+
// corner case
104+
if (ip_resi_xucb == 0) {
105+
ip_resi_xucb = std::numeric_limits<T>::infinity();
106+
}
107+
103108
// We use unnormalized vector to get error factor. To be more specific,
104109
// sqrt((1 - <o, o_bar>^2) / <o, o_bar>^2) / sqrt(dim - 1) = 3rd item in following
105110
// expression
@@ -465,6 +470,11 @@ inline void ex_bits_code_with_factor(
465470
T ip_resi_xucb = dot_product<T>(residual_arr.data(), xu_cb.data(), dim);
466471
T ip_cent_xucb = dot_product<T>(centroid, xu_cb.data(), dim);
467472

473+
// corner case
474+
if (ip_resi_xucb == 0) {
475+
ip_resi_xucb = std::numeric_limits<T>::infinity();
476+
}
477+
468478
T tmp_error =
469479
l2_norm * kConstEpsilon *
470480
std::sqrt(
@@ -556,7 +566,8 @@ static inline void rabitq_scalar_impl(
556566

557567
float norm_data = std::sqrt(l2norm_sqr(residual_arr.data(), dim));
558568
float norm_quan = std::sqrt(l2norm_sqr(u_cb.data(), dim));
559-
float cos_similarity = dot_product<T>(residual_arr.data(), u_cb.data(), dim) / (norm_data * norm_quan);
569+
float cos_similarity =
570+
dot_product<T>(residual_arr.data(), u_cb.data(), dim) / (norm_data * norm_quan);
560571

561572
if (scalar_quantizer_type == ScalarQuantizerType::RECONSTRUCTION) {
562573
delta = norm_data / norm_quan * cos_similarity;

0 commit comments

Comments
 (0)