#!/usr/bin/env python
r"""Utility functions for common :mod:`matplotlib` tasks.
These helpers provide shortcuts for creating figures, saving output, building grids
of axes with shared colorbars, and NaN-aware image filtering.
"""
import logging
import numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
from datetime import datetime
from pathlib import Path
from scipy.ndimage import gaussian_filter
# Path to the solarwindpy style file
_STYLE_PATH = Path(__file__).parent / "solarwindpy.mplstyle"
[docs]
def use_style():
r"""Apply the SolarWindPy matplotlib style.
This sets publication-ready defaults including:
- 4x4 inch figure size
- 12pt base font size
- Spectral_r colormap
- 300 DPI PDF output
Examples
--------
>>> import solarwindpy.plotting as swp_pp
>>> swp_pp.use_style() # doctest: +SKIP
"""
plt.style.use(_STYLE_PATH)
[docs]
def subplots(nrows=1, ncols=1, scale_width=1.0, scale_height=1.0, **kwargs):
r"""Create a grid of subplots with a scaled figure size.
Parameters
----------
nrows : int, optional
Number of subplot rows.
ncols : int, optional
Number of subplot columns.
scale_width : float, optional
Factor applied to the default figure width.
scale_height : float, optional
Factor applied to the default figure height.
**kwargs
Additional keyword arguments passed directly to
:func:`matplotlib.pyplot.subplots`.
Returns
-------
fig : :class:`matplotlib.figure.Figure`
ax : :class:`matplotlib.axes.Axes` or array of Axes
Examples
--------
>>> fig, ax = subplots(2, 2, scale_width=1.5)
"""
scale = np.array([scale_width * ncols, scale_height * nrows])
figsize = scale * kwargs.pop("figsize", mpl.rcParams["figure.figsize"])
return plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, **kwargs)
[docs]
def save(
fig,
spath,
add_info=True,
info_x=0,
info_y=0,
log=True,
pdf=True,
png=True,
**kwargs,
):
r"""Save a figure in both PDF and PNG formats.
Parameters
----------
fig : :class:`matplotlib.figure.Figure` or :class:`matplotlib.axes.Axes`
The figure or axis to save.
spath : :class:`pathlib.Path`
Base path for the output files. The appropriate extension will be
added automatically.
add_info : bool, optional
If ``True``, add an attribution and timestamp to the bottom left of the
PNG version.
info_x : float, optional
X-position of the attribution text in figure coordinates.
info_y : float, optional
Y-position of the attribution text in figure coordinates.
log : bool, optional
If ``True``, write information about the saved files to ``alog``.
pdf : bool, optional
Save a PDF version of the figure.
png : bool, optional
Save a PNG version of the figure.
**kwargs
Additional keyword arguments passed to :meth:`Figure.savefig`.
Returns
-------
None
Examples
--------
>>> fig, ax = subplots()
>>> save(fig, Path('my_plot'))
"""
if isinstance(fig, mpl.axes.Axes):
fig = fig.figure
assert isinstance(fig, mpl.figure.Figure)
assert isinstance(spath, Path)
# tight_layout = kwargs.pop("tight_layout", True)
bbox_inches = kwargs.pop("bbox_inches", "tight")
# if tight_layout:
# fig.tight_layout()
# Save the PDF without the timestamp so we can create the final LaTeX file
# without them.
# Add the datetime stamp to the PNG as those are what we render most often when
# working, drafting, etc.
if log:
alog = logging.getLogger(__name__)
alog.info("Saving figure\n%s", spath.resolve().with_suffix(""))
if pdf:
fig.savefig(
spath.with_suffix(".pdf"),
bbox_inches=bbox_inches,
format="pdf",
**kwargs,
)
if log:
alog.info("Suffix saved: pdf")
if png:
if add_info:
info = "B. L. Alterman {}".format(datetime.now().strftime("%Y%m%dT%H%M%S"))
fig.text(info_x, info_y, info)
fig.savefig(
spath.with_suffix(".png"),
bbox_inches=bbox_inches,
format="png",
**kwargs,
)
if log:
alog.info("Suffix saved: png")
[docs]
def joint_legend(*axes, idx_for_legend=-1, **kwargs):
r"""Create a combined legend for multiple axes.
Parameters
----------
*axes : :class:`matplotlib.axes.Axes`
Axes objects from which to collect legend handles and labels.
idx_for_legend : int, optional
Index of the axis (after flattening) on which to place the legend.
By default the legend is placed on the last axis. ``idx_for_legend=-1``
assumes that the last axis is on the right hand side of the figure.
**kwargs
Extra keyword arguments forwarded to :meth:`Axes.legend`.
Returns
-------
legend : :class:`matplotlib.legend.Legend`
Examples
--------
>>> fig, ax = subplots(1, 2)
>>> ax[0].plot([1, 2], label='a') # doctest: +ELLIPSIS
[<matplotlib.lines.Line2D object at 0x...>]
>>> ax[1].plot([2, 3], label='b') # doctest: +ELLIPSIS
[<matplotlib.lines.Line2D object at 0x...>]
>>> joint_legend(ax[0], ax[1]) # doctest: +ELLIPSIS
<matplotlib.legend.Legend object at 0x...>
"""
axes = np.array(axes).ravel()
handles = []
labels = []
for ax in axes:
hdl, lbl = ax.get_legend_handles_labels()
for i, l in enumerate(lbl):
if l not in labels:
h = hdl[i]
if isinstance(h, mpl.container.ErrorbarContainer):
h = h[0]
# h = hdl[i]
# try:
# if len(h) == 3:
# # Used `ax.errorbar`, not `ax.plot`.
# h = h[0]
# except TypeError:
# pass
labels.append(l)
handles.append(h)
handles = np.array(handles)
labels = np.array(labels)
sorter = np.argsort(labels)
labels = labels[sorter]
handles = handles[sorter]
loc = kwargs.pop("loc", (1.05, 0.1))
return axes[idx_for_legend].legend(handles, labels, loc=loc, **kwargs)
[docs]
def build_ax_array_with_common_colorbar( # noqa: C901 - complexity justified by 4 cbar positions
nrows=1,
ncols=1,
cbar_loc="top",
figsize="auto",
sharex=True,
sharey=True,
hspace=0,
wspace=0,
fig_kwargs=None,
gs_kwargs=None,
):
r"""Build an array of axes that share a colour bar.
Parameters
----------
nrows, ncols : int, optional
Desired grid shape.
cbar_loc : {"top", "bottom", "left", "right"}, optional
Location of the colorbar relative to the axes grid.
figsize : tuple or "auto", optional
Figure size as (width, height) in inches. If ``"auto"`` (default),
scales from ``rcParams["figure.figsize"]`` based on nrows/ncols.
sharex : bool, optional
If ``True``, share x-axis limits across all panels. Default ``True``.
sharey : bool, optional
If ``True``, share y-axis limits across all panels. Default ``True``.
hspace : float, optional
Vertical spacing between subplots. Default ``0``.
wspace : float, optional
Horizontal spacing between subplots. Default ``0``.
fig_kwargs : dict, optional
Keyword arguments forwarded to :func:`matplotlib.pyplot.figure`.
gs_kwargs : dict, optional
Additional options for :class:`matplotlib.gridspec.GridSpec`.
Returns
-------
fig : :class:`matplotlib.figure.Figure`
axes : ndarray of :class:`matplotlib.axes.Axes`
cax : :class:`matplotlib.axes.Axes`
Examples
--------
>>> fig, axes, cax = build_ax_array_with_common_colorbar(2, 3, cbar_loc='right') # doctest: +SKIP
>>> fig, axes, cax = build_ax_array_with_common_colorbar(3, 1, figsize=(5, 12)) # doctest: +SKIP
"""
if fig_kwargs is None:
fig_kwargs = dict()
if gs_kwargs is None:
gs_kwargs = dict()
cbar_loc = cbar_loc.lower()
if cbar_loc not in ("top", "bottom", "left", "right"):
raise ValueError
# Compute figsize
if figsize == "auto":
base_figsize = np.array(mpl.rcParams["figure.figsize"])
fig_scale = np.array([ncols, nrows])
if cbar_loc in ("right", "left"):
cbar_scale = np.array([1.3, 1])
else:
cbar_scale = np.array([1, 1.3])
figsize = base_figsize * fig_scale * cbar_scale
# Compute grid ratios (independent of figsize)
if cbar_loc in ("right", "left"):
height_ratios = nrows * [1]
width_ratios = (ncols * [1]) + [0.05, 0.075]
if cbar_loc == "left":
width_ratios = width_ratios[::-1]
else:
height_ratios = [0.075, 0.05] + (nrows * [1])
if cbar_loc == "bottom":
height_ratios = height_ratios[::-1]
width_ratios = ncols * [1]
fig = plt.figure(figsize=figsize, **fig_kwargs)
# print(cbar_loc)
# print(nrows, ncols)
# print(len(height_ratios), len(width_ratios))
# print()
gs = mpl.gridspec.GridSpec(
len(height_ratios),
len(width_ratios),
hspace=hspace,
wspace=wspace,
height_ratios=height_ratios,
width_ratios=width_ratios,
**gs_kwargs,
)
if cbar_loc == "left":
cax = gs[:, 0]
col_range = range(2, ncols + 2)
row_range = range(nrows)
elif cbar_loc == "right":
cax = gs[:, -1]
col_range = range(0, ncols)
row_range = range(nrows)
elif cbar_loc == "top":
cax = gs[0, :]
col_range = range(ncols)
row_range = range(2, nrows + 2)
elif cbar_loc == "bottom":
cax = gs[-1, :]
col_range = range(ncols)
row_range = range(0, nrows)
else:
raise ValueError
cax = fig.add_subplot(cax)
# Create axes with sharex/sharey using modern matplotlib API
# (The old .get_shared_x_axes().join() approach is deprecated in matplotlib 3.6+)
axes = np.empty((nrows, ncols), dtype=object)
first_ax = None
for row_idx, i in enumerate(row_range):
for col_idx, j in enumerate(col_range):
if first_ax is None:
ax = fig.add_subplot(gs[i, j])
first_ax = ax
else:
ax = fig.add_subplot(
gs[i, j],
sharex=first_ax if sharex else None,
sharey=first_ax if sharey else None,
)
axes[row_idx, col_idx] = ax
if cbar_loc == "top":
cax.xaxis.set_ticks_position("top")
cax.xaxis.set_label_position("top")
elif cbar_loc == "left":
cax.yaxis.set_ticks_position("left")
cax.yaxis.set_label_position("left")
if axes.shape != (nrows, ncols):
raise ValueError( # noqa: E203 - aligned table format intentional
f"Unexpected axes shape\nExpected : {(nrows, ncols)}\nCreated : {axes.shape}"
)
# print("rows")
# print(list(row_range))
# print(height_ratios)
# print()
# print("cols")
# print(list(col_range))
# print(width_ratios)
axes = axes.squeeze()
if axes.ndim == 0:
axes = axes.item()
return fig, axes, cax
[docs]
def calculate_nrows_ncols(n):
r"""Determine a sensible ``(nrows, ncols)`` pair for ``n`` axes.
The heuristic attempts to generate a nearly square layout while also taking
typical display aspect ratios into account.
Parameters
----------
n : int
Total number of axes required.
Returns
-------
nrows : int
ncols : int
Examples
--------
>>> calculate_nrows_ncols(5) # doctest: +ELLIPSIS
(...2..., ...3...)
"""
root = int(np.fix(np.sqrt(n)))
while n % root:
root -= 1
other = int(n / root)
if ((other == 1) or (root == 1)) and (n > 4):
n += 1
root = int(np.fix(np.sqrt(n)))
while n % root:
root -= 1
other = int(n / root)
nrows = np.max([root, other])
ncols = np.min([root, other])
if nrows < 4:
nrows, ncols = ncols, nrows
return nrows, ncols
[docs]
def nan_gaussian_filter(array, sigma, **kwargs):
r"""Apply Gaussian filter with proper NaN handling via normalized convolution.
Unlike :func:`scipy.ndimage.gaussian_filter` which propagates NaN values to
all neighboring cells, this function:
1. Smooths valid data correctly near NaN regions
2. Preserves NaN locations (no interpolation into NaN cells)
The algorithm uses normalized convolution: both the data (with NaN replaced
by 0) and a weight mask (1 for valid, 0 for NaN) are filtered. The result
is the ratio of filtered data to filtered weights, ensuring proper
normalization near boundaries.
Parameters
----------
array : np.ndarray
2D array possibly containing NaN values.
sigma : float
Standard deviation for the Gaussian kernel, in pixels.
**kwargs
Additional keyword arguments passed to
:func:`scipy.ndimage.gaussian_filter`.
Returns
-------
np.ndarray
Filtered array with original NaN locations preserved.
See Also
--------
scipy.ndimage.gaussian_filter : Underlying filter implementation.
Notes
-----
This implementation follows the normalized convolution approach described
in [1]_. The key insight is that filtering a weight mask alongside the
data allows proper normalization at boundaries and near missing values.
References
----------
.. [1] Knutsson, H., & Westin, C. F. (1993). Normalized and differential
convolution. In Proceedings of IEEE Conference on Computer Vision and
Pattern Recognition (pp. 515-523).
Examples
--------
>>> import numpy as np
>>> arr = np.array([[1, 2, np.nan], [4, 5, 6], [7, 8, 9]])
>>> result = nan_gaussian_filter(arr, sigma=1.0)
>>> bool(np.isnan(result[0, 2])) # NaN preserved
True
>>> bool(np.isfinite(result[0, 1])) # Neighbor is valid
True
"""
arr = array.copy()
nan_mask = np.isnan(arr)
# Replace NaN with 0 for filtering
arr[nan_mask] = 0
# Create weights: 1 where valid, 0 where NaN
weights = (~nan_mask).astype(float)
# Filter both data and weights
filtered_data = gaussian_filter(arr, sigma=sigma, **kwargs)
filtered_weights = gaussian_filter(weights, sigma=sigma, **kwargs)
# Normalize: weighted average of valid neighbors only
result = np.divide(
filtered_data,
filtered_weights,
where=filtered_weights > 0,
out=np.full_like(filtered_data, np.nan),
)
# Preserve original NaN locations
result[nan_mask] = np.nan
return result