"""
Cell-level spatial visualization for AnnData objects.
Provides :func:`plot_segmentation`, a general-purpose renderer for
cell polygons (from segmentation) or centroid scatter plots, with
categorical and continuous colouring, customizable legends, optional
crop, and a scalebar. Works for any technology that exposes per-cell
2D coordinates (Xenium, VisiumHD-segmented, Atera, ...).
"""
import math
import textwrap
from typing import List, Optional, Sequence, Tuple, Union
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.sparse
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon as MplPolygon
from matplotlib.transforms import blended_transform_factory, offset_copy
def _add_scalebar(
ax,
length=None,
units="µm",
color="black",
linewidth=3,
fontsize=9,
position="lower_right",
pad_frac=0.05,
text_pad_pts=2.0,
remove_axes=True,
):
"""Add a scalebar to a spatial axes (1-2-5 auto-snap, inversion-aware).
Uses a blended transform (data X, axes-fraction Y) so the bar position is
independent of data extent or figure size: ``pad_frac`` is the fraction of
the axes height/width inset from the chosen corner. The label is placed on
the side away from the data (below the bar for ``lower_*``, above for
``upper_*``) so it never overlaps the cells, regardless of ``invert_yaxis``.
"""
xlim = ax.get_xlim()
width = abs(xlim[1] - xlim[0])
if length is None:
raw = 0.2 * width
if raw > 0:
exponent = int(math.floor(math.log10(raw)))
base = raw / (10 ** exponent)
nice_base = 5 if base > 5 else 2 if base > 2 else 1
length = nice_base * (10 ** exponent)
else:
length = 1.0
if "right" in position:
x_end = xlim[1] - pad_frac * width
x_start = x_end - length
else:
x_start = xlim[0] + pad_frac * width
x_end = x_start + length
is_lower = "lower" in position
y_axes = pad_frac if is_lower else 1.0 - pad_frac
blended = blended_transform_factory(ax.transData, ax.transAxes)
ax.plot(
[x_start, x_end], [y_axes, y_axes],
color=color, lw=linewidth, solid_capstyle="butt",
transform=blended, clip_on=False,
)
label = f"{int(length) if length >= 1 else length} {units}"
sign = -1.0 if is_lower else 1.0
text_transform = offset_copy(
blended, fig=ax.figure,
y=sign * (linewidth / 2.0 + text_pad_pts),
units="points",
)
text_va = "top" if is_lower else "bottom"
ax.text(
(x_start + x_end) / 2.0, y_axes, label,
color=color, transform=text_transform,
ha="center", va=text_va, fontsize=fontsize,
clip_on=False,
)
if remove_axes:
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel("")
ax.set_ylabel("")
for spine in ax.spines.values():
spine.set_visible(False)
def _render_panel(
ax,
*,
adata,
local_idx,
cell_ids,
coords,
xlim,
ylim,
color_by,
celltype_key,
segmentation,
palette,
highlight,
layer,
legend,
legend_loc,
legend_bbox,
legend_frameon,
legend_title,
legend_fontsize,
legend_ncol,
legend_params,
axes_off,
scatter_size,
cmap,
vmin,
vmax,
scalebar,
scalebar_kwargs,
cbar_kwargs,
panel_title,
title_fontsize,
):
"""Render a single panel (one feature) into ``ax``."""
ax.set_xlim(*xlim)
ax.set_ylim(*ylim)
ax.set_autoscale_on(False)
target_col = color_by if color_by is not None else celltype_key
colors_array = None
cmap_vals = None
is_categorical = False
final_palette: dict = {}
if target_col in adata.obs.columns:
col_data = adata.obs[target_col].iloc[local_idx]
if isinstance(col_data.dtype, pd.CategoricalDtype) or pd.api.types.is_object_dtype(col_data):
is_categorical = True
unique_cats = list(adata.obs[target_col].astype("category").cat.categories)
if palette:
final_palette = dict(palette)
elif f"{target_col}_colors" in adata.uns:
final_palette = dict(zip(unique_cats, adata.uns[f"{target_col}_colors"]))
else:
s2 = plt.get_cmap("Set2")
final_palette = {cat: s2(i % 8) for i, cat in enumerate(unique_cats)}
if highlight:
h_set = set(highlight)
for cat in unique_cats:
if cat not in h_set:
final_palette[cat] = "whitesmoke"
colors_array = [final_palette.get(v, "whitesmoke") for v in col_data.values]
else:
cmap_vals = col_data.values
elif target_col in adata.var_names:
sub = adata[local_idx, target_col]
gene_data = sub.layers[layer] if layer is not None else sub.X
if scipy.sparse.issparse(gene_data):
cmap_vals = gene_data.toarray().flatten()
else:
cmap_vals = np.asarray(gene_data).flatten()
else:
colors_array = ["#cccccc"] * len(local_idx)
p_coll = None
scatter = None
if segmentation is not None:
patches, v_cols, v_cmap = [], [], []
for i, idx in enumerate(local_idx):
cid = cell_ids[idx]
poly = segmentation.get(cid)
if poly is None:
continue
patches.append(MplPolygon(np.array(poly.exterior.coords), closed=True))
if colors_array is not None:
v_cols.append(colors_array[i])
if cmap_vals is not None:
v_cmap.append(cmap_vals[i])
p_coll = PatchCollection(
patches, alpha=1.0, edgecolor="gray", linewidth=0.05, antialiased=True
)
if colors_array is not None:
p_coll.set_facecolor(v_cols)
else:
p_coll.set_array(np.array(v_cmap))
p_coll.set_cmap(cmap)
if vmin is not None or vmax is not None:
p_coll.set_clim(vmin=vmin, vmax=vmax)
ax.add_collection(p_coll)
else:
l_coords = coords[local_idx]
c = colors_array if colors_array is not None else cmap_vals
scatter = ax.scatter(
l_coords[:, 0],
l_coords[:, 1],
c=c,
cmap=cmap if cmap_vals is not None else None,
vmin=vmin if cmap_vals is not None else None,
vmax=vmax if cmap_vals is not None else None,
s=scatter_size,
edgecolor="none",
)
if legend:
if is_categorical:
if highlight:
items = [
(cat, col)
for cat, col in final_palette.items()
if col != "whitesmoke"
]
else:
items = list(final_palette.items())
handles = [mpatches.Patch(color=col, label=cat) for cat, col in items]
if handles:
kw = dict(
handles=handles,
loc=legend_loc,
frameon=legend_frameon,
ncol=legend_ncol,
)
if legend_bbox is not None:
kw["bbox_to_anchor"] = legend_bbox
if legend_title is not None:
kw["title"] = legend_title
if legend_fontsize is not None:
kw["fontsize"] = legend_fontsize
if legend_params:
kw.update(legend_params)
ax.legend(**kw)
elif cmap_vals is not None:
mappable = p_coll if p_coll is not None else scatter
cb_kw = dict(fraction=0.046, pad=0.04)
if cbar_kwargs:
cb_kw.update(cbar_kwargs)
plt.colorbar(mappable, ax=ax, **cb_kw)
ax.set_aspect("equal")
if scalebar:
sb_params = {"color": "black", "position": "lower_right", "pad_frac": 0.05}
if scalebar_kwargs:
sb_params.update(scalebar_kwargs)
_add_scalebar(ax, remove_axes=axes_off, **sb_params)
elif axes_off:
ax.set_xticks([])
ax.set_yticks([])
for spine in ax.spines.values():
spine.set_visible(False)
# Final safeguard: re-apply limits after every artist has been added,
# in case anything (notably colorbar or scalebar) caused a relim.
ax.set_xlim(*xlim)
ax.set_ylim(*ylim)
if panel_title is not None:
ax.set_title(panel_title, fontsize=title_fontsize)
[docs]
def plot_segmentation(
adata,
spatial_key: str = "X_spatial",
color_by: Optional[Union[str, Sequence[str]]] = None,
celltype_key: str = "cell_type",
segmentation: Optional[dict] = None,
cell_id_col: Optional[str] = None,
palette: Optional[dict] = None,
highlight: Optional[List[str]] = None,
layer: Optional[str] = None,
crop: Optional[Tuple[float, float, float, float]] = None,
invert_yaxis: bool = True,
legend: bool = True,
legend_loc: str = "center left",
legend_bbox: Optional[Tuple[float, float]] = (1.01, 0.5),
legend_frameon: bool = False,
legend_title: Optional[str] = None,
legend_fontsize: Optional[float] = 7.0,
legend_ncol: int = 1,
legend_params: Optional[dict] = None,
axes_off: bool = True,
figsize: Optional[Tuple[float, float]] = None,
ax=None,
ncols: int = 4,
panel_titles: bool = True,
title: Optional[Union[str, Sequence[str]]] = None,
title_fontsize: Optional[float] = 12,
wrapped_title_length: int = 45,
dpi: int = 150,
scatter_size: float = 2.0,
cmap: str = "viridis",
vmin: Optional[float] = None,
vmax: Optional[float] = None,
y_pad_ratio: float = 0.1,
x_pad_ratio: float = 0.0,
scalebar: bool = True,
scalebar_kwargs: Optional[dict] = None,
cbar_kwargs: Optional[dict] = None,
save: Optional[str] = None,
):
"""
Plot cell-resolution spatial data from an AnnData object.
Renders cells as segmentation polygons when ``segmentation`` is
provided, otherwise as a centroid scatter plot. Supports categorical
and continuous colouring, optional highlighting of a subset of
categories, axis cropping, and a scalebar with bottom/top padding.
When ``color_by`` is a list, multiple panels are drawn in a grid
laid out by ``ncols`` (matching ``sc.pl.spatial`` semantics): the
geometry, crop, and view limits are computed once and shared across
panels; each panel is coloured independently and gets its own legend
or colorbar.
Parameters
----------
adata : anndata.AnnData
AnnData with spatial coordinates in ``adata.obsm[spatial_key]``.
spatial_key : str, optional (default: "X_spatial")
Key in ``adata.obsm`` for the ``(n_cells, 2+)`` coordinate
array. Defaults to scCellFie's canonical key.
color_by : str, list of str, or None, optional (default: None)
Column in ``adata.obs`` or name in ``adata.var_names`` to colour
by. If None, falls back to ``celltype_key``. Pass a list of
names (e.g. ``["task_A", "task_B", "GENE1"]``) to render a
multi-panel figure with one panel per feature.
celltype_key : str, optional (default: "cell_type")
Default categorical column used when ``color_by`` is None.
segmentation : dict, optional (default: None)
Mapping ``cell_id -> shapely.Polygon`` (e.g. output of
:func:`sccellfie.io.load_segmentation` with ``output="dict"``).
If None, a scatter of centroids is drawn.
cell_id_col : str, optional (default: None)
Column in ``adata.obs`` identifying cells. Defaults to
``adata.obs.index``.
palette : dict, optional (default: None)
Custom ``{category: color}`` mapping for categorical colouring.
Falls back to ``adata.uns["{color_by}_colors"]`` or matplotlib
``Set2`` cycling. In multi-panel mode the same palette is reused
for every categorical feature.
highlight : list of str, optional (default: None)
Subset of categories to highlight; all others are drawn in
``whitesmoke`` and excluded from the legend.
layer : str, optional (default: None)
Layer name in ``adata.layers`` used when ``color_by`` is a gene.
If None, uses ``adata.X``.
crop : tuple, optional (default: None)
``(minx, miny, maxx, maxy)`` bounds to restrict the view. Data
outside this box is not rendered. If None, uses data extent.
invert_yaxis : bool, optional (default: True)
Invert the y-axis (microscopy convention).
legend : bool, optional (default: True)
Show the legend for categorical data, or a colorbar for
continuous data.
legend_loc : str, optional (default: "center left")
``loc`` argument passed to ``ax.legend()``. Ignored for colorbar.
legend_bbox : tuple, optional (default: (1.01, 0.5))
``bbox_to_anchor`` for the legend. Use ``None`` to disable the
anchor and rely on ``legend_loc`` alone.
legend_frameon : bool, optional (default: False)
Whether the legend frame/border is drawn.
legend_title : str, optional (default: None)
Title shown above the legend entries.
legend_fontsize : float, optional (default: 7.0)
Font size for legend labels. ``None`` falls back to the matplotlib
default. The small default suits spatial plots with many
categories; bump it via ``legend_params={'fontsize': 10}`` (or
the dedicated arg) when needed.
legend_ncol : int, optional (default: 1)
Number of columns in the legend.
legend_params : dict, optional (default: None)
Arbitrary kwargs forwarded to ``ax.legend(...)`` (e.g.
``handlelength``, ``labelspacing``, ``borderpad``,
``columnspacing``). Keys here override the dedicated
``legend_*`` arguments on conflict.
axes_off : bool, optional (default: True)
Remove ticks, tick labels, and spines (standard for spatial plots).
figsize : tuple, optional (default: None)
- Single panel (``color_by`` is a str or None): the figure size,
defaulting to ``(10, 10)`` when None.
- Multi panel (``color_by`` is a list): the per-panel size,
defaulting to ``(4, 4)`` when None. The total figure size is
``(figsize[0] * ncols, figsize[1] * nrows)``.
Ignored when ``ax`` is provided.
ax : matplotlib.axes.Axes, optional (default: None)
Existing axes to draw onto. Only valid when ``color_by`` is a
single feature (or None). For multi-panel, omit ``ax`` and let
the function build the grid.
ncols : int, optional (default: 4)
Number of columns in the panel grid when ``color_by`` is a list.
Number of rows is ``ceil(len(color_by) / ncols)``. Mirrors
``sc.pl.spatial``'s ``ncols`` parameter.
panel_titles : bool, optional (default: True)
Master toggle for panel titles. When True, each panel's title is
set to the corresponding feature name (or to the explicit string
passed via ``title=``). Set False to suppress titles entirely
(in single- and multi-panel modes).
title : str, list of str, or None, optional (default: None)
Explicit title override. For single-feature mode pass a string;
for multi-feature mode pass a list of strings whose length
matches ``color_by``. When None (default), titles are auto-derived
from the feature names. Ignored if ``panel_titles=False``.
title_fontsize : float, optional (default: 12)
Font size of the per-panel title. Mirrors the convention in
:func:`sccellfie.plotting.plot_spatial`.
wrapped_title_length : int, optional (default: 45)
Maximum number of characters per title line. Long feature names
(e.g. metabolic-task labels) are wrapped via :func:`textwrap.wrap`
before being set, matching the behavior of the other tool plots
(:func:`plot_spatial`, :func:`create_multi_violin_plots`,
:func:`create_volcano_plot`). Pass a large value (e.g. 1000) to
disable wrapping.
dpi : int, optional (default: 150)
Figure DPI; also used when ``save`` is set.
scatter_size : float, optional (default: 2.0)
Marker size for centroid scatter mode.
cmap : str, optional (default: "viridis")
Matplotlib colormap name for continuous colouring.
vmin, vmax : float, optional (default: None)
Lower / upper bounds for continuous colouring. When set, values
outside ``[vmin, vmax]`` are clipped at the colormap edges and
the colorbar is restricted to that range. Ignored for
categorical colouring. In multi-panel mode the same bounds apply
to every panel (useful for comparing features on a shared
scale). Pass only one of the two to cap a single side.
y_pad_ratio : float, optional (default: 0.1)
Fraction of the y range added as top/bottom whitespace (so the
scalebar label has room).
x_pad_ratio : float, optional (default: 0.0)
Fraction of the x range added as left/right whitespace. Default
keeps x tight to data — increase when the legend or a colorbar
sits to the right of the plot and you want extra breathing room
on the data side too.
scalebar : bool, optional (default: True)
Draw a scalebar on every panel.
scalebar_kwargs : dict, optional (default: None)
Overrides for the scalebar (e.g. ``length``, ``units``, ``color``,
``position``, ``pad_frac``, ``fontsize``, ``text_pad_pts``).
``pad_frac`` is the inset of the bar from the axes corner as a
fraction of the axes height/width; ``text_pad_pts`` is the gap (in
points) between the bar and its label. The label is always placed
on the side of the bar away from the data, so it never overlaps
cells when ``y_pad_ratio > 0``.
cbar_kwargs : dict, optional (default: None)
Overrides passed to ``plt.colorbar`` for continuous colouring.
save : str, optional (default: None)
If given, save the figure to this path with ``dpi`` and
``bbox_inches="tight"``.
Returns
-------
fig : matplotlib.figure.Figure
The matplotlib figure object.
ax : matplotlib.axes.Axes or numpy.ndarray of Axes
Single Axes when ``color_by`` is a string (or None); a 2D array
of Axes (shape ``(nrows, ncols)``) when ``color_by`` is a list.
"""
is_list = isinstance(color_by, (list, tuple)) and not isinstance(color_by, str)
features = list(color_by) if is_list else [color_by]
n_panels = len(features)
if is_list and ax is not None:
raise ValueError(
"`ax` is only supported for a single `color_by`. "
"When passing a list of features, omit `ax` so the grid can be built."
)
if is_list:
nrows_grid = math.ceil(n_panels / ncols)
ncols_grid = ncols
per_panel = figsize if figsize is not None else (4, 4)
total_figsize = (per_panel[0] * ncols_grid, per_panel[1] * nrows_grid)
else:
total_figsize = figsize if figsize is not None else (10, 10)
nrows_grid, ncols_grid = 1, 1
coords = np.asarray(adata.obsm[spatial_key])[:, :2]
cell_ids = (
adata.obs[cell_id_col].values
if cell_id_col is not None
else adata.obs.index.values
)
if crop is not None:
d_minx, d_miny, d_maxx, d_maxy = crop
else:
d_minx, d_maxx = coords[:, 0].min(), coords[:, 0].max()
d_miny, d_maxy = coords[:, 1].min(), coords[:, 1].max()
mask = (
(coords[:, 0] >= d_minx)
& (coords[:, 0] <= d_maxx)
& (coords[:, 1] >= d_miny)
& (coords[:, 1] <= d_maxy)
)
local_idx = np.where(mask)[0]
xr = d_maxx - d_minx
yr = d_maxy - d_miny
xp = xr * x_pad_ratio
yp = yr * y_pad_ratio
xlim = (d_minx - xp, d_maxx + xp)
ylim = (d_maxy + yp, d_miny - yp) if invert_yaxis else (d_miny - yp, d_maxy + yp)
if ax is None:
fig, axes = plt.subplots(
nrows_grid, ncols_grid,
figsize=total_figsize, dpi=dpi, squeeze=False,
)
else:
fig = ax.get_figure()
axes = np.array([[ax]])
flat_axes = axes.ravel()
if panel_titles:
if title is not None:
if is_list:
if not isinstance(title, (list, tuple)) or isinstance(title, str):
raise ValueError(
"`title` must be a list/tuple of strings when `color_by` is a list."
)
if len(title) != n_panels:
raise ValueError(
f"`title` has {len(title)} entries but `color_by` has {n_panels}."
)
resolved_titles = list(title)
else:
if not isinstance(title, str):
raise ValueError(
"`title` must be a string when `color_by` is a single feature."
)
resolved_titles = [title]
else:
resolved_titles = [
(f if f is not None else celltype_key) for f in features
]
else:
resolved_titles = [None] * n_panels
for i, feature in enumerate(features):
raw = resolved_titles[i]
if raw is None:
wrapped = None
else:
wrapped = "\n".join(textwrap.wrap(str(raw), width=wrapped_title_length))
if wrapped == "":
wrapped = str(raw)
_render_panel(
flat_axes[i],
adata=adata,
local_idx=local_idx,
cell_ids=cell_ids,
coords=coords,
xlim=xlim,
ylim=ylim,
color_by=feature,
celltype_key=celltype_key,
segmentation=segmentation,
palette=palette,
highlight=highlight,
layer=layer,
legend=legend,
legend_loc=legend_loc,
legend_bbox=legend_bbox,
legend_frameon=legend_frameon,
legend_title=legend_title,
legend_fontsize=legend_fontsize,
legend_ncol=legend_ncol,
legend_params=legend_params,
axes_off=axes_off,
scatter_size=scatter_size,
cmap=cmap,
vmin=vmin,
vmax=vmax,
scalebar=scalebar,
scalebar_kwargs=scalebar_kwargs,
cbar_kwargs=cbar_kwargs,
panel_title=wrapped,
title_fontsize=title_fontsize,
)
for j in range(n_panels, flat_axes.size):
flat_axes[j].set_visible(False)
if is_list:
fig.tight_layout()
if save:
fig.savefig(save, dpi=dpi, bbox_inches="tight")
if is_list:
return fig, axes
return fig, flat_axes[0]