diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 814f045c61b8..30f933ee3c56 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -339,11 +339,7 @@ def _forward_transform( ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """Input is of shape [B, N, P].""" mu, sigma = self._timesfm_masked_mean_std(inputs, patched_pads) - sigma = torch.where( - sigma < self.config.tolerance, - torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device), - sigma, - ) + sigma = torch.clamp(sigma, min=self.config.tolerance) # Normalize each patch outputs = (inputs - mu[:, None, None]) / sigma[:, None, None] @@ -522,24 +518,16 @@ def _get_patch_index(arr: torch.Tensor): # Calculate the number of valid elements num_valid_elements = torch.sum(mask, dim=1) - num_valid_elements = torch.where( - num_valid_elements == 0, - torch.tensor(1, dtype=num_valid_elements.dtype, device=num_valid_elements.device), - num_valid_elements, - ) + num_valid_elements = torch.clamp(num_valid_elements, min=1.0) - # Calculate the masked sum and squared sum + # Calculate the masked sum and mean masked_sum = torch.sum(arr * mask, dim=1) - masked_squared_sum = torch.sum((arr * mask) ** 2, dim=1) - - # Calculate the masked mean and standard deviation - masked_mean = masked_sum / num_valid_elements - masked_var = masked_squared_sum / num_valid_elements - masked_mean**2 - masked_var = torch.where( - masked_var < 0.0, - torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device), - masked_var, - ) + masked_mean = masked_sum / num_valid_elements # [b] + + # Calculate the masked variance using centered values + masked_centered_arr = (arr - masked_mean.unsqueeze(-1)) * mask + masked_var = torch.sum(masked_centered_arr**2, dim=1) / num_valid_elements + masked_var = torch.clamp(masked_var, min=0.0) masked_std = torch.sqrt(masked_var) return masked_mean, masked_std diff --git a/src/transformers/models/timesfm/modular_timesfm.py b/src/transformers/models/timesfm/modular_timesfm.py index f88973c420e9..619398b5aea7 100644 --- a/src/transformers/models/timesfm/modular_timesfm.py +++ b/src/transformers/models/timesfm/modular_timesfm.py @@ -295,11 +295,7 @@ def _forward_transform( ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """Input is of shape [B, N, P].""" mu, sigma = self._timesfm_masked_mean_std(inputs, patched_pads) - sigma = torch.where( - sigma < self.config.tolerance, - torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device), - sigma, - ) + sigma = torch.clamp(sigma, min=self.config.tolerance) # Normalize each patch outputs = (inputs - mu[:, None, None]) / sigma[:, None, None] @@ -478,24 +474,16 @@ def _get_patch_index(arr: torch.Tensor): # Calculate the number of valid elements num_valid_elements = torch.sum(mask, dim=1) - num_valid_elements = torch.where( - num_valid_elements == 0, - torch.tensor(1, dtype=num_valid_elements.dtype, device=num_valid_elements.device), - num_valid_elements, - ) + num_valid_elements = torch.clamp(num_valid_elements, min=1.0) - # Calculate the masked sum and squared sum + # Calculate the masked sum and mean masked_sum = torch.sum(arr * mask, dim=1) - masked_squared_sum = torch.sum((arr * mask) ** 2, dim=1) - - # Calculate the masked mean and standard deviation - masked_mean = masked_sum / num_valid_elements - masked_var = masked_squared_sum / num_valid_elements - masked_mean**2 - masked_var = torch.where( - masked_var < 0.0, - torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device), - masked_var, - ) + masked_mean = masked_sum / num_valid_elements # [b] + + # Calculate the masked variance using centered values + masked_centered_arr = (arr - masked_mean.unsqueeze(-1)) * mask + masked_var = torch.sum(masked_centered_arr**2, dim=1) / num_valid_elements + masked_var = torch.clamp(masked_var, min=0.0) masked_std = torch.sqrt(masked_var) return masked_mean, masked_std