import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Polygon
from uvisbox.Core.CommonInterface import BoxplotStyleConfig
try:
import pyvista as pv
PYVISTA_AVAILABLE = True
except ImportError:
PYVISTA_AVAILABLE = False
[docs]
def visualize_curve_boxplot(mesh_data, boxplot_style=None, ax=None):
"""
Visualize curve boxplot from mesh data.
This function creates a visualization of the curve boxplot using the mesh data
output from the mesh pipeline. For 2D curves, it uses matplotlib. For 3D curves,
it uses PyVista if available, otherwise falls back to matplotlib.
Parameters:
-----------
mesh_data : dict
Dictionary containing mesh data with the following keys:
- 'percentile_meshes': dict of percentile meshes
- 'median_curve': median curve
- 'outliers': outlier curves
- 'n_dims': dimensionality (2 or 3)
boxplot_style : BoxplotStyleConfig, optional
Configuration for the boxplot visualization including percentiles, colors,
and median/outlier styling. If None, uses default configuration.
ax : matplotlib.axes.Axes or pyvista.Plotter, optional
Axes or plotter to use for visualization. Can be:
- matplotlib.axes.Axes for 2D or 3D matplotlib rendering
- pyvista.Plotter for 3D PyVista rendering
- None: creates appropriate visualization object automatically
Returns:
--------
ax : matplotlib.axes.Axes or pyvista.Plotter
The visualization object (matplotlib axes for 2D, PyVista plotter for 3D).
Examples:
---------
>>> import numpy as np
>>> from uvisbox.Modules.CurveBoxplot.curve_boxplot_stats import curve_boxplot_summary_statistics
>>> from uvisbox.Modules.CurveBoxplot.curve_boxplot_mesh import curve_boxplot_mesh
>>> from uvisbox.Modules.CurveBoxplot.curve_boxplot_vis import visualize_curve_boxplot
>>> from uvisbox.Core.CommonInterface import BoxplotStyleConfig
>>>
>>> # Generate synthetic curve data
>>> curves = np.random.randn(50, 100, 2).cumsum(axis=1)
>>>
>>> # Process through pipeline
>>> stats = curve_boxplot_summary_statistics(curves)
>>> mesh_data = curve_boxplot_mesh(stats)
>>>
>>> # Visualize
>>> ax = visualize_curve_boxplot(mesh_data)
>>>
>>> # Custom styling
>>> style = BoxplotStyleConfig(percentiles=[25, 50, 75], show_outliers=True)
>>> ax = visualize_curve_boxplot(mesh_data, boxplot_style=style)
"""
# Use default config if none provided
if boxplot_style is None:
boxplot_style = BoxplotStyleConfig()
n_dims = mesh_data['n_dims']
# Route to appropriate backend
if n_dims == 2:
return _visualize_curve_boxplot_2d_matplotlib(mesh_data, boxplot_style, ax)
elif n_dims == 3:
# Check if ax is a PyVista plotter
if PYVISTA_AVAILABLE and ax is not None and hasattr(ax, 'add_mesh'):
# It's a PyVista plotter
return _visualize_curve_boxplot_3d_pyvista(mesh_data, boxplot_style, ax)
elif PYVISTA_AVAILABLE and ax is None:
# No ax provided, create PyVista plotter for 3D
ax = pv.Plotter()
return _visualize_curve_boxplot_3d_pyvista(mesh_data, boxplot_style, ax)
else:
# Fall back to matplotlib 3D (either ax is matplotlib axes or PyVista not available)
return _visualize_curve_boxplot_3d_matplotlib(mesh_data, boxplot_style, ax)
else:
raise ValueError(f"Unsupported curve dimension: {n_dims}. Must be 2 or 3.")
def _visualize_curve_boxplot_2d_matplotlib(mesh_data, boxplot_style, ax=None):
"""
Visualize 2D curve boxplot using matplotlib.
Internal helper function for 2D matplotlib rendering.
Parameters:
-----------
mesh_data : dict
Mesh data from curve_boxplot_mesh
boxplot_style : BoxplotStyleConfig
Style configuration
ax : matplotlib.axes.Axes, optional
Matplotlib axes to plot on
Returns:
--------
ax : matplotlib.axes.Axes
The matplotlib axes object with the plot
"""
# Create figure/axes if not provided
if ax is None:
fig, ax = plt.subplots(figsize=(10, 8))
# Get colors from colormap
colors = boxplot_style.get_percentile_colors()
percentiles = boxplot_style.percentiles
# Sort percentiles in descending order for proper plotting (largest first)
sorted_percentile_indices = np.argsort(percentiles)[::-1]
sorted_percentiles = [percentiles[i] for i in sorted_percentile_indices]
sorted_colors = [colors[i] for i in sorted_percentile_indices]
# Plot each percentile band from largest to smallest
for percentile, color in zip(sorted_percentiles, sorted_colors):
mesh_key = f'{int(percentile)}_percentile_mesh'
if mesh_key in mesh_data['percentile_meshes']:
points, triangles = mesh_data['percentile_meshes'][mesh_key]
_plot_band_mesh_2d(points, triangles, ax=ax, color=color, alpha=1.0)
# Plot outliers
if boxplot_style.show_outliers and mesh_data['outliers'].shape[0] > 0:
outliers = mesh_data['outliers']
for idx in range(len(outliers)):
outlier_curve = outliers[idx]
label = 'Outliers' if idx == 0 else None
ax.plot(outlier_curve[:, 0], outlier_curve[:, 1],
color=boxplot_style.outliers_color,
linewidth=boxplot_style.outliers_width,
alpha=boxplot_style.outliers_alpha,
label=label, zorder=5)
# Plot median curve
if boxplot_style.show_median:
median_curve = mesh_data['median_curve']
ax.plot(median_curve[:, 0], median_curve[:, 1],
color=boxplot_style.median_color,
linewidth=boxplot_style.median_width,
alpha=boxplot_style.median_alpha,
label='Median Curve', zorder=10)
return ax
def _visualize_curve_boxplot_3d_matplotlib(mesh_data, boxplot_style, ax=None):
"""
Visualize 3D curve boxplot using matplotlib 3D.
Internal helper function for 3D matplotlib rendering.
Parameters:
-----------
mesh_data : dict
Mesh data from curve_boxplot_mesh
boxplot_style : BoxplotStyleConfig
Style configuration
ax : matplotlib.axes.Axes, optional
Matplotlib 3D axes to plot on
Returns:
--------
ax : matplotlib.axes.Axes
The matplotlib 3D axes object with the plot
"""
# Create figure/axes if not provided
if ax is None:
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
# Get colors from colormap
colors = boxplot_style.get_percentile_colors()
percentiles = boxplot_style.percentiles
# Sort percentiles in descending order
sorted_percentile_indices = np.argsort(percentiles)[::-1]
sorted_percentiles = [percentiles[i] for i in sorted_percentile_indices]
sorted_colors = [colors[i] for i in sorted_percentile_indices]
# Plot each percentile band
for percentile, color in zip(sorted_percentiles, sorted_colors):
mesh_key = f'{int(percentile)}_percentile_mesh'
if mesh_key in mesh_data['percentile_meshes']:
points, triangles = mesh_data['percentile_meshes'][mesh_key]
_plot_band_mesh_3d(points, triangles, ax=ax, color=color, alpha=1.0)
# Plot outliers
if boxplot_style.show_outliers and mesh_data['outliers'].shape[0] > 0:
outliers = mesh_data['outliers']
for idx in range(len(outliers)):
outlier_curve = outliers[idx]
label = 'Outliers' if idx == 0 else None
ax.plot(outlier_curve[:, 0], outlier_curve[:, 1], outlier_curve[:, 2],
color=boxplot_style.outliers_color,
linewidth=boxplot_style.outliers_width,
alpha=boxplot_style.outliers_alpha,
label=label, zorder=5)
# Plot median curve
if boxplot_style.show_median:
median_curve = mesh_data['median_curve']
ax.plot(median_curve[:, 0], median_curve[:, 1], median_curve[:, 2],
color=boxplot_style.median_color,
linewidth=boxplot_style.median_width,
alpha=boxplot_style.median_alpha,
label='Median Curve', zorder=10)
return ax
def _visualize_curve_boxplot_3d_pyvista(mesh_data, boxplot_style, ax):
"""
Visualize 3D curve boxplot using PyVista.
Internal helper function for 3D PyVista rendering with opacity based on percentile.
Parameters:
-----------
mesh_data : dict
Mesh data from curve_boxplot_mesh
boxplot_style : BoxplotStyleConfig
Style configuration
ax : pyvista.Plotter
PyVista plotter to use
Returns:
--------
ax : pyvista.Plotter
The PyVista plotter object with the visualization
"""
if not PYVISTA_AVAILABLE:
raise ImportError("PyVista is required for 3D visualization but is not installed. "
"Install it with: pip install pyvista")
# Get colors from colormap
colors = boxplot_style.get_percentile_colors()
percentiles = boxplot_style.percentiles
# Sort percentiles in descending order
sorted_percentile_indices = np.argsort(percentiles)[::-1]
sorted_percentiles = [percentiles[i] for i in sorted_percentile_indices]
sorted_colors = [colors[i] for i in sorted_percentile_indices]
# Plot each percentile band with opacity based on percentile
for percentile, color in zip(sorted_percentiles, sorted_colors):
mesh_key = f'{int(percentile)}_percentile_mesh'
if mesh_key in mesh_data['percentile_meshes']:
points, triangles = mesh_data['percentile_meshes'][mesh_key]
# Calculate opacity: (1 - (percentile / 100))^2
opacity = (1 - (percentile / 100)) ** 2
# Create PyVista mesh
faces = np.hstack([np.full((triangles.shape[0], 1), 3), triangles]).astype(int)
poly_mesh = pv.PolyData(points, faces)
# Add mesh to plotter with label
ax.add_mesh(poly_mesh, color=color, opacity=opacity,
smooth_shading=True, show_edges=False,
label=f'{int(percentile)}th percentile')
# Plot outliers
if boxplot_style.show_outliers and mesh_data['outliers'].shape[0] > 0:
outliers = mesh_data['outliers']
for idx, outlier_curve in enumerate(outliers):
# Use Spline for smooth curve rendering
# Add label only to first outlier to avoid duplicate legend entries
label = 'Outliers' if idx == 0 else None
curve_line = pv.Spline(outlier_curve[:, 0:3], n_points=outlier_curve.shape[0])
ax.add_mesh(curve_line,
color=boxplot_style.outliers_color,
line_width=boxplot_style.outliers_width,
opacity=boxplot_style.outliers_alpha,
label=label)
# Plot median curve
if boxplot_style.show_median:
median_curve = mesh_data['median_curve']
median_line = pv.Spline(median_curve[:, 0:3], n_points=median_curve.shape[0])
ax.add_mesh(median_line,
color=boxplot_style.median_color,
line_width=boxplot_style.median_width,
opacity=boxplot_style.median_alpha,
label='Median Curve')
return ax
def _plot_band_mesh_2d(points, triangles, ax, color, alpha):
"""
Plot a triangulated mesh band in 2D using matplotlib.
Internal helper function for rendering 2D triangular meshes.
Parameters:
-----------
points : np.ndarray
Vertex coordinates of the mesh. Shape: (n_points, 2)
triangles : np.ndarray
Triangle faces defined by point indices. Shape: (n_triangles, 3)
ax : matplotlib.axes.Axes
Matplotlib axes to plot on
color : str or tuple
Color for the mesh
alpha : float
Transparency of the mesh (0=transparent, 1=opaque)
Returns:
--------
ax : matplotlib.axes.Axes
The matplotlib axes object used for plotting
"""
for tri in triangles:
poly = Polygon(points[tri], facecolor=color, edgecolor='none', alpha=alpha)
ax.add_patch(poly)
return ax
def _plot_band_mesh_3d(points, triangles, ax, color, alpha):
"""
Plot a triangulated mesh band in 3D using matplotlib.
Internal helper function for rendering 3D triangular meshes.
Parameters:
-----------
points : np.ndarray
Vertex coordinates of the mesh. Shape: (n_points, 3)
triangles : np.ndarray
Triangle faces defined by point indices. Shape: (n_triangles, 3)
ax : matplotlib.axes.Axes
Matplotlib 3D axes to plot on
color : str or tuple
Color for the mesh
alpha : float
Transparency of the mesh (0=transparent, 1=opaque)
Returns:
--------
ax : matplotlib.axes.Axes
The matplotlib 3D axes object used for plotting
"""
ax.plot_trisurf(points[:, 0], points[:, 1], points[:, 2],
triangles=triangles, color=color, alpha=alpha)
return ax