Source code for sccellfie.plotting.communication

import os
import networkx as nx
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from matplotlib.patches import FancyArrowPatch
from matplotlib.path import Path as MplPath
import matplotlib.patches as patches


def normalize_vector(vector, normalize_to):
    """
    Makes `vector` norm equal to `normalize_to`

    Parameters
    ----------
    vector : np.array
        Vector to normalize.

    normalize_to : float
        Value to normalize the vector to.

    Returns
    -------
    normalized_vector : np.array
        Normalized vector.
    """
    vector_norm = np.linalg.norm(vector)
    normalized_vector = vector * normalize_to / vector_norm
    return normalized_vector


def orthogonal_vector(point, width, normalize_to=None):
    """
    Gets orthogonal vector to a `point`

    Parameters
    ----------
    point : np.array
        Point to get the orthogonal vector.

    width : float
        Width of the orthogonal vector.

    normalize_to : float, optional (default: None)
        Value to normalize the vector to.

    Returns
    -------
    ort_vector : np.array
        Orthogonal vector to the point.
    """
    EPSILON = 0.000001
    x = width
    y = -x * point[0] / (point[1] + EPSILON)
    ort_vector = np.array([x, y])
    if normalize_to is not None:
        ort_vector = normalize_vector(ort_vector, normalize_to)
    return ort_vector


def draw_self_loop(point, ax, node_radius, padding=1.2, width=0.1,
                   linewidth=1, color="pink", alpha=0.5,
                   mutation_scale=20):
    """
    Draws a loop from `point` to itself, starting from node border

    Parameters
    ----------
    point : np.array
        Point to draw the loop.

    ax : matplotlib.axes.Axes
        Axes object where the loop will be drawn.

    node_radius : float
        Radius of the node.

    padding : float, optional (default: 1.2)
        Padding for the loop.

    width : float, optional (default: 0.1)
        Width of the loop.

    linewidth : float, optional (default: 1)
        Width of the loop line.

    color : str, optional (default: "pink")
        Color of the loop.

    alpha : float, optional (default: 0.5)
        Transparency of the loop.

    mutation_scale : float, optional (default: 20)
        Mutation scale of the loop.
    """
    # Get the center of the plot
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    center = np.array([(xlim[1] + xlim[0]) / 2, (ylim[1] + ylim[0]) / 2])

    # Center the point
    centered_point = point - center

    # Calculate direction vector from center to point
    direction = centered_point / np.linalg.norm(centered_point)

    # Calculate start point at node border
    start_point = point - direction * node_radius

    # Calculate the loop points from the border
    point_with_padding = padding * (start_point - center)
    ort_vector = orthogonal_vector(centered_point, width, normalize_to=width)

    first_anchor = ort_vector + point_with_padding + center
    second_anchor = -ort_vector + point_with_padding + center

    # Calculate end point slightly before the node border to show arrow
    end_point = point - direction * (node_radius * 0.7)

    # Define path
    verts = [start_point, first_anchor, second_anchor, end_point]
    codes = [MplPath.MOVETO, MplPath.CURVE4, MplPath.CURVE4, MplPath.CURVE4]

    path = MplPath(verts, codes)
    patch = patches.FancyArrowPatch(
        path=path,
        facecolor='none',
        lw=linewidth,
        arrowstyle="-|>",
        color=color,
        alpha=alpha,
        mutation_scale=mutation_scale
    )
    ax.add_patch(patch)


