How to on Spark: Cross Validation

Run TimeGPT distributedly on top of Spark.
TimeGPT works on top of Spark, Dask, and Ray through Fugue. TimeGPT will read the input DataFrame and use the corresponding engine. For example, if the input is a Spark DataFrame, StatsForecast will use the existing Spark session to run the forecast.

Installation

As long as Spark is installed and configured, TimeGPT will be able to use it. If executing on a distributed Spark cluster, make use the nixtlats library is installed across all the workers.

Executing on Spark

To run the forecasts distributed on Spark, just pass in a Spark DataFrame instead. Instantiate TimeGPT class.


from nixtlats import TimeGPT

/home/ubuntu/miniconda/envs/nixtlats/lib/python3.11/site-packages/statsforecast/core.py:25: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from tqdm.autonotebook import tqdm

timegpt = TimeGPT(
    # defaults to os.environ.get("TIMEGPT_TOKEN")
    token = 'my_token_provided_by_nixtla'
)

Use Spark as an engine.


from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/11/09 17:49:20 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
23/11/09 17:49:21 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
23/11/09 17:49:21 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.
23/11/09 17:49:21 WARN Utils: Service 'SparkUI' could not bind on port 4042. Attempting port 4043.

Cross validation


url_df = 'https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/electricity-short.csv'
spark_df = spark.createDataFrame(pd.read_csv(url_df))
spark_df.show(5)

+---------+-------------------+-----+
|unique_id|                 ds|    y|
+---------+-------------------+-----+
|       BE|2016-12-01 00:00:00| 72.0|
|       BE|2016-12-01 01:00:00| 65.8|
|       BE|2016-12-01 02:00:00|59.99|
|       BE|2016-12-01 03:00:00|50.69|
|       BE|2016-12-01 04:00:00|52.58|
+---------+-------------------+-----+
only showing top 5 rows

fcst_df = timegpt.cross_validation(spark_df, h=12, n_windows=5, step_size=2)
fcst_df.show(5)

INFO:nixtlats.timegpt:Validating inputs...                        (5 + 15) / 20]
INFO:nixtlats.timegpt:Inferred freq: H
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Preprocessing dataframes...
INFO:nixtlats.timegpt:Inferred freq: H
INFO:nixtlats.timegpt:Calling Forecast Endpoint...=============>  (19 + 1) / 20]
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Preprocessing dataframes...
INFO:nixtlats.timegpt:Inferred freq: H
INFO:nixtlats.timegpt:Calling Forecast Endpoint...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Preprocessing dataframes...
INFO:nixtlats.timegpt:Inferred freq: H
INFO:nixtlats.timegpt:Calling Forecast Endpoint...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Preprocessing dataframes...
INFO:nixtlats.timegpt:Inferred freq: H
INFO:nixtlats.timegpt:Calling Forecast Endpoint...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Preprocessing dataframes...
INFO:nixtlats.timegpt:Inferred freq: H
INFO:nixtlats.timegpt:Calling Forecast Endpoint...
INFO:nixtlats.timegpt:Validating inputs...
                                                                                

+---------+-------------------+-------------------+------------------+
|unique_id|                 ds|             cutoff|           TimeGPT|
+---------+-------------------+-------------------+------------------+
|       FR|2016-12-30 04:00:00|2016-12-30 03:00:00| 44.89374542236328|
|       FR|2016-12-30 05:00:00|2016-12-30 03:00:00| 46.05792999267578|
|       FR|2016-12-30 06:00:00|2016-12-30 03:00:00|48.790077209472656|
|       FR|2016-12-30 07:00:00|2016-12-30 03:00:00| 54.39702606201172|
|       FR|2016-12-30 08:00:00|2016-12-30 03:00:00| 57.59300231933594|
+---------+-------------------+-------------------+------------------+
only showing top 5 rows

Cross validation with exogenous variables

Exogenous variables or external factors are crucial in time series forecasting as they provide additional information that might influence the prediction. These variables could include holiday markers, marketing spending, weather data, or any other external data that correlate with the time series data you are forecasting. For example, if you’re forecasting ice cream sales, temperature data could serve as a useful exogenous variable. On hotter days, ice cream sales may increase. To incorporate exogenous variables in TimeGPT, you’ll need to pair each point in your time series data with the corresponding external data. Let’s see an example.


df = pd.read_csv('https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/electricity-short-with-ex-vars.csv')
spark_df = spark.createDataFrame(df)
spark_df.show(5)

