Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add refresh_events to CertHandler #108

Merged
merged 8 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions lib/charms/observability_libs/v1/cert_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,15 @@
import logging

from ops.charm import CharmBase
from ops.framework import EventBase, EventSource, Object, ObjectEvents
from ops.framework import BoundEvent, EventBase, EventSource, Object, ObjectEvents, StoredState
from ops.jujuversion import JujuVersion
from ops.model import Relation, Secret, SecretNotFoundError

logger = logging.getLogger(__name__)

LIBID = "b5cd5cd580f3428fa5f59a8876dcbe6a"
LIBAPI = 1
LIBPATCH = 11
LIBPATCH = 12

VAULT_SECRET_LABEL = "cert-handler-private-vault"

Expand Down Expand Up @@ -273,6 +273,7 @@ class CertHandler(Object):
"""A wrapper for the requirer side of the TLS Certificates charm library."""

on = CertHandlerEvents() # pyright: ignore
_stored = StoredState()

def __init__(
self,
Expand All @@ -283,6 +284,7 @@ def __init__(
peer_relation_name: str = "peers",
cert_subject: Optional[str] = None,
sans: Optional[List[str]] = None,
refresh_events: Optional[List[BoundEvent]] = None,
):
"""CertHandler is used to wrap TLS Certificates management operations for charms.

Expand All @@ -299,8 +301,16 @@ def __init__(
Must match metadata.yaml.
cert_subject: Custom subject. Name collisions are under the caller's responsibility.
sans: DNS names. If none are given, use FQDN.
refresh_events: an optional list of bound events which
will be observed to replace the current CSR with a new one
if there are changes in the CSR's DNS SANs or IP SANs.
Then, subsequently, replace its corresponding certificate with a new one.
"""
super().__init__(charm, key)
self._stored.set_default(
current_sans_ip=None,
current_sans_dns=None,
)
PietroPasotti marked this conversation as resolved.
Show resolved Hide resolved
self.charm = charm

# We need to sanitize the unit name, otherwise route53 complains:
Expand Down Expand Up @@ -355,6 +365,15 @@ def __init__(
self._on_upgrade_charm,
)

if refresh_events:
michaeldmitry marked this conversation as resolved.
Show resolved Hide resolved
for ev in refresh_events:
self.framework.observe(ev, self._on_refresh_event)

def _on_refresh_event(self, _):
"""Replace the latest current CSR with a new one if there are any SANs changes."""
if self.sans_ip != self._stored.sans_ip or self.sans_dns != self._stored.sans_dns:
self._generate_csr(renew=True)

def _on_upgrade_charm(self, _):
has_privkey = self.vault.get_value("private-key")

Expand Down Expand Up @@ -484,6 +503,9 @@ def _generate_csr(
)
self.certificates.request_certificate_creation(certificate_signing_request=csr)

self._stored.sans_ip = self.sans_ip
self._stored.sans_dns = self.sans_dns

if clear_cert:
self.vault.clear()

Expand Down
104 changes: 103 additions & 1 deletion tests/scenario/test_cert_handler/test_cert_handler_v1.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import socket
import sys
from contextlib import contextmanager
from pathlib import Path
from unittest.mock import patch

import pytest
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.x509.oid import ExtensionOID
from ops import CharmBase
from scenario import Context, PeerRelation, Relation, State

Expand All @@ -12,6 +17,7 @@

libs = str(Path(__file__).parent.parent.parent.parent / "lib")
sys.path.append(libs)
MOCK_HOSTNAME = "mock-hostname"


class MyCharm(CharmBase):
Expand All @@ -22,8 +28,28 @@ class MyCharm(CharmBase):

def __init__(self, fw):
super().__init__(fw)
sans = [socket.getfqdn()]
if hostname := self._mock_san:
sans.append(hostname)

self.ch = CertHandler(self, key="ch", sans=[socket.getfqdn()])
self.ch = CertHandler(self, key="ch", sans=sans, refresh_events=[self.on.config_changed])

@property
def _mock_san(self):
"""This property is meant to be mocked to return a mock string hostname to be used as SAN.

By default, it returns None.
"""
return None


def get_csr_obj(csr: str):
return x509.load_pem_x509_csr(csr.encode(), default_backend())


def get_sans_from_csr(csr):
san_extension = csr.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME)
return set(san_extension.value.get_values_for_type(x509.DNSName))


@pytest.fixture
Expand All @@ -36,6 +62,20 @@ def certificates():
return Relation("certificates")


