from __future__ import division, print_function
__all__ = ['feature_importance_map', 'confusion_matrices',
'freq_hist_misclassifications', 'compare_distributions',
'compare_misclf_pairwise_parallel_coord_plot',
'compare_misclf_pairwise', ]
import itertools
import warnings
from sys import version_info
import matplotlib.pyplot as plt
import numpy.matlib # to force
import numpy as np
import scipy.stats
from matplotlib import cm
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.colors import ListedColormap
if version_info.major > 2:
from neuropredict import config as cfg
from neuropredict.utils import round_
else:
raise NotImplementedError('neuropredict requires Python 3+.')
[docs]def feature_importance_map(feat_imp,
method_labels,
base_output_path,
feature_names=None,
show_distr=False,
plot_title='feature importance',
show_all=False):
"""
Generates a map/barplot of feature importance.
feat_imp must be a list of length num_datasets,
each an ndarray of size [num_repetitions,num_features[idx]]
where num_features[idx] refers to the dimensionality of n-th dataset.
metho_names must be a list of strings of the same length as feat_imp.
feature_names must be a list (of ndarrays of strings) same size as feat_imp,
each element being another list of labels corresponding to num_features[idx].
Parameters
----------
feat_imp : list
List of numpy arrays, each of length num_features
method_labels : list
List of names for each method (or feature set).
base_output_path : str
feature_names : list
List of names for each feature.
show_distr : bool
plots the distribution (over different trials of cross-validation)
of feature importance for each feature.
plot_title : str
Title of the importance map figure.
show_all : bool
If true, this will attempt to show the importance values for all the
features. Be advised if you have more than 50 features, the figure would
illegible. The default is to show only few important features (ranked by
their median importance), when there is more than 25 features.
Returns
-------
"""
num_datasets = len(feat_imp)
if num_datasets > 1:
fig, ax = plt.subplots(num_datasets, 1,
sharex=True,
figsize=[9, 12])
ax = ax.flatten()
else:
fig, ax_h = plt.subplots(figsize=[9, 12])
ax = [ax_h] # to support indexing
for dd in range(num_datasets):
scaled_imp = feat_imp[dd]
# some models do not provide importance values
is_nan_imp_values = np.isnan(scaled_imp.flatten())
if np.all(is_nan_imp_values):
print('unusable feature importance values for {} : '
'all NaNs!\n Skipping it.'.format(method_labels[dd]))
continue
num_features = feat_imp[dd].shape[1]
if feature_names is None:
feat_labels = np.array(["f{}".format(ix) for ix in range(num_features)])
else:
feat_labels = feature_names[dd]
if len(feat_labels) < num_features:
raise ValueError('Insufficient number of feature labels.')
usable_imp, freq_sel, median_feat_imp, stdev_feat_imp, conf_interval \
= compute_median_std_feat_imp(scaled_imp)
if num_features > cfg.max_allowed_num_features_importance_map:
print('Too many (n={}) features detected for {}.\n'
'Showing only the top {} to make the map legible.\n'
'Use the exported results to plot make importance maps.'
''.format(num_features, method_labels[dd],
cfg.max_allowed_num_features_importance_map))
sort_indices = np.argsort(median_feat_imp)[
::-1] # ascending order, then reversing
selected_idx_display = sort_indices[
:cfg.max_allowed_num_features_importance_map]
usable_imp_display = [usable_imp[ix] for ix in selected_idx_display]
selected_feat_imp = median_feat_imp[selected_idx_display]
selected_imp_stdev = stdev_feat_imp[selected_idx_display]
selected_conf_interval = conf_interval[selected_idx_display]
selected_feat_names = feat_labels[selected_idx_display]
effective_num_features = cfg.max_allowed_num_features_importance_map
else:
selected_idx_display = None
usable_imp_display = usable_imp
selected_feat_imp = median_feat_imp
selected_imp_stdev = stdev_feat_imp
selected_conf_interval = conf_interval
selected_feat_names = feat_labels
effective_num_features = num_features
feat_ticks = range(effective_num_features)
plt.sca(ax[dd])
# checking whether all features selected equal number of times (needed for
# violing pl
# violin distribution or stick bar plot?
if show_distr:
line_coll = ax[dd].violinplot(usable_imp_display,
positions=feat_ticks,
widths=0.8, bw_method=0.2,
vert=False,
showmedians=True, showextrema=False)
cmap = cm.get_cmap(cfg.CMAP_FEAT_IMP, effective_num_features)
for cc, ln in enumerate(line_coll['bodies']):
ln.set_facecolor(cmap(cc))
# ln.set_label(feat_labels[cc])
else:
barwidth = max(0.05, min(0.9, 8.0 / effective_num_features))
rects = ax[dd].barh(feat_ticks, selected_feat_imp,
height=barwidth, xerr=selected_conf_interval)
ax[dd].tick_params(axis='both', which='major', labelsize=10)
ax[dd].grid(axis='x', which='major')
ax[dd].set_yticks(feat_ticks)
ax[dd].set_ylim(np.min(feat_ticks) - 1, np.max(feat_ticks) + 1)
ax[dd].set_yticklabels(selected_feat_names) # , rotation=45) # 'vertical'
ax[dd].set_title(method_labels[dd])
print()
if num_datasets < len(ax):
fig.delaxes(ax[-1])
plt.xlabel('feature importance', fontsize=14)
# plt.suptitle(plot_title, fontsize=16)
fig.tight_layout()
base_output_path.replace(' ', '_')
pp1 = PdfPages(base_output_path + '.pdf')
pp1.savefig()
pp1.close()
plt.close()
return
def mean_confidence_interval(data, confidence=0.95):
"""Computes mean and CI
From: https://stackoverflow.com/questions/15033511/compute-a-confidence-interval-from-sample-data/
"""
arr = 1.0 * np.array(data, dtype=float)
n = len(arr)
mu = np.mean(arr)
se = scipy.stats.sem(arr)
h = se * scipy.stats.t._ppf((1 + confidence) / 2., n - 1)
return mu, h
def compute_median_std_feat_imp(imp,
ignore_value=cfg.importance_value_to_treated_as_not_selected,
never_tested_value=cfg.importance_value_never_tested,
never_tested_stdev=cfg.importance_value_never_tested_stdev):
"Calculates the median/SD of feature importance, ignoring NaNs and zeros"
num_features = imp.shape[1]
usable_values = list()
freq_selection = list()
conf_interval = list()
median_values = list()
stdev_values = list()
for feat in range(num_features):
index_nan_or_0 = np.logical_or(np.isnan(imp[:, feat]),
np.isclose(ignore_value, imp[:, feat],
rtol=1e-4, atol=1e-5))
index_usable = np.logical_not(index_nan_or_0)
this_feat_values = imp[index_usable, feat].flatten()
if len(this_feat_values) > 0:
usable_values.append(this_feat_values)
freq_selection.append(len(this_feat_values))
median_values.append(np.median(this_feat_values))
stdev_values.append(np.std(this_feat_values))
mean_, CI_sym = mean_confidence_interval(this_feat_values)
conf_interval.append(CI_sym)
else: # never ever selected
usable_values.append(None)
freq_selection.append(0)
median_values.append(never_tested_value)
stdev_values.append(never_tested_stdev)
conf_interval.append(never_tested_stdev)
return usable_values, np.array(freq_selection), \
np.array(median_values), np.array(stdev_values), np.array(conf_interval)
[docs]def confusion_matrices(cfmat_array, class_labels,
method_names, base_output_path,
cmap=cfg.CMAP_CONFMATX):
"""
Display routine for the confusion matrix.
Entries in confusin matrix can be turned into percentages with
`display_perc=True`.
Use a separate method to iteratve over multiple datasets.
confusion_matrix dime: [num_repetitions, num_classes, num_classes, num_datasets]
Parameters
----------
cfmat_array
class_labels
method_names
base_output_path
cmap
Returns
-------
"""
num_datasets = cfmat_array.shape[3]
num_classes = cfmat_array.shape[1]
if num_classes != cfmat_array.shape[2]:
raise ValueError("Invalid dimensions of confusion matrix.\nNeed "
"[num_repetitions, num_classes, num_classes, num_datasets]."
" Given shape : {}".format(cfmat_array.shape))
np.set_printoptions(2)
for dd in range(num_datasets):
output_path = base_output_path + '_' + str(method_names[dd])
output_path.replace(' ', '_')
avg_cfmat = mean_over_cv_trials(cfmat_array[:, :, :, dd], num_classes)
fig, ax = plt.subplots(figsize=cfg.COMMON_FIG_SIZE)
vis_single_confusion_matrix(avg_cfmat, class_labels=class_labels,
title=method_names[dd], cmap=cmap, ax=ax)
fig.tight_layout()
pp1 = PdfPages(output_path + '.pdf')
pp1.savefig()
pp1.close()
plt.close()
return
def vis_single_confusion_matrix(conf_mat,
class_labels=('A', 'B'),
title='Confusion Matrix',
cmap='cividis',
ax=None,
y_label='True class',
x_label='Predicted class'):
"""Helper to plot a single CM"""
if not isinstance(cmap, ListedColormap):
cmap = cm.get_cmap(cmap)
annot_color_low_values = cmap.colors[0]
annot_color_high_values = cmap.colors[-1]
else:
annot_color_low_values = 'white'
annot_color_high_values = 'black'
if ax is None:
fig, ax = plt.subplots(figsize=cfg.COMMON_FIG_SIZE)
num_classes = conf_mat.shape[0]
if num_classes != conf_mat.shape[1]:
print('Conf matrix shape is not square!')
if len(class_labels) < num_classes:
print('Need {} labels. Given {}'.format(num_classes, len(class_labels)))
im = ax.imshow(conf_mat, interpolation='nearest', cmap=cmap)
plt.colorbar(im, fraction=0.046, pad=0.04)
tick_marks = np.arange(len(class_labels))
plt.xticks(tick_marks, class_labels, rotation=45)
plt.yticks(tick_marks, class_labels)
left, right, bottom, top = im.get_extent()
ax.set(xlim=(left, right), ylim=(bottom, top),
xlabel=x_label, ylabel=y_label, title=title)
max_val = conf_mat.max()
val_25p, val_75p = max_val/4, (3*max_val)/4
for i, j in itertools.product(range(num_classes), range(num_classes)):
try:
if conf_mat[i, j] >= val_75p:
val_annot_color = annot_color_low_values
elif conf_mat[i, j] >= val_25p:
val_annot_color = 'white' #hardcoded!
else:
val_annot_color = annot_color_high_values
except:
val_annot_color = 'black'
annot_str = "{:.{prec}f}%".format(conf_mat[i, j], prec=cfg.PRECISION_METRICS)
ax.text(j, i, annot_str, color=val_annot_color,
horizontalalignment="center") # , fontsize='large')
plt.tight_layout()
return ax
def mean_over_cv_trials(conf_mat_array, num_classes):
"""
Common method to average over different CV trials,
to ensure it is done over the right axis
(the first one - axis=0, column 1) for all confusion matrix methods.
"""
if conf_mat_array.shape[1] != num_classes or \
conf_mat_array.shape[2] != num_classes or \
len(conf_mat_array.shape) != 3:
raise ValueError('Invalid shape of confusion matrix array! '
'It must be num_rep x {nc} x {nc}'.format(nc=num_classes))
# can not expect nan's here; If so, its a bug somewhere else
avg_cfmat = np.mean(conf_mat_array, axis=0)
# percentage confusion relative to class size
class_size_elementwise = np.transpose(np.matlib.repmat(np.sum(avg_cfmat, axis=1),
num_classes, 1))
avg_cfmat_perc = np.divide(avg_cfmat, class_size_elementwise)
# making it human readable : 0-100%, with only 2 decimals
return np.around(100*avg_cfmat_perc, decimals=cfg.PRECISION_METRICS)
def compute_pairwise_misclf(cfmat_array):
"Merely computes the misclassification rates, for pairs of classes."
num_datasets = cfmat_array.shape[3]
num_classes = cfmat_array.shape[1]
if num_classes != cfmat_array.shape[2]:
raise ValueError("Invalid dimensions of confusion matrix.\n Shape must be: "
"[num_repetitions, num_classes, num_classes, num_datasets]")
num_misclf_axes = num_classes * (num_classes - 1)
avg_cfmat = np.full([num_datasets, num_classes, num_classes], np.nan)
misclf_rate = np.full([num_datasets, num_misclf_axes], np.nan)
for dd in range(num_datasets):
# mean confusion over CV trials
avg_cfmat[dd, :, :] = mean_over_cv_trials(cfmat_array[:, :, :, dd],
num_classes)
count = 0
for ii, jj in itertools.product(range(num_classes), range(num_classes)):
if ii != jj:
misclf_rate[dd, count] = avg_cfmat[dd, ii, jj]
count = count + 1
return avg_cfmat, misclf_rate
def label_misclf_axes(class_labels):
"""Method to generate labels for misclf axes!"""
num_classes = len(class_labels)
labels = list()
# iteration below match that in compute_pairwise_misclf() exactly!!
for ii, jj in itertools.product(range(num_classes), range(num_classes)):
if ii != jj:
labels.append("{} --> {}".format(class_labels[ii], class_labels[jj]))
return labels
[docs]def compare_misclf_pairwise_parallel_coord_plot(cfmat_array,
class_labels, method_labels,
out_path):
"""
Produces a parallel coordinate plot (unravelling the cobweb plot)
comparing the the misclassfication rate of all feature sets
for different pairwise classifications.
Parameters
----------
cfmat_array
class_labels
method_labels
out_path
Returns
-------
"""
num_datasets = cfmat_array.shape[3]
num_classes = cfmat_array.shape[1]
if num_classes != cfmat_array.shape[2]:
raise ValueError("Invalid dimensions of confusion matrix.\n Shape must be: "
"[num_repetitions, num_classes, num_classes, num_datasets]")
num_misclf_axes = num_classes * (num_classes - 1)
out_path.replace(' ', '_')
avg_cfmat, misclf_rate = compute_pairwise_misclf(cfmat_array)
misclf_ax_labels = label_misclf_axes(class_labels)
fig = plt.figure(figsize=cfg.COMMON_FIG_SIZE)
ax = fig.add_subplot(1, 1, 1)
cmap = cm.get_cmap(cfg.CMAP_DATASETS, num_datasets)
misclf_ax_labels_loc = list()
handles = list()
misclf_ax_labels_loc = range(1, num_misclf_axes + 1)
for dd in range(num_datasets):
h = ax.plot(misclf_ax_labels_loc, misclf_rate[dd, :], color=cmap(dd))
handles.append(h[0])
ax.legend(handles, method_labels)
ax.set_xticks(misclf_ax_labels_loc)
ax.set_xticklabels(misclf_ax_labels)
ax.set_ylabel('misclassification rate (in %)')
ax.set_xlabel('misclassification type')
ax.set_xlim([0.75, num_misclf_axes + 0.25])
fig.tight_layout()
pp1 = PdfPages(out_path + '.pdf')
pp1.savefig()
pp1.close()
plt.close()
return
def compare_misclf_pairwise_barplot(cfmat_array, class_labels, method_labels,
out_path):
"""
Produces a bar plot comparing the the misclassfication rate of all feature
sets for different pairwise classifications.
Parameters
----------
cfmat_array
class_labels
method_labels
out_path
Returns
-------
"""
num_datasets = cfmat_array.shape[3]
num_classes = cfmat_array.shape[1]
if num_classes != cfmat_array.shape[2]:
raise ValueError("Invalid dimensions of confusion matrix.\n Shape must be: "
"[num_repetitions, num_classes, num_classes, num_datasets]")
num_misclf_axes = num_classes * (num_classes - 1)
out_path.replace(' ', '_')
avg_cfmat, misclf_rate = compute_pairwise_misclf(cfmat_array)
misclf_ax_labels = label_misclf_axes(class_labels)
fig = plt.figure(figsize=cfg.COMMON_FIG_SIZE)
ax = fig.add_subplot(1, 1, 1)
cmap = cm.get_cmap(cfg.CMAP_DATASETS, num_datasets)
misclf_ax_labels_loc = list()
handles = list()
for mca in range(num_misclf_axes):
x_pos = np.array(range(num_datasets)) + mca * (num_datasets + 1)
h = ax.bar(x_pos, misclf_rate[:, mca], color=cmap(range(num_datasets)))
handles.append(h)
misclf_ax_labels_loc.append(np.mean(x_pos))
ax.legend(handles[0], method_labels)
ax.set_xticks(misclf_ax_labels_loc)
ax.set_xticklabels(misclf_ax_labels)
ax.set_ylabel('misclassification rate (in %)')
ax.set_xlabel('misclassification type')
fig.tight_layout()
pp1 = PdfPages(out_path + '.pdf')
pp1.savefig()
pp1.close()
plt.close()
return
[docs]def compare_misclf_pairwise(cfmat_array, class_labels, method_labels, out_path):
"""
Produces a cobweb plot comparing the the misclassfication rate
of all feature sets for different pairwise classifications.
Parameters
----------
cfmat_array
class_labels
method_labels
out_path
Returns
-------
"""
num_datasets = cfmat_array.shape[3]
num_classes = cfmat_array.shape[1]
if num_classes != cfmat_array.shape[2]:
raise ValueError("Invalid dimensions of confusion matrix.\n Shape must be: "
"[num_repetitions, num_classes, num_classes, num_datasets]")
num_misclf_axes = num_classes * (num_classes - 1)
avg_cfmat, misclf_rate = compute_pairwise_misclf(cfmat_array)
misclf_ax_labels = label_misclf_axes(class_labels)
theta = 2 * np.pi * np.linspace(0, 1 - 1.0 / num_misclf_axes, num_misclf_axes)
fig = plt.figure(figsize=[9, 9])
cmap = cm.get_cmap(cfg.CMAP_DATASETS, num_datasets)
ax = fig.add_subplot(1, 1, 1, projection='polar')
# clock-wise
ax.set_theta_direction(-1)
# starting at top
ax.set_theta_offset(np.pi / 2.0)
for dd in range(num_datasets):
ax.plot(theta, misclf_rate[dd, :], color=cmap(dd),
linewidth=cfg.LINE_WIDTH)
# connecting the last axis to the first to close the loop
ax.plot([theta[-1], theta[0]],
[misclf_rate[dd, -1], misclf_rate[dd, 0]],
color=cmap(dd), linewidth=cfg.LINE_WIDTH)
lbl_handles = ax.set_thetagrids(theta * 360 / (2 * np.pi),
labels=misclf_ax_labels,
va='top',
ha='center',
fontsize=cfg.FONT_SIZE)
ax.grid(linewidth=cfg.LINE_WIDTH)
tick_perc = ['{:.2f}%'.format(tt) for tt in ax.get_yticks()]
ax.set_yticklabels(tick_perc, fontsize=cfg.FONT_SIZE)
# ax.set_yticks(np.arange(100 / num_classes, 100, 10))
plt.tick_params(axis='both', which='major')
# putting legends outside the plot below.
fig.subplots_adjust(bottom=0.2)
leg = ax.legend(method_labels, ncol=2, loc=9,
bbox_to_anchor=(0.5, -0.1))
# setting colors manually as plot has been through arbitray jumps
for ix, lh in enumerate(leg.legendHandles):
lh.set_color(cmap(ix))
leg.set_frame_on(False) # making leg background transparent
fig.tight_layout()
out_path.replace(' ', '_')
# fig.savefig(out_path + '.png', transparent=True, dpi=300,
# bbox_extra_artists=(leg,), bbox_inches='tight')
fig.savefig(out_path + '.pdf',
bbox_extra_artists=(leg,), bbox_inches='tight')
plt.close()
return
def compute_perc_misclf_per_sample(num_times_misclfd, num_times_tested):
"Utility function to compute subject-wise percentage of misclassification."
num_samples = len(num_times_tested[0].keys())
num_datasets = len(num_times_tested)
perc_misclsfd = [None] * num_datasets
never_tested = list() # since train/test samples are the same
# across different feature sets
for dd in range(num_datasets):
perc_misclsfd[dd] = dict()
for sid in num_times_misclfd[dd].keys():
if num_times_tested[dd][sid] > 0:
perc_misclsfd[dd][sid] = np.float64(num_times_misclfd[dd][sid]) \
/ np.float64(num_times_tested[dd][sid])
else:
never_tested.append(sid)
never_tested = list(set(never_tested))
return perc_misclsfd, never_tested, num_samples, num_datasets
[docs]def freq_hist_misclassifications(num_times_misclfd, num_times_tested, method_labels,
outpath, separate_plots=False):
"""
Summary of most/least frequently misclassified subjects for further analysis
"""
num_bins = cfg.MISCLF_HIST_NUM_BINS
count_thresh = cfg.MISCLF_PERC_THRESH
def annnotate_plots(ax_h):
"Adds axes labels and helpful highlights"
cur_ylim = ax_h.get_ylim()
line_thresh, = ax_h.plot([count_thresh, count_thresh],
cur_ylim, 'k--',
linewidth=cfg.MISCLF_HIST_ANNOT_LINEWIDTH)
ax_h.set_ylim(cur_ylim)
ax_h.set_ylabel('number of subjects')
ax_h.set_xlabel('percentage of misclassification')
# computing the percentage of misclassification per subject
perc_misclsfd, never_tested, num_samples, num_datasets = \
compute_perc_misclf_per_sample(num_times_misclfd, num_times_tested)
if len(never_tested) > 0:
warnings.warn(' {} subjects were never selected for testing.'
''.format(len(never_tested)))
nvpath = outpath + '_never_tested_samples.txt'
with open(nvpath, 'w') as nvf:
nvf.writelines('\n'.join(never_tested))
# plot frequency histogram per dataset
if num_datasets > 1 and separate_plots:
fig, ax = plt.subplots(int(np.ceil(num_datasets / 2.0)), 2,
sharey=True,
figsize=[12, 9])
ax = ax.flatten()
else:
fig, ax_h = plt.subplots(figsize=[12, 9])
for dd in range(num_datasets):
# calculating perc of most frequently misclassified subjects in each dataset
most_freq_misclfd = [sid for sid in perc_misclsfd[dd].keys()
if perc_misclsfd[dd][sid] > count_thresh]
perc_most_freq_misclsfd = 100 * len(most_freq_misclfd) / len(
perc_misclsfd[dd])
this_method_label = "{} - {:.1f}%" \
"".format(method_labels[dd], perc_most_freq_misclsfd)
if dd == 0:
this_method_label = this_method_label + 'most frequently misclassfied'
# for plotting
if num_datasets > 1 and separate_plots:
ax_h = ax[dd]
plt.sca(ax_h)
ax_h.hist(perc_misclsfd[dd].values(), num_bins)
else:
ax_h.hist(list(perc_misclsfd[dd].values()), num_bins,
histtype='stepfilled', alpha=cfg.MISCLF_HIST_ALPHA,
label=this_method_label)
# for annotation
if num_datasets > 1 and separate_plots:
ax_h.set_title(this_method_label)
annnotate_plots(ax_h)
else:
if dd == num_datasets - 1:
ax_h.legend(loc=2)
annnotate_plots(ax_h)
txt_path = '_'.join([outpath, method_labels[dd], 'ids_most_frequent.txt'])
with open(txt_path, 'w') as mfm:
mfm.writelines('\n'.join(most_freq_misclfd))
if separate_plots and num_datasets < len(ax):
fig.delaxes(ax[-1])
fig.tight_layout()
pp1 = PdfPages(outpath + '_frequency_histogram.pdf')
pp1.savefig()
pp1.close()
plt.close()
return
[docs]def compare_distributions(metric, labels, output_path, y_label='metric',
horiz_line_loc=None, horiz_line_label=None,
upper_lim_y=1.01, lower_lim_y=-0.01,
ytick_step=None):
"""
Distribution plots of various metrics such as balanced accuracy!
metric is expected to be ndarray of size [num_repetitions, num_datasets]
upper_lim_y = None would make it automatic and adapt to given metric distribution
upper_lim_y = 1.01 and ytick_step = 0.05 are targeted for Accuracy/AUC metrics,
in classification applications
"""
if not np.isfinite(metric).all():
raise ValueError('NaN or Inf found in the input metric array!')
num_repetitions = metric.shape[0]
num_datasets = metric.shape[1]
if len(labels) < num_datasets:
raise ValueError("Insufficient number of labels for {} features!"
"".format(num_datasets))
method_ticks = 1.0 + np.arange(num_datasets)
fig, ax = plt.subplots(figsize=cfg.COMMON_FIG_SIZE)
line_coll = ax.violinplot(metric, widths=cfg.violin_width,
bw_method=cfg.violin_bandwidth,
showmedians=True, showextrema=False,
positions=method_ticks)
cmap = cm.get_cmap(cfg.CMAP_DATASETS, num_datasets)
for cc, ln in enumerate(line_coll['bodies']):
ln.set_facecolor(cmap(cc))
ln.set_label(labels[cc])
ax.tick_params(axis='both', which='major', labelsize=15)
ax.grid(axis='y', which='major', linewidth=cfg.LINE_WIDTH, zorder=0)
# ---- setting y-axis limits
if upper_lim_y is not None:
upper_lim = round_(np.min([upper_lim_y, metric.max()]))
else:
upper_lim = round_(metric.max())
if lower_lim_y is not None:
lower_lim = round_(np.max([lower_lim_y, metric.min()]))
else:
lower_lim = round_(metric.min())
ax.set_ylim(lower_lim, upper_lim)
# ----
ax.set_xlim(np.min(method_ticks) - 1, np.max(method_ticks) + 1)
ax.set_xticks(method_ticks)
# ax.set_xticklabels(labels, rotation=45) # 'vertical'
if ytick_step is None:
ytick_loc = ax.get_yticks()
else:
ytick_loc = np.arange(lower_lim, upper_lim, ytick_step)
if horiz_line_loc is not None:
ytick_loc = np.append(ytick_loc, horiz_line_loc)
plt.text(0.05, horiz_line_loc, horiz_line_label)
ytick_loc = round_(ytick_loc)
ax.set_yticks(ytick_loc)
ax.set_yticklabels(ytick_loc)
plt.ylabel(y_label, fontsize=cfg.FONT_SIZE)
plt.tick_params(axis='both', which='major', labelsize=cfg.FONT_SIZE)
# numbered labels
numbered_labels = ['{} {}'.format(int(ix), lbl)
for ix, lbl in zip(method_ticks, labels)]
# putting legends outside the plot below.
fig.subplots_adjust(bottom=0.2)
leg = ax.legend(numbered_labels, ncol=2, loc=9, bbox_to_anchor=(0.5, -0.1))
# setting colors manually as plot has been through arbitray jumps
for ix, lh in enumerate(leg.legendHandles):
lh.set_color(cmap(ix))
leg.set_frame_on(False) # making leg background transparent
# fig.savefig(output_path + '.png', transparent=True, dpi=300,
# bbox_extra_artists=(leg,), bbox_inches='tight')
fig.savefig(output_path + '.pdf', bbox_extra_artists=(leg,), bbox_inches='tight')
plt.close()
return
def multi_scatter_plot(y_data, x_data, fig_out_path,
y_label='Residuals',
x_label='True targets',
show_zero_line=False,
trend_line=None,
show_hist=True):
"""Important diagnostic plot for predictive regression analysis"""
if show_hist:
fig, axes = plt.subplots(nrows=1, ncols=2, sharey=True,
gridspec_kw=dict(width_ratios=(4.5, 1)),
figsize=cfg.COMMON_FIG_SIZE)
ax = axes[0]
hist_ax = axes[1]
else:
fig, ax = plt.subplots(figsize=cfg.COMMON_FIG_SIZE)
num_datasets = len(y_data)
from matplotlib.cm import get_cmap
cmap = get_cmap(cfg.CMAP_DATASETS, max(num_datasets + 1, 9))
colors = np.array(cmap.colors)
ds_labels = list(y_data.keys())
for index, ds_id in enumerate(ds_labels):
color = colors[index, np.newaxis, :]
h_path_coll = ax.scatter(x_data[ds_id], y_data[ds_id],
alpha=cfg.alpha_regression_targets,
label=ds_id, c=color)
if show_hist:
hist_ax.hist(y_data[ds_id], density=True, orientation="horizontal",
color=color, bins=cfg.num_bins_hist,
alpha=cfg.alpha_regression_targets, )
if show_hist:
hist_ax.yaxis.tick_right()
hist_ax.grid(False, axis="x")
hist_ax.set_xlabel("Density")
# switching focus to the right axis
plt.sca(ax)
leg = ax.legend(ds_labels)
extra_artists = [leg, ]
if show_zero_line: # helpful for residuals plot
baseline = ax.axhline(y=0, color='black')
extra_artists.append(baseline)
baseline_hist = hist_ax.axhline(y=0, c='black')
# extra_artists.append(baseline_hist)
if trend_line is not None:
tline = ax.axhline(y=trend_line, color='black', label='median of medians')
extra_artists.append(tline)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
fig.tight_layout()
fig.savefig(fig_out_path + '.pdf',
# bbox_extra_artists=extra_artists,
bbox_inches='tight')
plt.close()
if __name__ == '__main__':
pass