From 77de6a902f9e303c2505546f3c85a6e7d0662295 Mon Sep 17 00:00:00 2001 From: LuizFCDuarte Date: Mon, 10 Jun 2024 21:52:50 -0300 Subject: [PATCH] :sparkles: Add OCSB test WIP the code is using PyCall --- Project.toml | 3 ++- src/Sarimax.jl | 1 + src/models/sarima.jl | 2 +- src/utils.jl | 9 +++++++++ 4 files changed, 13 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index fb3d6ea..9e0d203 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Optim = "429524aa-4258-5aef-a3af-852621145aeb" +PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" SCIP = "82193955-e24f-5292-bf16-6f2c5261a85f" @@ -29,4 +30,4 @@ TimeSeries = "9e3dc215-6440-5c97-bce1-76c03772f85e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["Test", "JSON"] diff --git a/src/Sarimax.jl b/src/Sarimax.jl index b46b5f6..181d557 100644 --- a/src/Sarimax.jl +++ b/src/Sarimax.jl @@ -20,6 +20,7 @@ using SCIP using StateSpaceModels using Statistics using TimeSeries +using PyCall # using GLMNet # using Lasso diff --git a/src/models/sarima.jl b/src/models/sarima.jl index 60df9ec..d69c62e 100644 --- a/src/models/sarima.jl +++ b/src/models/sarima.jl @@ -982,7 +982,7 @@ function auto( @assert maxQ >= 0 @assert informationCriteria ∈ ["aic","aicc","bic"] @assert integrationTest ∈ ["kpss"] - @assert seasonalIntegrationTest ∈ ["seas","ch"] + @assert seasonalIntegrationTest ∈ ["seas","ch","ocsb"] @assert objectiveFunction ∈ ["mae","mse","ml","bilevel"] ModelFl = eltype(values(y)) diff --git a/src/utils.jl b/src/utils.jl index b958108..edf72b0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -156,6 +156,15 @@ function selectSeasonalIntegrationOrder( return StateSpaceModels.seasonal_strength_test(y, seasonality) elseif test == "ch" return StateSpaceModels.canova_hansen_test(y, seasonality) + elseif test == "ocsb" + py""" + import pmdarima as pm + import numpy as np + def seasonal_diffs(ts, seasonal_period): + ts_np = np.array(ts) + return pm.arima.nsdiffs(ts_np, m=seasonal_period) + """ + return py"seasonal_diffs"(y, seasonality) end throw(ArgumentError("The test $test is not supported"))