import numpy as np
import pandas as pd
import scanpy as sc
from scipy.sparse import issparse, csr_matrix, hstack
COMPLEX_AGG_METHODS = {
'min': lambda x, axis: np.min(x, axis=axis),
'mean': lambda x, axis: np.mean(x, axis=axis),
'gmean': lambda x, axis: np.exp(np.mean(np.log(np.clip(x, 1e-10, None)), axis=axis)),
}
[docs]
def make_complex_name(subunits, separator='&'):
"""
Generates a canonical complex name from a list of subunit names.
Parameters
----------
subunits : list of str
List of subunit gene/task names.
separator : str, default='&'
Character(s) used to join the sorted subunit names.
Returns
-------
str
Canonical complex name with sorted subunits joined by separator.
"""
return separator.join(sorted(subunits))
[docs]
def add_complexes_to_adata(adata, complexes, agg_method='min', layer=None, copy=False):
"""
Adds multi-gene complex expression as new variables in an AnnData object.
Computes per-cell aggregated expression for each complex and appends
the result as new columns in adata.X. Layers are handled by computing
the complex aggregation for the source layer and zero-filling others.
Parameters
----------
adata : AnnData
AnnData object containing expression data with individual gene expression.
complexes : dict
Dictionary mapping complex names (str) to lists of subunit gene names.
Example: {'ITGA4&ITGB1': ['ITGA4', 'ITGB1']}
agg_method : str, default='min'
Aggregation across subunits per cell. Options:
- 'min' : Minimum expression (rate-limiting subunit).
- 'mean' : Arithmetic mean expression.
- 'gmean' : Geometric mean expression.
layer : str, optional
Layer to read subunit expression from. If None, uses adata.X.
copy : bool, default=False
If True, return a modified copy. If False, modify adata in place
and return None.
Returns
-------
AnnData or None
If copy=True, returns the modified AnnData. Otherwise modifies
adata in place and returns None.
Raises
------
ValueError
If agg_method is not one of 'min', 'mean', 'gmean'.
If any subunit gene is not found in adata.var_names.
If a complex name already exists in adata.var_names.
"""
if agg_method not in COMPLEX_AGG_METHODS:
raise ValueError(
f"Invalid agg_method '{agg_method}'. Must be one of {list(COMPLEX_AGG_METHODS.keys())}"
)
if not complexes:
return adata.copy() if copy else None
# Validate subunits exist
for complex_name, subunits in complexes.items():
missing = [s for s in subunits if s not in adata.var_names]
if missing:
raise ValueError(
f"Subunit(s) {missing} for complex '{complex_name}' not found in adata.var_names"
)
# Validate complex names don't collide
existing = [name for name in complexes.keys() if name in adata.var_names]
if existing:
raise ValueError(
f"Complex name(s) {existing} already exist in adata.var_names. "
f"Remove them first or use different names."
)
if copy:
adata = adata.copy()
# Select expression source
if layer is not None:
X = adata.layers[layer]
else:
X = adata.X
is_sparse = issparse(X)
# Compute all complex columns
complex_cols = []
complex_names = []
for complex_name, subunits in complexes.items():
subunit_indices = [adata.var_names.get_loc(s) for s in subunits]
sub_X = X[:, subunit_indices]
if issparse(sub_X):
sub_X = sub_X.toarray()
agg_values = COMPLEX_AGG_METHODS[agg_method](sub_X, axis=1)
complex_cols.append(agg_values.reshape(-1, 1))
complex_names.append(complex_name)
complex_matrix = np.hstack(complex_cols) # (n_cells, n_complexes)
# Build new X via hstack
if is_sparse:
new_X = hstack([adata.X, csr_matrix(complex_matrix)], format='csr')
else:
new_X = np.hstack([adata.X, complex_matrix])
# Build new var with metadata
new_var_entries = pd.DataFrame(
{
'is_complex': [True] * len(complex_names),
'complex_subunits': ['|'.join(complexes[n]) for n in complex_names],
'complex_agg_method': [agg_method] * len(complex_names),
},
index=complex_names,
)
old_var = adata.var.copy()
if 'is_complex' not in old_var.columns:
old_var['is_complex'] = False
old_var['complex_subunits'] = np.nan
old_var['complex_agg_method'] = np.nan
new_var = pd.concat([old_var, new_var_entries])
# Handle layers
new_layers = {}
if adata.layers is not None:
for layer_name, layer_data in adata.layers.items():
layer_is_sparse = issparse(layer_data)
if layer_name == layer:
# Source layer: compute complex aggregation
layer_complex_cols = []
for complex_name, subunits in complexes.items():
subunit_indices = [adata.var_names.get_loc(s) for s in subunits]
sub_L = layer_data[:, subunit_indices]
if issparse(sub_L):
sub_L = sub_L.toarray()
agg_vals = COMPLEX_AGG_METHODS[agg_method](sub_L, axis=1)
layer_complex_cols.append(agg_vals.reshape(-1, 1))
layer_complex_matrix = np.hstack(layer_complex_cols)
if layer_is_sparse:
new_layers[layer_name] = hstack(
[layer_data, csr_matrix(layer_complex_matrix)], format='csr'
)
else:
new_layers[layer_name] = np.hstack([layer_data, layer_complex_matrix])
else:
# Non-source layer: zero-fill
n_new = len(complex_names)
if layer_is_sparse:
zero_cols = csr_matrix((layer_data.shape[0], n_new))
new_layers[layer_name] = hstack([layer_data, zero_cols], format='csr')
else:
zero_cols = np.zeros((layer_data.shape[0], n_new))
new_layers[layer_name] = np.hstack([layer_data, zero_cols])
# Preserve obs with categorical dtypes
new_obs = adata.obs.copy()
for col in new_obs.columns:
if pd.api.types.is_categorical_dtype(adata.obs[col]):
new_obs[col] = new_obs[col].astype('category')
if hasattr(adata.obs[col].cat, 'ordered') and adata.obs[col].cat.ordered:
new_obs[col] = new_obs[col].cat.reorder_categories(adata.obs[col].cat.categories)
new_obs[col] = new_obs[col].cat.as_ordered()
# Reconstruct AnnData
new_obsm = adata.obsm.copy() if adata.obsm is not None else None
new_obsp = adata.obsp.copy() if adata.obsp is not None else None
adata_new = sc.AnnData(
X=new_X,
obs=new_obs,
var=new_var,
uns=adata.uns.copy() if adata.uns else None,
obsm=new_obsm,
obsp=new_obsp,
layers=new_layers if new_layers else None,
)
if copy:
return adata_new
else:
adata.__dict__.update(adata_new.__dict__)
return None
[docs]
def prepare_var_pairs(adata, var_pairs, complex_sep='&', agg_method='min', layer=None):
"""
Prepares variable pairs for communication scoring by detecting
multi-element (complex) entries, adding them to adata, and returning
normalized string-only pairs.
Each element in a var_pair can be either a string (single gene/task)
or a list/tuple of strings (complex with multiple subunits). When a
list is detected, the complex is automatically named by joining the
sorted subunit names with complex_sep and added to adata via
add_complexes_to_adata. Complexes already present in adata.var_names
are skipped.
Parameters
----------
adata : AnnData
AnnData object containing expression data.
var_pairs : list of tuples
List of (ligand, receptor) pairs where each element can be:
- str: single gene or task name.
- list/tuple of str: subunits of a complex.
Example::
var_pairs = [
(['TASK1', 'TASK2'], ['GENE1', 'GENE2']), # both complex
('TASK3', 'GENE4'), # both single
('TASK1', ['GENE5', 'GENE6']), # mixed
]
complex_sep : str, default='&'
Separator used to join subunit names into the complex name.
agg_method : str, default='min'
Aggregation method for complex subunits. See add_complexes_to_adata.
layer : str, optional
Layer to read subunit expression from.
Returns
-------
normalized_pairs : list of tuples
String-only (ligand, receptor) pairs ready for scoring functions.
Complex elements are replaced by their generated names.
"""
complexes_to_add = {}
normalized_pairs = []
for var1, var2 in var_pairs:
# Normalize each element
if isinstance(var1, (list, tuple)):
name1 = make_complex_name(var1, separator=complex_sep)
if name1 not in adata.var_names and name1 not in complexes_to_add:
complexes_to_add[name1] = list(var1)
else:
name1 = var1
if isinstance(var2, (list, tuple)):
name2 = make_complex_name(var2, separator=complex_sep)
if name2 not in adata.var_names and name2 not in complexes_to_add:
complexes_to_add[name2] = list(var2)
else:
name2 = var2
normalized_pairs.append((name1, name2))
if complexes_to_add:
add_complexes_to_adata(adata, complexes_to_add, agg_method=agg_method, layer=layer)
return normalized_pairs