Skip to content

Commit

Permalink
Update Version 3.2.8
Browse files Browse the repository at this point in the history
  • Loading branch information
shinny-pack authored and shinny-mayanqiong committed Apr 29, 2022
1 parent 43ccfa2 commit 6d48354
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 92 deletions.
2 changes: 1 addition & 1 deletion PKG-INFO
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: tqsdk
Version: 3.2.7
Version: 3.2.8
Summary: TianQin SDK
Home-page: https://www.shinnytech.com/tqsdk
Author: TianQin
Expand Down
4 changes: 2 additions & 2 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@
# built documents.
#
# The short X.Y version.
version = u'3.2.7'
version = u'3.2.8'
# The full version, including alpha/beta/rc tags.
release = u'3.2.7'
release = u'3.2.8'

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
Expand Down
6 changes: 6 additions & 0 deletions doc/version.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

版本变更
=============================
3.2.8 (2022/04/29)

* 修复:下载多合约 klines 时数据可能未完全收全的问题
* 修复:支持多进程下使用 :py:meth:`~tqsdk.TqApi.get_kline_data_series`、:py:meth:`~tqsdk.TqApi.get_tick_data_series` 接口


3.2.7 (2022/04/22)

* 优化:对多线程用例,增加可能的错误提示
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_tag(self):

