Source code for WORC.plotting.plot_ROC

#!/usr/bin/env python

# Copyright 2016-2021 Biomedical Imaging Group Rotterdam, Departments of
# Medical Informatics and Radiology, Erasmus MC, Rotterdam, The Netherlands
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt

import tikzplotlib
import pandas as pd
import argparse
from WORC.plotting.compute_CI import compute_confidence as CI
import numpy as np
from sklearn.metrics import roc_auc_score, auc
from sklearn.metrics import precision_recall_curve
import csv
from WORC.plotting.plot_estimator_performance import plot_estimator_performance


[docs]def plot_single_ROC(y_truth, y_score, verbose=False, returnplot=False): ''' Get the False Positive Ratio (FPR) and True Positive Ratio (TPR) for the ground truth and score of a single estimator. These ratios can be used to plot a Receiver Operator Characteristic (ROC) curve. ''' # Sort both lists based on the scores y_truth = np.asarray(y_truth) y_truth = np.int_(y_truth) y_score = np.asarray(y_score) inds = y_score.argsort() y_truth_sorted = y_truth[inds] y_score = y_score[inds] # Compute the TPR and FPR for all possible thresholds FP = 0 TP = 0 fpr = list() tpr = list() thresholds = list() fprev = -np.inf i = 0 N = float(np.bincount(y_truth)[0]) if len(np.bincount(y_truth)) == 1: # No class = 1 present. P = 0 else: P = float(np.bincount(y_truth)[1]) if N == 0: print('[WORC Warning] No negative class samples found, cannot determine ROC. Skipping iteration.') return fpr, tpr, thresholds elif P == 0: print('[WORC Warning] No positive class samples found, cannot determine ROC. Skipping iteration.') return fpr, tpr, thresholds while i < len(y_truth_sorted): if y_score[i] != fprev: fpr.append(1 - FP/N) tpr.append(1 - TP/P) thresholds.append(y_score[i]) fprev = y_score[i] if y_truth_sorted[i] == 1: TP += 1 else: FP += 1 i += 1 if verbose or returnplot: roc_auc = auc(fpr, tpr) f = plt.figure() ax = plt.subplot(111) lw = 2 ax.plot(fpr, tpr, color='darkorange', lw=lw, label='ROC curve (area = %0.2f)' % roc_auc) ax.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--') plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.0]) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('Receiver operating characteristic example') plt.legend(loc="lower right") if not returnplot: return fpr[::-1], tpr[::-1], thresholds[::-1] else: return fpr[::-1], tpr[::-1], thresholds[::-1], f
[docs]def plot_single_PRC(y_truth, y_score, verbose=False, returnplot=False): ''' Get the precision and recall (=true positive rate) for the ground truth and score of a single estimator. These ratios can be used to plot a Precision Recall Curve (ROC). ''' # Sort both lists based on the scores y_truth = np.asarray(y_truth) y_truth = np.int_(y_truth) y_score = np.asarray(y_score) inds = y_score.argsort() y_truth_sorted = y_truth[inds] y_score = y_score[inds] # Compute the TPR and FPR for all possible thresholds N = float(np.bincount(y_truth)[0]) if len(np.bincount(y_truth)) == 1: # No class = 1 present. P = 0 else: P = float(np.bincount(y_truth)[1]) if N == 0: print('[WORC Warning] No negative class samples found, cannot determine PRC. Skipping iteration.') return list(), list(), list() elif P == 0: print('[WORC Warning] No positive class samples found, cannot determine PRC. Skipping iteration.') return list(), list(), list() precision, tpr, thresholds =\ precision_recall_curve(y_truth_sorted, y_score) # Convert to lists precision = precision.tolist() tpr = tpr.tolist() thresholds = thresholds.tolist() if verbose or returnplot: prc_auc = auc(tpr, precision) f = plt.figure() ax = plt.subplot(111) lw = 2 ax.plot(tpr, precision, color='darkorange', lw=lw, label='PR curve (area = %0.2f)' % prc_auc) ax.plot([0, 1], [1, 0], color='navy', lw=lw, linestyle='--') plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.0]) plt.xlabel('Recall') plt.ylabel('Precision') plt.title('Precision-Recall curve') plt.legend(loc="lower right") if not returnplot: return tpr[::-1], precision[::-1], thresholds[::-1] else: return tpr[::-1], precision[::-1], thresholds[::-1], f
[docs]def curve_thresholding(metric1t, metric2t, thresholds, nsamples=20): ''' Construct metric1 and metric2 (either FPR and TPR, or TPR and Precision) ratios at different thresholds for the scores of an estimator. ''' # Combine all found thresholds in a list and create samples T = list() for t in thresholds: T.extend(t) T = sorted(T) tsamples = np.linspace(0, len(T) - 1, nsamples) # Compute the metric1s and metric2s at the sample points nrocs = len(metric1t) metric1 = np.zeros((nsamples, nrocs)) metric2 = np.zeros((nsamples, nrocs)) th = list() for n_sample, tidx in enumerate(tsamples): tidx = int(tidx) th.append(T[tidx]) for i_roc in range(0, nrocs): idx = 0 while float(thresholds[i_roc][idx]) > float(T[tidx]) and idx < (len(thresholds[i_roc]) - 1): idx += 1 metric1[n_sample, i_roc] = metric1t[i_roc][idx] metric2[n_sample, i_roc] = metric2t[i_roc][idx] return metric1, metric2, th
[docs]def plot_ROC_CIc(y_truth, y_score, N_1, N_2, plot='default', alpha=0.95, verbose=False, DEBUG=False, tsamples=20): ''' Plot a Receiver Operator Characteristic (ROC) curve with confidence intervals. tsamples: number of sample points on which to determine the confidence intervals. The sample pointsare used on the thresholds for y_score. ''' # Compute ROC curve and ROC area for each class fprt = list() tprt = list() roc_auc = list() thresholds = list() for yt, ys in zip(y_truth, y_score): fpr_temp, tpr_temp, thresholds_temp = plot_single_ROC(yt, ys) if fpr_temp: roc_auc.append(roc_auc_score(yt, ys)) fprt.append(fpr_temp) tprt.append(tpr_temp) thresholds.append(thresholds_temp) # Sample FPR and TPR at numerous points fpr, tpr, th = curve_thresholding(fprt, tprt, thresholds, tsamples) # Compute the confidence intervals for the ROC CIs_tpr = list() CIs_fpr = list() for i in range(0, tsamples): if i == 0: # Point (1, 1) is always in there, but shows as (nan, nan) CIs_fpr.append([1, 1]) CIs_tpr.append([1, 1]) else: cit_fpr = CI(fpr[i, :], N_1, N_2, alpha) CIs_fpr.append([cit_fpr[0], cit_fpr[1]]) cit_tpr = CI(tpr[i, :], N_1, N_2, alpha) CIs_tpr.append([cit_tpr[0], cit_tpr[1]]) # The point (0, 0) is also always there but not computed CIs_fpr.append([0, 0]) CIs_tpr.append([0, 0]) # Calculate also means of CIs after converting to array CIs_tpr = np.asarray(CIs_tpr) CIs_fpr = np.asarray(CIs_fpr) CIs_tpr_means = np.mean(CIs_tpr, axis=1).tolist() CIs_fpr_means = np.mean(CIs_fpr, axis=1).tolist() # compute AUC CI roc_auc = CI(roc_auc, N_1, N_2, alpha) f = plt.figure() lw = 2 subplot = f.add_subplot(111) subplot.plot(CIs_fpr_means, CIs_tpr_means, color='orange', lw=lw, label='ROC curve (AUC = (%0.2f, %0.2f))' % (roc_auc[0], roc_auc[1])) for i in range(0, len(CIs_fpr_means)): if CIs_tpr[i, 1] <= 1: ymax = CIs_tpr[i, 1] else: ymax = 1 if CIs_tpr[i, 0] <= 0: ymin = 0 else: ymin = CIs_tpr[i, 0] if CIs_tpr_means[i] <= 1: ymean = CIs_tpr_means[i] else: ymean = 1 if CIs_fpr[i, 1] <= 1: xmax = CIs_fpr[i, 1] else: xmax = 1 if CIs_fpr[i, 0] <= 0: xmin = 0 else: xmin = CIs_fpr[i, 0] if CIs_fpr_means[i] <= 1: xmean = CIs_fpr_means[i] else: xmean = 1 if DEBUG: print(xmin, xmax, ymean) print(ymin, ymax, xmean) subplot.plot([xmin, xmax], [ymean, ymean], color='black', alpha=0.15) subplot.plot([xmean, xmean], [ymin, ymax], color='black', alpha=0.15) subplot.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--') plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.0]) plt.xlabel('False Positive Rate (1 - Specificity)') plt.ylabel('True Positive Rate (Sensitivity)') plt.title('Receiver operating characteristic') plt.legend(loc="lower right") if verbose: plt.show() f = plt.figure() lw = 2 subplot = f.add_subplot(111) subplot.plot(CIs_fpr_means, CIs_tpr_means, color='darkorange', lw=lw, label='ROC curve (AUC = (%0.2f, %0.2f))' % (roc_auc[0], roc_auc[1])) for i in range(0, len(CIs_fpr_means)): if CIs_tpr[i, 1] <= 1: ymax = CIs_tpr[i, 1] else: ymax = 1 if CIs_tpr[i, 0] <= 0: ymin = 0 else: ymin = CIs_tpr[i, 0] if CIs_tpr_means[i] <= 1: ymean = CIs_tpr_means[i] else: ymean = 1 if CIs_fpr[i, 1] <= 1: xmax = CIs_fpr[i, 1] else: xmax = 1 if CIs_fpr[i, 0] <= 0: xmin = 0 else: xmin = CIs_fpr[i, 0] if CIs_fpr_means[i] <= 1: xmean = CIs_fpr_means[i] else: xmean = 1 subplot.plot([xmin, xmax], [ymean, ymean], color='black', alpha=0.15) subplot.plot([xmean, xmean], [ymin, ymax], color='black', alpha=0.15) subplot.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--') plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.0]) plt.xlabel('False Positive Rate (1 - Specificity)') plt.ylabel('True Positive Rate (Sensitivity)') plt.title('Receiver operating characteristic') plt.legend(loc="lower right") return f, CIs_fpr, CIs_tpr
[docs]def plot_PRC_CIc(y_truth, y_score, N_1, N_2, plot='default', alpha=0.95, verbose=False, DEBUG=False, tsamples=20): ''' Plot a Precision-Recall curve with confidence intervals. tsamples: number of sample points on which to determine the confidence intervals. The sample pointsare used on the thresholds for y_score. ''' # Compute PR curve and PR area for each class tprt = list() precisiont = list() prc_auc = list() thresholds = list() for yt, ys in zip(y_truth, y_score): tpr_temp, precision_temp, thresholds_temp = plot_single_PRC(yt, ys) if tpr_temp: prc_auc.append(auc(tpr_temp, precision_temp)) tprt.append(tpr_temp) precisiont.append(precision_temp) thresholds.append(thresholds_temp) # Sample TPR and precision at numerous points tpr, precisionr, th = curve_thresholding(tprt, precisiont, thresholds, tsamples) # Compute the confidence intervals for the ROC CIs_precisionr = list() CIs_tpr = list() for i in range(0, tsamples): if i == 0: # Point (0, 0) is always in there, but shows as (nan, nan) CIs_tpr.append([1, 1]) CIs_precisionr.append([0, 0]) else: cit_tpr = CI(tpr[i, :], N_1, N_2, alpha) CIs_tpr.append([cit_tpr[0], cit_tpr[1]]) cit_precisionr = CI(precisionr[i, :], N_1, N_2, alpha) CIs_precisionr.append([cit_precisionr[0], cit_precisionr[1]]) # The point (0, 1) is also always there but not computed CIs_tpr.append([0, 0]) CIs_precisionr.append([1, 1]) # Calculate also means of CIs after converting to array CIs_precisionr = np.asarray(CIs_precisionr) CIs_tpr = np.asarray(CIs_tpr) CIs_precisionr_means = np.mean(CIs_precisionr, axis=1).tolist() CIs_tpr_means = np.mean(CIs_tpr, axis=1).tolist() # compute AUC CI prc_auc = CI(prc_auc, N_1, N_2, alpha) f = plt.figure() lw = 2 subplot = f.add_subplot(111) subplot.plot(CIs_tpr_means, CIs_precisionr_means, color='orange', lw=lw, label='PR curve (AUC = (%0.2f, %0.2f))' % (prc_auc[0], prc_auc[1])) for i in range(0, len(CIs_tpr_means)): if CIs_precisionr[i, 1] <= 1: ymax = CIs_precisionr[i, 1] else: ymax = 1 if CIs_precisionr[i, 0] <= 0: ymin = 0 else: ymin = CIs_precisionr[i, 0] if CIs_precisionr_means[i] <= 1: ymean = CIs_precisionr_means[i] else: ymean = 1 if CIs_tpr[i, 1] <= 1: xmax = CIs_tpr[i, 1] else: xmax = 1 if CIs_tpr[i, 0] <= 0: xmin = 0 else: xmin = CIs_tpr[i, 0] if CIs_tpr_means[i] <= 1: xmean = CIs_tpr_means[i] else: xmean = 1 if DEBUG: print(xmin, xmax, ymean) print(ymin, ymax, xmean) subplot.plot([xmin, xmax], [ymean, ymean], color='black', alpha=0.15) subplot.plot([xmean, xmean], [ymin, ymax], color='black', alpha=0.15) subplot.plot([0, 1], [1, 0], color='navy', lw=lw, linestyle='--') plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.0]) plt.xlabel('Recall') plt.ylabel('Precision') plt.title('Precision-Recall curve') plt.legend(loc="lower right") if verbose: plt.show() f = plt.figure() lw = 2 subplot = f.add_subplot(111) subplot.plot(CIs_tpr_means, CIs_precisionr_means, color='darkorange', lw=lw, label='PRC curve (AUC = (%0.2f, %0.2f))' % (prc_auc[0], prc_auc[1])) for i in range(0, len(CIs_tpr_means)): if CIs_precisionr[i, 1] <= 1: ymax = CIs_precisionr[i, 1] else: ymax = 1 if CIs_precisionr[i, 0] <= 0: ymin = 0 else: ymin = CIs_precisionr[i, 0] if CIs_precisionr[i] <= 1: ymean = CIs_precisionr[i] else: ymean = 1 if CIs_tpr[i, 1] <= 1: xmax = CIs_tpr[i, 1] else: xmax = 1 if CIs_tpr[i, 0] <= 0: xmin = 0 else: xmin = CIs_tpr[i, 0] if CIs_tpr_means[i] <= 1: xmean = CIs_tpr_means[i] else: xmean = 1 if DEBUG: print(xmin, xmax, ymean) print(ymin, ymax, xmean) subplot.plot([xmin, xmax], [ymean, ymean], color='black', alpha=0.15) subplot.plot([xmean, xmean], [ymin, ymax], color='black', alpha=0.15) subplot.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--') plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.0]) plt.xlabel('Recall') plt.ylabel('Precision') plt.title('Precision-Recall curve') plt.legend(loc="lower right") return f, CIs_tpr, CIs_precisionr
[docs]def main(): parser = argparse.ArgumentParser(description='Plot the ROC Curve of an estimator') parser.add_argument('-prediction', '--prediction', metavar='prediction', nargs='+', dest='prediction', type=str, required=True, help='Prediction file (HDF)') parser.add_argument('-pinfo', '--pinfo', metavar='pinfo', nargs='+', dest='pinfo', type=str, required=True, help='Patient Info File (txt)') parser.add_argument('-ensemble_method', '--ensemble_method', metavar='ensemble_method', nargs='+', dest='ensemble_method', type=str, required=True, help='Method for creating ensemble (string)') parser.add_argument('-ensemble_size', '--ensemble_size', metavar='ensemble_size', nargs='+', dest='ensemble_size', type=str, required=False, help='Length of ensemble (int)') parser.add_argument('-label_type', '--label_type', metavar='label_type', nargs='+', dest='label_type', type=str, required=True, help='Label name that is predicted (string)') parser.add_argument('-ROC_png', '--ROC_png', metavar='ROC_png', nargs='+', dest='ROC_png', type=str, required=False, help='File to write ROC to (PNG)') parser.add_argument('-ROC_csv', '--ROC_csv', metavar='ROC_csv', nargs='+', dest='ROC_csv', type=str, required=False, help='File to write ROC to (csv)') parser.add_argument('-ROC_tex', '--ROC_tex', metavar='ROC_tex', nargs='+', dest='ROC_tex', type=str, required=False, help='File to write ROC to (tex)') parser.add_argument('-PRC_png', '--PRC_png', metavar='PRC_png', nargs='+', dest='PRC_png', type=str, required=False, help='File to write PR to (PNG)') parser.add_argument('-PRC_csv', '--PRC_csv', metavar='PRC_csv', nargs='+', dest='PRC_csv', type=str, required=False, help='File to write PR to (csv)') parser.add_argument('-PRC_tex', '--PRC_tex', metavar='PRC_tex', nargs='+', dest='PRC_tex', type=str, required=False, help='File to write PR to (tex)') args = parser.parse_args() plot_ROC(prediction=args.prediction, pinfo=args.pinfo, ensemble_method=args.ensemble_method, ensemble_size=args.ensemble_size, label_type=args.label_type, ROC_png=args.ROC_png, ROC_tex=args.ROC_tex, ROC_csv=args.ROC_csv, PRC_png=args.PRC_png, PRC_tex=args.PRC_tex, PRC_csv=args.PRC_csv)
[docs]def plot_ROC(prediction, pinfo, ensemble_method='top_N', ensemble_size=1, label_type=None, ROC_png=None, ROC_tex=None, ROC_csv=None, PRC_png=None, PRC_tex=None, PRC_csv=None): # Convert the inputs to the correct format if type(prediction) is list: prediction = ''.join(prediction) if type(pinfo) is list: pinfo = ''.join(pinfo) if type(ensemble_method) is list: ensemble_method = ''.join(ensemble_method) if type(ensemble_size) is list: ensemble_size = int(ensemble_size[0]) if type(ROC_png) is list: ROC_png = ''.join(ROC_png) if type(ROC_csv) is list: ROC_csv = ''.join(ROC_csv) if type(ROC_tex) is list: ROC_tex = ''.join(ROC_tex) if type(PRC_png) is list: PRC_png = ''.join(PRC_png) if type(PRC_csv) is list: PRC_csv = ''.join(PRC_csv) if type(PRC_tex) is list: PRC_tex = ''.join(PRC_tex) if type(label_type) is list: label_type = ''.join(label_type) # Read the inputs prediction = pd.read_hdf(prediction) if label_type is None: # Assume we want to have the first key label_type = prediction.keys()[0] elif len(label_type.split(',')) != 1: # Multiclass, just take the prediction label label_type = prediction.keys()[0] N_1 = len(prediction[label_type].Y_train[0]) N_2 = len(prediction[label_type].Y_test[0]) # Determine the predicted score per patient print('Determining score per patient.') y_truths, y_scores, _, _ =\ plot_estimator_performance(prediction, pinfo, [label_type], alpha=0.95, ensemble_method=ensemble_method, ensemble_size=ensemble_size, output='decision') # Check if we can compute confidence intervals config = prediction[label_type].config crossval_type = config['CrossValidation']['Type'] # -------------------------------------------------------------- # ROC Curve if crossval_type == 'LOO': print("LOO: Plotting the ROC without confidence intervals.") y_truths = [i[0] for i in y_truths] y_scores = [i[0] for i in y_scores] fpr, tpr, _, f = plot_single_ROC(y_truths, y_scores, returnplot=True) else: # Plot the ROC with confidence intervals print("Plotting the ROC with confidence intervals.") f, fpr, tpr = plot_ROC_CIc(y_truths, y_scores, N_1, N_2) # Save the outputs if ROC_png is not None: f.savefig(ROC_png) print(("ROC saved as {} !").format(ROC_png)) if ROC_tex is not None: tikzplotlib.save(ROC_tex) print(("ROC saved as {} !").format(ROC_tex)) if ROC_csv is not None: with open(ROC_csv, 'w') as csv_file: writer = csv.writer(csv_file) writer.writerow(['FPR', 'TPR']) for i in range(0, len(fpr)): data = [str(fpr[i]), str(tpr[i])] writer.writerow(data) print(("ROC saved as {} !").format(ROC_csv)) # -------------------------------------------------------------- # PR Curve if crossval_type == 'LOO': tpr, precisionr, _, f = plot_single_PRC(y_truths, y_scores, returnplot=True) else: f, tpr, precisionr = plot_PRC_CIc(y_truths, y_scores, N_1, N_2) if PRC_png is not None: f.savefig(PRC_png) print(("PRC saved as {} !").format(PRC_png)) if PRC_tex is not None: tikzplotlib.save(PRC_tex) print(("PRC saved as {} !").format(PRC_tex)) if PRC_csv is not None: with open(PRC_csv, 'w') as csv_file: writer = csv.writer(csv_file) writer.writerow(['Recall', 'Precision']) for i in range(0, len(tpr)): data = [str(tpr[i]), str(precisionr[i])] writer.writerow(data) print(("PRC saved as {} !").format(PRC_csv)) return f, fpr, tpr
if __name__ == '__main__': main()