Source code for causalpy.experiments.interrupted_time_series

#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""
Interrupted Time Series Analysis
"""

from typing import List, Union

import arviz as az
import numpy as np
import pandas as pd
import xarray as xr
from formulaic import model_matrix
from matplotlib import pyplot as plt
from sklearn.base import RegressorMixin

from causalpy.custom_exceptions import BadIndexException
from causalpy.plot_utils import get_hdi_to_df, plot_xY
from causalpy.pymc_models import PyMCModel
from causalpy.utils import round_num

from .base import BaseExperiment

LEGEND_FONT_SIZE = 12


[docs] class InterruptedTimeSeries(BaseExperiment): """ The class for interrupted time series analysis. :param data: A pandas dataframe :param treatment_time: The time when treatment occurred, should be in reference to the data index :param formula: A statistical model formula :param model: A PyMC model Example -------- >>> import causalpy as cp >>> df = ( ... cp.load_data("its") ... .assign(date=lambda x: pd.to_datetime(x["date"])) ... .set_index("date") ... ) >>> treatment_time = pd.to_datetime("2017-01-01") >>> seed = 42 >>> result = cp.InterruptedTimeSeries( ... df, ... treatment_time, ... formula="y ~ 1 + t + C(month)", ... model=cp.pymc_models.LinearRegression( ... sample_kwargs={ ... "target_accept": 0.95, ... "random_seed": seed, ... "progressbar": False, ... } ... ), ... ) """ expt_type = "Interrupted Time Series" supports_ols = True supports_bayes = True
[docs] def __init__( self, data: pd.DataFrame, treatment_time: Union[int, float, pd.Timestamp], formula: str, model=None, **kwargs, ) -> None: super().__init__(model=model) # rename the index to "obs_ind" data.index.name = "obs_ind" self.input_validation(data, treatment_time) self.treatment_time = treatment_time # set experiment type - usually done in subclasses self.expt_type = "Pre-Post Fit" # split data in to pre and post intervention self.datapre = data[data.index < self.treatment_time] self.datapost = data[data.index >= self.treatment_time] self.formula = formula # set things up with pre-intervention data dm = model_matrix(self.formula, self.datapre) self.labels = list(dm.rhs.columns) self.matrix_spec = dm.model_spec self.outcome_variable_name = dm.lhs.columns[0] self.pre_y, self.pre_X = (dm.lhs.to_numpy(), dm.rhs.to_numpy()) # process post-intervention data new_dm = model_matrix(spec=self.matrix_spec, data=self.datapost) self.post_X = new_dm.rhs.to_numpy() self.post_y = new_dm.lhs.to_numpy() # turn into xarray.DataArray's self.pre_X = xr.DataArray( self.pre_X, dims=["obs_ind", "coeffs"], coords={ "obs_ind": self.datapre.index, "coeffs": self.labels, }, ) self.pre_y = xr.DataArray( self.pre_y[:, 0], dims=["obs_ind"], coords={"obs_ind": self.datapre.index}, ) self.post_X = xr.DataArray( self.post_X, dims=["obs_ind", "coeffs"], coords={ "obs_ind": self.datapost.index, "coeffs": self.labels, }, ) self.post_y = xr.DataArray( self.post_y[:, 0], dims=["obs_ind"], coords={"obs_ind": self.datapost.index}, ) # fit the model to the observed (pre-intervention) data if isinstance(self.model, PyMCModel): COORDS = {"coeffs": self.labels, "obs_ind": np.arange(self.pre_X.shape[0])} self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS) elif isinstance(self.model, RegressorMixin): self.model.fit(X=self.pre_X, y=self.pre_y) else: raise ValueError("Model type not recognized") # score the goodness of fit to the pre-intervention data self.score = self.model.score(X=self.pre_X, y=self.pre_y) # get the model predictions of the observed (pre-intervention) data self.pre_pred = self.model.predict(X=self.pre_X) # calculate the counterfactual self.post_pred = self.model.predict(X=self.post_X) self.pre_impact = self.model.calculate_impact(self.pre_y, self.pre_pred) self.post_impact = self.model.calculate_impact(self.post_y, self.post_pred) self.post_impact_cumulative = self.model.calculate_cumulative_impact( self.post_impact )
[docs] def input_validation(self, data, treatment_time): """Validate the input data and model formula for correctness""" if isinstance(data.index, pd.DatetimeIndex) and not isinstance( treatment_time, pd.Timestamp ): raise BadIndexException( "If data.index is DatetimeIndex, treatment_time must be pd.Timestamp." ) if not isinstance(data.index, pd.DatetimeIndex) and isinstance( treatment_time, pd.Timestamp ): raise BadIndexException( "If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501 )
[docs] def summary(self, round_to=None) -> None: """Print summary of main results and model coefficients. :param round_to: Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers """ print(f"{self.expt_type:=^80}") print(f"Formula: {self.formula}") self.print_coefficients(round_to)
def _bayesian_plot( self, round_to=None, **kwargs ) -> tuple[plt.Figure, List[plt.Axes]]: """ Plot the results :param round_to: Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers. """ counterfactual_label = "Counterfactual" fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8)) # TOP PLOT -------------------------------------------------- # pre-intervention period h_line, h_patch = plot_xY( self.datapre.index, self.pre_pred["posterior_predictive"].mu, ax=ax[0], plot_hdi_kwargs={"color": "C0"}, ) handles = [(h_line, h_patch)] labels = ["Pre-intervention period"] (h,) = ax[0].plot(self.datapre.index, self.pre_y, "k.", label="Observations") handles.append(h) labels.append("Observations") # post intervention period h_line, h_patch = plot_xY( self.datapost.index, self.post_pred["posterior_predictive"].mu, ax=ax[0], plot_hdi_kwargs={"color": "C1"}, ) handles.append((h_line, h_patch)) labels.append(counterfactual_label) ax[0].plot(self.datapost.index, self.post_y, "k.") # Shaded causal effect h = ax[0].fill_between( self.datapost.index, y1=az.extract( self.post_pred, group="posterior_predictive", var_names="mu" ).mean("sample"), y2=np.squeeze(self.post_y), color="C0", alpha=0.25, ) handles.append(h) labels.append("Causal impact") ax[0].set( title=f""" Pre-intervention Bayesian $R^2$: {round_num(self.score.r2, round_to)} (std = {round_num(self.score.r2_std, round_to)}) """ ) # MIDDLE PLOT ----------------------------------------------- plot_xY( self.datapre.index, self.pre_impact, ax=ax[1], plot_hdi_kwargs={"color": "C0"}, ) plot_xY( self.datapost.index, self.post_impact, ax=ax[1], plot_hdi_kwargs={"color": "C1"}, ) ax[1].axhline(y=0, c="k") ax[1].fill_between( self.datapost.index, y1=self.post_impact.mean(["chain", "draw"]), color="C0", alpha=0.25, label="Causal impact", ) ax[1].set(title="Causal Impact") # BOTTOM PLOT ----------------------------------------------- ax[2].set(title="Cumulative Causal Impact") plot_xY( self.datapost.index, self.post_impact_cumulative, ax=ax[2], plot_hdi_kwargs={"color": "C1"}, ) ax[2].axhline(y=0, c="k") # Intervention line for i in [0, 1, 2]: ax[i].axvline( x=self.treatment_time, ls="-", lw=3, color="r", ) ax[0].legend( handles=(h_tuple for h_tuple in handles), labels=labels, fontsize=LEGEND_FONT_SIZE, ) return fig, ax def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]: """ Plot the results :param round_to: Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers. """ counterfactual_label = "Counterfactual" fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8)) ax[0].plot(self.datapre.index, self.pre_y, "k.") ax[0].plot(self.datapost.index, self.post_y, "k.") ax[0].plot(self.datapre.index, self.pre_pred, c="k", label="model fit") ax[0].plot( self.datapost.index, self.post_pred, label=counterfactual_label, ls=":", c="k", ) ax[0].set( title=f"$R^2$ on pre-intervention data = {round_num(self.score, round_to)}" ) ax[1].plot(self.datapre.index, self.pre_impact, "k.") ax[1].plot( self.datapost.index, self.post_impact, "k.", label=counterfactual_label, ) ax[1].axhline(y=0, c="k") ax[1].set(title="Causal Impact") ax[2].plot(self.datapost.index, self.post_impact_cumulative, c="k") ax[2].axhline(y=0, c="k") ax[2].set(title="Cumulative Causal Impact") # Shaded causal effect ax[0].fill_between( self.datapost.index, y1=np.squeeze(self.post_pred), y2=np.squeeze(self.post_y), color="C0", alpha=0.25, label="Causal impact", ) ax[1].fill_between( self.datapost.index, y1=np.squeeze(self.post_impact), color="C0", alpha=0.25, label="Causal impact", ) # Intervention line # TODO: make this work when treatment_time is a datetime for i in [0, 1, 2]: ax[i].axvline( x=self.treatment_time, ls="-", lw=3, color="r", label="Treatment time", ) ax[0].legend(fontsize=LEGEND_FONT_SIZE) return (fig, ax)
[docs] def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame: """ Recover the data of the experiment along with the prediction and causal impact information. :param hdi_prob: Prob for which the highest density interval will be computed. The default value is defined as the default from the :func:`arviz.hdi` function. """ if isinstance(self.model, PyMCModel): hdi_pct = int(round(hdi_prob * 100)) pred_lower_col = f"pred_hdi_lower_{hdi_pct}" pred_upper_col = f"pred_hdi_upper_{hdi_pct}" impact_lower_col = f"impact_hdi_lower_{hdi_pct}" impact_upper_col = f"impact_hdi_upper_{hdi_pct}" pre_data = self.datapre.copy() post_data = self.datapost.copy() pre_data["prediction"] = ( az.extract(self.pre_pred, group="posterior_predictive", var_names="mu") .mean("sample") .values ) post_data["prediction"] = ( az.extract(self.post_pred, group="posterior_predictive", var_names="mu") .mean("sample") .values ) pre_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df( self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob ).set_index(pre_data.index) post_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df( self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob ).set_index(post_data.index) pre_data["impact"] = self.pre_impact.mean(dim=["chain", "draw"]).values post_data["impact"] = self.post_impact.mean(dim=["chain", "draw"]).values pre_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df( self.pre_impact, hdi_prob=hdi_prob ).set_index(pre_data.index) post_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df( self.post_impact, hdi_prob=hdi_prob ).set_index(post_data.index) self.plot_data = pd.concat([pre_data, post_data]) return self.plot_data else: raise ValueError("Unsupported model type")
[docs] def get_plot_data_ols(self) -> pd.DataFrame: """ Recover the data of the experiment along with the prediction and causal impact information. """ pre_data = self.datapre.copy() post_data = self.datapost.copy() pre_data["prediction"] = self.pre_pred post_data["prediction"] = self.post_pred pre_data["impact"] = self.pre_impact post_data["impact"] = self.post_impact self.plot_data = pd.concat([pre_data, post_data]) return self.plot_data