Skip to content

Commit

Permalink
m2u
Browse files Browse the repository at this point in the history
Signed-off-by: oneJue <[email protected]>
  • Loading branch information
oneJue committed Feb 1, 2025
1 parent fb32430 commit 32332a3
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 24 deletions.
7 changes: 6 additions & 1 deletion ts_benchmark/baselines/duet/duet.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,12 @@ def validate(self, valid_data_loader, criterion):
self.model.train()
return total_loss

def forecast_fit(self, train_valid_data: pd.DataFrame, train_ratio_in_tv: float) -> "ModelBase":
def forecast_fit(
self,
train_valid_data: pd.DataFrame,
covariates: Optional[Dict],
train_ratio_in_tv: float,
) -> "ModelBase":
"""
Train the model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,21 +262,22 @@ def validate(
def forecast_fit(
self,
train_valid_data: pd.DataFrame,
covariates: Optional[Dict],
covariates: dict,
train_ratio_in_tv: float,
) -> "ModelBase":
"""
Train the model.
:param train_data: Time series data used for training.
:param covariates: Additional external variables
:param covariates: Additional external variables.
:param train_ratio_in_tv: Represents the splitting ratio of the training set validation set. If it is equal to 1, it means that the validation set is not partitioned.
:return: The fitted model object.
"""
exog_dim = -1
if covariates["exog"] is not None:
exog_dim = covariates["exog"].shape[-1]
train_valid_data = pd.concat([train_valid_data, covariates["exog"]], axis=1)
exog_data = covariates.get("exog")
if exog_data is not None:
exog_dim = exog_data.shape[-1]
train_valid_data = pd.concat([train_valid_data, exog_data], axis=1)

if train_valid_data.shape[1] == 1:
train_drop_last = False
Expand Down Expand Up @@ -396,36 +397,40 @@ def forecast_fit(
adjust_learning_rate(optimizer, epoch + 1, config)

def forecast(
self, horizon: int, covariates: Optional[Dict], train: pd.DataFrame
self, horizon: int, covariates: dict, series: pd.DataFrame
) -> np.ndarray:
"""
Make predictions.
:param horizon: The predicted length.
:param covariates: Additional external variables
:param testdata: Time series data used for prediction.
:param series: Time series data used for prediction.
:return: An array of predicted results.
"""
exog_dim = -1
if covariates["exog"] is not None:
exog_dim = covariates["exog"].shape[-1]
train = pd.concat([train, covariates["exog"]], axis=1)
series = pd.concat([series, covariates["exog"]], axis=1)
if exog_dim != -1 and horizon != self.config.output_chunk_length:
raise ValueError(
f"Error: 'exog' is enabled during training, but horizon ({horizon}) != output_chunk_length ({self.config.output_chunk_length}) during forecast."
)

if self.early_stopping.check_point is not None:
self.model.load_state_dict(self.early_stopping.check_point)

if self.config.norm:
train = pd.DataFrame(
self.scaler.transform(train.values),
columns=train.columns,
index=train.index,
series = pd.DataFrame(
self.scaler.transform(series.values),
columns=series.columns,
index=series.index,
)

if self.model is None:
raise ValueError("Model not trained. Call the fit() function first.")

config = self.config
train, test = split_time(train, len(train) - config.seq_len)
series, test = split_time(series, len(series) - config.seq_len)

# Additional timestamp marks required to generate transformer class methods
test = self.padding_data_for_forecast(test)
Expand Down Expand Up @@ -517,6 +522,10 @@ def batch_forecast(
input_np = np.concatenate(
(input_np, input_data["covariates"]["exog"]), axis=2
)
if exog_dim != -1 and horizon != self.config.output_chunk_length:
raise ValueError(
f"Error: 'exog' is enabled during training, but horizon ({horizon}) != output_chunk_length ({self.config.output_chunk_length}) during forecast."
)

if self.config.norm:
origin_shape = input_np.shape
Expand Down
10 changes: 3 additions & 7 deletions ts_benchmark/evaluation/strategy/rolling_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
self,
series: pd.DataFrame,
index_list: List[int],
covariates: Optional[Dict] = None,
covariates: dict,
):
self.series = series
self.index_list = index_list
Expand Down Expand Up @@ -294,12 +294,8 @@ def _eval_sample(
all_rolling_predict = []
for i, index in itertools.islice(enumerate(index_list), num_rollings):
train, rest = split_time(series, index)
target_train, exog_train = split_channel(train)
test, _ = split_channel(
split_time(rest, horizon)[0].iloc[
:, : target_train_valid_data.shape[-1]
]
)
test, _ = split_channel(split_time(rest, horizon)[0], target_channel)
target_train, exog_train = split_channel(train, target_channel)
covariates = {"exog": exog_train}

start_inference_time = time.time()
Expand Down
5 changes: 2 additions & 3 deletions ts_benchmark/models/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def forecast_fit(
self,
train_data: pd.DataFrame,
*,
covariates: Optional[Dict] = None,
covariates: dict,
train_ratio_in_tv: float = 1.0,
**kwargs
) -> "ModelBase":
Expand All @@ -75,9 +75,8 @@ def forecast(
self,
horizon: int,
*,
covariates: Optional[Dict] = None,
series: pd.DataFrame,
**kwargs
covariates: dict,
) -> np.ndarray:
"""
Forecasting with the model
Expand Down
1 change: 1 addition & 0 deletions ts_benchmark/utils/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def split_channel(
:param df: Input DataFrame to split.
:param target_channel: Rules for selecting target columns. Can include:
- Integers (positive/negative) for single column indices.
- Lists/tuples of two integers representing slices (e.g., `[2,4]` selects columns 2-3).
- If `None`, all columns are treated as target columns (exog becomes None).
Expand Down

0 comments on commit 32332a3

Please sign in to comment.