From 6a7018145aaa50973ca3ef35cb4777c805fe83e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Guillot?= Date: Fri, 1 Mar 2024 17:59:22 -0800 Subject: [PATCH] Add more specific exceptions --- .gitignore | 1 + LICENSE.txt | 2 +- miniflux.py | 146 +++++++++++++++++++++++++++++-------------- pyproject.toml | 2 +- tests/test_client.py | 94 ++++++++++++++++++++++++++-- 5 files changed, 190 insertions(+), 55 deletions(-) diff --git a/.gitignore b/.gitignore index 0bf74f7..fe7714d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ __pycache__ .cache +.idea .venv .vscode *.egg* diff --git a/LICENSE.txt b/LICENSE.txt index 3d8b343..aa951f9 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,6 +1,6 @@ The MIT License (MIT) -Copyright (c) 2018-2023 Frédéric Guillot +Copyright (c) 2018-2024 Frédéric Guillot Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/miniflux.py b/miniflux.py index 42b5c06..fc9861c 100644 --- a/miniflux.py +++ b/miniflux.py @@ -1,6 +1,6 @@ # The MIT License (MIT) # -# Copyright (c) 2018-2023 Frederic Guillot +# Copyright (c) 2018-2024 Frederic Guillot # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -55,6 +55,41 @@ def get_error_reason(self) -> str: return default_reason +class ResourceNotFound(ClientError): + """ + Exception raised when the API client receives a 404 response from the server. + """ + pass + + +class AccessForbidden(ClientError): + """ + Exception raised when the API client receives a 403 response from the server. + """ + pass + + +class AccessUnauthorized(ClientError): + """ + Exception raised when the API client receives a 401 response from the server. + """ + pass + + +class BadRequest(ClientError): + """ + Exception raised when the API client receives a 400 response from the server. + """ + pass + + +class ServerError(ClientError): + """ + Exception raised when the API client receives a 500 response from the server. + """ + pass + + class Client: """ Miniflux API client. @@ -64,10 +99,10 @@ class Client: def __init__( self, base_url: str, - username: str = None, - password: str = None, + username: Optional[str] = None, + password: Optional[str] = None, timeout: float = 30.0, - api_key: str = None, + api_key: Optional[str] = None, user_agent: str = DEFAULT_USER_AGENT, ): self._base_url = base_url @@ -75,7 +110,7 @@ def __init__( self._username = username self._password = password self._timeout = timeout - self._auth = (self._username, self._password) if not api_key else None + self._auth: Optional[tuple] = (self._username, self._password) if not api_key else None self._headers = {"User-Agent": user_agent} if api_key: self._headers["X-Auth-Token"] = api_key @@ -93,6 +128,19 @@ def _get_params(self, **kwargs) -> Optional[Dict]: def _get_modification_params(self, **kwargs) -> Dict: return {k: v for k, v in kwargs.items() if v is not None} + def _handle_error_response(self, response: requests.Response): + if response.status_code == 404: + raise ResourceNotFound(response) + if response.status_code == 403: + raise AccessForbidden(response) + if response.status_code == 401: + raise AccessUnauthorized(response) + if response.status_code == 400: + raise BadRequest(response) + if response.status_code >= 500: + raise ServerError(response) + raise ClientError(response) + def flush_history(self) -> bool: """ Mark all read entries as removed excepted the starred ones. @@ -104,7 +152,9 @@ def flush_history(self) -> bool: response = requests.delete( endpoint, headers=self._headers, auth=self._auth, timeout=self._timeout ) - return response.status_code == 202 + if response.status_code == 202: + return True + self._handle_error_response(response) def get_version(self) -> Dict: """ @@ -121,7 +171,7 @@ def get_version(self) -> Dict: ) if response.status_code == 200: return response.json() - raise ClientError(response) + self._handle_error_response(response) def me(self) -> Dict: """ @@ -156,7 +206,7 @@ def me(self) -> Dict: ) if response.status_code == 200: return response.json() - raise ClientError(response) + self._handle_error_response(response) def export(self) -> str: """ @@ -184,7 +234,7 @@ def export_feeds(self) -> str: ) if response.status_code == 200: return response.text - raise ClientError(response) + self._handle_error_response(response) def import_feeds(self, opml: str) -> Dict: """ @@ -208,7 +258,7 @@ def import_feeds(self, opml: str) -> Dict: ) if response.status_code == 201: return response.json() - raise ClientError(response) + self._handle_error_response(response) def discover(self, website_url: str, **kwargs) -> List[Dict]: """ @@ -234,7 +284,7 @@ def discover(self, website_url: str, **kwargs) -> List[Dict]: ) if response.status_code == 200: return response.json() - raise ClientError(response) + self._handle_error_response(response) def get_category_feeds(self, category_id: int) -> List[Dict]: """ @@ -253,7 +303,7 @@ def get_category_feeds(self, category_id: int) -> List[Dict]: ) if response.status_code == 200: return response.json() - raise ClientError(response) + self._handle_error_response(response) def get_feeds(self) -> List[Dict]: """ @@ -270,7 +320,7 @@ def get_feeds(self) -> List[Dict]: ) if response.status_code == 200: return response.json() - raise ClientError(response) + self._handle_error_response(response) def get_feed(self, feed_id: int) -> Dict: """ @@ -289,7 +339,7 @@ def get_feed(self, feed_id: int) -> Dict: ) if response.status_code == 200: return response.json() - raise ClientError(response) + self._handle_error_response(response) def get_feed_icon(self, feed_id: int) -> Dict: """ @@ -308,7 +358,7 @@ def get_feed_icon(self, feed_id: int) -> Dict: ) if response.status_code == 200: return response.json() - raise ClientError(response) + self._handle_error_response(response) def get_icon(self, icon_id: int) -> Dict: """ @@ -327,7 +377,7 @@ def get_icon(self, icon_id: int) -> Dict: ) if response.status_code == 200: return response.json() - raise ClientError(response) + self._handle_error_response(response) def get_icon_by_feed_id(self, feed_id: int) -> Dict: """ @@ -342,7 +392,7 @@ def get_icon_by_feed_id(self, feed_id: int) -> Dict: """ return self.get_feed_icon(feed_id) - def create_feed(self, feed_url: str, category_id: int = None, **kwargs) -> int: + def create_feed(self, feed_url: str, category_id: Optional[int] = None, **kwargs) -> int: """ Create a new feed. @@ -367,7 +417,7 @@ def create_feed(self, feed_url: str, category_id: int = None, **kwargs) -> int: ) if response.status_code == 201: return response.json()["feed_id"] - raise ClientError(response) + self._handle_error_response(response) def update_feed(self, feed_id: int, **kwargs) -> Dict: """ @@ -391,7 +441,7 @@ def update_feed(self, feed_id: int, **kwargs) -> Dict: ) if response.status_code == 201: return response.json() - raise ClientError(response) + self._handle_error_response(response) def refresh_all_feeds(self) -> bool: """ @@ -407,7 +457,7 @@ def refresh_all_feeds(self) -> bool: endpoint, headers=self._headers, auth=self._auth, timeout=self._timeout ) if response.status_code >= 400: - raise ClientError(response) + self._handle_error_response(response) return True def refresh_feed(self, feed_id: int) -> bool: @@ -426,7 +476,7 @@ def refresh_feed(self, feed_id: int) -> bool: endpoint, headers=self._headers, auth=self._auth, timeout=self._timeout ) if response.status_code >= 400: - raise ClientError(response) + self._handle_error_response(response) return True def refresh_category(self, category_id: int) -> bool: @@ -445,7 +495,7 @@ def refresh_category(self, category_id: int) -> bool: endpoint, headers=self._headers, auth=self._auth, timeout=self._timeout ) if response.status_code >= 400: - raise ClientError(response) + self._handle_error_response(response) return True def delete_feed(self, feed_id: int) -> None: @@ -462,7 +512,7 @@ def delete_feed(self, feed_id: int) -> None: endpoint, headers=self._headers, auth=self._auth, timeout=self._timeout ) if response.status_code != 204: - raise ClientError(response) + self._handle_error_response(response) def get_feed_entry(self, feed_id: int, entry_id: int) -> Dict: """ @@ -482,7 +532,7 @@ def get_feed_entry(self, feed_id: int, entry_id: int) -> Dict: ) if response.status_code == 200: return response.json() - raise ClientError(response) + self._handle_error_response(response) def get_feed_entries(self, feed_id: int, **kwargs) -> Dict: """ @@ -506,7 +556,7 @@ def get_feed_entries(self, feed_id: int, **kwargs) -> Dict: ) if response.status_code == 200: return response.json() - raise ClientError(response) + self._handle_error_response(response) def mark_feed_entries_as_read(self, feed_id: int) -> None: """ @@ -524,7 +574,7 @@ def mark_feed_entries_as_read(self, feed_id: int) -> None: endpoint, headers=self._headers, auth=self._auth, timeout=self._timeout ) if response.status_code != 204: - raise ClientError(response) + self._handle_error_response(response) def get_entry(self, entry_id: int) -> Dict: """ @@ -543,7 +593,7 @@ def get_entry(self, entry_id: int) -> Dict: ) if response.status_code == 200: return response.json() - raise ClientError(response) + self._handle_error_response(response) def get_entries(self, **kwargs) -> Dict: """ @@ -565,9 +615,9 @@ def get_entries(self, **kwargs) -> Dict: ) if response.status_code == 200: return response.json() - raise ClientError(response) + self._handle_error_response(response) - def update_entry(self, entry_id: int, title: str = None, content: str = None) -> Dict: + def update_entry(self, entry_id: int, title: Optional[str] = None, content: Optional[str] = None) -> Dict: """ Update an entry. @@ -594,7 +644,7 @@ def update_entry(self, entry_id: int, title: str = None, content: str = None) -> ) if response.status_code == 201: return response.json() - raise ClientError(response) + self._handle_error_response(response) def update_entries(self, entry_ids: List[int], status: str) -> bool: """ @@ -618,7 +668,7 @@ def update_entries(self, entry_ids: List[int], status: str) -> bool: timeout=self._timeout, ) if response.status_code >= 400: - raise ClientError(response) + self._handle_error_response(response) return True def fetch_entry_content(self, entry_id: int) -> Dict: @@ -638,7 +688,7 @@ def fetch_entry_content(self, entry_id: int) -> Dict: ) if response.status_code == 200: return response.json() - raise ClientError(response) + self._handle_error_response(response) def toggle_bookmark(self, entry_id: int) -> bool: """ @@ -656,7 +706,7 @@ def toggle_bookmark(self, entry_id: int) -> bool: endpoint, headers=self._headers, auth=self._auth, timeout=self._timeout ) if response.status_code >= 400: - raise ClientError(response) + self._handle_error_response(response) return True def save_entry(self, entry_id: int) -> bool: @@ -675,7 +725,7 @@ def save_entry(self, entry_id: int) -> bool: endpoint, headers=self._headers, auth=self._auth, timeout=self._timeout ) if response.status_code != 202: - raise ClientError(response) + self._handle_error_response(response) return True def get_categories(self) -> List[Dict]: @@ -693,7 +743,7 @@ def get_categories(self) -> List[Dict]: ) if response.status_code == 200: return response.json() - raise ClientError(response) + self._handle_error_response(response) def get_category_entry(self, category_id: int, entry_id: int) -> Dict: """ @@ -713,7 +763,7 @@ def get_category_entry(self, category_id: int, entry_id: int) -> Dict: ) if response.status_code == 200: return response.json() - raise ClientError(response) + self._handle_error_response(response) def get_category_entries(self, category_id: int, **kwargs) -> Dict: """ @@ -737,7 +787,7 @@ def get_category_entries(self, category_id: int, **kwargs) -> Dict: ) if response.status_code == 200: return response.json() - raise ClientError(response) + self._handle_error_response(response) def create_category(self, title: str) -> Dict: """ @@ -761,7 +811,7 @@ def create_category(self, title: str) -> Dict: ) if response.status_code == 201: return response.json() - raise ClientError(response) + self._handle_error_response(response) def update_category(self, category_id: int, title: str) -> Dict: """ @@ -786,7 +836,7 @@ def update_category(self, category_id: int, title: str) -> Dict: ) if response.status_code == 201: return response.json() - raise ClientError(response) + self._handle_error_response(response) def delete_category(self, category_id: int) -> None: """ @@ -802,7 +852,7 @@ def delete_category(self, category_id: int) -> None: endpoint, headers=self._headers, auth=self._auth, timeout=self._timeout ) if response.status_code != 204: - raise ClientError(response) + self._handle_error_response(response) def mark_category_entries_as_read(self, category_id: int) -> None: """ @@ -818,7 +868,7 @@ def mark_category_entries_as_read(self, category_id: int) -> None: endpoint, headers=self._headers, auth=self._auth, timeout=self._timeout ) if response.status_code != 204: - raise ClientError(response) + self._handle_error_response(response) def get_users(self) -> List[Dict]: """ @@ -835,7 +885,7 @@ def get_users(self) -> List[Dict]: ) if response.status_code == 200: return response.json() - raise ClientError(response) + self._handle_error_response(response) def get_user_by_id(self, user_id: int) -> Dict: """ @@ -870,7 +920,7 @@ def _get_user(self, user_id_or_username: Union[str, int]) -> Dict: ) if response.status_code == 200: return response.json() - raise ClientError(response) + self._handle_error_response(response) def create_user(self, username: str, password: str, is_admin: bool = False) -> Dict: """ @@ -896,7 +946,7 @@ def create_user(self, username: str, password: str, is_admin: bool = False) -> D ) if response.status_code == 201: return response.json() - raise ClientError(response) + self._handle_error_response(response) def update_user(self, user_id: int, **kwargs) -> Dict: """ @@ -920,7 +970,7 @@ def update_user(self, user_id: int, **kwargs) -> Dict: ) if response.status_code == 201: return response.json() - raise ClientError(response) + self._handle_error_response(response) def delete_user(self, user_id: int) -> None: """ @@ -936,7 +986,7 @@ def delete_user(self, user_id: int) -> None: endpoint, headers=self._headers, auth=self._auth, timeout=self._timeout ) if response.status_code != 204: - raise ClientError(response) + self._handle_error_response(response) def mark_user_entries_as_read(self, user_id: int) -> None: """ @@ -952,7 +1002,7 @@ def mark_user_entries_as_read(self, user_id: int) -> None: endpoint, headers=self._headers, auth=self._auth, timeout=self._timeout ) if response.status_code != 204: - raise ClientError(response) + self._handle_error_response(response) def get_feed_counters(self) -> Dict: """ @@ -969,4 +1019,4 @@ def get_feed_counters(self) -> Dict: ) if response.status_code == 200: return response.json() - raise ClientError(response) + self._handle_error_response(response) diff --git a/pyproject.toml b/pyproject.toml index c33c9b6..3d783d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "miniflux" -version = "1.0.0" +version = "1.0.1" description = "Client library for Miniflux" readme = "README.rst" requires-python = ">=3.7" diff --git a/tests/test_client.py b/tests/test_client.py index 309f1f9..c1ae40e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,6 +1,6 @@ # The MIT License (MIT) # -# Copyright (c) 2018-2023 Frederic Guillot +# Copyright (c) 2018-2024 Frederic Guillot # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -26,7 +26,7 @@ from unittest import mock import miniflux -from miniflux import ClientError +from miniflux import AccessForbidden, AccessUnauthorized, BadRequest, ClientError, ResourceNotFound, ServerError from requests.exceptions import Timeout @@ -35,7 +35,7 @@ def test_get_error_reason(self): response = mock.Mock() response.status_code = 404 response.json.return_value = {"error_message": "some error"} - error = ClientError(response) + error = ResourceNotFound(response) self.assertEqual(error.status_code, 404) self.assertEqual(error.get_error_reason(), "some error") @@ -43,7 +43,7 @@ def test_get_error_without_reason(self): response = mock.Mock() response.status_code = 404 response.json.return_value = {} - error = ClientError(response) + error = ResourceNotFound(response) self.assertEqual(error.status_code, 404) self.assertEqual(error.get_error_reason(), "status_code=404") @@ -51,7 +51,7 @@ def test_get_error_with_bad_response(self): response = mock.Mock() response.status_code = 404 response.json.return_value = None - error = ClientError(response) + error = ResourceNotFound(response) self.assertEqual(error.status_code, 404) self.assertEqual(error.get_error_reason(), "status_code=404") @@ -930,6 +930,20 @@ def test_get_user_by_id(self): assert result == expected_result + def test_get_inexisting_user(self): + requests = _get_request_mock() + + response = mock.Mock() + response.status_code = 404 + response.json.return_value = {"error_message": "some error"} + + requests.get.return_value = response + + client = miniflux.Client("http://localhost", "username", "password") + + with self.assertRaises(ResourceNotFound): + client.get_user_by_id(123) + def test_get_user_by_username(self): requests = _get_request_mock() expected_result = [] @@ -1160,6 +1174,76 @@ def test_update_entries_status(self): self.assertEqual(payload.get("status"), "read") self.assertTrue(result) + def test_not_found_response(self): + requests = _get_request_mock() + + response = mock.Mock() + response.status_code = 404 + response.json.return_value = {"error_message": "Not found"} + + requests.get.return_value = response + + client = miniflux.Client("http://localhost", "username", "password") + + with self.assertRaises(ResourceNotFound): + client.get_version() + + def test_unauthorized_response(self): + requests = _get_request_mock() + + response = mock.Mock() + response.status_code = 401 + response.json.return_value = {"error_message": "Unauthorized"} + + requests.get.return_value = response + + client = miniflux.Client("http://localhost", "username", "password") + + with self.assertRaises(AccessUnauthorized): + client.get_version() + + def test_forbidden_response(self): + requests = _get_request_mock() + + response = mock.Mock() + response.status_code = 403 + response.json.return_value = {"error_message": "Forbidden"} + + requests.get.return_value = response + + client = miniflux.Client("http://localhost", "username", "password") + + with self.assertRaises(AccessForbidden): + client.get_version() + + def test_bad_request_response(self): + requests = _get_request_mock() + + response = mock.Mock() + response.status_code = 400 + response.json.return_value = {"error_message": "Bad request"} + + requests.get.return_value = response + + client = miniflux.Client("http://localhost", "username", "password") + + with self.assertRaises(BadRequest): + client.get_version() + + def test_server_error_response(self): + requests = _get_request_mock() + + response = mock.Mock() + response.status_code = 500 + response.json.return_value = {"error_message": "Server error"} + + requests.get.return_value = response + + client = miniflux.Client("http://localhost", "username", "password") + + with self.assertRaises(ServerError): + client.get_version() + def _get_request_mock(): patcher = mock.patch("miniflux.requests")