Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

predict() function definition lacks an optional argument "segment" #1821

Open
teancake opened this issue Jul 3, 2024 · 0 comments
Open

predict() function definition lacks an optional argument "segment" #1821

teancake opened this issue Jul 3, 2024 · 0 comments
Labels
question Further information is requested

Comments

@teancake
Copy link

teancake commented Jul 3, 2024

The definition of function predict()

def predict(self, dataset):

in files

qlib/contrib/model/pytorch_lstm_ts.py
qlib/contrib/model/pytorch_tcn_ts.py

lacks an optional argument "segment", as defined in the base class

class Model(BaseModel):

    def predict(self, dataset: Dataset, segment: Union[Text, slice] = "test") -> object:

This causes predict() in LSTM and TCN to fail when segment for prediction is specified explicitly, and is not compatible with other models such as LGBModel, e.g.

model.predict(dataset=dataset, segment="pred")

Possible revision:

    def predict(self, dataset: Dataset, segment: Union[Text, slice] = "test") -> object:
        if not self.fitted:
            raise ValueError("model is not fitted yet!")
        dl_test = dataset.prepare(segment, col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
@teancake teancake added the question Further information is requested label Jul 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

1 participant