Skip to content

Commit

Permalink
Remove fastapi socketio (#29)
Browse files Browse the repository at this point in the history
This PR removes the dependency on the fastapi-socketio library (which is
no longer maintained).
It is replaced by using the socketio library.

---------

Co-authored-by: Niklas Neugebauer <[email protected]>
  • Loading branch information
denniswittich and NiklasNeugebauer authored Sep 30, 2024
1 parent 3e90660 commit c7dbc33
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 30 deletions.
1 change: 0 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
"--disable=C0301", // Line too long (exceeds character limit)
"--disable=W0718", // Catching too general exception
"--disable=W0719", // Raising too general exception
"--disable=W1203", // Use % formatting in logging functions and pass the % parameters as arguments
"--disable=W1514", // Using open without explicitly specifying an encoding
"--disable=R0902", // Too many instance attributes
"--disable=R0903", // Too few public methods
Expand Down
62 changes: 33 additions & 29 deletions learning_loop_node/detector/detector_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from typing import Dict, List, Optional, Union

import numpy as np
import socketio
from dacite import from_dict
from fastapi.encoders import jsonable_encoder
from fastapi_socketio import SocketManager
from socketio import AsyncClient

from ..data_classes import Category, Context, Detections, DetectionStatus, ModelInformation, Shape
Expand Down Expand Up @@ -41,7 +41,7 @@ def __init__(self, name: str, detector: DetectorLogic, uuid: Optional[str] = Non
self.organization = environment_reader.organization()
self.project = environment_reader.project()
assert self.organization and self.project, 'Detector node needs an organization and an project'
self.log.info(f'Using {self.organization}/{self.project}')
self.log.info('Using %s/%s', self.organization, self.project)
self.operation_mode: OperationMode = OperationMode.Startup
self.connected_clients: List[str] = []

Expand Down Expand Up @@ -126,10 +126,19 @@ async def on_repeat(self) -> None:

def setup_sio_server(self) -> None:
"""The DetectorNode acts as a SocketIO server. This method sets up the server and defines the event handlers."""

# pylint: disable=unused-argument

async def _detect(sid, data: Dict) -> Dict:
# Initialize the Socket.IO server
self.sio = socketio.AsyncServer(async_mode='asgi')
# Initialize and mount the ASGI app
self.sio_app = socketio.ASGIApp(self.sio, socketio_path='/socket.io')
self.mount('/ws', self.sio_app)
# Register event handlers

self.log.info('>>>>>>>>>>>>>>>>>>>>>>> Setting up the SIO server')

@self.sio.event
async def detect(sid, data: Dict) -> Dict:
self.log.info('running detect via socketio')
try:
np_image = np.frombuffer(data['image'], np.uint8)
Expand All @@ -149,12 +158,14 @@ async def _detect(sid, data: Dict) -> Dict:
f.write(data['image'])
return {'error': str(e)}

async def _info(sid) -> Union[str, Dict]:
@self.sio.event
async def info(sid) -> Union[str, Dict]:
if self.detector_logic.is_initialized:
return asdict(self.detector_logic.model_info)
return 'No model loaded'

async def _upload(sid, data: Dict) -> Optional[Dict]:
@self.sio.event
async def upload(sid, data: Dict) -> Optional[Dict]:
'''upload an image with detections'''

detection_data = data.get('detections', {})
Expand All @@ -179,42 +190,35 @@ async def _upload(sid, data: Dict) -> Optional[Dict]:
return {'error': str(e)}
return None

def _connect(sid, environ, auth) -> None:
@self.sio.event
def connect(sid, environ, auth) -> None:
self.connected_clients.append(sid)

print('>>>>>>>>>>>>>>>>>>>>>>> setting up sio server', flush=True)

self.sio_server = SocketManager(app=self)
self.sio_server.on('detect', _detect)
self.sio_server.on('info', _info)
self.sio_server.on('upload', _upload)
self.sio_server.on('connect', _connect)

async def _check_for_update(self) -> None:
if self.operation_mode == OperationMode.Startup:
return
try:
self.log.info(f'Current operation mode is {self.operation_mode}')
self.log.info('Current operation mode is %s', self.operation_mode)
try:
await self.sync_status_with_learning_loop()
except Exception as e:
self.log.error(f'Could not check for updates: {e}')
self.log.error('Could not check for updates: %s', e)
return

if self.operation_mode != OperationMode.Idle:
self.log.info(f'not checking for updates; operation mode is {self.operation_mode}')
self.log.info('not checking for updates; operation mode is %s', self.operation_mode)
return

self.status.reset_error('update_model')
if self.target_model is None:
self.log.info('not checking for updates; no target model selected')
return

current_version = self.detector_logic._model_info.version if self.detector_logic._model_info is not None else None
current_version = self.detector_logic._model_info.version if self.detector_logic._model_info is not None else None # pylint: disable=protected-access

if not self.detector_logic.is_initialized or self.target_model.version != current_version:
self.log.info(
f'Current model "{current_version or "-"}" needs to be updated to {self.target_model.version}')
self.log.info('Current model "%s" needs to be updated to %s',
current_version or "-", self.target_model.version)

with step_into(GLOBALS.data_folder):
model_symlink = 'model'
Expand All @@ -232,7 +236,7 @@ async def _check_for_update(self) -> None:
except Exception:
pass
os.symlink(target_model_folder, model_symlink)
self.log.info(f'Updated symlink for model to {os.readlink(model_symlink)}')
self.log.info('Updated symlink for model to %s', os.readlink(model_symlink))

self.detector_logic.load_model()
try:
Expand Down Expand Up @@ -283,13 +287,13 @@ async def sync_status_with_learning_loop(self) -> None:
model_format=self.detector_logic.model_format,
)

self.log.info(f'sending status {status}')
self.log.info('sending status %s', status)
response = await self.sio_client.call('update_detector', (self.organization, self.project, jsonable_encoder(asdict(status))))

assert response is not None
socket_response = from_dict(data_class=SocketResponse, data=response)
if not socket_response.success:
self.log.error(f'Statusupdate failed: {response}')
self.log.error('Statusupdate failed: %s', response)
raise Exception(f'Statusupdate failed: {response}')

assert socket_response.payload is not None
Expand All @@ -303,19 +307,19 @@ async def sync_status_with_learning_loop(self) -> None:

if self.version_control == rest_version_control.VersionMode.FollowLoop:
self.target_model = self.loop_deployment_target
self.log.info(f'After sending status. Target_model is {self.target_model.version}')
self.log.info('After sending status. Target_model is %s', self.target_model.version)

async def set_operation_mode(self, mode: OperationMode):
self.operation_mode = mode
try:
await self.sync_status_with_learning_loop()
except Exception as e:
self.log.warning(f'Operation mode set to {mode}, but sync failed: {e}')
self.log.warning('Operation mode set to %s, but sync failed: %s', mode, e)

def reload(self, reason: str):
'''provide a cause for the reload'''

self.log.info(f'########## reloading app because {reason}')
self.log.info('########## reloading app because %s', reason)
if os.path.isfile('/app/app_code/restart/restart.py'):
subprocess.call(['touch', '/app/app_code/restart/restart.py'])
elif os.path.isfile('/app/main.py'):
Expand All @@ -340,7 +344,7 @@ async def get_detections(self, raw_image: np.ndarray, camera_id: Optional[str],

n_bo, n_cl = len(detections.box_detections), len(detections.classification_detections)
n_po, n_se = len(detections.point_detections), len(detections.segmentation_detections)
self.log.info(f'detected:{n_bo} boxes, {n_po} points, {n_se} segs, {n_cl} classes')
self.log.info('Detected %d boxes, %d points, %d segs, %d classes', n_bo, n_po, n_se, n_cl)

if autoupload is None or autoupload == 'filtered': # NOTE default is filtered
Thread(target=self.relevance_filter.may_upload_detections,
Expand All @@ -350,7 +354,7 @@ async def get_detections(self, raw_image: np.ndarray, camera_id: Optional[str],
elif autoupload == 'disabled':
pass
else:
self.log.error(f'unknown autoupload value {autoupload}')
self.log.error('unknown autoupload value %s', autoupload)
return jsonable_encoder(asdict(detections))

async def upload_images(self, images: List[bytes]):
Expand Down

0 comments on commit c7dbc33

Please sign in to comment.