Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 219 additions & 3 deletions dodola/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
Math stuff and business logic goes here. This is the "business logic".
"""

from typing import Union
import warnings
import logging
from numba import float32, float64, jit, types
import numpy as np
import xarray as xr
from xclim import sdba, set_options
Expand All @@ -19,6 +21,220 @@
# Assume data input here is generally clean and valid.


@jit(
[
float32[:, :](float32[:, :], float32[:, :], float32[:], types.unicode_type),
float64[:, :](float64[:, :], float64[:, :], float64[:], types.unicode_type),
float64[:, :](float64[:, :], float32[:, :], float64[:], types.unicode_type),
float32[:](float32[:], float32[:], float32[:], types.unicode_type),
float64[:](float64[:], float64[:], float64[:], types.unicode_type),
],
nopython=True,
)
def _argsort(arr_coarse, arr_fine, q, arr_sort="coarse"):
if arr_coarse.ndim == 1:
inds = np.argsort(arr_coarse)
if arr_sort == "coarse":
out = arr_coarse[inds]
elif arr_sort == "fine":
out = arr_fine[inds]
else:
out = np.empty((arr_coarse.shape[0], q.size), dtype=arr_coarse.dtype)
for index in range(out.shape[0]):
inds = np.argsort(arr_coarse[index])
if arr_sort == "coarse":
out[index] = arr_coarse[index][inds]
elif arr_sort == "fine":
out[index] = arr_fine[index][inds]
return out


def argsort(da_ref_coarse, da_ref_fine, q, dim, arr_sort="coarse", axis=0):
"""Sort ref_coarse or ref_fine (specified by the arr_sort input arg)
with the indices used to quantile ref_coarse"""
# We have two cases :
# - When all dims are processed : we stack them and use _argsort1d
# - When the quantiles are vectorized over some dims, these are also stacked and then _argsort2D is used.
# All this stacking is so we can cover all ND+1D cases with one numba function.

# check if there are any nulls in input arrays
# if there are nulls, remove them
da_ref_coarse = da_ref_coarse.dropna(dim="time")
da_ref_fine = da_ref_fine.dropna(dim="time")

# Stack the dims and send to the last position
# This is in case there are more than one
dims = [dim] if isinstance(dim, str) else dim
tem = xr.core.utils.get_temp_dimname(da_ref_coarse.dims, "temporal")
da_ref_coarse = da_ref_coarse.stack({tem: dims})
da_ref_fine = da_ref_fine.stack({tem: dims})

# So we cut in half the definitions to declare in numba
if not hasattr(q, "dtype") or q.dtype != da_ref_coarse.dtype:
q = np.array(q, dtype=da_ref_coarse.dtype)

if len(da_ref_coarse.dims) > 1:
# There are some extra dims
extra = xr.core.utils.get_temp_dimname(da_ref_coarse.dims, "extra")
da_ref_coarse = da_ref_coarse.stack({extra: set(da_ref_coarse.dims) - {tem}})
da_ref_fine = da_ref_fine.stack({extra: set(da_ref_fine.dims) - {tem}})
da_ref_coarse = da_ref_coarse.transpose(..., tem)
da_ref_fine = da_ref_fine.transpose(..., tem)

if da_ref_coarse.values.shape != da_ref_fine.values.shape:
raise ValueError("shape of coarse values does not match fine values")

if da_ref_coarse.values.shape[1] != q.shape[0]:
raise ValueError(
"shape of q is {} and shape of ref coarse/fine is {}".format(
q.shape, da_ref_coarse.values.shape
)
)

out = _argsort(da_ref_coarse.values, da_ref_fine.values, q, arr_sort)

res = xr.DataArray(
out,
dims=(extra, "quantiles"),
coords={extra: da_ref_coarse[extra], "quantiles": q},
attrs=da_ref_coarse.attrs,
).unstack(extra)

else:
# All dims are processed
res = xr.DataArray(
_argsort(da_ref_coarse.values, da_ref_fine.values, q, arr_sort),
dims=("quantiles"),
coords={"quantiles": q},
attrs=da_ref_coarse.attrs,
)

return res


@sdba.base.map_groups(
af=[sdba.Grouper.PROP, "quantiles"],
ref_coarse_q=[sdba.Grouper.PROP, "quantiles"],
)
def _qplad_train(ds, *, dim, kind, quantiles):
"""QPLAD: Train step on one group.

