@@ -600,23 +600,131 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
600600 exp_reg_correction = SHFL_SYNC(exp_reg_correction, 0);
601601 """
602602 split_weight_update_cpu = """
603+ auto offset_idx = momentum1_offsets_data[feature_begin] + idx;
604+
605+ // Counter update logic with halflife decay
606+ at::acc_type<grad_t, true> freq = 1.0;
607+ at::acc_type<grad_t, true> tail_id_threshold_val = tail_id_threshold;
608+ if (max_counter != 0.0) {
609+ if (is_tail_id_thresh_ratio == 1) {
610+ tail_id_threshold_val = std::floor(tail_id_threshold * max_counter);
611+ }
612+
613+ if (counter_halflife > 0) {
614+ // Decay based on counter_halflife
615+ const auto iter_delta = prev_iter_host[offset_idx] == 0 ? 1.0 : iter * 1.0 - prev_iter_host[offset_idx];
616+ const auto counter_log_rho = std::log(2.0) / counter_halflife;
617+ row_counter_host[offset_idx] = 1.0 + std::exp(-iter_delta * counter_log_rho) * row_counter_host[offset_idx];
618+ } else if (counter_halflife == 0) {
619+ // Count only 1 (appear or not)
620+ row_counter_host[offset_idx] = 1.0;
621+ } else {
622+ // Count raw appearance without decaying
623+ row_counter_host[offset_idx] += 1.0;
624+ }
625+ freq = counter_halflife / row_counter_host[offset_idx];
626+ }
627+
628+ // Compute gradient statistics
603629 at::acc_type<grad_t, true> g_local_sum_square = 0.0;
630+ at::acc_type<grad_t, true> w_local_sum_square = 0.0;
631+
604632 for (int64_t d = 0; d < D; ++d) {
605- g_local_sum_square += grad_buffer[d] * grad_buffer[d];
633+ auto grad = grad_buffer[d];
634+ // For L2 regularization (weight_decay_mode=1), add weight_decay to gradient before other computation
635+ if (weight_decay_mode == 1) {
636+ grad += weight_decay * host_weights_data[embedding_begin + d];
637+ }
638+ g_local_sum_square += grad * grad;
639+
640+ // COW-clip (regularization_mode=4) requires weight norm
641+ if (regularization_mode == 4) {
642+ const auto weight = host_weights_data[embedding_begin + d];
643+ w_local_sum_square += weight * weight;
644+ }
606645 }
607- auto g_avg_square = g_local_sum_square / D;
608- auto offset_idx = momentum1_offsets_data[feature_begin] + idx;
646+
647+ const auto g_sum_square = g_local_sum_square;
648+ const auto g_avg_square = g_sum_square / D;
649+ const auto w_sum_square = w_local_sum_square;
650+
651+ // Update momentum
609652 at::acc_type<grad_t, true> new_sum_square_grads = momentum1_host[offset_idx] + g_avg_square;
610653 momentum1_host[offset_idx] = new_sum_square_grads;
611- at::acc_type<grad_t, true> multiplier;
612- multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
613- const auto iter_delta = iter * 1.0 - prev_iter_host[offset_idx];
654+ const auto multiplier = learning_rate / (std::sqrt(new_sum_square_grads) + eps);
655+ const auto adjustment_enabled = adjustment_iter <= 0 || (adjustment_iter > 0 && iter > adjustment_iter);
656+
657+ // Compute adjusted multiplier and regularization correction
658+ at::acc_type<grad_t, true> adjusted_multiplier = 0.0;
659+ at::acc_type<grad_t, true> exp_reg_correction = 0.0;
660+
661+ if (regularization_mode == 3) {
662+ // Counter-based regularization (regularization_mode=3)
663+ adjusted_multiplier = multiplier;
664+ if (learning_rate_mode >= 0 && adjustment_enabled) {
665+ if (row_counter_host[offset_idx] > tail_id_threshold_val) {
666+ if (learning_rate_mode == 0) {
667+ adjusted_multiplier = multiplier * std::max(std::min(std::pow(max_counter / (row_counter_host[offset_idx] + 1.0), adjustment_ub), 10.0), 1.0);
668+ } else if (learning_rate_mode == 1) {
669+ adjusted_multiplier = multiplier * std::min(std::max(std::pow((row_counter_host[offset_idx] + 1.0) / max_counter, adjustment_ub), 0.1), 1.0);
670+ } else if (learning_rate_mode == 2) {
671+ adjusted_multiplier = learning_rate / (std::sqrt(adjustment_ub * row_counter_host[offset_idx]) + eps);
672+ }
673+ }
674+ }
675+ } else if (regularization_mode == 4) {
676+ // COW-clip (regularization_mode=4)
677+ const auto clip_thresh = row_counter_host[offset_idx] * std::max(weight_norm_coefficient * std::sqrt(w_sum_square), lower_bound);
678+ adjusted_multiplier = std::min(1.0f, static_cast<float>(clip_thresh / std::sqrt(g_sum_square))) * multiplier;
679+ } else {
680+ // Default: no special regularization
681+ adjusted_multiplier = multiplier;
682+ }
683+
684+ exp_reg_correction = 1.0;
685+ if (regularization_mode == 3) {
686+ // Counter-based regularization (regularization_mode=3)
687+ if (adjustment_enabled) {
688+ if (weight_decay_mode == 3) {
689+ // AdagradW (weight_decay_mode=3)
690+ if (counter_halflife == -1) {
691+ adjusted_multiplier = multiplier * std::sqrt(row_counter_host[offset_idx] * 1.0);
692+ } else if (counter_halflife == -2) {
693+ adjusted_multiplier = std::min(static_cast<float>(learning_rate * std::pow(row_counter_host[offset_idx] * 1.0, 1.0)), adjustment_ub) / (std::sqrt(new_sum_square_grads) + eps);
694+ }
695+ exp_reg_correction = 1.0 - weight_decay * learning_rate;
696+ const auto lazy_delta = prev_iter_host[offset_idx] == 0 ? 1.0 : iter * 1.0 - prev_iter_host[offset_idx];
697+ const auto lazy_multiplier = std::pow(exp_reg_correction, std::min(lazy_delta, iter * 1.0 - adjustment_iter) - 1.0);
698+ adjusted_multiplier *= lazy_multiplier;
699+ exp_reg_correction *= lazy_multiplier;
700+ } else if (weight_decay_mode == 2) {
701+ // Decoupled weight decay (weight_decay_mode=2)
702+ exp_reg_correction = 1.0 - freq * weight_decay * learning_rate;
703+ } else if (weight_decay_mode == 1) {
704+ // L2 regularization (coupled wd)
705+ exp_reg_correction = 1.0 - freq * weight_decay * multiplier;
706+ }
707+ }
708+ } else if (regularization_mode == 4) {
709+ // COW-clip (regularization_mode=4)
710+ if (weight_decay_mode == 2) {
711+ // Decoupled weight decay (weight_decay_mode=2)
712+ exp_reg_correction = 1.0 - weight_decay * learning_rate;
713+ } else if (weight_decay_mode == 1) {
714+ // L2 regularization (coupled wd)
715+ exp_reg_correction = 1.0 - weight_decay * adjusted_multiplier;
716+ }
717+ } else {
718+ // Default regularization
719+ exp_reg_correction = 1.0;
720+ }
721+
722+ // Update prev_iter
614723 prev_iter_host[offset_idx] = iter * 1.0;
615- const auto exp_reg = 1.0 / (weight_decay * multiplier + 1.0);
616- const auto exp_reg_correction = powf(exp_reg, iter_delta);
724+
725+ // Apply weight updates
617726 for (int64_t d = 0; d < D; ++d) {
618- const auto weight = host_weights_data[embedding_begin + d];
619- host_weights_data[embedding_begin + d] = exp_reg_correction * weight - exp_reg * multiplier * grad_buffer[d];
727+ host_weights_data[embedding_begin + d] = exp_reg_correction * host_weights_data[embedding_begin + d] - adjusted_multiplier * grad_buffer[d];
620728 }
621729 """
622730
0 commit comments