Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions climanet/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings

import numpy as np
from .utils import add_month_day_dims, calc_stats
from .utils import add_month_day_dims, calc_stats, add_month_hour_dims
from .geo_embedding_utils import (
calculate_sh_geo_pos_embeddings,
compute_patch_geo_pos_embedding,
Expand All @@ -28,7 +28,21 @@ def __init__(
sh_pos_table: str = None, # Optional; str formatted path to precomputed table of sh
sh_embed_dim: int = 96, # sh_embed_dim should <= (sh_order_L + 1)**2
sh_order_L: int = 10,
is_hourly: bool = False,
):
"""Initialize the dataset with daily and monthly data, and optional land mask.

Args:
daily_da: xarray DataArray with daily data (M, time, H, W)
monthly_da: xarray DataArray with monthly data (M, H, W)
land_mask: Optional xarray DataArray with land mask (H, W) or (1, H, W)
time_dim: Name of the time dimension in the input data
spatial_dims: Tuple of (lat_dim, lon_dim) names in the input data
patch_size: Tuple of (patch_height, patch_width) in pixels
stride: Tuple of (stride_height, stride_width) in pixels. If None, defaults to patch_size (non-overlapping patches).
is_hourly: Whether the daily data is hourly (T=31*24) or daily (T=31).

"""
self.spatial_dims = spatial_dims
self.patch_size = patch_size
self.daily_da = daily_da
Expand All @@ -53,11 +67,19 @@ def __init__(
f"Patch size {patch_size} is larger than data dimensions {daily_da.sizes[spatial_dims]}"
)

# Reshape daily → (M, T=31, H, W), monthly → (M, H, W),
# and get padded_days_mask → (M, T=31)
daily_mt, monthly_m, padded_days_mask, daily_timef = add_month_day_dims(
daily_da, monthly_da, time_dim=time_dim
)
if is_hourly:
# hours_per_day == 24
# Reshape daily → (M, T=31*24, H, W), monthly → (M, H, W),
# and get padded_days_mask → (M, T=31*24)
daily_mt, monthly_m, padded_days_mask, daily_timef = add_month_hour_dims(
daily_da, monthly_da, time_dim=time_dim
)
else:
# Reshape daily → (M, T=31, H, W), monthly → (M, H, W),
# and get padded_days_mask → (M, T=31)
daily_mt, monthly_m, padded_days_mask, daily_timef = add_month_day_dims(
daily_da, monthly_da, time_dim=time_dim
)

# Convert to numpy once — all __getitem__ calls use these
self.daily_np = daily_mt.to_numpy().copy().astype(np.float32) # (M, T=31, H, W) float
Expand Down
15 changes: 0 additions & 15 deletions climanet/st_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,21 +742,6 @@ def forward(
)
assert T % self.patch_size[0] == 0, "T must be divisible by patch size"

if self.patch_size[0] > 1:
daily_timef = daily_timef.view(B, M, Tp, self.patch_size[0], 4).mean(
dim=3
) # -> (B,M, Tp, 4)

if padded_days_mask is not None and self.patch_size[0] > 1:
B, M, T_days = padded_days_mask.shape
if T_days % self.patch_size[0] != 0:
raise ValueError(
f"T_days={T_days} must be divisible by patch_size[0]={self.patch_size[0]}"
)
padded_days_mask = padded_days_mask.view(
B, M, T_days // self.patch_size[0], self.patch_size[0]
).all(dim=-1) # (B, M, Tp)

# Step 1: Encode spatio-temporal patches
# each month independently by folding M into batch
# encoder input shape = (B, C, T, H, W) where C is channel.
Expand Down
91 changes: 91 additions & 0 deletions climanet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,97 @@ def save_model(model: torch.nn.Module, run_dir: str, verbose: bool) -> None:
print(f"Model saved to {model_path}")


def add_month_hour_dims(
hourly_ts: xr.DataArray, # (time, H, W) hourly
monthly_ts: xr.DataArray, # (time, H, W) monthly
time_dim: str = "time",
spatial_dims: Tuple[str, str] = ("lat", "lon"),
):
"""Reshape hourly and monthly data to have explicit month (M) and hour (T) dimensions.

Here we assume maximum 31 days in a month with 24 hours per day = 744 hours maximum.
Invalid hour entries will be padded with NaN.

Returns
-------
hourly_m : xr.DataArray - dims: (M, T=744, H, W)
monthly_m : xr.DataArray - dims: (M, H, W)
padded_hours_mask : xr.DataArray - dims: (M, T=744), bool, True where hour is padded
time_features : xr.DataArray - dims: (M, T=744, 2)
"""
# Month key as integer YYYYMM
hkey = hourly_ts[time_dim].dt.year * 100 + hourly_ts[time_dim].dt.month
mkey = monthly_ts[time_dim].dt.year * 100 + monthly_ts[time_dim].dt.month

# Unique month keys preserving order
_, idx = np.unique(hkey.values, return_index=True)
month_keys = hkey.values[np.sort(idx)]

# Create hour-of-month coordinate (1-744)
# hour_of_month = (day_of_month - 1) * 24 + hour_of_day + 1
day_of_month = hourly_ts[time_dim].dt.day.values
hour_of_day = hourly_ts[time_dim].dt.hour.values
hour_of_month = (day_of_month - 1) * 24 + hour_of_day + 1

# Add M (month key) and T (hour of month) coordinates to hourly data
hourly_indexed = (
hourly_ts.assign_coords(
M=(time_dim, hkey.values),
T=(time_dim, hour_of_month)
)
.set_index({time_dim: ("M", "T")})
.unstack(time_dim)
.reindex(T=np.arange(1, 745), M=month_keys) # 744 = 31 days * 24 hours
)
# Force dim order: (M, T, H, W)
other_dims = [d for d in hourly_ts.dims if d != time_dim]
hourly_indexed = hourly_indexed.transpose("M", "T", *other_dims)

# Build padded hours mask from hourly_indexed (NaN locations)
padded_hours_mask = ~hourly_indexed.notnull().any(dim=spatial_dims)

# Align monthly data to same month keys/order
monthly_m = (
monthly_ts.assign_coords(M=(time_dim, mkey.values))
.swap_dims({time_dim: "M"})
.drop_vars(time_dim)
.sel(M=month_keys)
)

# Build aligned datetime array (M, T)
time_da = hourly_ts[time_dim]

# time_indexed is (M, T) with NaT for padded hours
time_indexed = (
time_da.assign_coords(
M=(time_dim, hkey.values),
T=(time_dim, hour_of_month)
)
.set_index({time_dim: ("M", "T")})
.unstack(time_dim)
.reindex(T=np.arange(1, 745), M=month_keys)
)

# Determine day-of-year (doy) and hour-of-day (hod)
doy_period = 365.24
hod_period = 24.0

doy = time_indexed.dt.dayofyear.fillna(0)
hod = time_indexed.dt.hour.fillna(0)

# Create phase from day and hour
doy_phase = 2 * np.pi * doy / doy_period
hod_phase = 2 * np.pi * hod / hod_period

# Stack cyclic encodings into time_features (M, T, 2)
time_features = xr.concat(
[doy_phase, hod_phase],
dim="feature"
).transpose("M", "T", "feature")

return hourly_indexed, monthly_m, padded_hours_mask, time_features


def configure_compute_resources(
model: torch.nn.Module, device: str, compute_threads: int, dataloader_num_workers: int
) -> torch.nn.Module:
Expand Down
Loading