diff --git a/.syncignore b/.syncignore new file mode 100644 index 00000000..7bf361fc --- /dev/null +++ b/.syncignore @@ -0,0 +1,5 @@ +.git/ +__pycache__/ +.DS_Store +*.tmp +.env \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 45eb6e46..aec19884 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -12,7 +12,15 @@ "--disable=W0718", // Catching too general exception "--disable=W0719", // Raising too general exception "--disable=W1203", // Use % formatting in logging functions and pass the % parameters as arguments - "--disable=W1514" // Using open without explicitly specifying an encoding + "--disable=W1514", // Using open without explicitly specifying an encoding + "--disable=R0902", // Too many instance attributes + "--disable=R0903", // Too few public methods + "--disable=R0912", // Too many branches + "--disable=R0913", // Too many arguments + "--disable=R0914", // Too many local variables + "--disable=R0915", // Too many statements + "--disable=R1732", // Consider using with for resource-allocating operations + "--disable=R0801" // Similar lines in 2 files ], "[python]": { "editor.defaultFormatter": "ms-python.autopep8", diff --git a/learning_loop_node/__init__.py b/learning_loop_node/__init__.py index b8f0f5cd..2fa5362e 100644 --- a/learning_loop_node/__init__.py +++ b/learning_loop_node/__init__.py @@ -1,12 +1,11 @@ import logging -import os -import sys -from .converter.converter_node import ConverterNode # from . import log_conf from .detector.detector_logic import DetectorLogic from .detector.detector_node import DetectorNode from .globals import GLOBALS from .trainer.trainer_node import TrainerNode +__all__ = ['TrainerNode', 'DetectorNode', 'DetectorLogic', 'GLOBALS'] + logging.info('>>>>>>>>>>>>>>>>>> LOOP INITIALIZED <<<<<<<<<<<<<<<<<<<<<<<') diff --git a/learning_loop_node/annotation/annotator_logic.py b/learning_loop_node/annotation/annotator_logic.py index 932abce9..a80cc13b 100644 --- a/learning_loop_node/annotation/annotator_logic.py +++ b/learning_loop_node/annotation/annotator_logic.py @@ -7,10 +7,10 @@ class AnnotatorLogic(): - def __init__(self): + def __init__(self) -> None: self._node: Optional[Node] = None - def init(self, node: Node): + def init(self, node: Node) -> None: self._node = node @abstractmethod diff --git a/learning_loop_node/annotation/annotator_node.py b/learning_loop_node/annotation/annotator_node.py index b1781b73..94848506 100644 --- a/learning_loop_node/annotation/annotator_node.py +++ b/learning_loop_node/annotation/annotator_node.py @@ -8,7 +8,7 @@ from ..data_classes import AnnotationNodeStatus, Context, NodeState, UserInput from ..data_classes.socket_response import SocketResponse from ..data_exchanger import DataExchanger -from ..helpers.misc import create_image_folder +from ..helpers.misc import create_image_folder, create_project_folder from ..node import Node from .annotator_logic import AnnotatorLogic @@ -18,10 +18,11 @@ class AnnotatorNode(Node): def __init__(self, name: str, annotator_logic: AnnotatorLogic, uuid: Optional[str] = None): - super().__init__(name, uuid) + super().__init__(name, uuid, 'annotation_node') self.tool = annotator_logic self.histories: Dict = {} annotator_logic.init(self) + self.status_sent = False def register_sio_events(self, sio_client: AsyncClient): @@ -50,8 +51,6 @@ async def _handle_user_input(self, user_input_dict: Dict) -> str: raise if tool_result.annotation: - if not self.sio_is_initialized(): - raise Exception('Socket client waas not initialized') await self.sio_client.call('update_segmentation_annotation', (user_input.data.context.organization, user_input.data.context.project, jsonable_encoder(asdict(tool_result.annotation))), timeout=30) @@ -67,6 +66,9 @@ def get_history(self, frontend_id: str) -> Dict: return self.histories.setdefault(frontend_id, self.tool.create_empty_history()) async def send_status(self): + if self.status_sent: + return + status = AnnotationNodeStatus( id=self.uuid, name=self.name, @@ -75,28 +77,27 @@ async def send_status(self): ) self.log.info(f'Sending status {status}') - if self._sio_client is None: - raise Exception('No socket client') - result = await self._sio_client.call('update_annotation_node', jsonable_encoder(asdict(status)), timeout=10) + try: + result = await self.sio_client.call('update_annotation_node', jsonable_encoder(asdict(status)), timeout=10) + except Exception as e: + self.log.error(f'Error for updating: {str(e)}') + return + assert isinstance(result, Dict) response = from_dict(data_class=SocketResponse, data=result) if not response.success: self.log.error(f'Error for updating: Response from loop was : {asdict(response)}') + else: + self.status_sent = True async def download_image(self, context: Context, uuid: str): - project_folder = Node.create_project_folder(context) + project_folder = create_project_folder(context) images_folder = create_image_folder(project_folder) downloader = DataExchanger(context=context, loop_communicator=self.loop_communicator) await downloader.download_images([uuid], images_folder) - async def get_state(self): - return NodeState.Online - - def get_node_type(self): - return 'annotation_node' - async def on_startup(self): pass @@ -104,4 +105,4 @@ async def on_shutdown(self): pass async def on_repeat(self): - pass + await self.send_status() diff --git a/learning_loop_node/converter/converter_logic.py b/learning_loop_node/converter/converter_logic.py deleted file mode 100644 index cef82eff..00000000 --- a/learning_loop_node/converter/converter_logic.py +++ /dev/null @@ -1,68 +0,0 @@ -import json -import os -import shutil -from abc import abstractmethod -from typing import List, Optional - -from ..data_classes import ModelInformation -from ..node import Node - - -class ConverterLogic(): - - def __init__( - self, source_format: str, target_format: str): - self.source_format = source_format - self.target_format = target_format - self._node: Optional[Node] = None - self.model_folder: Optional[str] = None - - def init(self, node: Node) -> None: - self._node = node - - @property - def node(self) -> Node: - if self._node is None: - raise Exception('ConverterLogic not initialized') - return self._node - - async def convert(self, model_information: ModelInformation) -> None: - project_folder = Node.create_project_folder(model_information.context) - - self.model_folder = ConverterLogic.create_model_folder(project_folder, model_information.id) - await self.node.data_exchanger.download_model(self.model_folder, - model_information.context, - model_information.id, - self.source_format) - - with open(f'{self.model_folder}/model.json', 'r') as f: - content = json.load(f) - if 'resolution' in content: - model_information.resolution = content['resolution'] - - await self._convert(model_information) - - async def upload_model(self, context, model_id: str) -> None: - files = self.get_converted_files(model_id) - await self.node.data_exchanger.upload_model(context, files, model_id, self.target_format) - - @abstractmethod - async def _convert(self, model_information: ModelInformation) -> None: - """Converts the model in self.model_folder to the target format.""" - - @abstractmethod - def get_converted_files(self, model_id) -> List[str]: - """Returns a list of files that should be uploaded to the server.""" - - @staticmethod - def create_convert_folder(project_folder: str) -> str: - image_folder = f'{project_folder}/images' - os.makedirs(image_folder, exist_ok=True) - return image_folder - - @staticmethod - def create_model_folder(project_folder: str, model_id: str) -> str: - model_folder = f'{project_folder}/{model_id}' - shutil.rmtree(model_folder, ignore_errors=True) # cleanup - os.makedirs(model_folder, exist_ok=True) - return model_folder diff --git a/learning_loop_node/converter/converter_node.py b/learning_loop_node/converter/converter_node.py deleted file mode 100644 index f23dd26e..00000000 --- a/learning_loop_node/converter/converter_node.py +++ /dev/null @@ -1,125 +0,0 @@ -import logging -from dataclasses import asdict -from http import HTTPStatus -from typing import List, Optional - -from dacite import from_dict -from fastapi.encoders import jsonable_encoder -from fastapi_utils.tasks import repeat_every -from socketio import AsyncClient - -from ..data_classes import Category, ModelInformation, NodeState -from ..node import Node -from .converter_logic import ConverterLogic - - -class ConverterNode(Node): - converter: ConverterLogic - skip_check_state: bool = False - bad_model_ids: List[str] = [] - - def __init__(self, name: str, converter: ConverterLogic, uuid: Optional[str] = None): - super().__init__(name, uuid) - self.converter = converter - converter.init(self) - - @self.on_event("startup") - @repeat_every(seconds=60, raise_exceptions=True, wait_first=False) - async def check_state(): - if not self.skip_check_state: - try: - await self.check_state() - except Exception: - logging.error('could not check state. Is loop reachable?') - - async def convert_model(self, model_information: ModelInformation): - if model_information.id in self.bad_model_ids: - logging.info( - f'skipping bad model model {model_information.id} for {model_information.context.organization}/{model_information.context.project}.') - return - try: - logging.info( - f'converting model {jsonable_encoder(asdict(model_information))}') - await self.converter.convert(model_information) - logging.info('uploading model ') - await self.converter.upload_model(model_information.context, model_information.id) - except Exception as e: - self.bad_model_ids.append(model_information.id) - logging.error( - f'could not convert model {model_information.id} for {model_information.context.organization}/{model_information.context.project}. Details: {str(e)}.') - - async def check_state(self): - logging.info(f'checking state: {self.status.state}') - - if self.status.state == NodeState.Running: - return - self.status.state = NodeState.Running - try: - await self.convert_models() - except Exception as exc: - logging.error(str(exc)) - - self.status.state = NodeState.Idle - - async def convert_models(self) -> None: - try: - response = await self.loop_communicator.get('/projects') - assert response.status_code == 200, f'Assert statuscode 200, but was {response.status_code}.' - content = response.json() - projects = content['projects'] - - for project in projects: - organization_id = project['organization_id'] - project_id = project['project_id'] - - response = await self.loop_communicator.get(f'{project["resource"]}') - if response.status_code != HTTPStatus.OK: - logging.error(f'got bad response for {response.url}: {str(response.status_code)}') - continue - - project_categories = [from_dict(data_class=Category, data=c) for c in response.json()['categories']] - - path = f'{project["resource"]}/models' - models_response = await self.loop_communicator.get(path) - assert models_response.status_code == 200 - content = models_response.json() - models = content['models'] - - for model in models: - if (model['version'] - and self.converter.source_format in model['formats'] - and self.converter.target_format not in model['formats'] - ): - # if self.converter.source_format in model['formats'] and project_id == 'drawingbot' and model['version'] == "6.0": - model_information = ModelInformation( - host=self.loop_communicator.base_url, - organization=organization_id, - project=project_id, - id=model['id'], - categories=project_categories, - version=model['version'], - ) - await self.convert_model(model_information) - except Exception: - logging.exception('could not convert models') - - async def send_status(self): - pass - - async def on_startup(self): - pass - - async def on_shutdown(self): - pass - - async def on_repeat(self): - pass - - def register_sio_events(self, sio_client: AsyncClient): - pass - - async def get_state(self): - return NodeState.Idle # NOTE unused for this node type - - def get_node_type(self): - return 'converter' diff --git a/learning_loop_node/converter/tests/test_converter.py b/learning_loop_node/converter/tests/test_converter.py deleted file mode 100644 index 7328806f..00000000 --- a/learning_loop_node/converter/tests/test_converter.py +++ /dev/null @@ -1,55 +0,0 @@ -import logging -from typing import List - -import pytest - -from learning_loop_node.converter.converter_logic import ConverterLogic -from learning_loop_node.converter.converter_node import ConverterNode -from learning_loop_node.data_classes import ModelInformation -from learning_loop_node.loop_communication import LoopCommunicator -from learning_loop_node.tests import test_helper - - -class TestConverter(ConverterLogic): - __test__ = False # hint for pytest - - def __init__(self, source_format: str, target_format: str, models: List[ModelInformation]): - super().__init__(source_format, target_format) - self.models = models - - async def _convert(self, model_information: ModelInformation) -> None: - self.models.append(model_information) - - def get_converted_files(self, model_id) -> List[str]: - return [] # test: test_meta_information fails because model cannot be uploaded - - -@pytest.mark.asyncio -@pytest.fixture() -async def setup_converter_test_project(glc: LoopCommunicator): - await glc.delete("/zauberzeug/projects/pytest_conv?keep_images=true") - project_configuration = { - 'project_name': 'pytest_conv', 'box_categories': 1, 'point_categories': 1, 'inbox': 0, 'annotate': 0, 'review': 0, - 'complete': 0, 'image_style': 'plain', 'thumbs': False, 'trainings': 1} - r = await glc.post("/zauberzeug/projects/generator", json=project_configuration) - assert r.status_code == 200 - yield - await glc.delete("/zauberzeug/projects/pytest?keep_images=true") - - -# pylint: disable=redefined-outer-name, unused-argument -@pytest.mark.asyncio -async def test_meta_information(setup_converter_test_project): - model_id = await test_helper.get_latest_model_id(project='pytest_conv') - - converter = TestConverter(source_format='mocked', target_format='test', models=[]) - node = ConverterNode(name='test', converter=converter) - await node.convert_models() - - pytest_project_model = [m for m in converter.models if m.id == model_id][0] - - categories = pytest_project_model.categories - assert len(categories) == 2 - category_types = [category.type for category in categories] - assert 'box' in category_types - assert 'point' in category_types diff --git a/learning_loop_node/data_classes/__init__.py b/learning_loop_node/data_classes/__init__.py index bc2980cd..524cb8bb 100644 --- a/learning_loop_node/data_classes/__init__.py +++ b/learning_loop_node/data_classes/__init__.py @@ -1,12 +1,19 @@ -from .annotations import (AnnotationData, AnnotationEventType, - SegmentationAnnotation, ToolOutput, UserInput) -from .detections import (BoxDetection, ClassificationDetection, Detections, - Observation, Point, PointDetection, +from .annotations import AnnotationData, AnnotationEventType, SegmentationAnnotation, ToolOutput, UserInput +from .detections import (BoxDetection, ClassificationDetection, Detections, Observation, Point, PointDetection, SegmentationDetection, Shape) -from .general import (AnnotationNodeStatus, Category, CategoryType, Context, - DetectionStatus, ErrorConfiguration, ModelInformation, - NodeState, NodeStatus) +from .general import (AnnotationNodeStatus, Category, CategoryType, Context, DetectionStatus, ErrorConfiguration, + ModelInformation, NodeState, NodeStatus) from .socket_response import SocketResponse -from .training import (BasicModel, Errors, Hyperparameter, Model, - PretrainedModel, Training, TrainingData, TrainingError, - TrainingOut, TrainingState, TrainingStatus) +from .training import (Errors, Hyperparameter, Model, PretrainedModel, TrainerState, Training, TrainingData, + TrainingError, TrainingOut, TrainingStateData, TrainingStatus) + +__all__ = [ + 'AnnotationData', 'AnnotationEventType', 'SegmentationAnnotation', 'ToolOutput', 'UserInput', + 'BoxDetection', 'ClassificationDetection', 'Detections', 'Observation', 'Point', 'PointDetection', + 'SegmentationDetection', 'Shape', + 'AnnotationNodeStatus', 'Category', 'CategoryType', 'Context', 'DetectionStatus', 'ErrorConfiguration', + 'ModelInformation', 'NodeState', 'NodeStatus', + 'SocketResponse', + 'Errors', 'Hyperparameter', 'Model', 'PretrainedModel', 'TrainerState', 'Training', 'TrainingData', + 'TrainingError', 'TrainingOut', 'TrainingStateData', 'TrainingStatus', +] diff --git a/learning_loop_node/data_classes/detections.py b/learning_loop_node/data_classes/detections.py index 21924720..0872b256 100644 --- a/learning_loop_node/data_classes/detections.py +++ b/learning_loop_node/data_classes/detections.py @@ -13,8 +13,11 @@ @dataclass(**KWONLY_SLOTS) class BoxDetection(): + """Coordinates according to COCO format. x,y is the top left corner of the box. + x increases to the right, y increases downwards. + """ category_name: str - x: int # TODO add definition of x,y,w,h + x: int y: int width: int height: int @@ -47,6 +50,8 @@ def __str__(self): @dataclass(**KWONLY_SLOTS) class PointDetection(): + """Coordinates according to COCO format. x,y is the center of the point. + x increases to the right, y increases downwards.""" category_name: str x: float y: float @@ -111,7 +116,7 @@ class Detections(): point_detections: List[PointDetection] = field(default_factory=list) segmentation_detections: List[SegmentationDetection] = field(default_factory=list) classification_detections: List[ClassificationDetection] = field(default_factory=list) - tags: Optional[List[str]] = field(default_factory=list) + tags: List[str] = field(default_factory=list) date: Optional[str] = field(default_factory=current_datetime) image_id: Optional[str] = None # used for detection of trainers diff --git a/learning_loop_node/data_classes/general.py b/learning_loop_node/data_classes/general.py index 8404ab22..3ef5e412 100644 --- a/learning_loop_node/data_classes/general.py +++ b/learning_loop_node/data_classes/general.py @@ -34,10 +34,6 @@ def from_list(values: List[dict]) -> List['Category']: return [from_dict(data_class=Category, data=value) for value in values] -def create_category(identifier: str, name: str, ctype: Union[CategoryType, str]): # TODO: This is probably unused - return Category(id=identifier, name=name, description='', hotkey='', color='', type=ctype, point_size=None) - - @dataclass(**KWONLY_SLOTS) class Context(): organization: str @@ -57,6 +53,7 @@ class ModelInformation(): categories: List[Category] resolution: Optional[int] = None model_root_path: Optional[str] = None + model_size: Optional[str] = None @property def context(self): @@ -64,6 +61,8 @@ def context(self): @staticmethod def load_from_disk(model_root_path: str) -> Optional['ModelInformation']: + """Load model.json from model_root_path and return ModelInformation object. + """ model_info_file_path = f'{model_root_path}/model.json' if not os.path.exists(model_info_file_path): logging.warning(f"could not find model information file '{model_info_file_path}'") @@ -121,7 +120,7 @@ class NodeState(str, Enum): class NodeStatus(): id: str name: str - state: Optional[NodeState] = NodeState.Offline + state: Optional[NodeState] = NodeState.Online uptime: Optional[int] = 0 errors: Dict = field(default_factory=dict) capabilities: List[str] = field(default_factory=list) diff --git a/learning_loop_node/data_classes/training.py b/learning_loop_node/data_classes/training.py index 49432925..d530ae7a 100644 --- a/learning_loop_node/data_classes/training.py +++ b/learning_loop_node/data_classes/training.py @@ -1,8 +1,10 @@ import sys +import time from dataclasses import dataclass, field from enum import Enum -from typing import Dict, List, Optional, Union +from pathlib import Path +from typing import Dict, List, Optional # pylint: disable=no-name-in-module from .general import Category, Context @@ -16,6 +18,14 @@ class Hyperparameter(): flip_rl: bool flip_ud: bool + @staticmethod + def from_data(data: Dict): + return Hyperparameter( + resolution=data['resolution'], + flip_rl=data.get('flip_rl', False), + flip_ud=data.get('flip_ud', False) + ) + @dataclass(**KWONLY_SLOTS) class TrainingData(): @@ -41,14 +51,15 @@ class PretrainedModel(): description: str -class TrainingState(str, Enum): +class TrainerState(str, Enum): + Idle = 'idle' Initialized = 'initialized' Preparing = 'preparing' DataDownloading = 'data_downloading' DataDownloaded = 'data_downloaded' TrainModelDownloading = 'train_model_downloading' TrainModelDownloaded = 'train_model_downloaded' - TrainingRunning = 'training_running' + TrainingRunning = 'running' TrainingFinished = 'training_finished' ConfusionMatrixSyncing = 'confusion_matrix_syncing' ConfusionMatrixSynced = 'confusion_matrix_synced' @@ -62,9 +73,9 @@ class TrainingState(str, Enum): @dataclass(**KWONLY_SLOTS) class TrainingStatus(): - id: str # TODO this must not be changed, but tests wont detect it -> update tests! + id: str # NOTE this must not be changed, but tests wont detect a change -> update tests! name: str - state: Union[Optional[TrainingState], str] + state: Optional[str] errors: Optional[Dict] uptime: Optional[float] progress: Optional[float] @@ -77,13 +88,13 @@ class TrainingStatus(): architecture: Optional[str] = None context: Optional[Context] = None - def short_str(self): + def short_str(self) -> str: prgr = f'{self.progress * 100:.0f}%' if self.progress else '' trtesk = f'{self.train_image_count}/{self.test_image_count}/{self.skipped_image_count}' if self.train_image_count else 'n.a.' cntxt = f'{self.context.organization}/{self.context.project}' if self.context else '' hyps = f'({self.hyperparameters})' if self.hyperparameters else '' arch = f'.{self.architecture} - ' if self.architecture else '' - return f'[{str(self.state)} {prgr}. {self.name}({self.id}). Tr/Ts/Tsk: {trtesk} {cntxt}{arch}{hyps}]' + return f'[{str(self.state).rsplit(".", maxsplit=1)[-1]} {prgr}. {self.name}({self.id}). Tr/Ts/Tsk: {trtesk} {cntxt}{arch}{hyps}]' @dataclass(**KWONLY_SLOTS) @@ -91,21 +102,35 @@ class Training(): id: str context: Context - project_folder: str - images_folder: str - training_folder: str + project_folder: str # f'{GLOBALS.data_folder}/{context.organization}/{context.project}' + images_folder: str # f'{project_folder}/images' + training_folder: str # f'{project_folder}/trainings/{trainings_id}' + start_time: float = field(default_factory=time.time) + + # model uuid to download (to continue training) | is not a uuid when training from scratch (blank or pt-name from provided_pretrained_models->name) + base_model_uuid_or_name: Optional[str] = None - base_model_id: Optional[str] = None data: Optional[TrainingData] = None training_number: Optional[int] = None - training_state: Optional[Union[TrainingState, str]] = None - model_id_for_detecting: Optional[str] = None + training_state: Optional[str] = None + model_uuid_for_detecting: Optional[str] = None hyperparameters: Optional[Dict] = None + @property + def training_folder_path(self) -> Path: + return Path(self.training_folder) + + def set_values_from_data(self, data: Dict) -> None: + self.data = TrainingData(categories=Category.from_list(data['categories'])) + self.data.hyperparameter = Hyperparameter.from_data(data=data) + self.training_number = data['training_number'] + self.base_model_uuid_or_name = data['id'] + self.training_state = TrainerState.Initialized + @dataclass(**KWONLY_SLOTS) class TrainingOut(): - confusion_matrix: Optional[Dict] = None + confusion_matrix: Optional[Dict] = None # This is actually just class-wise metrics train_image_count: Optional[int] = None test_image_count: Optional[int] = None trainer_id: Optional[str] = None @@ -113,9 +138,9 @@ class TrainingOut(): @dataclass(**KWONLY_SLOTS) -class BasicModel(): - confusion_matrix: Optional[Dict] = None - meta_information: Optional[Dict] = None +class TrainingStateData(): + confusion_matrix: Dict = field(default_factory=dict) + meta_information: Dict = field(default_factory=dict) @dataclass(**KWONLY_SLOTS) @@ -130,8 +155,8 @@ class Model(): class Errors(): - def __init__(self): - self._errors: Dict = {} + def __init__(self) -> None: + self._errors: Dict[str, str] = {} def set(self, key: str, value: str): self._errors[key] = value @@ -140,7 +165,7 @@ def set(self, key: str, value: str): def errors(self) -> Dict: return self._errors - def reset(self, key: str): + def reset(self, key: str) -> None: try: del self._errors[key] except AttributeError: @@ -148,7 +173,7 @@ def reset(self, key: str): except KeyError: pass - def reset_all(self): + def reset_all(self) -> None: self._errors = {} def has_error_for(self, key: str) -> bool: @@ -162,3 +187,6 @@ class TrainingError(Exception): def __init__(self, cause: str, *args: object) -> None: super().__init__(*args) self.cause = cause + + def __str__(self) -> str: + return f'TrainingError: {self.cause}' diff --git a/learning_loop_node/data_exchanger.py b/learning_loop_node/data_exchanger.py index 23f19976..9e8ffdb8 100644 --- a/learning_loop_node/data_exchanger.py +++ b/learning_loop_node/data_exchanger.py @@ -2,23 +2,19 @@ import logging import os import shutil -import time import zipfile from glob import glob from http import HTTPStatus from io import BytesIO -from time import perf_counter +from time import time from typing import Dict, List, Optional -import aiofiles -from tqdm.asyncio import tqdm +import aiofiles # type: ignore from .data_classes import Context -from .helpers.misc import create_resource_paths, create_task +from .helpers.misc import create_resource_paths, create_task, is_valid_image from .loop_communication import LoopCommunicator -check_jpeg = shutil.which('jpeginfo') is not None - class DownloadError(Exception): @@ -26,201 +22,151 @@ def __init__(self, cause: str, *args: object) -> None: super().__init__(*args) self.cause = cause + def __str__(self) -> str: + return f'DownloadError: {self.cause}' + class DataExchanger(): def __init__(self, context: Optional[Context], loop_communicator: LoopCommunicator): - self.context = context + """Exchanges data with the learning loop via the loop_communicator (rest api). + + Args: + context (Optional[Context]): The context of the node. This is the organization and project name. + loop_communicator (LoopCommunicator): The loop_communicator to use for communication with the learning loop. + + Note: + The context can be set later with the set_context method. + """ + self.set_context(context) + self.progress = 0.0 self.loop_communicator = loop_communicator + + self.check_jpeg = shutil.which('jpeginfo') is not None + if self.check_jpeg: + logging.info('Detected command line tool "jpeginfo". Images will be checked for validity') + else: + logging.error('Missing command line tool "jpeginfo". We cannot check for validity of images.') + + def set_context(self, context: Optional[Context]) -> None: + self._context = context self.progress = 0.0 - def set_context(self, context: Context): - self.context = context + @property + def context(self) -> Context: + assert self._context, 'DataExchanger: Context was not set yet.. call set_context() first.' + return self._context - async def fetch_image_ids(self, query_params: Optional[str] = '') -> List[str]: - if self.context is None: - logging.warning('context was not set yet') - return [] + # ---------------------------- END OF INIT ---------------------------- + + async def fetch_image_uuids(self, query_params: Optional[str] = '') -> List[str]: + """Fetch image uuids from the learning loop data endpoint.""" + logging.info(f'Fetching image uuids for {self.context.organization}/{self.context.project}..') response = await self.loop_communicator.get(f'/{self.context.organization}/projects/{self.context.project}/data?{query_params}') assert response.status_code == 200, response return (response.json())['image_ids'] - async def download_images_data(self, ids: List[str]) -> List[Dict]: - '''Download image annotations etc.''' - if self.context is None: - logging.warning('context was not set yet') - return [] - - return await self._download_images_data(self.context.organization, self.context.project, ids) - - async def download_images(self, image_ids: List[str], image_folder: str) -> None: - '''Download images. Will skip existing images''' - if self.context is None: - logging.warning('context was not set yet') - return - - new_image_ids = await asyncio.get_event_loop().run_in_executor(None, DataExchanger.filter_existing_images, image_ids, image_folder) - paths, ids = create_resource_paths(self.context.organization, self.context.project, new_image_ids) - await self._download_images(paths, ids, image_folder) - - @staticmethod - async def delete_corrupt_images(image_folder: str) -> None: - logging.info('deleting corrupt images') - n_deleted = 0 - for image in glob(f'{image_folder}/*.jpg'): - if not await DataExchanger.is_valid_image(image): - logging.debug(f' deleting image {image}') - os.remove(image) - n_deleted += 1 - - logging.info(f'deleted {n_deleted} images') - - @staticmethod - def filter_existing_images(all_image_ids, image_folder) -> List[str]: - logging.info(f'### Going to filter {len(all_image_ids)} images ids') - start = perf_counter() - ids = [os.path.splitext(os.path.basename(image))[0] - for image in glob(f'{image_folder}/*.jpg')] - logging.info(f'found {len(ids)} images on disc') - result = [id for id in all_image_ids if id not in ids] - end = perf_counter() - logging.info(f'calculated {len(result)} new image ids, which took {end-start:0.2f} seconds') - return result - - def jepeg_check_info(self): - if check_jpeg: - logging.info('Detected command line tool "jpeginfo". Images will be checked for validity') - else: - logging.error('Missing command line tool "jpeginfo". We can not check for validity of images.') + async def download_images_data(self, image_uuids: List[str], chunk_size: int = 100) -> List[Dict]: + """Download image annotations, tags, set and other information for the given image uuids.""" + logging.info(f'Fetching annotations, tags, sets, etc. for {len(image_uuids)} images..') - async def _download_images_data(self, organization: str, project: str, image_ids: List[str], chunk_size: int = 100) -> List[Dict]: - logging.info('fetching annotations and other image data') - num_image_ids = len(image_ids) - self.jepeg_check_info() - images_data = [] + num_image_ids = len(image_uuids) if num_image_ids == 0: logging.info('got empty list. No images were downloaded') - return images_data - starttime = time.time() + return [] + progress_factor = 0.5 / num_image_ids # 50% of progress is for downloading data - for i in tqdm(range(0, num_image_ids, chunk_size), position=0, leave=True): + images_data: List[Dict] = [] + for i in range(0, num_image_ids, chunk_size): self.progress = i * progress_factor - chunk_ids = image_ids[i:i+chunk_size] - response = await self.loop_communicator.get(f'/{organization}/projects/{project}/images?ids={",".join(chunk_ids)}') + chunk_ids = image_uuids[i:i+chunk_size] + response = await self.loop_communicator.get(f'/{self.context.organization}/projects/{self.context.project}/images?ids={",".join(chunk_ids)}') if response.status_code != 200: - logging.error( - f'Error during downloading list of images. Statuscode is {response.status_code}') + logging.error(f'Error {response.status_code} during downloading image data. Continue with next batch..') continue images_data += response.json()['images'] - total_time = round(time.time() - starttime, 1) - if images_data: - per100 = total_time / len(images_data) * 100 - logging.debug(f'[+] Performance: {total_time} sec total. Per 100 : {per100:.1f} sec') - else: - logging.debug(f'[+] Performance: {total_time} sec total.') + return images_data - async def _download_images(self, paths: List[str], image_ids: List[str], image_folder: str, chunk_size: int = 10) -> None: - num_image_ids = len(image_ids) - if num_image_ids == 0: - logging.debug('got empty list. No images were downloaded') + async def download_images(self, image_uuids: List[str], image_folder: str, chunk_size: int = 10) -> None: + """Downloads images (actual image data). Will skip existing images""" + logging.info(f'Downloading {len(image_uuids)} images (actual image data).. skipping existing images.') + if not image_uuids: return - logging.info('fetching image files') - starttime = time.time() + + existing_uuids = {os.path.splitext(os.path.basename(image))[0] for image in glob(f'{image_folder}/*.jpg')} + new_image_uuids = [id for id in image_uuids if id not in existing_uuids] + + paths, _ = create_resource_paths(self.context.organization, self.context.project, new_image_uuids) + num_image_ids = len(image_uuids) os.makedirs(image_folder, exist_ok=True) progress_factor = 0.5 / num_image_ids # second 50% of progress is for downloading images - for i in tqdm(range(0, num_image_ids, chunk_size), position=0, leave=True): + for i in range(0, num_image_ids, chunk_size): self.progress = 0.5 + i * progress_factor chunk_paths = paths[i:i+chunk_size] - chunk_ids = image_ids[i:i+chunk_size] + chunk_ids = image_uuids[i:i+chunk_size] tasks = [] for j, chunk_j in enumerate(chunk_paths): - tasks.append(create_task(self.download_one_image(chunk_j, chunk_ids[j], image_folder))) + start = time() + tasks.append(create_task(self._download_one_image(chunk_j, chunk_ids[j], image_folder))) + await asyncio.sleep(max(0, 0.02 - (time() - start))) # prevent too many requests at once await asyncio.gather(*tasks) - total_time = round(time.time() - starttime, 1) - per100 = total_time / (i + len(tasks)) * 100 - logging.debug(f'[+] Performance (image files): {total_time} sec total. Per 100 : {per100:.1f}') - async def download_one_image(self, path: str, image_id: str, image_folder: str) -> None: + async def _download_one_image(self, path: str, image_id: str, image_folder: str) -> None: response = await self.loop_communicator.get(path) if response.status_code != HTTPStatus.OK: - logging.error(f'bad status code {response.status_code} for {path}') + logging.error(f'bad status code {response.status_code} for {path}. Details: {response.text}') return filename = f'{image_folder}/{image_id}.jpg' async with aiofiles.open(filename, 'wb') as f: await f.write(response.content) - if not await self.is_valid_image(filename): + if not await is_valid_image(filename, self.check_jpeg): os.remove(filename) - @staticmethod - async def is_valid_image(filename: str) -> bool: - if not os.path.isfile(filename) or os.path.getsize(filename) == 0: - return False - if not check_jpeg: - return True - - info = await asyncio.create_subprocess_shell( - f'jpeginfo -c {filename}', - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE) - out, _ = await info.communicate() - return "OK" in out.decode() - - async def download_model(self, target_folder: str, context: Context, model_id: str, model_format: str) -> List[str]: - path = f'/{context.organization}/projects/{context.project}/models/{model_id}/{model_format}/file' + async def download_model(self, target_folder: str, context: Context, model_uuid: str, model_format: str) -> List[str]: + """Downloads a model (and additional meta data like model.json) and returns the paths of the downloaded files. + Used before training a model (when continuing a finished training) or before detecting images. + """ + logging.info(f'Downloading model data for uuid {model_uuid} from the loop to {target_folder}..') + + path = f'/{context.organization}/projects/{context.project}/models/{model_uuid}/{model_format}/file' response = await self.loop_communicator.get(path, requires_login=False) if response.status_code != 200: content = response.json() - logging.error( - f'could not download {self.loop_communicator.base_url}/{path}: {response.status_code}, content: {content}') + logging.error(f'could not download loop/{path}: {response.status_code}, content: {content}') raise DownloadError(content['detail']) try: provided_filename = response.headers.get( "Content-Disposition").split("filename=")[1].strip('"') content = response.content except: - logging.error(f'Error during downloading model {path}:') - try: - logging.exception(response.json()) - except Exception: - pass + logging.exception(f'Error during downloading model {path}:') raise - # unzip and place downloaded model tmp_path = f'/tmp/{os.path.splitext(provided_filename)[0]}' shutil.rmtree(tmp_path, ignore_errors=True) with zipfile.ZipFile(BytesIO(content), 'r') as zip_: zip_.extractall(tmp_path) - logging.info(f'---- downloaded model {model_id} to {tmp_path}.') - created_files = [] - files = glob(f'{tmp_path}/**/*', recursive=True) - for file in files: + for file in glob(f'{tmp_path}/**/*', recursive=True): new_file = shutil.move(file, target_folder) - logging.info(f'moved model file {os.path.basename(file)} to {new_file}.') created_files.append(new_file) - return created_files - async def upload_model(self, context: Context, files: List[str], model_id: str, mformat: str) -> None: - response = await self.loop_communicator.put(f'/{context.organization}/projects/{context.project}/models/{model_id}/{mformat}/file', files=files) - if response.status_code != 200: - msg = f'---- could not upload model with id {model_id} and format {mformat}. Details: {response.text}' - raise Exception(msg) - logging.info(f'---- uploaded model with id {model_id} and format {mformat}.') + shutil.rmtree(tmp_path, ignore_errors=True) + logging.info(f'Downloaded model {model_uuid}({model_format}) to {target_folder}.') + return created_files - async def upload_model_for_training(self, context: Context, files: List[str], training_number: Optional[int], mformat: str) -> Optional[str]: - """Returns the new model uuid to use for detection.""" + async def upload_model_get_uuid(self, context: Context, files: List[str], training_number: Optional[int], mformat: str) -> Optional[str]: + """Used by the trainers. Function returns the new model uuid to use for detection.""" response = await self.loop_communicator.put(f'/{context.organization}/projects/{context.project}/trainings/{training_number}/models/latest/{mformat}/file', files=files) if response.status_code != 200: - msg = f'---- could not upload model for training {training_number} and format {mformat}. Details: {response.text}' - logging.error(msg) + logging.error(f'Could not upload model for training {training_number}, format {mformat}: {response.text}') response.raise_for_status() return None - else: - uploaded_model = response.json() - logging.info( - f'---- uploaded model for training {training_number} and format {mformat}. Model id is {uploaded_model}') - return uploaded_model['id'] + + uploaded_model = response.json() + logging.info(f'Uploaded model for training {training_number}, format {mformat}. Response is: {uploaded_model}') + return uploaded_model['id'] diff --git a/learning_loop_node/detector/__init__.py b/learning_loop_node/detector/__init__.py index 8b137891..e69de29b 100644 --- a/learning_loop_node/detector/__init__.py +++ b/learning_loop_node/detector/__init__.py @@ -1 +0,0 @@ - diff --git a/learning_loop_node/detector/detector_node.py b/learning_loop_node/detector/detector_node.py index 785a10fe..92b5fa21 100644 --- a/learning_loop_node/detector/detector_node.py +++ b/learning_loop_node/detector/detector_node.py @@ -14,7 +14,7 @@ from fastapi_socketio import SocketManager from socketio import AsyncClient -from ..data_classes import Category, Context, Detections, DetectionStatus, ModelInformation, NodeState, Shape +from ..data_classes import Category, Context, Detections, DetectionStatus, ModelInformation, Shape from ..data_classes.socket_response import SocketResponse from ..data_exchanger import DataExchanger, DownloadError from ..globals import GLOBALS @@ -34,9 +34,8 @@ class DetectorNode(Node): def __init__(self, name: str, detector: DetectorLogic, uuid: Optional[str] = None, use_backdoor_controls: bool = False) -> None: - super().__init__(name, uuid) + super().__init__(name, uuid, 'detector', False) self.detector_logic = detector - self.needs_login = False self.organization = environment_reader.organization() self.project = environment_reader.project() assert self.organization and self.project, 'Detector node needs an organization and an project' @@ -170,6 +169,8 @@ async def _upload(sid, data: Dict) -> Optional[Dict]: def _connect(sid, environ, auth) -> None: self.connected_clients.append(sid) + print('>>>>>>>>>>>>>>>>>>>>>>> setting up sio server', flush=True) + self.sio_server = SocketManager(app=self) self.sio_server.on('detect', _detect) self.sio_server.on('info', _info) @@ -185,7 +186,9 @@ async def _check_for_update(self) -> None: if not update_to_model_id: self.log.info('could not check for updates') return - if self.detector_logic.is_initialized: # TODO: solve race condition !!! + + # TODO: solve race condition (it should not be required to recheck if model_info is not None, but it is!) + if self.detector_logic.is_initialized: model_info = self.detector_logic._model_info # pylint: disable=protected-access if model_info is not None: self.log.info(f'Current model: {model_info.version} with id {model_info.id}') @@ -220,8 +223,7 @@ async def _check_for_update(self) -> None: await self.data_exchanger.download_model(target_model_folder, Context(organization=self.organization, project=self.project), - update_to_model_id, - self.detector_logic.model_format) + update_to_model_id, self.detector_logic.model_format) try: os.unlink(model_symlink) os.remove(model_symlink) @@ -256,7 +258,7 @@ async def send_status(self) -> Union[str, Literal[False]]: name=self.name, state=self.status.state, errors=self.status.errors, - uptime=int((datetime.now() - self.startup_time).total_seconds()), + uptime=int((datetime.now() - self.startup_datetime).total_seconds()), operation_mode=self.operation_mode, current_model=current_model, target_model=self.target_model, @@ -272,13 +274,11 @@ async def send_status(self) -> Union[str, Literal[False]]: return False assert socket_response.payload is not None + # TODO This is weird because target_model_version is stored in self and target_model_id is returned self.target_model = socket_response.payload['target_model_version'] self.log.info(f'After sending status. Target_model is {self.target_model}') return socket_response.payload['target_model_id'] - async def get_state(self): - return NodeState.Online # NOTE At the moment only trainer-nodes use a meaningful state - async def set_operation_mode(self, mode: OperationMode): self.operation_mode = mode await self.send_status() @@ -353,9 +353,6 @@ def find_category_id_by_name(categories: List[Category], category_name: str): classification_detection.category_id = category_id return detections - def get_node_type(self): - return 'detector' - def register_sio_events(self, sio_client: AsyncClient): pass diff --git a/learning_loop_node/detector/inbox_filter/cam_observation_history.py b/learning_loop_node/detector/inbox_filter/cam_observation_history.py index 88bbe881..a87c72ee 100644 --- a/learning_loop_node/detector/inbox_filter/cam_observation_history.py +++ b/learning_loop_node/detector/inbox_filter/cam_observation_history.py @@ -1,20 +1,17 @@ import os from typing import List, Union -from learning_loop_node.data_classes import (BoxDetection, - ClassificationDetection, - Detections, Observation, - PointDetection, - SegmentationDetection) +from learning_loop_node.data_classes import (BoxDetection, ClassificationDetection, Detections, Observation, + PointDetection, SegmentationDetection) class CamObservationHistory: - def __init__(self): + def __init__(self) -> None: self.reset_time = 3600 self.recent_observations: List[Observation] = [] self.iou_threshold = 0.5 - def forget_old_detections(self): + def forget_old_detections(self) -> None: self.recent_observations = [detection for detection in self.recent_observations if not detection.is_older_than(self.reset_time)] diff --git a/learning_loop_node/detector/outbox.py b/learning_loop_node/detector/outbox.py index 23138c85..ca1a200d 100644 --- a/learning_loop_node/detector/outbox.py +++ b/learning_loop_node/detector/outbox.py @@ -53,7 +53,6 @@ def save(self, image: bytes, detections: Optional[Detections] = None, tags: Opti with open(tmp + '/image.json', 'w') as f: json.dump(jsonable_encoder(asdict(detections)), f) - # TODO sometimes No such file or directory: '/tmp/learning_loop_lib_data/tmp/2023-09-07_13:27:38.399/image.jpg' with open(tmp + '/image.jpg', 'wb') as f: f.write(image) diff --git a/learning_loop_node/detector/rest/about.py b/learning_loop_node/detector/rest/about.py index c464b999..9f1e407e 100644 --- a/learning_loop_node/detector/rest/about.py +++ b/learning_loop_node/detector/rest/about.py @@ -16,6 +16,7 @@ async def get_about(request: Request): curl http://localhost/about ''' app: 'DetectorNode' = request.app + return { 'operation_mode': app.operation_mode.value, 'state': app.status.state, diff --git a/learning_loop_node/detector/tests/conftest.py b/learning_loop_node/detector/tests/conftest.py index ad183fe2..1611f265 100644 --- a/learning_loop_node/detector/tests/conftest.py +++ b/learning_loop_node/detector/tests/conftest.py @@ -12,7 +12,6 @@ import uvicorn from learning_loop_node import DetectorNode -from learning_loop_node.data_classes.general import Category, ModelInformation from learning_loop_node.detector.outbox import Outbox from learning_loop_node.globals import GLOBALS diff --git a/learning_loop_node/detector/tests/test_client_communication.py b/learning_loop_node/detector/tests/test_client_communication.py index be3d2d4b..24fbd095 100644 --- a/learning_loop_node/detector/tests/test_client_communication.py +++ b/learning_loop_node/detector/tests/test_client_communication.py @@ -5,7 +5,7 @@ import requests from learning_loop_node import DetectorNode -from learning_loop_node.data_classes import Category, ModelInformation +from learning_loop_node.data_classes import ModelInformation from learning_loop_node.detector.tests.conftest import get_outbox_files from learning_loop_node.globals import GLOBALS @@ -88,15 +88,17 @@ async def test_sio_upload(test_detector_node: DetectorNode, sio_client): assert len(get_outbox_files(test_detector_node.outbox)) == 2, 'There should be one image and one .json file.' +# NOTE: This test seems to be flaky. async def test_about_endpoint(test_detector_node: DetectorNode): - await asyncio.sleep(1) + await asyncio.sleep(3) response = requests.get(f'http://localhost:{GLOBALS.detector_port}/about', timeout=30) assert response.status_code == 200 response_dict = json.loads(response.content) + assert response_dict['model_info'] model_information = ModelInformation.from_dict(response_dict['model_info']) assert response_dict['operation_mode'] == 'idle' assert response_dict['state'] == 'online' assert response_dict['target_model'] == '1.1' - assert any([c.name == 'purple point' for c in model_information.categories]) + assert any(c.name == 'purple point' for c in model_information.categories) diff --git a/learning_loop_node/detector/tests/test_outbox.py b/learning_loop_node/detector/tests/test_outbox.py index 9db7dd09..adf56744 100644 --- a/learning_loop_node/detector/tests/test_outbox.py +++ b/learning_loop_node/detector/tests/test_outbox.py @@ -9,6 +9,8 @@ from learning_loop_node.detector.detector_node import DetectorNode from learning_loop_node.detector.outbox import Outbox +# pylint: disable=redefined-outer-name + @pytest.fixture() def test_outbox(): diff --git a/learning_loop_node/detector/tests/testing_detector.py b/learning_loop_node/detector/tests/testing_detector.py index ed710824..95dd1300 100644 --- a/learning_loop_node/detector/tests/testing_detector.py +++ b/learning_loop_node/detector/tests/testing_detector.py @@ -4,7 +4,7 @@ from learning_loop_node import DetectorLogic from learning_loop_node.conftest import get_dummy_detections -from learning_loop_node.data_classes import Category, Detections, ModelInformation +from learning_loop_node.data_classes import Detections class TestingDetectorLogic(DetectorLogic): @@ -20,10 +20,3 @@ def init(self) -> None: def evaluate(self, image: np.ndarray) -> Detections: logging.info('evaluating') return self.det_to_return - - # return Detections( - # box_detections=[BoxDetection(category_name='some_category_name', x=1, y=2, height=3, width=4, - # model_name='some_model', confidence=.42, category_id='some_id')], - # point_detections=[PointDetection(category_name='some_category_name_2', x=10, y=12, - # model_name='some_model', confidence=.42, category_id='some_id')] - # ) diff --git a/learning_loop_node/globals.py b/learning_loop_node/globals.py index eee9511a..336df3fa 100644 --- a/learning_loop_node/globals.py +++ b/learning_loop_node/globals.py @@ -1,8 +1,8 @@ class Globals(): - def __init__(self): + def __init__(self) -> None: self.data_folder: str = '/data' - self.detector_port: int = 5004 # TODO move to tests + self.detector_port: int = 5004 # NOTE used for tests GLOBALS = Globals() diff --git a/learning_loop_node/helpers/gdrive_downloader.py b/learning_loop_node/helpers/gdrive_downloader.py index 8e5b3120..deefed68 100755 --- a/learning_loop_node/helpers/gdrive_downloader.py +++ b/learning_loop_node/helpers/gdrive_downloader.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -import requests +import requests # type: ignore # https://stackoverflow.com/a/39225272/4082686 diff --git a/learning_loop_node/helpers/misc.py b/learning_loop_node/helpers/misc.py index 3eda99c5..aea20e60 100644 --- a/learning_loop_node/helpers/misc.py +++ b/learning_loop_node/helpers/misc.py @@ -1,14 +1,21 @@ """original copied from https://quantlane.com/blog/ensure-asyncio-task-exceptions-get-logged/""" import asyncio import functools +import json import logging import os +import shutil +import sys from dataclasses import asdict +from glob import glob +from time import perf_counter from typing import Any, Coroutine, List, Optional, Tuple, TypeVar +from uuid import UUID, uuid4 import pynvml -from ..data_classes import SocketResponse +from ..data_classes import Context, SocketResponse, Training +from ..globals import GLOBALS T = TypeVar('T') @@ -48,7 +55,7 @@ def _handle_task_result(task: asyncio.Task, logger.exception(message, *message_args) -def get_free_memory_mb() -> float: # TODO check if this is used +def get_free_memory_mb() -> float: # NOTE used by yolov5 pynvml.nvmlInit() h = pynvml.nvmlDeviceGetHandleByIndex(0) info = pynvml.nvmlDeviceGetMemoryInfo(h) @@ -56,16 +63,33 @@ def get_free_memory_mb() -> float: # TODO check if this is used return free +async def is_valid_image(filename: str, check_jpeg: bool) -> bool: + if not os.path.isfile(filename) or os.path.getsize(filename) == 0: + return False + if not check_jpeg: + return True + + info = await asyncio.create_subprocess_shell(f'jpeginfo -c {filename}', + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE) + out, _ = await info.communicate() + return "OK" in out.decode() + + +async def delete_corrupt_images(image_folder: str, check_jpeg: bool = False) -> None: + logging.info('deleting corrupt images') + n_deleted = 0 + for image in glob(f'{image_folder}/*.jpg'): + if not await is_valid_image(image, check_jpeg): + logging.debug(f' deleting image {image}') + os.remove(image) + n_deleted += 1 + + logging.info(f'deleted {n_deleted} images') + + def create_resource_paths(organization_name: str, project_name: str, image_ids: List[str]) -> Tuple[List[str], List[str]]: - # TODO: experimental: return [f'/{organization_name}/projects/{project_name}/images/{id}/main' for id in image_ids], image_ids - # if not image_ids: - # return [], [] - # url_ids: List[Tuple(str, str)] = [(f'/{organization_name}/projects/{project_name}/images/{id}/main', id) - # for id in image_ids] - # urls, ids = list(map(list, zip(*url_ids))) - - # return urls, ids def create_image_folder(project_folder: str) -> str: @@ -74,6 +98,24 @@ def create_image_folder(project_folder: str) -> str: return image_folder +def read_or_create_uuid(identifier: str) -> str: + identifier = identifier.lower().replace(' ', '_') + uuids = {} + os.makedirs(GLOBALS.data_folder, exist_ok=True) + file_path = f'{GLOBALS.data_folder}/uuids.json' + if os.path.exists(file_path): + with open(file_path, 'r') as f: + uuids = json.load(f) + + uuid = uuids.get(identifier, None) + if not uuid: + uuid = str(uuid4()) + uuids[identifier] = uuid + with open(file_path, 'w') as f: + json.dump(uuids, f) + return uuid + + def ensure_socket_response(func): """Decorator to ensure that the return value of a socket.io event handler is a SocketResponse. @@ -90,20 +132,85 @@ async def wrapper_ensure_socket_response(*args, **kwargs): if isinstance(value, str): return asdict(SocketResponse.for_success(value)) - elif isinstance(value, bool): + if isinstance(value, bool): return asdict(SocketResponse.from_bool(value)) - elif isinstance(value, SocketResponse): + if isinstance(value, SocketResponse): return value - elif (args[0] in ['connect', 'disconnect', 'connect_error']): + if (args[0] in ['connect', 'disconnect', 'connect_error']): return value - elif value is None: + if value is None: return None - else: - raise Exception( - f"Return type for sio must be str, bool, SocketResponse or None', but was {type(value)}'") + + raise Exception( + f"Return type for sio must be str, bool, SocketResponse or None', but was {type(value)}'") except Exception as e: logging.exception(f'An error occured for {args[0]}') return asdict(SocketResponse.for_failure(str(e))) return wrapper_ensure_socket_response + + +def is_valid_uuid4(val): + if not val: + return False + try: + _ = UUID(str(val)).version + return True + except ValueError: + return False + + +def create_project_folder(context: Context) -> str: + project_folder = f'{GLOBALS.data_folder}/{context.organization}/{context.project}' + os.makedirs(project_folder, exist_ok=True) + return project_folder + + +def activate_asyncio_warnings() -> None: + '''Produce warnings for coroutines which take too long on the main loop and hence clog the event loop''' + try: + if sys.version_info.major >= 3 and sys.version_info.minor >= 7: # most + loop = asyncio.get_running_loop() + else: + loop = asyncio.get_event_loop() + + loop.set_debug(True) + loop.slow_callback_duration = 0.2 + logging.info('activated asyncio warnings') + except Exception: + logging.exception('could not activate asyncio warnings. Exception:') + + +def images_for_ids(image_ids, image_folder) -> List[str]: + logging.info(f'### Going to get images for {len(image_ids)} images ids') + start = perf_counter() + images = [img for img in glob(f'{image_folder}/**/*.*', recursive=True) + if os.path.splitext(os.path.basename(img))[0] in image_ids] + end = perf_counter() + logging.info(f'found {len(images)} images for {len(image_ids)} image ids, which took {end-start:0.2f} seconds') + return images + + +def generate_training(project_folder: str, context: Context) -> Training: + training_uuid = str(uuid4()) + return Training( + id=training_uuid, + context=context, + project_folder=project_folder, + images_folder=create_image_folder(project_folder), + training_folder=create_training_folder(project_folder, training_uuid) + ) + + +def delete_all_training_folders(project_folder: str): + if not os.path.exists(f'{project_folder}/trainings'): + return + for uuid in os.listdir(f'{project_folder}/trainings'): + shutil.rmtree(f'{project_folder}/trainings/{uuid}', ignore_errors=True) + + +def create_training_folder(project_folder: str, trainings_id: str) -> str: + training_folder = f'{project_folder}/trainings/{trainings_id}' + os.makedirs(training_folder, exist_ok=True) + return training_folder diff --git a/learning_loop_node/loop_communication.py b/learning_loop_node/loop_communication.py index d4b3dadf..a643fec4 100644 --- a/learning_loop_node/loop_communication.py +++ b/learning_loop_node/loop_communication.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import List, Optional +from typing import Awaitable, Callable, List, Optional import httpx from httpx import Cookies, Timeout @@ -24,21 +24,21 @@ def __init__(self) -> None: self.project: str = environment_reader.project() # used by mock_detector self.base_url: str = f'http{"s" if "learning-loop.ai" in host else ""}://' + host self.async_client: httpx.AsyncClient = httpx.AsyncClient(base_url=self.base_url, timeout=Timeout(60.0)) + self.async_client.cookies.clear() logging.info(f'Loop interface initialized with base_url: {self.base_url} / user: {self.username}') - # @property - # def project_path(self): # TODO: remove? - # return f'/{self.organization}/projects/{self.project}' + def websocket_url(self) -> str: + return f'ws{"s" if "learning-loop.ai" in self.host else ""}://' + self.host - async def ensure_login(self) -> None: + async def ensure_login(self, relogin=False) -> None: """aiohttp client session needs to be created on the event loop""" assert not self.async_client.is_closed, 'async client must not be used after shutdown' - if not self.async_client.cookies.keys(): + if not self.async_client.cookies.keys() or relogin: + self.async_client.cookies.clear() response = await self.async_client.post('/api/login', data={'username': self.username, 'password': self.password}) if response.status_code != 200: - self.async_client.cookies.clear() logging.info(f'Login failed with response: {response}') raise LoopCommunicationException('Login failed with response: ' + str(response)) self.async_client.cookies.update(response.cookies) @@ -50,8 +50,9 @@ async def logout(self) -> None: if response.status_code != 200: logging.info(f'Logout failed with response: {response}') raise LoopCommunicationException('Logout failed with response: ' + str(response)) + self.async_client.cookies.clear() - async def get_cookies(self) -> Cookies: + def get_cookies(self) -> Cookies: return self.async_client.cookies async def shutdown(self): @@ -70,37 +71,68 @@ async def backend_ready(self) -> bool: logging.info(f'backend not ready: {e}') await asyncio.sleep(10) + async def retry_on_401(self, func: Callable[..., Awaitable[httpx.Response]], *args, **kwargs) -> httpx.Response: + response = await func(*args, **kwargs) + if response.status_code == 401: + await self.ensure_login(relogin=True) + response = await func(*args, **kwargs) + return response + async def get(self, path: str, requires_login: bool = True, api_prefix: str = '/api') -> httpx.Response: if requires_login: await self.ensure_login() + return await self.retry_on_401(self._get, path, api_prefix) + else: + return await self._get(path, api_prefix) + + async def _get(self, path: str, api_prefix: str) -> httpx.Response: return await self.async_client.get(api_prefix+path) - async def put(self, path, files: Optional[List[str]]=None, requires_login=True, api_prefix='/api', **kwargs) -> httpx.Response: + async def put(self, path: str, files: Optional[List[str]] = None, requires_login: bool = True, api_prefix: str = '/api', **kwargs) -> httpx.Response: if requires_login: await self.ensure_login() + return await self.retry_on_401(self._put, path, files, api_prefix, **kwargs) + else: + return await self._put(path, files, api_prefix, **kwargs) + + async def _put(self, path: str, files: Optional[List[str]], api_prefix: str, **kwargs) -> httpx.Response: if files is None: return await self.async_client.put(api_prefix+path, **kwargs) - - file_list = [('files', open(f, 'rb')) for f in files] # TODO: does this properly close the files after upload? - return await self.async_client.put(api_prefix+path, files=file_list) - async def post(self, path, requires_login=True, api_prefix='/api', **kwargs) -> httpx.Response: + file_handles = [] + for f in files: + try: + file_handles.append(open(f, 'rb')) + except FileNotFoundError: + for fh in file_handles: + fh.close() # Ensure all files are closed + return httpx.Response(404, content=b'File not found') + + try: + file_list = [('files', fh) for fh in file_handles] # Use file handles + response = await self.async_client.put(api_prefix+path, files=file_list) + finally: + for fh in file_handles: + fh.close() # Ensure all files are closed + + return response + + async def post(self, path: str, requires_login: bool = True, api_prefix: str = '/api', **kwargs) -> httpx.Response: if requires_login: await self.ensure_login() + return await self.retry_on_401(self._post, path, api_prefix, **kwargs) + else: + return await self._post(path, api_prefix, **kwargs) + + async def _post(self, path, api_prefix='/api', **kwargs) -> httpx.Response: return await self.async_client.post(api_prefix+path, **kwargs) - async def delete(self, path, requires_login=True, api_prefix='/api', **kwargs) -> httpx.Response: + async def delete(self, path: str, requires_login: bool = True, api_prefix: str = '/api', **kwargs) -> httpx.Response: if requires_login: await self.ensure_login() - return await self.async_client.delete(api_prefix+path, **kwargs) - - # --------------------------------- unused?! --------------------------------- #TODO remove? + return await self.retry_on_401(self._delete, path, api_prefix, **kwargs) + else: + return await self._delete(path, api_prefix, **kwargs) - # def get_data(self, path): - # return asyncio.get_event_loop().run_until_complete(self._get_data_async(path)) - - # async def _get_data_async(self, path) -> bytes: - # response = await self.get(f'{self.project_path}{path}') - # if response.status_code != 200: - # raise LoopCommunicationException('bad response: ' + str(response)) - # return response.content + async def _delete(self, path, api_prefix, **kwargs) -> httpx.Response: + return await self.async_client.delete(api_prefix+path, **kwargs) diff --git a/learning_loop_node/node.py b/learning_loop_node/node.py index ffce72f7..9418123e 100644 --- a/learning_loop_node/node.py +++ b/learning_loop_node/node.py @@ -1,58 +1,58 @@ import asyncio -import json import logging -import os import sys from abc import abstractmethod +from contextlib import asynccontextmanager from datetime import datetime -from typing import Optional -from uuid import uuid4 +from typing import Any, Optional import aiohttp import socketio from fastapi import FastAPI -from fastapi_utils.tasks import repeat_every from socketio import AsyncClient -from .data_classes import Context, NodeState, NodeStatus +from .data_classes import NodeStatus from .data_exchanger import DataExchanger -from .globals import GLOBALS -from .helpers import environment_reader, log_conf -from .helpers.misc import ensure_socket_response +from .helpers import log_conf +from .helpers.misc import activate_asyncio_warnings, ensure_socket_response, read_or_create_uuid from .loop_communication import LoopCommunicator class Node(FastAPI): - def __init__(self, name: str, uuid: Optional[str] = None): + def __init__(self, name: str, uuid: Optional[str] = None, node_type: str = 'node', needs_login: bool = True): """Base class for all nodes. A node is a process that communicates with the zauberzeug learning loop. + This class provides the basic functionality to connect to the learning loop via socket.io and to exchange data. Args: name (str): The name of the node. This name is used to generate a uuid. uuid (Optional[str]): The uuid of the node. If None, a uuid is generated based on the name and stored in f'{GLOBALS.data_folder}/uuids.json'. - From the second run, the uuid is recovered based on the name of the node. Defaults to None. + From the second run, the uuid is recovered based on the name of the node. + needs_login (bool): If True, the node will try to login to the learning loop. """ - super().__init__() + super().__init__(lifespan=self.lifespan) log_conf.init() + self.name = name + self.uuid = uuid or read_or_create_uuid(self.name) + self.needs_login = needs_login + self.log = logging.getLogger() self.loop_communicator = LoopCommunicator() + self.websocket_url = self.loop_communicator.websocket_url() self.data_exchanger = DataExchanger(None, self.loop_communicator) - host = environment_reader.host(default='learning-loop.ai') - self.ws_url = f'ws{"s" if "learning-loop.ai" in host else ""}://' + host - - self.name = name - self.uuid = self.read_or_create_uuid(self.name) if uuid is None else uuid - self.startup_time = datetime.now() + self.startup_datetime = datetime.now() self._sio_client: Optional[AsyncClient] = None self.status = NodeStatus(id=self.uuid, name=self.name) - # NOTE this is can be set to False for Nodes which do not need to authenticate with the backend (like the DetectorNode) - self.needs_login = True - self._setup_sio_headers() - self._register_lifecycle_events() + + self.sio_headers = {'organization': self.loop_communicator.organization, + 'project': self.loop_communicator.project, + 'nodeType': node_type} + + self.repeat_task: Any = None @property def sio_client(self) -> AsyncClient: @@ -60,52 +60,25 @@ def sio_client(self) -> AsyncClient: raise Exception('sio_client not yet initialized') return self._sio_client - def sio_is_initialized(self) -> bool: - return self._sio_client is not None - - # --------------------------------------------------- INIT --------------------------------------------------- - - def read_or_create_uuid(self, identifier: str) -> str: - identifier = identifier.lower().replace(' ', '_') - uuids = {} - os.makedirs(GLOBALS.data_folder, exist_ok=True) - file_path = f'{GLOBALS.data_folder}/uuids.json' - if os.path.exists(file_path): - with open(file_path, 'r') as f: - uuids = json.load(f) - - uuid = uuids.get(identifier, None) - if not uuid: - uuid = str(uuid4()) - uuids[identifier] = uuid - with open(file_path, 'w') as f: - json.dump(uuids, f) - return uuid - - def _setup_sio_headers(self) -> None: - self.sio_headers = {'organization': self.loop_communicator.organization, - 'project': self.loop_communicator.project, - 'nodeType': self.get_node_type()} - # --------------------------------------------------- APPLICATION LIFECYCLE --------------------------------------------------- - - def _register_lifecycle_events(self): - @self.on_event("startup") - async def startup(): + @asynccontextmanager + async def lifespan(self, app: FastAPI): # pylint: disable=unused-argument + try: await self._on_startup() - - @self.on_event("shutdown") # NOTE only used for developent ?! - async def shutdown(): + self.repeat_task = asyncio.create_task(self.repeat_loop()) + yield + finally: await self._on_shutdown() - - @self.on_event("startup") - @repeat_every(seconds=5, raise_exceptions=False, wait_first=False) - async def ensure_connected() -> None: - await self._on_repeat() + if self.repeat_task is not None: + self.repeat_task.cancel() + try: + await self.repeat_task + except asyncio.CancelledError: + pass async def _on_startup(self): self.log.info('received "startup" lifecycle-event') - Node._activate_asyncio_warnings() + # activate_asyncio_warnings() if self.needs_login: await self.loop_communicator.backend_ready() self.log.info('ensuring login') @@ -123,10 +96,18 @@ async def _on_shutdown(self): self.log.info('successfully disconnected from loop.') await self.on_shutdown() + async def repeat_loop(self) -> None: + """NOTE: with the lifespan approach, we cannot use @repeat_every anymore :(""" + while True: + try: + await self._on_repeat() + except asyncio.CancelledError: + return + except Exception as e: + self.log.exception(f'error in repeat loop: {e}') + await asyncio.sleep(5) + async def _on_repeat(self): - while not self.sio_is_initialized(): - self.log.info('Waiting for sio client to be initialized') - await asyncio.sleep(1) if not self.sio_client.connected: self.log.info('Reconnecting to loop via sio') await self.connect_sio() @@ -138,8 +119,11 @@ async def _on_repeat(self): # --------------------------------------------------- SOCKET.IO --------------------------------------------------- async def create_sio_client(self): - cookies = await self.loop_communicator.get_cookies() - self._sio_client = AsyncClient(request_timeout=20, http_session=aiohttp.ClientSession(cookies=cookies)) + """Create a socket.io client that communicates with the learning loop and register the events. + Note: The method is called in startup and soft restart of detector, so the _sio_client should always be available.""" + + self._sio_client = AsyncClient(request_timeout=20, + http_session=aiohttp.ClientSession(cookies=self.loop_communicator.get_cookies())) # pylint: disable=protected-access self.sio_client._trigger_event = ensure_socket_response(self.sio_client._trigger_event) @@ -147,72 +131,39 @@ async def create_sio_client(self): @self._sio_client.event async def connect(): self.log.info('received "connect" via sio from loop.') - self.status = NodeStatus(id=self.uuid, name=self.name) - state = await self.get_state() - try: - await self._update_send_state(state) - except: - self.log.exception('Error sending state. Exception:') - raise @self._sio_client.event async def disconnect(): self.log.info('received "disconnect" via sio from loop.') - await self._update_send_state(NodeState.Offline) @self._sio_client.event async def restart(): - self.log.info('received "restart" via sio from loop.') - self.restart() + self.log.info('received "restart" via sio from loop -> restarting node.') + sys.exit(0) self.register_sio_events(self._sio_client) async def connect_sio(self): - if not self.sio_is_initialized(): - self.log.warning('sio client not yet initialized') - return try: await self.sio_client.disconnect() except Exception: pass - self.log.info(f'(re)connecting to Learning Loop at {self.ws_url}') + self.log.info(f'(re)connecting to Learning Loop at {self.websocket_url}') try: - await self.sio_client.connect(f"{self.ws_url}", headers=self.sio_headers, socketio_path="/ws/socket.io") + await self.sio_client.connect(f"{self.websocket_url}", headers=self.sio_headers, socketio_path="/ws/socket.io") self.log.info('connected to Learning Loop') except socketio.exceptions.ConnectionError: # type: ignore self.log.warning('connection error') except Exception: - self.log.exception(f'error while connecting to "{self.ws_url}". Exception:') - - async def _update_send_state(self, state: NodeState): - self.status.state = state - if self.status.state != NodeState.Offline: - await self.send_status() + self.log.exception(f'error while connecting to "{self.websocket_url}". Exception:') # --------------------------------------------------- ABSTRACT METHODS --------------------------------------------------- - @abstractmethod - def register_sio_events(self, sio_client: AsyncClient): - """Register socket.io events for the communication with the learning loop. - The events: connect and disconnect are already registered and should not be overwritten.""" - - @abstractmethod - async def send_status(self): - """Send the current status to the learning loop. - Note that currently this method is also used to react to the response of the learning loop.""" - - @abstractmethod - async def get_state(self) -> NodeState: - """Return the current state of the node.""" - - @abstractmethod - def get_node_type(self): - pass - @abstractmethod async def on_startup(self): - """This method is called when the node is started.""" + """This method is called when the node is started. + Note: In this method the sio connection is not yet established!""" @abstractmethod async def on_shutdown(self): @@ -221,32 +172,8 @@ async def on_shutdown(self): @abstractmethod async def on_repeat(self): """This method is called every 10 seconds.""" - # --------------------------------------------------- SHARED FUNCTIONS --------------------------------------------------- - - def restart(self): - """Restart the node.""" - self.log.info('restarting node') - sys.exit(0) - - # --------------------------------------------------- HELPER --------------------------------------------------- - - @staticmethod - def create_project_folder(context: Context) -> str: - project_folder = f'{GLOBALS.data_folder}/{context.organization}/{context.project}' - os.makedirs(project_folder, exist_ok=True) - return project_folder - @staticmethod - def _activate_asyncio_warnings() -> None: - '''Produce warnings for coroutines which take too long on the main loop and hence clog the event loop''' - try: - if sys.version_info.major >= 3 and sys.version_info.minor >= 7: # most - loop = asyncio.get_running_loop() - else: - loop = asyncio.get_event_loop() - - loop.set_debug(True) - loop.slow_callback_duration = 0.2 - logging.info('activated asyncio warnings') - except Exception: - logging.exception('could not activate asyncio warnings. Exception:') + @abstractmethod + def register_sio_events(self, sio_client: AsyncClient): + """Register (additional) socket.io events for the communication with the learning loop. + The events: connect, disconnect and restart are already registered and should not be overwritten.""" diff --git a/learning_loop_node/converter/__init__.py b/learning_loop_node/py.typed similarity index 100% rename from learning_loop_node/converter/__init__.py rename to learning_loop_node/py.typed diff --git a/learning_loop_node/tests/test_downloader.py b/learning_loop_node/tests/test_downloader.py index bf2e10e8..43ee4c6f 100644 --- a/learning_loop_node/tests/test_downloader.py +++ b/learning_loop_node/tests/test_downloader.py @@ -2,9 +2,10 @@ import shutil from learning_loop_node.data_classes import Context -from learning_loop_node.data_exchanger import DataExchanger, check_jpeg +from learning_loop_node.data_exchanger import DataExchanger from learning_loop_node.globals import GLOBALS +from ..helpers.misc import delete_corrupt_images from . import test_helper @@ -33,26 +34,26 @@ async def test_download_model(data_exchanger: DataExchanger): # pylint: disable=redefined-outer-name async def test_fetching_image_ids(data_exchanger: DataExchanger): - ids = await data_exchanger.fetch_image_ids() + ids = await data_exchanger.fetch_image_uuids() assert len(ids) == 3 async def test_download_images(data_exchanger: DataExchanger): _, image_folder, _ = test_helper.create_needed_folders() - image_ids = await data_exchanger.fetch_image_ids() + image_ids = await data_exchanger.fetch_image_uuids() await data_exchanger.download_images(image_ids, image_folder) files = test_helper.get_files_in_folder(GLOBALS.data_folder) assert len(files) == 3 async def test_download_training_data(data_exchanger: DataExchanger): - image_ids = await data_exchanger.fetch_image_ids() + image_ids = await data_exchanger.fetch_image_uuids() image_data = await data_exchanger.download_images_data(image_ids) assert len(image_data) == 3 async def test_removal_of_corrupted_images(data_exchanger: DataExchanger): - image_ids = await data_exchanger.fetch_image_ids() + image_ids = await data_exchanger.fetch_image_uuids() shutil.rmtree('/tmp/img_folder', ignore_errors=True) os.makedirs('/tmp/img_folder', exist_ok=True) @@ -65,7 +66,7 @@ async def test_removal_of_corrupted_images(data_exchanger: DataExchanger): with open('/tmp/img_folder/c1.jpg', 'w') as f: f.write('I am no image') - await data_exchanger.delete_corrupt_images('/tmp/img_folder') + await delete_corrupt_images('/tmp/img_folder', True) - assert len(os.listdir('/tmp/img_folder')) == num_images if check_jpeg else num_images - 1 + assert len(os.listdir('/tmp/img_folder')) == num_images if data_exchanger.check_jpeg else num_images - 1 shutil.rmtree('/tmp/img_folder', ignore_errors=True) diff --git a/learning_loop_node/tests/test_executor.py b/learning_loop_node/tests/test_executor.py index b661c818..1dbae97c 100644 --- a/learning_loop_node/tests/test_executor.py +++ b/learning_loop_node/tests/test_executor.py @@ -21,26 +21,28 @@ def cleanup(): cleanup_process.communicate() -def test_executor_lifecycle(): +@pytest.mark.asyncio +async def test_executor_lifecycle(): assert_process_is_running('some_executable.sh', False) - executor = Executor('/tmp/test_executor/' + str(uuid4())) - cmd = executor.path + '/some_executable.sh' - with open(cmd, 'w') as f: - f.write('while true; do echo "some output"; sleep 1; done') - os.chmod(cmd, 0o755) + executor = Executor('/tmp/test_executor/' + str(uuid4())+'/') + cmd = 'bash some_executable.sh' + executable_path = executor.path+'some_executable.sh' + with open(executable_path, 'w') as f: + f.write('/bin/bash -c "while true; do sleep 1; echo some output; done"') + os.chmod(executable_path, 0o755) - executor.start(cmd) + await executor.start(cmd) - assert executor.is_process_running() + assert executor.is_running() assert_process_is_running('some_executable.sh') - sleep(1) + sleep(5) assert 'some output' in executor.get_log() - executor.stop() + await executor.stop_and_wait() - assert not executor.is_process_running() + assert not executor.is_running() sleep(1) assert_process_is_running('some_executable.sh', False) @@ -48,6 +50,7 @@ def test_executor_lifecycle(): def assert_process_is_running(process_name, running=True): if running: for process in psutil.process_iter(): + print(process.name(), process.cmdline()) process_name_match = process_name in process.name() process_cmd_match = process_name in str(process.cmdline()) if process_name_match or process_cmd_match: diff --git a/learning_loop_node/tests/test_helper.py b/learning_loop_node/tests/test_helper.py index 88a94af2..c52037ed 100644 --- a/learning_loop_node/tests/test_helper.py +++ b/learning_loop_node/tests/test_helper.py @@ -7,10 +7,8 @@ from typing import Callable from learning_loop_node.data_classes import Context -from learning_loop_node.helpers.misc import create_image_folder +from learning_loop_node.helpers.misc import create_image_folder, create_project_folder, create_training_folder from learning_loop_node.loop_communication import LoopCommunicator -from learning_loop_node.node import Node -from learning_loop_node.trainer.trainer_logic import TrainerLogic def get_files_in_folder(folder: str): @@ -65,8 +63,8 @@ def _update_attribute_dict(obj: dict, **kwargs) -> None: def create_needed_folders(training_uuid: str = 'some_uuid'): # pylint: disable=unused-argument - project_folder = Node.create_project_folder( + project_folder = create_project_folder( Context(organization='zauberzeug', project='pytest')) image_folder = create_image_folder(project_folder) - training_folder = TrainerLogic.create_training_folder(project_folder, training_uuid) + training_folder = create_training_folder(project_folder, training_uuid) return project_folder, image_folder, training_folder diff --git a/learning_loop_node/trainer/downloader.py b/learning_loop_node/trainer/downloader.py index 94cd0516..7deb59cf 100644 --- a/learning_loop_node/trainer/downloader.py +++ b/learning_loop_node/trainer/downloader.py @@ -12,7 +12,7 @@ def __init__(self, data_exchanger: DataExchanger, data_query_params: Optional[st self.data_exchanger = data_exchanger async def download_training_data(self, image_folder: str) -> Tuple[List[Dict], int]: - image_ids = await self.data_exchanger.fetch_image_ids(query_params=self.data_query_params) + image_ids = await self.data_exchanger.fetch_image_uuids(query_params=self.data_query_params) image_data, skipped_image_count = await self.download_images_and_annotations(image_ids, image_folder) return (image_data, skipped_image_count) diff --git a/learning_loop_node/trainer/executor.py b/learning_loop_node/trainer/executor.py index c768332c..082407ad 100644 --- a/learning_loop_node/trainer/executor.py +++ b/learning_loop_node/trainer/executor.py @@ -1,105 +1,109 @@ - -import ctypes +import asyncio import logging import os -import signal -import subprocess -from sys import platform +import shlex +from io import BufferedWriter from typing import List, Optional -import psutil +class Executor: + def __init__(self, base_path: str, log_name='last_training.log') -> None: + """An executor that runs a command in a separate async subprocess. + The log of the process is written to 'last_training.log' in the base_path. + Tthe process is executed in the base_path directory. + The process should be awaited to finish using `wait` or stopped using `stop` to + avoid zombie processes and close the log file.""" -def create_signal_handler(sig=signal.SIGTERM): - if platform == "linux" or platform == "linux2": - # "The system will send a signal to the child once the parent exits for any reason (even sigkill)." - # https://stackoverflow.com/a/19448096 - libc = ctypes.CDLL("libc.so.6") + self.path = base_path + self.log_file_path = f'{self.path}/{log_name}' + self.log_file: None | BufferedWriter = None + self._process: Optional[asyncio.subprocess.Process] = None # pylint: disable=no-member + os.makedirs(self.path, exist_ok=True) - def callable_(): - os.setsid() - return libc.prctl(1, sig) + def _get_running_process(self) -> Optional[asyncio.subprocess.Process]: # pylint: disable=no-member + """Get the running process if available.""" + if self._process is not None and self._process.returncode is None: + return self._process + return None - return callable_ - return os.setsid + async def start(self, cmd: str, env: Optional[dict[str, str]] = None) -> None: + """Start the process with the given command and environment variables.""" + full_env = os.environ.copy() + if env is not None: + full_env.update(env) -class Executor: - def __init__(self, base_path: str) -> None: - self.path = base_path - os.makedirs(self.path, exist_ok=True) - self.process: Optional[subprocess.Popen[bytes]] = None - - def start(self, cmd: str): - with open(f'{self.path}/last_training.log', 'a') as f: - f.write(f'\nStarting executor with command: {cmd}\n') - # pylint: disable=subprocess-popen-preexec-fn - self.process = subprocess.Popen( - f'cd {self.path}; {cmd} >> last_training.log 2>&1', - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - executable='/bin/bash', - preexec_fn=create_signal_handler(), - ) + logging.info(f'Starting executor with command: {cmd} in {self.path} - logging to {self.log_file_path}') + self.log_file = open(self.log_file_path, 'ab') - def is_process_running(self): - if self.process is None: - return False + self._process = await asyncio.create_subprocess_exec( + *shlex.split(cmd), + cwd=self.path, + stdout=self.log_file, + stderr=asyncio.subprocess.STDOUT, # Merge stderr with stdout + env=full_env + ) - if self.process.poll() is not None: - return False + def is_running(self) -> bool: + """Check if the process is still running.""" + return self._process is not None and self._process.returncode is None - try: - psutil.Process(self.process.pid) - except psutil.NoSuchProcess: - # self.process.terminate() # TODO does this make sense? - # self.process = None - return False + def terminate(self) -> None: + """Terminate the process.""" - return True + if process := self._get_running_process(): + try: + process.terminate() + return + except ProcessLookupError: + logging.error('No process to terminate') + self._process = None - def get_log(self) -> str: - try: - with open(f'{self.path}/last_training.log') as f: - return f.read() - except Exception: - return '' + async def wait(self) -> Optional[int]: + """Wait for the process to finish. Returns the return code of the process or None if no process is running.""" - def get_log_by_lines(self, since_last_start=False) -> List[str]: # TODO do not read whole log again - try: - with open(f'{self.path}/last_training.log') as f: - lines = f.readlines() - if since_last_start: - lines_since_last_start = [] - for line in reversed(lines): - lines_since_last_start.append(line) - if line.startswith('Starting executor with command:'): - break - return list(reversed(lines_since_last_start)) - return lines - except Exception: - return [] + if not self._process: + logging.info('No process to wait for') + return None - def stop(self): - if self.process is None: - logging.info('no process running ... nothing to stop') - return + return_code = await self._process.wait() - logging.info('terminating process') + self.close_log() + self._process = None - try: - os.killpg(os.getpgid(self.process.pid), signal.SIGTERM) - except ProcessLookupError: - pass + return return_code - self.process.terminate() - _, _ = self.process.communicate(timeout=3) + async def stop_and_wait(self) -> Optional[int]: + """Terminate the process and wait for it to finish. Returns the return code of the process.""" - @property - def return_code(self): - if not self.process: - return None - if self.is_process_running(): + if not self.is_running(): + logging.info('No process to stop') return None - return self.process.poll() + + self.terminate() + return await self.wait() + + # -------------------------------------------------------------------------------------------- LOGGING + + def get_log(self) -> str: + """Get the log of the process as a string.""" + if not os.path.exists(self.log_file_path): + return '' + with open(self.log_file_path, 'r') as f: + return f.read() + + def get_log_by_lines(self, tail: Optional[int] = None) -> List[str]: + """Get the log of the process as a list of lines.""" + if not os.path.exists(self.log_file_path): + return [] + with open(self.log_file_path) as f: + lines = f.readlines() + if tail is not None: + lines = lines[-tail:] + return lines + + def close_log(self): + """Close the log file.""" + if self.log_file is not None: + self.log_file.close() + self.log_file = None diff --git a/learning_loop_node/trainer/io_helpers.py b/learning_loop_node/trainer/io_helpers.py index 3755f2f2..4849d67a 100644 --- a/learning_loop_node/trainer/io_helpers.py +++ b/learning_loop_node/trainer/io_helpers.py @@ -1,5 +1,6 @@ import json +import logging import os from dataclasses import asdict from pathlib import Path @@ -8,8 +9,19 @@ from dacite import from_dict from fastapi.encoders import jsonable_encoder -from ..data_classes import Detections, Training +from ..data_classes import Context, Detections, Training from ..globals import GLOBALS +from ..loop_communication import LoopCommunicator + + +class EnvironmentVars: + def __init__(self) -> None: + self.restart_after_training = os.environ.get( + 'RESTART_AFTER_TRAINING', 'FALSE').lower() in ['true', '1'] + self.keep_old_trainings = os.environ.get( + 'KEEP_OLD_TRAININGS', 'FALSE').lower() in ['true', '1'] + self.inference_batch_size = int( + os.environ.get('INFERENCE_BATCH_SIZE', '10')) class LastTrainingIO: @@ -35,13 +47,16 @@ def exists(self) -> bool: class ActiveTrainingIO: - @staticmethod - def create_mocked_training_io() -> 'ActiveTrainingIO': - training_folder = '' - return ActiveTrainingIO(training_folder) + # @staticmethod + # def create_mocked_training_io() -> 'ActiveTrainingIO': + # training_folder = '' + # return ActiveTrainingIO(training_folder) - def __init__(self, training_folder: str): + def __init__(self, training_folder: str, loop_communicator: LoopCommunicator, context: Context) -> None: self.training_folder = training_folder + self.loop_communicator = loop_communicator + self.context = context + self.mup_path = f'{training_folder}/model_uploading_progress.txt' # string with placeholder gor index self.det_path = f'{training_folder}' + '/detections_{0}.json' @@ -63,13 +78,16 @@ def load_model_upload_progress(self) -> List[str]: # detections - def get_detection_file_names(self) -> List[Path]: + def _get_detection_file_names(self) -> List[Path]: files = [f for f in Path(self.training_folder).iterdir() if f.is_file() and f.name.startswith('detections_')] if not files: return [] return files + def get_number_of_detection_files(self) -> int: + return len(self._get_detection_file_names()) + # TODO: saving and uploading multiple files is not tested! def save_detections(self, detections: List[Detections], index: int = 0) -> None: with open(self.det_path.format(index), 'w') as f: @@ -81,11 +99,11 @@ def load_detections(self, index: int = 0) -> List[Detections]: return [from_dict(data_class=Detections, data=d) for d in dict_list] def delete_detections(self) -> None: - for file in self.get_detection_file_names(): + for file in self._get_detection_file_names(): os.remove(Path(self.training_folder) / file) def detections_exist(self) -> bool: - return bool(self.get_detection_file_names()) + return bool(self._get_detection_file_names()) # detections upload file index @@ -124,3 +142,42 @@ def delete_detection_upload_progress(self) -> None: def detection_upload_progress_exist(self) -> bool: return os.path.exists(self.dup_path) + + async def upload_detetions(self): + num_files = self.get_number_of_detection_files() + print(f'num_files: {num_files}', flush=True) + if not num_files: + logging.error('no detection files found') + return + current_json_file_index = self.load_detections_upload_file_index() + for i in range(current_json_file_index, num_files): + detections = self.load_detections(i) + logging.info(f'uploading detections {i}/{num_files}') + await self._upload_detections_batched(self.context, detections) + self.save_detections_upload_file_index(i+1) + + async def _upload_detections_batched(self, context: Context, detections: List[Detections]): + batch_size = 10 + skip_detections = self.load_detection_upload_progress() + for i in range(skip_detections, len(detections), batch_size): + up_progress = i+batch_size + batch_detections = detections[i:up_progress] + dict_detections = [jsonable_encoder(asdict(detection)) for detection in batch_detections] + logging.info(f'uploading detections. File size : {len(json.dumps(dict_detections))}') + await self._upload_detections(context, batch_detections, up_progress) + skip_detections = up_progress + + async def _upload_detections(self, context: Context, batch_detections: List[Detections], up_progress: int): + detections_json = [jsonable_encoder(asdict(detections)) for detections in batch_detections] + response = await self.loop_communicator.post( + f'/{context.organization}/projects/{context.project}/detections', json=detections_json) + if response.status_code != 200: + msg = f'could not upload detections. {str(response)}' + logging.error(msg) + raise Exception(msg) + + logging.info('successfully uploaded detections') + if up_progress > len(batch_detections): + self.save_detection_upload_progress(0) + else: + self.save_detection_upload_progress(up_progress) diff --git a/learning_loop_node/trainer/rest/backdoor_controls.py b/learning_loop_node/trainer/rest/backdoor_controls.py index 8349e737..e2dafc26 100644 --- a/learning_loop_node/trainer/rest/backdoor_controls.py +++ b/learning_loop_node/trainer/rest/backdoor_controls.py @@ -5,10 +5,10 @@ from dataclasses import asdict from typing import TYPE_CHECKING, Dict -from dacite import from_dict from fastapi import APIRouter, HTTPException, Request from ...data_classes import ErrorConfiguration, NodeState +from ..trainer_logic import TrainerLogic if TYPE_CHECKING: from ..trainer_node import TrainerNode @@ -95,7 +95,9 @@ async def add_steps(request: Request): trainer_node = trainer_node_from_request(request) trainer_logic = trainer_node.trainer_logic # NOTE: is MockTrainerLogic which has 'provide_new_model' and 'current_iteration' - if not trainer_logic._executor or not trainer_logic._executor.is_process_running(): # pylint: disable=protected-access + assert isinstance(trainer_logic, TrainerLogic), 'trainer_logic is not TrainerLogic' + + if not trainer_logic._executor or not trainer_logic._executor.is_running(): # pylint: disable=protected-access training = trainer_logic._training # pylint: disable=protected-access logging.error(f'cannot add steps when not running, state: {training.training_state if training else "None"}') raise HTTPException(status_code=409, detail="trainer is not running") @@ -109,7 +111,7 @@ async def add_steps(request: Request): for _ in range(steps): try: logging.warning('calling sync_confusion_matrix') - await trainer_logic.sync_confusion_matrix() + await trainer_logic._sync_confusion_matrix() # pylint: disable=protected-access except Exception: pass # Tests can force synchroniation to fail, error state is reported to backend trainer_logic.provide_new_model = previous_state # type: ignore @@ -119,11 +121,14 @@ async def add_steps(request: Request): @router.post("/kill_training_process") async def kill_process(request: Request): + # pylint: disable=protected-access trainer_node = trainer_node_from_request(request) - if not trainer_node.trainer_logic._executor or not trainer_node.trainer_logic._executor.is_process_running(): + trainer_logic = trainer_node.trainer_logic + assert isinstance(trainer_logic, TrainerLogic), 'trainer_logic is not TrainerLogic' + if not trainer_logic._executor or not trainer_logic._executor.is_running(): raise HTTPException(status_code=409, detail="trainer is not running") - trainer_node.trainer_logic._executor.stop() + await trainer_logic._executor.stop_and_wait() @router.post("/force_status_update") diff --git a/learning_loop_node/trainer/rest/controls.py b/learning_loop_node/trainer/rest/controls.py index 17434d64..6c92d9a8 100644 --- a/learning_loop_node/trainer/rest/controls.py +++ b/learning_loop_node/trainer/rest/controls.py @@ -7,6 +7,8 @@ router = APIRouter() +# pylint: disable=protected-access + @router.post("/controls/detect/{organization}/{project}/{version}") async def operation_mode(organization: str, project: str, version: str, request: Request): @@ -22,5 +24,5 @@ async def operation_mode(organization: str, project: str, version: str, request: model_id = next(m for m in models if m['version'] == version)['id'] logging.info(model_id) trainer: TrainerLogic = request.app.trainer - await trainer.do_detections() + await trainer._do_detections() return "OK" diff --git a/learning_loop_node/trainer/tests/conftest.py b/learning_loop_node/trainer/tests/conftest.py index 75937920..aca1919c 100644 --- a/learning_loop_node/trainer/tests/conftest.py +++ b/learning_loop_node/trainer/tests/conftest.py @@ -10,6 +10,8 @@ from learning_loop_node.trainer.tests.testing_trainer_logic import TestingTrainerLogic from learning_loop_node.trainer.trainer_node import TrainerNode +# pylint: disable=protected-access + logging.basicConfig(level=logging.INFO) # show ouptut from uvicorn server https://stackoverflow.com/a/66132186/364388 log_to_stderr(logging.INFO) @@ -24,16 +26,14 @@ async def test_initialized_trainer_node(): trainer = TestingTrainerLogic() node = TrainerNode(name='test', trainer_logic=trainer, uuid='NOD30000-0000-0000-0000-000000000000') - trainer._node = node # pylint: disable=protected-access - trainer.init_new_training(context=Context(organization='zauberzeug', project='demo'), - details={'categories': [], - 'id': '917d5c7f-403d-7e92-f95f-577f79c2273a', # version 1.2 of demo project - 'training_number': 0, - 'resolution': 800, - 'flip_rl': False, - 'flip_ud': False}) - - # pylint: disable=protected-access + trainer._node = node + trainer._init_new_training(context=Context(organization='zauberzeug', project='demo'), + details={'categories': [], + 'id': '917d5c7f-403d-7e92-f95f-577f79c2273a', # version 1.2 of demo project + 'training_number': 0, + 'resolution': 800, + 'flip_rl': False, + 'flip_ud': False}) await node._on_startup() yield node await node._on_shutdown() @@ -44,19 +44,17 @@ async def test_initialized_trainer(): trainer = TestingTrainerLogic() node = TrainerNode(name='test', trainer_logic=trainer, uuid='NODE-000-0000-0000-0000-000000000000') - # pylint: disable=protected-access - await node._on_startup() - trainer._node = node # pylint: disable=protected-access - trainer.init_new_training(context=Context(organization='zauberzeug', project='demo'), - details={'categories': [], - 'id': '917d5c7f-403d-7e92-f95f-577f79c2273a', # version 1.2 of demo project - 'training_number': 0, - 'resolution': 800, - 'flip_rl': False, - 'flip_ud': False}) + await node._on_startup() + trainer._node = node + trainer._init_new_training(context=Context(organization='zauberzeug', project='demo'), + details={'categories': [], + 'id': '917d5c7f-403d-7e92-f95f-577f79c2273a', # version 1.2 of demo project + 'training_number': 0, + 'resolution': 800, + 'flip_rl': False, + 'flip_ud': False}) yield trainer - # await node._on_shutdown() try: await node._on_shutdown() except Exception: @@ -66,10 +64,3 @@ async def test_initialized_trainer(): def is_port_in_use(port): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: return s.connect_ex(('localhost', port)) == 0 - - -# @pytest.fixture(autouse=True, scope='session') -# def initialize_active_training(): -# from learning_loop_node.trainer import active_training_module -# active_training_module.init('00000000-0000-0000-0000-000000000000') -# yield diff --git a/learning_loop_node/trainer/tests/states/test_state_cleanup.py b/learning_loop_node/trainer/tests/states/test_state_cleanup.py index 3326d156..f3911a54 100644 --- a/learning_loop_node/trainer/tests/states/test_state_cleanup.py +++ b/learning_loop_node/trainer/tests/states/test_state_cleanup.py @@ -1,11 +1,13 @@ from learning_loop_node.trainer.tests.state_helper import create_active_training_file from learning_loop_node.trainer.tests.testing_trainer_logic import TestingTrainerLogic +# pylint: disable=protected-access + async def test_cleanup_successfull(test_initialized_trainer: TestingTrainerLogic): trainer = test_initialized_trainer create_active_training_file(trainer, training_state='ready_for_cleanup') - trainer.init_from_last_training() + trainer._init_from_last_training() trainer.active_training_io.save_detections(detections=[]) trainer.active_training_io.save_detection_upload_progress(count=42) @@ -16,9 +18,9 @@ async def test_cleanup_successfull(test_initialized_trainer: TestingTrainerLogic assert trainer.active_training_io.detection_upload_progress_exist() is True assert trainer.active_training_io.detections_upload_file_index_exists() is True - await trainer.clear_training() + await trainer._clear_training() - assert trainer._training is None # pylint: disable=protected-access + assert trainer._training is None assert trainer.node.last_training_io.exists() is False assert trainer.active_training_io.detections_exist() is False assert trainer.active_training_io.detection_upload_progress_exist() is False diff --git a/learning_loop_node/trainer/tests/states/test_state_detecting.py b/learning_loop_node/trainer/tests/states/test_state_detecting.py index d571a665..5492f8dc 100644 --- a/learning_loop_node/trainer/tests/states/test_state_detecting.py +++ b/learning_loop_node/trainer/tests/states/test_state_detecting.py @@ -1,11 +1,12 @@ import asyncio from learning_loop_node.conftest import get_dummy_detections -from learning_loop_node.data_classes import TrainingState +from learning_loop_node.data_classes import TrainerState from learning_loop_node.trainer.tests.state_helper import assert_training_state, create_active_training_file from learning_loop_node.trainer.tests.testing_trainer_logic import TestingTrainerLogic from learning_loop_node.trainer.trainer_logic import TrainerLogic +# pylint: disable=protected-access error_key = 'detecting' @@ -13,60 +14,62 @@ def trainer_has_error(trainer: TrainerLogic): return trainer.errors.has_error_for(error_key) -async def test_successful_detecting(test_initialized_trainer: TestingTrainerLogic): # TODO Flaky test +async def test_successful_detecting(test_initialized_trainer: TestingTrainerLogic): # NOTE was a flaky test trainer = test_initialized_trainer create_active_training_file(trainer, training_state='train_model_uploaded', - model_id_for_detecting='917d5c7f-403d-7e92-f95f-577f79c2273a') + model_uuid_for_detecting='917d5c7f-403d-7e92-f95f-577f79c2273a') # trainer.load_active_training() - _ = asyncio.get_running_loop().create_task(trainer.do_detections()) + _ = asyncio.get_running_loop().create_task( + trainer._perform_state('do_detections', TrainerState.Detecting, TrainerState.Detected, trainer._do_detections)) - await assert_training_state(trainer.training, 'detecting', timeout=1, interval=0.001) - await assert_training_state(trainer.training, 'detected', timeout=10, interval=0.001) + await assert_training_state(trainer.training, TrainerState.Detecting, timeout=1, interval=0.001) + await assert_training_state(trainer.training, TrainerState.Detected, timeout=10, interval=0.001) assert trainer_has_error(trainer) is False - assert trainer.training.training_state == 'detected' + assert trainer.training.training_state == TrainerState.Detected assert trainer.node.last_training_io.load() == trainer.training assert trainer.active_training_io.detections_exist() async def test_detecting_can_be_aborted(test_initialized_trainer: TestingTrainerLogic): trainer = test_initialized_trainer - create_active_training_file(trainer, training_state=TrainingState.TrainModelUploaded) - trainer.init_from_last_training() - trainer.training.model_id_for_detecting = '12345678-bobo-7e92-f95f-424242424242' + create_active_training_file(trainer, training_state=TrainerState.TrainModelUploaded) + trainer._init_from_last_training() + trainer.training.model_uuid_for_detecting = '12345678-bobo-7e92-f95f-424242424242' - _ = asyncio.get_running_loop().create_task(trainer.run()) + _ = asyncio.get_running_loop().create_task(trainer._run()) - await assert_training_state(trainer.training, 'detecting', timeout=5, interval=0.001) + await assert_training_state(trainer.training, TrainerState.Detecting, timeout=5, interval=0.001) await trainer.stop() await asyncio.sleep(0.1) - assert trainer._training is None # pylint: disable=protected-access + assert trainer._training is None assert trainer.active_training_io.detections_exist() is False assert trainer.node.last_training_io.exists() is False async def test_model_not_downloadable_error(test_initialized_trainer: TestingTrainerLogic): trainer = test_initialized_trainer - create_active_training_file(trainer, training_state='train_model_uploaded', - model_id_for_detecting='00000000-0000-0000-0000-000000000000') # bad model id - trainer.init_from_last_training() + create_active_training_file(trainer, training_state=TrainerState.TrainModelUploaded, + model_uuid_for_detecting='00000000-0000-0000-0000-000000000000') # bad model id + trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer.run()) + _ = asyncio.get_running_loop().create_task(trainer._run()) await assert_training_state(trainer.training, 'detecting', timeout=1, interval=0.001) await assert_training_state(trainer.training, 'train_model_uploaded', timeout=1, interval=0.001) + await asyncio.sleep(0.1) assert trainer_has_error(trainer) - assert trainer.training.training_state == 'train_model_uploaded' - assert trainer.training.model_id_for_detecting == '00000000-0000-0000-0000-000000000000' + assert trainer.training.training_state == TrainerState.TrainModelUploaded + assert trainer.training.model_uuid_for_detecting == '00000000-0000-0000-0000-000000000000' assert trainer.node.last_training_io.load() == trainer.training def test_save_load_detections(test_initialized_trainer: TestingTrainerLogic): trainer = test_initialized_trainer create_active_training_file(trainer) - trainer.init_from_last_training() + trainer._init_from_last_training() detections = [get_dummy_detections(), get_dummy_detections()] diff --git a/learning_loop_node/trainer/tests/states/test_state_download_train_model.py b/learning_loop_node/trainer/tests/states/test_state_download_train_model.py index 687e5060..282a2288 100644 --- a/learning_loop_node/trainer/tests/states/test_state_download_train_model.py +++ b/learning_loop_node/trainer/tests/states/test_state_download_train_model.py @@ -2,22 +2,28 @@ import asyncio import os +from learning_loop_node.data_classes import TrainerState from learning_loop_node.trainer.tests.state_helper import assert_training_state, create_active_training_file from learning_loop_node.trainer.tests.testing_trainer_logic import TestingTrainerLogic +# pylint: disable=protected-access + async def test_downloading_is_successful(test_initialized_trainer: TestingTrainerLogic): trainer = test_initialized_trainer - create_active_training_file(trainer, training_state='data_downloaded') + create_active_training_file(trainer, training_state=TrainerState.DataDownloaded) trainer.model_format = 'mocked' - trainer.init_from_last_training() + trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer.download_model()) + asyncio.get_running_loop().create_task( + trainer._perform_state('download_model', + TrainerState.TrainModelDownloading, + TrainerState.TrainModelDownloaded, trainer._download_model)) await assert_training_state(trainer.training, 'train_model_downloading', timeout=1, interval=0.001) await assert_training_state(trainer.training, 'train_model_downloaded', timeout=1, interval=0.001) - assert trainer.training.training_state == 'train_model_downloaded' + assert trainer.training.training_state == TrainerState.TrainModelDownloaded assert trainer.node.last_training_io.load() == trainer.training # file on disk @@ -29,9 +35,9 @@ async def test_downloading_is_successful(test_initialized_trainer: TestingTraine async def test_abort_download_model(test_initialized_trainer: TestingTrainerLogic): trainer = test_initialized_trainer create_active_training_file(trainer, training_state='data_downloaded') - trainer.init_from_last_training() + trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer.run()) + _ = asyncio.get_running_loop().create_task(trainer._run()) await assert_training_state(trainer.training, 'train_model_downloading', timeout=1, interval=0.001) await trainer.stop() @@ -43,15 +49,15 @@ async def test_abort_download_model(test_initialized_trainer: TestingTrainerLogi async def test_downloading_failed(test_initialized_trainer: TestingTrainerLogic): trainer = test_initialized_trainer - create_active_training_file(trainer, training_state='data_downloaded', - base_model_id='00000000-0000-0000-0000-000000000000') # bad model id) - trainer.init_from_last_training() + create_active_training_file(trainer, training_state=TrainerState.DataDownloaded, + base_model_uuid_or_name='00000000-0000-0000-0000-000000000000') # bad model id) + trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer.run()) + _ = asyncio.get_running_loop().create_task(trainer._run()) await assert_training_state(trainer.training, 'train_model_downloading', timeout=1, interval=0.001) - await assert_training_state(trainer.training, 'data_downloaded', timeout=1, interval=0.001) + await assert_training_state(trainer.training, TrainerState.DataDownloaded, timeout=1, interval=0.001) assert trainer.errors.has_error_for('download_model') assert trainer._training is not None # pylint: disable=protected-access - assert trainer.training.training_state == 'data_downloaded' + assert trainer.training.training_state == TrainerState.DataDownloaded assert trainer.node.last_training_io.load() == trainer.training diff --git a/learning_loop_node/trainer/tests/states/test_state_prepare.py b/learning_loop_node/trainer/tests/states/test_state_prepare.py index 9d2eedcc..d3222f9a 100644 --- a/learning_loop_node/trainer/tests/states/test_state_prepare.py +++ b/learning_loop_node/trainer/tests/states/test_state_prepare.py @@ -1,10 +1,11 @@ import asyncio -from learning_loop_node.data_classes import Context +from learning_loop_node.data_classes import Context, TrainerState from learning_loop_node.trainer.tests.state_helper import assert_training_state, create_active_training_file from learning_loop_node.trainer.tests.testing_trainer_logic import TestingTrainerLogic from learning_loop_node.trainer.trainer_logic import TrainerLogic +# pylint: disable=protected-access error_key = 'prepare' @@ -15,11 +16,11 @@ def trainer_has_error(trainer: TrainerLogic): async def test_preparing_is_successful(test_initialized_trainer: TestingTrainerLogic): trainer = test_initialized_trainer create_active_training_file(trainer) - trainer.init_from_last_training() + trainer._init_from_last_training() - await trainer.prepare() + await trainer._perform_state('prepare', TrainerState.DataDownloading, TrainerState.DataDownloaded, trainer._prepare) assert trainer_has_error(trainer) is False - assert trainer.training.training_state == 'data_downloaded' + assert trainer.training.training_state == TrainerState.DataDownloaded assert trainer.training.data is not None assert trainer.node.last_training_io.load() == trainer.training @@ -27,10 +28,10 @@ async def test_preparing_is_successful(test_initialized_trainer: TestingTrainerL async def test_abort_preparing(test_initialized_trainer: TestingTrainerLogic): trainer = test_initialized_trainer create_active_training_file(trainer) - trainer.init_from_last_training() + trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer.run()) - await assert_training_state(trainer.training, 'data_downloading', timeout=1, interval=0.001) + _ = asyncio.get_running_loop().create_task(trainer._run()) + await assert_training_state(trainer.training, TrainerState.DataDownloading, timeout=1, interval=0.001) await trainer.stop() await asyncio.sleep(0.1) @@ -43,13 +44,13 @@ async def test_request_error(test_initialized_trainer: TestingTrainerLogic): trainer = test_initialized_trainer create_active_training_file(trainer, context=Context( organization='zauberzeug', project='some_bad_project')) - trainer.init_from_last_training() + trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer.run()) - await assert_training_state(trainer.training, 'data_downloading', timeout=3, interval=0.001) - await assert_training_state(trainer.training, 'initialized', timeout=3, interval=0.001) + _ = asyncio.get_running_loop().create_task(trainer._run()) + await assert_training_state(trainer.training, TrainerState.DataDownloading, timeout=3, interval=0.001) + await assert_training_state(trainer.training, TrainerState.Initialized, timeout=3, interval=0.001) assert trainer_has_error(trainer) assert trainer._training is not None # pylint: disable=protected-access - assert trainer.training.training_state == 'initialized' + assert trainer.training.training_state == TrainerState.Initialized assert trainer.node.last_training_io.load() == trainer.training diff --git a/learning_loop_node/trainer/tests/states/test_state_sync_confusion_matrix.py b/learning_loop_node/trainer/tests/states/test_state_sync_confusion_matrix.py index b6cce7c2..6a292be5 100644 --- a/learning_loop_node/trainer/tests/states/test_state_sync_confusion_matrix.py +++ b/learning_loop_node/trainer/tests/states/test_state_sync_confusion_matrix.py @@ -3,12 +3,15 @@ from pytest_mock import MockerFixture # pip install pytest-mock +from learning_loop_node.data_classes import TrainerState from learning_loop_node.trainer.trainer_logic import TrainerLogic from learning_loop_node.trainer.trainer_node import TrainerNode from ..state_helper import assert_training_state, create_active_training_file from ..testing_trainer_logic import TestingTrainerLogic +# pylint: disable=protected-access + error_key = 'sync_confusion_matrix' @@ -21,14 +24,14 @@ async def test_nothing_to_sync(test_initialized_trainer: TestingTrainerLogic): # TODO this requires trainer to have _training # trainer.load_active_training() - create_active_training_file(trainer, training_state='training_finished') - trainer.init_from_last_training() + create_active_training_file(trainer, training_state=TrainerState.TrainingFinished) + trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer.run()) + _ = asyncio.get_running_loop().create_task(trainer._run()) - await assert_training_state(trainer.training, 'confusion_matrix_synced', timeout=1, interval=0.001) + await assert_training_state(trainer.training, TrainerState.ConfusionMatrixSynced, timeout=1, interval=0.001) assert trainer_has_error(trainer) is False - assert trainer.training.training_state == 'confusion_matrix_synced' + assert trainer.training.training_state == TrainerState.ConfusionMatrixSynced assert trainer.node.last_training_io.load() == trainer.training @@ -37,16 +40,16 @@ async def test_unsynced_model_available__sync_successful(test_initialized_traine assert isinstance(trainer, TestingTrainerLogic) await mock_socket_io_call(mocker, test_initialized_trainer_node, {'success': True}) - create_active_training_file(trainer, training_state='training_finished') + create_active_training_file(trainer, training_state=TrainerState.TrainingFinished) - trainer.init_from_last_training() + trainer._init_from_last_training() trainer.has_new_model = True - _ = asyncio.get_running_loop().create_task(trainer.run()) - await assert_training_state(trainer.training, 'confusion_matrix_synced', timeout=1, interval=0.001) + _ = asyncio.get_running_loop().create_task(trainer._run()) + await assert_training_state(trainer.training, TrainerState.ConfusionMatrixSynced, timeout=1, interval=0.001) assert trainer_has_error(trainer) is False -# assert trainer.training.training_state == 'confusion_matrix_synced' +# assert trainer.training.training_state == TrainerState.ConfusionMatrixSynced assert trainer.node.last_training_io.load() == trainer.training @@ -54,18 +57,18 @@ async def test_unsynced_model_available__sio_not_connected(test_initialized_trai trainer = test_initialized_trainer_node.trainer_logic assert isinstance(trainer, TestingTrainerLogic) - create_active_training_file(trainer, training_state='training_finished') + create_active_training_file(trainer, training_state=TrainerState.TrainingFinished) assert test_initialized_trainer_node.sio_client.connected is False trainer.has_new_model = True - _ = asyncio.get_running_loop().create_task(trainer.run()) + _ = asyncio.get_running_loop().create_task(trainer._run()) await assert_training_state(trainer.training, 'confusion_matrix_syncing', timeout=1, interval=0.001) - await assert_training_state(trainer.training, 'training_finished', timeout=1, interval=0.001) + await assert_training_state(trainer.training, TrainerState.TrainingFinished, timeout=1, interval=0.001) assert trainer_has_error(trainer) - assert trainer.training.training_state == 'training_finished' + assert trainer.training.training_state == TrainerState.TrainingFinished assert trainer.node.last_training_io.load() == trainer.training @@ -75,16 +78,16 @@ async def test_unsynced_model_available__request_is_not_successful(test_initiali await mock_socket_io_call(mocker, test_initialized_trainer_node, {'success': False}) - create_active_training_file(trainer, training_state='training_finished') + create_active_training_file(trainer, training_state=TrainerState.TrainingFinished) trainer.has_new_model = True - _ = asyncio.get_running_loop().create_task(trainer.run()) + _ = asyncio.get_running_loop().create_task(trainer._run()) await assert_training_state(trainer.training, 'confusion_matrix_syncing', timeout=1, interval=0.001) - await assert_training_state(trainer.training, 'training_finished', timeout=1, interval=0.001) + await assert_training_state(trainer.training, TrainerState.TrainingFinished, timeout=1, interval=0.001) assert trainer_has_error(trainer) - assert trainer.training.training_state == 'training_finished' + assert trainer.training.training_state == TrainerState.TrainingFinished assert trainer.node.last_training_io.load() == trainer.training diff --git a/learning_loop_node/trainer/tests/states/test_state_train.py b/learning_loop_node/trainer/tests/states/test_state_train.py index c46294ba..4e1d200c 100644 --- a/learning_loop_node/trainer/tests/states/test_state_train.py +++ b/learning_loop_node/trainer/tests/states/test_state_train.py @@ -1,48 +1,49 @@ import asyncio +from learning_loop_node.data_classes import TrainerState from learning_loop_node.tests.test_helper import condition from learning_loop_node.trainer.tests.state_helper import assert_training_state, create_active_training_file from learning_loop_node.trainer.tests.testing_trainer_logic import TestingTrainerLogic +# pylint: disable=protected-access + async def test_successful_training(test_initialized_trainer: TestingTrainerLogic): trainer = test_initialized_trainer - create_active_training_file(trainer, training_state='train_model_downloaded') - trainer.init_from_last_training() + create_active_training_file(trainer, training_state=TrainerState.TrainModelDownloaded) + trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer.run()) + _ = asyncio.get_running_loop().create_task(trainer._run()) - await assert_training_state(trainer.training, 'training_running', timeout=1, interval=0.001) + await condition(lambda: trainer._executor and trainer._executor.is_running(), timeout=1, interval=0.01) + await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=1, interval=0.01) assert trainer.start_training_task is not None - assert trainer.start_training_task.__name__ == 'start_training' - # pylint: disable=protected-access assert trainer._executor is not None - trainer._executor.stop() # NOTE normally a training terminates itself - await assert_training_state(trainer.training, 'training_finished', timeout=1, interval=0.001) + await trainer.stop() # NOTE normally a training terminates itself + await assert_training_state(trainer.training, TrainerState.TrainingFinished, timeout=1, interval=0.001) - assert trainer.training.training_state == 'training_finished' + assert trainer.training.training_state == TrainerState.TrainingFinished assert trainer.node.last_training_io.load() == trainer.training async def test_stop_running_training(test_initialized_trainer: TestingTrainerLogic): trainer = test_initialized_trainer - create_active_training_file(trainer, training_state='train_model_downloaded') - trainer.init_from_last_training() + create_active_training_file(trainer, training_state=TrainerState.TrainModelDownloaded) + trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer.run()) + _ = asyncio.get_running_loop().create_task(trainer._run()) - await condition(lambda: trainer._executor and trainer._executor.is_process_running(), timeout=1, interval=0.01) # pylint: disable=protected-access - await assert_training_state(trainer.training, 'training_running', timeout=1, interval=0.001) + await condition(lambda: trainer._executor and trainer._executor.is_running(), timeout=1, interval=0.01) + await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=1, interval=0.01) assert trainer.start_training_task is not None - assert trainer.start_training_task.__name__ == 'start_training' await trainer.stop() - await assert_training_state(trainer.training, 'training_finished', timeout=1, interval=0.001) + await assert_training_state(trainer.training, TrainerState.TrainingFinished, timeout=2, interval=0.01) - assert trainer.training.training_state == 'training_finished' + assert trainer.training.training_state == TrainerState.TrainingFinished assert trainer.node.last_training_io.load() == trainer.training @@ -50,21 +51,19 @@ async def test_training_can_maybe_resumed(test_initialized_trainer: TestingTrain trainer = test_initialized_trainer # NOTE e.g. when a node-computer is restarted - create_active_training_file(trainer, training_state='train_model_downloaded') - trainer.init_from_last_training() - trainer._can_resume = True # pylint: disable=protected-access + create_active_training_file(trainer, training_state=TrainerState.TrainModelDownloaded) + trainer._init_from_last_training() + trainer._can_resume_flag = True - _ = asyncio.get_running_loop().create_task(trainer.run()) + _ = asyncio.get_running_loop().create_task(trainer._run()) - await condition(lambda: trainer._executor and trainer._executor.is_process_running(), timeout=1, interval=0.01) # pylint: disable=protected-access - await assert_training_state(trainer.training, 'training_running', timeout=1, interval=0.001) + await condition(lambda: trainer._executor and trainer._executor.is_running(), timeout=1, interval=0.01) + await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=1, interval=0.001) assert trainer.start_training_task is not None - assert trainer.start_training_task.__name__ == 'resume' - # pylint: disable=protected-access assert trainer._executor is not None - trainer._executor.stop() # NOTE normally a training terminates itself e.g - await assert_training_state(trainer.training, 'training_finished', timeout=1, interval=0.001) + await trainer._executor.stop_and_wait() # NOTE normally a training terminates itself e.g + await assert_training_state(trainer.training, TrainerState.TrainingFinished, timeout=1, interval=0.001) - assert trainer.training.training_state == 'training_finished' + assert trainer.training.training_state == TrainerState.TrainingFinished assert trainer.node.last_training_io.load() == trainer.training diff --git a/learning_loop_node/trainer/tests/states/test_state_upload_detections.py b/learning_loop_node/trainer/tests/states/test_state_upload_detections.py index ca6912d1..e2784514 100644 --- a/learning_loop_node/trainer/tests/states/test_state_upload_detections.py +++ b/learning_loop_node/trainer/tests/states/test_state_upload_detections.py @@ -4,12 +4,13 @@ from dacite import from_dict from learning_loop_node.conftest import get_dummy_detections -from learning_loop_node.data_classes import BoxDetection, Context, Detections +from learning_loop_node.data_classes import BoxDetection, Context, Detections, TrainerState from learning_loop_node.loop_communication import LoopCommunicator from learning_loop_node.trainer.tests.state_helper import assert_training_state, create_active_training_file from learning_loop_node.trainer.tests.testing_trainer_logic import TestingTrainerLogic from learning_loop_node.trainer.trainer_logic import TrainerLogic +# pylint: disable=protected-access error_key = 'upload_detections' @@ -43,13 +44,14 @@ async def create_valid_detection_file(trainer: TrainerLogic, number_of_entries: @pytest.mark.asyncio async def test_upload_successful(test_initialized_trainer: TestingTrainerLogic): trainer = test_initialized_trainer - create_active_training_file(trainer, training_state='detected') - trainer.init_from_last_training() + create_active_training_file(trainer, training_state=TrainerState.Detected) + trainer._init_from_last_training() await create_valid_detection_file(trainer) - await trainer.upload_detections() + await asyncio.get_running_loop().create_task( + trainer._perform_state('upload_detections', TrainerState.DetectionUploading, TrainerState.ReadyForCleanup, trainer.active_training_io.upload_detetions)) - assert trainer.training.training_state == 'ready_for_cleanup' + assert trainer.training.training_state == TrainerState.ReadyForCleanup assert trainer.node.last_training_io.load() == trainer.training @@ -57,13 +59,16 @@ async def test_upload_successful(test_initialized_trainer: TestingTrainerLogic): async def test_detection_upload_progress_is_stored(test_initialized_trainer: TestingTrainerLogic): trainer = test_initialized_trainer - create_active_training_file(trainer, training_state='detected') - trainer.init_from_last_training() + create_active_training_file(trainer, training_state=TrainerState.Detected) + trainer._init_from_last_training() await create_valid_detection_file(trainer) assert trainer.active_training_io.load_detections_upload_file_index() == 0 - await trainer.upload_detections() + # await trainer.upload_detections() + await asyncio.get_running_loop().create_task( + trainer._perform_state('upload_detections', TrainerState.DetectionUploading, TrainerState.ReadyForCleanup, trainer.active_training_io.upload_detetions)) + assert trainer.active_training_io.load_detection_upload_progress() == 0 # Progress is reset for every file assert trainer.active_training_io.load_detections_upload_file_index() == 1 @@ -72,8 +77,8 @@ async def test_detection_upload_progress_is_stored(test_initialized_trainer: Tes async def test_ensure_all_detections_are_uploaded(test_initialized_trainer: TestingTrainerLogic): trainer = test_initialized_trainer - create_active_training_file(trainer, training_state='detected') - trainer.init_from_last_training() + create_active_training_file(trainer, training_state=TrainerState.Detected) + trainer._init_from_last_training() await create_valid_detection_file(trainer, 2, 0) await create_valid_detection_file(trainer, 2, 1) @@ -87,7 +92,7 @@ async def test_ensure_all_detections_are_uploaded(test_initialized_trainer: Test for i in range(skip_detections, len(detections), batch_size): batch_detections = detections[i:i+batch_size] # pylint: disable=protected-access - await trainer._upload_detections(trainer.training.context, batch_detections, i + batch_size) + await trainer.active_training_io._upload_detections(trainer.training.context, batch_detections, i + batch_size) expected_value = i + batch_size if i + batch_size < len(detections) else 0 # Progress is reset for every file assert trainer.active_training_io.load_detection_upload_progress() == expected_value @@ -103,7 +108,7 @@ async def test_ensure_all_detections_are_uploaded(test_initialized_trainer: Test for i in range(skip_detections, len(detections), batch_size): batch_detections = detections[i:i+batch_size] # pylint: disable=protected-access - await trainer._upload_detections(trainer.training.context, batch_detections, i + batch_size) + await trainer.active_training_io._upload_detections(trainer.training.context, batch_detections, i + batch_size) expected_value = i + batch_size if i + batch_size < len(detections) else 0 # Progress is reset for every file assert trainer.active_training_io.load_detection_upload_progress() == expected_value @@ -114,46 +119,43 @@ async def test_ensure_all_detections_are_uploaded(test_initialized_trainer: Test async def test_bad_status_from_LearningLoop(test_initialized_trainer: TestingTrainerLogic): trainer = test_initialized_trainer - create_active_training_file(trainer, training_state='detected', context=Context( + create_active_training_file(trainer, training_state=TrainerState.Detected, context=Context( organization='zauberzeug', project='some_bad_project')) - trainer.init_from_last_training() + trainer._init_from_last_training() trainer.active_training_io.save_detections([get_dummy_detections()]) - _ = asyncio.get_running_loop().create_task(trainer.run()) - await assert_training_state(trainer.training, 'detection_uploading', timeout=1, interval=0.001) - await assert_training_state(trainer.training, 'detected', timeout=1, interval=0.001) + _ = asyncio.get_running_loop().create_task(trainer._run()) + await assert_training_state(trainer.training, TrainerState.DetectionUploading, timeout=1, interval=0.001) + await assert_training_state(trainer.training, TrainerState.Detected, timeout=1, interval=0.001) assert trainer_has_error(trainer) - assert trainer.training.training_state == 'detected' + assert trainer.training.training_state == TrainerState.Detected assert trainer.node.last_training_io.load() == trainer.training -async def test_other_errors(test_initialized_trainer: TestingTrainerLogic): +async def test_go_to_cleanup_if_no_detections_exist(test_initialized_trainer: TestingTrainerLogic): + """This test simulates a situation where the detection file is missing. + In this case, the trainer should report an error and move to the ReadyForCleanup state.""" trainer = test_initialized_trainer # e.g. missing detection file - create_active_training_file(trainer, training_state='detected') - trainer.init_from_last_training() - - _ = asyncio.get_running_loop().create_task(trainer.run()) - await assert_training_state(trainer.training, 'detection_uploading', timeout=1, interval=0.001) - await assert_training_state(trainer.training, 'detected', timeout=1, interval=0.001) + create_active_training_file(trainer, training_state=TrainerState.Detected) + trainer._init_from_last_training() - assert trainer_has_error(trainer) - assert trainer.training.training_state == 'detected' - assert trainer.node.last_training_io.load() == trainer.training + _ = asyncio.get_running_loop().create_task(trainer._run()) + await assert_training_state(trainer.training, TrainerState.ReadyForCleanup, timeout=1, interval=0.001) async def test_abort_uploading(test_initialized_trainer: TestingTrainerLogic): trainer = test_initialized_trainer - create_active_training_file(trainer, training_state='detected') - trainer.init_from_last_training() + create_active_training_file(trainer, training_state=TrainerState.Detected) + trainer._init_from_last_training() await create_valid_detection_file(trainer) - _ = asyncio.get_running_loop().create_task(trainer.run()) + _ = asyncio.get_running_loop().create_task(trainer._run()) - await assert_training_state(trainer.training, 'detection_uploading', timeout=1, interval=0.001) + await assert_training_state(trainer.training, TrainerState.DetectionUploading, timeout=1, interval=0.001) await trainer.stop() await asyncio.sleep(0.1) diff --git a/learning_loop_node/trainer/tests/states/test_state_upload_model.py b/learning_loop_node/trainer/tests/states/test_state_upload_model.py index 05eaa8ed..b2bfa4c7 100644 --- a/learning_loop_node/trainer/tests/states/test_state_upload_model.py +++ b/learning_loop_node/trainer/tests/states/test_state_upload_model.py @@ -2,11 +2,12 @@ from pytest_mock import MockerFixture -from learning_loop_node.data_classes import Context +from learning_loop_node.data_classes import Context, TrainerState from learning_loop_node.trainer.tests.state_helper import assert_training_state, create_active_training_file from learning_loop_node.trainer.tests.testing_trainer_logic import TestingTrainerLogic from learning_loop_node.trainer.trainer_logic import TrainerLogic +# pylint: disable=protected-access error_key = 'upload_model' @@ -19,28 +20,29 @@ async def test_successful_upload(mocker: MockerFixture, test_initialized_trainer mock_upload_model_for_training(mocker, 'new_model_id') create_active_training_file(trainer) - trainer.init_from_last_training() + trainer._init_from_last_training() - train_task = asyncio.get_running_loop().create_task(trainer.upload_model()) + train_task = asyncio.get_running_loop().create_task( + trainer._perform_state('upload_model', TrainerState.TrainModelUploading, TrainerState.TrainModelUploaded, trainer._upload_model)) - await assert_training_state(trainer.training, 'train_model_uploading', timeout=1, interval=0.001) + await assert_training_state(trainer.training, TrainerState.TrainModelUploading, timeout=1, interval=0.001) await train_task assert trainer_has_error(trainer) is False - assert trainer.training.training_state == 'train_model_uploaded' - assert trainer.training.model_id_for_detecting is not None + assert trainer.training.training_state == TrainerState.TrainModelUploaded + assert trainer.training.model_uuid_for_detecting is not None assert trainer.node.last_training_io.load() == trainer.training async def test_abort_upload_model(test_initialized_trainer: TestingTrainerLogic): trainer = test_initialized_trainer - create_active_training_file(trainer, training_state='confusion_matrix_synced') - trainer.init_from_last_training() + create_active_training_file(trainer, training_state=TrainerState.ConfusionMatrixSynced) + trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer.run()) + _ = asyncio.get_running_loop().create_task(trainer._run()) - await assert_training_state(trainer.training, 'train_model_uploading', timeout=1, interval=0.001) + await assert_training_state(trainer.training, TrainerState.TrainModelUploading, timeout=1, interval=0.001) await trainer.stop() await asyncio.sleep(0.1) @@ -55,18 +57,18 @@ async def test_bad_server_response_content(test_initialized_trainer: TestingTrai The training should be aborted and the training state should be set to confusion_matrix_synced.""" trainer = test_initialized_trainer - create_active_training_file(trainer, training_state='confusion_matrix_synced') - trainer.init_from_last_training() + create_active_training_file(trainer, training_state=TrainerState.ConfusionMatrixSynced) + trainer._init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer.run()) + _ = asyncio.get_running_loop().create_task(trainer._run()) - await assert_training_state(trainer.training, 'train_model_uploading', timeout=1, interval=0.001) + await assert_training_state(trainer.training, TrainerState.TrainModelUploading, timeout=1, interval=0.001) # TODO goes to finished because of the error - await assert_training_state(trainer.training, 'confusion_matrix_synced', timeout=2, interval=0.001) + await assert_training_state(trainer.training, TrainerState.ConfusionMatrixSynced, timeout=2, interval=0.001) assert trainer_has_error(trainer) - assert trainer.training.training_state == 'confusion_matrix_synced' - assert trainer.training.model_id_for_detecting is None + assert trainer.training.training_state == TrainerState.ConfusionMatrixSynced + assert trainer.training.model_uuid_for_detecting is None assert trainer.node.last_training_io.load() == trainer.training @@ -76,12 +78,12 @@ async def test_mock_loop_response_example(mocker: MockerFixture, test_initialize mock_upload_model_for_training(mocker, 'new_model_id') create_active_training_file(trainer) - trainer.init_from_last_training() + trainer._init_from_last_training() # pylint: disable=protected-access - result = await trainer._upload_model_return_new_id(Context(organization='zauberzeug', project='demo')) + result = await trainer._upload_model_return_new_model_uuid(Context(organization='zauberzeug', project='demo')) assert result is not None def mock_upload_model_for_training(mocker, return_value): - mocker.patch('learning_loop_node.data_exchanger.DataExchanger.upload_model_for_training', return_value=return_value) + mocker.patch('learning_loop_node.data_exchanger.DataExchanger.upload_model_get_uuid', return_value=return_value) diff --git a/learning_loop_node/trainer/tests/test_errors.py b/learning_loop_node/trainer/tests/test_errors.py index bb6b3d8a..9a9c1cd8 100644 --- a/learning_loop_node/trainer/tests/test_errors.py +++ b/learning_loop_node/trainer/tests/test_errors.py @@ -1,37 +1,45 @@ import asyncio import re +import pytest + +from learning_loop_node.data_classes import TrainerState from learning_loop_node.trainer.tests.state_helper import assert_training_state, create_active_training_file from learning_loop_node.trainer.tests.testing_trainer_logic import TestingTrainerLogic +# pylint: disable=protected-access + async def test_training_process_is_stopped_when_trainer_reports_error(test_initialized_trainer: TestingTrainerLogic): trainer = test_initialized_trainer - create_active_training_file(trainer, training_state='train_model_downloaded') - trainer.init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer.run()) + create_active_training_file(trainer, training_state=TrainerState.TrainModelDownloaded) + trainer._init_from_last_training() + _ = asyncio.get_running_loop().create_task(trainer._run()) - await assert_training_state(trainer.training, 'training_running', timeout=1, interval=0.001) + await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=1, interval=0.001) trainer.error_msg = 'some_error' - await assert_training_state(trainer.training, 'train_model_downloaded', timeout=6, interval=0.001) + await assert_training_state(trainer.training, TrainerState.TrainModelDownloaded, timeout=6, interval=0.001) +@pytest.mark.skip(reason='The since_last_start flag is deprecated.') async def test_log_can_provide_only_data_for_current_run(test_initialized_trainer: TestingTrainerLogic): trainer = test_initialized_trainer - create_active_training_file(trainer, training_state='train_model_downloaded') - trainer.init_from_last_training() - _ = asyncio.get_running_loop().create_task(trainer.run()) + create_active_training_file(trainer, training_state=TrainerState.TrainModelDownloaded) + trainer._init_from_last_training() + _ = asyncio.get_running_loop().create_task(trainer._run()) + + await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=1, interval=0.001) + await asyncio.sleep(0.1) # give tests a bit time to to check for the state - await assert_training_state(trainer.training, 'training_running', timeout=1, interval=0.001) assert trainer._executor is not None assert len(re.findall('Starting executor', str(trainer._executor.get_log_by_lines()))) == 1 trainer.error_msg = 'some_error' - await assert_training_state(trainer.training, 'train_model_downloaded', timeout=6, interval=0.001) + await assert_training_state(trainer.training, TrainerState.TrainModelDownloaded, timeout=6, interval=0.001) trainer.error_msg = None - await assert_training_state(trainer.training, 'training_running', timeout=1, interval=0.001) + await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=1, interval=0.001) await asyncio.sleep(1) assert len(re.findall('Starting executor', str(trainer._executor.get_log_by_lines()))) > 1 # Here only the current run is provided - assert len(re.findall('Starting executor', str(trainer._executor.get_log_by_lines(since_last_start=True)))) == 1 + # assert len(re.findall('Starting executor', str(trainer._executor.get_log_by_lines(since_last_start=True)))) == 1 diff --git a/learning_loop_node/trainer/tests/test_trainer_states.py b/learning_loop_node/trainer/tests/test_trainer_states.py index c6e449b7..c5f2d04e 100644 --- a/learning_loop_node/trainer/tests/test_trainer_states.py +++ b/learning_loop_node/trainer/tests/test_trainer_states.py @@ -1,10 +1,9 @@ from uuid import uuid4 -from learning_loop_node.data_classes import Context, Training, TrainingState +from learning_loop_node.data_classes import Context, TrainerState, Training from learning_loop_node.trainer.io_helpers import LastTrainingIO -from learning_loop_node.trainer.tests.testing_trainer_logic import \ - TestingTrainerLogic +from learning_loop_node.trainer.tests.testing_trainer_logic import TestingTrainerLogic from learning_loop_node.trainer.trainer_node import TrainerNode @@ -27,8 +26,8 @@ def test_fixture_trainer_node(test_initialized_trainer_node): def test_save_load_training(): training = create_training() last_training_io = LastTrainingIO('00000000-0000-0000-0000-000000000000') - training.training_state = TrainingState.Preparing + training.training_state = TrainerState.Preparing last_training_io.save(training) training = last_training_io.load() - assert training.training_state == 'preparing' + assert training.training_state == TrainerState.Preparing diff --git a/learning_loop_node/trainer/tests/testing_trainer_logic.py b/learning_loop_node/trainer/tests/testing_trainer_logic.py index 08589657..50171e08 100644 --- a/learning_loop_node/trainer/tests/testing_trainer_logic.py +++ b/learning_loop_node/trainer/tests/testing_trainer_logic.py @@ -1,8 +1,8 @@ import asyncio import time -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional -from learning_loop_node.data_classes import BasicModel, Context, Detections, ModelInformation, PretrainedModel +from learning_loop_node.data_classes import Context, Detections, ModelInformation, PretrainedModel, TrainingStateData from learning_loop_node.trainer.trainer_logic import TrainerLogic @@ -11,7 +11,7 @@ class TestingTrainerLogic(TrainerLogic): def __init__(self, can_resume: bool = False) -> None: super().__init__('mocked') - self._can_resume: bool = can_resume + self._can_resume_flag: bool = can_resume self.has_new_model: bool = False self.error_msg: Optional[str] = None @@ -25,25 +25,25 @@ def model_architecture(self) -> str: @property def provided_pretrained_models(self) -> List[PretrainedModel]: - return [ - PretrainedModel(name='small', label='Small', description='a small model'), - PretrainedModel(name='medium', label='Medium', description='a medium model'), - PretrainedModel(name='large', label='Large', description='a large model')] + return [PretrainedModel(name='small', label='Small', description='a small model'), + PretrainedModel(name='medium', label='Medium', description='a medium model'), + PretrainedModel(name='large', label='Large', description='a large model')] # pylint: disable=unused-argument - async def start_training(self, model: str = 'model.model') -> None: + async def _start_training_from_base_model(self, model: str = 'model.model') -> None: assert self._executor is not None - self._executor.start('while true; do sleep 1; done') + await self._executor.start('/bin/bash -c "while true; do sleep 1; done"') - async def start_training_from_scratch(self, base_model_id: str) -> None: - await self.start_training(model=f'model_{base_model_id}.pt') + async def _start_training_from_scratch(self) -> None: + assert self.training.base_model_uuid_or_name is not None, 'base_model_uuid_or_name must be set' + await self._start_training_from_base_model(model=f'model_{self.training.base_model_uuid_or_name}.pt') - def get_new_model(self) -> Optional[BasicModel]: + def _get_new_best_training_state(self) -> Optional[TrainingStateData]: if self.has_new_model: - return BasicModel(confusion_matrix={}) + return TrainingStateData(confusion_matrix={}) return None - def on_model_published(self, basic_model: BasicModel) -> None: + def _on_metrics_published(self, training_state_data: TrainingStateData) -> None: pass async def _prepare(self) -> None: @@ -54,24 +54,19 @@ async def _download_model(self) -> None: await super()._download_model() await asyncio.sleep(0.1) # give tests a bit time to to check for the state - async def ensure_confusion_matrix_synced(self): + async def _upload_model(self) -> None: await asyncio.sleep(0.1) # give tests a bit time to to check for the state - await super().ensure_confusion_matrix_synced() + await super()._upload_model() await asyncio.sleep(0.1) # give tests a bit time to to check for the state - async def upload_model(self) -> None: + async def _upload_model_return_new_model_uuid(self, context: Context) -> Optional[str]: await asyncio.sleep(0.1) # give tests a bit time to to check for the state - await super().upload_model() - await asyncio.sleep(0.1) # give tests a bit time to to check for the state - - async def _upload_model_return_new_id(self, context: Context) -> Optional[str]: - await asyncio.sleep(0.1) # give tests a bit time to to check for the state - result = await super()._upload_model_return_new_id(context) + result = await super()._upload_model_return_new_model_uuid(context) await asyncio.sleep(0.1) # give tests a bit time to to check for the state assert isinstance(result, str) return result - def get_latest_model_files(self) -> Union[List[str], Dict[str, List[str]]]: + async def _get_latest_model_files(self) -> Dict[str, List[str]]: time.sleep(1) # NOTE reduce flakyness in Backend tests du to wrong order of events. fake_weight_file = '/tmp/weightfile.weights' with open(fake_weight_file, 'wb') as f: @@ -82,18 +77,18 @@ def get_latest_model_files(self) -> Union[List[str], Dict[str, List[str]]]: f.write('zweiundvierzig') return {'mocked': [fake_weight_file, more_data_file], 'mocked_2': [fake_weight_file, more_data_file]} - def can_resume(self) -> bool: - return self._can_resume + def _can_resume(self) -> bool: + return self._can_resume_flag - async def resume(self) -> None: - return await self.start_training() + async def _resume(self) -> None: + return await self._start_training_from_base_model() async def _detect(self, model_information: ModelInformation, images: List[str], model_folder: str) -> List[Detections]: detections: List[Detections] = [] return detections - async def clear_training_data(self, training_folder: str) -> None: + async def _clear_training_data(self, training_folder: str) -> None: return - def get_executor_error_from_log(self) -> Optional[str]: + def _get_executor_error_from_log(self) -> Optional[str]: return self.error_msg diff --git a/learning_loop_node/trainer/trainer_logic.py b/learning_loop_node/trainer/trainer_logic.py index 1b11b4e3..ea32b6dc 100644 --- a/learning_loop_node/trainer/trainer_logic.py +++ b/learning_loop_node/trainer/trainer_logic.py @@ -3,405 +3,91 @@ import logging import os import shutil -import time from abc import abstractmethod -from dataclasses import asdict from datetime import datetime -from glob import glob -from time import perf_counter -from typing import TYPE_CHECKING, Coroutine, Dict, List, Optional, Union -from uuid import UUID, uuid4 +from typing import Coroutine, List, Optional -import socketio from dacite import from_dict -from fastapi.encoders import jsonable_encoder -from tqdm import tqdm - -from ..data_classes import (BasicModel, Category, Context, Detections, Errors, Hyperparameter, ModelInformation, - PretrainedModel, Training, TrainingData, TrainingError, TrainingState) -from ..helpers.misc import create_image_folder -from ..node import Node -from . import training_syncronizer -from .downloader import TrainingsDownloader -from .executor import Executor -from .io_helpers import ActiveTrainingIO - -if TYPE_CHECKING: - from .trainer_node import TrainerNode - -def is_valid_uuid4(val): - try: - _ = UUID(str(val)).version - return True - except ValueError: - return False +from ..data_classes import Detections, ModelInformation, TrainerState, TrainingError +from ..helpers.misc import create_image_folder, create_project_folder, images_for_ids, is_valid_uuid4 +from .executor import Executor +from .trainer_logic_generic import TrainerLogicGeneric -class TrainerLogic(): +class TrainerLogic(TrainerLogicGeneric): def __init__(self, model_format: str) -> None: - self.model_format: str = model_format + """This class is the base class for all trainers that use an executor to run training processes. + The executor is used to run the training process in a separate process.""" + + super().__init__(model_format) + self._detection_progress: Optional[float] = None self._executor: Optional[Executor] = None - self.start_time: Optional[float] = None - self.training_task: Optional[asyncio.Task] = None self.start_training_task: Optional[Coroutine] = None - self.errors = Errors() - self.shutdown_event: asyncio.Event = asyncio.Event() - self.detection_progress = 0.0 - - self._training: Optional[Training] = None - self._active_training_io: Optional[ActiveTrainingIO] = None - self._node: Optional[TrainerNode] = None - self.restart_after_training = os.environ.get('RESTART_AFTER_TRAINING', 'FALSE').lower() in ['true', '1'] - self.keep_old_trainings = os.environ.get('KEEP_OLD_TRAININGS', 'FALSE').lower() in ['true', '1'] - self.inference_batch_size = int(os.environ.get('INFERENCE_BATCH_SIZE', '10')) - logging.info(f'INFERENCE_BATCH_SIZE: {self.inference_batch_size}') + self.inference_batch_size = 10 - @property - def executor(self) -> Executor: - assert self._executor is not None, 'executor must be set, call `run_training` first' - return self._executor - - @property - def training(self) -> Training: - assert self._training is not None, 'training must be set, call `init` first' - return self._training + # ---------------------------------------- IMPLEMENTED ABSTRACT PROPERTIES ---------------------------------------- @property - def active_training_io(self) -> ActiveTrainingIO: - assert self._active_training_io is not None, 'active_training_io must be set, call `init` first' - return self._active_training_io + def detection_progress(self) -> Optional[float]: + return self._detection_progress - @property - def node(self) -> 'TrainerNode': - assert self._node is not None, 'node should be set by TrainerNodes before initialization' - return self._node + # ---------------------------------------- PROPERTIES ---------------------------------------- @property - def is_initialized(self) -> bool: - """_training and _active_training_io are set in 'init_new_training' or 'init_from_last_training'""" - return self._training is not None and self._active_training_io is not None and self._node is not None - - def init_new_training(self, context: Context, details: Dict) -> None: - """Called on `begin_training` event from the Learning Loop. - Note that details needs the entries 'categories' and 'training_number'""" + def executor(self) -> Executor: + assert self._executor is not None, 'executor must be set, call `run_training` first' + return self._executor - try: - project_folder = Node.create_project_folder(context) - if not self.keep_old_trainings: - # NOTE: We delete all existing training folders because they are not needed anymore. - TrainerLogic.delete_all_training_folders(project_folder) - self._training = TrainerLogic.generate_training(project_folder, context) - self._training.data = TrainingData(categories=Category.from_list(details['categories'])) - self._training.data.hyperparameter = from_dict(data_class=Hyperparameter, data=details) - self._training.training_number = details['training_number'] - self._training.base_model_id = details['id'] - self._training.training_state = TrainingState.Initialized - self._active_training_io = ActiveTrainingIO(self._training.training_folder) - logging.info(f'init training: {self._training}') - except Exception: - logging.exception('Error in init') - - def init_from_last_training(self) -> None: - self._training = self.node.last_training_io.load() - assert self._training is not None and self._training.training_folder is not None, 'could not restore training folder' - self._active_training_io = ActiveTrainingIO(self._training.training_folder) - - async def run(self) -> None: - """Called on `begin_training` event from the Learning Loop.""" - - self.start_time = time.time() - self.errors.reset_all() - try: - self.training_task = asyncio.get_running_loop().create_task(self._run()) - await self.training_task # Object is used to potentially cancel the task - except asyncio.CancelledError: - if not self.shutdown_event.is_set(): - logging.info('training task was cancelled but not by shutdown event') - self.training.training_state = TrainingState.ReadyForCleanup - self.node.last_training_io.save(self.training) - await self.clear_training() - - except Exception as e: - logging.exception(f'Error in train: {e}') - finally: - self.start_time = None - - # ---------------------------------------- TRAINING STATES ---------------------------------------- - - async def _run(self) -> None: - """asyncio.CancelledError is catched in train""" - - if not self.is_initialized: - logging.error('could not start training - trainer is not initialized') - return + # ---------------------------------------- IMPLEMENTED ABSTRACT MEHTODS ---------------------------------------- - while self._training is not None: - tstate = self.training.training_state - logging.info(f'STATE LOOP: {tstate}') - await asyncio.sleep(0.6) # Note: Required for pytests! - if tstate == TrainingState.Initialized: # -> DataDownloading -> DataDownloaded - await self.prepare() - elif tstate == TrainingState.DataDownloaded: # -> TrainModelDownloading -> TrainModelDownloaded - await self.download_model() - elif tstate == TrainingState.TrainModelDownloaded: # -> TrainingRunning -> TrainingFinished - await self.train() - elif tstate == TrainingState.TrainingFinished: # -> ConfusionMatrixSyncing -> ConfusionMatrixSynced - await self.ensure_confusion_matrix_synced() - elif tstate == TrainingState.ConfusionMatrixSynced: # -> TrainModelUploading -> TrainModelUploaded - await self.upload_model() - elif tstate == TrainingState.TrainModelUploaded: # -> Detecting -> Detected - await self.do_detections() - elif tstate == TrainingState.Detected: # -> DetectionUploading -> ReadyForCleanup - await self.upload_detections() - elif tstate == TrainingState.ReadyForCleanup: # -> RESTART or TrainingFinished - await self.clear_training() - self.may_restart() - - async def prepare(self) -> None: - previous_state = self.training.training_state - self.training.training_state = TrainingState.DataDownloading - error_key = 'prepare' - try: - await self._prepare() - except asyncio.CancelledError: - logging.warning('CancelledError in prepare') - raise - except Exception as e: - logging.exception("Unknown error in 'prepare'. Exception:") - self.training.training_state = previous_state - self.errors.set(error_key, str(e)) - else: - self.errors.reset(error_key) - self.training.training_state = TrainingState.DataDownloaded - self.node.last_training_io.save(self.training) - - async def _prepare(self) -> None: - self.node.data_exchanger.set_context(self.training.context) - downloader = TrainingsDownloader(self.node.data_exchanger) - image_data, skipped_image_count = await downloader.download_training_data(self.training.images_folder) - assert self.training.data is not None, 'training.data must be set' - self.training.data.image_data = image_data - self.training.data.skipped_image_count = skipped_image_count - - async def download_model(self) -> None: - logging.info('Downloading model') - previous_state = self.training.training_state - self.training.training_state = TrainingState.TrainModelDownloading - error_key = 'download_model' - try: - await self._download_model() - except asyncio.CancelledError: - logging.warning('CancelledError in download_model') - raise - except Exception as e: - logging.exception('download_model failed') - self.training.training_state = previous_state - self.errors.set(error_key, str(e)) - else: - self.errors.reset(error_key) - logging.info('download_model_task finished') - self.training.training_state = TrainingState.TrainModelDownloaded - self.node.last_training_io.save(self.training) - - async def _download_model(self) -> None: - model_id = self.training.base_model_id - assert model_id is not None, 'model_id must be set' - if is_valid_uuid4( - self.training.base_model_id): # TODO this checks if we continue a training -> make more explicit - logging.info('loading model from Learning Loop') - logging.info(f'downloading model {model_id} as {self.model_format}') - await self.node.data_exchanger.download_model(self.training.training_folder, self.training.context, model_id, self.model_format) - shutil.move(f'{self.training.training_folder}/model.json', - f'{self.training.training_folder}/base_model.json') - else: - logging.info(f'base_model_id {model_id} is not a valid uuid4, skipping download') - - async def train(self) -> None: - logging.info('Running training') + async def _train(self) -> None: + previous_state = TrainerState.TrainModelDownloaded error_key = 'run_training' - # NOTE normally we reset errors after the step was successful. We do not want to display an old error during the whole training. - self.errors.reset(error_key) - previous_state = self.training.training_state self._executor = Executor(self.training.training_folder) - self.training.training_state = TrainingState.TrainingRunning + self.training.training_state = TrainerState.TrainingRunning + try: await self._start_training() - last_sync_time = datetime.now() + while True: - if not self.executor.is_process_running(): + await asyncio.sleep(0.1) + if not self.executor.is_running(): break if (datetime.now() - last_sync_time).total_seconds() > 5: last_sync_time = datetime.now() - if self.get_executor_error_from_log(): + if self._get_executor_error_from_log(): break self.errors.reset(error_key) try: - await self.sync_confusion_matrix() + await self._sync_confusion_matrix() except asyncio.CancelledError: logging.warning('CancelledError in run_training') raise except Exception: - pass - else: - await asyncio.sleep(0.1) + logging.error('Error in sync_confusion_matrix (this error is ignored)') - error = self.get_executor_error_from_log() - if error: - self.errors.set(error_key, error) + if error := self._get_executor_error_from_log(): raise TrainingError(cause=error) - # TODO check if this works: + + # NOTE: This is problematic, because the return code is not 0 when executor was stoppen e.g. via self.stop() # if self.executor.return_code != 0: - # self.errors.set(error_key, f'Executor return code was {self.executor.return_code}') - # raise TrainingError(cause=f'Executor return code was {self.executor.return_code}') + # raise TrainingError(cause=f'Executor returned with error code: {self.executor.return_code}') - except asyncio.CancelledError: - logging.warning('CancelledError in run_training') - raise except TrainingError: - logging.exception('Error in TrainingProcess') - if self.executor.is_process_running(): - self.executor.stop() - self.training.training_state = previous_state - except Exception as e: - self.errors.set(error_key, f'Could not start training {str(e)}') + logging.exception('Exception in trainer_logic._train') + await self.executor.stop_and_wait() self.training.training_state = previous_state - logging.exception('Error in run_training') - else: - self.training.training_state = TrainingState.TrainingFinished - self.node.last_training_io.save(self.training) - - async def _start_training(self): - self.start_training_task = None # NOTE: this is used i.e. by tests - if self.can_resume(): - self.start_training_task = self.resume() - else: - base_model_id = self.training.base_model_id - if not is_valid_uuid4(base_model_id): # TODO this check was done earlier! - assert isinstance(base_model_id, str) - # TODO this could be removed here and accessed via self.training.base_model_id - self.start_training_task = self.start_training_from_scratch(base_model_id) - else: - self.start_training_task = self.start_training() - await self.start_training_task - - async def ensure_confusion_matrix_synced(self): - logging.info('Ensure syncing confusion matrix') - previous_state = self.training.training_state - self.training.training_state = TrainingState.ConfusionMatrixSyncing - try: - await self.sync_confusion_matrix() - except asyncio.CancelledError: - logging.warning('CancelledError in run_training') - raise - except Exception: - logging.exception('Error in ensure_confusion_matrix_synced') - self.training.training_state = previous_state - else: - self.training.training_state = TrainingState.ConfusionMatrixSynced - self.node.last_training_io.save(self.training) - - async def sync_confusion_matrix(self): - logging.info('Syncing confusion matrix') - error_key = 'sync_confusion_matrix' - try: - await training_syncronizer.try_sync_model(self, self.node.uuid, self.node.sio_client) - except socketio.exceptions.BadNamespaceError as e: # type: ignore - logging.error('Error during confusion matrix syncronization. BadNamespaceError') - self.errors.set(error_key, str(e)) - raise - except Exception as e: - logging.exception('Error during confusion matrix syncronization') - self.errors.set(error_key, str(e)) - raise - - self.errors.reset(error_key) - - async def upload_model(self) -> None: - error_key = 'upload_model' - previous_state = self.training.training_state - self.training.training_state = TrainingState.TrainModelUploading - try: - new_model_id = await self._upload_model_return_new_id(self.training.context) - if new_model_id is None: - self.training.training_state = TrainingState.ReadyForCleanup - logging.error('could not upload model - maybe training failed.. cleaning up') - return - assert new_model_id is not None, 'uploaded_model must be set' - logging.info(f'successfully uploaded model and received new model id: {new_model_id}') - self.training.model_id_for_detecting = new_model_id - except asyncio.CancelledError: - logging.warning('CancelledError in upload_model') raise - except Exception as e: - logging.exception('Error in upload_model. Exception:') - self.errors.set(error_key, str(e)) - self.training.training_state = previous_state # TODO... going back is pointless here as it ends in a deadlock ?! - # self.training.training_state = TrainingState.ReadyForCleanup - else: - self.errors.reset(error_key) - self.training.training_state = TrainingState.TrainModelUploaded - self.node.last_training_io.save(self.training) - - async def _upload_model_return_new_id(self, context: Context) -> Optional[str]: - """Upload model files, usually pytorch model (.pt) hyp.yaml and the converted .wts file. - Note that with the latest trainers the conversion to (.wts) is done by the trainer. - The conversion from .wts to .engine is done by the detector (needs to be done on target hardware). - Note that trainer may train with different classes, which is why we send an initial model.json file. - """ - files = await asyncio.get_running_loop().run_in_executor(None, self.get_latest_model_files) - - if files is None: - return None - - if isinstance(files, List): - files = {self.model_format: files} - assert isinstance(files, Dict), f'can only save model as list or dict, but was {files}' - - model_json_path = self.create_model_json_with_categories() - already_uploaded_formats = self.active_training_io.load_model_upload_progress() - - new_id = None - for file_format in files: - if file_format in already_uploaded_formats: - continue - _files = files[file_format] - # model.json was mandatory in previous versions. Now its forbidden to provide an own model.json file. - assert not any(f for f in _files if 'model.json' in f), "Upload 'model.json' not allowed (added automatically)." - _files.append(model_json_path) - new_id = await self.node.data_exchanger.upload_model_for_training(context, _files, self.training.training_number, file_format) - if new_id is None: - return None - - already_uploaded_formats.append(file_format) - self.active_training_io.save_model_upload_progress(already_uploaded_formats) - - return new_id - - async def do_detections(self): - error_key = 'detecting' - previous_state = self.training.training_state - try: - self.training.training_state = TrainingState.Detecting - await self._do_detections() - except asyncio.CancelledError: - logging.warning('CancelledError in do_detections') - raise - except Exception as e: - self.errors.set(error_key, str(e)) - logging.exception('Error in do_detections - Exception:') - self.training.training_state = previous_state - else: - self.errors.reset(error_key) - self.training.training_state = TrainingState.Detected - self.node.last_training_io.save(self.training) async def _do_detections(self) -> None: context = self.training.context - model_id = self.training.model_id_for_detecting - assert model_id, 'model_id must be set' + model_id = self.training.model_uuid_for_detecting + if not model_id: + logging.error('model_id is not set! Cannot do detections.') + return tmp_folder = f'/tmp/model_for_auto_detections_{model_id}_{self.model_format}' shutil.rmtree(tmp_folder, ignore_errors=True) @@ -410,111 +96,57 @@ async def _do_detections(self) -> None: await self.node.data_exchanger.download_model(tmp_folder, context, model_id, self.model_format) with open(f'{tmp_folder}/model.json', 'r') as f: - content = json.load(f) - model_information = from_dict(data_class=ModelInformation, data=content) + model_information = from_dict(data_class=ModelInformation, data=json.load(f)) - project_folder = Node.create_project_folder(context) + project_folder = create_project_folder(context) image_folder = create_image_folder(project_folder) self.node.data_exchanger.set_context(context) image_ids = [] for state, p in zip(['inbox', 'annotate', 'review', 'complete'], [0.1, 0.2, 0.3, 0.4]): - self.detection_progress = p + self._detection_progress = p logging.info(f'fetching image ids of {state}') - new_ids = await self.node.data_exchanger.fetch_image_ids(query_params=f'state={state}') + new_ids = await self.node.data_exchanger.fetch_image_uuids(query_params=f'state={state}') image_ids += new_ids logging.info(f'downloading {len(new_ids)} images') await self.node.data_exchanger.download_images(new_ids, image_folder) - self.detection_progress = 0.42 - await self.node.data_exchanger.delete_corrupt_images(image_folder) + self._detection_progress = 0.42 + # await delete_corrupt_images(image_folder) - images = await asyncio.get_event_loop().run_in_executor(None, TrainerLogic.images_for_ids, image_ids, image_folder) - num_images = len(images) - logging.info(f'running detections on {num_images} images') - batch_size = 200 - idx = 0 + images = await asyncio.get_event_loop().run_in_executor(None, images_for_ids, image_ids, image_folder) if not images: - self.active_training_io.save_detections([], idx) - for i in tqdm(range(0, num_images, batch_size), position=0, leave=True): - self.detection_progress = 0.5 + (i/num_images)*0.5 - batch_images = images[i:i+batch_size] + self.active_training_io.save_detections([], 0) + num_images = len(images) + + for idx, i in enumerate(range(0, num_images, self.inference_batch_size)): + self._detection_progress = 0.5 + (i/num_images)*0.5 + batch_images = images[i:i+self.inference_batch_size] batch_detections = await self._detect(model_information, batch_images, tmp_folder) self.active_training_io.save_detections(batch_detections, idx) - idx += 1 - return None + # ---------------------------------------- METHODS ---------------------------------------- - async def upload_detections(self): - error_key = 'upload_detections' - previous_state = self.training.training_state - self.training.training_state = TrainingState.DetectionUploading - await asyncio.sleep(0.1) # NOTE needed for tests - try: - json_files = self.active_training_io.get_detection_file_names() - if not json_files: - raise Exception() - current_json_file_index = self.active_training_io.load_detections_upload_file_index() - for i in range(current_json_file_index, len(json_files)): - detections = self.active_training_io.load_detections(i) - logging.info(f'uploading detections {i}/{len(json_files)}') - await self._upload_detections_batched(self.training.context, detections) - self.active_training_io.save_detections_upload_file_index(i+1) - except asyncio.CancelledError: - logging.warning('CancelledError in upload_detections') - raise - except Exception as e: - self.errors.set(error_key, str(e)) - logging.exception('Error in upload_detections') - self.training.training_state = previous_state - else: - self.errors.reset(error_key) - self.training.training_state = TrainingState.ReadyForCleanup - self.node.last_training_io.save(self.training) - - async def _upload_detections_batched(self, context: Context, detections: List[Detections]): - batch_size = 10 - skip_detections = self.active_training_io.load_detection_upload_progress() - for i in tqdm(range(skip_detections, len(detections), batch_size), position=0, leave=True): - up_progress = i+batch_size - batch_detections = detections[i:up_progress] - dict_detections = [jsonable_encoder(asdict(detection)) for detection in batch_detections] - logging.info(f'uploading detections. File size : {len(json.dumps(dict_detections))}') - await self._upload_detections(context, batch_detections, up_progress) - skip_detections = up_progress - - async def _upload_detections(self, context: Context, batch_detections: List[Detections], up_progress: int): - assert self._active_training_io is not None, 'active_training must be set' - - detections_json = [jsonable_encoder(asdict(detections)) for detections in batch_detections] - response = await self.node.loop_communicator.post( - f'/{context.organization}/projects/{context.project}/detections', json=detections_json) - if response.status_code != 200: - msg = f'could not upload detections. {str(response)}' - logging.error(msg) - raise Exception(msg) + async def _start_training(self): + self.start_training_task = None # NOTE: this is used i.e. by tests + if self._can_resume(): + self.start_training_task = self._resume() else: - logging.info('successfully uploaded detections') - if up_progress > len(batch_detections): - self._active_training_io.save_detection_upload_progress(0) + base_model_uuid_or_name = self.training.base_model_uuid_or_name + if not is_valid_uuid4(base_model_uuid_or_name): + self.start_training_task = self._start_training_from_scratch() else: - self._active_training_io.save_detection_upload_progress(up_progress) - - async def clear_training(self): - self.active_training_io.delete_detections() - self.active_training_io.delete_detection_upload_progress() - self.active_training_io.delete_detections_upload_file_index() - await self.clear_training_data(self.training.training_folder) - self.node.last_training_io.delete() - # self.training.training_state = TrainingState.TrainingFinished - assert self._node is not None - await self._node.send_status() # make sure the status is updated before we stop the training - self._training = None + self.start_training_task = self._start_training_from_base_model() + await self.start_training_task + + # ---------------------------------------- OVERWRITTEN METHODS ---------------------------------------- async def stop(self) -> None: """If executor is running, stop it. Else cancel training task.""" - if not self.is_initialized: + print('===============> stop received in trainer_logic.', flush=True) + + if not self.training_active: return - if self._executor and self._executor.is_process_running(): - self.executor.stop() + if self._executor and self._executor.is_running(): + await self.executor.stop_and_wait() elif self.training_task: logging.info('cancelling training task') if self.training_task.cancel(): @@ -523,175 +155,33 @@ async def stop(self) -> None: except asyncio.CancelledError: pass logging.info('cancelled training task') - self.may_restart() - - async def shutdown(self) -> None: - self.shutdown_event.set() - await self.stop() - await self.stop() # NOTE first stop may only stop training. - - def get_log(self) -> str: - return self.executor.get_log() + self._may_restart() - def may_restart(self) -> None: - if self.restart_after_training: - logging.info('restarting') - assert self._node is not None - self._node.restart() - else: - logging.info('not restarting') - - @property - def general_progress(self) -> Optional[float]: - """Represents the progress for different states.""" - if not self.is_initialized: - return None - - t_state = self.training.training_state - if t_state == TrainingState.DataDownloading: - return self.node.data_exchanger.progress - if t_state == TrainingState.TrainingRunning: - return self.training_progress - if t_state == TrainingState.Detecting: - return self.detection_progress - - return None # ---------------------------------------- ABSTRACT METHODS ---------------------------------------- - @property - @abstractmethod - def training_progress(self) -> Optional[float]: - """Represents the training progress.""" - raise NotImplementedError - - @property - @abstractmethod - def provided_pretrained_models(self) -> List[PretrainedModel]: - raise NotImplementedError - - @property @abstractmethod - def model_architecture(self) -> Optional[str]: - raise NotImplementedError + async def _start_training_from_base_model(self) -> None: + '''Should be used to start a training on executer, e.g. self.executor.start(cmd).''' @abstractmethod - async def start_training(self) -> None: - '''Should be used to start a training.''' + async def _start_training_from_scratch(self) -> None: + '''Should be used to start a training from scratch on executer, e.g. self.executor.start(cmd). + NOTE base_model_id is now accessible via self.training.base_model_id + the id of a pretrained model provided by self.provided_pretrained_models.''' @abstractmethod - async def start_training_from_scratch(self, base_model_id: str) -> None: - '''Should be used to start a training from scratch. - base_model_id is the id of a pretrained model provided by self.provided_pretrained_models.''' - - @abstractmethod - def can_resume(self) -> bool: + def _can_resume(self) -> bool: '''Override this method to return True if the trainer can resume training.''' @abstractmethod - async def resume(self) -> None: + async def _resume(self) -> None: '''Is called when self.can_resume() returns True. One may resume the training on a previously trained model stored by self.on_model_published(basic_model).''' @abstractmethod - def get_executor_error_from_log(self) -> Optional[str]: # TODO we should allow other options to get the error + def _get_executor_error_from_log(self) -> Optional[str]: '''Should be used to provide error informations to the Learning Loop by extracting data from self.executor.get_log().''' - @abstractmethod - def get_new_model(self) -> Optional[BasicModel]: - '''Is called frequently in `try_sync_model` to check if a new "best" model is availabe. - Returns None if no new model could be found. Otherwise BasicModel(confusion_matrix, meta_information). - `confusion_matrix` contains a dict of all classes: - - The classes must be identified by their id, not their name. - - For each class a dict with tp, fp, fn is provided (true positives, false positives, false negatives). - `meta_information` can hold any data which is helpful for self.on_model_published to store weight file etc for later upload via self.get_model_files - ''' - - @abstractmethod - def on_model_published(self, basic_model: BasicModel) -> None: - '''Called after a BasicModel has been successfully send to the Learning Loop. - The files for this model should be stored. - self.get_latest_model_files is used to gather all files needed for transfering the actual data from the trainer node to the Learning Loop. - In the simplest implementation this method just renames the weight file (encoded in BasicModel.meta_information) into a file name like latest_published_model - ''' - - @abstractmethod - def get_latest_model_files(self) -> Optional[Union[List[str], Dict[str, List[str]]]]: - '''Called when the Learning Loop requests to backup the latest model for the training. - Should return a list of file paths which describe the model. - These files must contain all data neccessary for the trainer to resume a training (eg. weight file, hyperparameters, etc.) - and will be stored in the Learning Loop unter the format of this trainer. - Note: by convention the weightfile should be named "model." where extension is the file format of the weightfile. - For example "model.pt" for pytorch or "model.weights" for darknet/yolo. - - If a trainer can also generate other formats (for example for an detector), - a dictionary mapping format -> list of files can be returned.''' - @abstractmethod async def _detect(self, model_information: ModelInformation, images: List[str], model_folder: str) -> List[Detections]: '''Called to run detections on a list of images.''' - - @abstractmethod - async def clear_training_data(self, training_folder: str) -> None: - '''Called after a training has finished. Deletes all data that is not needed anymore after a training run. - This can be old weightfiles or any additional files.''' - - # ---------------------------------------- HELPER METHODS ---------------------------------------- - - @staticmethod - def images_for_ids(image_ids, image_folder) -> List[str]: - logging.info(f'### Going to get images for {len(image_ids)} images ids') - start = perf_counter() - images = [img for img in glob(f'{image_folder}/**/*.*', recursive=True) - if os.path.splitext(os.path.basename(img))[0] in image_ids] - end = perf_counter() - logging.info(f'found {len(images)} images for {len(image_ids)} image ids, which took {end-start:0.2f} seconds') - return images - - @staticmethod - def generate_training(project_folder: str, context: Context) -> Training: - training_uuid = str(uuid4()) - return Training( - id=training_uuid, - context=context, - project_folder=project_folder, - images_folder=create_image_folder(project_folder), - training_folder=TrainerLogic.create_training_folder(project_folder, training_uuid) - ) - - @staticmethod - def delete_all_training_folders(project_folder: str): - if not os.path.exists(f'{project_folder}/trainings'): - return - for uuid in os.listdir(f'{project_folder}/trainings'): - shutil.rmtree(f'{project_folder}/trainings/{uuid}', ignore_errors=True) - - @staticmethod - def create_training_folder(project_folder: str, trainings_id: str) -> str: - training_folder = f'{project_folder}/trainings/{trainings_id}' - os.makedirs(training_folder, exist_ok=True) - return training_folder - - @property - def hyperparameters(self) -> Optional[Dict]: - if self._training and self._training.data and self._training.data.hyperparameter: - information = {} - information['resolution'] = self._training.data.hyperparameter.resolution - information['flipRl'] = self._training.data.hyperparameter.flip_rl - information['flipUd'] = self._training.data.hyperparameter.flip_ud - return information - return None - - def create_model_json_with_categories(self) -> str: - """Remaining fields are filled by the Learning Loop""" - if self._training and self._training.data: - content = { - 'categories': [asdict(c) for c in self._training.data.categories], - } - else: - content = None - - model_json_path = '/tmp/model.json' - with open(model_json_path, 'w') as f: - json.dump(content, f) - - return model_json_path diff --git a/learning_loop_node/trainer/trainer_logic_generic.py b/learning_loop_node/trainer/trainer_logic_generic.py new file mode 100644 index 00000000..f790bbd9 --- /dev/null +++ b/learning_loop_node/trainer/trainer_logic_generic.py @@ -0,0 +1,495 @@ +import asyncio +import json +import logging +import shutil +import sys +import time +from abc import ABC, abstractmethod +from dataclasses import asdict +from typing import TYPE_CHECKING, Callable, Coroutine, Dict, List, Optional + +from fastapi.encoders import jsonable_encoder + +from ..data_classes import (Context, Errors, Hyperparameter, PretrainedModel, TrainerState, Training, TrainingData, + TrainingOut, TrainingStateData) +from ..helpers.misc import create_project_folder, delete_all_training_folders, generate_training, is_valid_uuid4 +from .downloader import TrainingsDownloader +from .io_helpers import ActiveTrainingIO, EnvironmentVars, LastTrainingIO + +if TYPE_CHECKING: + from .trainer_node import TrainerNode + + +class TrainerLogicGeneric(ABC): + + def __init__(self, model_format: str) -> None: + + # NOTE: model_format is used in the file path for the model on the server: + # It acts as a key for list of files (cf. _get_latest_model_files) + # '/{context.organization}/projects/{context.project}/models/{model_id}/{model_format}/file' + self.model_format: str = model_format + self.errors = Errors() + + self.training_task: Optional[asyncio.Task] = None + self.shutdown_event: asyncio.Event = asyncio.Event() + + self._node: Optional['TrainerNode'] = None # type: ignore + self._last_training_io: Optional[LastTrainingIO] = None # type: ignore + + self._training: Optional[Training] = None + self._active_training_io: Optional[ActiveTrainingIO] = None + self._environment_vars = EnvironmentVars() + + # ---------------------------------------- PROPERTIES TO AVOID CHECKING FOR NONE ---------------------------------------- + + @property + def node(self) -> 'TrainerNode': + assert self._node is not None, 'node should be set by TrainerNode before initialization' + return self._node + + @property + def last_training_io(self) -> LastTrainingIO: + assert self._last_training_io is not None, 'last_training_io should be set by TrainerNode before initialization' + return self._last_training_io + + @property + def active_training_io(self) -> ActiveTrainingIO: + assert self._active_training_io is not None, 'active_training_io must be set, call `init` first' + return self._active_training_io + + @property + def training(self) -> Training: + assert self._training is not None, 'training must be initialized, call `init` first' + return self._training + + @property + def hyperparameter(self) -> Hyperparameter: + assert self.training_data is not None, 'Training should have data' + assert self.training_data.hyperparameter is not None, 'Training.data should have hyperparameter' + return self.training_data.hyperparameter + + # ---------------------------------------- PROPERTIES ---------------------------------------- + + @property + def training_data(self) -> Optional[TrainingData]: + if self.training_active and self.training.data: + return self.training.data + return None + + @property + def training_context(self) -> Optional[Context]: + if self.training_active: + return self.training.context + return None + + @property + def training_active(self) -> bool: + """_training and _active_training_io are set in 'init_new_training' or 'init_from_last_training'. + """ + return self._training is not None and self._active_training_io is not None + + @property + def state(self) -> str: + """Returns the current state of the training. Used solely by the node in send_status(). + """ + if (not self.training_active) or (self.training.training_state is None): + return TrainerState.Idle.value + return self.training.training_state + + @property + def training_uptime(self) -> Optional[float]: + """Livetime of current Training object. Start time is set during initialization of Training object. + """ + if self.training_active: + return time.time() - self.training.start_time + return None + + @property + def hyperparameters_for_state_sync(self) -> Optional[Dict]: + """Used in sync_confusion_matrix and send_status to provide information about the training configuration. + """ + if self._training and self._training.data and self._training.data.hyperparameter: + information = {} + information['resolution'] = self._training.data.hyperparameter.resolution + information['flipRl'] = self._training.data.hyperparameter.flip_rl + information['flipUd'] = self._training.data.hyperparameter.flip_ud + return information + return None + + @property + def general_progress(self) -> Optional[float]: + """Represents the progress for different states, should run from 0 to 100 for each state. + Note that training_progress and detection_progress need to be implemented in the specific trainer. + """ + if not self.training_active: + return None + + t_state = self.training.training_state + if t_state == TrainerState.DataDownloading: + return self.node.data_exchanger.progress + if t_state == TrainerState.TrainingRunning: + return self.training_progress + if t_state == TrainerState.Detecting: + return self.detection_progress + + return None + + # ---------------------------------------- ABSTRACT PROPERTIES ---------------------------------------- + + @property + @abstractmethod + def training_progress(self) -> Optional[float]: + """Represents the training progress.""" + raise NotImplementedError + + @property + @abstractmethod + def detection_progress(self) -> Optional[float]: + """Represents the detection progress.""" + raise NotImplementedError + + @property + @abstractmethod + def model_architecture(self) -> Optional[str]: + """Returns the architecture name of the model if available""" + raise NotImplementedError + + @property + @abstractmethod + def provided_pretrained_models(self) -> List[PretrainedModel]: + """Returns the list of provided pretrained models. + The names of the models will come back as model_uuid_or_name in the training details. + """ + raise NotImplementedError + + # ---------------------------------------- METHODS ---------------------------------------- + + # NOTE: Trainings are started by the Learning Loop via the begin_training event + # or by the trainer itself via try_continue_run_if_incomplete. + # The trainer will then initialize a new training object and start the training loop. + # Initializing a new training object will create the folder structure for the training. + # The training loop will then run through the states of the training. + + async def try_continue_run_if_incomplete(self) -> bool: + """Tries to continue a training if the last training was not finished. + """ + if not self.training_active and self.last_training_io.exists(): + self._init_from_last_training() + logging.info('found incomplete training, continuing now.') + asyncio.get_event_loop().create_task(self._run()) + return True + return False + + def _init_from_last_training(self) -> None: + """Initializes a new training object from the last training saved on disc via last_training_io. + """ + self._training = self.last_training_io.load() + assert self._training is not None and self._training.training_folder is not None, 'could not restore training folder' + self._active_training_io = ActiveTrainingIO( + self._training.training_folder, self.node.loop_communicator, self._training.context) + + async def begin_training(self, organization: str, project: str, details: Dict) -> None: + """Called on `begin_training` event from the Learning Loop. + """ + self._init_new_training(Context(organization=organization, project=project), details) + asyncio.get_event_loop().create_task(self._run()) + + def _init_new_training(self, context: Context, details: Dict) -> None: + """Called on `begin_training` event from the Learning Loop. + Note that details needs the entries 'categories' and 'training_number', + but also the hyperparameter entries. + """ + project_folder = create_project_folder(context) + if not self._environment_vars.keep_old_trainings: + delete_all_training_folders(project_folder) + self._training = generate_training(project_folder, context) + self._training.set_values_from_data(details) + + self._active_training_io = ActiveTrainingIO( + self._training.training_folder, self.node.loop_communicator, context) + logging.info(f'new training initialized: {self._training}') + + async def _run(self) -> None: + """Called on `begin_training` event from the Learning Loop. + Either via `begin_training` or `try_continue_run_if_incomplete`. + """ + self.errors.reset_all() + try: + self.training_task = asyncio.get_running_loop().create_task(self._training_loop()) + await self.training_task # NOTE: Task object is used to potentially cancel the task + except asyncio.CancelledError: + if not self.shutdown_event.is_set(): + logging.info('training task was cancelled but not by shutdown event') + self.training.training_state = TrainerState.ReadyForCleanup + self.last_training_io.save(self.training) + await self._clear_training() + except Exception as e: + logging.exception(f'Error in train: {e}') + + # ---------------------------------------- TRAINING STATES ---------------------------------------- + + async def _training_loop(self) -> None: + """Cycle through the training states until the training is finished or + an asyncio.CancelledError is raised. + """ + assert self.training_active + + while self._training is not None: + tstate = self.training.training_state + await asyncio.sleep(0.6) # Note: Required for pytests! + + if tstate == TrainerState.Initialized: # -> DataDownloading -> DataDownloaded + await self._perform_state('prepare', TrainerState.DataDownloading, TrainerState.DataDownloaded, self._prepare) + elif tstate == TrainerState.DataDownloaded: # -> TrainModelDownloading -> TrainModelDownloaded + await self._perform_state('download_model', TrainerState.TrainModelDownloading, TrainerState.TrainModelDownloaded, self._download_model) + elif tstate == TrainerState.TrainModelDownloaded: # -> TrainingRunning -> TrainingFinished + await self._perform_state('run_training', TrainerState.TrainingRunning, TrainerState.TrainingFinished, self._train) + elif tstate == TrainerState.TrainingFinished: # -> ConfusionMatrixSyncing -> ConfusionMatrixSynced + await self._perform_state('sync_confusion_matrix', TrainerState.ConfusionMatrixSyncing, TrainerState.ConfusionMatrixSynced, self._sync_confusion_matrix) + elif tstate == TrainerState.ConfusionMatrixSynced: # -> TrainModelUploading -> TrainModelUploaded + await self._perform_state('upload_model', TrainerState.TrainModelUploading, TrainerState.TrainModelUploaded, self._upload_model) + elif tstate == TrainerState.TrainModelUploaded: # -> Detecting -> Detected + await self._perform_state('detecting', TrainerState.Detecting, TrainerState.Detected, self._do_detections) + elif tstate == TrainerState.Detected: # -> DetectionUploading -> ReadyForCleanup + await self._perform_state('upload_detections', TrainerState.DetectionUploading, TrainerState.ReadyForCleanup, self.active_training_io.upload_detetions) + elif tstate == TrainerState.ReadyForCleanup: # -> RESTART or TrainingFinished + await self._clear_training() + self._may_restart() + + async def _perform_state(self, error_key: str, state_during: TrainerState, state_after: TrainerState, action: Callable[[], Coroutine], reset_early=False): + await asyncio.sleep(0.1) + logging.info(f'Performing state: {state_during}') + previous_state = self.training.training_state + self.training.training_state = state_during + await asyncio.sleep(0.1) + if reset_early: + self.errors.reset(error_key) + + try: + if await action(): + logging.error('Something went really bad.. cleaning up') + state_after = TrainerState.ReadyForCleanup + except asyncio.CancelledError: + logging.warning(f'CancelledError in {state_during}') + raise + except Exception as e: + self.errors.set(error_key, str(e)) + logging.exception(f'Error in {state_during} - Exception:') + self.training.training_state = previous_state + else: + if not reset_early: + self.errors.reset(error_key) + self.training.training_state = state_after + self.last_training_io.save(self.training) + + async def _prepare(self) -> None: + """Downloads images to the images_folder and saves annotations to training.data.image_data. + """ + self.node.data_exchanger.set_context(self.training.context) + downloader = TrainingsDownloader(self.node.data_exchanger) + image_data, skipped_image_count = await downloader.download_training_data(self.training.images_folder) + assert self.training.data is not None, 'training.data must be set' + self.training.data.image_data = image_data + self.training.data.skipped_image_count = skipped_image_count + + async def _download_model(self) -> None: + """If training is continued, the model is downloaded from the Learning Loop to the training_folder. + The downloaded model.json file is renamed to base_model.json because a new model.json will be created during training. + """ + base_model_uuid = self.training.base_model_uuid_or_name + + # TODO this checks if we continue a training -> make more explicit + if not base_model_uuid or not is_valid_uuid4(base_model_uuid): + logging.info(f'skipping model download. No base model provided (in form of uuid): {base_model_uuid}') + return + + logging.info('loading model from Learning Loop') + logging.info(f'downloading model {base_model_uuid} as {self.model_format}') + await self.node.data_exchanger.download_model(self.training.training_folder, self.training.context, base_model_uuid, self.model_format) + shutil.move(f'{self.training.training_folder}/model.json', + f'{self.training.training_folder}/base_model.json') + + async def _sync_confusion_matrix(self) -> None: + """Syncronizes the confusion matrix with the Learning Loop via the update_training endpoint. + NOTE: This stage sets the errors explicitly because it may be used inside the training stage. + """ + error_key = 'sync_confusion_matrix' + try: + new_best_model = self._get_new_best_training_state() + if new_best_model and self.training.data: + new_training = TrainingOut(trainer_id=self.node.uuid, + confusion_matrix=new_best_model.confusion_matrix, + train_image_count=self.training.data.train_image_count(), + test_image_count=self.training.data.test_image_count(), + hyperparameters=self.hyperparameters_for_state_sync) + await asyncio.sleep(0.1) # NOTE needed for tests. + + result = await self.node.sio_client.call('update_training', ( + self.training.context.organization, self.training.context.project, jsonable_encoder(new_training))) + if isinstance(result, dict) and result['success']: + logging.info(f'successfully updated training {asdict(new_training)}') + self._on_metrics_published(new_best_model) + else: + raise Exception(f'Error for update_training: Response from loop was : {result}') + except Exception as e: + logging.exception('Error during confusion matrix syncronization') + self.errors.set(error_key, str(e)) + raise + self.errors.reset(error_key) + + async def _upload_model(self) -> None: + """Uploads the latest model to the Learning Loop. + """ + new_model_uuid = await self._upload_model_return_new_model_uuid(self.training.context) + if new_model_uuid is None: + self.training.training_state = TrainerState.ReadyForCleanup + logging.error('could not upload model - maybe training failed.. cleaning up') + logging.info(f'Successfully uploaded model and received new model id: {new_model_uuid}') + self.training.model_uuid_for_detecting = new_model_uuid + + async def _upload_model_return_new_model_uuid(self, context: Context) -> Optional[str]: + """Upload model files, usually pytorch model (.pt) hyp.yaml and the converted .wts file. + Note that with the latest trainers the conversion to (.wts) is done by the trainer. + The conversion from .wts to .engine is done by the detector (needs to be done on target hardware). + Note that trainer may train with different classes, which is why we send an initial model.json file.""" + + files = await self._get_latest_model_files() + if files is None: + return None + + if isinstance(files, List): + files = {self.model_format: files} + assert isinstance(files, Dict), f'can only upload model as list or dict, but was {files}' + + already_uploaded_formats = self.active_training_io.load_model_upload_progress() + + model_uuid = None + for file_format in [f for f in files if f not in already_uploaded_formats]: + _files = files[file_format] + [self._dump_categories_to_json()] + assert len([f for f in _files if 'model.json' in f]) == 1, "model.json must be included exactly once" + + model_uuid = await self.node.data_exchanger.upload_model_get_uuid(context, _files, self.training.training_number, file_format) + if model_uuid is None: + return None + + already_uploaded_formats.append(file_format) + self.active_training_io.save_model_upload_progress(already_uploaded_formats) + + return model_uuid + + def _dump_categories_to_json(self) -> str: + """Dumps the categories to a json file and returns the path to the file. + """ + content = {'categories': [asdict(c) for c in self.training_data.categories], } if self.training_data else None + json_path = '/tmp/model.json' + with open(json_path, 'w') as f: + json.dump(content, f) + return json_path + + async def _clear_training(self): + """Clears the training data after a training has finished. + """ + self.active_training_io.delete_detections() + self.active_training_io.delete_detection_upload_progress() + self.active_training_io.delete_detections_upload_file_index() + await self._clear_training_data(self.training.training_folder) + self.last_training_io.delete() + + await self.node.send_status() + self._training = None + + # ---------------------------------------- OTHER METHODS ---------------------------------------- + + async def on_shutdown(self) -> None: + self.shutdown_event.set() + await self.stop() + await self.stop() + + async def stop(self): + """Stops the training process by canceling training task. + """ + if not self.training_active: + return + if self.training_task: + logging.info('cancelling training task') + if self.training_task.cancel(): + try: + await self.training_task + except asyncio.CancelledError: + pass + logging.info('cancelled training task') + self._may_restart() + + def _may_restart(self) -> None: + """If the environment variable RESTART_AFTER_TRAINING is set, the trainer will restart after a training. + """ + if self._environment_vars.restart_after_training: + logging.info('restarting') + sys.exit(0) + else: + logging.info('not restarting') + # ---------------------------------------- ABSTRACT METHODS ---------------------------------------- + + @abstractmethod + async def _train(self) -> None: + """Should be used to execute a training. + At this point, images are already downloaded to the images_folder and annotations are saved in training.data.image_data. + If a training is continued, the model is already downloaded. + The model should be synchronized with the Learning Loop via self._sync_confusion_matrix() every now and then. + asyncio.CancelledError should be catched and re-raised. + """ + raise NotImplementedError + + @abstractmethod + async def _do_detections(self) -> None: + """Should be used to infer detections of all images and save them to drive. + active_training_io.save_detections(...) should be used to store the detections. + asyncio.CancelledError should be catched and re-raised. + """ + raise NotImplementedError + + @abstractmethod + def _get_new_best_training_state(self) -> Optional[TrainingStateData]: + """Is called frequently by `_sync_confusion_matrix` to check if a new "best" model is availabe. + Returns None if no new model could be found. Otherwise TrainingStateData(confusion_matrix, meta_information). + `confusion_matrix` contains a dict of all classes: + - The classes must be identified by their uuid, not their name. + - For each class a dict with tp, fp, fn is provided (true positives, false positives, false negatives). + `meta_information` can hold any data which is helpful for self._on_metrics_published to store weight file etc for later upload via self.get_model_files + """ + raise NotImplementedError + + @abstractmethod + def _on_metrics_published(self, training_state_data: TrainingStateData) -> None: + """Called after the metrics corresponding to TrainingStateData have been successfully send to the Learning Loop. + Receives the TrainingStateData object which was returned by self._get_new_best_training_state. + If above function returns None, this function is not called. + The respective files for this model should be stored so they can be later uploaded in get_latest_model_files. + """ + raise NotImplementedError + + @abstractmethod + async def _get_latest_model_files(self) -> Dict[str, List[str]]: + """Called when the Learning Loop requests to backup the latest model for the training. + This function is used to __generate and gather__ all files needed for transfering the actual data from the trainer node to the Learning Loop. + In the simplest implementation this method just renames the weight file (e.g. stored in TrainingStateData.meta_information) into a file name like latest_published_model + + The function should return a list of file paths which describe the model per format. + These files must contain all data neccessary for the trainer to resume a training (eg. weight file, hyperparameters, etc.) + and will be stored in the Learning Loop unter the format of this trainer. + Note: by convention the weightfile should be named "model." where extension is the file format of the weightfile. + For example "model.pt" for pytorch or "model.weights" for darknet/yolo. + + If a trainer can also generate other formats (for example for an detector), + a dictionary mapping format -> list of files can be returned. + + If the function returns an empty dict, something went wrong and the model upload will be skipped. + """ + raise NotImplementedError + + @abstractmethod + async def _clear_training_data(self, training_folder: str) -> None: + """Called after a training has finished. Deletes all data that is not needed anymore after a training run. + This can be old weightfiles or any additional files. + """ + raise NotImplementedError diff --git a/learning_loop_node/trainer/trainer_node.py b/learning_loop_node/trainer/trainer_node.py index d2ae3249..f69cf103 100644 --- a/learning_loop_node/trainer/trainer_node.py +++ b/learning_loop_node/trainer/trainer_node.py @@ -1,74 +1,59 @@ import asyncio -import time from dataclasses import asdict -from typing import Dict, Optional, Union +from typing import Dict, Optional -from dacite import from_dict from fastapi.encoders import jsonable_encoder from socketio import AsyncClient -from ..data_classes import Context, NodeState, TrainingState, TrainingStatus -from ..data_classes.socket_response import SocketResponse +from ..data_classes import TrainingStatus from ..node import Node from .io_helpers import LastTrainingIO from .rest import backdoor_controls, controls -from .trainer_logic import TrainerLogic +from .trainer_logic_generic import TrainerLogicGeneric class TrainerNode(Node): - def __init__(self, name: str, trainer_logic: TrainerLogic, uuid: Optional[str] = None, use_backdoor_controls: bool = False): - super().__init__(name, uuid) - trainer_logic._node = self # pylint: disable=protected-access + def __init__(self, name: str, trainer_logic: TrainerLogicGeneric, uuid: Optional[str] = None, use_backdoor_controls: bool = False): + super().__init__(name, uuid, 'trainer') + trainer_logic._node = self self.trainer_logic = trainer_logic self.last_training_io = LastTrainingIO(self.uuid) + self.trainer_logic._last_training_io = self.last_training_io + self.include_router(controls.router, tags=["controls"]) if use_backdoor_controls: self.include_router(backdoor_controls.router, tags=["controls"]) - # --------------------------------------------------- STATUS --------------------------------------------------- - - @property - def progress(self) -> Union[float, None]: - return self.trainer_logic.general_progress if (self.trainer_logic is not None and - hasattr(self.trainer_logic, 'general_progress')) else None - - @property - def training_uptime(self) -> Union[float, None]: - return time.time() - self.trainer_logic.start_time if self.trainer_logic.start_time else None - - # ----------------------------------- LIVECYCLE: ABSTRACT NODE METHODS -------------------------- + # ----------------------------------- NODE LIVECYCLE METHODS -------------------------- async def on_startup(self): pass async def on_shutdown(self): self.log.info('shutdown detected, stopping training') - await self.trainer_logic.shutdown() + await self.trainer_logic.on_shutdown() async def on_repeat(self): try: - if await self.continue_run_if_incomplete(): + if await self.trainer_logic.try_continue_run_if_incomplete(): return # NOTE: we prevent sending idle status after starting a continuation await self.send_status() except Exception as e: if isinstance(e, asyncio.TimeoutError): self.log.warning('timeout when sending status to learning loop, reconnecting sio_client') - await self.sio_client.disconnect() - # NOTE: reconnect happens in node._on_repeat + await self.sio_client.disconnect() # NOTE: reconnect happens in node._on_repeat else: self.log.exception(f'could not send status state: {e}') - # ---------------------------------------------- NODE ABSTRACT METHODS --------------------------------------------------- + # ---------------------------------------------- NODE METHODS --------------------------------------------------- def register_sio_events(self, sio_client: AsyncClient): @sio_client.event async def begin_training(organization: str, project: str, details: Dict): - assert self._sio_client is not None self.log.info('received begin_training from server') - self.trainer_logic.init_new_training(Context(organization=organization, project=project), details) - asyncio.get_event_loop().create_task(self.trainer_logic.run()) + await self.trainer_logic.begin_training(organization, project, details) return True @sio_client.event @@ -81,93 +66,29 @@ async def stop_training(): return True async def send_status(self): - if self._sio_client is None or not self._sio_client.connected: + if not self.sio_client.connected: self.log.warning('cannot send status - not connected to the Learning Loop') return - if not self.trainer_logic.is_initialized: - state_for_learning_loop = str(NodeState.Idle.value) - else: - assert self.trainer_logic.training.training_state is not None - state_for_learning_loop = TrainerNode.state_for_learning_loop( - self.trainer_logic.training.training_state) - status = TrainingStatus(id=self.uuid, name=self.name, - state=state_for_learning_loop, + state=self.trainer_logic.state, errors={}, - uptime=self.training_uptime, - progress=self.progress) + uptime=self.trainer_logic.training_uptime, + progress=self.trainer_logic.general_progress) status.pretrained_models = self.trainer_logic.provided_pretrained_models status.architecture = self.trainer_logic.model_architecture - if self.trainer_logic.is_initialized and self.trainer_logic.training.data: - status.train_image_count = self.trainer_logic.training.data.train_image_count() - status.test_image_count = self.trainer_logic.training.data.test_image_count() - status.skipped_image_count = self.trainer_logic.training.data.skipped_image_count - status.hyperparameters = self.trainer_logic.hyperparameters + if data := self.trainer_logic.training_data: + status.train_image_count = data.train_image_count() + status.test_image_count = data.test_image_count() + status.skipped_image_count = data.skipped_image_count + status.hyperparameters = self.trainer_logic.hyperparameters_for_state_sync status.errors = self.trainer_logic.errors.errors - status.context = self.trainer_logic.training.context + status.context = self.trainer_logic.training_context self.log.info(f'sending status: {status.short_str()}') - result = await self._sio_client.call('update_trainer', jsonable_encoder(asdict(status)), timeout=30) - assert isinstance(result, Dict) - response = from_dict(data_class=SocketResponse, data=result) - - if not response.success: - self.log.error(f'Error when sending status update: Response from loop was:\n {asdict(response)}') - - async def continue_run_if_incomplete(self) -> bool: - if not self.trainer_logic.is_initialized and self.last_training_io.exists(): - self.log.info('found incomplete training, continuing now.') - self.trainer_logic.init_from_last_training() - asyncio.get_event_loop().create_task(self.trainer_logic.run()) - return True - return False - - async def get_state(self): - if self.trainer_logic._executor is not None and self.trainer_logic._executor.is_process_running(): # pylint: disable=protected-access - return NodeState.Running - return NodeState.Idle - - def get_node_type(self): - return 'trainer' - - # --------------------------------------------------- HELPER --------------------------------------------------- - - @staticmethod - def state_for_learning_loop(trainer_state: Union[TrainingState, str]) -> str: - if trainer_state == TrainingState.Initialized: - return 'Training is initialized' - if trainer_state == TrainingState.DataDownloading: - return 'Downloading data' - if trainer_state == TrainingState.DataDownloaded: - return 'Data downloaded' - if trainer_state == TrainingState.TrainModelDownloading: - return 'Downloading model' - if trainer_state == TrainingState.TrainModelDownloaded: - return 'Model downloaded' - if trainer_state == TrainingState.TrainingRunning: - return NodeState.Running - if trainer_state == TrainingState.TrainingFinished: - return 'Training finished' - if trainer_state == TrainingState.Detecting: - return NodeState.Detecting - if trainer_state == TrainingState.ConfusionMatrixSyncing: - return 'Syncing confusion matrix' - if trainer_state == TrainingState.ConfusionMatrixSynced: - return 'Confusion matrix synced' - if trainer_state == TrainingState.TrainModelUploading: - return 'Uploading trained model' - if trainer_state == TrainingState.TrainModelUploaded: - return 'Trained model uploaded' - if trainer_state == TrainingState.Detecting: - return 'calculating detections' - if trainer_state == TrainingState.Detected: - return 'Detections calculated' - if trainer_state == TrainingState.DetectionUploading: - return 'Uploading detections' - if trainer_state == TrainingState.ReadyForCleanup: - return 'Cleaning training' - return 'unknown state' + result = await self.sio_client.call('update_trainer', jsonable_encoder(asdict(status)), timeout=30) + if isinstance(result, Dict) and not result['success']: + self.log.error(f'Error when sending status update: Response from loop was:\n {result}') diff --git a/learning_loop_node/trainer/training_syncronizer.py b/learning_loop_node/trainer/training_syncronizer.py deleted file mode 100644 index 1707d407..00000000 --- a/learning_loop_node/trainer/training_syncronizer.py +++ /dev/null @@ -1,52 +0,0 @@ - -import asyncio -import logging -from dataclasses import asdict -from typing import TYPE_CHECKING, Any - -import socketio -from dacite import from_dict -from fastapi.encoders import jsonable_encoder - -from ..data_classes import TrainingOut -from ..data_classes.socket_response import SocketResponse - -if TYPE_CHECKING: - from .trainer_logic import TrainerLogic - - -async def try_sync_model(trainer: 'TrainerLogic', trainer_node_uuid: str, sio_client: socketio.AsyncClient): - try: - model = trainer.get_new_model() - except Exception as exc: - logging.exception('error while getting new model') - raise Exception(f'Could not get new model: {str(exc)}') from exc - logging.debug(f'new model {model}') - - if model: - response = await sync_model(trainer, trainer_node_uuid, sio_client, model) - - if not response.success: - error_msg = f'Error for update_training: Response from loop was : {asdict(response)}' - logging.error(error_msg) - raise Exception(error_msg) - - -async def sync_model(trainer, trainer_node_uuid, sio_client, model): - current_training = trainer.training - new_training = TrainingOut( - trainer_id=trainer_node_uuid, - confusion_matrix=model.confusion_matrix, - train_image_count=current_training.data.train_image_count(), - test_image_count=current_training.data.test_image_count(), - hyperparameters=trainer.hyperparameters) - - await asyncio.sleep(0.1) # NOTE needed for tests. - - result = await sio_client.call('update_training', (current_training.context.organization, current_training.context.project, jsonable_encoder(new_training))) - response = from_dict(data_class=SocketResponse, data=result) - - if response.success: - logging.info(f'successfully updated training {asdict(new_training)}') - trainer.on_model_published(model) - return response diff --git a/mock_converter/app_code/restart/restart.py b/mock_annotator/app_code/restart/restart.py similarity index 84% rename from mock_converter/app_code/restart/restart.py rename to mock_annotator/app_code/restart/restart.py index f7203baa..915175ed 100644 --- a/mock_converter/app_code/restart/restart.py +++ b/mock_annotator/app_code/restart/restart.py @@ -1,5 +1,2 @@ # add 'reload_dirs=['./app_code/restart'] to uvicorn call in main.py # save this file to trigger uvicorn restart - - -# TODO raus nehmen diff --git a/mock_annotator/start.sh b/mock_annotator/start.sh index e6d3aaac..7814999d 100755 --- a/mock_annotator/start.sh +++ b/mock_annotator/start.sh @@ -5,5 +5,5 @@ if [[ $1 = "debug" ]]; then elif [[ $1 = "profile" ]]; then kernprof -l /app/main.py else - python3 /app/main.py + uvicorn main:node --host 0.0.0.0 --port 80 --reload --lifespan on --reload-dir /app/app_code/restart fi \ No newline at end of file diff --git a/mock_converter.dockerfile b/mock_converter.dockerfile deleted file mode 100644 index 42c883c8..00000000 --- a/mock_converter.dockerfile +++ /dev/null @@ -1,7 +0,0 @@ -FROM base_node:latest - -COPY ./mock_converter/ /app -ENV PYTHONPATH "${PYTHONPATH}:/app:/usr/local/lib/python3.11/site-packages:/learning_loop_node/learning_loop_node" -ENV TZ=Europe/Amsterdam - -EXPOSE 80 diff --git a/mock_converter/app_code/__init__.py b/mock_converter/app_code/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/mock_converter/app_code/backdoor_controls.py b/mock_converter/app_code/backdoor_controls.py deleted file mode 100644 index 6472d4b2..00000000 --- a/mock_converter/app_code/backdoor_controls.py +++ /dev/null @@ -1,55 +0,0 @@ -"""These restful endpoints are only to be used for testing purposes and are not part of the 'offical' trainer behavior.""" - -import asyncio -import logging - -from fastapi import APIRouter, HTTPException, Request - -from learning_loop_node.data_classes import NodeState - -router = APIRouter() - - -@router.put("/socketio") -async def put_socketio(request: Request): - ''' - Example Usage - - curl -X PUT -d "on" http://localhost:8005/socketio - ''' - state = str(await request.body(), 'utf-8') - if state == 'off': - if request.app.status.state != NodeState.Offline: - logging.info('turning socketio off') - asyncio.create_task(request.app.sio.disconnect()) - if state == 'on': - if request.app.status.state == NodeState.Offline: - logging.info('turning socketio on') - asyncio.create_task(request.app.connect()) - - -@router.put("/check_state") -async def put_check_state(request: Request): - value = str(await request.body(), 'utf-8') - print(f'turning automatically check_state {value}', flush=True) - - if value == 'off': - request.app.skip_check_state = True - for _ in range(5): - if request.app.status.state != NodeState.Idle: - await asyncio.sleep(0.5) - else: - break - if request.app.status.state != NodeState.Idle: - raise HTTPException(status_code=409, detail="Could not skip auto checking. State is still not idle") - - if value == 'on': - request.app.skip_check_state = False - - -@router.post("/step") -async def add_steps(request: Request): - if request.app.status.state == NodeState.Running: - raise HTTPException(status_code=409, detail="converter is already running") - - await request.app.check_state() diff --git a/mock_converter/app_code/mock_converter_logic.py b/mock_converter/app_code/mock_converter_logic.py deleted file mode 100644 index 7fc68579..00000000 --- a/mock_converter/app_code/mock_converter_logic.py +++ /dev/null @@ -1,18 +0,0 @@ - -import asyncio -from typing import List - -from learning_loop_node.converter.converter_logic import ConverterLogic -from learning_loop_node.data_classes import ModelInformation - - -class MockConverterLogic(ConverterLogic): - - async def _convert(self, model_information: ModelInformation) -> None: - await asyncio.sleep(1) - - def get_converted_files(self, model_id: str) -> List[str]: - fake_converted_file = '/tmp/converted_weightfile.converted' - with open(fake_converted_file, 'wb') as f: - f.write(b'\x42') - return [fake_converted_file] diff --git a/mock_converter/app_code/tests/.gitkeep b/mock_converter/app_code/tests/.gitkeep deleted file mode 100644 index e69de29b..00000000 diff --git a/mock_converter/app_code/tests/test_dummy.py b/mock_converter/app_code/tests/test_dummy.py deleted file mode 100644 index 1f00624b..00000000 --- a/mock_converter/app_code/tests/test_dummy.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_always_succeed_to_ensure_ci_of_loop_will_not_fail(): - assert True diff --git a/mock_converter/main.py b/mock_converter/main.py deleted file mode 100644 index b8bdb907..00000000 --- a/mock_converter/main.py +++ /dev/null @@ -1,24 +0,0 @@ -import logging -import os - -import uvicorn -from app_code import backdoor_controls -from app_code.mock_converter_logic import MockConverterLogic - -from learning_loop_node.converter.converter_node import ConverterNode - -logging.basicConfig(level=logging.DEBUG) - -mock_converter = MockConverterLogic(source_format='mocked', target_format='mocked_converted') -node = ConverterNode(uuid='85ef1a58-308d-4c80-8931-43d1f752f4f3', name='mocked converter', converter=mock_converter) -node.skip_check_state = True # do not check states auotmatically for this mock - -# setting up backdoor_controls -node.include_router(backdoor_controls.router, prefix="") - - -if __name__ == "__main__": - reload_dirs = ['./app_code/restart'] if os.environ.get('MANUAL_RESTART', None) \ - else ['./app_code', './learning-loop-node', '/usr/local/lib/python3.11/site-packages/learning_loop_node'] - uvicorn.run("main:node", host="0.0.0.0", port=80, lifespan='on', - reload=True, use_colors=True, reload_dirs=reload_dirs) diff --git a/mock_converter/pytest.ini b/mock_converter/pytest.ini deleted file mode 100644 index 0d20a612..00000000 --- a/mock_converter/pytest.ini +++ /dev/null @@ -1,8 +0,0 @@ -[pytest] -# NOTE: changing default location of pytest_cache because the uvicorn file watcher somehow triggered to many reloads -cache_dir = /tmp/pytest_cache -python_files = test_*.py -asyncio_mode = auto - -testpaths = tests - \ No newline at end of file diff --git a/mock_converter/start.sh b/mock_converter/start.sh deleted file mode 100755 index 125eee97..00000000 --- a/mock_converter/start.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/env bash - -uvicorn main:node --host 0.0.0.0 --port 80 --reload --lifespan on --reload-dir /app \ No newline at end of file diff --git a/mock_detector/app_code/tests/test_detector.py b/mock_detector/app_code/tests/test_detector.py index 3d05d99e..75816212 100644 --- a/mock_detector/app_code/tests/test_detector.py +++ b/mock_detector/app_code/tests/test_detector.py @@ -5,6 +5,8 @@ from learning_loop_node.detector.detector_node import DetectorNode from learning_loop_node.globals import GLOBALS +# pylint: disable=unused-argument + @pytest.fixture(scope="session") def event_loop(request): diff --git a/mock_trainer/app_code/mock_trainer_logic.py b/mock_trainer/app_code/mock_trainer_logic.py index b3f1adb5..d293758e 100644 --- a/mock_trainer/app_code/mock_trainer_logic.py +++ b/mock_trainer/app_code/mock_trainer_logic.py @@ -2,11 +2,11 @@ import asyncio import logging import time -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional -from learning_loop_node.data_classes import (BasicModel, BoxDetection, CategoryType, ClassificationDetection, - Detections, ErrorConfiguration, ModelInformation, Point, PointDetection, - PretrainedModel, SegmentationDetection, Shape) +from learning_loop_node.data_classes import (BoxDetection, CategoryType, ClassificationDetection, Detections, + ErrorConfiguration, ModelInformation, Point, PointDetection, + PretrainedModel, SegmentationDetection, Shape, TrainingStateData) from learning_loop_node.trainer.trainer_logic import TrainerLogic from . import progress_simulator @@ -23,28 +23,28 @@ def __init__(self, model_format: str) -> None: self.current_iteration = 0 self.provide_new_model = True - def can_resume(self) -> bool: + def _can_resume(self) -> bool: return False - async def resume(self) -> None: + async def _resume(self) -> None: pass - async def start_training(self) -> None: + async def _start_training_from_base_model(self) -> None: self.current_iteration = 0 if self.error_configuration.begin_training: - raise Exception() - self.executor.start('while true; do sleep 1; done') + raise Exception('Could not start training') + await self.executor.start('/bin/bash -c "while true; do sleep 1; done"') - async def start_training_from_scratch(self, base_model_id: str) -> None: + async def _start_training_from_scratch(self) -> None: self.current_iteration = 0 - self.executor.start('while true; do sleep 1; done') + await self.executor.start('/bin/bash -c "while true; do sleep 1; done"') - def get_executor_error_from_log(self) -> Optional[str]: + def _get_executor_error_from_log(self) -> Optional[str]: if self.error_configuration.crash_training: return 'mocked crash' return None - def get_latest_model_files(self) -> Union[List[str], Dict[str, List[str]]]: + async def _get_latest_model_files(self) -> Dict[str, List[str]]: if self.error_configuration.save_model: raise Exception() @@ -66,37 +66,34 @@ async def _detect(self, model_information: ModelInformation, images: List[str], for image in images: image_id = image.split('/')[-1].replace('.jpg', '') - box_detections = [] - point_detections = [] - segmentation_detections = [] - classification_detections = [] - det_entry = { - 'image_id': image_id, 'box_detections': box_detections, 'point_detections': point_detections, - 'segmentation_detections': segmentation_detections, - 'classification_detections': classification_detections} + box_detections: List[BoxDetection] = [] + point_detections: List[PointDetection] = [] + segmentation_detections: List[SegmentationDetection] = [] + classification_detections: List[ClassificationDetection] = [] + for c in model_information.categories: if c.type == CategoryType.Box: - d = BoxDetection(category_name=c.name, x=1, y=2, width=30, height=40, - model_name=model_information.version, confidence=.99, category_id=c.id) - box_detections.append(d) + bd = BoxDetection(category_name=c.name, x=1, y=2, width=30, height=40, + model_name=model_information.version, confidence=.99, category_id=c.id) + box_detections.append(bd) elif c.type == CategoryType.Point: - d = PointDetection(category_name=c.name, x=100, y=200, - model_name=model_information.version, confidence=.97, category_id=c.id) - point_detections.append(d) + pd = PointDetection(category_name=c.name, x=100, y=200, + model_name=model_information.version, confidence=.97, category_id=c.id) + point_detections.append(pd) elif c.type == CategoryType.Segmentation: - d = SegmentationDetection(category_name=c.name, shape=Shape(points=[Point(x=1, y=2), Point( + sd = SegmentationDetection(category_name=c.name, shape=Shape(points=[Point(x=1, y=2), Point( x=3, y=4)]), model_name=model_information.version, confidence=.96, category_id=c.id) - segmentation_detections.append(d) + segmentation_detections.append(sd) elif c.type == CategoryType.Classification: - d = ClassificationDetection(category_name=c.name, model_name=model_information.version, - confidence=.95, category_id=c.id) - classification_detections.append(d) + cd = ClassificationDetection(category_name=c.name, model_name=model_information.version, + confidence=.95, category_id=c.id) + classification_detections.append(cd) detections.append(Detections(box_detections=box_detections, point_detections=point_detections, segmentation_detections=segmentation_detections, classification_detections=classification_detections, image_id=image_id)) return detections - async def clear_training_data(self, training_folder: str): + async def _clear_training_data(self, training_folder: str): pass @property @@ -111,18 +108,18 @@ def training_progress(self) -> float: print(f'prog. is {self.current_iteration} / {self.max_iterations} = {self.current_iteration / self.max_iterations}') return self.current_iteration / self.max_iterations - def get_new_model(self) -> Optional[BasicModel]: + def _get_new_best_training_state(self) -> Optional[TrainingStateData]: logging.warning('get_new_model called') if self.error_configuration.get_new_model: - raise Exception() + raise Exception('Could not get new model') if not self.provide_new_model: return None self.current_iteration += 1 return progress_simulator.increment_time(self, self.latest_known_confusion_matrix) - def on_model_published(self, basic_model: BasicModel) -> None: - assert isinstance(basic_model.confusion_matrix, Dict) - self.latest_known_confusion_matrix = basic_model.confusion_matrix + def _on_metrics_published(self, training_state_data: TrainingStateData) -> None: + assert isinstance(training_state_data.confusion_matrix, Dict) + self.latest_known_confusion_matrix = training_state_data.confusion_matrix @property def model_architecture(self) -> str: diff --git a/mock_trainer/app_code/progress_simulator.py b/mock_trainer/app_code/progress_simulator.py index 042a0b29..76f8be52 100644 --- a/mock_trainer/app_code/progress_simulator.py +++ b/mock_trainer/app_code/progress_simulator.py @@ -1,11 +1,11 @@ import random from typing import Dict, Optional -from learning_loop_node.data_classes import BasicModel +from learning_loop_node.data_classes import TrainingStateData from learning_loop_node.trainer.trainer_logic import TrainerLogic -def increment_time(trainer: TrainerLogic, latest_known_confusion_matrix: Dict) -> Optional[BasicModel]: +def increment_time(trainer: TrainerLogic, latest_known_confusion_matrix: Dict) -> Optional[TrainingStateData]: if not trainer._training or not trainer._training.data: # pylint: disable=protected-access return None @@ -23,7 +23,7 @@ def increment_time(trainer: TrainerLogic, latest_known_confusion_matrix: Dict) - 'fn': max(random.randint(10-maximum, 10-minimum), 2), } - new_model = BasicModel( + new_model = TrainingStateData( confusion_matrix=confusion_matrix, ) diff --git a/mock_trainer/app_code/tests/conftest.py b/mock_trainer/app_code/tests/conftest.py index 86c62dc2..6c23ca7e 100644 --- a/mock_trainer/app_code/tests/conftest.py +++ b/mock_trainer/app_code/tests/conftest.py @@ -1,5 +1,4 @@ import asyncio -import logging import shutil import pytest @@ -7,6 +6,8 @@ from learning_loop_node.globals import GLOBALS from learning_loop_node.loop_communication import LoopCommunicator +# pylint: disable=redefined-outer-name + @pytest.fixture() async def glc(): diff --git a/mock_trainer/app_code/tests/test_detections.py b/mock_trainer/app_code/tests/test_detections.py index 47781d3d..a1e3b471 100644 --- a/mock_trainer/app_code/tests/test_detections.py +++ b/mock_trainer/app_code/tests/test_detections.py @@ -5,16 +5,17 @@ from learning_loop_node.data_classes import Category, Context from learning_loop_node.globals import GLOBALS +from learning_loop_node.helpers.misc import create_project_folder, generate_training from learning_loop_node.loop_communication import LoopCommunicator -from learning_loop_node.node import Node from learning_loop_node.tests import test_helper -from learning_loop_node.trainer.trainer_logic import TrainerLogic from learning_loop_node.trainer.trainer_node import TrainerNode from ..mock_trainer_logic import MockTrainerLogic +# pylint: disable=protected-access,redefined-outer-name,unused-argument -async def test_all(setup_test_project1, glc: LoopCommunicator): # pylint: disable=unused-argument, redefined-outer-name + +async def test_all(setup_test_project1, glc: LoopCommunicator): assert_image_count(0) assert GLOBALS.data_folder == '/tmp/learning_loop_lib_data' @@ -29,14 +30,14 @@ async def test_all(setup_test_project1, glc: LoopCommunicator): # pylint: disab 'resolution': 800, 'flip_rl': False, 'flip_ud': False} - trainer._node = node # pylint: disable=protected-access - trainer.init_new_training(context=context, details=details) - - project_folder = Node.create_project_folder(context) - training = TrainerLogic.generate_training(project_folder, context) - training.model_id_for_detecting = latest_model_id - trainer._training = training # pylint: disable=protected-access - await trainer._do_detections() # pylint: disable=protected-access + trainer._node = node + trainer._init_new_training(context=context, details=details) + + project_folder = create_project_folder(context) + training = generate_training(project_folder, context) + training.model_uuid_for_detecting = latest_model_id + trainer._training = training + await trainer._do_detections() detections = trainer.active_training_io.load_detections() assert_image_count(10) # TODO This assert fails frequently on Drone diff --git a/mock_trainer/app_code/tests/test_mock_trainer.py b/mock_trainer/app_code/tests/test_mock_trainer.py index 20e43931..e2b518b0 100644 --- a/mock_trainer/app_code/tests/test_mock_trainer.py +++ b/mock_trainer/app_code/tests/test_mock_trainer.py @@ -1,23 +1,25 @@ from typing import Dict from uuid import uuid4 -from learning_loop_node.data_classes import (Context, Model, Training, - TrainingData) +from learning_loop_node.data_classes import Context, Model, Training, TrainingData from learning_loop_node.globals import GLOBALS from learning_loop_node.trainer.executor import Executor from ..mock_trainer_logic import MockTrainerLogic +# pylint: disable=protected-access +# pylint: disable=unused-argument + async def create_mock_trainer() -> MockTrainerLogic: mock_trainer = MockTrainerLogic(model_format='mocked') - mock_trainer._executor = Executor(GLOBALS.data_folder) # pylint: disable=protected-access + mock_trainer._executor = Executor(GLOBALS.data_folder) return mock_trainer async def test_get_model_files(setup_test_project2): mock_trainer = await create_mock_trainer() - files = mock_trainer.get_latest_model_files() + files = await mock_trainer._get_latest_model_files() assert isinstance(files, Dict) @@ -28,7 +30,7 @@ async def test_get_model_files(setup_test_project2): async def test_get_new_model(setup_test_project2): mock_trainer = await create_mock_trainer() - await mock_trainer.start_training() + await mock_trainer._start_training_from_base_model() model = Model(uuid=(str(uuid4()))) context = Context(organization="", project="") @@ -39,5 +41,5 @@ async def test_get_new_model(setup_test_project2): images_folder="", training_folder="",) mock_trainer.training.data = TrainingData(image_data=[], categories=[]) - model = mock_trainer.get_new_model() + model = mock_trainer._get_new_best_training_state() assert model is not None