diff --git a/tests/osquery/test_setup_enrollments.py b/tests/osquery/test_setup_enrollments.py index 7af91a58cf..2aec0fab16 100644 --- a/tests/osquery/test_setup_enrollments.py +++ b/tests/osquery/test_setup_enrollments.py @@ -64,7 +64,7 @@ def test_create_enrollment_permission_denied(self): @patch("zentral.contrib.osquery.forms.get_osquery_versions") def test_create_enrollment_view_get(self, get_osquery_versions): - get_osquery_versions.returns = {} + get_osquery_versions.returns = [] self._login("osquery.add_enrollment") configuration = self._force_configuration() response = self.client.get(reverse("osquery:create_enrollment", args=(configuration.pk,))) @@ -76,7 +76,7 @@ def test_create_enrollment_view_get(self, get_osquery_versions): @patch("zentral.contrib.osquery.forms.get_osquery_versions") def test_create_enrollment_view_post(self, get_osquery_versions): - get_osquery_versions.returns = {} + get_osquery_versions.returns = [] self._login("osquery.add_enrollment", "osquery.view_configuration", "osquery.view_enrollment") configuration = self._force_configuration() response = self.client.post(reverse("osquery:create_enrollment", args=(configuration.pk,)), @@ -93,6 +93,25 @@ def test_create_enrollment_view_post(self, get_osquery_versions): self.assertContains(response, reverse(f"osquery_api:{view_name}", args=(enrollment.pk,))) get_osquery_versions.assert_called_once_with() + @patch("zentral.contrib.osquery.releases.requests.get") + def test_create_enrollment_view_get_osquery_versions_error_post(self, requests_get): + requests_get.side_effect = RuntimeError("YOLO") + self._login("osquery.add_enrollment", "osquery.view_configuration", "osquery.view_enrollment") + configuration = self._force_configuration() + response = self.client.post(reverse("osquery:create_enrollment", args=(configuration.pk,)), + {"secret-meta_business_unit": self.mbu.pk, + "configuration": configuration.pk, + "osquery_release": ""}, follow=True) + self.assertEqual(response.status_code, 200) + self.assertTemplateUsed(response, "osquery/configuration_detail.html") + self.assertEqual(response.context["object"], configuration) + enrollment = response.context["enrollments"][0] + self.assertEqual(enrollment.version, 1) + self.assertContains(response, enrollment.secret.meta_business_unit.name) + for view_name in ("enrollment_package", "enrollment_script", "enrollment_powershell_script"): + self.assertContains(response, reverse(f"osquery_api:{view_name}", args=(enrollment.pk,))) + requests_get.assert_called_once() + # bump enrollment version def test_bump_enrollment_version_redirect(self): diff --git a/zentral/contrib/osquery/releases.py b/zentral/contrib/osquery/releases.py index b2dbc5f016..f7f67a990c 100644 --- a/zentral/contrib/osquery/releases.py +++ b/zentral/contrib/osquery/releases.py @@ -4,7 +4,6 @@ import tempfile import requests from urllib.parse import urlparse -from requests.exceptions import ConnectionError, HTTPError from zentral.utils.local_dir import get_and_create_local_dir @@ -25,13 +24,13 @@ def get_osquery_versions(ignore_draft_release=True, check_urls=True, last=3): + versions = [] try: resp = requests.get(GITHUB_API_URL, timeout=2) resp.raise_for_status() - except (ConnectionError, HTTPError): + except Exception: logger.exception("Could not get versions from Github.") - return - versions = [] + return versions releases = resp.json() if last: # limit releases to check