Source code for scenicplus.RSS

"""Calculate the specificty of eRegulons in clusters of cells.

Calculates the distance between the real distribution of eRegulon AUC values and
a fictional distribution where the eRegulon is only expressed/accessible in cells
of a certain cluster.

"""

import pandas as pd
import numpy as np
from scipy.spatial.distance import jensenshannon
from math import ceil, floor
import matplotlib.pyplot as plt
import matplotlib
from typing import List, Tuple, Union
from adjustText import adjust_text
from mudata import MuData
from scenicplus.scenicplus_mudata import ScenicPlusMuData

def rss(aucs, labels):
        # jensenshannon function provides distance 
        # which is the sqrt of the JS divergence.
        return 1.0 - jensenshannon(aucs / aucs.sum(), labels / labels.sum())

[docs] def regulon_specificity_scores_df( data_matrix: pd.DataFrame, variable_matrix: pd.Series): """ Calculate the Regulon Specificty Scores (RSS). [doi: 10.1016/j.celrep.2018.10.045] Parameters --------- data_matrix: 'class::pd.DataFrame` A pandas dataframe containing regulon scores per cell. variable_matrix: 'class::pd.Series' A pandas series with an annotation per cell. """ cell_types = list(variable_matrix.unique()) n_types = len(cell_types) regulons = list(data_matrix.columns) n_regulons = len(regulons) rss_values = np.empty(shape=(n_types, n_regulons), dtype=float) for cidx, regulon_name in enumerate(regulons): for ridx, type in enumerate(cell_types): rss_values[ridx, cidx] = rss( data_matrix[regulon_name], (variable_matrix == type).astype(int)) rss_values = pd.DataFrame( data=rss_values, index=cell_types, columns=regulons) return rss_values
[docs] def regulon_specificity_scores(scplus_mudata: Union[MuData, ScenicPlusMuData], variable: str, modalities: list, selected_regulons: List[int] = None): """ Calculate the Regulon Specificty Scores (RSS). [doi: 10.1016/j.celrep.2018.10.045] Parameters --------- scplus_mudata: `class::MuData` or 'class::ScenicPlusMuData' A MuData object with eRegulons AUC computed. variable: str Variable to calculate the RSS values for. modalities: List, A list of modalities to calculate RSS values for. selected_regulons: List, optional Regulons to calculate RSS values for. """ #TODO: add checks rss_values_per_modality = [] for modality in modalities: if selected_regulons is not None: modality_regulons = [regulon for regulon in selected_regulons if regulon in scplus_mudata.mod[modality].var_names] else: modality_regulons = list(scplus_mudata.mod[modality].var_names) data_matrix = scplus_mudata.mod[modality][:, modality_regulons].to_df() variable_matrix = scplus_mudata.obs.loc[data_matrix.index, variable] rss_values_per_modality.append(regulon_specificity_scores_df(data_matrix=data_matrix, variable_matrix=variable_matrix)) return pd.concat(rss_values_per_modality, axis=1)
[docs] def plot_rss(data_matrix: pd.DataFrame, top_n: int = 5, selected_groups: List[str] = None, num_columns: int = 1, figsize: Tuple[float, float] = (6.4, 4.8), fontsize: int = 12, save: str = None): """ Plot RSS values per group Parameters --------- data_matrix: `class::pd.DataFrame` A pandas dataframe with RSS scores per variable. top_n: int, optional Number of top eRegulons to highlight. selected_groups: List, optional Groups to plot. Default: None (all) num_columns: int, optional Number of columns for multiplotting figsize: tuple, optional Size of the figure. If num_columns is 1, this is the size for each figure; if num_columns is above 1, this is the overall size of the figure (if keeping default, it will be the size of each subplot in the figure). Default: (6.4, 4.8) fontsize: int, optional Size of the eRegulons names in plot. save: str, optional Path to save plot. Default: None. """ if selected_groups is None: cats = sorted(data_matrix.index.tolist()) else: cats = selected_groups if num_columns > 1: num_rows = int(np.ceil(len(cats) / num_columns)) figsize = (figsize[0] * num_columns, figsize[1] * num_rows) i = 1 fig = plt.figure(figsize=figsize) pdf = None if (save is not None) & (num_columns == 1): pdf = matplotlib.backends.backend_pdf.PdfPages(save) for c in cats: x = data_matrix.T[c] if num_columns > 1: ax = fig.add_subplot(num_rows, num_columns, i) i = i + 1 else: fig = plt.figure(figsize=figsize) ax = plt.axes() _plot_rss_internal(data_matrix, c, top_n=top_n, max_n=None, ax=ax) ax.set_ylim(x.min()-(x.max()-x.min())*0.05, x.max()+(x.max()-x.min())*0.05) for t in ax.texts: t.set_fontsize(fontsize) ax.set_ylabel('') ax.set_xlabel('') adjust_text(ax.texts) if num_columns == 1: fig.text(0.5, 0.0, 'eRegulon rank', ha='center', va='center', size='x-large') fig.text(0.00, 0.5, 'eRegulon specificity score (eRSS)', ha='center', va='center', rotation='vertical', size='x-large') plt.tight_layout() plt.rcParams.update({ 'figure.autolayout': True, 'figure.titlesize': 'large', 'axes.labelsize': 'medium', 'axes.titlesize': 'large', 'xtick.labelsize': 'medium', 'ytick.labelsize': 'medium' }) if save is not None: pdf.savefig(fig, bbox_inches='tight') plt.show() if num_columns > 1: fig.text(0.5, 0.0, 'eRegulon rank', ha='center', va='center', size='x-large') fig.text(0.00, 0.5, 'eRegulon specificity score (eRSS)', ha='center', va='center', rotation='vertical', size='x-large') plt.tight_layout() plt.rcParams.update({ 'figure.autolayout': True, 'figure.titlesize': 'large', 'axes.labelsize': 'medium', 'axes.titlesize': 'large', 'xtick.labelsize': 'medium', 'ytick.labelsize': 'medium' }) if save is not None: fig.savefig(save, bbox_inches='tight') plt.show() if (save is not None) & (num_columns == 1): pdf = pdf.close()
def _plot_rss_internal(rss, cell_type, top_n=5, max_n=None, ax=None): """ Helper function to plot RSS """ if ax is None: _, ax = plt.subplots(1, 1, figsize=(4, 4)) if max_n is None: max_n = rss.shape[1] data = rss.T[cell_type].sort_values(ascending=False)[0:max_n] ax.plot(np.arange(len(data)), data, '.') ax.set_ylim([floor(data.min() * 100.0) / 100.0, ceil(data.max() * 100.0) / 100.0]) ax.set_title(cell_type) ax.set_xticklabels([]) font = { 'color': 'red', 'weight': 'normal' } for idx, (regulon_name, rss_val) in enumerate( zip(data[0:top_n].index, data[0:top_n].values)): ax.plot([idx, idx], [rss_val, rss_val], 'r.') ax.text( idx + (max_n / 25), rss_val, regulon_name, fontdict=font, horizontalalignment='left', verticalalignment='center', )