import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
[docs]
def create_radial_plot(metabolic_df, task_info_df, cell_type=None, tissue=None,
task_col='metabolic_task', category_col='System', value_col='scaled_trimean',
tissue_col='tissue', cell_type_col='cell_type', figsize=(6, 6),
title='Metabolic activities', palette='Dark2', title_fontsize=24,
legend_fontsize=14, legend_loc="center left",
legend_bbox_to_anchor=(1.1, 0.5), alpha_fill=0.25, alpha_bg=0.1, ylim=1.0,
sort_by_value=False, ax=None, show_legend=True, save=None, dpi=300,
bbox_inches='tight', tight_layout=True):
"""
Creates a radial plot of metabolic task activities grouped by category.
Parameters
----------
metabolic_df : pandas.DataFrame
DataFrame containing metabolic task activities. Typically, it corresponds
to the 'melted' dataframe in the outputs from `sccellfie.reports.summary.generate_report_from_adata()`.
Required columns: task_col, value_col, cell_type_col, tissue_col.
task_info_df : pandas.DataFrame
DataFrame containing task categorization information.
Required columns: task_col and category_col.
cell_type : str, optional (default: None)
The specific cell type to plot. If None, the maximum activity across all cell types
within the specified tissue is used.
tissue : str, optional (default: None)
The specific tissue to plot. If None, all tissues are included.
task_col : str, optional (default: 'metabolic_task')
The column name in metabolic_df containing task identifiers.
category_col : str, optional (default: 'System')
The column name in task_info_df containing category information.
value_col : str, optional (default: 'scaled_trimean')
The column name in metabolic_df containing activity values.
tissue_col : str, optional (default: 'tissue')
The column name in metabolic_df containing tissue information.
cell_type_col : str, optional (default: 'cell_type')
The column name in metabolic_df containing cell type information.
figsize : tuple, optional (default: (6, 6))
The size of the figure. Only used if ax is None.
title : str, optional (default: 'Metabolic activities')
The title for the plot. Set to None to disable the title.
palette : str, optional (default: 'Dark2)
Name of a palette for coloring the categories of metabolic tasks.
title_fontsize : int, optional (default: 24)
Font size for the title.
legend_fontsize : int, optional (default: 14)
Font size for the legend.
legend_loc : str, optional (default: "center left")
Location of the legend.
legend_bbox_to_anchor : tuple, optional (default: (1.1, 0.5))
Position of the legend relative to the legend_loc.
alpha_fill : float, optional (default: 0.25)
Alpha transparency for the filled areas.
alpha_bg : float, optional (default: 0.1)
Alpha transparency for the background areas.
ylim : float, optional (default: 1.0)
Limit value for the y-axis (radial direction). If None, the maximum value across
all tasks is used instead.
sort_by_value : bool, optional (default: False)
If True, tasks within each category are sorted by their value.
If False, tasks are sorted alphabetically within each category.
ax : matplotlib.axes.Axes, optional (default: None)
A matplotlib axes with polar projection to draw the plot on.
If None, a new figure and axes are created.
show_legend : bool, optional (default: True)
Whether to display the legend.
save : str, optional (default: None)
The filepath to save the figure. If None, the figure is not saved.
dpi : int, optional (default: 300)
The resolution of the saved figure.
bbox_inches : str, optional (default: 'tight')
The bbox_inches parameter for saving the figure.
tight_layout : bool, optional (default: True)
Whether to use tight layout for the plot. Only applied if ax is None.
Returns
-------
fig : matplotlib.figure.Figure
The matplotlib figure object.
ax : matplotlib.axes.Axes
The matplotlib axes object.
Examples
--------
>>> import pandas as pd
>>> from sccellfie.plotting import create_radial_plot
>>>
>>> # Load example data
>>> metabolic_df = pd.read_csv('Melted.csv')
>>> task_info_df = pd.read_csv('TaskInfo.csv')
>>>
>>> # Create radial plot for maximum activities across all cell types in a tissue
>>> fig, ax = create_radial_plot(metabolic_df, task_info_df, tissue='Blood')
>>> plt.show()
>>>
>>> # Create radial plot for a specific cell type in a specific tissue
>>> fig, ax = create_radial_plot(metabolic_df, task_info_df, cell_type='T cell', tissue='Blood')
>>> plt.show()
>>>
>>> # Create multiple subplots with shared legend
>>> fig = plt.figure(figsize=(20, 10))
>>> ax1 = fig.add_subplot(121, projection='polar')
>>> ax2 = fig.add_subplot(122, projection='polar')
>>>
>>> # First subplot with legend
>>> create_radial_plot(metabolic_df, task_info_df, tissue='Blood', ax=ax1, show_legend=True)
>>> # Second subplot without legend
>>> create_radial_plot(metabolic_df, task_info_df, tissue='Liver', ax=ax2, show_legend=False)
>>> plt.tight_layout()
>>> plt.show()
"""
# Copy dataframes to avoid modifying originals
metabolic_df = metabolic_df.copy()
task_info_df = task_info_df.copy()
# Check required columns
required_cols = [task_col, value_col, cell_type_col]
if tissue is not None:
required_cols.append(tissue_col)
missing_cols = [col for col in required_cols if col not in metabolic_df.columns]
if missing_cols:
raise ValueError(f"Missing required columns in metabolic_df: {', '.join(missing_cols)}")
# Filter by tissue if specified
if tissue is not None:
if tissue_col not in metabolic_df.columns:
raise ValueError(f"'{tissue_col}' column is missing from metabolic_df")
data = metabolic_df[metabolic_df[tissue_col] == tissue].copy()
if len(data) == 0:
raise ValueError(f"No data found for tissue '{tissue}'")
else:
data = metabolic_df.copy()
# Filter by cell type or get max across cell types
if cell_type is not None:
if cell_type_col not in data.columns:
raise ValueError(f"'{cell_type_col}' column is missing from metabolic_df")
filtered_data = data[data[cell_type_col] == cell_type].copy()
if len(filtered_data) == 0:
raise ValueError(f"No data found for cell_type '{cell_type}'")
# Group by task to get average values (in case there are multiple entries per task for this cell type)
radial_df = filtered_data.groupby(task_col)[value_col].mean().reset_index()
else:
# Group by task and calculate maximum across all cell types
radial_df = data.groupby(task_col)[value_col].max().reset_index()
# Ensure task_info_df has columns we need
if task_col not in task_info_df.columns:
# Try to find a different column with task names
if 'Task' in task_info_df.columns:
task_info_df = task_info_df.rename(columns={'Task': task_col})
else:
raise ValueError(f"'{task_col}' column not found in task_info_df")
# Merge with task categories
radial_df = pd.merge(radial_df, task_info_df[[task_col, category_col]], on=task_col, how='left')
# Check for tasks without category information
missing_categories = radial_df[radial_df[category_col].isna()][task_col].unique()
if len(missing_categories) > 0:
# Assign uncategorized tasks to "Other" category
radial_df.loc[radial_df[category_col].isna(), category_col] = "Other"
print(f"Warning: {len(missing_categories)} tasks have no category information and were assigned to 'Other'")
# Get unique categories ordered by count (descending)
category_counts = radial_df[category_col].value_counts()
categories_by_size = category_counts.index.tolist() # Categories ordered by descending count
# Simplified color assignment approach
import matplotlib.cm as cm
from matplotlib.colors import to_rgba
# Check if we need to extend the palette
n_categories = len(categories_by_size)
if isinstance(palette, str):
# Get the built-in colormap
cmap = cm.get_cmap(palette)
# Check if the colormap has enough colors
if cmap.N >= n_categories:
# Colormap has enough colors, use it directly
color_map = {cat: to_rgba(cmap(i / cmap.N)) for i, cat in enumerate(categories_by_size)}
else:
# Need more colors, try to extend with glasbey if available
try:
import glasbey
palette_size = n_categories
try:
extended_palette = glasbey.extend_palette(palette, palette_size=palette_size)
color_map = {cat: extended_palette[i] for i, cat in enumerate(categories_by_size)}
except Exception as e:
# Glasbey extension failed, fall back to cycling the original colormap
print(f"Warning: Could not extend palette: {str(e)}. Using color cycling instead.")
color_map = {cat: to_rgba(cmap(i % cmap.N)) for i, cat in enumerate(categories_by_size)}
except ImportError:
# Glasbey not available, use cycling of the original colormap
print("Warning: glasbey module not available. Using color cycling instead.")
color_map = {cat: to_rgba(cmap(i % cmap.N)) for i, cat in enumerate(categories_by_size)}
else:
# If palette is already a list of colors, use it with cycling if needed
color_map = {cat: palette[i % len(palette)] for i, cat in enumerate(categories_by_size)}
# Create figure if ax is not provided
if ax is None:
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111, projection='polar')
else:
# Make sure the provided axis has polar projection
from matplotlib.projections.polar import PolarAxes
if not isinstance(ax, PolarAxes):
raise ValueError("The provided ax must have polar projection")
fig = ax.figure
if ylim is None:
ylim = radial_df[value_col].max()
# Create a DataFrame with pathway and category information
pathway_data = radial_df.copy()
# Sort by category count (descending)
pathway_data['category_count'] = pathway_data[category_col].map(category_counts)
# Get categories in the order determined by our sorting (by size)
categories_ordered = pathway_data.sort_values('category_count', ascending=False)[category_col].unique()
# Drop the temporary column
pathway_data = pathway_data.drop(columns=['category_count'])
# Calculate angle for each pathway
total_pathways = len(pathway_data)
angle_per_pathway = (2 * np.pi) / total_pathways
# Initialize for angle assignment
pathway_data_grouped = []
current_angle = 0
# Process each category in our determined order
for category in categories_ordered:
# Record the starting angle for this category
category_start_angle = current_angle
# Get data for this category
category_data = pathway_data[pathway_data[category_col] == category].copy()
# Sort within category: alphabetically by default, by value if requested
if sort_by_value:
category_data = category_data.sort_values(value_col, ascending=False)
else:
# Sort alphabetically by task name
category_data = category_data.sort_values(task_col)
# Create a range of angles for this category
num_pathways = len(category_data)
category_angles = np.linspace(current_angle,
current_angle + (num_pathways * angle_per_pathway),
num_pathways,
endpoint=False)
# Assign these angles
category_data['angle'] = category_angles
# Update current angle for the next category
current_angle += num_pathways * angle_per_pathway
# Store category boundary
pathway_data_grouped.append({
'category': category,
'start_angle': category_start_angle,
'end_angle': current_angle,
'data': category_data
})
# Create a new dataframe with all the angle-assigned data
pathway_data = pd.concat([group_info['data'] for group_info in pathway_data_grouped])
# Set up the polar plot
ax.set_theta_offset(np.pi / 2)
ax.set_theta_direction(-1)
ax.set_ylim(0, ylim)
# Add the axis from the center
ax.spines['polar'].set_visible(True)
ax.spines['polar'].set_linewidth(2)
ax.spines['polar'].set_color('black')
# Set the position of the score ticks
ax.set_rlabel_position(180)
ax.set_yticks([np.min([1., ylim]), ylim])
ax.set_yticklabels(['', ''], fontsize=16)
ax.tick_params(axis='y', which='major', width=1., color='red')
# Color the circular area behind each category with exact alignment
for group_info in pathway_data_grouped:
category = group_info['category']
start_angle = group_info['start_angle']
end_angle = group_info['end_angle']
# Create angles that span exactly from the first to last pathway in this category
cat_angles = np.linspace(start_angle, end_angle, 50)
# Fill the background
ax.fill_between(cat_angles, 0, ylim, color=color_map[category], alpha=alpha_bg, zorder=0)
# Plot the data on the radial plot with connected bars within each category
for group_info in pathway_data_grouped:
category = group_info['category']
# Get all pathways in this category
category_paths = pathway_data[pathway_data[category_col] == category]
# Get angles and scores for this category
cat_angles = category_paths['angle'].values
cat_scores = category_paths[value_col].values
# For each pathway in the category, draw the radial line
for angle, score in zip(cat_angles, cat_scores):
ax.plot([angle, angle], [0, score], color=color_map[category], linewidth=2)
# Connect all bars within this category with a polygon
polygon_angles = []
polygon_radii = []
# Start the polygon at the first bar, at the base
for angle, score in zip(cat_angles, cat_scores):
# Add a point at the base (radius 0)
polygon_angles.append(angle)
polygon_radii.append(0)
# Add a point at the top of the bar
polygon_angles.append(angle)
polygon_radii.append(score)
# Convert to numpy arrays for matplotlib
polygon_angles = np.array(polygon_angles)
polygon_radii = np.array(polygon_radii)
# Plot the polygon
ax.fill(polygon_angles, polygon_radii, color=color_map[category], alpha=alpha_fill)
# Remove theta ticks
ax.set_xticks([])
# Create a legend only if show_legend is True
if show_legend:
# Create a legend using the same category order as the colors
legend_labels = categories_by_size
legend_handles = [plt.Rectangle((0, 0), 1, 1, color=color_map[cat], lw=0) for cat in legend_labels]
legend_labels = [l.upper() for l in legend_labels]
# Create the legend
ax.legend(legend_handles, legend_labels,
loc=legend_loc, bbox_to_anchor=legend_bbox_to_anchor,
fontsize=legend_fontsize, frameon=False, borderaxespad=0,
ncol=1)
# Add title with appropriate information (if not None)
if title is not None:
title_parts = []
title_parts.append(title)
if tissue is not None:
title_parts.append(tissue)
if cell_type is not None:
title_parts.append(cell_type)
else:
title_parts.append("across cell types")
# Join title parts with newlines
full_title = '\n'.join(title_parts)
# Use set_title if this is a subplot, otherwise use suptitle
if ax is not None and ax.get_subplotspec() is not None:
ax.set_title(full_title, fontsize=title_fontsize, fontweight='bold', pad=20)
else:
plt.suptitle(full_title, fontsize=title_fontsize, fontweight='bold', y=1.025)
# Apply tight layout if requested and we created the figure
if tight_layout and ax is None:
plt.tight_layout()
# Save if requested
if save is not None:
try:
from sccellfie.plotting.plot_utils import _get_file_format, _get_file_dir
dir, basename = _get_file_dir(save)
os.makedirs(dir, exist_ok=True)
format = _get_file_format(save)
plt.savefig(f'{dir}/radial_{basename}.{format}', dpi=dpi, bbox_inches=bbox_inches)
except ImportError:
# Fall back to basic save if plot_utils is not available
plt.savefig(save, dpi=dpi, bbox_inches=bbox_inches)
return fig, ax