Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 139 additions & 1 deletion esmvalcore/cmor/_fixes/cmip6/icon_esm_lr.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,147 @@
"""Fixes for ICON-ESM-LR model."""

import numpy as np
from iris.coords import AuxCoord
from iris.mesh import Connectivity, MeshXY

from esmvalcore.cmor._fixes.fix import Fix
from esmvalcore.iris_helpers import has_unstructured_grid

# deduplicate decimals - round cartesian coords to merge
# identical vertices (especially pole vertices)
CARTESIAN_COORDINATE_DECIMALS = 12


class AllVars(Fix):
"""Fixes for all variables."""
"""Adapt the native ICON mesh fix for ICON-ESM-LR outputs.

Like the native ICON fix (`esmvalcore.cmor._fixes.icon._base_fixes`).
this avoids ``MeshXY.from_coords`` because shared
polygon vertices would be duplicated. Since CMIP6 ICON-ESM-LR files don't
have ``vertex_of_cell``, recreate the connectivity from coordinate
bounds.
"""

@staticmethod
def _can_create_mesh(cube):
# check if mesh is there
if cube.mesh is not None:
return False
# unstructured?
if not has_unstructured_grid(cube):
return False

lat = cube.coord("latitude")
lon = cube.coord("longitude")

# check bounds
if not lat.has_bounds() or not lon.has_bounds():
return False
if lat.bounds.shape != lon.bounds.shape:
return False
return lat.bounds.ndim == 2

@staticmethod
def _get_node_coords_and_connectivity(lat_bounds, lon_bounds):
"""Build unique mesh nodes and face-node connectivity.

Cell vertices are converted to cartesian coordinates to
identify shared nodes. Duplicate vertices are
removed and a face-node connectivity array is generated.
"""
lat_rad = np.deg2rad(lat_bounds)
lon_rad = np.deg2rad(lon_bounds)

cartesian = np.stack(
[
np.cos(lat_rad) * np.cos(lon_rad),
np.cos(lat_rad) * np.sin(lon_rad),
np.sin(lat_rad),
],
axis=-1,
)

# round coords to avoid floating point diffs
rounded = np.round(
cartesian.reshape(-1, 3),
decimals=CARTESIAN_COORDINATE_DECIMALS,
)
# unique mesh nodes
unique_nodes, inverse = np.unique(
rounded,
axis=0,
return_inverse=True,
)

# we create face-node connectivity and back to lon/lat
connectivity = inverse.reshape(lat_bounds.shape)
norm = np.linalg.norm(unique_nodes, axis=1)
unit_nodes = unique_nodes / norm[:, np.newaxis]
node_lat = np.rad2deg(np.arcsin(unit_nodes[:, 2]))
node_lon = (
np.rad2deg(np.arctan2(unit_nodes[:, 1], unit_nodes[:, 0])) % 360.0
)

return node_lat, node_lon, connectivity

def _fix_unstructured_mesh(self, cube):
"""Create and attach the Iris mesh.

Constructs node coordinates and face-node connectivity
from latitude longitude bounds and replaces the original
coordinate representation with Iris mesh coordinates.
"""
if not self._can_create_mesh(cube):
return

lat = cube.coord("latitude")
lon = cube.coord("longitude")
mesh_dim = cube.coord_dims(lat)

# construct the face_node_connectivity from bounds
node_lat_points, node_lon_points, face_node_connectivity = (
self._get_node_coords_and_connectivity(lat.bounds, lon.bounds)
)

node_lat = AuxCoord(
node_lat_points,
standard_name="latitude",
long_name="node latitude",
var_name="nlat",
units=lat.units,
)
node_lon = AuxCoord(
node_lon_points,
standard_name="longitude",
long_name="node longitude",
var_name="nlon",
units=lon.units,
)

face_lat = lat.copy()
face_lon = lon.copy()
# Update face bounds using the deduplicated node coords.
face_lat.bounds = node_lat.points[face_node_connectivity]
face_lon.bounds = node_lon.points[face_node_connectivity]

# Same create mesh logic with native ICON fix (Iris mesh object).
connectivity = Connectivity(
indices=face_node_connectivity,
cf_role="face_node_connectivity",
start_index=0,
location_axis=0,
)
mesh = MeshXY(
topology_dimension=2,
node_coords_and_axes=[(node_lat, "y"), (node_lon, "x")],
face_coords_and_axes=[(face_lat, "y"), (face_lon, "x")],
connectivities=[connectivity],
)

cube.remove_coord("latitude")
cube.remove_coord("longitude")
for mesh_coord in mesh.to_MeshCoords("face"):
cube.add_aux_coord(mesh_coord, mesh_dim)

def fix_metadata(self, cubes):
"""Rename ``var_name`` of latitude and longitude.
Expand All @@ -29,5 +166,6 @@ def fix_metadata(self, cubes):
for std_name, var_name in varnames_to_change.items():
if cube.coords(std_name):
cube.coord(std_name).var_name = var_name
self._fix_unstructured_mesh(cube)

return cubes
61 changes: 61 additions & 0 deletions tests/integration/cmor/_fixes/cmip6/test_icon_esm_lr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for the fixes of ICON-ESM-LR."""

import numpy as np
import pytest
from iris.coords import AuxCoord
from iris.cube import Cube, CubeList
Expand Down Expand Up @@ -95,3 +96,63 @@ def test_allvars_fix_metadata_no_lat_lon(cubes):
fix = AllVars(None)
out_cubes = fix.fix_metadata(cubes)
assert cubes is out_cubes


def test_allvars_fix_metadata_adds_mesh():
lat = AuxCoord(
[0.5, 0.5],
standard_name="latitude",
var_name="latitude",
units="degrees_north",
bounds=[[0.0, 0.0, 1.0], [0.0, 1.0, 1.0]],
)
lon = AuxCoord(
[0.5, 0.5],
standard_name="longitude",
var_name="longitude",
units="degrees_east",
bounds=[[0.0, 1.0, 0.0], [1.0, 1.0, 0.0]],
)
cube = Cube(
np.array([1.0, 2.0]),
var_name="tas",
aux_coords_and_dims=[(lat, 0), (lon, 0)],
)

AllVars(None).fix_metadata(CubeList([cube]))

assert cube.mesh is not None
assert cube.coords("latitude", mesh_coords=True)
assert cube.coords("longitude", mesh_coords=True)
assert cube.mesh.connectivity().shape == (2, 3)


def test_allvars_fix_metadata_merges_pole_vertices():
"""Ensure pole vertices are merged into a single mesh node.

there must be 4 nodes: North Pole, (0,0), (0,120), (0,240)
"""
lat = AuxCoord(
[60.0, 60.0],
standard_name="latitude",
var_name="latitude",
units="degrees_north",
bounds=[[90.0, 0.0, 0.0], [90.0, 0.0, 0.0]],
)
lon = AuxCoord(
[60.0, 180.0],
standard_name="longitude",
var_name="longitude",
units="degrees_east",
bounds=[[0.0, 0.0, 120.0], [240.0, 120.0, 240.0]],
)
cube = Cube(
np.array([1.0, 2.0]),
var_name="tas",
aux_coords_and_dims=[(lat, 0), (lon, 0)],
)

AllVars(None).fix_metadata(CubeList([cube]))

node_lat = cube.mesh.coord(location="node", axis="y")
assert len(node_lat.points) == 4