[docs] def plot_communication_network(ccc_scores, sender_col, receiver_col, score_col, score_threshold=None, panel_size=(12, 8), network_layout='spring', edge_color='magenta', edge_width=25, edge_arrow_size=20, edge_alpha=0.25, node_color="#210070", node_size=1000, node_alpha=0.9, node_label_size=12, node_label_alpha=0.7, node_label_offset=(0.05, -0.2), title=None, title_fontsize=14, ax=None, save=None, dpi=300, tight_layout=True, bbox_inches='tight'): """ Plots a network of cell-cell communication. Edges represent communication scores between cells. These scores could be an overall communication score or a specific ligand-receptor pair score. Parameters ---------- ccc_scores : pandas.DataFrame DataFrame containing the cell-cell communication scores. It should contain columns for the sender cell, receiver cell, and the communication score. sender_col : str Column name for the sender cell. receiver_col : str Column name for the receiver cell. score_col : str Column name for the communication score. score_threshold : float, optional (default: None) Threshold for the communication score. If provided, only scores above this threshold are plotted. panel_size : tuple, optional (default: (12, 8)) Size of the plot panel. Only works if `ax` is None. network_layout : str, optional (default: 'spring') Layout of the network graph. Should be either 'spring' or 'circular'. edge_color : str, optional (default: 'magenta') Color of the edges. edge_width : float, optional (default: 25) Width of the edges. edge_arrow_size : float, optional (default: 20) Size of the edge arrows. edge_alpha : float, optional (default: 0.25) Transparency of the edges. node_color : str, optional (default: '#210070') Color of the nodes. node_size : int, optional (default: 1000) Size of the nodes. node_alpha : float, optional (default: 0.9) Transparency of the nodes. node_label_size : int, optional (default: 12) Font size of the node labels. node_label_alpha : float, optional (default: 0.7) Transparency of the node labels. node_label_offset : tuple, optional (default: (0.05, -0.2)) Offset of the node labels. title : str, optional (default: None) Title of the plot. title_fontsize : int, optional (default: 14) Font size of the title. ax : matplotlib.axes.Axes, optional (default: None) Axes object where the plot will be drawn. If None, a new figure is created. save : str, optional (default: None) Filepath to save the plot. If None, the plot is not saved. dpi : int, optional (default: 300) Resolution of the saved plot. tight_layout : bool, optional (default: True) Whether to use tight layout for the plot. bbox_inches : str, optional (default: 'tight') Bounding box in inches. Only used if `save` is provided. Returns ------- fig : matplotlib.figure.Figure The matplotlib figure object. ax : matplotlib.axes.Axes The matplotlib axes object. """ # Filter by threshold if specified if score_threshold is not None: ccc_scores = ccc_scores[ccc_scores[score_col] >= score_threshold].copy() # Create figure if needed if ax is None: fig, ax = plt.subplots(figsize=panel_size) else: fig = plt.gcf() # Calculate figure-dependent scaling factor fig_width, fig_height = fig.get_size_inches() figsize_factor = np.sqrt(fig_width * fig_height) / np.sqrt(12 * 8) # Create network graph G = nx.DiGraph() all_cells = pd.concat([ccc_scores[sender_col], ccc_scores[receiver_col]]).unique() G.add_nodes_from(all_cells) for _, row in ccc_scores.iterrows(): G.add_edge(row[sender_col], row[receiver_col], weight=row[score_col]) # Set layout if network_layout == 'spring': pos = nx.spring_layout(G, k=1., seed=888) elif network_layout == 'circular': pos = nx.circular_layout(G) else: raise ValueError("network_layout should be either 'spring' or 'circular'") # Get edge weights and normalize them weights = np.array([G.edges[e]['weight'] for e in G.edges()]) if len(weights) > 0: weights_norm = 0.2 + 0.8 * (weights - weights.min()) / ( weights.max() - weights.min() if weights.max() != weights.min() else 1) else: weights_norm = [] # Calculate node radius for offset calculations base_node_radius = np.sqrt(node_size / np.pi) / 150 node_radius = base_node_radius * figsize_factor # Separate self-loops from regular edges self_loops = [(u, v) for (u, v) in G.edges() if u == v] regular_edges = [(u, v) for (u, v) in G.edges() if u != v] def get_offset_positions(pos_src, pos_dst, offset): """Calculate offset positions for arrow endpoints""" direction = pos_dst - pos_src length = np.linalg.norm(direction) if length == 0: return pos_src, pos_dst unit_vec = direction / length scaled_offset = offset * figsize_factor pos_src_offset = pos_src + unit_vec * scaled_offset pos_dst_offset = pos_dst - unit_vec * scaled_offset return pos_src_offset, pos_dst_offset # Draw regular edges edge_weights = dict(zip(G.edges(), weights_norm)) for i, (u, v) in enumerate(regular_edges): pos_u, pos_v = np.array(pos[u]), np.array(pos[v]) offset = node_radius * 0.8 pos_src_offset, pos_dst_offset = get_offset_positions(pos_u, pos_v, offset) arrow = FancyArrowPatch( posA=pos_src_offset, posB=pos_dst_offset, arrowstyle='-|>', connectionstyle=f"arc3,rad=-0.2", color=edge_color, alpha=edge_alpha, linewidth=edge_width * edge_weights[(u, v)], mutation_scale=edge_arrow_size * figsize_factor ) ax.add_patch(arrow) # Draw self-loops using the improved function for u, v in self_loops: pos_u = np.array(pos[u]) draw_self_loop( point=pos_u, ax=ax, node_radius=node_radius, padding=1.2, width=0.1, linewidth=edge_width * edge_weights[(u, v)], color=edge_color, alpha=edge_alpha, mutation_scale=edge_arrow_size ) # Draw nodes nx.draw_networkx_nodes(G, pos, node_color=node_color, node_size=node_size, alpha=node_alpha, ax=ax) # Add labels label_options = {"ec": "k", "fc": "white", "alpha": node_label_alpha} label_pos = {k: v + np.array(node_label_offset) for k, v in pos.items()} nx.draw_networkx_labels(G, label_pos, font_size=node_label_size, bbox=label_options, ax=ax) # Adjust layout ax.set_frame_on(False) xlim = ax.get_xlim() ylim = ax.get_ylim() coeff = 1.4 ax.set_xlim((xlim[0] * coeff, xlim[1] * coeff)) ax.set_ylim((ylim[0] * coeff, ylim[1] * coeff)) if title is not None: ax.set_title(title, fontsize=title_fontsize, y=0.9) if fig is not None: if tight_layout: plt.tight_layout() if save is not None: from sccellfie.plotting.plot_utils import _get_file_format, _get_file_dir dir, basename = _get_file_dir(save) os.makedirs(dir, exist_ok=True) format = _get_file_format(save) plt.savefig(f'{dir}/ccc_{basename}.{format}', dpi=dpi, bbox_inches=bbox_inches) return fig, ax