Skip to content

Commit

Permalink
Make sure that supplementary variables and weights have same chunks a…
Browse files Browse the repository at this point in the history
…s parent cube (#2637)

Co-authored-by: Bouwe Andela <b.andela@esciencecenter.nl>
  • Loading branch information
schlunma and bouweandela authored Jan 21, 2025
1 parent fc45c72 commit 2e6804a
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 53 deletions.
4 changes: 2 additions & 2 deletions esmvalcore/iris_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,14 +247,14 @@ def rechunk_cube(
cube:
Input cube.
complete_coords:
(Names of) coordinates along which the output cubes should not be
(Names of) coordinates along which the output cube should not be
chunked.
remaining_dims:
Chunksize of the remaining dimensions.
Returns
-------
Cube
iris.cube.Cube
Rechunked cube. This will always be a copy of the input cube.
"""
Expand Down
87 changes: 57 additions & 30 deletions esmvalcore/preprocessor/_supplementary_vars.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
"""Preprocessor functions for ancillary variables and cell measures."""

import logging
from typing import Iterable
from collections.abc import Callable
from typing import Iterable, Literal

import iris.coords
import iris.cube
from iris.cube import Cube

logger = logging.getLogger(__name__)

PREPROCESSOR_SUPPLEMENTARIES = {}


def register_supplementaries(variables, required):
def register_supplementaries(
variables: list[str],
required: Literal["require_at_least_one", "prefer_at_least_one"],
) -> Callable:
"""Register supplementary variables required for a preprocessor function.
Parameters
----------
variables: :obj:`list` of :obj`str`
variables:
List of variable names.
required:
How strong the requirement is. Can be 'require_at_least_one' if at
Expand All @@ -39,16 +43,25 @@ def wrapper(func):
return wrapper


def add_cell_measure(cube, cell_measure_cube, measure):
"""Add a cube as a cell_measure in the cube containing the data.
def add_cell_measure(
cube: Cube,
cell_measure_cube: Cube,
measure: Literal["area", "volume"],
) -> None:
"""Add cell measure to cube (in-place).
Note
----
This assumes that the cell measure spans the rightmost dimensions of the
cube.
Parameters
----------
cube: iris.cube.Cube
cube:
Iris cube with input data.
cell_measure_cube: iris.cube.Cube
cell_measure_cube:
Iris cube with cell measure data.
measure: str
measure:
Name of the measure, can be 'area' or 'volume'.
Returns
Expand All @@ -65,47 +78,62 @@ def add_cell_measure(cube, cell_measure_cube, measure):
raise ValueError(
f"measure name must be 'area' or 'volume', got {measure} instead"
)
measure = iris.coords.CellMeasure(
cell_measure_cube.core_data(),
coord_dims = tuple(
range(cube.ndim - len(cell_measure_cube.shape), cube.ndim)
)
cell_measure_data = cell_measure_cube.core_data()
if cell_measure_cube.has_lazy_data():
cube_chunks = tuple(cube.lazy_data().chunks[d] for d in coord_dims)
cell_measure_data = cell_measure_data.rechunk(cube_chunks)
cell_measure = iris.coords.CellMeasure(
cell_measure_data,
standard_name=cell_measure_cube.standard_name,
units=cell_measure_cube.units,
measure=measure,
var_name=cell_measure_cube.var_name,
attributes=cell_measure_cube.attributes,
)
start_dim = cube.ndim - len(measure.shape)
cube.add_cell_measure(measure, range(start_dim, cube.ndim))
cube.add_cell_measure(cell_measure, coord_dims)
logger.debug(
"Added %s as cell measure in cube of %s.",
cell_measure_cube.var_name,
cube.var_name,
)


def add_ancillary_variable(cube, ancillary_cube):
"""Add cube as an ancillary variable in the cube containing the data.
def add_ancillary_variable(cube: Cube, ancillary_cube: Cube) -> None:
"""Add ancillary variable to cube (in-place).
Note
----
This assumes that the ancillary variable spans the rightmost dimensions of
the cube.
Parameters
----------
cube: iris.cube.Cube
cube:
Iris cube with input data.
ancillary_cube: iris.cube.Cube
ancillary_cube:
Iris cube with ancillary data.
Returns
-------
iris.cube.Cube
Cube with added ancillary variables
"""
coord_dims = tuple(range(cube.ndim - len(ancillary_cube.shape), cube.ndim))
ancillary_data = ancillary_cube.core_data()
if ancillary_cube.has_lazy_data():
cube_chunks = tuple(cube.lazy_data().chunks[d] for d in coord_dims)
ancillary_data = ancillary_data.rechunk(cube_chunks)
ancillary_var = iris.coords.AncillaryVariable(
ancillary_cube.core_data(),
ancillary_data,
standard_name=ancillary_cube.standard_name,
units=ancillary_cube.units,
var_name=ancillary_cube.var_name,
attributes=ancillary_cube.attributes,
)
start_dim = cube.ndim - len(ancillary_var.shape)
cube.add_ancillary_variable(ancillary_var, range(start_dim, cube.ndim))
cube.add_ancillary_variable(ancillary_var, coord_dims)
logger.debug(
"Added %s as ancillary variable in cube of %s.",
ancillary_cube.var_name,
Expand All @@ -114,10 +142,10 @@ def add_ancillary_variable(cube, ancillary_cube):


def add_supplementary_variables(
cube: iris.cube.Cube,
supplementary_cubes: Iterable[iris.cube.Cube],
) -> iris.cube.Cube:
"""Add ancillary variables and/or cell measures.
cube: Cube,
supplementary_cubes: Iterable[Cube],
) -> Cube:
"""Add ancillary variables and/or cell measures to cube (in-place).
Parameters
----------
Expand All @@ -131,7 +159,7 @@ def add_supplementary_variables(
iris.cube.Cube
Cube with added ancillary variables and/or cell measures.
"""
measure_names = {
measure_names: dict[str, Literal["area", "volume"]] = {
"areacella": "area",
"areacello": "area",
"volcello": "volume",
Expand All @@ -145,15 +173,14 @@ def add_supplementary_variables(
return cube


def remove_supplementary_variables(cube: iris.cube.Cube):
"""Remove supplementary variables.
def remove_supplementary_variables(cube: Cube) -> Cube:
"""Remove supplementary variables from cube (in-place).
Strip cell measures or ancillary variables from the cube containing the
data.
Strip cell measures or ancillary variables from the cube.
Parameters
----------
cube: iris.cube.Cube
cube:
Iris cube with data and cell measures or ancillary variables.
Returns
Expand Down
21 changes: 15 additions & 6 deletions esmvalcore/preprocessor/_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def extract_volume(
return cube.extract(z_constraint)


def calculate_volume(cube: Cube) -> da.core.Array:
def calculate_volume(cube: Cube) -> np.ndarray | da.Array:
"""Calculate volume from a cube.
This function is used when the 'ocean_volume' cell measure can't be found.
Expand All @@ -119,13 +119,13 @@ def calculate_volume(cube: Cube) -> da.core.Array:
Parameters
----------
cube: iris.cube.Cube
cube:
input cube.
Returns
-------
dask.array.core.Array
Grid volumes.
np.ndarray | dask.array.Array
Grid volume.
"""
# Load depth field and figure out which dim is which
Expand Down Expand Up @@ -158,7 +158,11 @@ def calculate_volume(cube: Cube) -> da.core.Array:
# Calculate Z-direction thickness
thickness = depth.core_bounds()[..., 1] - depth.core_bounds()[..., 0]
if cube.has_lazy_data():
thickness = da.array(thickness)
z_chunks = tuple(cube.lazy_data().chunks[d] for d in z_dim)
if isinstance(thickness, da.Array):
thickness = thickness.rechunk(z_chunks)
else:
thickness = da.asarray(thickness, chunks=z_chunks)

# Get or calculate the horizontal areas of the cube
has_cell_measure = bool(cube.cell_measures("cell_area"))
Expand All @@ -182,6 +186,8 @@ def calculate_volume(cube: Cube) -> da.core.Array:
thickness, cube.shape, z_dim, chunks=chunks
)
grid_volume = area_arr * thickness_arr
if cube.has_lazy_data():
grid_volume = grid_volume.rechunk(chunks)

return grid_volume

Expand Down Expand Up @@ -403,7 +409,10 @@ def axis_statistics(
def _add_axis_stats_weights_coord(cube, coord, coord_dims):
"""Add weights for axis_statistics to cube (in-place)."""
weights = np.abs(coord.lazy_bounds()[:, 1] - coord.lazy_bounds()[:, 0])
if not cube.has_lazy_data():
if cube.has_lazy_data():
coord_chunks = tuple(cube.lazy_data().chunks[d] for d in coord_dims)
weights = weights.rechunk(coord_chunks)
else:
weights = weights.compute()
weights_coord = AuxCoord(
weights,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
:func:`esmvalcore.preprocessor._supplementary_vars` module.
"""

import dask.array as da
import iris
import iris.fileformats
import numpy as np
Expand Down Expand Up @@ -112,20 +113,37 @@ def setUp(self):
],
)

@pytest.mark.parametrize("lazy", [True, False])
@pytest.mark.parametrize("var_name", ["areacella", "areacello"])
def test_add_cell_measure_area(self, var_name):
def test_add_cell_measure_area(self, var_name, lazy):
"""Test add area fx variables as cell measures."""
if lazy:
self.fx_area.data = self.fx_area.lazy_data()
self.new_cube_data = da.array(self.new_cube_data).rechunk((1, 2))
self.fx_area.var_name = var_name
self.fx_area.standard_name = "cell_area"
self.fx_area.units = "m2"
cube = iris.cube.Cube(
self.new_cube_data, dim_coords_and_dims=self.coords_spec
)

cube = add_supplementary_variables(cube, [self.fx_area])
assert cube.cell_measure(self.fx_area.standard_name) is not None

def test_add_cell_measure_volume(self):
assert cube.has_lazy_data() is lazy
assert cube.cell_measures(self.fx_area.standard_name)
cell_measure = cube.cell_measure(self.fx_area.standard_name)
assert cell_measure.has_lazy_data() is lazy
if lazy:
assert cell_measure.lazy_data().chunks == cube.lazy_data().chunks

@pytest.mark.parametrize("lazy", [True, False])
def test_add_cell_measure_volume(self, lazy):
"""Test add volume as cell measure."""
if lazy:
self.fx_volume.data = self.fx_volume.lazy_data()
self.new_cube_3D_data = da.array(self.new_cube_3D_data).rechunk(
(1, 2, 3)
)
self.fx_volume.var_name = "volcello"
self.fx_volume.standard_name = "ocean_volume"
self.fx_volume.units = "m3"
Expand All @@ -137,8 +155,15 @@ def test_add_cell_measure_volume(self):
(self.lons, 2),
],
)

cube = add_supplementary_variables(cube, [self.fx_volume])
assert cube.cell_measure(self.fx_volume.standard_name) is not None

assert cube.has_lazy_data() is lazy
assert cube.cell_measures(self.fx_volume.standard_name)
cell_measure = cube.cell_measure(self.fx_volume.standard_name)
assert cell_measure.has_lazy_data() is lazy
if lazy:
assert cell_measure.lazy_data().chunks == cube.lazy_data().chunks

def test_no_cell_measure(self):
"""Test no cell measure is added."""
Expand All @@ -153,16 +178,27 @@ def test_no_cell_measure(self):
cube = add_supplementary_variables(cube, [])
assert cube.cell_measures() == []

def test_add_supplementary_vars(self):
"""Test invalid variable is not added as cell measure."""
@pytest.mark.parametrize("lazy", [True, False])
def test_add_ancillary_vars(self, lazy):
"""Test adding ancillary variables."""
if lazy:
self.fx_area.data = self.fx_area.lazy_data()
self.new_cube_data = da.array(self.new_cube_data).rechunk((1, 2))
self.fx_area.var_name = "sftlf"
self.fx_area.standard_name = "land_area_fraction"
self.fx_area.units = "%"
cube = iris.cube.Cube(
self.new_cube_data, dim_coords_and_dims=self.coords_spec
)

cube = add_supplementary_variables(cube, [self.fx_area])
assert cube.ancillary_variable(self.fx_area.standard_name) is not None

assert cube.has_lazy_data() is lazy
assert cube.ancillary_variables(self.fx_area.standard_name)
anc_var = cube.ancillary_variable(self.fx_area.standard_name)
assert anc_var.has_lazy_data() is lazy
if lazy:
assert anc_var.lazy_data().chunks == cube.lazy_data().chunks

def test_wrong_shape(self, monkeypatch):
"""Test variable is not added if it's not broadcastable to cube."""
Expand Down
Loading

0 comments on commit 2e6804a

Please sign in to comment.