diff --git a/src/cmap/_colormap.py b/src/cmap/_colormap.py index 748742ddc..433072c2f 100644 --- a/src/cmap/_colormap.py +++ b/src/cmap/_colormap.py @@ -109,6 +109,12 @@ class Colormap: the matplotlib docs for more. - a `Callable` that takes an array of N values in the range [0, 1] and returns an (N, 4) array of RGBA values in the range [0, 1]. + cmap_kwargs : dict[str, Any] | None + Keyword arguments to pass to the colormap function when `value` is a + registered colormap name that maps to a callable function. For example: + `Colormap("cubehelix", cmap_kwargs={"start": 1.0, "rotation": -1.0})`. + If provided when `value` does not resolve to a callable colormap function, a + `TypeError` will be raised. name : str | None A name for the colormap. If None, will be set to the identifier or the string `"custom colormap"`. @@ -129,6 +135,14 @@ class Colormap: The color to use for values above the colormap's range. bad : ColorLike | None The color to use for bad (NaN, inf) values. + + Raises + ------ + TypeError + If `cmap_kwargs` is provided and `value` does not resolve to a callable + colormap function. + ValueError + If `value` is a string colormap name that is not found in the catalog. """ __slots__ = ( @@ -226,6 +240,7 @@ def __init__( under: ColorLike | None = None, over: ColorLike | None = None, bad: ColorLike | None = None, + cmap_kwargs: dict[str, Any] | None = None, ) -> None: self.info: CatalogItem | None = None @@ -238,6 +253,14 @@ def __init__( under = info.under if under is None else under bad = info.bad if bad is None else bad self.info = info + + # Check if cmap_kwargs is provided for a non-callable colormap + if cmap_kwargs and not callable(info.data): + raise TypeError( + f"Cannot apply cmap_kwargs to colormap {info.name!r}: " + "colormap is not a parametrized callable" + ) + if isinstance(info.data, list): if not info.data: # pragma: no cover raise ValueError(f"Catalog colormap {info.name!r} has no data") @@ -252,7 +275,12 @@ def __init__( f"Invalid catalog colormap data for {info.name!r}: {info.data}" ) else: - stops = _parse_colorstops(info.data) + _data: ColormapLike + if cmap_kwargs: + _data = partial(cast("Callable", info.data), **cmap_kwargs) + else: + _data = info.data + stops = _parse_colorstops(_data) if interpolation is None: interpolation = info.interpolation if rev: diff --git a/tests/test_colormap.py b/tests/test_colormap.py index 3a9484d71..d1550136b 100644 --- a/tests/test_colormap.py +++ b/tests/test_colormap.py @@ -236,3 +236,18 @@ def test_shifted() -> None: assert cm.shifted(0.5).shifted(-0.5) == cm # two shifts of 0.5 should give the original array assert cm.shifted().shifted() == cm + + +def test_function_colormap_with_cmap_kwargs() -> None: + # construct cubehelix with custom parameters + ch = Colormap("cubehelix", cmap_kwargs={"start": 1.0, "rotation": -1.0}) + + # values should be different from default cubehelix + default_ch = Colormap("cubehelix") + assert ch(0.5) != default_ch(0.5) + + +def test_invalid_function_colormap_with_cmap_kwargs() -> None: + # providing cmap_kwargs to a non-callable colormap should raise TypeError + with pytest.raises(TypeError, match=r"Cannot apply cmap_kwargs to colormap"): + Colormap("viridis", cmap_kwargs={"param": 1.0})