Skip to content

Commit

Permalink
fix estimation bug and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andre_ramos committed Aug 4, 2024
1 parent 8b0f65d commit 8e593d4
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 3 deletions.
2 changes: 1 addition & 1 deletion paper_tests/m4_test/evaluate_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ function evaluate_SSL(initialization_df::DataFrame, results_df::DataFrame, input
T= length(normalized_y)
normalized_y = normalized_y[max(1, T-sample_size+1):end]
output = StateSpaceLearning.fit_model(normalized_y;
model_input = Dict("stochastic_level" => true, "trend" => true,
model_input = Dict("level" => true, "stochastic_level" => true, "trend" => true,
"stochastic_trend" => true,
"seasonal" => true, "stochastic_seasonal" => true, "freq_seasonal" => 12,
"outlier" => outlier, "ζ_ω_threshold" => 12),
Expand Down
2 changes: 1 addition & 1 deletion paper_tests/simulation_test/evaluate_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function get_SSL_results(y_train::Vector{Float64}, true_features::Vector{Int64},
series_result=nothing

t = @elapsed output = StateSpaceLearning.fit_model(y_train; Exogenous_X=X_train,
model_input = Dict("stochastic_level" => true, "trend" => true,
model_input = Dict("level" => true, "stochastic_level" => true, "trend" => true,
"stochastic_trend" => true,
"seasonal" => true, "stochastic_seasonal" => true, "freq_seasonal" => 12,
"outlier" => false, "ζ_ω_threshold" => 12),
Expand Down
2 changes: 1 addition & 1 deletion src/estimation_procedure/default_estimation_procedure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,5 +225,5 @@ function default_estimation_procedure(Estimation_X::Matrix{Tl}, estimation_y::Ve
!penalize_initial_states ? ts_penalty_factor[components_indexes["initial_states"][2:end]] .= 0 : nothing
end

return fit_lasso(Estimation_X, estimation_y, α, information_criteria, penalize_exogenous, components_indexes, penalty_factor; rm_average = false)
return fit_lasso(Estimation_X, estimation_y, α, information_criteria, penalize_exogenous, components_indexes, ts_penalty_factor; rm_average = false)
end
6 changes: 6 additions & 0 deletions test/StateSpaceLearning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,10 @@ end
@test_throws AssertionError StateSpaceLearning.forecast(output1, 10; Exogenous_Forecast = rand(5, 3))
@test_throws AssertionError StateSpaceLearning.forecast(output2, 10)
@test_throws AssertionError StateSpaceLearning.forecast(output2, 10; Exogenous_Forecast = rand(5, 3))

y3 = [4.718, 4.77, 4.882, 4.859, 4.795, 4.905, 4.997, 4.997, 4.912, 4.779, 4.644, 4.77, 4.744, 4.836, 4.948, 4.905, 4.828, 5.003, 5.135, 5.135, 5.062, 4.89, 4.736, 4.941, 4.976, 5.01, 5.181, 5.093, 5.147, 5.181, 5.293, 5.293, 5.214, 5.087, 4.983, 5.111, 5.141, 5.192, 5.262, 5.198, 5.209, 5.384, 5.438, 5.488, 5.342, 5.252, 5.147, 5.267, 5.278, 5.278, 5.463, 5.459, 5.433, 5.493, 5.575, 5.605, 5.468, 5.351, 5.192, 5.303, 5.318, 5.236, 5.459, 5.424, 5.455, 5.575, 5.71, 5.68, 5.556, 5.433, 5.313, 5.433, 5.488, 5.451, 5.587, 5.594, 5.598, 5.752, 5.897, 5.849, 5.743, 5.613, 5.468, 5.627, 5.648, 5.624, 5.758, 5.746, 5.762, 5.924, 6.023, 6.003, 5.872, 5.723, 5.602, 5.723, 5.752, 5.707, 5.874, 5.852, 5.872, 6.045, 6.142, 6.146, 6.001, 5.849, 5.72, 5.817, 5.828, 5.762, 5.891, 5.852, 5.894, 6.075, 6.196, 6.224, 6.001, 5.883, 5.736, 5.82, 5.886, 5.834, 6.006, 5.981, 6.04, 6.156, 6.306, 6.326, 6.137, 6.008, 5.891, 6.003, 6.033, 5.968, 6.037, 6.133, 6.156, 6.282, 6.432, 6.406, 6.23, 6.133, 5.966, 6.068]
output3 = StateSpaceLearning.fit_model(y3)
forecast3 = trunc.(StateSpaceLearning.forecast(output3, 18); digits = 3)
@assert forecast3 == [6.11, 6.082, 6.221, 6.19, 6.197, 6.328, 6.447, 6.44, 6.285, 6.163, 6.026, 6.142, 6.166, 6.138, 6.278, 6.246, 6.253, 6.384]

end

0 comments on commit 8e593d4

Please sign in to comment.