Source code for pyttop.plot.plotfuncs

# -*- coding: utf-8 -*-
"""
Created on Mon Aug  5 19:04:11 2024

@author: Yu-Chen Wang
"""

import numpy as np
import matplotlib.pyplot as plt
from .base import plotFunc, plotFuncAx
from .base import scatter as ptscatter
from collections.abc import Iterable

__all__ = [
    'refline', 'annotate',
    'binned_quantiles'
    ]

[docs] @plotFunc def refline(x=None, y=None, xpos=.1, ypos=.1, xtxt=None, ytxt=None, xfmt='.2f', yfmt='.2f', marker='', style='through', label=None, ax=None, **lineargs): ''' Plot reference line(s) and optionally marker(s) at given position(s). This function adds vertical and/or horizontal lines to the plot, anchored at the specified `x` and/or `y` values. Optionally, marker(s) can be drawn at the intersection(s), and text annotations can be shown on the reference line(s) to indicate the values. Parameters ---------- x : float or Iterable, optional The x-coordinate(s) at which to draw vertical reference line(s). The default is None. y : float or Iterable, optional The y-coordinate(s) at which to draw horizontal reference line(s). The default is None. xpos : float or None, optional Relative x (horizontal) position (in axes fraction) for y-value annotation text. If None, no text is shown. The default is 0.1. ypos : float or None, optional Relative y (vertical) position (in axes fraction) for x-value annotation text. If None, no text is shown. The default is 0.1. xtxt : str, optional If not None, the x label text will be overwritten by this. ytxt : str, optional If not None, the y label text will be overwritten by this. xfmt : str, optional Format string for x label (if ``xtxt`` not specified). The default is ``'.2f'``. yfmt : str, optional Format string for y label (if ``ytxt`` not specified). The default is ``'.2f'``. marker : optional Marker style for the intersection point, if both x and y are provided. The default is '' (no marker). style : {'through', 'axis'}, optional Line style: - ``'through'``: line(s) extend across the full axis. - ``'axis'``: only plot line(s) on the left and/or beneath the point. The default is ``'through'``. label : str, optional Label assigned to the line(s), useful for legends. ax : matplotlib.axes.Axes, optional The axis on which to plot. If None, uses the current axis. **lineargs : Additional keyword arguments passed to ``ax.axhline`` and ``ax.axvline``. ''' def _format_val(v, formatter): return (v - formatter.offset) / 10.**formatter.orderOfMagnitude # check input if style not in ['through', 'axis']: raise ValueError(f"'style' should be 'through' or 'axis', got '{style}'") artists = {} if ax is None: ax = plt.gca() xmin, xmax = ax.get_xlim() xscale = ax.get_xscale() if xscale == 'log': dx = np.log10(xmax) - np.log10(xmin) else: dx = xmax - xmin ymin, ymax = ax.get_ylim() yscale = ax.get_yscale() if yscale == 'log': dy = np.log10(ymax) - np.log10(ymin) else: dy = ymax - ymin fig = ax.figure if x is None and y is None: raise ValueError('You should at least specify one of the parameters: "x" and "y".') if x is not None: if isinstance(x, Iterable): xs = x else: xs = [x] if not isinstance(xpos, Iterable): xposs = [xpos]*len(xs) else: xposs = xpos if y is not None: if isinstance(y, Iterable): ys = y else: ys = [y] if not isinstance(ypos, Iterable): yposs = [ypos]*len(ys) else: yposs = ypos if x is None: xs = [xmax]*len(ys) if y is None: ys = [ymax]*len(xs) plotx, ploty = False, False if x is not None: plotx = True if y is not None: ploty = True if plotx: for i, info in enumerate(zip(xs, xposs, ys)): x, xpos, y = info if style == 'through': lineymax = 1 elif style == 'axis': lineymax = (np.log10(y)-np.log10(ymin))/dy if yscale == 'log' else (y-ymin)/dy if i != 0: label = None artists['vline'] = ax.axvline(x, ymax=lineymax, label=label, **lineargs) if xpos is not None: fig.canvas.draw() # makes sure the ScalarFormatter has been set x_fmter = ax.xaxis.get_major_formatter() if xtxt is None: xtxt1 = f'{_format_val(x, x_fmter):{xfmt}}' else: xtxt1 = xtxt if yscale == 'log': yt = ymin * (ymax/ymin)**ypos else: yt = ymin + ypos * dy artists['vtext'] = ax.text(x, yt, xtxt1, horizontalalignment='center', backgroundcolor='white') if ploty: for i, info in enumerate(zip(ys, yposs, xs)): y, ypos, x = info if style == 'through': linexmax = 1 elif style == 'axis': linexmax = (np.log10(x)-np.log10(xmin))/dx if xscale == 'log' else (x-xmin)/dx if i != 0 or plotx: label = None artists['hline'] = ax.axhline(y, xmax=linexmax, label=label, **lineargs) if ypos is not None: fig.canvas.draw() # makes sure the ScalarFormatter has been set y_fmter = ax.yaxis.get_major_formatter() if ytxt is None: ytxt1 = f'{_format_val(y, y_fmter):{yfmt}}' else: ytxt1 = ytxt if xscale == 'log': xt = xmin * (xmax/xmin)**xpos else: xt = xmin + xpos * dx artists['htext'] = ax.text(xt, y, ytxt1, verticalalignment='center', backgroundcolor='white') if plotx and ploty: artists['scat'] = ax.scatter(x, y, marker=marker, c='k') return artists
# annotate = plotFunc(_annotate) annotate = refline
[docs] @plotFunc def binned_quantiles(x, y, bin_size=.1, bin_dist=.1, quantiles=[.16, .50, .84], min_n=10, xmin=None, xmax=None, show_scatter=True, s=None, c=None, label=None, show_bins=True, show_errorbars=True, emarker='o', es=5, ec=None, elabel=None, show_fill=False, fc=None, flabel=None, errkwargs={}, fillkwargs={}, **kwargs ): """ Plot sliding-window quantile errorbars/fill. This function visualizes a 2D distribution by plotting raw (x, y) points and computing sliding-window quantiles in x-bins. It then overlays error bars and/or filled regions to represent variability (e.g. 16th–84th percentile) in y-values within each x-bin. This is useful when visualizing scatter data along with robust estimates of central tendency and spread. Parameters ---------- x, y : array-like Data coordinates. bin_size : float, optional Width of each sliding bin in x-units. Default is 0.1. bin_dist : float, optional Step size between consecutive bin positions (i.e., sliding window stride). Default is 0.1. quantiles : list of 3 floats, optional List of quantiles to compute within each bin. Must be in increasing order. Default is [0.16, 0.50, 0.84]. min_n : int, optional Minimum number of data points required in a bin to compute quantiles. Default is 10. xmin, xmax : float, optional Range of x-values to include in binning. If None, inferred from data. show_scatter : bool, optional If True, plot the raw scatter points. Default is True. s, c : optional Marker size and color for scatter points. label : str, optional Label for the scatter plot. show_bins : bool, optional If True, include horizontal error bars showing bin width. Default is True. show_errorbars : bool, optional If True, show vertical error bars (quantile-based). Default is True. emarker : str, optional Marker style for error bar midpoints. Default is 'o'. es : float, optional Marker size for error bars. Default is 5. ec : color, optional Color for error bars. elabel : str, optional Label for the error bars. If None and ``show_scatter`` is False, inherits ``label``. show_fill : bool, optional If True, fills the area between lower and upper quantiles. Default is False. fc : color, optional Fill color. If None, inherits from ``ec``. flabel : str, optional Label for the fill. errkwargs : dict, optional Additional keyword arguments passed to ``plt.errorbar``. fillkwargs : dict, optional Additional keyword arguments passed to ``plt.fill_between``. **kwargs : dict Additional keyword arguments passed to the scatter plot. """ # np.asarray() or np.array() will return a base np.ndarray (masks will be lost) x, y = np.asanyarray(x), np.asanyarray(y) artists = {} if show_scatter: ptscatter(x, y, s=s, c=c, label=label, **kwargs) artists['scatter'] = ptscatter.s # scatter = plt.scatter(x, y, s=s, c=c, label=label, **kwargs) if not show_scatter and elabel is None: elabel = label if fc is None: fc = ec mask = np.ma.getmaskarray(x) | np.ma.getmaskarray(y) x = x[~mask] y = y[~mask] if xmin is None: xmin = np.min(x) if xmax is None: xmax = np.max(x) x_lefts = np.arange(xmin, xmax - bin_size + bin_dist, bin_dist) # print(np.min(x), np.max(x), bin_size, x_lefts) x_centers = x_lefts + bin_size / 2 x_rights = x_lefts + bin_size ymids, yq0s, yq1s = [], [], [] for left, right in zip(x_lefts, x_rights): assert right == left + bin_size in_bin = (x >= left) & (x < right) if np.sum(in_bin) >= min_n: yq0, ymid, yq1 = np.quantile(y[in_bin], quantiles) else: yq0, ymid, yq1 = np.nan, np.nan, np.nan ymids.append(ymid) yq0s.append(yq0) yq1s.append(yq1) ymids = np.array(ymids) yq0s = np.array(yq0s) yq1s = np.array(yq1s) ekwargs = dict( linestyle='', ) ekwargs.update(errkwargs) if show_bins: xerr = [x_centers-x_lefts, x_rights-x_centers] else: xerr = None if show_errorbars: yerr = [ymids-yq0s, yq1s-ymids] else: yerr = None artists['errorbar'] = plt.errorbar( x_centers, ymids, xerr=xerr, yerr=yerr, marker=emarker, markersize=es, color=ec, label=elabel, **ekwargs, ) if show_fill: fkwargs = {'alpha': .2} | fillkwargs artists['fill'] = plt.fill_between( x_centers, yq1s, yq0s, color=fc, label=flabel, **fkwargs, ) return artists