Dataset variables:
ref_coarse : training target, coarse resolution
ref_fine : training target, fine resolution
"""
# compute indices of days corresponding to each quantile for ref coarse
# sort ref coarse with those indices (corresponding to # of quantiles)
ref_coarse_q = argsort(
ds.ref_coarse, ds.ref_fine, quantiles, dim, arr_sort="coarse", axis=0
)

# sort ref fine with the same indices
ref_fine_q = argsort(
ds.ref_coarse, ds.ref_fine, quantiles, dim, arr_sort="fine", axis=0
)

# compute adjustment factors as difference bw course and fine for those days
af = sdba.utils.get_correction(ref_coarse_q, ref_fine_q, kind)

return xr.Dataset(data_vars=dict(af=af, ref_coarse_q=ref_coarse_q))


@sdba.base.map_blocks(reduces=[sdba.Grouper.PROP, "quantiles"], scen=[], sim_q=[])
def _qplad_adjust(ds, *, group, interp, extrapolation, kind):
"""QPLAD: Adjust process on one block.

Dataset variables:
af : Adjustment factors
hist_q : Quantiles over the training data
sim : Data to adjust.
"""
af, _ = sdba.utils.extrapolate_qm(ds.af, ds.ref_coarse_q, method=extrapolation)

sel = {dim: ds.sim_q[dim] for dim in set(af.dims).intersection(set(ds.sim_q.dims))}
sel["quantiles"] = ds.sim_q
af = sdba.utils.broadcast(af, ds.sim, group=group, interp=interp, sel=sel)

scen = sdba.utils.apply_correction(ds.sim, af, kind)
return xr.Dataset(dict(scen=scen, sim_q=ds.sim_q))


class QuantilePreservingAnalogDownscaling(sdba.adjustment.TrainAdjust):
r"""Quantile-Preserving Localized Analogs Downscaling.

Adjustment factors are computed between the corresponding days of `ref_coarse` and `ref_fine`.
Quantiles of `sim` are matched to the corresponding quantiles of `AFs` and corrected accordingly.

Parameters
----------
Train step:

nquantiles : int
The number of quantiles to use. Two endpoints at 1e-6 and 1 - 1e-6 will not be added.
kind : {'+', '*'}
The adjustment kind, either additive or multiplicative.
group : Union[str, Grouper]
The grouping information. See :py:class:`xclim.sdba.base.Grouper` for details.


"""

_allow_diff_calendars = False

@classmethod
def _train(
cls,
ref: xr.DataArray,
hist: xr.DataArray,
*,
nquantiles: int = 20,
kind: str = sdba.utils.ADDITIVE,
group: Union[str, sdba.Grouper] = "time",
):

quantiles = equally_spaced_nodes(nquantiles, eps=None).astype(ref.dtype)

ds = _qplad_train(
xr.Dataset({"ref_coarse": ref, "ref_fine": hist}),
group=group,
quantiles=quantiles,
kind=kind,
)

ds.af.attrs.update(
standard_name="Adjustment factors",
long_name="Quantile Preserving Localized Analogs Downscaling Adjustment Factors",
)
ds.ref_coarse_q.attrs.update(
standard_name="Empirical quantiles",
long_name="Empirical quantiles of coarse reference data",
)

return ds, {"group": group, "kind": kind}

def _adjust(self, sim):

# match quantiles from sim to corresponding AFs for that DOY
ds = xr.Dataset(
{
"sim": sim.drop("sim_q"),
"af": self.ds.af,
"sim_q": sim.sim_q,
"ref_coarse_q": self.ds.ref_coarse_q,
}
)

out = _qplad_adjust(
ds,
group=self.group,
interp="linear",
extrapolation="constant",
kind=self.kind,
)

return out.scen


def train_quantiledeltamapping(
reference, historical, variable, kind, quantiles_n=100, window_n=31
):
Expand Down Expand Up @@ -202,7 +418,7 @@ def train_analogdownscaling(

Returns
-------
xclim.sdba.adjustment.QuantilePreservingAnalogDownscaling
QuantilePreservingAnalogDownscaling
"""

# QPLAD method requires that the number of quantiles equals
Expand All @@ -222,7 +438,7 @@ def train_analogdownscaling(
)
)

qplad = sdba.adjustment.QuantilePreservingAnalogDownscaling.train(
qplad = QuantilePreservingAnalogDownscaling.train(
ref=coarse_reference[variable],
hist=fine_reference[variable],
kind=str(kind),
Expand Down Expand Up @@ -256,7 +472,7 @@ def adjust_analogdownscaling(simulation, qplad, variable):
variable = str(variable)

if isinstance(qplad, xr.Dataset):
qplad = sdba.adjustment.QuantilePreservingAnalogDownscaling.from_dataset(qplad)
qplad = QuantilePreservingAnalogDownscaling.from_dataset(qplad)

out = qplad.adjust(simulation[variable]).to_dataset(name=variable)

Expand Down
3 changes: 1 addition & 2 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ dependencies:
- pytest-cov
- python=3.9
- s3fs=2022.1.0
- xclim=0.31.0
- xarray=0.21.1
- xesmf=0.6.2
- bottleneck=1.3.2
- zarr=2.11.0
- pip:
- git+https://github.com/ClimateImpactLab/xclim@63023d27f89a457c752568ffcec2e9ce9ad7a81a