setuptools.setup(
name='tqsdk',
version="3.2.7",
version="3.2.8",
description='TianQin SDK',
author='TianQin',
author_email='[email protected]',
Expand All @@ -47,7 +47,7 @@ def get_tag(self):
zip_safe=False,
python_requires='>=3.6.4',
install_requires=["websockets>=8.1", "requests", "numpy", "pandas>=1.1.0", "scipy", "simplejson", "aiohttp",
"certifi", "pyjwt", "psutil", "shinny_structlog", "sgqlc"],
"certifi", "pyjwt", "psutil", "shinny_structlog", "sgqlc", "filelock"],
cmdclass={'bdist_wheel': bdist_wheel},
classifiers=[
"Programming Language :: Python :: 3",
Expand Down
2 changes: 1 addition & 1 deletion tqsdk/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '3.2.7'
__version__ = '3.2.8'
169 changes: 88 additions & 81 deletions tqsdk/data_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@

import numpy as np
import pandas as pd
from filelock import FileLock

from tqsdk.channel import TqChan
from tqsdk.diff import _get_obj
from tqsdk.rangeset import _rangeset_difference, _rangeset_intersection
from tqsdk.tafunc import get_dividend_df, get_dividend_factor
from tqsdk.tafunc import get_dividend_df
from tqsdk.utils import _generate_uuid

CACHE_DIR = os.path.join(os.path.expanduser('~'), ".tqsdk/data_series")
Expand Down Expand Up @@ -73,91 +74,93 @@ def __init__(self, api, symbol_list, dur_nano, start_dt_nano, end_dt_nano, adj_t

async def _run(self):
symbol = self._symbol_list[0] # todo: 目前只处理一个合约的情况
# 检查缓存文件,计算需要下载的数据段
rangeset_id = DataSeries._get_rangeset_id(symbol, self._dur_nano)
rangeset_dt = DataSeries._get_rangeset_dt(symbol, self._dur_nano, rangeset_id)
DataSeries._assert_rangeset_asce_sorted(rangeset_id)
DataSeries._assert_rangeset_asce_sorted(rangeset_dt)
# 计算差集 todo: 已下载数据的 rangeset_id[0][0] ~ rangeset_id[-1][1] 覆盖到的部分,可以使用 id 范围下载,未覆盖到的使用 dt 范围下载
# rangeset_dt 中的最后一个 range 的最后一根 k 线,需要重新下载,因为如果用户填写的 end_dt_nano 是未来时间,有可能这跟 k 线数据是过时的
if len(rangeset_dt) > 0:
rangeset_dt[-1] = (rangeset_dt[-1][0], rangeset_dt[-1][1] - self._dur_nano)
if rangeset_dt[-1][0] == rangeset_dt[-1][1]:
# 如果最后一个 range 的只包含最后一根 k 线,则删去最后一个 Range,最后一根 k 线需要重新下载
rangeset_dt.pop(-1)
diff_rangeset = _rangeset_difference([(self._start_dt_nano, self._end_dt_nano)], rangeset_dt)

# 下载数据并全部完成
if len(diff_rangeset) > 0:
await self._download_data_series(diff_rangeset)
self._merge_rangeset() # 归并文件
lock_path = DataSeries._get_lock_path(symbol, self._dur_nano)
with FileLock(lock_path, timeout=-1):
# 检查缓存文件,计算需要下载的数据段
rangeset_id = DataSeries._get_rangeset_id(symbol, self._dur_nano)
rangeset_dt = DataSeries._get_rangeset_dt(symbol, self._dur_nano, rangeset_id)
DataSeries._assert_rangeset_asce_sorted(rangeset_id)
DataSeries._assert_rangeset_asce_sorted(rangeset_dt)

assert len(rangeset_id) == len(rangeset_dt) > 0

# 查找用户请求时间段与已有数据时间段的交集
target_rangeset_dt = _rangeset_intersection([(self._start_dt_nano, self._end_dt_nano)], rangeset_dt)
assert len(target_rangeset_dt) <= 1 # 用户请求应该落在一个时间段内,或者用户请求的时间段内没有任何数据
if len(target_rangeset_dt) == 0: # 用户请求的时间段内没有任何数据
# 计算差集 todo: 已下载数据的 rangeset_id[0][0] ~ rangeset_id[-1][1] 覆盖到的部分,可以使用 id 范围下载,未覆盖到的使用 dt 范围下载
# rangeset_dt 中的最后一个 range 的最后一根 k 线,需要重新下载,因为如果用户填写的 end_dt_nano 是未来时间,有可能这跟 k 线数据是过时的
if len(rangeset_dt) > 0:
rangeset_dt[-1] = (rangeset_dt[-1][0], rangeset_dt[-1][1] - self._dur_nano)
if rangeset_dt[-1][0] == rangeset_dt[-1][1]:
# 如果最后一个 range 的只包含最后一根 k 线,则删去最后一个 Range,最后一根 k 线需要重新下载
rangeset_dt.pop(-1)
diff_rangeset = _rangeset_difference([(self._start_dt_nano, self._end_dt_nano)], rangeset_dt)

# 下载数据并全部完成
if len(diff_rangeset) > 0:
await self._download_data_series(diff_rangeset)
self._merge_rangeset() # 归并文件
rangeset_id = DataSeries._get_rangeset_id(symbol, self._dur_nano)
rangeset_dt = DataSeries._get_rangeset_dt(symbol, self._dur_nano, rangeset_id)
DataSeries._assert_rangeset_asce_sorted(rangeset_id)
DataSeries._assert_rangeset_asce_sorted(rangeset_dt)

assert len(rangeset_id) == len(rangeset_dt) > 0

# 查找用户请求时间段与已有数据时间段的交集
target_rangeset_dt = _rangeset_intersection([(self._start_dt_nano, self._end_dt_nano)], rangeset_dt)
assert len(target_rangeset_dt) <= 1 # 用户请求应该落在一个时间段内,或者用户请求的时间段内没有任何数据
if len(target_rangeset_dt) == 0: # 用户请求的时间段内没有任何数据
self.is_ready = True
return

# 此时用户请求时间范围,转化为 target_rangeset_dt[0]
# 找到 self._start_dt_nano, self._end_dt_nano 对应的 range_id
(start_dt, end_dt) = target_rangeset_dt[0]
range_id = None
for index in range(len(rangeset_dt)):
range_dt = rangeset_dt[index]
if range_dt[0] <= start_dt and end_dt <= range_dt[1]:
range_id = rangeset_id[index]
break
assert range_id

# 返回 dataframe, 目标文件为
filename = os.path.join(CACHE_DIR, f"{symbol}.{self._dur_nano}.{range_id[0]}.{range_id[1]}")
data_cols = DataSeries._get_data_cols(symbol=self._symbol_list[0], dur_nano=self._dur_nano)
dtype = np.dtype([('id', 'i8'), ('datetime', 'i8')] + [(col, 'f8') for col in data_cols])
fp = np.memmap(filename, dtype=dtype, mode='r', shape=range_id[1] - range_id[0])

# target_rangeset_dt[0] 对应的 id 范围是 [start_id, end_id)
start_id = fp[fp['datetime'] <= start_dt][-1]["id"]
end_id = fp[fp['datetime'] < end_dt][-1]["id"]
rows = end_id - start_id + 1 # 读取数据行数
array = fp[start_id - range_id[0]: start_id - range_id[0] + rows]

# 按列赋值,df.append 方法会返回一个新的 df,初始化构造的 df 可能(如果是异步代码下)已经返回给用户代码中;df.update 方法会导致更新的列的 type 都是 object
self.df["id"] = array["id"]
self.df["datetime"] = array["datetime"]
for c in data_cols:
self.df[c] = array[c]
self.df["symbol"] = symbol
self.df["duration"] = self._dur_nano

# 复权, 如果存在 STOCK / FUND 并且 adj_type is not None, 这里需要提前准备下载时间段内的复权因子
quote = self._api.get_quote(symbol)
if self._adj_type and quote.ins_class in ["STOCK", "FUND"]:
factor_df = await self._update_dividend_factor(symbol) # 复权需要根据日线计算除权因子,todo: 如果本地下载过日线,不需要再从服务器获取日线数据
if self._adj_type == "F":
# 倒序循环 factor_df, 对于小于当前 factor_df[datetime] 的行 乘以 factor_df[factor]
for i in range(factor_df.shape[0] - 1, -1, -1):
dt = factor_df.iloc[i].datetime
factor = factor_df.iloc[i].factor
adj_cols = DataSeries._get_adj_cols(symbol, self._dur_nano)
lt = self.df["datetime"].lt(dt)
self.df.loc[lt, adj_cols] = self.df.loc[lt, adj_cols] * factor
if self._adj_type == "B":
# 正序循环 factor_df, 对于大于等于当前 factor_df[datetime] 的行 乘以 1 / factor_df[factor]
for i in range(factor_df.shape[0]):
dt = factor_df.iloc[i].datetime
factor = factor_df.iloc[i].factor
adj_cols = DataSeries._get_adj_cols(symbol, self._dur_nano)
ge = self.df["datetime"].ge(dt)
self.df.loc[ge, adj_cols] = self.df.loc[ge, adj_cols] / factor
# 结束状态
self.is_ready = True
return

# 此时用户请求时间范围,转化为 target_rangeset_dt[0]
# 找到 self._start_dt_nano, self._end_dt_nano 对应的 range_id
(start_dt, end_dt) = target_rangeset_dt[0]
range_id = None
for index in range(len(rangeset_dt)):
range_dt = rangeset_dt[index]
if range_dt[0] <= start_dt and end_dt <= range_dt[1]:
range_id = rangeset_id[index]
break
assert range_id

# 返回 dataframe, 目标文件为
filename = os.path.join(CACHE_DIR, f"{symbol}.{self._dur_nano}.{range_id[0]}.{range_id[1]}")
data_cols = DataSeries._get_data_cols(symbol=self._symbol_list[0], dur_nano=self._dur_nano)
dtype = np.dtype([('id', 'i8'), ('datetime', 'i8')] + [(col, 'f8') for col in data_cols])
fp = np.memmap(filename, dtype=dtype, mode='r', shape=range_id[1] - range_id[0])

# target_rangeset_dt[0] 对应的 id 范围是 [start_id, end_id)
start_id = fp[fp['datetime'] <= start_dt][-1]["id"]
end_id = fp[fp['datetime'] < end_dt][-1]["id"]
rows = end_id - start_id + 1 # 读取数据行数
array = fp[start_id - range_id[0]: start_id - range_id[0] + rows]

# 按列赋值,df.append 方法会返回一个新的 df,初始化构造的 df 可能(如果是异步代码下)已经返回给用户代码中;df.update 方法会导致更新的列的 type 都是 object
self.df["id"] = array["id"]
self.df["datetime"] = array["datetime"]
for c in data_cols:
self.df[c] = array[c]
self.df["symbol"] = symbol
self.df["duration"] = self._dur_nano

# 复权, 如果存在 STOCK / FUND 并且 adj_type is not None, 这里需要提前准备下载时间段内的复权因子
quote = self._api.get_quote(symbol)
if self._adj_type and quote.ins_class in ["STOCK", "FUND"]:
factor_df = await self._update_dividend_factor(symbol) # 复权需要根据日线计算除权因子,todo: 如果本地下载过日线,不需要再从服务器获取日线数据
if self._adj_type == "F":
# 倒序循环 factor_df, 对于小于当前 factor_df[datetime] 的行 乘以 factor_df[factor]
for i in range(factor_df.shape[0] - 1, -1, -1):
dt = factor_df.iloc[i].datetime
factor = factor_df.iloc[i].factor
adj_cols = DataSeries._get_adj_cols(symbol, self._dur_nano)
lt = self.df["datetime"].lt(dt)
self.df.loc[lt, adj_cols] = self.df.loc[lt, adj_cols] * factor
if self._adj_type == "B":
# 正序循环 factor_df, 对于大于等于当前 factor_df[datetime] 的行 乘以 1 / factor_df[factor]
for i in range(factor_df.shape[0]):
dt = factor_df.iloc[i].datetime
factor = factor_df.iloc[i].factor
adj_cols = DataSeries._get_adj_cols(symbol, self._dur_nano)
ge = self.df["datetime"].ge(dt)
self.df.loc[ge, adj_cols] = self.df.loc[ge, adj_cols] / factor
# 结束状态
self.is_ready = True

async def _download_data_series(self, rangeset):
symbol = self._symbol_list[0]
Expand Down Expand Up @@ -382,6 +385,10 @@ def _ensure_cache_dir():
if not os.path.exists(CACHE_DIR):
os.makedirs(CACHE_DIR, exist_ok=True)

@staticmethod
def _get_lock_path(symbol, dur_nano):
return os.path.join(CACHE_DIR, f".{symbol}.{dur_nano}.lock")

@staticmethod
def _get_data_cols(symbol, dur_nano):
# 获取数据列
Expand Down
9 changes: 4 additions & 5 deletions tqsdk/tools/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,13 +303,12 @@ async def _download_data(self):
continue
left_id = chart.get("left_id", -1)
right_id = chart.get("right_id", -1)
if (left_id == -1 and right_id == -1) or self._api._data.get("mdhis_more_data", True):
if (left_id == -1 and right_id == -1) or chart.get("more_data", True):
# 定位信息还没收到, 或数据序列还没收到
continue
for serial in serials:
# 检查合约的数据是否收到
if serial.get("last_id", -1) == -1:
continue
# 检查合约的数据是否收到
if any([serial.get("last_id", -1) == -1 for serial in serials]):
continue
if current_id is None:
current_id = max(left_id, 0)
while current_id <= right_id:
Expand Down

0 comments on commit 6d48354

Please sign in to comment.