Skip to content

Commit

Permalink
some cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
deeleeramone committed Jan 7, 2025
1 parent df72101 commit 58dadc5
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,22 @@
import json
import logging
import os
import signal
import sys
from pathlib import Path
from typing import Optional

import uvicorn

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from openbb_core.provider.utils.websockets.database import (
CHECK_FOR,
Database,
kill_thread,
)
from openbb_core.provider.utils.websockets.helpers import (
get_logger,
handle_termination_signal,
parse_kwargs,
)
from starlette.websockets import WebSocketState


kwargs = parse_kwargs()

HOST = kwargs.pop("host", None) or "localhost"
Expand All @@ -41,7 +36,7 @@

app = FastAPI()

CONNECTED_CLIENTS = set()
CONNECTED_CLIENTS: set = set()
MAIN_CLIENT = None
STDIN_TASK = None
LOGGER = get_logger("broadcast-server")
Expand All @@ -58,7 +53,7 @@ async def read_stdin():
continue

if line.strip() == "numclients":
MAIN_CLIENT.logger.info(
MAIN_CLIENT.logger.info( # type: ignore
"Number of connected clients: %i", len(CONNECTED_CLIENTS)
)
continue
Expand Down Expand Up @@ -101,7 +96,7 @@ async def websocket_endpoint( # noqa: PLR0915
str(AUTH_TOKEN),
sql=sql,
)
broadcast_server.replay = replay
broadcast_server.replay = replay # type: ignore
auth_token = str(auth_token)

if sql and (
Expand Down Expand Up @@ -244,7 +239,7 @@ async def stream_results( # noqa: PLR0915 # pylint: disable=too-many-branches

last_id = (
0
if hasattr(self, "replay") and self.replay is True or replay is True
if hasattr(self, "replay") and self.replay is True or replay is True # type: ignore
else last_id
)

Expand All @@ -260,10 +255,10 @@ async def stream_results( # noqa: PLR0915 # pylint: disable=too-many-branches
any(x.lower() in sql.lower() for x in CHECK_FOR)
or (self.table_name not in sql and "message" not in sql)
):
await self.websocket.accept()
await self.websocket.send_text("Invalid SQL query passed.")
await self.websocket.close(code=1008, reason="Invalid query")
self.logger.error(
await self.websocket.accept() # type: ignore
await self.websocket.send_text("Invalid SQL query passed.") # type: ignore
await self.websocket.close(code=1008, reason="Invalid query") # type: ignore
self.logger.error( # type: ignore
"Invalid query passed to the stream_results method: %s", sql
)
return
Expand Down Expand Up @@ -291,7 +286,7 @@ async def stream_results( # noqa: PLR0915 # pylint: disable=too-many-branches
await self.websocket.send_json(
json.dumps(json.loads(row[1]))
)
if self.replay is True:
if self.replay is True: # type: ignore
await asyncio.sleep(self.sleep_time / 10)
await cursor.close()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,12 @@ def _atexit(self) -> None:
self.logger.info("Websocket results saved to, %s\n", str(self.results_path))
if os.path.exists(self.results_file) and not self.save_database: # type: ignore
os.remove(self.results_file) # type: ignore
if os.path.exists(self.results_file + "-journal"):
os.remove(self.results_file + "-journal")
if os.path.exists(self.results_file + "-shm"):
os.remove(self.results_file + "-shm")
if os.path.exists(self.results_file + "-wal"):
os.remove(self.results_file + "-wal")
if os.path.exists(self.results_file + "-journal"): # type: ignore
os.remove(self.results_file + "-journal") # type: ignore
if os.path.exists(self.results_file + "-shm"): # type: ignore
os.remove(self.results_file + "-shm") # type: ignore
if os.path.exists(self.results_file + "-wal"): # type: ignore
os.remove(self.results_file + "-wal") # type: ignore

def _log_provider_output(self, output_queue) -> None:
"""Log output from the provider logger, handling exceptions, errors, and messages that are not data."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,18 +166,18 @@ def __init__( # pylint: disable=too-many-positional-arguments
self.results_path = Path(results_file).absolute()
self.results_file = results_file

if (
" " in table_name
or table_name.isupper()
or any(x.lower() in table_name.lower() for x in CHECK_FOR)
if table_name and (
" " in table_name # type: ignore
or table_name.isupper() # type: ignore
or any(x.lower() in table_name.lower() for x in CHECK_FOR) # type: ignore
):
raise OpenBBError(ProgrammingError(f"Invalid table name, {table_name}."))

self.table_name = table_name if table_name else "records"
self.limit = limit
self.loop = loop
self.kwargs = kwargs if kwargs else {}
self._connections = {}
self._connections: dict = {}
run_async(self._setup_database)
self.data_model = data_model

Expand Down Expand Up @@ -242,29 +242,29 @@ async def get_connection(self, name: str = "read"):
conn_kwargs = self.kwargs.copy()

if name == "read":
if ":" not in self.results_file:
results_file = (
if ":" not in self.results_file: # type: ignore
results_file = ( # type: ignore
"file:"
+ (
self.results_file
if self.results_file.startswith("/")
else "/" + self.results_file
self.results_file # type: ignore
if self.results_file.startswith("/") # type: ignore
else "/" + self.results_file # type: ignore
)
+ "?mode=ro"
)
else:
results_file = (
self.results_file
+ f"{'&mode=ro' if '?' in self.results_file else '?mode=ro'}"
results_file = ( # type: ignore
self.results_file # type: ignore
+ f"{'&mode=ro' if '?' in self.results_file else '?mode=ro'}" # type: ignore
)
conn_kwargs["uri"] = True
elif name == "write":
results_file = self.results_file
results_file = self.results_file # type: ignore

conn_kwargs["check_same_thread"] = False

if name not in self._connections:
conn = await aiosqlite.connect(results_file, **conn_kwargs)
conn = await aiosqlite.connect(results_file, **conn_kwargs) # type: ignore
pragmas = [
"PRAGMA journal_mode=WAL",
"PRAGMA synchronous=off",
Expand Down Expand Up @@ -355,8 +355,8 @@ async def _fetch_all(self, limit: Optional[int] = None) -> list:
query += " LIMIT ?"
params = (limit,)
else:
params = ()
async with conn.execute(query, params) as cursor:
params = None
async with conn.execute(query, params) as cursor: # type: ignore
async for row in cursor:
rows.append(await self._deserialize_row(row, cursor))

Expand Down Expand Up @@ -444,6 +444,7 @@ def get_latest_results(
f" {e.__class__.__name__ if hasattr(e, '__class__') else e} -> {e.args}"
)
self.logger.error(msg)
return []

async def _query_db(self, sql, parameters: Optional[Iterable[Any]] = None) -> list:
"""Query the SQLite database."""
Expand Down Expand Up @@ -472,10 +473,11 @@ async def _query_db(self, sql, parameters: Optional[Iterable[Any]] = None) -> li
) as cursor:
async for row in cursor:
rows.append(await self._deserialize_row(row, cursor))
return rows
except Exception as e: # pylint: disable=broad-except
raise OpenBBError(e) from e

return rows

def query(self, sql: str, parameters: Optional[Iterable[Any]] = None) -> list:
"""
Run a SELECT query to the database.
Expand Down Expand Up @@ -654,7 +656,7 @@ def __init__(
self._last_processed_timestamp = None
self._conn = None
self.num_workers = 60
self.write_tasks = []
self.write_tasks: list = []
self._export_running = False
self._prune_running = False
self.batch_processor = BatchProcessor(self)
Expand Down Expand Up @@ -688,7 +690,7 @@ async def stop_writer(self):

async def _process_queue(self):
"""Process queue with parallel writers."""
batch = []
batch: list = []

while self.writer_running:
try:
Expand Down Expand Up @@ -833,6 +835,8 @@ async def _export_database(self):
for key in json.loads(row[0]):
headers[key] = None

new_rows: list = []

if self.compress_export:
with gzip.open(path, "wt") as gz_file:
writer = csv.DictWriter(gz_file, fieldnames=list(headers))
Expand All @@ -841,7 +845,7 @@ async def _export_database(self):

while True:
rows = await cursor.fetchmany(chunk_size)
new_rows: list = []

if not rows:
break

Expand All @@ -864,7 +868,7 @@ async def _export_database(self):

while True:
rows = await cursor.fetchmany(chunk_size)
new_rows: list = []

if not rows:
break

Expand Down Expand Up @@ -915,17 +919,17 @@ def start_export_task(self):
return

self._export_running = True
self.export_thread = threading.Thread(
self.export_thread = threading.Thread( # type: ignore
target=self._run_export_event, name="ExportThread", daemon=True
)
self.export_thread.start()
self.export_thread.start() # type: ignore

def stop_export_task(self):
"""Public method to stop the background export task."""
if hasattr(self, "export_thread") and self.export_thread:
self.export_thread.join(timeout=1)
if self.export_thread.is_alive():
kill_thread(self.export_thread)
self.export_thread.join(timeout=1) # type: ignore
if self.export_thread.is_alive(): # type: ignore
kill_thread(self.export_thread) # type: ignore
self._export_running = False
self.export_thread = None

Expand Down Expand Up @@ -1091,12 +1095,12 @@ def __init__(
def run(self):
"""Run the batch processor as tasks."""
try:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.loop = asyncio.new_event_loop() # type: ignore
asyncio.set_event_loop(self.loop) # type: ignore
# Create worker tasks
while self.running and not self._shutdown.is_set():
try:
self.loop.run_until_complete(self._worker())
self.loop.run_until_complete(self._worker()) # type: ignore
except (SystemExit, KeyboardInterrupt):
self.running = False
break
Expand All @@ -1112,19 +1116,19 @@ def stop(self):
"""Signal thread to stop gracefully."""
self.running = False
self._shutdown.set()
if self.loop and self.loop.is_running():
self.loop.call_soon_threadsafe(self.loop.stop)
if self.loop and self.loop.is_running(): # type: ignore
self.loop.call_soon_threadsafe(self.loop.stop) # type: ignore

def _cleanup(self):
"""Clean up resources on shutdown"""
if self.loop:
pending = asyncio.all_tasks(self.loop)
pending = asyncio.all_tasks(self.loop) # type: ignore
for task in pending:
task.cancel()
self.loop.run_until_complete(
self.loop.run_until_complete( # type: ignore
asyncio.gather(*pending, return_exceptions=True)
)
self.loop.close()
self.loop.close() # type: ignore

async def _worker(self):
# pylint: disable=import-outside-toplevel
Expand Down
47 changes: 44 additions & 3 deletions openbb_platform/providers/fmp/openbb_fmp/utils/websocket_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,38 @@
"""FMP WebSocket client."""
"""
FMP WebSocket Client.
This file should be run as a script, and is intended to be run as a subprocess of FmpWebSocketFetcher.
Keyword arguments are passed from the command line as space-delimited, `key=value`, pairs.
Required Keyword Arguments
--------------------------
api_key: str
The API key for the Polygon WebSocket.
asset_type: str
The asset type to subscribe to. Default is "crypto".
Options: "stock", "crypto", "fx"
symbol: str
The symbol to subscribe to. Example: "AAPL" or "AAPL,MSFT".
results_file: str
The path to the file where the results will be stored.
Optional Keyword Arguments
--------------------------
table_name: str
The name of the table to store the data in. Default is "records".
limit: int
The maximum number of rows to store in the database.
connect_kwargs: dict
Additional keyword arguments to pass directly to `websockets.connect()`.
Example: {"ping_timeout": 300}
"""

import asyncio
import json
import signal
import sys

import orjson as json
import websockets
from openbb_core.provider.utils.websockets.database import Database, DatabaseWriter
from openbb_core.provider.utils.websockets.helpers import (
Expand All @@ -31,6 +59,8 @@
CONNECT_KWARGS = kwargs.pop("connect_kwargs", {})
URL = URL_MAP.get(kwargs.pop("asset_type"), None)

SUBSCRIBED_SYMBOLS: set = set()

if not URL:
raise ValueError("Invalid asset type provided.")

Expand Down Expand Up @@ -87,6 +117,14 @@ async def subscribe(websocket, symbol, event):
}
try:
await websocket.send(json.dumps(subscribe_event))

for t in ticker:
if event == "subscribe":
SUBSCRIBED_SYMBOLS.add(t)
else:
SUBSCRIBED_SYMBOLS.discard(t)

kwargs["symbol"] = ",".join(SUBSCRIBED_SYMBOLS)
except Exception as e: # pylint: disable=broad-except
msg = f"PROVIDER ERROR: {e.__class__.__name__ if hasattr(e, '__class__') else e}: {e}"
logger.error(msg)
Expand Down Expand Up @@ -201,7 +239,10 @@ async def connect_and_stream():
sys.exit(1)

except Exception as e: # pylint: disable=broad-except
msg = f"PROVIDER ERROR: Unexpected error -> {e.__class__.__name__}: {e}"
msg = (
"PROVIDER ERROR: Unexpected error ->"
f" {e.__class__.__name__ if hasattr(e, '__class__') else e}: {e.args}"
)
logger.error(msg)
sys.exit(1)

Expand Down
Loading

0 comments on commit 58dadc5

Please sign in to comment.