Source code for WORC.plotting.plot_errors

#!/usr/bin/env python

# Copyright 2016-2020 Biomedical Imaging Group Rotterdam, Departments of
# Medical Informatics and Radiology, Erasmus MC, Rotterdam, The Netherlands
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import matplotlib
import matplotlib.pyplot as plt

import os
import numpy as np
from WORC.IOparser.file_io import load_features
import pandas as pd
import WORC.addexceptions as ae
import tikzplotlib

[docs]def plot_errors(featurefiles, patientinfo, label_type, featurenames, posteriors_csv=None, agesex=True, output_png=None, output_tex=None): """Scatterplot of all objects with marking of errors.""" # check some input if len(featurenames) != 2: raise ae.WORCValueError(f'Featurenames should be list with two strings, got {featurenames}.') # Read the features and classification data print("Reading features and label data.") label_data, image_features =\ load_features(featurefiles, patientinfo, label_type) # Read in the scores and labels if posteriors_csv is not None: data = pd.read_csv(posteriors_csv) PIDs = data['PatientID'].values labels = data['TrueLabel'].values.tolist() probabilities = data['Probability'].values # Convert probabilities to correct / not correct = list() for label, prob in zip(labels, probabilities): if prob >= 0.5 and label == 1.0: correct.append(True) elif prob < 0.5 and label == 0.0: correct.append(True) else: # incorrect prediction correct.append(False) # Select indices of features we need feature_labels = image_features[0][1] if featurenames[0] not in feature_labels: raise ae.WORCKeyError(f'Feature {featurenames[0]} not in feature names.') else: feature_1_index = feature_labels.index(featurenames[0]) if featurenames[1] not in feature_labels: raise ae.WORCKeyError(f'Feature {featurenames[1]} not in feature names.') else: feature_2_index = feature_labels.index(featurenames[1]) # Match probabilities to features feature_1 = list() feature_2 = list() for pid, prob in zip(PIDs, probabilities): if pid not in label_data['patient_IDs']: raise ae.WORCKeyError(f'Features for {pid} not provided.') else: index = label_data['patient_IDs'].tolist().index(pid) feature_1.append(image_features[index][0][feature_1_index]) feature_2.append(image_features[index][0][feature_2_index]) # Resort based on PID order = np.argsort(PIDs) feature_1 = [feature_1[index] for index in order] feature_2 = [feature_2[index] for index in order] correct = [correct[index] for index in order] labels = [labels[index] for index in order] # Actual plotting f = plt.figure(figsize=(20, 15)) ax = plt.subplot(1, 1, 1) coordinates = list() for index, label in enumerate(labels): # Check if coordinate has already been plotted coordinate = (feature_1[index], feature_2[index]) addoffset = 1 sign = 1 while coordinate in coordinates: # Coordinate plotted, add an x-offset offset = sign * 0.01 * addoffset coordinate = (feature_1[index] + offset, feature_2[index]) if sign == 1: sign = -1 else: sign = 1 addoffset += 1 coordinates.append(coordinate) # Red border if classification is incorrect if not correct[index]: ax.scatter(coordinate[0], coordinate[1], s=80, marker='o', color='red') s = 30 else: s = 50 # Plot point in feature space if label == 1.0: ax.scatter(coordinate[0], coordinate[1], s=s, marker='o', color='#7dcfe2') else: ax.scatter(coordinate[0], coordinate[1], s=s, marker='o', color='blue') # Add some labelling etc to the plot if agesex: plt.xlabel('Sex', size=12) plt.ylabel('Age', size=12) plt.xticks([0, 1], ['Female', 'Male'], size=8) else: plt.xlabel(featurenames[0], size=12) plt.ylabel(featurenames[1], size=12) # Save output if output_png is not None: plt.savefig(output_png, bbox_inches='tight', pad_inches=0) print(f"Plot saved as {output_png}!") if output_tex is not None: print(f"Plot saved as {output_tex}!")