Skip to content

Commit

Permalink
download orderbook data (#1754)
Browse files Browse the repository at this point in the history
* download orderbook data

* fix CI error

* fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* optimize get_data code

* optimize get_data code

* optimize get_data code

* optimize README

---------

Co-authored-by: Linlang <[email protected]>
  • Loading branch information
SunsetWolf and Linlang authored Mar 7, 2024
1 parent 98f569e commit 39f88da
Show file tree
Hide file tree
Showing 14 changed files with 30 additions and 34 deletions.
3 changes: 0 additions & 3 deletions examples/benchmarks/TRA/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,6 @@ def predict(self, dataset, segment="test"):


class LSTM(nn.Module):

"""LSTM Model
Args:
Expand Down Expand Up @@ -414,7 +413,6 @@ def forward(self, x):


class Transformer(nn.Module):

"""Transformer Model
Args:
Expand Down Expand Up @@ -475,7 +473,6 @@ def forward(self, x):


class TRA(nn.Module):

"""Temporal Routing Adaptor (TRA)
TRA takes historical prediction errors & latent representation as inputs,
Expand Down
5 changes: 1 addition & 4 deletions examples/orderbook_data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,11 @@ pip install arctic # NOTE: pip may fail to resolve the right package dependency
2. Please follow following steps to download example data
```bash
cd examples/orderbook_data/
wget http://fintech.msra.cn/stock_data/downloads/highfreq_orderboook_example_data.tar.bz2
tar xf highfreq_orderboook_example_data.tar.bz2
python ../../scripts/get_data.py download_data --target_dir . --file_name highfreq_orderbook_example_data.zip
```

3. Please import the example data to your mongo db
```bash
cd examples/orderbook_data/
python create_dataset.py initialize_library # Initialization Libraries
python create_dataset.py import_data # Initialization Libraries
```
Expand All @@ -42,7 +40,6 @@ python create_dataset.py import_data # Initialization Libraries

After importing these data, you run `example.py` to create some high-frequency features.
```bash
cd examples/orderbook_data/
pytest -s --disable-warnings example.py # If you want run all examples
pytest -s --disable-warnings example.py::TestClass::test_exp_10 # If you want to run specific example
```
Expand Down
16 changes: 9 additions & 7 deletions qlib/backtest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,15 @@ def create_account_instance(
init_cash=init_cash,
position_dict=position_dict,
pos_type=pos_type,
benchmark_config={}
if benchmark is None
else {
"benchmark": benchmark,
"start_time": start_time,
"end_time": end_time,
},
benchmark_config=(
{}
if benchmark is None
else {
"benchmark": benchmark,
"start_time": start_time,
"end_time": end_time,
}
),
)


Expand Down
8 changes: 5 additions & 3 deletions qlib/backtest/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,9 +622,11 @@ def cal_trade_indicators(
print(
"[Indicator({}) {}]: FFR: {}, PA: {}, POS: {}".format(
freq,
trade_start_time
if isinstance(trade_start_time, str)
else trade_start_time.strftime("%Y-%m-%d %H:%M:%S"),
(
trade_start_time
if isinstance(trade_start_time, str)
else trade_start_time.strftime("%Y-%m-%d %H:%M:%S")
),
fulfill_rate,
price_advantage,
positive_rate,
Expand Down
1 change: 1 addition & 0 deletions qlib/contrib/eva/alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
The interface should be redesigned carefully in the future.
"""

import pandas as pd
from typing import Tuple
from qlib import get_module_logger
Expand Down
3 changes: 0 additions & 3 deletions qlib/contrib/model/pytorch_tra.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,6 @@ def predict(self, dataset, segment="test"):


class RNN(nn.Module):

"""RNN Model
Args:
Expand Down Expand Up @@ -601,7 +600,6 @@ def forward(self, x):


class Transformer(nn.Module):

"""Transformer Model
Args:
Expand Down Expand Up @@ -649,7 +647,6 @@ def forward(self, x):


class TRA(nn.Module):

"""Temporal Routing Adaptor (TRA)
TRA takes historical prediction errors & latent representation as inputs,
Expand Down
1 change: 0 additions & 1 deletion qlib/contrib/strategy/signal_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,6 @@ def generate_trade_decision(self, execute_result=None):


class EnhancedIndexingStrategy(WeightStrategyBase):

"""Enhanced Indexing Strategy
Enhanced indexing combines the arts of active management and passive management,
Expand Down
2 changes: 0 additions & 2 deletions qlib/model/ens/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def __call__(self, ensemble_dict: dict, *args, **kwargs):


class SingleKeyEnsemble(Ensemble):

"""
Extract the object if there is only one key and value in the dict. Make the result more readable.
{Only key: Only value} -> Only value
Expand Down Expand Up @@ -64,7 +63,6 @@ def __call__(self, ensemble_dict: Union[dict, object], recursion: bool = True) -


class RollingEnsemble(Ensemble):

"""Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.
NOTE: The values of dict must be pd.DataFrame, and have the index "datetime".
Expand Down
4 changes: 1 addition & 3 deletions qlib/model/riskmodel/shrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,7 @@ def _get_shrink_param_lw_single_factor(self, X: np.ndarray, S: np.ndarray, F: np
v1 = y.T.dot(z) / t - cov_mkt[:, None] * S
roff1 = np.sum(v1 * cov_mkt[:, None].T) / var_mkt - np.sum(np.diag(v1) * cov_mkt) / var_mkt
v3 = z.T.dot(z) / t - var_mkt * S
roff3 = (
np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2
)
roff3 = np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2
roff = 2 * roff1 - roff3
rho = rdiag + roff

Expand Down
1 change: 0 additions & 1 deletion qlib/workflow/online/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def get_collector(self) -> Collector:


class RollingStrategy(OnlineStrategy):

"""
This example strategy always uses the latest rolling model sas online models.
"""
Expand Down
4 changes: 1 addition & 3 deletions scripts/dump_bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,7 @@ def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]:
return (
self._include_fields
if self._include_fields
else set(df_columns) - set(self._exclude_fields)
if self._exclude_fields
else df_columns
else set(df_columns) - set(self._exclude_fields) if self._exclude_fields else df_columns
)

@staticmethod
Expand Down
8 changes: 5 additions & 3 deletions scripts/dump_pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,11 @@ def get_dump_fields(self, df: Iterable[str]) -> Iterable[str]:
return (
set(self._include_fields)
if self._include_fields
else set(df[self.field_column_name]) - set(self._exclude_fields)
if self._exclude_fields
else set(df[self.field_column_name])
else (
set(df[self.field_column_name]) - set(self._exclude_fields)
if self._exclude_fields
else set(df[self.field_column_name])
)
)

def get_filenames(self, symbol, field, interval):
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def get_version(rel_path: str) -> str:
# To ensure stable operation of the experiment manager, we have limited the version of mlflow,
# and we need to verify whether version 2.0 of mlflow can serve qlib properly.
"mlflow>=1.12.1, <=1.30.0",
# mlflow 1.30.0 requires packaging<22, so we limit the packaging version, otherwise the CI will fail.
"packaging<22",
"tqdm",
"loguru",
"lightgbm>=3.3.0",
Expand Down
6 changes: 5 additions & 1 deletion tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,18 @@


class WorkflowTest(TestAutoData):
TMP_PATH = Path("./.mlruns_tmp/")
# Creating the directory manually doesn't work with mlflow,
# so we add a subfolder named .trash when we create the directory.
TMP_PATH = Path("./.mlruns_tmp/.trash")

def tearDown(self) -> None:
if self.TMP_PATH.exists():
shutil.rmtree(self.TMP_PATH)

def test_get_local_dir(self):
""" """
self.TMP_PATH.mkdir(parents=True, exist_ok=True)

with R.start(uri=str(self.TMP_PATH)):
pass

Expand Down

0 comments on commit 39f88da

Please sign in to comment.