Source code for greykite.sklearn.estimator.prophet_estimator

# BSD 2-CLAUSE LICENSE

# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:

# Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
# #ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# original author: Albert Chen, Rachit Kumar, Sayan Patra
"""sklearn estimator for fbprophet"""

import fbprophet
from fbprophet.plot import plot_components
from sklearn.exceptions import NotFittedError
from sklearn.metrics import mean_squared_error

from greykite.common.constants import PREDICTED_COL
from greykite.common.constants import PREDICTED_LOWER_COL
from greykite.common.constants import PREDICTED_UPPER_COL
from greykite.common.constants import TIME_COL
from greykite.common.constants import VALUE_COL
from greykite.common.logging import LoggingLevelEnum
from greykite.common.logging import log_message
from greykite.common.logging import pprint
from greykite.sklearn.estimator.base_forecast_estimator import BaseForecastEstimator


[docs]class ProphetEstimator(BaseForecastEstimator): """Wrapper for Facebook Prophet model. Parameters ---------- score_func : callable see BaseForecastEstimator coverage : float between [0.0, 1.0] see BaseForecastEstimator null_model_params : dict with arguments to define DummyRegressor null model, optional, default=None see BaseForecastEstimator add_regressor_dict: dictionary of extra regressors to be added to the model, optional, default=None These should be available for training and entire prediction interval. Dictionary format:: add_regressor_dict={ # we can add as many regressors as we want, in the following format "reg_col1": { "prior_scale": 10, "standardize": True, "mode": 'additive' }, "reg_col2": { "prior_scale": 20, "standardize": True, "mode": 'multiplicative' } } add_seasonality_dict: dict of custom seasonality parameters to be added to the model, optional, default=None parameter details: https://github.com/facebook/prophet/blob/master/python/fbprophet/forecaster.py - refer to add_seasonality() function. Key is the seasonality component name e.g. 'monthly'; parameters are specified via dict. Dictionary format:: add_seasonality_dict={ 'monthly': { 'period': 30.5, 'fourier_order': 5 }, 'weekly': { 'period': 7, 'fourier_order': 20, 'prior_scale': 0.6, 'mode': 'additive', 'condition_name': 'condition_col' # takes a bool column in df with True/False values. This means that # the seasonality will only be applied to dates where the condition_name column is True. }, 'yearly': { 'period': 365.25, 'fourier_order': 10, 'prior_scale': 0.2, 'mode': 'additive' } } Note: If there is a conflict in built-in and custom seasonality e.g. both have "yearly", then custom seasonality will be used and Model will throw a warning such as: "INFO:fbprophet:Found custom seasonality named "yearly", disabling built-in yearly seasonality." kwargs : additional parameters Other parameters are the same as Prophet model, with one exception: ``interval_width`` is specified by ``coverage``. See source code ``__init__`` for the parameter names, and refer to Prophet documentation for a description: * https://facebook.github.io/prophet/docs/quick_start.html * https://github.com/facebook/prophet/blob/c7a150312ad26db5a8ea81d8cd32dd5fd6ce1eda/python/fbprophet/forecaster.py#L30 Attributes ---------- model : ``Prophet`` object Prophet model object forecast : `pandas.DataFrame` Output of predict method of ``Prophet``. """ def __init__( self, score_func=mean_squared_error, coverage=0.80, # to specify interval_width in Prophet null_model_params=None, growth="linear", changepoints=None, n_changepoints=25, changepoint_range=0.8, yearly_seasonality="auto", weekly_seasonality="auto", daily_seasonality="auto", holidays=None, seasonality_mode="additive", seasonality_prior_scale=10.0, holidays_prior_scale=10.0, changepoint_prior_scale=0.05, mcmc_samples=0, uncertainty_samples=1000, add_regressor_dict=None, add_seasonality_dict=None): # every subclass of BaseForecastEstimator must call super().__init__ super().__init__( score_func=score_func, coverage=coverage, null_model_params=null_model_params) # necessary to set parameters, to ensure get_params() works (used in grid search) self.score_func = score_func self.coverage = coverage self.null_model_params = null_model_params self.growth = growth self.changepoints = changepoints self.n_changepoints = n_changepoints self.changepoint_range = changepoint_range self.yearly_seasonality = yearly_seasonality self.weekly_seasonality = weekly_seasonality self.daily_seasonality = daily_seasonality self.holidays = holidays self.seasonality_mode = seasonality_mode self.seasonality_prior_scale = seasonality_prior_scale self.holidays_prior_scale = holidays_prior_scale self.changepoint_prior_scale = changepoint_prior_scale self.mcmc_samples = mcmc_samples self.uncertainty_samples = uncertainty_samples # additional regressor names and optimization self.add_regressor_dict = add_regressor_dict # additional seasonality parameters self.add_seasonality_dict = add_seasonality_dict # set by the fit method self.model = None # set by the predict method self.forecast = None
[docs] def fit(self, X, y=None, time_col=TIME_COL, value_col=VALUE_COL, **fit_params): """Fits fbprophet model. Parameters ---------- X : `pandas.DataFrame` Input timeseries, with timestamp column, value column, and any additional regressors. The value column is the response, included in X to allow transformation by `sklearn.pipeline.Pipeline` y : ignored The original timeseries values, ignored. (The y for fitting is included in ``X``.) time_col : `str` Time column name in ``X`` value_col : `str` Value column name in ``X`` fit_params : `dict` additional parameters for null model Returns ------- self : self Fitted model is stored in ``self.model``. """ super().fit(X, y=y, time_col=time_col, value_col=value_col, **fit_params) if self.add_regressor_dict is None: fit_columns = [time_col, value_col] else: reg_cols = list(self.add_regressor_dict.keys()) fit_columns = [time_col, value_col] + reg_cols fit_df = X.reset_index(drop=True)[fit_columns] fit_df.rename(columns={time_col: "ds", value_col: "y"}, inplace=True) # Prophet expects these column names. Other estimators can use TIME_COL, etc. # uses coverage instead of interval_width to set prediction band width. This ensures a common # interface for parameters common to every BaseForecastEstimator, usually also needed for forecast evaluation # model must be initialized here, not in __init__, to update parameters in grid search self.model = fbprophet.Prophet( growth=self.growth, changepoints=self.changepoints, n_changepoints=self.n_changepoints, changepoint_range=self.changepoint_range, yearly_seasonality=self.yearly_seasonality, weekly_seasonality=self.weekly_seasonality, daily_seasonality=self.daily_seasonality, holidays=self.holidays, seasonality_mode=self.seasonality_mode, seasonality_prior_scale=self.seasonality_prior_scale, holidays_prior_scale=self.holidays_prior_scale, changepoint_prior_scale=self.changepoint_prior_scale, mcmc_samples=self.mcmc_samples, interval_width=self.coverage, uncertainty_samples=self.uncertainty_samples ) # if extra regressors are given, we add them to temporal features data # This implementation assumes that the regressor(s) are provided in time series df, alongside target column. if self.add_regressor_dict is not None: for reg_col, reg_params in self.add_regressor_dict.items(): self.model.add_regressor(name=reg_col, **reg_params) # if custom seasonality is provided, we supply it to Prophet model if self.add_seasonality_dict is not None: for seasonality_type, seasonality_params in self.add_seasonality_dict.items(): self.model.add_seasonality(name=seasonality_type, **seasonality_params) self.model.fit(fit_df) return self
[docs] def predict(self, X, y=None): """Creates forecast for dates specified in ``X``. Parameters ---------- X : `pandas.DataFrame` Input timeseries with timestamp column and any additional regressors. Timestamps are the dates for prediction. Value column, if provided in X, is ignored. y : ignored Returns ------- predictions : `pandas.DataFrame` Forecasted values for the dates in ``X``. Columns: * TIME_COL dates * PREDICTED_COL predictions * PREDICTED_LOWER_COL lower bound of predictions, optional * PREDICTED_UPPER_COL upper bound of predictions, optional * [other columns], optional PREDICTED_LOWER_COL and PREDICTED_UPPER_COL are present iff coverage is not None """ # Returns the cached result if applicable cached_predictions = super().predict(X=X) if cached_predictions is not None: return cached_predictions # if regressors are not provided, then use time column to predict future. Else, use regressor from predict df if self.add_regressor_dict is None: predict_columns = [self.time_col_] else: reg_cols = list(self.add_regressor_dict.keys()) predict_columns = [self.time_col_] + reg_cols fut_df = X.reset_index(drop=True)[predict_columns] # prophet expects time_col name to be "ds" fut_df.rename(columns={self.time_col_: "ds"}, inplace=True) pred_df = self.model.predict(fut_df) self.forecast = pred_df # This is used by the plot_components # renames columns to standardized schema output_columns = { "ds": TIME_COL, "yhat": PREDICTED_COL, "yhat_lower": PREDICTED_LOWER_COL, "yhat_upper": PREDICTED_UPPER_COL } predictions = (pred_df[output_columns.keys()] .rename(output_columns, axis=1)) # Caches the predictions self.cached_predictions_ = predictions return predictions
[docs] def summary(self): """Prints input parameters and Prophet model parameters. Returns ------- log_message : str log message printed to logging.info() """ super().summary() if self.model is not None: log_message(pprint(self.model.params), LoggingLevelEnum.INFO)
[docs] def plot_components( self, uncertainty=True, plot_cap=True, weekly_start=0, yearly_start=0, figsize=None): """Plot the ``Prophet`` forecast components on the dataset passed to ``predict``. Will plot whichever are available of: trend, holidays, weekly seasonality, and yearly seasonality. Parameters ---------- uncertainty : `bool`, optional, default True Boolean to plot uncertainty intervals. plot_cap : `bool`, optional, default True Boolean indicating if the capacity should be shown in the figure, if available. weekly_start : `int`, optional, default 0 Specifying the start day of the weekly seasonality plot. 0 (default) starts the week on Sunday. 1 shifts by 1 day to Jan 2, and so on. yearly_start : `int`, optional, default 0 Specifying the start day of the yearly seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts by 1 day to Jan 2, and so on. figsize : `tuple` , optional, default None Width, height in inches. Returns ------- fig: `matplotlib.figure.Figure` A matplotlib figure. """ if self.model is None: raise NotFittedError("The fit method has not been run yet.") if self.forecast is None: raise RuntimeError("The predict method has not been run yet.") try: return plot_components( m=self.model, fcst=self.forecast, uncertainty=uncertainty, plot_cap=plot_cap, weekly_start=weekly_start, yearly_start=yearly_start, figsize=figsize) except AttributeError as e: if "'DatetimeIndex'" in repr(e): # 'DatetimeIndex' object has no attribute 'weekday_name' raise Exception("Prophet 0.5 component plots are incompatible with pandas 1.*. " "Upgrade to fbprophet:0.6 or higher.") else: raise e