+---------+-------------------+-----+----------+----------+-----+-----+-----+-----+-----+-----+-----+
|unique_id|                 ds|    y|Exogenous1|Exogenous2|day_0|day_1|day_2|day_3|day_4|day_5|day_6|
+---------+-------------------+-----+----------+----------+-----+-----+-----+-----+-----+-----+-----+
|       BE|2016-12-01 00:00:00| 72.0|   61507.0|   71066.0|  0.0|  0.0|  0.0|  1.0|  0.0|  0.0|  0.0|
|       BE|2016-12-01 01:00:00| 65.8|   59528.0|   67311.0|  0.0|  0.0|  0.0|  1.0|  0.0|  0.0|  0.0|
|       BE|2016-12-01 02:00:00|59.99|   58812.0|   67470.0|  0.0|  0.0|  0.0|  1.0|  0.0|  0.0|  0.0|
|       BE|2016-12-01 03:00:00|50.69|   57676.0|   64529.0|  0.0|  0.0|  0.0|  1.0|  0.0|  0.0|  0.0|
|       BE|2016-12-01 04:00:00|52.58|   56804.0|   62773.0|  0.0|  0.0|  0.0|  1.0|  0.0|  0.0|  0.0|
+---------+-------------------+-----+----------+----------+-----+-----+-----+-----+-----+-----+-----+
only showing top 5 rows

Let’s call the cross_validation method, adding this information:


timegpt_cv_ex_vars_df = timegpt.cross_validation(
    df=spark_df,
    h=48, 
    level=[80, 90],
    n_windows=5,
)
timegpt_cv_ex_vars_df.show(5)

INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Inferred freq: H
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Preprocessing dataframes...
INFO:nixtlats.timegpt:Inferred freq: H
WARNING:nixtlats.timegpt:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtlats.timegpt:Restricting input...
INFO:nixtlats.timegpt:Calling Forecast Endpoint...
INFO:nixtlats.timegpt:Validating inputs...=====================>  (19 + 1) / 20]
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Preprocessing dataframes...
INFO:nixtlats.timegpt:Inferred freq: H
WARNING:nixtlats.timegpt:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtlats.timegpt:Restricting input...
INFO:nixtlats.timegpt:Calling Forecast Endpoint...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Preprocessing dataframes...
INFO:nixtlats.timegpt:Inferred freq: H
WARNING:nixtlats.timegpt:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtlats.timegpt:Restricting input...
INFO:nixtlats.timegpt:Calling Forecast Endpoint...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Preprocessing dataframes...
INFO:nixtlats.timegpt:Inferred freq: H
WARNING:nixtlats.timegpt:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtlats.timegpt:Restricting input...
INFO:nixtlats.timegpt:Calling Forecast Endpoint...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Preprocessing dataframes...
INFO:nixtlats.timegpt:Inferred freq: H
WARNING:nixtlats.timegpt:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtlats.timegpt:Restricting input...
INFO:nixtlats.timegpt:Calling Forecast Endpoint...
INFO:nixtlats.timegpt:Validating inputs...
                                                                                

+---------+-------------------+-------------------+------------------+------------------+------------------+------------------+------------------+
|unique_id|                 ds|             cutoff|           TimeGPT|     TimeGPT-lo-90|     TimeGPT-lo-80|     TimeGPT-hi-80|     TimeGPT-hi-90|
+---------+-------------------+-------------------+------------------+------------------+------------------+------------------+------------------+
|       FR|2016-12-21 00:00:00|2016-12-20 23:00:00| 57.46266174316406| 54.32243190002441|54.725050598144534| 60.20027288818359|60.602891586303706|
|       FR|2016-12-21 01:00:00|2016-12-20 23:00:00|52.549095153808594|50.111817771911625| 50.20576373291016| 54.89242657470703| 54.98637253570556|
|       FR|2016-12-21 02:00:00|2016-12-20 23:00:00| 49.98523712158203|47.396572181701664| 48.40804647827149|51.562427764892576|  52.5739020614624|
|       FR|2016-12-21 03:00:00|2016-12-20 23:00:00|   49.146240234375| 46.38533438110352| 46.51724838256836| 51.77523208618164| 51.90714608764648|
|       FR|2016-12-21 04:00:00|2016-12-20 23:00:00| 47.01085662841797| 42.29354175567627|42.783941421508786|51.237771835327145|51.728171501159665|
+---------+-------------------+-------------------+------------------+------------------+------------------+------------------+------------------+
only showing top 5 rows

spark.stop()