Source code for BaseITS.plotting

import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import numpy as np
from datetime import datetime
import pandas as pd
from prophet.plot import plot_cross_validation_metric
from BaseITS.utils import *

# All plots saved in a plots folder


[docs]class Plotting: """Class to handle the plotting given a fixed expectation of inputs. This class is an addition, Users can implement more plotting functions to visualize their data & outputs. Methods: plot_sphaghetti(...): Line plot with all the provided outcomes and locations. plot_cumulative(...): Cumulative plots. plot_count_its(...): Counts plots. plot_percent_change(...): Percentage plots. plot_diff(...): Difference of values plots. plot_boxplots(...): Boxplots plots_metrics_distribution(...): Kernel Density Estimate plot """ def __init__( self, intervention_end_date: str, forecast: dict(), data: pd.DataFrame, outcome_labels: dict() = {}, file_path: str = "plots/", ): """Constructor Args: intervention_end_date(str): forecast (dict): forecast output from prophet model data (pd.DataFrame): dataset to be used outcome_labels (dict, optional): label mapping of the outcomes. Defaults to {}. file_path(str, optional): File path to save the plots. Defaults to 'plots/' """ self.intervention_end_date = intervention_end_date self.forecast = forecast self.data = data self.outcome_labels = outcome_labels self.file_path = file_path self.__validate_data # TODO Add an implementation that allows calling the constructor/functions without the inputs def __validate_data(self): # TODO: Implement this function pass
[docs] def plot_sphaghetti(self, id_var: str, x_var: str, y_var: str, title: str): """ Plot sphaghetti @params data: data in wide format (dataframe) id_var: column name of unique ids of individuals (str) i.e Regions, Districts x_var: name of x variable (str) i.e Date column name y_var: name of y variable (str) i.e outcome measure to be plotted column name i.e Diabetes """ sns.set(rc={"figure.figsize": (12, 7)}) mpl.rcParams["font.size"] = 8.0 mpl.rcParams["font.weight"] = "bold" mpl.rcParams["axes.labelweight"] = "bold" ids = self.data[id_var] for i in ids: df = self.data[[x_var, y_var]][self.data[id_var] == i] df[y_var] = df[y_var].fillna(df[y_var].mean()) # z = np.abs(stats.zscore(df[y_var])) # df = df[z<3] plt.plot( df[x_var], df[y_var], marker="", color="black", linewidth=0.1, alpha=0.1 ) plt.ylabel("Count", fontsize=18, fontweight="bold") # plt.xlabel(x_var, fontsize = 14, fontweight='bold') plt.title(title, fontsize=20, fontweight="bold") plt.xticks(rotation=90, fontsize=18, fontweight="bold") plt.yticks(fontsize=18, fontweight="bold") save_plots(self.file_path + title) plt.xlabel("Date") plt.show() plt.close()
def __plot_all_cumulative_its(self): return NotImplementedError( "User to implement this themselves. Exposing the function only for plotting one outcome, location, intervention, model only." )
[docs] def plot_cumulative( self, forecast: pd.DataFrame(), unit: str, outcome: str, prediction_start_date: str, prediction_end_date: str, normalise: bool = False, ): """ Function to plot the cumulative forecast. @params: unit: Geographical/Organizational unit (string) outcome: Outcome measure (string) forecast: Dataframe with the prophet forecast output for the geographical/organizational unit and outcome (pd.DataFrame) prediction_start_date: Prediction start date (datetime) prediction_end_date: Prediction end date (datetime) normalise: Data normalised (boolean) """ sns.set(rc={"figure.figsize": (10, 8)}) sns.set_style("white") mpl.rcParams["font.size"] = 8.0 mpl.rcParams["font.weight"] = "bold" mpl.rcParams["axes.labelweight"] = "bold" prediction_start_date = datetime.strptime(prediction_start_date, "%Y-%m-%d") prediction_end_date = datetime.strptime(prediction_end_date, "%Y-%m-%d") plt.xlim([forecast["ds"].min(), prediction_end_date]) plt.axvspan( xmin=prediction_start_date, xmax=prediction_end_date, color="grey", alpha=0.2, lw=0, ) plt.plot( forecast["ds"], forecast["y"].cumsum(), color="orange", label="actual values", ) plt.plot( forecast["ds"], forecast["yhat"].cumsum(), color="#33adff", label="predicted values", ) plt.xticks(rotation=90, fontsize=18, fontweight="bold") plt.yticks(fontsize=18, fontweight="bold") plt.xlabel("") if normalise: plt.ylabel("Cumulative counts per 100k", fontsize=18, fontweight="bold") title = "{} in {} Region normalised".format( self.outcome_labels[outcome], unit ) else: plt.ylabel("Cumulative counts", fontsize=18, fontweight="bold") title = "{} in {} ".format(self.outcome_labels[outcome], unit) plt.title(title, fontsize=20, fontweight="bold") plt.legend(loc="upper left", fontsize=12) save_plots(file_path="plots/", plot_name=title) plt.show() plt.close()
def __plot_all_counts_its(self): return NotImplementedError( "User to implement this themselves. Exposing the function only for plotting one outcome, location, intervention, model only." )
[docs] def plot_count_its( self, forecast: pd.DataFrame(), unit: str, outcome: str, prediction_start_date: str, prediction_end_date: str, normalise: bool = False, ): """ Function to plot the forecast counts. @params: unit: Geographical unit (string) outcome: Outcome measure (string) forecast: Dataframe with the prophet forecast output for the geographical unit and outcome(pd.DataFrame) prediction_start_date: Prediction start date (datetime) prediction_end_date: Prediction end date (datetime) normalise: Data normalised (boolean) """ sns.set(rc={"figure.figsize": (10, 8)}) sns.set_style("white") mpl.rcParams["font.size"] = 8.0 mpl.rcParams["font.weight"] = "bold" mpl.rcParams["axes.labelweight"] = "bold" prediction_start_date = datetime.strptime(prediction_start_date, "%Y-%m-%d") prediction_end_date = datetime.strptime(prediction_end_date, "%Y-%m-%d") plt.axvspan( xmin=prediction_start_date, xmax=prediction_end_date, color="grey", alpha=0.2, lw=0, ) plt.scatter( forecast["ds"], forecast["y"], facecolors="none", edgecolors="black", s=20, label="observed values", ) plt.xlim([forecast["ds"].min(), prediction_end_date]) plt.plot( forecast["ds"], forecast["yhat"], color="#33adff", label="predicted values" ) plt.fill_between( forecast["ds"], forecast["yhat_lower"], forecast["yhat_upper"], color="#33adff", alpha=0.25, ) plt.plot( forecast["ds"], forecast["trend"], color="red", linestyle="--", label="predicted trend", ) plt.fill_between( forecast["ds"], forecast["trend_lower"], forecast["trend_upper"], color="red", alpha=0.2, ) plt.xticks(rotation=90, fontsize=18, fontweight="bold") plt.yticks(fontsize=18, fontweight="bold") plt.xlabel("") if normalise: plt.ylabel("count per 100k", fontsize=18, fontweight="bold") title = "{} in {} normalised".format(self.outcome_labels[outcome], unit) else: plt.ylabel("counts", fontsize=18, fontweight="bold") title = "{} in {} ".format(self.outcome_labels[outcome], unit) plt.title(title, fontsize=20, fontweight="bold") plt.legend(loc="upper left", fontsize=12) save_plots(file_path="plots/", plot_name=title) plt.show() plt.close()
[docs] def plot_percent_change(self, dataset: pd.DataFrame, prediction_start_date: str): """ Function to plot the percentage change of the outcomes and the predicted values @params: data: Dataframe with the selected metrics from the forecast (pd.DataFrame) prediction_start_date: Prediction start date (str) """ df = dataset[dataset["ds"] == prediction_start_date].copy() df = df.sort_values(["outcome", "mean_percent_change_after"], ascending=False) prediction_start_date = datetime.strptime(prediction_start_date, "%Y-%m-%d") sns.set(rc={"figure.figsize": (10, 9)}) mpl.rcParams["font.size"] = 12.0 mpl.rcParams["font.weight"] = "bold" mpl.rcParams["axes.labelweight"] = "bold" spot = 0 for outcome in set(df["outcome"]): spot += 1 plt.subplot(3, 1, spot) x = df["location"][df["outcome"] == outcome] y = np.array( list(df["mean_percent_change_after"][df["outcome"] == outcome]) ) conf = np.array( list(df["percent_change_conf_int_after"][df["outcome"] == outcome]) ) yconf = np.c_[y - conf[:, 0], conf[:, 1] - y].T plt.plot(y, x, ".", color="blue") plt.axvline(x=0, linewidth=1, linestyle="--", color="grey") title = "{}".format(self.outcome_labels[outcome]) plt.title(title, loc="center", fontsize=12, fontweight="bold") plt.xlabel( "percent change in {} counts".format(self.outcome_labels[outcome]), fontsize=12, fontweight="bold", ) plt.errorbar(y, x, xerr=yconf, fmt=" ", color="blue") suptitle = "Mean and 95% CI of Percent Change in Predicted vs Actual Counts ({} {} to December 2020)".format( prediction_start_date.strftime("%B"), prediction_start_date.year ) plt.suptitle(suptitle, fontsize=14, fontweight="bold") plt.tight_layout(pad=4.0) save_plots(self.file_path, "{}".format(suptitle)) # plt.savefig("plots/{}.png".format(suptitle), dpi=300, bbox_inches="tight") plt.show() plt.close()
[docs] def plot_diff(self, dataset: pd.DataFrame, prediction_start_date: str): """ Function to plot the difference of the outcomes and the predicted values @params: data: Dataframe with the metrics from the forecast (pd.DataFrame) prediction_start_date: Prediction start date (str) """ df = dataset[dataset["ds"] == prediction_start_date].copy() df = df.sort_values(["outcome", "mean_change_after"], ascending=False) prediction_start_date = datetime.strptime(prediction_start_date, "%Y-%m-%d") sns.set(rc={"figure.figsize": (11, 9)}) mpl.rcParams["font.size"] = 12.0 mpl.rcParams["font.weight"] = "bold" mpl.rcParams["axes.labelweight"] = "bold" spot = 0 for outcome in set(df["outcome"]): spot += 1 plt.subplot(3, 1, spot) x = df["location"][df["outcome"] == outcome] y = np.array(list(df["mean_change_after"][df["outcome"] == outcome])) conf = np.array(list(df["change_conf_int_after"][df["outcome"] == outcome])) yconf = np.c_[y - conf[:, 0], conf[:, 1] - y].T plt.plot(y, x, ".", color="blue") plt.axvline(x=0, linewidth=1, linestyle="--", color="grey") title = "{}".format(self.outcome_labels[outcome]) plt.title(title, loc="center", fontsize=12, fontweight="bold") plt.xlabel( "mean of the difference in predicted vs actual {} counts".format( self.outcome_labels[outcome] ), fontsize=12, fontweight="bold", ) plt.errorbar(y, x, xerr=yconf, fmt=" ", color="blue") suptitle = "Mean and 95% CI of Difference in Predicted vs Actual Counts ({} {} to December 2020)".format( prediction_start_date.strftime("%B"), prediction_start_date.year ) plt.suptitle(suptitle, fontsize=14, fontweight="bold") plt.tight_layout(pad=4.0) save_plots(self.file_path, "{}".format(suptitle)) # plt.savefig("plots/{}.png".format(suptitle), dpi=300, bbox_inches="tight") plt.show() plt.close()
[docs] def plot_boxplots(self, x: str, y: str, title: str): """Function to plot boxplots Args: x (str): date column y (str): data value column to be plotted title (str): title of plot """ fig, ax = plt.subplots(figsize=(16, 8)) sns.boxplot( self.data[x].dt.strftime("%Y-%m"), self.data[y], ax=ax, showfliers=False ) # sns.violinplot(data[x].dt.strftime('%Y-%m'), data[y], ax=ax,cut=0) plt.yticks(fontsize=14, fontweight="bold") plt.xticks(rotation="vertical") plt.xticks(rotation=90, fontsize=14, fontweight="bold") plt.ylabel("Count", fontsize=14, fontweight="bold") plt.xlabel("") plt.title(title, fontsize=20, fontweight="bold") save_plots(self.file_path, title + "time_series_boxplots.png") # plt.savefig( # self.file_path + title + "time_series_boxplots.png", dpi=300, bbox_inches="tight" # ) plt.show() plt.close()
[docs] def plots_metrics_distribution( self, prediction_start_date: str, prediction_end_date: str ): """Function to plot metrics distribution before and after interruption @params: prediction_start_date: Prediction start date (str) prediction_end_date: Prediction end date (str) """ prediction_start_date = datetime.strptime(prediction_start_date, "%Y-%m-%d") prediction_end_date = datetime.strptime(prediction_end_date, "%Y-%m-%d") df_before = self.forecast[self.forecast["ds"] < prediction_start_date] df_after = self.forecast[ (self.forecast["ds"] >= prediction_start_date) & (self.forecast["ds"] <= prediction_end_date) ] sns.set(rc={"figure.figsize": (8, 6)}) sns.set_style("white") mpl.rcParams["font.size"] = 8.0 mpl.rcParams["font.weight"] = "bold" mpl.rcParams["axes.labelweight"] = "bold" sns.kdeplot(df_before["yhat-y"], shade=True, label="Before") sns.kdeplot(df_after["yhat-y"], shade=True, label="After") title = "Distribution of yhat-y before and after interruption" plt.title( title, fontsize=16, fontweight="bold", ) plt.legend() save_plots(self.file_path, title) plt.show() plt.close()
[docs] def plot_cv_metric( self, location: str, outcome: str, df_cv: pd.DataFrame, prediction_start_date: str, outcome_labels: dict, ): """ Function to plot the cross-validation metrics. @params: location: Geographical unit (string) outcome: Outcome measure (string) df_cv: Dataframe with date and the outcome value for the geographical unit and outcome. (pd.DataFrame) seasonality_mode: tuned prophet model parameter (string) changepoint_prior_scale: tuned prophet model parameter (double) seasonality_prior_scale: tuned prophet model parameter (double) cutoff_start: Prediction start date (datetime) """ sns.set(rc={"figure.figsize": (10, 8)}) sns.set_style("white") mpl.rcParams["font.size"] = 8.0 mpl.rcParams["font.weight"] = "bold" mpl.rcParams["axes.labelweight"] = "bold" plot_cross_validation_metric(df_cv, metric="mape") plt.xticks(fontsize=18, fontweight="bold") plt.yticks(fontsize=18, fontweight="bold") plt.xlabel("Horizon", fontsize=18, fontweight="bold") plt.ylabel("MAPE", fontsize=18, fontweight="bold") prediction_start_date = datetime.strptime(prediction_start_date, "%Y-%m-%d") title = "MAPE for {} in {} ({} {})".format( outcome_labels[outcome], location, prediction_start_date.strftime("%B"), prediction_start_date.year, ) plt.title(title, fontsize=20, fontweight="bold") plt.savefig("plots/{}.png".format(title), dpi=300, bbox_inches="tight") plt.show() plt.close()