From 58dadc5b1eb8adddb33ae7ae0273bfbc687fb619 Mon Sep 17 00:00:00 2001 From: Danglewood <85772166+deeleeramone@users.noreply.github.com> Date: Mon, 6 Jan 2025 19:31:46 -0800 Subject: [PATCH] some cleanup --- .../provider/utils/websockets/broadcast.py | 23 ++--- .../provider/utils/websockets/client.py | 12 +-- .../provider/utils/websockets/database.py | 74 +++++++-------- .../fmp/openbb_fmp/utils/websocket_client.py | 47 +++++++++- .../openbb_polygon/utils/websocket_client.py | 8 +- .../models/websocket_connection.py | 17 ++-- .../openbb_tiingo/utils/websocket_client.py | 90 +++++++++++++------ 7 files changed, 171 insertions(+), 100 deletions(-) diff --git a/openbb_platform/core/openbb_core/provider/utils/websockets/broadcast.py b/openbb_platform/core/openbb_core/provider/utils/websockets/broadcast.py index 0d53eddba84c..fdd6e9bf7214 100644 --- a/openbb_platform/core/openbb_core/provider/utils/websockets/broadcast.py +++ b/openbb_platform/core/openbb_core/provider/utils/websockets/broadcast.py @@ -4,27 +4,22 @@ import json import logging import os -import signal import sys from pathlib import Path from typing import Optional import uvicorn - from fastapi import FastAPI, WebSocket, WebSocketDisconnect from openbb_core.provider.utils.websockets.database import ( CHECK_FOR, Database, - kill_thread, ) from openbb_core.provider.utils.websockets.helpers import ( get_logger, - handle_termination_signal, parse_kwargs, ) from starlette.websockets import WebSocketState - kwargs = parse_kwargs() HOST = kwargs.pop("host", None) or "localhost" @@ -41,7 +36,7 @@ app = FastAPI() -CONNECTED_CLIENTS = set() +CONNECTED_CLIENTS: set = set() MAIN_CLIENT = None STDIN_TASK = None LOGGER = get_logger("broadcast-server") @@ -58,7 +53,7 @@ async def read_stdin(): continue if line.strip() == "numclients": - MAIN_CLIENT.logger.info( + MAIN_CLIENT.logger.info( # type: ignore "Number of connected clients: %i", len(CONNECTED_CLIENTS) ) continue @@ -101,7 +96,7 @@ async def websocket_endpoint( # noqa: PLR0915 str(AUTH_TOKEN), sql=sql, ) - broadcast_server.replay = replay + broadcast_server.replay = replay # type: ignore auth_token = str(auth_token) if sql and ( @@ -244,7 +239,7 @@ async def stream_results( # noqa: PLR0915 # pylint: disable=too-many-branches last_id = ( 0 - if hasattr(self, "replay") and self.replay is True or replay is True + if hasattr(self, "replay") and self.replay is True or replay is True # type: ignore else last_id ) @@ -260,10 +255,10 @@ async def stream_results( # noqa: PLR0915 # pylint: disable=too-many-branches any(x.lower() in sql.lower() for x in CHECK_FOR) or (self.table_name not in sql and "message" not in sql) ): - await self.websocket.accept() - await self.websocket.send_text("Invalid SQL query passed.") - await self.websocket.close(code=1008, reason="Invalid query") - self.logger.error( + await self.websocket.accept() # type: ignore + await self.websocket.send_text("Invalid SQL query passed.") # type: ignore + await self.websocket.close(code=1008, reason="Invalid query") # type: ignore + self.logger.error( # type: ignore "Invalid query passed to the stream_results method: %s", sql ) return @@ -291,7 +286,7 @@ async def stream_results( # noqa: PLR0915 # pylint: disable=too-many-branches await self.websocket.send_json( json.dumps(json.loads(row[1])) ) - if self.replay is True: + if self.replay is True: # type: ignore await asyncio.sleep(self.sleep_time / 10) await cursor.close() diff --git a/openbb_platform/core/openbb_core/provider/utils/websockets/client.py b/openbb_platform/core/openbb_core/provider/utils/websockets/client.py index 704e3aa7be82..c7dae8e53b29 100644 --- a/openbb_platform/core/openbb_core/provider/utils/websockets/client.py +++ b/openbb_platform/core/openbb_core/provider/utils/websockets/client.py @@ -233,12 +233,12 @@ def _atexit(self) -> None: self.logger.info("Websocket results saved to, %s\n", str(self.results_path)) if os.path.exists(self.results_file) and not self.save_database: # type: ignore os.remove(self.results_file) # type: ignore - if os.path.exists(self.results_file + "-journal"): - os.remove(self.results_file + "-journal") - if os.path.exists(self.results_file + "-shm"): - os.remove(self.results_file + "-shm") - if os.path.exists(self.results_file + "-wal"): - os.remove(self.results_file + "-wal") + if os.path.exists(self.results_file + "-journal"): # type: ignore + os.remove(self.results_file + "-journal") # type: ignore + if os.path.exists(self.results_file + "-shm"): # type: ignore + os.remove(self.results_file + "-shm") # type: ignore + if os.path.exists(self.results_file + "-wal"): # type: ignore + os.remove(self.results_file + "-wal") # type: ignore def _log_provider_output(self, output_queue) -> None: """Log output from the provider logger, handling exceptions, errors, and messages that are not data.""" diff --git a/openbb_platform/core/openbb_core/provider/utils/websockets/database.py b/openbb_platform/core/openbb_core/provider/utils/websockets/database.py index 36d418f3ab74..4e380c3b5531 100644 --- a/openbb_platform/core/openbb_core/provider/utils/websockets/database.py +++ b/openbb_platform/core/openbb_core/provider/utils/websockets/database.py @@ -166,10 +166,10 @@ def __init__( # pylint: disable=too-many-positional-arguments self.results_path = Path(results_file).absolute() self.results_file = results_file - if ( - " " in table_name - or table_name.isupper() - or any(x.lower() in table_name.lower() for x in CHECK_FOR) + if table_name and ( + " " in table_name # type: ignore + or table_name.isupper() # type: ignore + or any(x.lower() in table_name.lower() for x in CHECK_FOR) # type: ignore ): raise OpenBBError(ProgrammingError(f"Invalid table name, {table_name}.")) @@ -177,7 +177,7 @@ def __init__( # pylint: disable=too-many-positional-arguments self.limit = limit self.loop = loop self.kwargs = kwargs if kwargs else {} - self._connections = {} + self._connections: dict = {} run_async(self._setup_database) self.data_model = data_model @@ -242,29 +242,29 @@ async def get_connection(self, name: str = "read"): conn_kwargs = self.kwargs.copy() if name == "read": - if ":" not in self.results_file: - results_file = ( + if ":" not in self.results_file: # type: ignore + results_file = ( # type: ignore "file:" + ( - self.results_file - if self.results_file.startswith("/") - else "/" + self.results_file + self.results_file # type: ignore + if self.results_file.startswith("/") # type: ignore + else "/" + self.results_file # type: ignore ) + "?mode=ro" ) else: - results_file = ( - self.results_file - + f"{'&mode=ro' if '?' in self.results_file else '?mode=ro'}" + results_file = ( # type: ignore + self.results_file # type: ignore + + f"{'&mode=ro' if '?' in self.results_file else '?mode=ro'}" # type: ignore ) conn_kwargs["uri"] = True elif name == "write": - results_file = self.results_file + results_file = self.results_file # type: ignore conn_kwargs["check_same_thread"] = False if name not in self._connections: - conn = await aiosqlite.connect(results_file, **conn_kwargs) + conn = await aiosqlite.connect(results_file, **conn_kwargs) # type: ignore pragmas = [ "PRAGMA journal_mode=WAL", "PRAGMA synchronous=off", @@ -355,8 +355,8 @@ async def _fetch_all(self, limit: Optional[int] = None) -> list: query += " LIMIT ?" params = (limit,) else: - params = () - async with conn.execute(query, params) as cursor: + params = None + async with conn.execute(query, params) as cursor: # type: ignore async for row in cursor: rows.append(await self._deserialize_row(row, cursor)) @@ -444,6 +444,7 @@ def get_latest_results( f" {e.__class__.__name__ if hasattr(e, '__class__') else e} -> {e.args}" ) self.logger.error(msg) + return [] async def _query_db(self, sql, parameters: Optional[Iterable[Any]] = None) -> list: """Query the SQLite database.""" @@ -472,10 +473,11 @@ async def _query_db(self, sql, parameters: Optional[Iterable[Any]] = None) -> li ) as cursor: async for row in cursor: rows.append(await self._deserialize_row(row, cursor)) - return rows except Exception as e: # pylint: disable=broad-except raise OpenBBError(e) from e + return rows + def query(self, sql: str, parameters: Optional[Iterable[Any]] = None) -> list: """ Run a SELECT query to the database. @@ -654,7 +656,7 @@ def __init__( self._last_processed_timestamp = None self._conn = None self.num_workers = 60 - self.write_tasks = [] + self.write_tasks: list = [] self._export_running = False self._prune_running = False self.batch_processor = BatchProcessor(self) @@ -688,7 +690,7 @@ async def stop_writer(self): async def _process_queue(self): """Process queue with parallel writers.""" - batch = [] + batch: list = [] while self.writer_running: try: @@ -833,6 +835,8 @@ async def _export_database(self): for key in json.loads(row[0]): headers[key] = None + new_rows: list = [] + if self.compress_export: with gzip.open(path, "wt") as gz_file: writer = csv.DictWriter(gz_file, fieldnames=list(headers)) @@ -841,7 +845,7 @@ async def _export_database(self): while True: rows = await cursor.fetchmany(chunk_size) - new_rows: list = [] + if not rows: break @@ -864,7 +868,7 @@ async def _export_database(self): while True: rows = await cursor.fetchmany(chunk_size) - new_rows: list = [] + if not rows: break @@ -915,17 +919,17 @@ def start_export_task(self): return self._export_running = True - self.export_thread = threading.Thread( + self.export_thread = threading.Thread( # type: ignore target=self._run_export_event, name="ExportThread", daemon=True ) - self.export_thread.start() + self.export_thread.start() # type: ignore def stop_export_task(self): """Public method to stop the background export task.""" if hasattr(self, "export_thread") and self.export_thread: - self.export_thread.join(timeout=1) - if self.export_thread.is_alive(): - kill_thread(self.export_thread) + self.export_thread.join(timeout=1) # type: ignore + if self.export_thread.is_alive(): # type: ignore + kill_thread(self.export_thread) # type: ignore self._export_running = False self.export_thread = None @@ -1091,12 +1095,12 @@ def __init__( def run(self): """Run the batch processor as tasks.""" try: - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) + self.loop = asyncio.new_event_loop() # type: ignore + asyncio.set_event_loop(self.loop) # type: ignore # Create worker tasks while self.running and not self._shutdown.is_set(): try: - self.loop.run_until_complete(self._worker()) + self.loop.run_until_complete(self._worker()) # type: ignore except (SystemExit, KeyboardInterrupt): self.running = False break @@ -1112,19 +1116,19 @@ def stop(self): """Signal thread to stop gracefully.""" self.running = False self._shutdown.set() - if self.loop and self.loop.is_running(): - self.loop.call_soon_threadsafe(self.loop.stop) + if self.loop and self.loop.is_running(): # type: ignore + self.loop.call_soon_threadsafe(self.loop.stop) # type: ignore def _cleanup(self): """Clean up resources on shutdown""" if self.loop: - pending = asyncio.all_tasks(self.loop) + pending = asyncio.all_tasks(self.loop) # type: ignore for task in pending: task.cancel() - self.loop.run_until_complete( + self.loop.run_until_complete( # type: ignore asyncio.gather(*pending, return_exceptions=True) ) - self.loop.close() + self.loop.close() # type: ignore async def _worker(self): # pylint: disable=import-outside-toplevel diff --git a/openbb_platform/providers/fmp/openbb_fmp/utils/websocket_client.py b/openbb_platform/providers/fmp/openbb_fmp/utils/websocket_client.py index 98a3ff2febf9..b608ff7154e5 100644 --- a/openbb_platform/providers/fmp/openbb_fmp/utils/websocket_client.py +++ b/openbb_platform/providers/fmp/openbb_fmp/utils/websocket_client.py @@ -1,10 +1,38 @@ -"""FMP WebSocket client.""" +""" +FMP WebSocket Client. + +This file should be run as a script, and is intended to be run as a subprocess of FmpWebSocketFetcher. + +Keyword arguments are passed from the command line as space-delimited, `key=value`, pairs. + +Required Keyword Arguments +-------------------------- + api_key: str + The API key for the Polygon WebSocket. + asset_type: str + The asset type to subscribe to. Default is "crypto". + Options: "stock", "crypto", "fx" + symbol: str + The symbol to subscribe to. Example: "AAPL" or "AAPL,MSFT". + results_file: str + The path to the file where the results will be stored. + +Optional Keyword Arguments +-------------------------- + table_name: str + The name of the table to store the data in. Default is "records". + limit: int + The maximum number of rows to store in the database. + connect_kwargs: dict + Additional keyword arguments to pass directly to `websockets.connect()`. + Example: {"ping_timeout": 300} +""" import asyncio +import json import signal import sys -import orjson as json import websockets from openbb_core.provider.utils.websockets.database import Database, DatabaseWriter from openbb_core.provider.utils.websockets.helpers import ( @@ -31,6 +59,8 @@ CONNECT_KWARGS = kwargs.pop("connect_kwargs", {}) URL = URL_MAP.get(kwargs.pop("asset_type"), None) +SUBSCRIBED_SYMBOLS: set = set() + if not URL: raise ValueError("Invalid asset type provided.") @@ -87,6 +117,14 @@ async def subscribe(websocket, symbol, event): } try: await websocket.send(json.dumps(subscribe_event)) + + for t in ticker: + if event == "subscribe": + SUBSCRIBED_SYMBOLS.add(t) + else: + SUBSCRIBED_SYMBOLS.discard(t) + + kwargs["symbol"] = ",".join(SUBSCRIBED_SYMBOLS) except Exception as e: # pylint: disable=broad-except msg = f"PROVIDER ERROR: {e.__class__.__name__ if hasattr(e, '__class__') else e}: {e}" logger.error(msg) @@ -201,7 +239,10 @@ async def connect_and_stream(): sys.exit(1) except Exception as e: # pylint: disable=broad-except - msg = f"PROVIDER ERROR: Unexpected error -> {e.__class__.__name__}: {e}" + msg = ( + "PROVIDER ERROR: Unexpected error ->" + f" {e.__class__.__name__ if hasattr(e, '__class__') else e}: {e.args}" + ) logger.error(msg) sys.exit(1) diff --git a/openbb_platform/providers/polygon/openbb_polygon/utils/websocket_client.py b/openbb_platform/providers/polygon/openbb_polygon/utils/websocket_client.py index 0878d95bc114..dea05d4df59e 100644 --- a/openbb_platform/providers/polygon/openbb_polygon/utils/websocket_client.py +++ b/openbb_platform/providers/polygon/openbb_polygon/utils/websocket_client.py @@ -32,12 +32,10 @@ import asyncio import json -import os import signal import sys import time -import orjson import websockets from openbb_core.provider.utils.websockets.database import Database, DatabaseWriter from openbb_core.provider.utils.websockets.helpers import ( @@ -178,7 +176,7 @@ async def login(websocket): try: await websocket.send(login_event) res = await websocket.recv(decode=False) - response = orjson.loads(res) + response = json.loads(res) messages = response if isinstance(response, list) else [response] for msg in messages: if msg.get("status") == "connected": @@ -323,7 +321,7 @@ async def process_input_messages(message): def _process_in_thread(): global LAST_MINUTE_COUNT, MESSAGE_COUNT # pylint: disable=global-statement # noqa - message_data = orjson.loads(message) + message_data = json.loads(message) if isinstance(message_data, list): MESSAGE_COUNT += len(message_data) LAST_MINUTE_COUNT += len(message_data) @@ -387,7 +385,7 @@ def _process_in_thread(): response = await websocket.recv(decode=False) - await process_message(orjson.loads(response)) + await process_message(json.loads(response)) await subscribe(websocket, kwargs["symbol"], "subscribe") diff --git a/openbb_platform/providers/tiingo/openbb_tiingo/models/websocket_connection.py b/openbb_platform/providers/tiingo/openbb_tiingo/models/websocket_connection.py index 8cba93413693..989999e77747 100644 --- a/openbb_platform/providers/tiingo/openbb_tiingo/models/websocket_connection.py +++ b/openbb_platform/providers/tiingo/openbb_tiingo/models/websocket_connection.py @@ -92,7 +92,7 @@ class TiingoWebSocketQueryParams(WebSocketQueryParams): description="The asset type for the feed. Choices are 'stock', 'fx', or 'crypto'.", ) feed: Literal["trade", "trade_and_quote"] = Field( - default="trade_and_quote", + default="trade", description="The asset type associated with the symbol. Choices are 'trade' or 'trade_and_quote'." + " FX only supports quote.", ) @@ -102,6 +102,7 @@ class TiingoWebSocketData(WebSocketData): """Tiingo WebSocket data model.""" timestamp: Optional[datetime] = Field( + default=None, description="The timestamp of the data.", ) type: Literal["quote", "trade", "break"] = Field( @@ -242,28 +243,20 @@ async def aextract_data( from asyncio import sleep api_key = credentials.get("tiingo_token") if credentials else "" - threshold_level = ( - 5 - if query.asset_type == "fx" or query.feed == "trade" - else ( - 2 - if query.asset_type == "crypto" and query.feed == "trade_and_quote" - else 0 - ) - ) + symbol = query.symbol.lower() kwargs = { "api_key": api_key, "asset_type": query.asset_type, - "threshold_level": threshold_level, + "feed": query.feed, "connect_kwargs": query.connect_kwargs, } client = WebSocketClient( name=query.name, module="openbb_tiingo.utils.websocket_client", - symbol=symbol.lower(), + symbol=symbol, limit=query.limit, results_file=query.results_file, table_name=query.table_name, diff --git a/openbb_platform/providers/tiingo/openbb_tiingo/utils/websocket_client.py b/openbb_platform/providers/tiingo/openbb_tiingo/utils/websocket_client.py index ae213457219f..0049847d9a96 100644 --- a/openbb_platform/providers/tiingo/openbb_tiingo/utils/websocket_client.py +++ b/openbb_platform/providers/tiingo/openbb_tiingo/utils/websocket_client.py @@ -1,12 +1,41 @@ -"""FMP WebSocket server.""" +""" +Tiingo WebSocket Client. + +This file should be run as a script, and is intended to be run as a subprocess of TiingoWebSocketFetcher. + +Keyword arguments are passed from the command line as space-delimited, `key=value`, pairs. + +Required Keyword Arguments +-------------------------- + api_key: str + The API key for the Polygon WebSocket. + asset_type: str + The asset type to subscribe to. Default is "crypto". + Options: "stock", "crypto", "fx" + symbol: str + The symbol to subscribe to. Example: "AAPL" or "AAPL,MSFT". Use "*" to subscribe to all symbols. + feed: str + The feed to subscribe to. One of: "trade" or "trade_and_quote". + results_file: str + The path to the file where the results will be stored. + +Optional Keyword Arguments +-------------------------- + table_name: str + The name of the table to store the data in. Default is "records". + limit: int + The maximum number of rows to store in the database. + connect_kwargs: dict + Additional keyword arguments to pass directly to `websockets.connect()`. + Example: {"ping_timeout": 300} +""" import asyncio -import os +import json import signal import sys -from datetime import UTC, datetime +from datetime import datetime -import orjson import websockets from openbb_core.provider.utils.errors import UnauthorizedError from openbb_core.provider.utils.websockets.database import Database, DatabaseWriter @@ -20,6 +49,7 @@ from openbb_tiingo.models.websocket_connection import TiingoWebSocketData from pandas import to_datetime from pydantic import ValidationError +from pytz import UTC from websockets.asyncio.client import connect URL_MAP = { @@ -79,14 +109,28 @@ ] SUBSCRIPTION_ID = "" logger = get_logger("openbb.websocket.tiingo") -input_queue = MessageQueue(logger=logger, backoff_factor=0) -db_queue = MessageQueue(logger=logger, backoff_factor=0) +input_queue = MessageQueue(logger=logger) +db_queue = MessageQueue(logger=logger) kwargs = parse_kwargs() CONNECT_KWARGS = kwargs.pop("connect_kwargs", {}) -kwargs["results_file"] = os.path.abspath(kwargs["results_file"]) -URL = URL_MAP.get(kwargs.pop("asset_type", "crypto")) +ASSET_TYPE = kwargs.pop("asset_type", "crypto") +FEED = kwargs.pop("feed", "trade") + + SUBSCRIBED_SYMBOLS: set = set() +THRESHOLD_LEVEL = ( + 5 + if ASSET_TYPE == "fx" or FEED == "trade" + else (2 if ASSET_TYPE == "crypto" and FEED == "trade_and_quote" else 0) +) + +URL = URL_MAP.get(ASSET_TYPE) + +if not kwargs.get("api_key"): + raise ValueError("No API key provided.") + + if not URL: raise ValueError("Invalid asset type provided.") @@ -120,9 +164,9 @@ async def update_symbols(symbol, event): } async with connect(URL) as websocket: - await websocket.send(orjson.dumps(update_event)) + await websocket.send(json.dumps(update_event)) response = await websocket.recv(decode=False) - message = orjson.loads(response) + message = json.loads(response) if "tickers" in message.get("data", {}): tickers = message["data"]["tickers"] threshold_level = message["data"].get("thresholdLevel") @@ -155,7 +199,7 @@ async def read_stdin_and_update_symbols(): f" Database Queue : {db_queue.queue.qsize()}" ) else: - line = orjson.loads(line.strip()) + line = json.loads(line.strip()) if line: symbol = line.get("symbol") @@ -167,7 +211,7 @@ async def process_message(message): # pylint: disable=too-many-branches """Process the message and write to the database.""" result: dict = {} data_message: dict = {} - message = message if isinstance(message, (dict, list)) else orjson.loads(message) + message = message if isinstance(message, (dict, list)) else json.loads(message) msg: str = "" if message.get("messageType") == "E": response = message.get("response", {}) @@ -243,16 +287,12 @@ async def connect_and_stream(): tasks: set = set() ticker: list = [] - - conn_kwargs = CONNECT_KWARGS.copy() - - conn_kwargs.update( - { - "ping_interval": 8, - "ping_timeout": 8, - "close_timeout": 1, - } - ) + conn_kwargs = { + "ping_interval": 8, + "ping_timeout": 8, + "close_timeout": 1, + } + conn_kwargs.update(CONNECT_KWARGS) if isinstance(kwargs["symbol"], str): ticker = kwargs["symbol"].lower().split(",") @@ -261,7 +301,7 @@ async def connect_and_stream(): "eventName": "subscribe", "authorization": kwargs["api_key"], "eventData": { - "thresholdLevel": kwargs["threshold_level"], + "thresholdLevel": THRESHOLD_LEVEL, "tickers": ticker, }, } @@ -270,7 +310,7 @@ async def message_receiver(websocket): """Receive messages from the WebSocket.""" while True: message = await websocket.recv(decode=False) - input_queue.queue.put_nowait(orjson.loads(message)) + input_queue.queue.put_nowait(json.loads(message)) stdin_task = asyncio.create_task(read_stdin_and_update_symbols()) tasks.add(stdin_task) @@ -290,7 +330,7 @@ async def message_receiver(websocket): ) tasks.add(receiver_task) - await websocket.send(orjson.dumps(subscribe_event)) + await websocket.send(json.dumps(subscribe_event)) logger.info("PROVIDER INFO: WebSocket connection established.") for _ in range(9): process_task = asyncio.create_task(