Source code for solarwindpy.plotting.spiral

#!/usr/bin/env python
r"""Spiral mesh plots and associated binning utilities."""

import logging

import numpy as np
import pandas as pd
import matplotlib as mpl

from datetime import datetime
from numbers import Number
from collections import namedtuple
from numba import njit, prange

from matplotlib import pyplot as plt

from . import base
from . import labels as labels_module

InitialSpiralEdges = namedtuple("InitialSpiralEdges", "x,y")
SpiralMeshBinID = namedtuple("SpiralMeshBinID", "id,fill,visited")
SpiralFilterThresholds = namedtuple(
    "SpiralFilterThresholds", "density,size", defaults=(False,)
)


[docs] @njit(parallel=True) def get_counts_per_bin(bins, x, y): nbins = bins.shape[0] cell_count = np.full(nbins, 0, dtype=np.int64) for i in prange(nbins): x0, x1, y0, y1 = bins[i] left = x >= x0 right = x < x1 bottom = y >= y0 top = y < y1 chk_cell = left & right & bottom & top cell_count[i] = chk_cell.sum() return cell_count
[docs] @njit(parallel=True) def calculate_bin_number_with_numba(mesh, x, y): fill = -9999 zbin = np.full(x.size, fill, dtype=np.int64) nbins = mesh.shape[0] bin_visited = np.zeros(nbins, dtype=np.int64) for i in prange(nbins): x0, x1, y0, y1 = mesh[i] # Assume that largest x- and y-edges are extended by larger of 1% and 0.01 # so that we can just naively use < instead of a special case of <=. # At time of writing (20200418), `SpiralPlot.initialize_mesh` did this. tk = (x >= x0) & (x < x1) & (y >= y0) & (y < y1) zbin[tk] = i bin_visited[i] += 1 return zbin, fill, bin_visited
[docs] class SpiralMesh(object):
[docs] def __init__(self, x, y, initial_xedges, initial_yedges, min_per_bin=250): self.set_data(x, y) self.set_min_per_bin(min_per_bin) self.set_initial_edges(initial_xedges, initial_yedges) self._cell_filter_thresholds = SpiralFilterThresholds(density=False, size=False)
@property def bin_id(self): return self._bin_id @property def cat(self): r""":py:class:`pd.Categorical` version of `bin_id`, with fill bin removed.""" return self._cat @property def data(self): return self._data @property def initial_edges(self): return self._initial_edges @property def mesh(self): return self._mesh @property def min_per_bin(self): return self._min_per_bin @property def cell_filter_thresholds(self): return self._cell_filter_thresholds @property def cell_filter(self): r"""Boolean :py:class:`Series` identifying properly filled mesh cells. Series selects mesh cells that meet density and area criteria specified by :py:meth:`mesh_cell_filter_thresholds`. Notes ---- Neither `density` nor `size` convert log-scale edges into linear scale. Doing so would overweight the area of mesh cells at larger values on a given axis. """ density = self.cell_filter_thresholds.density size = self.cell_filter_thresholds.size x = self.mesh[:, [0, 1]] y = self.mesh[:, [2, 3]] dx = x[:, 1] - x[:, 0] dy = y[:, 1] - y[:, 0] dA = dx * dy tk = np.full_like(dx, True, dtype=bool) if size: size_quantile = np.quantile(dA, size) tk_size = dA < size_quantile tk = tk & (tk_size) if density: cnt = np.bincount(self.bin_id.id, minlength=self.mesh.shape[0]) assert cnt.shape == tk.shape cell_density = cnt / dA density_quantile = np.quantile(cell_density, density) tk_density = cell_density > density_quantile tk = tk & tk_density return tk
[docs] def set_cell_filter_thresholds(self, **kwargs): r"""Set or update the :py:meth:`mesh_cell_filter_thresholds`. Parameters ---------- density: scalar The density quantile above which we want to select bins, e.g. above the 0.01 quantile. This ensures that each bin meets some sufficient fill factor. size: scalar The size quantile below which we want to select bins, e.g. below the 0.99 quantile. This ensures that the bin isn't so large that it will appear as an outlier. """ density = kwargs.pop("density", False) size = kwargs.pop("size", False) if len(kwargs.keys()): extra = "\n".join(["{}: {}".format(k, v) for k, v in kwargs.items()]) raise KeyError("Unexpected kwarg\n{}".format(extra)) self._cell_filter_thresholds = SpiralFilterThresholds( density=density, size=size )
[docs] def set_initial_edges(self, xedges, yedges): self._initial_edges = InitialSpiralEdges(xedges, yedges)
[docs] def set_data(self, x, y): data = pd.concat({"x": x, "y": y}, axis=1) self._data = data # SpiralMeshData(x, y)
[docs] def set_min_per_bin(self, new): self._min_per_bin = int(new)
[docs] def initialize_bins(self): # Leaves initial edges altered when we change maximum edge. xbins = self.initial_edges.x ybins = self.initial_edges.y left = xbins[:-1] right = xbins[1:] bottom = ybins[:-1] top = ybins[1:] nx = left.size ny = bottom.size mesh = np.full((nx * ny, 4), np.nan, dtype=np.float64) for x0, x1, i in zip(left, right, range(nx)): for y0, y1, j in zip(bottom, top, range(ny)): # NOTE: i*ny+j means go to i'th row, which has # nrow * number of bins passed. Then go # to j'th bin because we have to traverse # to the j'th y-bin too. mesh[(i * ny) + j] = [x0, x1, y0, y1] mesh = np.array(mesh) self.initial_mesh = np.array(mesh) return mesh
[docs] @staticmethod def process_one_spiral_step(bins, x, y, min_per_bin): cell_count = get_counts_per_bin(bins, x, y) bins_to_replace = cell_count > min_per_bin nbins_to_replace = bins_to_replace.sum() if not nbins_to_replace: return None, 0 xhyh = 0.5 * (bins[:, [0, 2]] + bins[:, [1, 3]]) def split_this_cell(idx): x0, x1, y0, y1 = bins[idx] xh, yh = xhyh[idx] # Reduce calls to `np.array`. # Just return a list here. split_cell = [ [x0, xh, y0, yh], [xh, x1, y0, yh], [xh, x1, yh, y1], [x0, xh, yh, y1], ] return split_cell new_cells = bins_to_replace.sum() * [None] for i, idx in enumerate(np.where(bins_to_replace)[0]): new_cells[i] = split_this_cell(idx) new_cells = np.vstack(new_cells) bins[bins_to_replace] = np.nan return new_cells, nbins_to_replace
@staticmethod def _visualize_logged_stats(stats_str): from matplotlib import pyplot as plt stats = [[y.strip() for y in x.split(" ") if y] for x in stats_str.split("\n")] stats.pop(1) # Remove column underline row stats = np.array(stats) index = pd.Index(stats[1:, 0].astype(int), name="Step") n_replaced = stats[1:, 1].astype(int) dt = pd.to_timedelta(stats[1:, 2]).total_seconds() dt_unit = "s" if dt.max() > 60: dt /= 60 dt_unit = "m" if dt.max() > 60: dt /= 60 dt_unit = "H" if dt.max() > 24: dt /= 24 dt_unit = "D" dt_key = f"Elapsed [{dt_unit}]" stats = pd.DataFrame({dt_key: dt, "N Divisions": n_replaced}, index=index) fig, ax = plt.subplots() tax = ax.twinx() x = stats.index k = f"Elapsed [{dt_unit}]" ax.plot(x, stats.loc[:, k], label=k, marker="+", ms=8) k = "N Divisions" tax.plot(x, stats.loc[:, k], label=k, c="C1", ls="--", marker="x", ms=8) tax.grid(False) ax.set_xlabel("Step Number") ax.set_ylabel(dt_key) tax.set_ylabel("N Divisions") h0, l0 = ax.get_legend_handles_labels() h1, l1 = tax.get_legend_handles_labels() ax.legend( h0 + h1, l0 + l1, title=rf"$\Delta t = {stats.loc[:, dt_key].sum():.0f} \, {dt_unit}$", ) ax.set_yscale("log") tax.set_yscale("log") return ax, tax, stats
[docs] def generate_mesh(self): logger = logging.getLogger("__main__") start = datetime.now() logger.warning(f"Generating {self.__class__.__name__} at {start}") x = self.data.x.values y = self.data.y.values min_per_bin = self.min_per_bin initial_bins = self.initialize_bins() # To reduce memory needs, only process data in mesh. x0 = initial_bins[:, 0].min() x1 = initial_bins[:, 1].max() y0 = initial_bins[:, 2].min() y1 = initial_bins[:, 3].max() tk_data_in_mesh = ( (x0 <= x) & (x <= x1) & (y0 <= y) & (y <= y1) & np.isfinite(x) & np.isfinite(y) ) x = x[tk_data_in_mesh] y = y[tk_data_in_mesh] initial_cell_count = get_counts_per_bin(initial_bins, x, y) bins_to_replace = initial_cell_count > min_per_bin nbins_to_replace = bins_to_replace.sum() list_of_bins = [initial_bins] active_bins = initial_bins logger.warning( """ Step N Elapsed Time ====== ======= ==============""" ) step_start = datetime.now() step = 0 while nbins_to_replace > 0: active_bins, nbins_to_replace = self.process_one_spiral_step( active_bins, x, y, min_per_bin ) now = datetime.now() logger.warning(f"{step:>6} {nbins_to_replace:>7} {(now - step_start)}") list_of_bins.append(active_bins) step += 1 step_start = now list_of_bins = [b for b in list_of_bins if b is not None] final_bins = np.vstack(list_of_bins) valid_bins = np.isfinite(final_bins).all(axis=1) final_bins = final_bins[valid_bins] stop = datetime.now() logger.warning(f"\nCompleted {self.__class__.__name__} at {stop}") logger.warning(f"Elasped time {stop - start}") logger.warning(f"Split bin threshold {min_per_bin}") logger.warning( f"Generated {final_bins.shape[0]} bins for {x.size} spectra (~{x.size / final_bins.shape[0]:.3f} spectra per bin)\n" ) self._mesh = final_bins
[docs] def calculate_bin_number(self): logger = logging.getLogger(__name__) logger.warning( f"Calculating {self.__class__.__name__} bin_number at {datetime.now()}" ) x = self.data.loc[:, "x"].values y = self.data.loc[:, "y"].values mesh = self.mesh nbins = mesh.shape[0] start = datetime.now() zbin, fill, bin_visited = calculate_bin_number_with_numba(mesh, x, y) stop = datetime.now() logger.warning(f"Elapsed time {stop - start}") if (zbin == fill).any(): logger.warning( f"""`zbin` contains {(zbin == fill).sum()} ({100 * (zbin == fill).mean():.1f}%) fill values that are outside of mesh. They will be replaced by NaNs and excluded from the aggregation. """ ) # Set fill bin to zero is_fill = zbin == fill # `minlength=nbins` forces us to include empty bins at the end of the array. bin_frequency = np.bincount(zbin[~is_fill], minlength=nbins) n_empty = (bin_frequency == 0).sum() logger.warning( f"""Largest bin population is {bin_frequency.max()} {n_empty} of {nbins} bins ({100 * n_empty / nbins:.1f}%) are empty """ ) if not bin_visited.all(): logger.warning(f"{(~bin_visited).sum()} bins went unvisited.") if (bin_visited > 1).any(): logger.warning(f"({(bin_visited > 1).sum()} bins visted more than once.") if nbins - bin_frequency.shape[0] != 0: raise ValueError( f"{nbins - bin_frequency.shape[0]} mesh cells do not have an associated z-value" ) bin_id = SpiralMeshBinID(zbin, fill, bin_visited) self._bin_id = bin_id return bin_id
[docs] def place_spectra_in_mesh(self): self.generate_mesh() bin_id = self.calculate_bin_number() return bin_id
[docs] def build_cat(self): bin_id = self.bin_id.id fill = self.bin_id.fill # Integer number corresponds to the order over # which the mesh was traversed. cat = pd.Categorical(bin_id, ordered=False) if fill in bin_id: cat.remove_categories(fill, inplace=True) self._cat = cat
[docs] class SpiralPlot2D(base.PlotWithZdata, base.CbarMaker): r"""2D spiral plotting with adaptive mesh refinement. Examples -------- splot = SpiralPlot2D(...) splot.initialize_mesh() """
[docs] def __init__( self, x, y, z=None, logx=False, logy=False, initial_bins=5, clip_data=False ): super().__init__() self.set_log(x=logx, y=logy) self.set_data(x, y, z, clip_data) self.set_labels(x="x", y="y", z=labels_module.Count() if z is None else "z") self.calc_initial_bins(initial_bins) self.set_clim(None, None)
@property def clim(self): return self._clim @property def initial_bins(self): return dict(self._initial_bins) @property def grouped(self): return self._grouped @property def mesh(self): return self._mesh
[docs] def agg(self, fcn=None): r"""Aggregate the z-values into their bins.""" self.logger.debug("aggregating z-data") if fcn is None: if self.data.loc[:, "z"].unique().size == 1: fcn = "count" else: fcn = "mean" gb = self.grouped agg = gb.agg(fcn) c0, c1 = self.clim if c0 is not None or c1 is not None: cnt = gb.agg("count") tk = pd.Series(True, index=agg.index) if c0 is not None: tk = tk & (cnt >= c0) if c1 is not None: tk = tk & (cnt <= c1) agg = agg.where(tk) # reindex to ensure we have a z-value for every bin. reindex = pd.RangeIndex(start=0, stop=self.mesh.mesh.shape[0], step=1) agg = agg.reindex(reindex) cell_filter = self.mesh.cell_filter if agg.shape != cell_filter.shape: raise ValueError( f"""Unable to algin `agg` and `cell_filter. agg : {agg.shape} filter : {cell_filter.shape}""" ) agg = agg.where(cell_filter, axis=0) return agg
[docs] def build_grouped(self): cat = self.mesh.cat z = self.data.loc[:, "z"] if not (cat.size == z.size): raise ValueError( f"""`cat` must have same size as data's first dimesion cat : {cat.size} data : {z.size} """ ) gb = z.groupby(cat) self._grouped = gb
[docs] def calc_initial_bins(self, nbins): data = self.data keys = ("x", "y") bins = {} if isinstance(nbins, int): # Single paramter for `nbins`. nbins = {k: nbins for k in keys} elif len(nbins) == len(keys): # Passed one bin spec per axis nbins = {k: v for k, v in zip(keys, nbins)} else: msg = f"Unrecognized `nbins`\ntype: {type(nbins)}\n bins:{nbins}" raise ValueError(msg) for k, b in nbins.items(): # Numpy and Astropy don't like NaNs when calculating bins. # Infinities in bins (typically from log10(0)) also create problems. d = data.loc[:, k].replace([-np.inf, np.inf], np.nan).dropna() if not isinstance(b, (int, np.ndarray)): raise TypeError("Only want in integer or np.ndarrays for initial edges") if isinstance(b, int): # Lets calculate the following quantiles. b = np.quantile( d, np.linspace(0, 1, b + 1) ) # Need N + 1 edges to make N bins. # Extend the right most bin by the larger of 1% or 0.01 (in the case of zero) # So that y < y1 inludes data at real data edge. b[-1] = np.max([0.01, 1.01 * b.max()]) assert not np.isnan(b).any() bins[k] = b bins = tuple(bins.items()) self._initial_bins = bins return bins
[docs] def initialize_mesh(self, **kwargs): x = self.data.loc[:, "x"] y = self.data.loc[:, "y"] xbins = self.initial_bins["x"] ybins = self.initial_bins["y"] mesh = SpiralMesh(x, y, xbins, ybins, **kwargs) # Attach mesh before anything else. # Makes debugging easier. self._mesh = mesh mesh.place_spectra_in_mesh() mesh.build_cat()
[docs] def set_clim(self, lower=None, upper=None): """Set the min (lower) and max (upper) counts per bin. This limit is applied after the :py:meth:`groupby.agg` is run.""" assert isinstance(lower, Number) or lower is None assert isinstance(upper, Number) or upper is None self._clim = base.RangeLimits(lower, upper)
[docs] def set_data(self, x, y, z, clip): super().set_data(x, y, z, clip) data = self.data if self.log.x: data.loc[:, "x"] = np.log10(np.abs(data.loc[:, "x"])) if self.log.y: data.loc[:, "y"] = np.log10(np.abs(data.loc[:, "y"])) self._data = data
def _limit_color_norm(self, norm): pct = self.data.loc[:, "z"].quantile([0.01, 0.99]) v0 = pct.loc[0.01] v1 = pct.loc[0.99] if norm.vmin is None: norm.vmin = v0 if norm.vmax is None: norm.vmax = v1 norm.clip = True
[docs] def make_plot( self, ax=None, cbar=True, limit_color_norm=False, cbar_kwargs=None, fcn=None, alpha_fcn=None, **kwargs, ): if ax is None: fig, ax = plt.subplots() C = self.agg(fcn=fcn) C = np.ma.masked_invalid(C.values) assert isinstance(C, np.ndarray) assert C.ndim == 1 if C.shape[0] != self.mesh.mesh.shape[0]: raise ValueError( f"""{self.mesh.mesh.shape[0] - C.shape[0]} mesh cells do not have a z-value associated with them. The z-values and mesh are not properly aligned.""" ) xmesh = self.mesh.mesh[:, [0, 1]] ymesh = self.mesh.mesh[:, [2, 3]] if self.log.x: xmesh = 10.0**xmesh if self.log.y: ymesh = 10.0**ymesh # (x,y) of bin's lower left corner. xy = zip(xmesh[:, 0], ymesh[:, 0]) dx = xmesh[:, 1] - xmesh[:, 0] dy = ymesh[:, 1] - ymesh[:, 0] start1 = datetime.now() self.logger.warning("Making patches") self.logger.warning(f"Start {start1}") patches = [ mpl.patches.Rectangle(this_xy, this_dx, this_dy) for this_xy, this_dx, this_dy in zip(xy, dx, dy) ] stop1 = datetime.now() self.logger.warning(f"Stop {stop1}") self.logger.warning(f"Elapsed {stop1 - start1}") # TODO: `match_original=False` if calculate alpha for each patch. edgecolors = "none" linewidth = 0.0 collection = mpl.collections.PatchCollection( patches, linewidth=linewidth, edgecolors=edgecolors ) collection.set_array(C) cmap = kwargs.pop("cmap", None) norm = kwargs.pop("norm", None) if len(kwargs): raise ValueError(f"Unexpected kwargs {kwargs.keys()}") if limit_color_norm and norm is not None: self._limit_color_norm(norm) collection.set_alpha(None) collection.set_cmap(cmap) collection.set_norm(norm) collection.autoscale_None() ax.add_collection(collection, autolim=False) minx = xmesh[:, 0].min() miny = ymesh[:, 0].min() maxx = xmesh[:, 1].max() maxy = ymesh[:, 1].max() collection.sticky_edges.x[:] = [minx, maxx] collection.sticky_edges.y[:] = [miny, maxy] corners = (minx, miny), (maxx, maxy) ax.update_datalim(corners) ax.autoscale_view() cbar_or_mappable = collection if cbar: if cbar_kwargs is None: cbar_kwargs = dict() if "cax" not in cbar_kwargs.keys() and "ax" not in cbar_kwargs.keys(): cbar_kwargs["ax"] = ax cbar = self._make_cbar(collection, norm=norm, **cbar_kwargs) cbar_or_mappable = cbar self._format_axis(ax) if alpha_fcn is not None: alpha_agg = np.ma.masked_invalid(self.agg(fcn=alpha_fcn).values) # Feature scale then invert so smallest STD # is most opaque. alpha_agg = mpl.colors.Normalize()(alpha_agg) alpha = 1 - alpha_agg self.logger.warning("Scaling alpha filter as alpha**0.25") alpha = alpha**0.25 # Set masked values to zero. Otherwise, masked # values are rendered as black. alpha = alpha.filled(0) # Must draw to initialize `facecolor`s plt.draw() colors = collection.get_facecolors() colors[:, 3] = alpha collection.set_facecolor(colors) return ax, cbar_or_mappable
def _verify_contour_passthrough_kwargs( self, ax, clabel_kwargs, edges_kwargs, cbar_kwargs ): if clabel_kwargs is None: clabel_kwargs = dict() if edges_kwargs is None: edges_kwargs = dict() if cbar_kwargs is None: cbar_kwargs = dict() if "cax" not in cbar_kwargs.keys() and "ax" not in cbar_kwargs.keys(): cbar_kwargs["ax"] = ax return clabel_kwargs, edges_kwargs, cbar_kwargs def _interpolate_to_grid(self, x, y, z, resolution=100, method="cubic"): r"""Interpolate scattered data to a regular grid. Parameters ---------- x, y : np.ndarray Coordinates of data points. z : np.ndarray Values at data points. resolution : int Number of grid points along each axis. method : {"linear", "cubic", "nearest"} Interpolation method passed to :func:`scipy.interpolate.griddata`. Returns ------- XX, YY : np.ndarray 2D meshgrid arrays. ZZ : np.ndarray Interpolated values on the grid. """ from scipy.interpolate import griddata xi = np.linspace(x.min(), x.max(), resolution) yi = np.linspace(y.min(), y.max(), resolution) XX, YY = np.meshgrid(xi, yi) ZZ = griddata((x, y), z, (XX, YY), method=method) return XX, YY, ZZ def _interpolate_with_rbf( self, x, y, z, resolution=100, neighbors=50, smoothing=1.0, kernel="thin_plate_spline", ): r"""Interpolate scattered data using sparse RBF. Uses :class:`scipy.interpolate.RBFInterpolator` with the ``neighbors`` parameter for efficient O(N·k) computation instead of O(N²). Parameters ---------- x, y : np.ndarray Coordinates of data points. z : np.ndarray Values at data points. resolution : int Number of grid points along each axis. neighbors : int Number of nearest neighbors to use for each interpolation point. Higher values produce smoother results but increase computation time. smoothing : float Smoothing parameter. Higher values produce smoother surfaces. kernel : str RBF kernel type. Options include "thin_plate_spline", "cubic", "quintic", "multiquadric", "inverse_multiquadric", "gaussian". Returns ------- XX, YY : np.ndarray 2D meshgrid arrays. ZZ : np.ndarray Interpolated values on the grid. """ from scipy.interpolate import RBFInterpolator points = np.column_stack([x, y]) rbf = RBFInterpolator( points, z, neighbors=neighbors, smoothing=smoothing, kernel=kernel ) xi = np.linspace(x.min(), x.max(), resolution) yi = np.linspace(y.min(), y.max(), resolution) XX, YY = np.meshgrid(xi, yi) grid_pts = np.column_stack([XX.ravel(), YY.ravel()]) ZZ = rbf(grid_pts).reshape(XX.shape) return XX, YY, ZZ
[docs] def plot_contours( self, ax=None, method="rbf", # RBF method params (default method) rbf_neighbors=50, rbf_smoothing=1.0, rbf_kernel="thin_plate_spline", # Grid method params grid_resolution=100, gaussian_filter_std=1.5, interpolation="cubic", nan_aware_filter=True, # Common params label_levels=True, cbar=True, cbar_kwargs=None, fcn=None, clabel_kwargs=None, skip_max_clbl=True, use_contourf=False, **kwargs, ): r"""Make a contour plot from adaptive mesh data with optional smoothing. Supports three interpolation methods for generating contours from the irregular adaptive mesh: - ``"rbf"``: Sparse RBF interpolation (default, fastest with built-in smoothing) - ``"grid"``: Grid interpolation + Gaussian smoothing (matches Hist2D API) - ``"tricontour"``: Direct triangulated contours (no smoothing, for debugging) Parameters ---------- ax : mpl.axes.Axes, None If None, create an Axes instance from ``plt.subplots``. method : {"rbf", "grid", "tricontour"} Interpolation method. Default is ``"rbf"`` (fastest with smoothing). RBF Method Parameters --------------------- rbf_neighbors : int Number of nearest neighbors for sparse RBF. Higher = smoother but slower. Default is 50. rbf_smoothing : float RBF smoothing parameter. Higher values produce smoother surfaces. Default is 1.0. rbf_kernel : str RBF kernel type. Options: "thin_plate_spline", "cubic", "quintic", "multiquadric", "inverse_multiquadric", "gaussian". Grid Method Parameters ---------------------- grid_resolution : int Number of grid points along each axis. Default is 100. gaussian_filter_std : float Standard deviation for Gaussian smoothing. Default is 1.5. Set to 0 to disable smoothing. interpolation : {"linear", "cubic", "nearest"} Interpolation method for griddata. Default is "cubic". nan_aware_filter : bool If True, use NaN-aware Gaussian filtering. Default is True. Common Parameters ----------------- label_levels : bool If True, add labels to contours with ``ax.clabel``. Default is True. cbar : bool If True, create a colorbar. Default is True. cbar_kwargs : dict, None Keyword arguments passed to ``self._make_cbar``. fcn : callable, None Aggregation function. If None, automatically select in :py:meth:`agg`. clabel_kwargs : dict, None Keyword arguments passed to ``ax.clabel``. skip_max_clbl : bool If True, don't label the maximum contour level. Default is True. use_contourf : bool If True, use filled contours. Default is False. **kwargs Additional arguments passed to the contour function. Common options: ``levels``, ``cmap``, ``norm``, ``linestyles``. Returns ------- ax : mpl.axes.Axes The axes containing the plot. lbls : list or None Contour labels if ``label_levels=True``, else None. cbar_or_mappable : Colorbar or QuadContourSet The colorbar if ``cbar=True``, else the contour set. qset : QuadContourSet The contour set object. Examples -------- >>> # Default: sparse RBF (fastest) >>> ax, lbls, cbar, qset = splot.plot_contours() >>> # Grid interpolation with Gaussian smoothing >>> ax, lbls, cbar, qset = splot.plot_contours( ... method='grid', ... grid_resolution=100, ... gaussian_filter_std=2.0 ... ) >>> # Debug: see raw triangulation >>> ax, lbls, cbar, qset = splot.plot_contours(method='tricontour') """ from .tools import nan_gaussian_filter # Validate method valid_methods = ("rbf", "grid", "tricontour") if method not in valid_methods: raise ValueError( f"Invalid method '{method}'. Must be one of {valid_methods}." ) # Pop contour-specific kwargs levels = kwargs.pop("levels", None) cmap = kwargs.pop("cmap", None) norm = kwargs.pop("norm", None) linestyles = kwargs.pop( "linestyles", [ "-", ":", "--", (0, (7, 3, 1, 3, 1, 3, 1, 3, 1, 3)), "--", ":", "-", (0, (7, 3, 1, 3, 1, 3)), ], ) if ax is None: fig, ax = plt.subplots() # Setup kwargs for clabel and cbar ( clabel_kwargs, _edges_kwargs, cbar_kwargs, ) = self._verify_contour_passthrough_kwargs( ax, clabel_kwargs, None, cbar_kwargs ) inline = clabel_kwargs.pop("inline", True) inline_spacing = clabel_kwargs.pop("inline_spacing", -3) fmt = clabel_kwargs.pop("fmt", "%s") # Get aggregated data and mesh cell centers C = self.agg(fcn=fcn).values if C.shape[0] != self.mesh.mesh.shape[0]: raise ValueError( f"{self.mesh.mesh.shape[0] - C.shape[0]} mesh cells do not have " "a z-value. The z-values and mesh are not properly aligned." ) x = self.mesh.mesh[:, [0, 1]].mean(axis=1) y = self.mesh.mesh[:, [2, 3]].mean(axis=1) if self.log.x: x = 10.0**x if self.log.y: y = 10.0**y # Filter to finite values tk_finite = np.isfinite(C) x = x[tk_finite] y = y[tk_finite] C = C[tk_finite] # Select contour function based on method if method == "tricontour": # Direct triangulated contour (no smoothing) contour_fcn = ax.tricontourf if use_contourf else ax.tricontour if levels is None: args = [x, y, C] else: args = [x, y, C, levels] qset = contour_fcn( *args, linestyles=linestyles, cmap=cmap, norm=norm, **kwargs ) else: # Interpolate to regular grid (rbf or grid method) if method == "rbf": XX, YY, ZZ = self._interpolate_with_rbf( x, y, C, resolution=grid_resolution, neighbors=rbf_neighbors, smoothing=rbf_smoothing, kernel=rbf_kernel, ) else: # method == "grid" XX, YY, ZZ = self._interpolate_to_grid( x, y, C, resolution=grid_resolution, method=interpolation, ) # Apply Gaussian smoothing if requested if gaussian_filter_std > 0: if nan_aware_filter: ZZ = nan_gaussian_filter(ZZ, sigma=gaussian_filter_std) else: from scipy.ndimage import gaussian_filter ZZ = gaussian_filter( np.nan_to_num(ZZ, nan=0), sigma=gaussian_filter_std ) # Mask invalid values ZZ = np.ma.masked_invalid(ZZ) # Standard contour on regular grid contour_fcn = ax.contourf if use_contourf else ax.contour if levels is None: args = [XX, YY, ZZ] else: args = [XX, YY, ZZ, levels] qset = contour_fcn( *args, linestyles=linestyles, cmap=cmap, norm=norm, **kwargs ) # Handle contour labels try: label_args = (qset, levels[:-1] if skip_max_clbl else levels) except TypeError: label_args = (qset,) class _NumericFormatter(float): """Format float without trailing zeros for contour labels.""" def __repr__(self): # Use float's repr to avoid recursion (str(self) calls __repr__) return float.__repr__(self).rstrip("0").rstrip(".") lbls = None if label_levels and len(qset.levels) > 0: qset.levels = [_NumericFormatter(level) for level in qset.levels] lbls = ax.clabel( *label_args, inline=inline, inline_spacing=inline_spacing, fmt=fmt, **clabel_kwargs, ) # Add colorbar cbar_or_mappable = qset if cbar: cbar_obj = self._make_cbar(qset, norm=norm, **cbar_kwargs) cbar_or_mappable = cbar_obj self._format_axis(ax) return ax, lbls, cbar_or_mappable, qset