From 15c38f25af70f469dc84a8c15a360fca87f2cf75 Mon Sep 17 00:00:00 2001 From: Robert Gildein Date: Thu, 22 Aug 2024 11:17:35 +0200 Subject: [PATCH] chore(lib): Update cert_handler library to v1 (#517) Update the cert_handler to v1. fixes: #516 --------- Signed-off-by: Robert Gildein --- .../observability_libs/v0/cert_handler.py | 437 ------ .../observability_libs/v1/cert_handler.py | 594 ++++++++ .../{v2 => v3}/tls_certificates.py | 1243 ++++++++++------- charms/istio-pilot/src/charm.py | 17 +- charms/istio-pilot/tests/unit/test_charm.py | 28 +- tests/test_bundle.py | 14 +- 6 files changed, 1384 insertions(+), 949 deletions(-) delete mode 100644 charms/istio-pilot/lib/charms/observability_libs/v0/cert_handler.py create mode 100644 charms/istio-pilot/lib/charms/observability_libs/v1/cert_handler.py rename charms/istio-pilot/lib/charms/tls_certificates_interface/{v2 => v3}/tls_certificates.py (63%) diff --git a/charms/istio-pilot/lib/charms/observability_libs/v0/cert_handler.py b/charms/istio-pilot/lib/charms/observability_libs/v0/cert_handler.py deleted file mode 100644 index 275cf7db..00000000 --- a/charms/istio-pilot/lib/charms/observability_libs/v0/cert_handler.py +++ /dev/null @@ -1,437 +0,0 @@ -# Copyright 2023 Canonical Ltd. -# See LICENSE file for licensing details. -"""## Overview. - -This document explains how to use the `CertHandler` class to -create and manage TLS certificates through the `tls_certificates` interface. - -The goal of the CertHandler is to provide a wrapper to the `tls_certificates` -library functions to make the charm integration smoother. - -## Library Usage - -This library should be used to create a `CertHandler` object, as per the -following example: - -```python -self.cert_handler = CertHandler( - charm=self, - key="my-app-cert-manager", - peer_relation_name="replicas", - cert_subject="unit_name", # Optional -) -``` - -You can then observe the library's custom event and make use of the key and cert: -```python -self.framework.observe(self.cert_handler.on.cert_changed, self._on_server_cert_changed) - -container.push(keypath, self.cert_handler.key) -container.push(certpath, self.cert_handler.cert) -``` - -This library requires a peer relation to be declared in the requirer's metadata. Peer relation data -is used for "persistent storage" of the private key and certs. -""" -import ipaddress -import json -import socket -from itertools import filterfalse -from typing import List, Optional, Union, cast - -try: - from charms.tls_certificates_interface.v2.tls_certificates import ( # type: ignore - AllCertificatesInvalidatedEvent, - CertificateAvailableEvent, - CertificateExpiringEvent, - CertificateInvalidatedEvent, - TLSCertificatesRequiresV2, - generate_csr, - generate_private_key, - ) -except ImportError as e: - raise ImportError( - "failed to import charms.tls_certificates_interface.v2.tls_certificates; " - "Either the library itself is missing (please get it through charmcraft fetch-lib) " - "or one of its dependencies is unmet." - ) from e - -import logging - -from ops.charm import CharmBase -from ops.framework import EventBase, EventSource, Object, ObjectEvents -from ops.model import Relation - -logger = logging.getLogger(__name__) - - -LIBID = "b5cd5cd580f3428fa5f59a8876dcbe6a" -LIBAPI = 0 -LIBPATCH = 14 - - -def is_ip_address(value: str) -> bool: - """Return True if the input value is a valid IPv4 address; False otherwise.""" - try: - ipaddress.IPv4Address(value) - return True - except ipaddress.AddressValueError: - return False - - -class CertChanged(EventBase): - """Event raised when a cert is changed (becomes available or revoked).""" - - -class CertHandlerEvents(ObjectEvents): - """Events for CertHandler.""" - - cert_changed = EventSource(CertChanged) - - -class CertHandler(Object): - """A wrapper for the requirer side of the TLS Certificates charm library.""" - - on = CertHandlerEvents() # pyright: ignore - - def __init__( - self, - charm: CharmBase, - *, - key: str, - peer_relation_name: str, - certificates_relation_name: str = "certificates", - cert_subject: Optional[str] = None, - extra_sans_dns: Optional[List[str]] = None, # TODO: in v1, rename arg to `sans` - ): - """CertHandler is used to wrap TLS Certificates management operations for charms. - - CerHandler manages one single cert. - - Args: - charm: The owning charm. - key: A manually-crafted, static, unique identifier used by ops to identify events. - It shouldn't change between one event to another. - peer_relation_name: Must match metadata.yaml. - certificates_relation_name: Must match metadata.yaml. - cert_subject: Custom subject. Name collisions are under the caller's responsibility. - extra_sans_dns: DNS names. If none are given, use FQDN. - """ - super().__init__(charm, key) - - self.charm = charm - # We need to sanitize the unit name, otherwise route53 complains: - # "urn:ietf:params:acme:error:malformed" :: Domain name contains an invalid character - self.cert_subject = charm.unit.name.replace("/", "-") if not cert_subject else cert_subject - - # Use fqdn only if no SANs were given, and drop empty/duplicate SANs - sans = list(set(filter(None, (extra_sans_dns or [socket.getfqdn()])))) - self.sans_ip = list(filter(is_ip_address, sans)) - self.sans_dns = list(filterfalse(is_ip_address, sans)) - - self.peer_relation_name = peer_relation_name - self.certificates_relation_name = certificates_relation_name - - self.certificates = TLSCertificatesRequiresV2(self.charm, self.certificates_relation_name) - - self.framework.observe( - self.charm.on.config_changed, - self._on_config_changed, - ) - self.framework.observe( - self.charm.on.certificates_relation_joined, # pyright: ignore - self._on_certificates_relation_joined, - ) - self.framework.observe( - self.certificates.on.certificate_available, # pyright: ignore - self._on_certificate_available, - ) - self.framework.observe( - self.certificates.on.certificate_expiring, # pyright: ignore - self._on_certificate_expiring, - ) - self.framework.observe( - self.certificates.on.certificate_invalidated, # pyright: ignore - self._on_certificate_invalidated, - ) - self.framework.observe( - self.certificates.on.all_certificates_invalidated, # pyright: ignore - self._on_all_certificates_invalidated, - ) - - # Peer relation events - self.framework.observe( - self.charm.on[self.peer_relation_name].relation_created, self._on_peer_relation_created - ) - - @property - def enabled(self) -> bool: - """Boolean indicating whether the charm has a tls_certificates relation.""" - # We need to check for units as a temporary workaround because of https://bugs.launchpad.net/juju/+bug/2024583 - # This could in theory not work correctly on scale down to 0 but it is necessary for the moment. - return ( - len(self.charm.model.relations[self.certificates_relation_name]) > 0 - and len(self.charm.model.get_relation(self.certificates_relation_name).units) > 0 # type: ignore - ) - - @property - def _peer_relation(self) -> Optional[Relation]: - """Return the peer relation.""" - return self.charm.model.get_relation(self.peer_relation_name, None) - - def _on_peer_relation_created(self, _): - """Generate the CSR if the certificates relation is ready.""" - self._generate_privkey() - - # check cert relation is ready - if not (self.charm.model.get_relation(self.certificates_relation_name)): - # peer relation event happened to fire before tls-certificates events. - # Abort, and let the "certificates joined" observer create the CSR. - logger.info("certhandler waiting on certificates relation") - return - - logger.debug("certhandler has peer and certs relation: proceeding to generate csr") - self._generate_csr() - - def _on_certificates_relation_joined(self, _) -> None: - """Generate the CSR if the peer relation is ready.""" - self._generate_privkey() - - # check peer relation is there - if not self._peer_relation: - # tls-certificates relation event happened to fire before peer events. - # Abort, and let the "peer joined" relation create the CSR. - logger.info("certhandler waiting on peer relation") - return - - logger.debug("certhandler has peer and certs relation: proceeding to generate csr") - self._generate_csr() - - def _generate_privkey(self): - # Generate priv key unless done already - # TODO figure out how to go about key rotation. - if not self._private_key: - private_key = generate_private_key() - self._private_key = private_key.decode() - - def _on_config_changed(self, _): - # FIXME on config changed, the web_external_url may or may not change. But because every - # call to `generate_csr` appends a uuid, CSRs cannot be easily compared to one another. - # so for now, will be overwriting the CSR (and cert) every config change. This is not - # great. We could avoid this problem if: - # - we extract the external_url from the existing cert and compare to current; or - # - we drop the web_external_url from the list of SANs. - # Generate a CSR only if the necessary relations are already in place. - if self._peer_relation and self.charm.model.get_relation(self.certificates_relation_name): - self._generate_csr(renew=True) - - def _generate_csr( - self, overwrite: bool = False, renew: bool = False, clear_cert: bool = False - ): - """Request a CSR "creation" if renew is False, otherwise request a renewal. - - Without overwrite=True, the CSR would be created only once, even if calling the method - multiple times. This is useful needed because the order of peer-created and - certificates-joined is not predictable. - - This method intentionally does not emit any events, leave it for caller's responsibility. - """ - # if we are in a relation-broken hook, we might not have a relation to publish the csr to. - if not self.charm.model.get_relation(self.certificates_relation_name): - logger.warning( - f"No {self.certificates_relation_name!r} relation found. " f"Cannot generate csr." - ) - return - - # At this point, assuming "peer joined" and "certificates joined" have already fired - # (caller must guard) so we must have a private_key entry in relation data at our disposal. - # Otherwise, traceback -> debug. - - # In case we already have a csr, do not overwrite it by default. - if overwrite or renew or not self._csr: - private_key = self._private_key - if private_key is None: - # FIXME: raise this in a less nested scope by - # generating privkey and csr in the same method. - raise RuntimeError( - "private key unset. call _generate_privkey() before you call this method." - ) - csr = generate_csr( - private_key=private_key.encode(), - subject=self.cert_subject, - sans_dns=self.sans_dns, - sans_ip=self.sans_ip, - ) - - if renew and self._csr: - self.certificates.request_certificate_renewal( - old_certificate_signing_request=self._csr.encode(), - new_certificate_signing_request=csr, - ) - else: - logger.info( - "Creating CSR for %s with DNS %s and IPs %s", - self.cert_subject, - self.sans_dns, - self.sans_ip, - ) - self.certificates.request_certificate_creation(certificate_signing_request=csr) - - # Note: CSR is being replaced with a new one, so until we get the new cert, we'd have - # a mismatch between the CSR and the cert. - # For some reason the csr contains a trailing '\n'. TODO figure out why - self._csr = csr.decode().strip() - - if clear_cert: - self._ca_cert = "" - self._server_cert = "" - self._chain = [] - - def _on_certificate_available(self, event: CertificateAvailableEvent) -> None: - """Get the certificate from the event and store it in a peer relation. - - Note: assuming "limit: 1" in metadata - """ - # We need to store the ca cert and server cert somewhere it would persist across upgrades. - # While we support Juju 2.9, the only option is peer data. When we drop 2.9, then secrets. - - # I think juju guarantees that a peer-created always fires before any regular - # relation-changed. If that is not the case, we would need more guards and more paths. - - # Process the cert only if it belongs to the unit that requested it (this unit) - event_csr = ( - event.certificate_signing_request.strip() - if event.certificate_signing_request - else None - ) - if event_csr == self._csr: - self._ca_cert = event.ca - self._server_cert = event.certificate - self._chain = event.chain - self.on.cert_changed.emit() # pyright: ignore - - @property - def key(self): - """Return the private key.""" - return self._private_key - - @property - def _private_key(self) -> Optional[str]: - if self._peer_relation: - return self._peer_relation.data[self.charm.unit].get("private_key", None) - return None - - @_private_key.setter - def _private_key(self, value: str): - # Caller must guard. We want the setter to fail loudly. Failure must have a side effect. - rel = self._peer_relation - assert rel is not None # For type checker - rel.data[self.charm.unit].update({"private_key": value}) - - @property - def _csr(self) -> Optional[str]: - if self._peer_relation: - return self._peer_relation.data[self.charm.unit].get("csr", None) - return None - - @_csr.setter - def _csr(self, value: str): - # Caller must guard. We want the setter to fail loudly. Failure must have a side effect. - rel = self._peer_relation - assert rel is not None # For type checker - rel.data[self.charm.unit].update({"csr": value}) - - @property - def _ca_cert(self) -> Optional[str]: - if self._peer_relation: - return self._peer_relation.data[self.charm.unit].get("ca", None) - return None - - @_ca_cert.setter - def _ca_cert(self, value: str): - # Caller must guard. We want the setter to fail loudly. Failure must have a side effect. - rel = self._peer_relation - assert rel is not None # For type checker - rel.data[self.charm.unit].update({"ca": value}) - - @property - def cert(self): - """Return the server cert.""" - return self._server_cert - - @property - def ca(self): - """Return the CA cert.""" - return self._ca_cert - - @property - def _server_cert(self) -> Optional[str]: - if self._peer_relation: - return self._peer_relation.data[self.charm.unit].get("certificate", None) - return None - - @_server_cert.setter - def _server_cert(self, value: str): - # Caller must guard. We want the setter to fail loudly. Failure must have a side effect. - rel = self._peer_relation - assert rel is not None # For type checker - rel.data[self.charm.unit].update({"certificate": value}) - - @property - def _chain(self) -> List[str]: - if self._peer_relation: - if chain := self._peer_relation.data[self.charm.unit].get("chain", []): - return cast(list, json.loads(cast(str, chain))) - return [] - - @_chain.setter - def _chain(self, value: List[str]): - # Caller must guard. We want the setter to fail loudly. Failure must have a side effect. - rel = self._peer_relation - assert rel is not None # For type checker - rel.data[self.charm.unit].update({"chain": json.dumps(value)}) - - @property - def chain(self) -> List[str]: - """Return the ca chain.""" - return self._chain - - def _on_certificate_expiring( - self, event: Union[CertificateExpiringEvent, CertificateInvalidatedEvent] - ) -> None: - """Generate a new CSR and request certificate renewal.""" - if event.certificate == self._server_cert: - self._generate_csr(renew=True) - - def _certificate_revoked(self, event) -> None: - """Remove the certificate from the peer relation and generate a new CSR.""" - # Note: assuming "limit: 1" in metadata - if event.certificate == self._server_cert: - self._generate_csr(overwrite=True, clear_cert=True) - self.on.cert_changed.emit() # pyright: ignore - - def _on_certificate_invalidated(self, event: CertificateInvalidatedEvent) -> None: - """Deal with certificate revocation and expiration.""" - if event.certificate != self._server_cert: - return - - # if event.reason in ("revoked", "expired"): - # Currently, the reason does not matter to us because the action is the same. - self._generate_csr(overwrite=True, clear_cert=True) - self.on.cert_changed.emit() # pyright: ignore - - def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEvent) -> None: - """Clear the certificates data when removing the relation.""" - # Note: assuming "limit: 1" in metadata - # The "certificates_relation_broken" event is converted to "all invalidated" custom - # event by the tls-certificates library. Per convention, we let the lib manage the - # relation and we do not observe "certificates_relation_broken" directly. - if self._peer_relation: - private_key = self._private_key - # This is a workaround for https://bugs.launchpad.net/juju/+bug/2024583 - self._peer_relation.data[self.charm.unit].clear() - if private_key: - self._peer_relation.data[self.charm.unit].update({"private_key": private_key}) - - # We do not generate a CSR here because the relation is gone. - self.on.cert_changed.emit() # pyright: ignore diff --git a/charms/istio-pilot/lib/charms/observability_libs/v1/cert_handler.py b/charms/istio-pilot/lib/charms/observability_libs/v1/cert_handler.py new file mode 100644 index 00000000..3b87ad46 --- /dev/null +++ b/charms/istio-pilot/lib/charms/observability_libs/v1/cert_handler.py @@ -0,0 +1,594 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. +"""## Overview. + +This document explains how to use the `CertHandler` class to +create and manage TLS certificates through the `tls_certificates` interface. + +The goal of the CertHandler is to provide a wrapper to the `tls_certificates` +library functions to make the charm integration smoother. + +## Library Usage + +This library should be used to create a `CertHandler` object, as per the +following example: + +```python +self.cert_handler = CertHandler( + charm=self, + key="my-app-cert-manager", + cert_subject="unit_name", # Optional +) +``` + +You can then observe the library's custom event and make use of the key and cert: +```python +self.framework.observe(self.cert_handler.on.cert_changed, self._on_server_cert_changed) + +container.push(keypath, self.cert_handler.private_key) +container.push(certpath, self.cert_handler.servert_cert) +``` + +Since this library uses [Juju Secrets](https://juju.is/docs/juju/secret) it requires Juju >= 3.0.3. +""" +import abc +import ipaddress +import json +import socket +from itertools import filterfalse +from typing import Dict, List, Optional, Union + +try: + from charms.tls_certificates_interface.v3.tls_certificates import ( # type: ignore + AllCertificatesInvalidatedEvent, + CertificateAvailableEvent, + CertificateExpiringEvent, + CertificateInvalidatedEvent, + ProviderCertificate, + TLSCertificatesRequiresV3, + generate_csr, + generate_private_key, + ) +except ImportError as e: + raise ImportError( + "failed to import charms.tls_certificates_interface.v3.tls_certificates; " + "Either the library itself is missing (please get it through charmcraft fetch-lib) " + "or one of its dependencies is unmet." + ) from e + +import logging + +from ops.charm import CharmBase +from ops.framework import EventBase, EventSource, Object, ObjectEvents +from ops.jujuversion import JujuVersion +from ops.model import Relation, Secret, SecretNotFoundError + +logger = logging.getLogger(__name__) + +LIBID = "b5cd5cd580f3428fa5f59a8876dcbe6a" +LIBAPI = 1 +LIBPATCH = 11 + +VAULT_SECRET_LABEL = "cert-handler-private-vault" + + +def is_ip_address(value: str) -> bool: + """Return True if the input value is a valid IPv4 address; False otherwise.""" + try: + ipaddress.IPv4Address(value) + return True + except ipaddress.AddressValueError: + return False + + +class CertChanged(EventBase): + """Event raised when a cert is changed (becomes available or revoked).""" + + +class CertHandlerEvents(ObjectEvents): + """Events for CertHandler.""" + + cert_changed = EventSource(CertChanged) + + +class _VaultBackend(abc.ABC): + """Base class for a single secret manager. + + Assumptions: + - A single secret (label) is managed by a single instance. + - Secret is per-unit (not per-app, i.e. may differ from unit to unit). + """ + + def store(self, contents: Dict[str, str], clear: bool = False): ... + + def get_value(self, key: str) -> Optional[str]: ... + + def retrieve(self) -> Dict[str, str]: ... + + def clear(self): ... + + +class _RelationVaultBackend(_VaultBackend): + """Relation backend for Vault. + + Use it to store data in a relation databag. + Assumes that a single relation exists and its data is readable. + If not, it will raise RuntimeErrors as soon as you try to read/write. + It will store the data, in plaintext (json-dumped) nested under a configurable + key in the **unit databag** of this relation. + + Typically, you'll use this with peer relations. + + Note: it is assumed that this object has exclusive access to the data, even though in practice it does not. + Modifying relation data yourself would go unnoticed and disrupt consistency. + """ + + _NEST_UNDER = "lib.charms.observability_libs.v1.cert_handler::vault" + # This key needs to be relation-unique. If someone ever creates multiple Vault(_RelationVaultBackend) + # instances backed by the same (peer) relation, they'll need to set different _NEST_UNDERs + # for each _RelationVaultBackend instance or they'll be fighting over it. + + def __init__(self, charm: CharmBase, relation_name: str): + self.charm = charm + self.relation_name = relation_name + + def _check_ready(self): + relation = self.charm.model.get_relation(self.relation_name) + if not relation or not relation.data: + # if something goes wrong here, the peer-backed vault is not ready to operate + # it can be because you are trying to use it too soon, i.e. before the (peer) + # relation has been created (or has joined). + raise RuntimeError("Relation backend not ready.") + + @property + def _relation(self) -> Optional[Relation]: + self._check_ready() + return self.charm.model.get_relation(self.relation_name) + + @property + def _databag(self): + self._check_ready() + # _check_ready verifies that there is a relation + return self._relation.data[self.charm.unit] # type: ignore + + def _read(self) -> Dict[str, str]: + value = self._databag.get(self._NEST_UNDER) + if value: + return json.loads(value) + return {} + + def _write(self, value: Dict[str, str]): + if not all(isinstance(x, str) for x in value.values()): + # the caller has to take care of encoding + raise TypeError("You can only store strings in Vault.") + + self._databag[self._NEST_UNDER] = json.dumps(value) + + def store(self, contents: Dict[str, str], clear: bool = False): + """Create a new revision by updating the previous one with ``contents``.""" + current = self._read() + + if clear: + current.clear() + + current.update(contents) + self._write(current) + + def get_value(self, key: str) -> Optional[str]: + """Like retrieve, but single-value.""" + return self._read().get(key) + + def retrieve(self): + """Return the full vault content.""" + return self._read() + + def clear(self): + del self._databag[self._NEST_UNDER] + + +class _SecretVaultBackend(_VaultBackend): + """Relation backend for Vault. + + Use it to store data in a Juju secret. + Assumes that Juju supports secrets. + If not, it will raise some exception as soon as you try to read/write. + + Note: it is assumed that this object has exclusive access to the data, even though in practice it does not. + Modifying secret's data yourself would go unnoticed and disrupt consistency. + """ + + _uninitialized_key = "uninitialized-secret-key" + + def __init__(self, charm: CharmBase, secret_label: str): + self.charm = charm + self.secret_label = secret_label # needs to be charm-unique. + + @property + def _secret(self) -> Secret: + try: + # we are owners, so we don't need to grant it to ourselves + return self.charm.model.get_secret(label=self.secret_label) + except SecretNotFoundError: + # we need to set SOME contents when we're creating the secret, so we do it. + return self.charm.unit.add_secret( + {self._uninitialized_key: "42"}, label=self.secret_label + ) + + def store(self, contents: Dict[str, str], clear: bool = False): + """Create a new revision by updating the previous one with ``contents``.""" + secret = self._secret + current = secret.get_content(refresh=True) + + if clear: + current.clear() + elif current.get(self._uninitialized_key): + # is this the first revision? clean up the mock contents we created instants ago. + del current[self._uninitialized_key] + + current.update(contents) + secret.set_content(current) + + def get_value(self, key: str) -> Optional[str]: + """Like retrieve, but single-value.""" + return self._secret.get_content(refresh=True).get(key) + + def retrieve(self): + """Return the full vault content.""" + return self._secret.get_content(refresh=True) + + def clear(self): + self._secret.remove_all_revisions() + + +class Vault: + """Simple application secret wrapper for local usage.""" + + def __init__(self, backend: _VaultBackend): + self._backend = backend + + def store(self, contents: Dict[str, str], clear: bool = False): + """Store these contents in the vault overriding whatever is there.""" + self._backend.store(contents, clear=clear) + + def get_value(self, key: str): + """Like retrieve, but single-value.""" + return self._backend.get_value(key) + + def retrieve(self) -> Dict[str, str]: + """Return the full vault content.""" + return self._backend.retrieve() + + def clear(self): + """Clear the vault.""" + try: + self._backend.clear() + except SecretNotFoundError: + # guard against: https://github.com/canonical/observability-libs/issues/95 + # this is fine, it might mean an earlier hook had already called .clear() + # not sure what exactly the root cause is, might be a juju bug + logger.debug("Could not clear vault: secret is gone already.") + + +class CertHandler(Object): + """A wrapper for the requirer side of the TLS Certificates charm library.""" + + on = CertHandlerEvents() # pyright: ignore + + def __init__( + self, + charm: CharmBase, + *, + key: str, + certificates_relation_name: str = "certificates", + peer_relation_name: str = "peers", + cert_subject: Optional[str] = None, + sans: Optional[List[str]] = None, + ): + """CertHandler is used to wrap TLS Certificates management operations for charms. + + CerHandler manages one single cert. + + Args: + charm: The owning charm. + key: A manually-crafted, static, unique identifier used by ops to identify events. + It shouldn't change between one event to another. + certificates_relation_name: Name of the certificates relation over which we obtain TLS certificates. + Must match metadata.yaml. + peer_relation_name: Name of a peer relation used to store our secrets. + Only used on older Juju versions where secrets are not supported. + Must match metadata.yaml. + cert_subject: Custom subject. Name collisions are under the caller's responsibility. + sans: DNS names. If none are given, use FQDN. + """ + super().__init__(charm, key) + self.charm = charm + + # We need to sanitize the unit name, otherwise route53 complains: + # "urn:ietf:params:acme:error:malformed" :: Domain name contains an invalid character + self.cert_subject = charm.unit.name.replace("/", "-") if not cert_subject else cert_subject + + # Use fqdn only if no SANs were given, and drop empty/duplicate SANs + sans = list(set(filter(None, (sans or [socket.getfqdn()])))) + self.sans_ip = list(filter(is_ip_address, sans)) + self.sans_dns = list(filterfalse(is_ip_address, sans)) + + if self._check_juju_supports_secrets(): + vault_backend = _SecretVaultBackend(charm, secret_label=VAULT_SECRET_LABEL) + + # TODO: gracefully handle situations where the + # secret is gone because the admin has removed it manually + # self.framework.observe(self.charm.on.secret_remove, self._rotate_csr) + + else: + vault_backend = _RelationVaultBackend(charm, relation_name=peer_relation_name) + self.vault = Vault(vault_backend) + + self.certificates_relation_name = certificates_relation_name + self.certificates = TLSCertificatesRequiresV3(self.charm, self.certificates_relation_name) + + self.framework.observe( + self.charm.on.config_changed, + self._on_config_changed, + ) + self.framework.observe( + self.charm.on[self.certificates_relation_name].relation_joined, # pyright: ignore + self._on_certificates_relation_joined, + ) + self.framework.observe( + self.certificates.on.certificate_available, # pyright: ignore + self._on_certificate_available, + ) + self.framework.observe( + self.certificates.on.certificate_expiring, # pyright: ignore + self._on_certificate_expiring, + ) + self.framework.observe( + self.certificates.on.certificate_invalidated, # pyright: ignore + self._on_certificate_invalidated, + ) + self.framework.observe( + self.certificates.on.all_certificates_invalidated, # pyright: ignore + self._on_all_certificates_invalidated, + ) + self.framework.observe( + self.charm.on.upgrade_charm, # pyright: ignore + self._on_upgrade_charm, + ) + + def _on_upgrade_charm(self, _): + has_privkey = self.vault.get_value("private-key") + + self._migrate_vault() + + # If we already have a csr, but the pre-migration vault has no privkey stored, + # the csr must have been signed with a privkey that is now outdated and utterly lost. + # So we throw away the csr and generate a new one (and a new privkey along with it). + if not has_privkey and self._csr: + logger.debug("CSR and privkey out of sync after charm upgrade. Renewing CSR.") + # this will call `self.private_key` which will generate a new privkey. + self._generate_csr(renew=True) + + def _migrate_vault(self): + peer_backend = _RelationVaultBackend(self.charm, relation_name="peers") + + if self._check_juju_supports_secrets(): + # we are on recent juju + if self.vault.retrieve(): + # we already were on recent juju: nothing to migrate + logger.debug( + "Private key is already stored as a juju secret. Skipping private key migration." + ) + return + + # we used to be on old juju: our secret stuff is in peer data + if contents := peer_backend.retrieve(): + logger.debug( + "Private key found in relation data. " + "Migrating private key to a juju secret." + ) + # move over to secret-backed storage + self.vault.store(contents) + + # clear the peer storage + peer_backend.clear() + return + + # if we are downgrading, i.e. from juju with secrets to juju without, + # we have lost all that was in the secrets backend. + + @property + def enabled(self) -> bool: + """Boolean indicating whether the charm has a tls_certificates relation. + + See also the `available` property. + """ + # We need to check for units as a temporary workaround because of https://bugs.launchpad.net/juju/+bug/2024583 + # This could in theory not work correctly on scale down to 0 but it is necessary for the moment. + + if not self.relation: + return False + + if not self.relation.units: # pyright: ignore + return False + + if not self.relation.app: # pyright: ignore + return False + + if not self.relation.data: # pyright: ignore + return False + + return True + + @property + def available(self) -> bool: + """Return True if all certs are available in relation data; False otherwise.""" + return ( + self.enabled + and self.server_cert is not None + and self.private_key is not None + and self.ca_cert is not None + ) + + def _on_certificates_relation_joined(self, _) -> None: + # this will only generate a csr if we don't have one already + self._generate_csr() + + def _on_config_changed(self, _): + # this will only generate a csr if we don't have one already + self._generate_csr() + + @property + def relation(self): + """The "certificates" relation.""" + return self.charm.model.get_relation(self.certificates_relation_name) + + def _generate_csr( + self, overwrite: bool = False, renew: bool = False, clear_cert: bool = False + ): + """Request a CSR "creation" if renew is False, otherwise request a renewal. + + Without overwrite=True, the CSR would be created only once, even if calling the method + multiple times. This is useful needed because the order of peer-created and + certificates-joined is not predictable. + + This method intentionally does not emit any events, leave it for caller's responsibility. + """ + # if we are in a relation-broken hook, we might not have a relation to publish the csr to. + if not self.relation: + logger.warning( + f"No {self.certificates_relation_name!r} relation found. " f"Cannot generate csr." + ) + return + + # In case we already have a csr, do not overwrite it by default. + if overwrite or renew or not self._csr: + private_key = self.private_key + csr = generate_csr( + private_key=private_key.encode(), + subject=self.cert_subject, + sans_dns=self.sans_dns, + sans_ip=self.sans_ip, + ) + + if renew and self._csr: + self.certificates.request_certificate_renewal( + old_certificate_signing_request=self._csr.encode(), + new_certificate_signing_request=csr, + ) + else: + logger.info( + "Creating CSR for %s with DNS %s and IPs %s", + self.cert_subject, + self.sans_dns, + self.sans_ip, + ) + self.certificates.request_certificate_creation(certificate_signing_request=csr) + + if clear_cert: + self.vault.clear() + + def _on_certificate_available(self, event: CertificateAvailableEvent) -> None: + """Emit cert-changed.""" + self.on.cert_changed.emit() # pyright: ignore + + @property + def private_key(self) -> str: + """Private key. + + BEWARE: if the vault misbehaves, the backing secret is removed, the peer relation dies + or whatever, we might be calling generate_private_key() again and cause a desync + with the CSR because it's going to be signed with an outdated key we have no way of retrieving. + The caller needs to ensure that if the vault backend gets reset, then so does the csr. + + TODO: we could consider adding a way to verify if the csr was signed by our privkey, + and do that on collect_unit_status as a consistency check + """ + private_key = self.vault.get_value("private-key") + if private_key is None: + private_key = generate_private_key().decode() + self.vault.store({"private-key": private_key}) + return private_key + + @property + def _csr(self) -> Optional[str]: + csrs = self.certificates.get_requirer_csrs() + if not csrs: + return None + + # in principle we only ever need one cert. + # we might want to complicate this a bit once we get into cert rotations: during the rotation, we may need to + # keep the old one around for a little while. If there's multiple certs, at the moment we're + # ignoring all but the last one. + if len(csrs) > 1: + logger.warning( + f"Multiple CSRs found in {self.certificates_relation_name!r} relation. " + "cert_handler is not ready to expect it." + ) + + return csrs[-1].csr + + def get_cert(self) -> Optional[ProviderCertificate]: + """Get the certificate from relation data.""" + all_certs = self.certificates.get_provider_certificates() + # search for the cert matching our csr. + matching_cert = [c for c in all_certs if c.csr == self._csr] + return matching_cert[0] if matching_cert else None + + @property + def ca_cert(self) -> Optional[str]: + """CA Certificate.""" + cert = self.get_cert() + return cert.ca if cert else None + + @property + def server_cert(self) -> Optional[str]: + """Server Certificate.""" + cert = self.get_cert() + return cert.certificate if cert else None + + @property + def chain(self) -> Optional[str]: + """Return the ca chain bundled as a single PEM string.""" + cert = self.get_cert() + return cert.chain_as_pem() if cert else None + + def _on_certificate_expiring( + self, event: Union[CertificateExpiringEvent, CertificateInvalidatedEvent] + ) -> None: + """Generate a new CSR and request certificate renewal.""" + if event.certificate == self.server_cert: + self._generate_csr(renew=True) + # FIXME why are we not emitting cert_changed here? + + def _certificate_revoked(self, event) -> None: + """Remove the certificate and generate a new CSR.""" + # Note: assuming "limit: 1" in metadata + if event.certificate == self.server_cert: + self._generate_csr(overwrite=True, clear_cert=True) + self.on.cert_changed.emit() # pyright: ignore + + def _on_certificate_invalidated(self, event: CertificateInvalidatedEvent) -> None: + """Deal with certificate revocation and expiration.""" + if event.certificate == self.server_cert: + # if event.reason in ("revoked", "expired"): + # Currently, the reason does not matter to us because the action is the same. + self._generate_csr(overwrite=True, clear_cert=True) + self.on.cert_changed.emit() # pyright: ignore + + def _on_all_certificates_invalidated(self, _: AllCertificatesInvalidatedEvent) -> None: + """Clear all secrets data when removing the relation.""" + # Note: assuming "limit: 1" in metadata + # The "certificates_relation_broken" event is converted to "all invalidated" custom + # event by the tls-certificates library. Per convention, we let the lib manage the + # relation and we do not observe "certificates_relation_broken" directly. + self.vault.clear() + # We do not generate a CSR here because the relation is gone. + self.on.cert_changed.emit() # pyright: ignore + + def _check_juju_supports_secrets(self) -> bool: + version = JujuVersion.from_environ() + if not JujuVersion(version=str(version)).has_secrets: + msg = f"Juju version {version} does not supports Secrets. Juju >= 3.0.3 is needed" + logger.debug(msg) + return False + return True diff --git a/charms/istio-pilot/lib/charms/tls_certificates_interface/v2/tls_certificates.py b/charms/istio-pilot/lib/charms/tls_certificates_interface/v3/tls_certificates.py similarity index 63% rename from charms/istio-pilot/lib/charms/tls_certificates_interface/v2/tls_certificates.py rename to charms/istio-pilot/lib/charms/tls_certificates_interface/v3/tls_certificates.py index f4a08366..aa4704c7 100644 --- a/charms/istio-pilot/lib/charms/tls_certificates_interface/v2/tls_certificates.py +++ b/charms/istio-pilot/lib/charms/tls_certificates_interface/v3/tls_certificates.py @@ -1,4 +1,4 @@ -# Copyright 2021 Canonical Ltd. +# Copyright 2024 Canonical Ltd. # See LICENSE file for licensing details. @@ -7,16 +7,19 @@ This library contains the Requires and Provides classes for handling the tls-certificates interface. +Pre-requisites: + - Juju >= 3.0 + ## Getting Started From a charm directory, fetch the library using `charmcraft`: ```shell -charmcraft fetch-lib charms.tls_certificates_interface.v2.tls_certificates +charmcraft fetch-lib charms.tls_certificates_interface.v3.tls_certificates ``` Add the following libraries to the charm's `requirements.txt` file: - jsonschema -- cryptography +- cryptography >= 42.0.0 Add the following section to the charm's `charmcraft.yaml` file: ```yaml @@ -36,10 +39,10 @@ Example: ```python -from charms.tls_certificates_interface.v2.tls_certificates import ( +from charms.tls_certificates_interface.v3.tls_certificates import ( CertificateCreationRequestEvent, CertificateRevocationRequestEvent, - TLSCertificatesProvidesV2, + TLSCertificatesProvidesV3, generate_private_key, ) from ops.charm import CharmBase, InstallEvent @@ -59,7 +62,7 @@ class ExampleProviderCharm(CharmBase): def __init__(self, *args): super().__init__(*args) - self.certificates = TLSCertificatesProvidesV2(self, "certificates") + self.certificates = TLSCertificatesProvidesV3(self, "certificates") self.framework.observe( self.certificates.on.certificate_request, self._on_certificate_request @@ -108,6 +111,7 @@ def _on_certificate_request(self, event: CertificateCreationRequestEvent) -> Non ca=ca_certificate, chain=[ca_certificate, certificate], relation_id=event.relation_id, + recommended_expiry_notification_time=720, ) def _on_certificate_revocation_request(self, event: CertificateRevocationRequestEvent) -> None: @@ -126,15 +130,15 @@ def _on_certificate_revocation_request(self, event: CertificateRevocationRequest Example: ```python -from charms.tls_certificates_interface.v2.tls_certificates import ( +from charms.tls_certificates_interface.v3.tls_certificates import ( CertificateAvailableEvent, CertificateExpiringEvent, CertificateRevokedEvent, - TLSCertificatesRequiresV2, + TLSCertificatesRequiresV3, generate_csr, generate_private_key, ) -from ops.charm import CharmBase, RelationJoinedEvent +from ops.charm import CharmBase, RelationCreatedEvent from ops.main import main from ops.model import ActiveStatus, WaitingStatus from typing import Union @@ -145,10 +149,10 @@ class ExampleRequirerCharm(CharmBase): def __init__(self, *args): super().__init__(*args) self.cert_subject = "whatever" - self.certificates = TLSCertificatesRequiresV2(self, "certificates") + self.certificates = TLSCertificatesRequiresV3(self, "certificates") self.framework.observe(self.on.install, self._on_install) self.framework.observe( - self.on.certificates_relation_joined, self._on_certificates_relation_joined + self.on.certificates_relation_created, self._on_certificates_relation_created ) self.framework.observe( self.certificates.on.certificate_available, self._on_certificate_available @@ -176,7 +180,7 @@ def _on_install(self, event) -> None: {"private_key_password": "banana", "private_key": private_key.decode()} ) - def _on_certificates_relation_joined(self, event: RelationJoinedEvent) -> None: + def _on_certificates_relation_created(self, event: RelationCreatedEvent) -> None: replicas_relation = self.model.get_relation("replicas") if not replicas_relation: self.unit.status = WaitingStatus("Waiting for peer relation to be created") @@ -273,48 +277,53 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven """ # noqa: D405, D410, D411, D214, D416 import copy +import ipaddress import json import logging import uuid from contextlib import suppress -from datetime import datetime, timedelta -from ipaddress import IPv4Address -from typing import Any, Dict, List, Literal, Optional, Union +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import List, Literal, Optional, Union from cryptography import x509 from cryptography.hazmat._oid import ExtensionOID from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.hazmat.primitives.serialization import pkcs12 -from cryptography.x509.extensions import Extension, ExtensionNotFound -from jsonschema import exceptions, validate # type: ignore[import] +from jsonschema import exceptions, validate from ops.charm import ( CharmBase, CharmEvents, RelationBrokenEvent, RelationChangedEvent, SecretExpiredEvent, - UpdateStatusEvent, ) from ops.framework import EventBase, EventSource, Handle, Object from ops.jujuversion import JujuVersion -from ops.model import Relation, SecretNotFoundError +from ops.model import ( + Application, + ModelError, + Relation, + RelationDataContent, + SecretNotFoundError, + Unit, +) # The unique Charmhub library identifier, never change it LIBID = "afd8c2bccf834997afce12c2706d2ede" # Increment this major API version when introducing breaking changes -LIBAPI = 2 +LIBAPI = 3 # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 16 +LIBPATCH = 17 PYDEPS = ["cryptography", "jsonschema"] REQUIRER_JSON_SCHEMA = { "$schema": "http://json-schema.org/draft-04/schema#", - "$id": "https://canonical.github.io/charm-relation-interfaces/tls_certificates/v2/schemas/requirer.json", # noqa: E501 + "$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/tls_certificates/v1/schemas/requirer.json", "type": "object", "title": "`tls_certificates` requirer root schema", "description": "The `tls_certificates` root schema comprises the entire requirer databag for this interface.", # noqa: E501 @@ -335,7 +344,10 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven "type": "array", "items": { "type": "object", - "properties": {"certificate_signing_request": {"type": "string"}}, + "properties": { + "certificate_signing_request": {"type": "string"}, + "ca": {"type": "boolean"}, + }, "required": ["certificate_signing_request"], }, } @@ -346,7 +358,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven PROVIDER_JSON_SCHEMA = { "$schema": "http://json-schema.org/draft-04/schema#", - "$id": "https://canonical.github.io/charm-relation-interfaces/tls_certificates/v2/schemas/provider.json", # noqa: E501 + "$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/tls_certificates/v1/schemas/provider.json", "type": "object", "title": "`tls_certificates` provider root schema", "description": "The `tls_certificates` root schema comprises the entire provider databag for this interface.", # noqa: E501 @@ -420,6 +432,58 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven logger = logging.getLogger(__name__) +@dataclass +class RequirerCSR: + """This class represents a certificate signing request from an interface Requirer.""" + + relation_id: int + application_name: str + unit_name: str + csr: str + is_ca: bool + + +@dataclass +class ProviderCertificate: + """This class represents a certificate from an interface Provider.""" + + relation_id: int + application_name: str + csr: str + certificate: str + ca: str + chain: List[str] + revoked: bool + expiry_time: datetime + expiry_notification_time: Optional[datetime] = None + + def chain_as_pem(self) -> str: + """Return full certificate chain as a PEM string.""" + return "\n\n".join(reversed(self.chain)) + + def to_json(self) -> str: + """Return the object as a JSON string. + + Returns: + str: JSON representation of the object + """ + return json.dumps( + { + "relation_id": self.relation_id, + "application_name": self.application_name, + "csr": self.csr, + "certificate": self.certificate, + "ca": self.ca, + "chain": self.chain, + "revoked": self.revoked, + "expiry_time": self.expiry_time.isoformat(), + "expiry_notification_time": self.expiry_notification_time.isoformat() + if self.expiry_notification_time + else None, + } + ) + + class CertificateAvailableEvent(EventBase): """Charm Event triggered when a TLS certificate is available.""" @@ -438,7 +502,7 @@ def __init__( self.chain = chain def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return { "certificate": self.certificate, "certificate_signing_request": self.certificate_signing_request, @@ -447,12 +511,16 @@ def snapshot(self) -> dict: } def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" self.certificate = snapshot["certificate"] self.certificate_signing_request = snapshot["certificate_signing_request"] self.ca = snapshot["ca"] self.chain = snapshot["chain"] + def chain_as_pem(self) -> str: + """Return full certificate chain as a PEM string.""" + return "\n\n".join(reversed(self.chain)) + class CertificateExpiringEvent(EventBase): """Charm Event triggered when a TLS certificate is almost expired.""" @@ -471,11 +539,11 @@ def __init__(self, handle, certificate: str, expiry: str): self.expiry = expiry def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return {"certificate": self.certificate, "expiry": self.expiry} def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" self.certificate = snapshot["certificate"] self.expiry = snapshot["expiry"] @@ -500,7 +568,7 @@ def __init__( self.chain = chain def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return { "reason": self.reason, "certificate_signing_request": self.certificate_signing_request, @@ -510,7 +578,7 @@ def snapshot(self) -> dict: } def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" self.reason = snapshot["reason"] self.certificate_signing_request = snapshot["certificate_signing_request"] self.certificate = snapshot["certificate"] @@ -525,33 +593,42 @@ def __init__(self, handle: Handle): super().__init__(handle) def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return {} def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" pass class CertificateCreationRequestEvent(EventBase): """Charm Event triggered when a TLS certificate is required.""" - def __init__(self, handle: Handle, certificate_signing_request: str, relation_id: int): + def __init__( + self, + handle: Handle, + certificate_signing_request: str, + relation_id: int, + is_ca: bool = False, + ): super().__init__(handle) self.certificate_signing_request = certificate_signing_request self.relation_id = relation_id + self.is_ca = is_ca def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return { "certificate_signing_request": self.certificate_signing_request, "relation_id": self.relation_id, + "is_ca": self.is_ca, } def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" self.certificate_signing_request = snapshot["certificate_signing_request"] self.relation_id = snapshot["relation_id"] + self.is_ca = snapshot["is_ca"] class CertificateRevocationRequestEvent(EventBase): @@ -572,7 +649,7 @@ def __init__( self.chain = chain def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return { "certificate": self.certificate, "certificate_signing_request": self.certificate_signing_request, @@ -581,33 +658,100 @@ def snapshot(self) -> dict: } def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" self.certificate = snapshot["certificate"] self.certificate_signing_request = snapshot["certificate_signing_request"] self.ca = snapshot["ca"] self.chain = snapshot["chain"] -def _load_relation_data(raw_relation_data: dict) -> dict: - """Loads relation data from the relation data bag. +def _load_relation_data(relation_data_content: RelationDataContent) -> dict: + """Load relation data from the relation data bag. Json loads all data. Args: - raw_relation_data: Relation data from the databag + relation_data_content: Relation data from the databag Returns: dict: Relation data in dict format. """ - certificate_data = dict() - for key in raw_relation_data: - try: - certificate_data[key] = json.loads(raw_relation_data[key]) - except (json.decoder.JSONDecodeError, TypeError): - certificate_data[key] = raw_relation_data[key] + certificate_data = {} + try: + for key in relation_data_content: + try: + certificate_data[key] = json.loads(relation_data_content[key]) + except (json.decoder.JSONDecodeError, TypeError): + certificate_data[key] = relation_data_content[key] + except ModelError: + pass return certificate_data +def _get_closest_future_time( + expiry_notification_time: datetime, expiry_time: datetime +) -> datetime: + """Return expiry_notification_time if not in the past, otherwise return expiry_time. + + Args: + expiry_notification_time (datetime): Notification time of impending expiration + expiry_time (datetime): Expiration time + + Returns: + datetime: expiry_notification_time if not in the past, expiry_time otherwise + """ + return ( + expiry_notification_time + if datetime.now(timezone.utc) < expiry_notification_time + else expiry_time + ) + + +def calculate_expiry_notification_time( + validity_start_time: datetime, + expiry_time: datetime, + provider_recommended_notification_time: Optional[int], + requirer_recommended_notification_time: Optional[int], +) -> datetime: + """Calculate a reasonable time to notify the user about the certificate expiry. + + It takes into account the time recommended by the provider and by the requirer. + Time recommended by the provider is preferred, + then time recommended by the requirer, + then dynamically calculated time. + + Args: + validity_start_time: Certificate validity time + expiry_time: Certificate expiry time + provider_recommended_notification_time: + Time in hours prior to expiry to notify the user. + Recommended by the provider. + requirer_recommended_notification_time: + Time in hours prior to expiry to notify the user. + Recommended by the requirer. + + Returns: + datetime: Time to notify the user about the certificate expiry. + """ + if provider_recommended_notification_time is not None: + provider_recommended_notification_time = abs(provider_recommended_notification_time) + provider_recommendation_time_delta = ( + expiry_time - timedelta(hours=provider_recommended_notification_time) + ) + if validity_start_time < provider_recommendation_time_delta: + return provider_recommendation_time_delta + + if requirer_recommended_notification_time is not None: + requirer_recommended_notification_time = abs(requirer_recommended_notification_time) + requirer_recommendation_time_delta = ( + expiry_time - timedelta(hours=requirer_recommended_notification_time) + ) + if validity_start_time < requirer_recommendation_time_delta: + return requirer_recommendation_time_delta + calculated_hours = (expiry_time - validity_start_time).total_seconds() / (3600 * 3) + return expiry_time - timedelta(hours=calculated_hours) + + def generate_ca( private_key: bytes, subject: str, @@ -615,11 +759,11 @@ def generate_ca( validity: int = 365, country: str = "US", ) -> bytes: - """Generates a CA Certificate. + """Generate a CA Certificate. Args: private_key (bytes): Private key - subject (str): Certificate subject + subject (str): Common Name that can be an IP or a Full Qualified Domain Name (FQDN). private_key_password (bytes): Private key password validity (int): Certificate validity time (in days) country (str): Certificate Issuing country @@ -630,7 +774,7 @@ def generate_ca( private_key_object = serialization.load_pem_private_key( private_key, password=private_key_password ) - subject = issuer = x509.Name( + subject_name = x509.Name( [ x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country), x509.NameAttribute(x509.NameOID.COMMON_NAME, subject), @@ -653,12 +797,12 @@ def generate_ca( ) cert = ( x509.CertificateBuilder() - .subject_name(subject) - .issuer_name(issuer) + .subject_name(subject_name) + .issuer_name(subject_name) .public_key(private_key_object.public_key()) # type: ignore[arg-type] .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.utcnow()) - .not_valid_after(datetime.utcnow() + timedelta(days=validity)) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=validity)) .add_extension(x509.SubjectKeyIdentifier(digest=subject_identifier), critical=False) .add_extension( x509.AuthorityKeyIdentifier( @@ -678,6 +822,105 @@ def generate_ca( return cert.public_bytes(serialization.Encoding.PEM) +def get_certificate_extensions( + authority_key_identifier: bytes, + csr: x509.CertificateSigningRequest, + alt_names: Optional[List[str]], + is_ca: bool, +) -> List[x509.Extension]: + """Generate a list of certificate extensions from a CSR and other known information. + + Args: + authority_key_identifier (bytes): Authority key identifier + csr (x509.CertificateSigningRequest): CSR + alt_names (list): List of alt names to put on cert - prefer putting SANs in CSR + is_ca (bool): Whether the certificate is a CA certificate + + Returns: + List[x509.Extension]: List of extensions + """ + cert_extensions_list: List[x509.Extension] = [ + x509.Extension( + oid=ExtensionOID.AUTHORITY_KEY_IDENTIFIER, + value=x509.AuthorityKeyIdentifier( + key_identifier=authority_key_identifier, + authority_cert_issuer=None, + authority_cert_serial_number=None, + ), + critical=False, + ), + x509.Extension( + oid=ExtensionOID.SUBJECT_KEY_IDENTIFIER, + value=x509.SubjectKeyIdentifier.from_public_key(csr.public_key()), + critical=False, + ), + x509.Extension( + oid=ExtensionOID.BASIC_CONSTRAINTS, + critical=True, + value=x509.BasicConstraints(ca=is_ca, path_length=None), + ), + ] + + sans: List[x509.GeneralName] = [] + san_alt_names = [x509.DNSName(name) for name in alt_names] if alt_names else [] + sans.extend(san_alt_names) + try: + loaded_san_ext = csr.extensions.get_extension_for_class(x509.SubjectAlternativeName) + sans.extend( + [x509.DNSName(name) for name in loaded_san_ext.value.get_values_for_type(x509.DNSName)] + ) + sans.extend( + [x509.IPAddress(ip) for ip in loaded_san_ext.value.get_values_for_type(x509.IPAddress)] + ) + sans.extend( + [ + x509.RegisteredID(oid) + for oid in loaded_san_ext.value.get_values_for_type(x509.RegisteredID) + ] + ) + except x509.ExtensionNotFound: + pass + + if sans: + cert_extensions_list.append( + x509.Extension( + oid=ExtensionOID.SUBJECT_ALTERNATIVE_NAME, + critical=False, + value=x509.SubjectAlternativeName(sans), + ) + ) + + if is_ca: + cert_extensions_list.append( + x509.Extension( + ExtensionOID.KEY_USAGE, + critical=True, + value=x509.KeyUsage( + digital_signature=False, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=True, + crl_sign=True, + encipher_only=False, + decipher_only=False, + ), + ) + ) + + existing_oids = {ext.oid for ext in cert_extensions_list} + for extension in csr.extensions: + if extension.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME: + continue + if extension.oid in existing_oids: + logger.warning("Extension %s is managed by the TLS provider, ignoring.", extension.oid) + continue + cert_extensions_list.append(extension) + + return cert_extensions_list + + def generate_certificate( csr: bytes, ca: bytes, @@ -685,8 +928,9 @@ def generate_certificate( ca_key_password: Optional[bytes] = None, validity: int = 365, alt_names: Optional[List[str]] = None, + is_ca: bool = False, ) -> bytes: - """Generates a TLS certificate based on a CSR. + """Generate a TLS certificate based on a CSR. Args: csr (bytes): CSR @@ -695,6 +939,7 @@ def generate_certificate( ca_key_password: CA private key password validity (int): Certificate validity (in days) alt_names (list): List of alt names to put on cert - prefer putting SANs in CSR + is_ca (bool): Whether the certificate is a CA certificate Returns: bytes: Certificate @@ -711,96 +956,36 @@ def generate_certificate( .issuer_name(issuer) .public_key(csr_object.public_key()) .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.utcnow()) - .not_valid_after(datetime.utcnow() + timedelta(days=validity)) - .add_extension( - x509.AuthorityKeyIdentifier( - key_identifier=ca_pem.extensions.get_extension_for_class( - x509.SubjectKeyIdentifier - ).value.key_identifier, - authority_cert_issuer=None, - authority_cert_serial_number=None, - ), - critical=False, - ) - .add_extension( - x509.SubjectKeyIdentifier.from_public_key(csr_object.public_key()), critical=False - ) - .add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=False) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=validity)) ) - - extensions_list = csr_object.extensions - san_ext: Optional[x509.Extension] = None - if alt_names: - full_sans_dns = alt_names.copy() + extensions = get_certificate_extensions( + authority_key_identifier=ca_pem.extensions.get_extension_for_class( + x509.SubjectKeyIdentifier + ).value.key_identifier, + csr=csr_object, + alt_names=alt_names, + is_ca=is_ca, + ) + for extension in extensions: try: - loaded_san_ext = csr_object.extensions.get_extension_for_class( - x509.SubjectAlternativeName - ) - full_sans_dns.extend(loaded_san_ext.value.get_values_for_type(x509.DNSName)) - except ExtensionNotFound: - pass - finally: - san_ext = Extension( - ExtensionOID.SUBJECT_ALTERNATIVE_NAME, - False, - x509.SubjectAlternativeName([x509.DNSName(name) for name in full_sans_dns]), + certificate_builder = certificate_builder.add_extension( + extval=extension.value, + critical=extension.critical, ) - if not extensions_list: - extensions_list = x509.Extensions([san_ext]) + except ValueError as e: + logger.warning("Failed to add extension %s: %s", extension.oid, e) - for extension in extensions_list: - if extension.value.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME and san_ext: - extension = san_ext - - certificate_builder = certificate_builder.add_extension( - extension.value, - critical=extension.critical, - ) - - certificate_builder._version = x509.Version.v3 cert = certificate_builder.sign(private_key, hashes.SHA256()) # type: ignore[arg-type] return cert.public_bytes(serialization.Encoding.PEM) -def generate_pfx_package( - certificate: bytes, - private_key: bytes, - package_password: str, - private_key_password: Optional[bytes] = None, -) -> bytes: - """Generates a PFX package to contain the TLS certificate and private key. - - Args: - certificate (bytes): TLS certificate - private_key (bytes): Private key - package_password (str): Password to open the PFX package - private_key_password (bytes): Private key password - - Returns: - bytes: - """ - private_key_object = serialization.load_pem_private_key( - private_key, password=private_key_password - ) - certificate_object = x509.load_pem_x509_certificate(certificate) - name = certificate_object.subject.rfc4514_string() - pfx_bytes = pkcs12.serialize_key_and_certificates( - name=name.encode(), - cert=certificate_object, - key=private_key_object, # type: ignore[arg-type] - cas=None, - encryption_algorithm=serialization.BestAvailableEncryption(package_password.encode()), - ) - return pfx_bytes - - def generate_private_key( password: Optional[bytes] = None, key_size: int = 2048, public_exponent: int = 65537, ) -> bytes: - """Generates a private key. + """Generate a private key. Args: password (bytes): Password for decrypting the private key @@ -817,20 +1002,24 @@ def generate_private_key( key_bytes = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.BestAvailableEncryption(password) - if password - else serialization.NoEncryption(), + encryption_algorithm=( + serialization.BestAvailableEncryption(password) + if password + else serialization.NoEncryption() + ), ) return key_bytes -def generate_csr( +def generate_csr( # noqa: C901 private_key: bytes, subject: str, add_unique_id_to_subject_name: bool = True, organization: Optional[str] = None, email_address: Optional[str] = None, country_name: Optional[str] = None, + state_or_province_name: Optional[str] = None, + locality_name: Optional[str] = None, private_key_password: Optional[bytes] = None, sans: Optional[List[str]] = None, sans_oid: Optional[List[str]] = None, @@ -838,17 +1027,19 @@ def generate_csr( sans_dns: Optional[List[str]] = None, additional_critical_extensions: Optional[List] = None, ) -> bytes: - """Generates a CSR using private key and subject. + """Generate a CSR using private key and subject. Args: private_key (bytes): Private key - subject (str): CSR Subject. + subject (str): CSR Common Name that can be an IP or a Full Qualified Domain Name (FQDN). add_unique_id_to_subject_name (bool): Whether a unique ID must be added to the CSR's subject name. Always leave to "True" when the CSR is used to request certificates using the tls-certificates relation. organization (str): Name of organization. email_address (str): Email address. country_name (str): Country Name. + state_or_province_name (str): State or Province Name. + locality_name (str): Locality Name. private_key_password (bytes): Private key password sans (list): Use sans_dns - this will be deprecated in a future release List of DNS subject alternative names (keeping it for now for backward compatibility) @@ -874,13 +1065,19 @@ def generate_csr( subject_name.append(x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, email_address)) if country_name: subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country_name)) + if state_or_province_name: + subject_name.append( + x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, state_or_province_name) + ) + if locality_name: + subject_name.append(x509.NameAttribute(x509.NameOID.LOCALITY_NAME, locality_name)) csr = x509.CertificateSigningRequestBuilder(subject_name=x509.Name(subject_name)) _sans: List[x509.GeneralName] = [] if sans_oid: _sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid]) if sans_ip: - _sans.extend([x509.IPAddress(IPv4Address(san)) for san in sans_ip]) + _sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in sans_ip]) if sans: _sans.extend([x509.DNSName(san) for san in sans]) if sans_dns: @@ -896,6 +1093,57 @@ def generate_csr( return signed_certificate.public_bytes(serialization.Encoding.PEM) +def get_sha256_hex(data: str) -> str: + """Calculate the hash of the provided data and return the hexadecimal representation.""" + digest = hashes.Hash(hashes.SHA256()) + digest.update(data.encode()) + return digest.finalize().hex() + + +def csr_matches_certificate(csr: str, cert: str) -> bool: + """Check if a CSR matches a certificate. + + Args: + csr (str): Certificate Signing Request as a string + cert (str): Certificate as a string + Returns: + bool: True/False depending on whether the CSR matches the certificate. + """ + csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) + cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) + + if csr_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) != cert_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ): + return False + return True + + +def _relation_data_is_valid( + relation: Relation, app_or_unit: Union[Application, Unit], json_schema: dict +) -> bool: + """Check whether relation data is valid based on json schema. + + Args: + relation (Relation): Relation object + app_or_unit (Union[Application, Unit]): Application or unit object + json_schema (dict): Json schema + + Returns: + bool: Whether relation data is valid. + """ + relation_data = _load_relation_data(relation.data[app_or_unit]) + try: + validate(instance=relation_data, schema=json_schema) + return True + except exceptions.ValidationError: + return False + + class CertificatesProviderCharmEvents(CharmEvents): """List of events that the TLS Certificates provider charm can leverage.""" @@ -912,10 +1160,10 @@ class CertificatesRequirerCharmEvents(CharmEvents): all_certificates_invalidated = EventSource(AllCertificatesInvalidatedEvent) -class TLSCertificatesProvidesV2(Object): +class TLSCertificatesProvidesV3(Object): """TLS certificates provider class to be instantiated by TLS certificates providers.""" - on = CertificatesProviderCharmEvents() + on = CertificatesProviderCharmEvents() # type: ignore[reportAssignmentType] def __init__(self, charm: CharmBase, relationship_name: str): super().__init__(charm, relationship_name) @@ -926,12 +1174,12 @@ def __init__(self, charm: CharmBase, relationship_name: str): self.relationship_name = relationship_name def _load_app_relation_data(self, relation: Relation) -> dict: - """Loads relation data from the application relation data bag. + """Load relation data from the application relation data bag. Json loads all data. Args: - relation_object: Relation data from the application databag + relation: Relation data from the application databag Returns: dict: Relation data in dict format. @@ -948,8 +1196,9 @@ def _add_certificate( certificate_signing_request: str, ca: str, chain: List[str], + recommended_expiry_notification_time: Optional[int] = None, ) -> None: - """Adds certificate to relation data. + """Add certificate to relation data. Args: relation_id (int): Relation id @@ -957,6 +1206,8 @@ def _add_certificate( certificate_signing_request (str): Certificate Signing Request ca (str): CA Certificate chain (list): CA Chain + recommended_expiry_notification_time (int): + Time in hours before the certificate expires to notify the user. Returns: None @@ -974,6 +1225,7 @@ def _add_certificate( "certificate_signing_request": certificate_signing_request, "ca": ca, "chain": chain, + "recommended_expiry_notification_time": recommended_expiry_notification_time, } provider_relation_data = self._load_app_relation_data(relation) provider_certificates = provider_relation_data.get("certificates", []) @@ -990,7 +1242,7 @@ def _remove_certificate( certificate: Optional[str] = None, certificate_signing_request: Optional[str] = None, ) -> None: - """Removes certificate from a given relation based on user provided certificate or csr. + """Remove certificate from a given relation based on user provided certificate or csr. Args: relation_id (int): Relation id @@ -1021,24 +1273,8 @@ def _remove_certificate( certificates.remove(certificate_dict) relation.data[self.model.app]["certificates"] = json.dumps(certificates) - @staticmethod - def _relation_data_is_valid(certificates_data: dict) -> bool: - """Uses JSON schema validator to validate relation data content. - - Args: - certificates_data (dict): Certificate data dictionary as retrieved from relation data. - - Returns: - bool: True/False depending on whether the relation data follows the json schema. - """ - try: - validate(instance=certificates_data, schema=REQUIRER_JSON_SCHEMA) - return True - except exceptions.ValidationError: - return False - def revoke_all_certificates(self) -> None: - """Revokes all certificates of this provider. + """Revoke all certificates of this provider. This method is meant to be used when the Root CA has changed. """ @@ -1056,8 +1292,9 @@ def set_relation_certificate( ca: str, chain: List[str], relation_id: int, + recommended_expiry_notification_time: Optional[int] = None, ) -> None: - """Adds certificates to relation data. + """Add certificates to relation data. Args: certificate (str): Certificate @@ -1065,6 +1302,8 @@ def set_relation_certificate( ca (str): CA Certificate chain (list): CA Chain relation_id (int): Juju relation ID + recommended_expiry_notification_time (int): + Recommended time in hours before the certificate expires to notify the user. Returns: None @@ -1086,10 +1325,11 @@ def set_relation_certificate( certificate_signing_request=certificate_signing_request.strip(), ca=ca.strip(), chain=[cert.strip() for cert in chain], + recommended_expiry_notification_time=recommended_expiry_notification_time, ) def remove_certificate(self, certificate: str) -> None: - """Removes a given certificate from relation data. + """Remove a given certificate from relation data. Args: certificate (str): TLS Certificate @@ -1105,16 +1345,24 @@ def remove_certificate(self, certificate: str) -> None: def get_issued_certificates( self, relation_id: Optional[int] = None - ) -> Dict[str, List[Dict[str, str]]]: - """Returns a dictionary of issued certificates. + ) -> List[ProviderCertificate]: + """Return a List of issued (non revoked) certificates. - It returns certificates from all relations if relation_id is not specified. - Certificates are returned per application name and CSR. + Returns: + List: List of ProviderCertificate objects + """ + provider_certificates = self.get_provider_certificates(relation_id=relation_id) + return [certificate for certificate in provider_certificates if not certificate.revoked] + + def get_provider_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Return a List of issued certificates. Returns: - dict: Certificates per application name. + List: List of ProviderCertificate objects """ - certificates: Dict[str, List[Dict[str, str]]] = {} + certificates: List[ProviderCertificate] = [] relations = ( [ relation @@ -1125,23 +1373,37 @@ def get_issued_certificates( else self.model.relations.get(self.relationship_name, []) ) for relation in relations: + if not relation.app: + logger.warning("Relation %s does not have an application", relation.id) + continue provider_relation_data = self._load_app_relation_data(relation) provider_certificates = provider_relation_data.get("certificates", []) - - certificates[relation.app.name] = [] # type: ignore[union-attr] for certificate in provider_certificates: - if not certificate.get("revoked", False): - certificates[relation.app.name].append( # type: ignore[union-attr] - { - "csr": certificate["certificate_signing_request"], - "certificate": certificate["certificate"], - } + try: + certificate_object = x509.load_pem_x509_certificate( + data=certificate["certificate"].encode() ) - + except ValueError as e: + logger.error("Could not load certificate - Skipping: %s", e) + continue + provider_certificate = ProviderCertificate( + relation_id=relation.id, + application_name=relation.app.name, + csr=certificate["certificate_signing_request"], + certificate=certificate["certificate"], + ca=certificate["ca"], + chain=certificate["chain"], + revoked=certificate.get("revoked", False), + expiry_time=certificate_object.not_valid_after_utc, + expiry_notification_time=certificate.get( + "recommended_expiry_notification_time" + ), + ) + certificates.append(provider_certificate) return certificates def _on_relation_changed(self, event: RelationChangedEvent) -> None: - """Handler triggered on relation changed event. + """Handle relation changed event. Looks at the relation data and either emits: - certificate request event: If the unit relation data contains a CSR for which @@ -1160,107 +1422,77 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: return if not self.model.unit.is_leader(): return - requirer_relation_data = _load_relation_data(event.relation.data[event.unit]) - provider_relation_data = self._load_app_relation_data(event.relation) - if not self._relation_data_is_valid(requirer_relation_data): + if not _relation_data_is_valid(event.relation, event.unit, REQUIRER_JSON_SCHEMA): logger.debug("Relation data did not pass JSON Schema validation") return - provider_certificates = provider_relation_data.get("certificates", []) - requirer_csrs = requirer_relation_data.get("certificate_signing_requests", []) + provider_certificates = self.get_provider_certificates(relation_id=event.relation.id) + requirer_csrs = self.get_requirer_csrs(relation_id=event.relation.id) provider_csrs = [ - certificate_creation_request["certificate_signing_request"] + certificate_creation_request.csr for certificate_creation_request in provider_certificates ] - requirer_unit_csrs = [ - certificate_creation_request["certificate_signing_request"] - for certificate_creation_request in requirer_csrs - ] - for certificate_signing_request in requirer_unit_csrs: - if certificate_signing_request not in provider_csrs: + for certificate_request in requirer_csrs: + if certificate_request.csr not in provider_csrs: self.on.certificate_creation_request.emit( - certificate_signing_request=certificate_signing_request, - relation_id=event.relation.id, + certificate_signing_request=certificate_request.csr, + relation_id=certificate_request.relation_id, + is_ca=certificate_request.is_ca, ) self._revoke_certificates_for_which_no_csr_exists(relation_id=event.relation.id) def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None: - """Revokes certificates for which no unit has a CSR. + """Revoke certificates for which no unit has a CSR. - Goes through all generated certificates and compare against the list of CSRs for all units - of a given relationship. - - Args: - relation_id (int): Relation id + Goes through all generated certificates and compare against the list of CSRs for all units. Returns: None """ - certificates_relation = self.model.get_relation( - relation_name=self.relationship_name, relation_id=relation_id - ) - if not certificates_relation: - raise RuntimeError(f"Relation {self.relationship_name} does not exist") - provider_relation_data = self._load_app_relation_data(certificates_relation) - list_of_csrs: List[str] = [] - for unit in certificates_relation.units: - requirer_relation_data = _load_relation_data(certificates_relation.data[unit]) - requirer_csrs = requirer_relation_data.get("certificate_signing_requests", []) - list_of_csrs.extend(csr["certificate_signing_request"] for csr in requirer_csrs) - provider_certificates = provider_relation_data.get("certificates", []) + provider_certificates = self.get_provider_certificates(relation_id) + requirer_csrs = self.get_requirer_csrs(relation_id) + list_of_csrs = [csr.csr for csr in requirer_csrs] for certificate in provider_certificates: - if certificate["certificate_signing_request"] not in list_of_csrs: + if certificate.csr not in list_of_csrs: self.on.certificate_revocation_request.emit( - certificate=certificate["certificate"], - certificate_signing_request=certificate["certificate_signing_request"], - ca=certificate["ca"], - chain=certificate["chain"], + certificate=certificate.certificate, + certificate_signing_request=certificate.csr, + ca=certificate.ca, + chain=certificate.chain, ) - self.remove_certificate(certificate=certificate["certificate"]) + self.remove_certificate(certificate=certificate.certificate) - def get_requirer_csrs_with_no_certs( + def get_outstanding_certificate_requests( self, relation_id: Optional[int] = None - ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: - """Filters the requirer's units csrs. - - Keeps the ones for which no certificate was provided. + ) -> List[RequirerCSR]: + """Return CSR's for which no certificate has been issued. Args: relation_id (int): Relation id Returns: - list: List of dictionaries that contain the unit's csrs - that don't have a certificate issued. + list: List of RequirerCSR objects. """ - all_unit_csr_mappings = copy.deepcopy(self.get_requirer_csrs(relation_id=relation_id)) - filtered_all_unit_csr_mappings: List[Dict[str, Union[int, str, List[Dict[str, str]]]]] = [] - for unit_csr_mapping in all_unit_csr_mappings: - csrs_without_certs = [] - for csr in unit_csr_mapping["unit_csrs"]: # type: ignore[union-attr] - if not self.certificate_issued_for_csr( - app_name=unit_csr_mapping["application_name"], # type: ignore[arg-type] - csr=csr["certificate_signing_request"], # type: ignore[index] - ): - csrs_without_certs.append(csr) - if csrs_without_certs: - unit_csr_mapping["unit_csrs"] = csrs_without_certs # type: ignore[assignment] - filtered_all_unit_csr_mappings.append(unit_csr_mapping) - return filtered_all_unit_csr_mappings - - def get_requirer_csrs( - self, relation_id: Optional[int] = None - ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: - """Returns a list of requirers' CSRs grouped by unit. + requirer_csrs = self.get_requirer_csrs(relation_id=relation_id) + outstanding_csrs: List[RequirerCSR] = [] + for relation_csr in requirer_csrs: + if not self.certificate_issued_for_csr( + app_name=relation_csr.application_name, + csr=relation_csr.csr, + relation_id=relation_id, + ): + outstanding_csrs.append(relation_csr) + return outstanding_csrs + + def get_requirer_csrs(self, relation_id: Optional[int] = None) -> List[RequirerCSR]: + """Return a list of requirers' CSRs. It returns CSRs from all relations if relation_id is not specified. CSRs are returned per relation id, application name and unit name. Returns: - list: List of dictionaries that contain the unit's csrs - with the following information - relation_id, application_name and unit_name. + list: List[RequirerCSR] """ - unit_csr_mappings: List[Dict[str, Union[int, str, List[Dict[str, str]]]]] = [] - + relation_csrs: List[RequirerCSR] = [] relations = ( [ relation @@ -1275,53 +1507,69 @@ def get_requirer_csrs( for unit in relation.units: requirer_relation_data = _load_relation_data(relation.data[unit]) unit_csrs_list = requirer_relation_data.get("certificate_signing_requests", []) - unit_csr_mappings.append( - { - "relation_id": relation.id, - "application_name": relation.app.name, # type: ignore[union-attr] - "unit_name": unit.name, - "unit_csrs": unit_csrs_list, - } - ) - return unit_csr_mappings + for unit_csr in unit_csrs_list: + csr = unit_csr.get("certificate_signing_request") + if not csr: + logger.warning("No CSR found in relation data - Skipping") + continue + ca = unit_csr.get("ca", False) + if not relation.app: + logger.warning("No remote app in relation - Skipping") + continue + relation_csr = RequirerCSR( + relation_id=relation.id, + application_name=relation.app.name, + unit_name=unit.name, + csr=csr, + is_ca=ca, + ) + relation_csrs.append(relation_csr) + return relation_csrs - def certificate_issued_for_csr(self, app_name: str, csr: str) -> bool: - """Checks whether a certificate has been issued for a given CSR. + def certificate_issued_for_csr( + self, app_name: str, csr: str, relation_id: Optional[int] + ) -> bool: + """Check whether a certificate has been issued for a given CSR. Args: app_name (str): Application name that the CSR belongs to. csr (str): Certificate Signing Request. + relation_id (Optional[int]): Relation ID Returns: bool: True/False depending on whether a certificate has been issued for the given CSR. """ - issued_certificates_per_csr = self.get_issued_certificates()[app_name] - for issued_pair in issued_certificates_per_csr: - if "csr" in issued_pair and issued_pair["csr"] == csr: - return csr_matches_certificate(csr, issued_pair["certificate"]) + issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id) + for issued_certificate in issued_certificates_per_csr: + if issued_certificate.csr == csr and issued_certificate.application_name == app_name: + return csr_matches_certificate(csr, issued_certificate.certificate) return False -class TLSCertificatesRequiresV2(Object): +class TLSCertificatesRequiresV3(Object): """TLS certificates requirer class to be instantiated by TLS certificates requirers.""" - on = CertificatesRequirerCharmEvents() + on = CertificatesRequirerCharmEvents() # type: ignore[reportAssignmentType] def __init__( self, charm: CharmBase, relationship_name: str, - expiry_notification_time: int = 168, + expiry_notification_time: Optional[int] = None, ): - """Generates/use private key and observes relation changed event. + """Generate/use private key and observes relation changed event. Args: charm: Charm object relationship_name: Juju relation name - expiry_notification_time (int): Time difference between now and expiry (in hours). - Used to trigger the CertificateExpiring event. Default: 7 days. + expiry_notification_time (int): Number of hours prior to certificate expiry. + Used to trigger the CertificateExpiring event. + This value is used as a recommendation only, + The actual value is calculated taking into account the provider's recommendation. """ super().__init__(charm, relationship_name) + if not JujuVersion.from_environ().has_secrets: + logger.warning("This version of the TLS library requires Juju secrets (Juju >= 3.0)") self.relationship_name = relationship_name self.charm = charm self.expiry_notification_time = expiry_notification_time @@ -1331,23 +1579,39 @@ def __init__( self.framework.observe( charm.on[relationship_name].relation_broken, self._on_relation_broken ) - if JujuVersion.from_environ().has_secrets: - self.framework.observe(charm.on.secret_expired, self._on_secret_expired) - else: - self.framework.observe(charm.on.update_status, self._on_update_status) + self.framework.observe(charm.on.secret_expired, self._on_secret_expired) - @property - def _requirer_csrs(self) -> List[Dict[str, str]]: - """Returns list of requirer's CSRs from relation data.""" + def get_requirer_csrs(self) -> List[RequirerCSR]: + """Return list of requirer's CSRs from relation unit data. + + Returns: + list: List of RequirerCSR objects. + """ relation = self.model.get_relation(self.relationship_name) if not relation: - raise RuntimeError(f"Relation {self.relationship_name} does not exist") + return [] + requirer_csrs = [] requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) - return requirer_relation_data.get("certificate_signing_requests", []) + requirer_csrs_dict = requirer_relation_data.get("certificate_signing_requests", []) + for requirer_csr_dict in requirer_csrs_dict: + csr = requirer_csr_dict.get("certificate_signing_request") + if not csr: + logger.warning("No CSR found in relation data - Skipping") + continue + ca = requirer_csr_dict.get("ca", False) + relation_csr = RequirerCSR( + relation_id=relation.id, + application_name=self.model.app.name, + unit_name=self.model.unit.name, + csr=csr, + is_ca=ca, + ) + requirer_csrs.append(relation_csr) + return requirer_csrs - @property - def _provider_certificates(self) -> List[Dict[str, str]]: - """Returns list of certificates from the provider's relation data.""" + def get_provider_certificates(self) -> List[ProviderCertificate]: + """Return list of certificates from the provider's relation data.""" + provider_certificates: List[ProviderCertificate] = [] relation = self.model.get_relation(self.relationship_name) if not relation: logger.debug("No relation: %s", self.relationship_name) @@ -1356,16 +1620,55 @@ def _provider_certificates(self) -> List[Dict[str, str]]: logger.debug("No remote app in relation: %s", self.relationship_name) return [] provider_relation_data = _load_relation_data(relation.data[relation.app]) - if not self._relation_data_is_valid(provider_relation_data): - logger.warning("Provider relation data did not pass JSON Schema validation") - return [] - return provider_relation_data.get("certificates", []) + provider_certificate_dicts = provider_relation_data.get("certificates", []) + for provider_certificate_dict in provider_certificate_dicts: + certificate = provider_certificate_dict.get("certificate") + if not certificate: + logger.warning("No certificate found in relation data - Skipping") + continue + try: + certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) + except ValueError as e: + logger.error("Could not load certificate - Skipping: %s", e) + continue + ca = provider_certificate_dict.get("ca") + chain = provider_certificate_dict.get("chain", []) + csr = provider_certificate_dict.get("certificate_signing_request") + recommended_expiry_notification_time = provider_certificate_dict.get( + "recommended_expiry_notification_time" + ) + expiry_time = certificate_object.not_valid_after_utc + validity_start_time = certificate_object.not_valid_before_utc + expiry_notification_time = calculate_expiry_notification_time( + validity_start_time=validity_start_time, + expiry_time=expiry_time, + provider_recommended_notification_time=recommended_expiry_notification_time, + requirer_recommended_notification_time=self.expiry_notification_time, + ) + if not csr: + logger.warning("No CSR found in relation data - Skipping") + continue + revoked = provider_certificate_dict.get("revoked", False) + provider_certificate = ProviderCertificate( + relation_id=relation.id, + application_name=relation.app.name, + csr=csr, + certificate=certificate, + ca=ca, + chain=chain, + revoked=revoked, + expiry_time=expiry_time, + expiry_notification_time=expiry_notification_time, + ) + provider_certificates.append(provider_certificate) + return provider_certificates - def _add_requirer_csr(self, csr: str) -> None: - """Adds CSR to relation data. + def _add_requirer_csr_to_relation_data(self, csr: str, is_ca: bool) -> None: + """Add CSR to relation data. Args: csr (str): Certificate Signing Request + is_ca (bool): Whether the certificate is a CA certificate Returns: None @@ -1376,16 +1679,24 @@ def _add_requirer_csr(self, csr: str) -> None: f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - new_csr_dict = {"certificate_signing_request": csr} - if new_csr_dict in self._requirer_csrs: - logger.info("CSR already in relation data - Doing nothing") - return - requirer_csrs = copy.deepcopy(self._requirer_csrs) - requirer_csrs.append(new_csr_dict) - relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs) + for requirer_csr in self.get_requirer_csrs(): + if requirer_csr.csr == csr and requirer_csr.is_ca == is_ca: + logger.info("CSR already in relation data - Doing nothing") + return + new_csr_dict = { + "certificate_signing_request": csr, + "ca": is_ca, + } + requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) + existing_relation_data = requirer_relation_data.get("certificate_signing_requests", []) + new_relation_data = copy.deepcopy(existing_relation_data) + new_relation_data.append(new_csr_dict) + relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps( + new_relation_data + ) - def _remove_requirer_csr(self, csr: str) -> None: - """Removes CSR from relation data. + def _remove_requirer_csr_from_relation_data(self, csr: str) -> None: + """Remove CSR from relation data. Args: csr (str): Certificate signing request @@ -1399,19 +1710,27 @@ def _remove_requirer_csr(self, csr: str) -> None: f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - requirer_csrs = copy.deepcopy(self._requirer_csrs) - csr_dict = {"certificate_signing_request": csr} - if csr_dict not in requirer_csrs: - logger.info("CSR not in relation data - Doing nothing") + if not self.get_requirer_csrs(): + logger.info("No CSRs in relation data - Doing nothing") return - requirer_csrs.remove(csr_dict) - relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs) + requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) + existing_relation_data = requirer_relation_data.get("certificate_signing_requests", []) + new_relation_data = copy.deepcopy(existing_relation_data) + for requirer_csr in new_relation_data: + if requirer_csr["certificate_signing_request"] == csr: + new_relation_data.remove(requirer_csr) + relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps( + new_relation_data + ) - def request_certificate_creation(self, certificate_signing_request: bytes) -> None: + def request_certificate_creation( + self, certificate_signing_request: bytes, is_ca: bool = False + ) -> None: """Request TLS certificate to provider charm. Args: certificate_signing_request (bytes): Certificate Signing Request + is_ca (bool): Whether the certificate is a CA certificate Returns: None @@ -1422,11 +1741,13 @@ def request_certificate_creation(self, certificate_signing_request: bytes) -> No f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - self._add_requirer_csr(certificate_signing_request.decode().strip()) + self._add_requirer_csr_to_relation_data( + certificate_signing_request.decode().strip(), is_ca=is_ca + ) logger.info("Certificate request sent to provider") def request_certificate_revocation(self, certificate_signing_request: bytes) -> None: - """Removes CSR from relation data. + """Remove CSR from relation data. The provider of this relation is then expected to remove certificates associated to this CSR from the relation data as well and emit a request_certificate_revocation event for the @@ -1438,13 +1759,13 @@ def request_certificate_revocation(self, certificate_signing_request: bytes) -> Returns: None """ - self._remove_requirer_csr(certificate_signing_request.decode().strip()) + self._remove_requirer_csr_from_relation_data(certificate_signing_request.decode().strip()) logger.info("Certificate revocation sent to provider") def request_certificate_renewal( self, old_certificate_signing_request: bytes, new_certificate_signing_request: bytes ) -> None: - """Renews certificate. + """Renew certificate. Removes old CSR from relation data and adds new one. @@ -1466,33 +1787,69 @@ def request_certificate_renewal( ) logger.info("Certificate renewal request completed.") - @staticmethod - def _relation_data_is_valid(certificates_data: dict) -> bool: - """Checks whether relation data is valid based on json schema. + def get_assigned_certificates(self) -> List[ProviderCertificate]: + """Get a list of certificates that were assigned to this unit. + + Returns: + List: List[ProviderCertificate] + """ + assigned_certificates = [] + for requirer_csr in self.get_certificate_signing_requests(fulfilled_only=True): + if cert := self._find_certificate_in_relation_data(requirer_csr.csr): + assigned_certificates.append(cert) + return assigned_certificates + + def get_expiring_certificates(self) -> List[ProviderCertificate]: + """Get a list of certificates that were assigned to this unit that are expiring or expired. + + Returns: + List: List[ProviderCertificate] + """ + expiring_certificates: List[ProviderCertificate] = [] + for requirer_csr in self.get_certificate_signing_requests(fulfilled_only=True): + if cert := self._find_certificate_in_relation_data(requirer_csr.csr): + if not cert.expiry_time or not cert.expiry_notification_time: + continue + if datetime.now(timezone.utc) > cert.expiry_notification_time: + expiring_certificates.append(cert) + return expiring_certificates + + def get_certificate_signing_requests( + self, + fulfilled_only: bool = False, + unfulfilled_only: bool = False, + ) -> List[RequirerCSR]: + """Get the list of CSR's that were sent to the provider. + + You can choose to get only the CSR's that have a certificate assigned or only the CSR's + that don't. Args: - certificates_data: Certificate data in dict format. + fulfilled_only (bool): This option will discard CSRs that don't have certificates yet. + unfulfilled_only (bool): This option will discard CSRs that have certificates signed. Returns: - bool: Whether relation data is valid. + List of RequirerCSR objects. """ - try: - validate(instance=certificates_data, schema=PROVIDER_JSON_SCHEMA) - return True - except exceptions.ValidationError: - return False + csrs = [] + for requirer_csr in self.get_requirer_csrs(): + cert = self._find_certificate_in_relation_data(requirer_csr.csr) + if (unfulfilled_only and cert) or (fulfilled_only and not cert): + continue + csrs.append(requirer_csr) + + return csrs def _on_relation_changed(self, event: RelationChangedEvent) -> None: - """Handler triggered on relation changed events. + """Handle relation changed event. Goes through all providers certificates that match a requested CSR. If the provider certificate is revoked, emit a CertificateInvalidateEvent, otherwise emit a CertificateAvailableEvent. - When Juju secrets are available, remove the secret for revoked certificate, - or add a secret with the correct expiry time for new certificates. - + Remove the secret for revoked certificate, or add a secret with the correct expiry + time for new certificates. Args: event: Juju event @@ -1500,54 +1857,65 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: Returns: None """ + if not event.app: + logger.warning("No remote app in relation - Skipping") + return + if not _relation_data_is_valid(event.relation, event.app, PROVIDER_JSON_SCHEMA): + logger.debug("Relation data did not pass JSON Schema validation") + return + provider_certificates = self.get_provider_certificates() requirer_csrs = [ - certificate_creation_request["certificate_signing_request"] - for certificate_creation_request in self._requirer_csrs + certificate_creation_request.csr + for certificate_creation_request in self.get_requirer_csrs() ] - for certificate in self._provider_certificates: - if certificate["certificate_signing_request"] in requirer_csrs: - if certificate.get("revoked", False): - if JujuVersion.from_environ().has_secrets: - with suppress(SecretNotFoundError): - secret = self.model.get_secret( - label=f"{LIBID}-{certificate['certificate_signing_request']}" - ) - secret.remove_all_revisions() + for certificate in provider_certificates: + if certificate.csr in requirer_csrs: + csr_in_sha256_hex = get_sha256_hex(certificate.csr) + if certificate.revoked: + with suppress(SecretNotFoundError): + logger.debug( + "Removing secret with label %s", + f"{LIBID}-{csr_in_sha256_hex}", + ) + secret = self.model.get_secret( + label=f"{LIBID}-{csr_in_sha256_hex}") + secret.remove_all_revisions() self.on.certificate_invalidated.emit( reason="revoked", - certificate=certificate["certificate"], - certificate_signing_request=certificate["certificate_signing_request"], - ca=certificate["ca"], - chain=certificate["chain"], + certificate=certificate.certificate, + certificate_signing_request=certificate.csr, + ca=certificate.ca, + chain=certificate.chain, ) else: - if JujuVersion.from_environ().has_secrets: - try: - secret = self.model.get_secret( - label=f"{LIBID}-{certificate['certificate_signing_request']}" - ) - secret.set_content({"certificate": certificate["certificate"]}) - secret.set_info( - expire=self._get_next_secret_expiry_time( - certificate["certificate"] - ), - ) - except SecretNotFoundError: - secret = self.charm.unit.add_secret( - {"certificate": certificate["certificate"]}, - label=f"{LIBID}-{certificate['certificate_signing_request']}", - expire=self._get_next_secret_expiry_time( - certificate["certificate"] - ), - ) + try: + logger.debug( + "Setting secret with label %s", f"{LIBID}-{csr_in_sha256_hex}" + ) + secret = self.model.get_secret(label=f"{LIBID}-{csr_in_sha256_hex}") + secret.set_content( + {"certificate": certificate.certificate, "csr": certificate.csr} + ) + secret.set_info( + expire=self._get_next_secret_expiry_time(certificate), + ) + except SecretNotFoundError: + logger.debug( + "Creating new secret with label %s", f"{LIBID}-{csr_in_sha256_hex}" + ) + secret = self.charm.unit.add_secret( + {"certificate": certificate.certificate, "csr": certificate.csr}, + label=f"{LIBID}-{csr_in_sha256_hex}", + expire=self._get_next_secret_expiry_time(certificate), + ) self.on.certificate_available.emit( - certificate_signing_request=certificate["certificate_signing_request"], - certificate=certificate["certificate"], - ca=certificate["ca"], - chain=certificate["chain"], + certificate_signing_request=certificate.csr, + certificate=certificate.certificate, + ca=certificate.ca, + chain=certificate.chain, ) - def _get_next_secret_expiry_time(self, certificate: str) -> Optional[datetime]: + def _get_next_secret_expiry_time(self, certificate: ProviderCertificate) -> Optional[datetime]: """Return the expiry time or expiry notification time. Extracts the expiry time from the provided certificate, calculates the @@ -1555,20 +1923,21 @@ def _get_next_secret_expiry_time(self, certificate: str) -> Optional[datetime]: the future. Args: - certificate: x509 certificate + certificate: ProviderCertificate object Returns: Optional[datetime]: None if the certificate expiry time cannot be read, next expiry time otherwise. """ - expiry_time = _get_certificate_expiry_time(certificate) - if not expiry_time: + if not certificate.expiry_time or not certificate.expiry_notification_time: return None - expiry_notification_time = expiry_time - timedelta(hours=self.expiry_notification_time) - return _get_closest_future_time(expiry_notification_time, expiry_time) + return _get_closest_future_time( + certificate.expiry_notification_time, + certificate.expiry_time, + ) def _on_relation_broken(self, event: RelationBrokenEvent) -> None: - """Handler triggered on relation broken event. + """Handle Relation Broken Event. Emitting `all_certificates_invalidated` from `relation-broken` rather than `relation-departed` since certs are stored in app data. @@ -1582,7 +1951,7 @@ def _on_relation_broken(self, event: RelationBrokenEvent) -> None: self.on.all_certificates_invalidated.emit() def _on_secret_expired(self, event: SecretExpiredEvent) -> None: - """Triggered when a certificate is set to expire. + """Handle Secret Expired Event. Loads the certificate from the secret, and will emit 1 of 2 events. @@ -1599,145 +1968,43 @@ def _on_secret_expired(self, event: SecretExpiredEvent) -> None: """ if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-"): return - csr = event.secret.label[len(f"{LIBID}-") :] - certificate_dict = self._find_certificate_in_relation_data(csr) - if not certificate_dict: + csr = event.secret.get_content()["csr"] + provider_certificate = self._find_certificate_in_relation_data(csr) + if not provider_certificate: # A secret expired but we did not find matching certificate. Cleaning up event.secret.remove_all_revisions() return - expiry_time = _get_certificate_expiry_time(certificate_dict["certificate"]) - if not expiry_time: + if not provider_certificate.expiry_time: # A secret expired but matching certificate is invalid. Cleaning up event.secret.remove_all_revisions() return - if datetime.utcnow() < expiry_time: + if datetime.now(timezone.utc) < provider_certificate.expiry_time: logger.warning("Certificate almost expired") self.on.certificate_expiring.emit( - certificate=certificate_dict["certificate"], - expiry=expiry_time.isoformat(), + certificate=provider_certificate.certificate, + expiry=provider_certificate.expiry_time.isoformat(), ) event.secret.set_info( - expire=_get_certificate_expiry_time(certificate_dict["certificate"]), + expire=provider_certificate.expiry_time, ) else: logger.warning("Certificate is expired") self.on.certificate_invalidated.emit( reason="expired", - certificate=certificate_dict["certificate"], - certificate_signing_request=certificate_dict["certificate_signing_request"], - ca=certificate_dict["ca"], - chain=certificate_dict["chain"], + certificate=provider_certificate.certificate, + certificate_signing_request=provider_certificate.csr, + ca=provider_certificate.ca, + chain=provider_certificate.chain, ) - self.request_certificate_revocation(certificate_dict["certificate"].encode()) + self.request_certificate_revocation(provider_certificate.certificate.encode()) event.secret.remove_all_revisions() - def _find_certificate_in_relation_data(self, csr: str) -> Optional[Dict[str, Any]]: - """Returns the certificate that match the given CSR.""" - for certificate_dict in self._provider_certificates: - if certificate_dict["certificate_signing_request"] != csr: + def _find_certificate_in_relation_data(self, csr: str) -> Optional[ProviderCertificate]: + """Return the certificate that match the given CSR.""" + for provider_certificate in self.get_provider_certificates(): + if provider_certificate.csr != csr: continue - return certificate_dict - return None - - def _on_update_status(self, event: UpdateStatusEvent) -> None: - """Triggered on update status event. - - Goes through each certificate in the "certificates" relation and checks their expiry date. - If they are close to expire (<7 days), emits a CertificateExpiringEvent event and if - they are expired, emits a CertificateExpiredEvent. - - Args: - event (UpdateStatusEvent): Juju event - - Returns: - None - """ - for certificate_dict in self._provider_certificates: - expiry_time = _get_certificate_expiry_time(certificate_dict["certificate"]) - if not expiry_time: - continue - time_difference = expiry_time - datetime.utcnow() - if time_difference.total_seconds() < 0: - logger.warning("Certificate is expired") - self.on.certificate_invalidated.emit( - reason="expired", - certificate=certificate_dict["certificate"], - certificate_signing_request=certificate_dict["certificate_signing_request"], - ca=certificate_dict["ca"], - chain=certificate_dict["chain"], - ) - self.request_certificate_revocation(certificate_dict["certificate"].encode()) - continue - if time_difference.total_seconds() < (self.expiry_notification_time * 60 * 60): - logger.warning("Certificate almost expired") - self.on.certificate_expiring.emit( - certificate=certificate_dict["certificate"], - expiry=expiry_time.isoformat(), - ) - - -def csr_matches_certificate(csr: str, cert: str) -> bool: - """Check if a CSR matches a certificate. - - expects to get the original string representations. - - Args: - csr (str): Certificate Signing Request - cert (str): Certificate - Returns: - bool: True/False depending on whether the CSR matches the certificate. - """ - try: - csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) - cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) - - if csr_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ) != cert_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ): - return False - if csr_object.subject != cert_object.subject: - return False - except ValueError: - logger.warning("Could not load certificate or CSR.") - return False - return True - - -def _get_closest_future_time( - expiry_notification_time: datetime, expiry_time: datetime -) -> datetime: - """Return expiry_notification_time if not in the past, otherwise return expiry_time. - - Args: - expiry_notification_time (datetime): Notification time of impending expiration - expiry_time (datetime): Expiration time - - Returns: - datetime: expiry_notification_time if not in the past, expiry_time otherwise - """ - return ( - expiry_notification_time if datetime.utcnow() < expiry_notification_time else expiry_time - ) - - -def _get_certificate_expiry_time(certificate: str) -> Optional[datetime]: - """Extract expiry time from a certificate string. - - Args: - certificate (str): x509 certificate as a string - - Returns: - Optional[datetime]: Expiry datetime or None - """ - try: - certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) - return certificate_object.not_valid_after - except ValueError: - logger.warning("Could not load certificate.") + return provider_certificate return None diff --git a/charms/istio-pilot/src/charm.py b/charms/istio-pilot/src/charm.py index a018720c..3f83460b 100755 --- a/charms/istio-pilot/src/charm.py +++ b/charms/istio-pilot/src/charm.py @@ -17,7 +17,7 @@ DEFAULT_RELATION_NAME as GATEWAY_INFO_RELATION_NAME, ) from charms.istio_pilot.v0.istio_gateway_info import GatewayProvider -from charms.observability_libs.v0.cert_handler import CertHandler +from charms.observability_libs.v1.cert_handler import CertHandler from charms.prometheus_k8s.v0.prometheus_scrape import MetricsEndpointProvider from lightkube import Client from lightkube.core.exceptions import ApiError @@ -112,7 +112,6 @@ def __init__(self, *args): self._cert_handler = CertHandler( self, key="istio-cert", - peer_relation_name=self.peer_relation_name, cert_subject=self._cert_subject, ) @@ -862,8 +861,12 @@ def _tls_info(self) -> Dict[str, str]: ).decode("utf-8"), } return { - "tls-crt": base64.b64encode(self._cert_handler.cert.encode("ascii")).decode("utf-8"), - "tls-key": base64.b64encode(self._cert_handler.key.encode("ascii")).decode("utf-8"), + "tls-crt": base64.b64encode(self._cert_handler.server_cert.encode("ascii")).decode( + "utf-8" + ), + "tls-key": base64.b64encode(self._cert_handler.private_key.encode("ascii")).decode( + "utf-8" + ), } # ---- Start of the block @@ -964,16 +967,16 @@ def _use_https_with_tls_provider(self) -> bool: # If the certificates relation is established, we can assume # that we want to configure TLS - if _xor(self._cert_handler.cert, self._cert_handler.key): + if _xor(self._cert_handler.server_cert, self._cert_handler.private_key): # Fail if tls is only partly configured as this is probably a mistake missing = "pkey" - if not self._cert_handler.cert: + if not self._cert_handler.server_cert: missing = "CA cert" raise ErrorWithStatus( f"Missing {missing}, cannot configure TLS", BlockedStatus, ) - if self._cert_handler.cert and self._cert_handler.key: + if self._cert_handler.server_cert and self._cert_handler.private_key: return True def _log_and_set_status(self, status): diff --git a/charms/istio-pilot/tests/unit/test_charm.py b/charms/istio-pilot/tests/unit/test_charm.py index 00161b11..78e09f48 100644 --- a/charms/istio-pilot/tests/unit/test_charm.py +++ b/charms/istio-pilot/tests/unit/test_charm.py @@ -140,15 +140,21 @@ class TestCharmEvents: their handling, etc). """ - def test_event_observing(self, harness, mocker, mocked_cert_subject): - harness.begin() + @patch("charm.Istioctl") + def test_event_observing(self, mocked_istioctl, harness, mocker, mocked_cert_subject): + harness.begin_with_initial_hooks() mocked_install = mocker.patch("charm.Operator.install") mocked_remove = mocker.patch("charm.Operator.remove") mocked_upgrade_charm = mocker.patch("charm.Operator.upgrade_charm") mocked_reconcile = mocker.patch("charm.Operator.reconcile") - RelationCreatedEvent harness.charm.on.install.emit() + mocked_istioctl.assert_called_once_with( + "./istioctl", + harness.charm.model.name, + "minimal", + istioctl_extra_flags=harness.charm._istioctl_extra_flags, + ) assert_called_once_and_reset(mocked_install) harness.charm.on.remove.emit() @@ -582,8 +588,8 @@ def test_gateway_port( harness.charm._cert_handler = MagicMock() harness.charm._cert_handler.enabled = cert_handler_enabled - harness.charm._cert_handler.cert = tls_cert - harness.charm._cert_handler.key = tls_key + harness.charm._cert_handler.server_cert = tls_cert + harness.charm._cert_handler.private_key = tls_key with expected_context: gateway_port = harness.charm._gateway_port @@ -883,8 +889,8 @@ def test_reconcile_gateway_with_tls( harness.begin() harness.charm._cert_handler = MagicMock() harness.charm._cert_handler.enabled = True - harness.charm._cert_handler.cert = "some-cert" - harness.charm._cert_handler.key = "some-key" + harness.charm._cert_handler.server_cert = "some-cert" + harness.charm._cert_handler.private_key = "some-key" # Act harness.charm._reconcile_gateway() @@ -1297,8 +1303,8 @@ def test_use_https_with_tls_provider( harness.begin() harness.charm._cert_handler = MagicMock() harness.charm._cert_handler.enabled = cert_handler_enabled - harness.charm._cert_handler.cert = tls_cert - harness.charm._cert_handler.key = tls_key + harness.charm._cert_handler.server_cert = tls_cert + harness.charm._cert_handler.private_key = tls_key with expected_context: assert harness.charm._use_https_with_tls_provider() == expected_return @@ -1512,8 +1518,8 @@ def test_tls_info_cert_provider(self, harness, mocked_lightkube_client): """Test the method returns a populated dictionary with TLS information.""" harness.begin() harness.charm._cert_handler = MagicMock() - harness.charm._cert_handler.cert = "cert-value" - harness.charm._cert_handler.key = "key-value" + harness.charm._cert_handler.server_cert = "cert-value" + harness.charm._cert_handler.private_key = "key-value" tls_crt_encoded = base64.b64encode("cert-value".encode("ascii")).decode("utf-8") tls_key_encoded = base64.b64encode("key-value".encode("ascii")).decode("utf-8") diff --git a/tests/test_bundle.py b/tests/test_bundle.py index b5a013c9..58cc5c86 100644 --- a/tests/test_bundle.py +++ b/tests/test_bundle.py @@ -264,11 +264,13 @@ async def test_enable_ingress_auth(ops_test: OpsTest): trust=OIDC_GATEKEEPER_TRUST, ) - await ops_test.model.add_relation(f"{ISTIO_PILOT}:ingress", f"{DEX_AUTH}:ingress") - await ops_test.model.add_relation(f"{ISTIO_PILOT}:ingress", f"{OIDC_GATEKEEPER}:ingress") - await ops_test.model.add_relation(f"{OIDC_GATEKEEPER}:oidc-client", f"{DEX_AUTH}:oidc-client") - await ops_test.model.add_relation(f"{OIDC_GATEKEEPER}:dex-oidc-config", f"{DEX_AUTH}:dex-oidc-config") - await ops_test.model.add_relation( + await ops_test.model.integrate(f"{ISTIO_PILOT}:ingress", f"{DEX_AUTH}:ingress") + await ops_test.model.integrate(f"{ISTIO_PILOT}:ingress", f"{OIDC_GATEKEEPER}:ingress") + await ops_test.model.integrate(f"{OIDC_GATEKEEPER}:oidc-client", f"{DEX_AUTH}:oidc-client") + await ops_test.model.integrate( + f"{OIDC_GATEKEEPER}:dex-oidc-config", f"{DEX_AUTH}:dex-oidc-config" + ) + await ops_test.model.integrate( f"{ISTIO_PILOT}:ingress-auth", f"{OIDC_GATEKEEPER}:ingress-auth" ) @@ -276,7 +278,7 @@ async def test_enable_ingress_auth(ops_test: OpsTest): await ops_test.model.wait_for_idle( status="active", raise_on_blocked=False, - timeout=90 * 10, + timeout=60 * 15, ) # Wait for the pods from our secondary workload, just in case. This should be faster than