Source code for ridgeplot._ridgeplot

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, cast

from ridgeplot._figure_factory import LabelsArray, RidgePlotFigureFactory, ShallowLabelsArray
from ridgeplot._kde import estimate_densities
from ridgeplot._missing import MISSING, MissingType
from ridgeplot._types import (
    Densities,
    Samples,
    ShallowDensities,
    ShallowSamples,
    is_flat_str_collection,
    is_shallow_densities,
    is_shallow_samples,
    nest_shallow_collection,
)

if TYPE_CHECKING:

    import plotly.graph_objects as go

    from ridgeplot._colors import ColorScale
    from ridgeplot._figure_factory import Colormode
    from ridgeplot._kde import KDEBandwidth, KDEPoints


[docs] def ridgeplot( samples: Samples | ShallowSamples | None = None, densities: Densities | ShallowDensities | None = None, kernel: str = "gau", bandwidth: KDEBandwidth = "normal_reference", kde_points: KDEPoints = 500, colorscale: str | ColorScale = "plasma", colormode: Colormode = "mean-minmax", coloralpha: float | None = None, labels: LabelsArray | ShallowLabelsArray | None = None, linewidth: float = 1.0, spacing: float = 0.5, show_annotations: bool | MissingType = MISSING, show_yticklabels: bool = True, xpad: float = 0.05, ) -> go.Figure: r"""Return an interactive ridgeline (Plotly) |~go.Figure|. .. note:: You must pass either :paramref:`samples` or :paramref:`densities` to this function, but not both. See descriptions below for more details. .. _bandwidths.py: https://www.statsmodels.org/stable/_modules/statsmodels/nonparametric/bandwidths.html .. _Plotly's built-in color-scales: https://plotly.com/python/builtin-colorscales/ .. _ragged: https://en.wikipedia.org/wiki/Jagged_array Parameters ---------- samples : Samples or ShallowSamples, optional If ``samples`` data is specified, Kernel Density Estimation (KDE) will be computed. See :paramref:`kernel`, :paramref:`bandwidth`, and :paramref:`kde_points` for more details and KDE configuration options. The ``samples`` argument should be an array of shape :math:`(R, T_r, S_t)`. Note that we support irregular (`ragged`_) arrays, where: - :math:`R` is the number of rows in the plot - :math:`T_r` is the number of traces per row, where each row :math:`r \in R` can have a different number of traces. - :math:`S_t` is the number of samples per trace, where each trace :math:`t \in T_r` can also have a different number of samples. The KDE will be performed over the sample values (:math:`S_t`) for all traces. After the KDE, the resulting array will be a (4D) :paramref:`densities` array with shape :math:`(R, T_r, P_t, 2)` (see below for more details). densities : Densities or ShallowDensities, optional If a ``densities`` array is specified, the KDE step will be skipped and all associated arguments ignored. Each density array should have shape :math:`(R, T_r, P_t, 2)` (4D). Just like the :paramref:`samples` argument, we also support irregular (`ragged`_) ``densities`` arrays, where: - :math:`R` is the number of rows in the plot - :math:`T_r` is the number of traces per row, where each row :math:`r \in R` can have a different number of traces. - :math:`P_t` is the number of points per trace, where each trace :math:`t \in T_r` can also have a different number of points. - :math:`2` is the number of coordinates per point (x and y) kernel : str The Kernel to be used during Kernel Density Estimation. The default is a Gaussian Kernel (``"gau"``). Choices are: - ``"biw"`` for biweight - ``"cos"`` for cosine - ``"epa"`` for Epanechnikov - ``"gau"`` for Gaussian. - ``"tri"`` for triangular - ``"triw"`` for triweight - ``"uni"`` for uniform bandwidth : KDEBandwidth The bandwidth to use during Kernel Density Estimation. The default is ``"normal_reference"``. Choices are: - ``"scott"`` - 1.059 * A * nobs ** (-1/5.), where A is ``min(std(x),IQR/1.34)`` - ``"silverman"`` - .9 * A * nobs ** (-1/5.), where A is ``min(std(x),IQR/1.34)`` - ``"normal_reference"`` - C * A * nobs ** (-1/5.), where C is calculated from the kernel. Equivalent (up to 2 dp) to the ``"scott"`` bandwidth for gaussian kernels. See `bandwidths.py`_. - If a float is given, its value is used as the bandwidth. - If a callable is given, it's return value is used. The callable should take exactly two parameters, i.e., ``fn(x, kern)``, and return a float, where: - ``x``: the clipped input data - ``kern``: the kernel instance used kde_points : KDEPoints This argument controls the points at which KDE is computed. If an ``int`` value is passed (default=500), the densities will be evaluated at ``kde_points`` evenly spaced points between the min and max of each set of samples. Optionally, you can also pass a custom 1D numerical array, which will be used for all traces. colorscale : str or ColorScale Any valid Plotly color-scale or a str with a valid named color-scale. Use :func:`~ridgeplot.list_all_colorscale_names()` to see which names are available or check out `Plotly's built-in color-scales`_. colormode : Colormode This argument controls the logic for choosing the color filling of each ridgeline trace. Each option provides a different method for calculating the :paramref:`colorscale` midpoint of each trace. The default is mode is ``"mean-means"``. Choices are: - ``"row-index"`` - uses the row's index. e.g. if the ridgeplot has 3 rows of traces, then the midpoints will be ``[[0, ...], [0.5, ...], [1, ...]]``. - ``"trace-index"`` - uses the trace's index. e.g. if the ridgeplot has a total of 3 traces (across all rows), then the midpoints will be 0, 0.5, and 1, respectively. - ``"mean-minmax"`` - uses the min-max normalized (weighted) mean of each density to calculate the midpoints. The normalization min and max values are the minimum and maximum x-values from all densities, respectively. - ``"mean-means"`` - uses the min-max normalized (weighted) mean of each density to calculate the midpoints. The normalization min and max values are the minimum and maximum mean values from all densities, respectively. coloralpha : float, optional If None (default), this argument will be ignored and the transparency values of the specifies color-scale will remain untouched. Otherwise, if a float value is passed, it will be used to overwrite the transparency (alpha) of the color-scale's colors. labels : LabelsArray or ShallowLabelsArray, optional A list of string labels for each trace. The default value is None, which will result in auto-generated labels of form "Trace n". If, instead, a list of labels is specified, it must be of the same size/length as the number of traces. linewidth : float The traces' line width (in px). spacing : float The vertical spacing between density traces, which is defined in units of the highest distribution (i.e. the maximum y-value). show_annotations : bool Whether to show the tick labels on the y-axis. The default is True. .. deprecated:: 0.1.21 Use :paramref:`show_yticklabels` instead. show_yticklabels : bool Whether to show the tick labels on the y-axis. The default is True. .. versionadded:: 0.1.21 Replaces the deprecated :paramref:`show_annotations` argument. xpad : float Specifies the extra padding to use on the x-axis. It is defined in units of the range between the minimum and maximum x-values from all distributions. Returns ------- :class:`plotly.graph_objects.Figure` A Plotly :class:`~plotly.graph_objects.Figure` with a ridgeline plot. You can further customize this figure to your liking (e.g. using the :meth:`~plotly.graph_objects.Figure.update_layout()` method). Raises ------ :exc:`ValueError` If both :paramref:`samples` and :paramref:`densities` are specified, or if neither of them is specified. i.e. you may only specify one of them. """ has_samples = samples is not None has_densities = densities is not None if has_samples and has_densities: raise ValueError("You may not specify both `samples` and `densities` arguments!") if not has_samples and not has_densities: raise ValueError("You must specify either `samples` or `densities`") if has_densities: if is_shallow_densities(densities): densities = cast(ShallowDensities, densities) densities = nest_shallow_collection(densities) densities = cast(Densities, densities) else: if is_shallow_samples(samples): samples = cast(ShallowSamples, samples) samples = nest_shallow_collection(samples) samples = cast(Samples, samples) # Convert samples to densities densities = estimate_densities( samples=samples, points=kde_points, kernel=kernel, bandwidth=bandwidth, ) if is_flat_str_collection(labels): labels = cast(ShallowLabelsArray, labels) labels = cast(LabelsArray, nest_shallow_collection(labels)) if colormode == "index": # type: ignore[comparison-overlap] warnings.warn( # type: ignore[unreachable] "The colormode='index' value has been deprecated in favor of " "colormode='row-index', which provides the same functionality but " "is more explicit and allows to distinguish between the " "'row-index' and 'trace-index' modes. Support for the " "deprecated value will be removed in a future version.", DeprecationWarning, stacklevel=2, ) colormode = "row-index" if show_annotations is not MISSING: warnings.warn( "The show_annotations argument has been deprecated in favor of " "show_yticklabels. Support for the deprecated argument will be " "removed in a future version.", DeprecationWarning, stacklevel=2, ) show_yticklabels = show_annotations ridgeplot_figure_factory = RidgePlotFigureFactory( densities=densities, labels=labels, colorscale=colorscale, coloralpha=coloralpha, colormode=colormode, linewidth=linewidth, spacing=spacing, show_yticklabels=show_yticklabels, xpad=xpad, ) return ridgeplot_figure_factory.make_figure()