How to on Spark: Cross Validation
Run TimeGPT distributedly on top of Spark.
TimeGPT
works on top of Spark, Dask, and Ray through Fugue.TimeGPT
will read the input DataFrame and use the corresponding engine. For example, if the input is a Spark DataFrame, StatsForecast will use the existing Spark session to run the forecast.
Installation
As long as Spark is installed and configured, TimeGPT
will be able to use it. If executing on a distributed Spark cluster, make use the nixtlats
library is installed across all the workers.
Executing on Spark
To run the forecasts distributed on Spark, just pass in a Spark DataFrame instead. Instantiate TimeGPT
class.
from nixtlats import TimeGPT
/home/ubuntu/miniconda/envs/nixtlats/lib/python3.11/site-packages/statsforecast/core.py:25: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from tqdm.autonotebook import tqdm
timegpt = TimeGPT(
# defaults to os.environ.get("TIMEGPT_TOKEN")
token = 'my_token_provided_by_nixtla'
)
Use Spark as an engine.
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/11/09 17:49:20 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
23/11/09 17:49:21 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
23/11/09 17:49:21 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.
23/11/09 17:49:21 WARN Utils: Service 'SparkUI' could not bind on port 4042. Attempting port 4043.
Cross validation
url_df = 'https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/electricity-short.csv'
spark_df = spark.createDataFrame(pd.read_csv(url_df))
spark_df.show(5)
+---------+-------------------+-----+
|unique_id| ds| y|
+---------+-------------------+-----+
| BE|2016-12-01 00:00:00| 72.0|
| BE|2016-12-01 01:00:00| 65.8|
| BE|2016-12-01 02:00:00|59.99|
| BE|2016-12-01 03:00:00|50.69|
| BE|2016-12-01 04:00:00|52.58|
+---------+-------------------+-----+
only showing top 5 rows
fcst_df = timegpt.cross_validation(spark_df, h=12, n_windows=5, step_size=2)
fcst_df.show(5)
INFO:nixtlats.timegpt:Validating inputs... (5 + 15) / 20]
INFO:nixtlats.timegpt:Inferred freq: H
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Preprocessing dataframes...
INFO:nixtlats.timegpt:Inferred freq: H
INFO:nixtlats.timegpt:Calling Forecast Endpoint...=============> (19 + 1) / 20]
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Preprocessing dataframes...
INFO:nixtlats.timegpt:Inferred freq: H
INFO:nixtlats.timegpt:Calling Forecast Endpoint...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Preprocessing dataframes...
INFO:nixtlats.timegpt:Inferred freq: H
INFO:nixtlats.timegpt:Calling Forecast Endpoint...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Preprocessing dataframes...
INFO:nixtlats.timegpt:Inferred freq: H
INFO:nixtlats.timegpt:Calling Forecast Endpoint...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Preprocessing dataframes...
INFO:nixtlats.timegpt:Inferred freq: H
INFO:nixtlats.timegpt:Calling Forecast Endpoint...
INFO:nixtlats.timegpt:Validating inputs...
+---------+-------------------+-------------------+------------------+
|unique_id| ds| cutoff| TimeGPT|
+---------+-------------------+-------------------+------------------+
| FR|2016-12-30 04:00:00|2016-12-30 03:00:00| 44.89374542236328|
| FR|2016-12-30 05:00:00|2016-12-30 03:00:00| 46.05792999267578|
| FR|2016-12-30 06:00:00|2016-12-30 03:00:00|48.790077209472656|
| FR|2016-12-30 07:00:00|2016-12-30 03:00:00| 54.39702606201172|
| FR|2016-12-30 08:00:00|2016-12-30 03:00:00| 57.59300231933594|
+---------+-------------------+-------------------+------------------+
only showing top 5 rows
Cross validation with exogenous variables
Exogenous variables or external factors are crucial in time series forecasting as they provide additional information that might influence the prediction. These variables could include holiday markers, marketing spending, weather data, or any other external data that correlate with the time series data you are forecasting. For example, if you’re forecasting ice cream sales, temperature data could serve as a useful exogenous variable. On hotter days, ice cream sales may increase. To incorporate exogenous variables in TimeGPT, you’ll need to pair each point in your time series data with the corresponding external data. Let’s see an example.
df = pd.read_csv('https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/electricity-short-with-ex-vars.csv')
spark_df = spark.createDataFrame(df)
spark_df.show(5)
+---------+-------------------+-----+----------+----------+-----+-----+-----+-----+-----+-----+-----+
|unique_id| ds| y|Exogenous1|Exogenous2|day_0|day_1|day_2|day_3|day_4|day_5|day_6|
+---------+-------------------+-----+----------+----------+-----+-----+-----+-----+-----+-----+-----+
| BE|2016-12-01 00:00:00| 72.0| 61507.0| 71066.0| 0.0| 0.0| 0.0| 1.0| 0.0| 0.0| 0.0|
| BE|2016-12-01 01:00:00| 65.8| 59528.0| 67311.0| 0.0| 0.0| 0.0| 1.0| 0.0| 0.0| 0.0|
| BE|2016-12-01 02:00:00|59.99| 58812.0| 67470.0| 0.0| 0.0| 0.0| 1.0| 0.0| 0.0| 0.0|
| BE|2016-12-01 03:00:00|50.69| 57676.0| 64529.0| 0.0| 0.0| 0.0| 1.0| 0.0| 0.0| 0.0|
| BE|2016-12-01 04:00:00|52.58| 56804.0| 62773.0| 0.0| 0.0| 0.0| 1.0| 0.0| 0.0| 0.0|
+---------+-------------------+-----+----------+----------+-----+-----+-----+-----+-----+-----+-----+
only showing top 5 rows
Let’s call the cross_validation
method, adding this information:
timegpt_cv_ex_vars_df = timegpt.cross_validation(
df=spark_df,
h=48,
level=[80, 90],
n_windows=5,
)
timegpt_cv_ex_vars_df.show(5)
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Inferred freq: H
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Preprocessing dataframes...
INFO:nixtlats.timegpt:Inferred freq: H
WARNING:nixtlats.timegpt:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtlats.timegpt:Restricting input...
INFO:nixtlats.timegpt:Calling Forecast Endpoint...
INFO:nixtlats.timegpt:Validating inputs...=====================> (19 + 1) / 20]
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Preprocessing dataframes...
INFO:nixtlats.timegpt:Inferred freq: H
WARNING:nixtlats.timegpt:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtlats.timegpt:Restricting input...
INFO:nixtlats.timegpt:Calling Forecast Endpoint...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Preprocessing dataframes...
INFO:nixtlats.timegpt:Inferred freq: H
WARNING:nixtlats.timegpt:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtlats.timegpt:Restricting input...
INFO:nixtlats.timegpt:Calling Forecast Endpoint...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Preprocessing dataframes...
INFO:nixtlats.timegpt:Inferred freq: H
WARNING:nixtlats.timegpt:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtlats.timegpt:Restricting input...
INFO:nixtlats.timegpt:Calling Forecast Endpoint...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Validating inputs...
INFO:nixtlats.timegpt:Preprocessing dataframes...
INFO:nixtlats.timegpt:Inferred freq: H
WARNING:nixtlats.timegpt:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtlats.timegpt:Restricting input...
INFO:nixtlats.timegpt:Calling Forecast Endpoint...
INFO:nixtlats.timegpt:Validating inputs...
+---------+-------------------+-------------------+------------------+------------------+------------------+------------------+------------------+
|unique_id| ds| cutoff| TimeGPT| TimeGPT-lo-90| TimeGPT-lo-80| TimeGPT-hi-80| TimeGPT-hi-90|
+---------+-------------------+-------------------+------------------+------------------+------------------+------------------+------------------+
| FR|2016-12-21 00:00:00|2016-12-20 23:00:00| 57.46266174316406| 54.32243190002441|54.725050598144534| 60.20027288818359|60.602891586303706|
| FR|2016-12-21 01:00:00|2016-12-20 23:00:00|52.549095153808594|50.111817771911625| 50.20576373291016| 54.89242657470703| 54.98637253570556|
| FR|2016-12-21 02:00:00|2016-12-20 23:00:00| 49.98523712158203|47.396572181701664| 48.40804647827149|51.562427764892576| 52.5739020614624|
| FR|2016-12-21 03:00:00|2016-12-20 23:00:00| 49.146240234375| 46.38533438110352| 46.51724838256836| 51.77523208618164| 51.90714608764648|
| FR|2016-12-21 04:00:00|2016-12-20 23:00:00| 47.01085662841797| 42.29354175567627|42.783941421508786|51.237771835327145|51.728171501159665|
+---------+-------------------+-------------------+------------------+------------------+------------------+------------------+------------------+
only showing top 5 rows
spark.stop()
Updated about 1 month ago