Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion climanet/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings

import numpy as np
from .utils import add_month_day_dims
from .utils import add_month_day_dims, calc_stats
import xarray as xr
import torch
from torch.utils.data import Dataset
Expand Down Expand Up @@ -52,6 +52,9 @@ def __init__(
self.lat_coords = daily_da[spatial_dims[0]].to_numpy().copy()
self.lon_coords = daily_da[spatial_dims[1]].to_numpy().copy()

# Store the stats of the daily data before filling NaNs
self.daily_mean, self.daily_std = calc_stats(self.daily_np)

if land_mask is not None:
lm = land_mask.to_numpy().copy()
if lm.ndim == 3:
Expand Down
5 changes: 4 additions & 1 deletion climanet/st_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,10 @@ def __init__(
spatial_heads: Number of attention heads in the spatial Transformer
"""
super().__init__()

# Store arguments to be used later for model saving/loading
self.config = {k: v for k, v in locals().items() if k not in ('self', '__class__')}

self.encoder = VideoEncoder(
in_chans=in_chans, embed_dim=embed_dim, patch_size=patch_size
)
Expand Down Expand Up @@ -568,7 +572,6 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None
Tp = T // self.patch_size[0]
Hp = H // self.patch_size[1]
Wp = W // self.patch_size[2]
Np = Tp * Hp * Wp

# check shape and patch compatibility
assert daily_mask.shape == daily_data.shape, (
Expand Down
234 changes: 211 additions & 23 deletions climanet/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import copy
from pathlib import Path
from typing import Tuple

import numpy as np
from torch.utils.data import Dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
import xarray as xr
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

def regrid_to_boundary_centered_grid(
da: xr.DataArray,
roll = False
) -> xr.DataArray:

def regrid_to_boundary_centered_grid(da: xr.DataArray, roll=False) -> xr.DataArray:
"""
Interpolates a DataArray from its current center-based grid onto a new
grid whose coordinates are derived from user-specified boundaries.
Expand All @@ -20,7 +24,7 @@ def regrid_to_boundary_centered_grid(

# --- 0. Longitude Domain Check and Correction ---

input_lon = da['longitude']
input_lon = da["longitude"]

# Check if roll for 0-360 to -180-180 is requested
if roll:
Expand All @@ -30,18 +34,24 @@ def regrid_to_boundary_centered_grid(
lon_diff = np.abs(input_lon - 180.0)
# We need to roll such that the 180-degree line is moved to the edge
# and the new array starts near -180
roll_amount = int(lon_diff.argmin().item() + (input_lon.size / 2)) % input_lon.size
roll_amount = (
int(lon_diff.argmin().item() + (input_lon.size / 2)) % input_lon.size
)

# Roll the DataArray and its coordinates
da = da.roll(longitude=roll_amount, roll_coords=True)

# Correct the longitude coordinate values: shift values > 180 down by 360
new_lon_coords = da['longitude'].where(da['longitude'] <= 180, da['longitude'] - 360)
new_lon_coords = da["longitude"].where(
da["longitude"] <= 180, da["longitude"] - 360
)

# Assign the corrected and sorted coordinates
da = da.assign_coords(longitude=new_lon_coords).sortby('longitude')
print(f"Longitudes adjusted. New range: {da['longitude'].min().item():.2f} "
f"to {da['longitude'].max().item():.2f}")
da = da.assign_coords(longitude=new_lon_coords).sortby("longitude")
print(
f"Longitudes adjusted. New range: {da['longitude'].min().item():.2f} "
f"to {da['longitude'].max().item():.2f}"
)

# --- 1. Define Target Grid Boundaries ---

Expand All @@ -67,23 +77,19 @@ def regrid_to_boundary_centered_grid(

# Use linear interpolation (suitable for gappy data) to map data onto the
# new centers. xarray handles the NaNs automatically.
da_regridded = da.interp(
latitude=new_lats,
longitude=new_lons,
method="linear"
)
da_regridded = da.interp(latitude=new_lats, longitude=new_lons, method="linear")

print(f"Regridding complete. New dimensions: {da_regridded.dims}")
return da_regridded


def add_month_day_dims(
daily_ts: xr.DataArray, # (time, H, W) daily
daily_ts: xr.DataArray, # (time, H, W) daily
monthly_ts: xr.DataArray, # (time, H, W) monthly
time_dim: str = "time",
spatial_dims: Tuple[str, str] = ("lat", "lon")
spatial_dims: Tuple[str, str] = ("lat", "lon"),
):
""" Reshape daily and monthly data to have explicit month (M) and day (T) dimensions.
"""Reshape daily and monthly data to have explicit month (M) and day (T) dimensions.

Here we assume maximum 31 days in a month, and invalid day entries will be
padded with NaN.
Expand All @@ -104,8 +110,9 @@ def add_month_day_dims(

# Add M (month key) and T (day of month) coordinates to daily data
daily_indexed = (
daily_ts
.assign_coords(M=(time_dim, dkey.values), T=(time_dim, daily_ts[time_dim].dt.day.values))
daily_ts.assign_coords(
M=(time_dim, dkey.values), T=(time_dim, daily_ts[time_dim].dt.day.values)
)
.set_index({time_dim: ("M", "T")})
.unstack(time_dim)
.reindex(T=np.arange(1, 32), M=month_keys)
Expand All @@ -119,8 +126,7 @@ def add_month_day_dims(

# Align monthly data to same month keys/order
monthly_m = (
monthly_ts
.assign_coords(M=(time_dim, mkey.values))
monthly_ts.assign_coords(M=(time_dim, mkey.values))
.swap_dims({time_dim: "M"})
.drop_vars(time_dim)
.sel(M=month_keys)
Expand All @@ -146,7 +152,189 @@ def pred_to_numpy(pred, orig_H=None, orig_W=None, land_mask=None):
if land_mask is not None:
pred = pred.clone().to(torch.float32)
land_mask = land_mask.bool()
land_mask = land_mask.unsqueeze(1) # (B, H,W) -> (B, 1, H, W) for broadcasting
land_mask = land_mask.unsqueeze(1) # (B, H,W) -> (B, 1, H, W) for broadcasting
pred = torch.where(land_mask, torch.full_like(pred, float("nan")), pred)

return pred.detach().cpu().numpy()


def calc_stats(arr: np.ndarray, mean_axis: int = 0) -> Tuple[np.ndarray, np.ndarray]:
"""Calculate mean and std along the specified axis, ignoring NaNs."""
axes_to_reduce = tuple(i for i in range(arr.ndim) if i != mean_axis)

mean = np.nanmean(arr, axis=axes_to_reduce) # shape: (M,)
std = np.nanstd(arr, axis=axes_to_reduce) # shape: (M,)
return mean, std


def _setup_logging(log_dir: str) -> SummaryWriter:
"""Set up TensorBoard logging directory and writer."""
Path(log_dir).mkdir(parents=True, exist_ok=True)
return SummaryWriter(log_dir)


def _compute_masked_loss(
pred: torch.Tensor, target: torch.Tensor, land_mask: torch.Tensor
) -> torch.Tensor:
"""Compute L1 loss masked to ocean pixels only."""
ocean = (~land_mask).to(pred.device).unsqueeze(1).float()
loss = torch.nn.functional.l1_loss(pred, target, reduction="none") * ocean

num = loss.sum(dim=(-2, -1))
denom = ocean.sum(dim=(-2, -1)).clamp_min(1)

return (num / denom).mean()


def _save_model(model: torch.nn.Module, log_dir: str, verbose: bool) -> None:
"""Save model state and config to disk."""
model_path = Path(log_dir) / "best_model.pth"
torch.save(
{"model_state_dict": model.state_dict(), "model_config": model.config},
model_path,
)
if verbose:
print(f"Model saved to {model_path}")


def train_monthly_model(
model: torch.nn.Module,
dataset: Dataset,
shuffle: bool = True,
batch_size: int = 2,
num_epoch: int = 100,
patience: int = 10,
accumulation_steps: int = 1,
optimizer_lr: float = 1e-3,
log_dir: str = ".",
save_model: bool = True,
device: str = "cpu",
verbose: bool = True,
):
"""Train the model to predict monthly data from daily data.
Args:
model: the PyTorch model to train
dataset: Dataset object containing the training data
shuffle: whether to shuffle the data each epoch
batch_size: number of samples per batch
num_epoch: number of epochs to train
patience: number of epochs to wait for improvement before early stopping
accumulation_steps: number of batches to accumulate gradients over before updating weights
optimizer_lr: learning rate for the optimizer
log_dir: directory to save logs
save_model: whether to save the best model to disk
device: device to run training on ("cpu" or "cuda")
verbose: whether to print training progress
"""

# Initialize the model
model = model.to(device)
decoder = model.decoder
with torch.no_grad():
decoder.bias.copy_(torch.from_numpy(dataset.daily_mean))
decoder.scale.copy_(torch.from_numpy(dataset.daily_std) + 1e-6)

# Create data loader
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
pin_memory=False,
)

# Set up logging
writer = _setup_logging(log_dir)

# Set the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=optimizer_lr)
best_loss = float("inf")
counter = 0
best_state_dict = None # Store best model state

# Add scheduler - reduces LR instead of stopping immediately
scheduler = ReduceLROnPlateau(
optimizer,
mode="min",
factor=0.5,
patience=patience // 2, # Reduce LR before early stop triggers
min_lr=1e-7,
)

model.train()
for epoch in range(num_epoch):
epoch_loss = 0.0

optimizer.zero_grad()

for i, batch in enumerate(dataloader):
# Batch prediction
pred = model(
batch["daily_patch"],
batch["daily_mask_patch"],
batch["land_mask_patch"],
batch["padded_days_mask"],
) # (B, M, H, W)

# Compute masked loss
loss = _compute_masked_loss(
pred, batch["monthly_patch"], batch["land_mask_patch"]
)

# Scale loss for gradient accumulation
scaled_loss = loss / accumulation_steps
scaled_loss.backward()

# Track unscaled loss for logging
epoch_loss += loss.item()

# Update weights every accumulation_steps batches
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()

# Handle remaining gradients if num_batches is not divisible by accumulation_steps
if (i + 1) % accumulation_steps != 0:
optimizer.step()
optimizer.zero_grad()

# Calculate average epoch loss
avg_epoch_loss = epoch_loss / (i + 1)

# Step scheduler
scheduler.step(avg_epoch_loss)

# Log to TensorBoard
writer.add_scalar("Loss/train", avg_epoch_loss, epoch)
writer.add_scalar("Loss/best", best_loss, epoch)

# Early stopping check
if avg_epoch_loss < best_loss:
best_loss = avg_epoch_loss
best_state_dict = copy.deepcopy(model.state_dict())
counter = 0
else:
counter += 1

if verbose and epoch % 20 == 0:
print(f"Epoch {epoch}: best_loss = {best_loss:.6f}")

# Only stop if LR is at minimum AND no improvement
current_lr = optimizer.param_groups[0]["lr"]
if counter >= patience and current_lr <= scheduler.min_lrs[0]:
writer.add_text("Training", f"Early stop at epoch {epoch}", epoch)
break

# Restore best model
if best_state_dict is not None:
model.load_state_dict(best_state_dict)

# Close the writer when done
writer.close()

if verbose:
print(f"Training complete. Best loss: {best_loss:.6f}")

if save_model:
_save_model(model, log_dir, verbose)

return model
Loading
Loading