Source code for scenicplus.RSS

"""Calculate the specificty of eRegulons in clusters of cells.

Calculates the distance between the real distribution of eRegulon AUC values and a fictional distribution where the eRegulon is only expressed/accessible in cells of a certain cluster.

"""

import pandas as pd
import numpy as np
from scipy.spatial.distance import jensenshannon
from math import ceil, floor
import matplotlib.pyplot as plt
import matplotlib
from typing import List, Tuple
from typing import Optional
from adjustText import adjust_text
import sklearn

from .scenicplus_class import SCENICPLUS


[docs]def regulon_specificity_scores(scplus_obj: SCENICPLUS, variable: str, auc_key: Optional[str] = 'eRegulon_AUC', signature_keys: Optional[List[str]] = ['Gene_based', 'Region_based'], selected_regulons: Optional[List[int]] = None, scale: Optional[bool] = False, out_key_suffix: Optional[str] = ''): """ Calculate the Regulon Specificty Scores (RSS). [doi: 10.1016/j.celrep.2018.10.045] Parameters --------- scplus_obj: `class::SCENICPLUS` A SCENICPLUS object with eRegulons AUC computed. variable: str Variable to calculate the RSS values for. auc_key: str, optional Key to extract AUC values from. Default: 'eRegulon_AUC' signature_keys: List, optional Keys to extract AUC values from. Default: ['Gene_based', 'Region_based'] scale: bool, optional Whether to scale the enrichment prior to the clustering. Default: False out_key_suffix: str, optional Suffix to add to the variable name to store the values (at scplus_obj.uns['RSS']) """ if scale: data_mat = pd.concat([pd.DataFrame(sklearn.preprocessing.StandardScaler().fit_transform( scplus_obj.uns[auc_key][x].T), index=scplus_obj.uns[auc_key][x].T.index.to_list(), columns=scplus_obj.uns[auc_key][x].T.columns) for x in signature_keys]).T else: data_mat = pd.concat([scplus_obj.uns[auc_key][x] for x in signature_keys], axis=1) if selected_regulons is not None: subset = [x for x in selected_regulons if x in data_mat.columns] data_mat = data_mat[subset] cell_data_series = scplus_obj.metadata_cell.loc[data_mat.index, variable] cell_data = list(cell_data_series.unique()) n_types = len(cell_data) regulons = list(data_mat.columns) n_regulons = len(regulons) rss_values = np.empty(shape=(n_types, n_regulons), dtype=np.float) def rss(aucs, labels): # jensenshannon function provides distance which is the sqrt of the JS divergence. return 1.0 - jensenshannon(aucs / aucs.sum(), labels / labels.sum()) for cidx, regulon_name in enumerate(regulons): for ridx, type in enumerate(cell_data): rss_values[ridx, cidx] = rss( data_mat[regulon_name], (cell_data_series == type).astype(int)) rss_values = pd.DataFrame( data=rss_values, index=cell_data, columns=regulons) if not 'RSS' in scplus_obj.uns.keys(): scplus_obj.uns['RSS'] = {} out_key = variable + out_key_suffix if not out_key in scplus_obj.uns['RSS'].keys(): scplus_obj.uns['RSS'][out_key] = {} scplus_obj.uns['RSS'][out_key] = rss_values
[docs]def plot_rss(scplus_obj: SCENICPLUS, rss_key: str, top_n: Optional[int] = 5, selected_groups: Optional[List[str]] = None, num_columns: Optional[int] = 1, figsize: Optional[Tuple[float, float]] = (6.4, 4.8), fontsize: Optional[int] = 12, save: str = None): """ Plot RSS values per group Parameters --------- scplus_obj: `class::SCENICPLUS` A SCENICPLUS object with eRegulons AUC computed. rss_key: str, optional Key to extract RSS values from. top_n: int, optional Number of top eRegulons to highlight. selected_groups: List, optional Groups to plot. Default: None (all) num_columns: int, optional Number of columns for multiplotting figsize: tuple, optional Size of the figure. If num_columns is 1, this is the size for each figure; if num_columns is above 1, this is the overall size of the figure (if keeping default, it will be the size of each subplot in the figure). Default: (6.4, 4.8) fontsize: int, optional Size of the eRegulons names in plot. save: str, optional Path to save plot. Default: None. """ data_mat = scplus_obj.uns['RSS'][rss_key] if selected_groups is None: cats = sorted(data_mat.index.tolist()) else: cats = selected_groups if num_columns > 1: num_rows = int(np.ceil(len(cats) / num_columns)) if figsize == (6.4, 4.8): figsize = (6.4 * num_columns, 4.8 * num_rows) i = 1 fig = plt.figure(figsize=figsize) pdf = None if (save is not None) & (num_columns == 1): pdf = matplotlib.backends.backend_pdf.PdfPages(save) for c in cats: x = data_mat.T[c] if num_columns > 1: ax = fig.add_subplot(num_rows, num_columns, i) i = i + 1 else: fig = plt.figure(figsize=figsize) ax = plt.axes() _plot_rss_internal(data_mat, c, top_n=top_n, max_n=None, ax=ax) ax.set_ylim(x.min()-(x.max()-x.min())*0.05, x.max()+(x.max()-x.min())*0.05) for t in ax.texts: t.set_fontsize(fontsize) ax.set_ylabel('') ax.set_xlabel('') adjust_text(ax.texts, autoalign='xy', ha='right', va='bottom', arrowprops=dict( arrowstyle='-', color='lightgrey'), precision=0.001) if num_columns == 1: fig.text(0.5, 0.0, 'eRegulon rank', ha='center', va='center', size='x-large') fig.text(0.00, 0.5, 'eRegulon specificity score (eRSS)', ha='center', va='center', rotation='vertical', size='x-large') plt.tight_layout() plt.rcParams.update({ 'figure.autolayout': True, 'figure.titlesize': 'large', 'axes.labelsize': 'medium', 'axes.titlesize': 'large', 'xtick.labelsize': 'medium', 'ytick.labelsize': 'medium' }) if save is not None: pdf.savefig(fig, bbox_inches='tight') plt.show() if num_columns > 1: fig.text(0.5, 0.0, 'eRegulon rank', ha='center', va='center', size='x-large') fig.text(0.00, 0.5, 'eRegulon specificity score (eRSS)', ha='center', va='center', rotation='vertical', size='x-large') plt.tight_layout() plt.rcParams.update({ 'figure.autolayout': True, 'figure.titlesize': 'large', 'axes.labelsize': 'medium', 'axes.titlesize': 'large', 'xtick.labelsize': 'medium', 'ytick.labelsize': 'medium' }) if save is not None: fig.savefig(save, bbox_inches='tight') plt.show() if (save is not None) & (num_columns == 1): pdf = pdf.close()
def _plot_rss_internal(rss, cell_type, top_n=5, max_n=None, ax=None): """ Helper function to plot RSS """ if ax is None: _, ax = plt.subplots(1, 1, figsize=(4, 4)) if max_n is None: max_n = rss.shape[1] data = rss.T[cell_type].sort_values(ascending=False)[0:max_n] ax.plot(np.arange(len(data)), data, '.') ax.set_ylim([floor(data.min() * 100.0) / 100.0, ceil(data.max() * 100.0) / 100.0]) ax.set_title(cell_type) ax.set_xticklabels([]) font = { 'color': 'red', 'weight': 'normal' } for idx, (regulon_name, rss_val) in enumerate(zip(data[0:top_n].index, data[0:top_n].values)): ax.plot([idx, idx], [rss_val, rss_val], 'r.') ax.text( idx + (max_n / 25), rss_val, regulon_name, fontdict=font, horizontalalignment='left', verticalalignment='center', )