diff --git a/lib/charms/observability_libs/v1/cert_handler.py b/lib/charms/observability_libs/v1/cert_handler.py index b0c9c6e..6e693ff 100644 --- a/lib/charms/observability_libs/v1/cert_handler.py +++ b/lib/charms/observability_libs/v1/cert_handler.py @@ -59,7 +59,7 @@ import logging from ops.charm import CharmBase -from ops.framework import EventBase, EventSource, Object, ObjectEvents +from ops.framework import BoundEvent, EventBase, EventSource, Object, ObjectEvents, StoredState from ops.jujuversion import JujuVersion from ops.model import Relation, Secret, SecretNotFoundError @@ -67,7 +67,7 @@ LIBID = "b5cd5cd580f3428fa5f59a8876dcbe6a" LIBAPI = 1 -LIBPATCH = 11 +LIBPATCH = 12 VAULT_SECRET_LABEL = "cert-handler-private-vault" @@ -273,6 +273,7 @@ class CertHandler(Object): """A wrapper for the requirer side of the TLS Certificates charm library.""" on = CertHandlerEvents() # pyright: ignore + _stored = StoredState() def __init__( self, @@ -283,6 +284,7 @@ def __init__( peer_relation_name: str = "peers", cert_subject: Optional[str] = None, sans: Optional[List[str]] = None, + refresh_events: Optional[List[BoundEvent]] = None, ): """CertHandler is used to wrap TLS Certificates management operations for charms. @@ -299,8 +301,17 @@ def __init__( Must match metadata.yaml. cert_subject: Custom subject. Name collisions are under the caller's responsibility. sans: DNS names. If none are given, use FQDN. + refresh_events: an optional list of bound events which + will be observed to replace the current CSR with a new one + if there are changes in the CSR's DNS SANs or IP SANs. + Then, subsequently, replace its corresponding certificate with a new one. """ super().__init__(charm, key) + # use StoredState to store the hash of the CSR + # to potentially trigger a CSR renewal on `refresh_events` + self._stored.set_default( + csr_hash=None, + ) self.charm = charm # We need to sanitize the unit name, otherwise route53 complains: @@ -355,6 +366,15 @@ def __init__( self._on_upgrade_charm, ) + if refresh_events: + for ev in refresh_events: + self.framework.observe(ev, self._on_refresh_event) + + def _on_refresh_event(self, _): + """Replace the latest current CSR with a new one if there are any SANs changes.""" + if self._stored.csr_hash != self._csr_hash: + self._generate_csr(renew=True) + def _on_upgrade_charm(self, _): has_privkey = self.vault.get_value("private-key") @@ -419,6 +439,20 @@ def enabled(self) -> bool: return True + @property + def _csr_hash(self) -> int: + """A hash of the config that constructs the CSR. + + Only include here the config options that, should they change, should trigger a renewal of + the CSR. + """ + return hash( + ( + tuple(self.sans_dns), + tuple(self.sans_ip), + ) + ) + @property def available(self) -> bool: """Return True if all certs are available in relation data; False otherwise.""" @@ -484,6 +518,8 @@ def _generate_csr( ) self.certificates.request_certificate_creation(certificate_signing_request=csr) + self._stored.csr_hash = self._csr_hash + if clear_cert: self.vault.clear() diff --git a/tests/scenario/test_cert_handler/test_cert_handler_v1.py b/tests/scenario/test_cert_handler/test_cert_handler_v1.py index db3ad3c..589995f 100644 --- a/tests/scenario/test_cert_handler/test_cert_handler_v1.py +++ b/tests/scenario/test_cert_handler/test_cert_handler_v1.py @@ -1,8 +1,13 @@ import socket import sys +from contextlib import contextmanager from pathlib import Path +from unittest.mock import patch import pytest +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.x509.oid import ExtensionOID from ops import CharmBase from scenario import Context, PeerRelation, Relation, State @@ -12,6 +17,7 @@ libs = str(Path(__file__).parent.parent.parent.parent / "lib") sys.path.append(libs) +MOCK_HOSTNAME = "mock-hostname" class MyCharm(CharmBase): @@ -22,8 +28,28 @@ class MyCharm(CharmBase): def __init__(self, fw): super().__init__(fw) + sans = [socket.getfqdn()] + if hostname := self._mock_san: + sans.append(hostname) - self.ch = CertHandler(self, key="ch", sans=[socket.getfqdn()]) + self.ch = CertHandler(self, key="ch", sans=sans, refresh_events=[self.on.config_changed]) + + @property + def _mock_san(self): + """This property is meant to be mocked to return a mock string hostname to be used as SAN. + + By default, it returns None. + """ + return None + + +def get_csr_obj(csr: str): + return x509.load_pem_x509_csr(csr.encode(), default_backend()) + + +def get_sans_from_csr(csr): + san_extension = csr.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME) + return set(san_extension.value.get_values_for_type(x509.DNSName)) @pytest.fixture @@ -36,6 +62,20 @@ def certificates(): return Relation("certificates") +@contextmanager +def _sans_patch(hostname=MOCK_HOSTNAME): + with patch.object(MyCharm, "_mock_san", hostname): + yield + + +@contextmanager +def _cert_renew_patch(): + with patch( + "charms.tls_certificates_interface.v3.tls_certificates.TLSCertificatesRequiresV3.request_certificate_renewal" + ) as patcher: + yield patcher + + @pytest.mark.parametrize("leader", (True, False)) def test_cert_joins(ctx, certificates, leader): with ctx.manager( @@ -72,3 +112,65 @@ def test_cert_joins_peer_vault_backend(ctx_juju2, certificates, leader): ) as mgr: mgr.run() assert mgr.charm.ch.private_key + + +def test_renew_csr_on_sans_change(ctx, certificates): + # generate a CSR + with ctx.manager( + certificates.joined_event, + State(leader=True, relations=[certificates]), + ) as mgr: + charm = mgr.charm + state_out = mgr.run() + orig_csr = get_csr_obj(charm.ch._csr) + assert get_sans_from_csr(orig_csr) == {socket.getfqdn()} + + # trigger a config_changed with a modified SAN + with _sans_patch(): + with ctx.manager("config_changed", state_out) as mgr: + charm = mgr.charm + state_out = mgr.run() + csr = get_csr_obj(charm.ch._csr) + # assert CSR contains updated SAN + assert get_sans_from_csr(csr) == {socket.getfqdn(), MOCK_HOSTNAME} + + +def test_csr_no_change_on_wrong_refresh_event(ctx, certificates): + with _cert_renew_patch() as renew_patch: + with ctx.manager( + "config_changed", + State(leader=True, relations=[certificates]), + ) as mgr: + charm = mgr.charm + state_out = mgr.run() + orig_csr = get_csr_obj(charm.ch._csr) + assert get_sans_from_csr(orig_csr) == {socket.getfqdn()} + + with _sans_patch(): + with _cert_renew_patch() as renew_patch: + with ctx.manager("update_status", state_out) as mgr: + charm = mgr.charm + state_out = mgr.run() + csr = get_csr_obj(charm.ch._csr) + assert get_sans_from_csr(csr) == {socket.getfqdn()} + assert renew_patch.call_count == 0 + + +def test_csr_no_change(ctx, certificates): + + with ctx.manager( + "config_changed", + State(leader=True, relations=[certificates]), + ) as mgr: + charm = mgr.charm + state_out = mgr.run() + orig_csr = get_csr_obj(charm.ch._csr) + assert get_sans_from_csr(orig_csr) == {socket.getfqdn()} + + with _cert_renew_patch() as renew_patch: + with ctx.manager("config_changed", state_out) as mgr: + charm = mgr.charm + state_out = mgr.run() + csr = get_csr_obj(charm.ch._csr) + assert get_sans_from_csr(csr) == {socket.getfqdn()} + assert renew_patch.call_count == 0 diff --git a/tests/unit/test_kubernetes_compute_resources.py b/tests/unit/test_kubernetes_compute_resources.py index e64df26..c4830cf 100644 --- a/tests/unit/test_kubernetes_compute_resources.py +++ b/tests/unit/test_kubernetes_compute_resources.py @@ -2,9 +2,10 @@ # See LICENSE file for licensing details. import unittest from unittest import mock -from unittest.mock import MagicMock, Mock +from unittest.mock import MagicMock, Mock, patch import httpx +import tenacity import yaml from charms.observability_libs.v0.kubernetes_compute_resources_patch import ( KubernetesComputeResourcesPatch, @@ -16,12 +17,22 @@ from ops import BlockedStatus, WaitingStatus from ops.charm import CharmBase from ops.testing import Harness +from pytest import fixture from tests.unit.helpers import PROJECT_DIR CL_PATH = "charms.observability_libs.v0.kubernetes_compute_resources_patch.KubernetesComputeResourcesPatch" +@fixture(autouse=True) +def patch_retry(): + with patch.multiple( + KubernetesComputeResourcesPatch, + PATCH_RETRY_STOP=tenacity.stop_after_delay(0), + ): + yield + + class TestKubernetesComputeResourcesPatch(unittest.TestCase): class _TestCharm(CharmBase): def __init__(self, *args): diff --git a/tox.ini b/tox.ini index e7f1122..722427f 100644 --- a/tox.ini +++ b/tox.ini @@ -111,5 +111,6 @@ allowlist_externals = rm commands = charmcraft fetch-lib charms.tls_certificates_interface.v2.tls_certificates + charmcraft fetch-lib charms.tls_certificates_interface.v3.tls_certificates pytest -v --tb native {[vars]tst_path}/scenario --log-cli-level=INFO -s {posargs} rm -rf ./lib/charms/tls_certificates_interface