@contextmanager
def _sans_patch(hostname=MOCK_HOSTNAME):
with patch.object(MyCharm, "_mock_san", hostname):
yield


@contextmanager
def _cert_renew_patch():
with patch(
"charms.tls_certificates_interface.v3.tls_certificates.TLSCertificatesRequiresV3.request_certificate_renewal"
) as patcher:
yield patcher


@pytest.mark.parametrize("leader", (True, False))
def test_cert_joins(ctx, certificates, leader):
with ctx.manager(
Expand Down Expand Up @@ -72,3 +112,65 @@ def test_cert_joins_peer_vault_backend(ctx_juju2, certificates, leader):
) as mgr:
mgr.run()
assert mgr.charm.ch.private_key


def test_renew_csr_on_sans_change(ctx, certificates):
# generate a CSR
with ctx.manager(
certificates.joined_event,
State(leader=True, relations=[certificates]),
) as mgr:
charm = mgr.charm
state_out = mgr.run()
orig_csr = get_csr_obj(charm.ch._csr)
assert get_sans_from_csr(orig_csr) == {socket.getfqdn()}

# trigger a config_changed with a modified SAN
with _sans_patch():
with ctx.manager("config_changed", state_out) as mgr:
charm = mgr.charm
state_out = mgr.run()
csr = get_csr_obj(charm.ch._csr)
# assert CSR contains updated SAN
assert get_sans_from_csr(csr) == {socket.getfqdn(), MOCK_HOSTNAME}


def test_csr_no_change_on_wrong_refresh_event(ctx, certificates):
with _cert_renew_patch() as renew_patch:
with ctx.manager(
"config_changed",
State(leader=True, relations=[certificates]),
) as mgr:
charm = mgr.charm
state_out = mgr.run()
orig_csr = get_csr_obj(charm.ch._csr)
assert get_sans_from_csr(orig_csr) == {socket.getfqdn()}

with _sans_patch():
with _cert_renew_patch() as renew_patch:
with ctx.manager("update_status", state_out) as mgr:
charm = mgr.charm
state_out = mgr.run()
csr = get_csr_obj(charm.ch._csr)
assert get_sans_from_csr(csr) == {socket.getfqdn()}
assert renew_patch.call_count == 0


def test_csr_no_change(ctx, certificates):

with ctx.manager(
"config_changed",
State(leader=True, relations=[certificates]),
) as mgr:
charm = mgr.charm
state_out = mgr.run()
orig_csr = get_csr_obj(charm.ch._csr)
assert get_sans_from_csr(orig_csr) == {socket.getfqdn()}

with _cert_renew_patch() as renew_patch:
with ctx.manager("config_changed", state_out) as mgr:
charm = mgr.charm
state_out = mgr.run()
csr = get_csr_obj(charm.ch._csr)
assert get_sans_from_csr(csr) == {socket.getfqdn()}
assert renew_patch.call_count == 0
13 changes: 12 additions & 1 deletion tests/unit/test_kubernetes_compute_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
# See LICENSE file for licensing details.
import unittest
from unittest import mock
from unittest.mock import MagicMock, Mock
from unittest.mock import MagicMock, Mock, patch

import httpx
import tenacity
import yaml
from charms.observability_libs.v0.kubernetes_compute_resources_patch import (
KubernetesComputeResourcesPatch,
Expand All @@ -16,12 +17,22 @@
from ops import BlockedStatus, WaitingStatus
from ops.charm import CharmBase
from ops.testing import Harness
from pytest import fixture

from tests.unit.helpers import PROJECT_DIR

CL_PATH = "charms.observability_libs.v0.kubernetes_compute_resources_patch.KubernetesComputeResourcesPatch"


@fixture(autouse=True)
def patch_retry():
with patch.multiple(
KubernetesComputeResourcesPatch,
PATCH_RETRY_STOP=tenacity.stop_after_delay(0),
):
yield


class TestKubernetesComputeResourcesPatch(unittest.TestCase):
class _TestCharm(CharmBase):
def __init__(self, *args):
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,6 @@ allowlist_externals =
rm
commands =
charmcraft fetch-lib charms.tls_certificates_interface.v2.tls_certificates
charmcraft fetch-lib charms.tls_certificates_interface.v3.tls_certificates
pytest -v --tb native {[vars]tst_path}/scenario --log-cli-level=INFO -s {posargs}
rm -rf ./lib/charms/tls_certificates_interface
Loading