diff --git a/lib/charms/tls_certificates_interface/v3/tls_certificates.py b/lib/charms/tls_certificates_interface/v3/tls_certificates.py deleted file mode 100644 index aa4704c..0000000 --- a/lib/charms/tls_certificates_interface/v3/tls_certificates.py +++ /dev/null @@ -1,2010 +0,0 @@ -# Copyright 2024 Canonical Ltd. -# See LICENSE file for licensing details. - - -"""Library for the tls-certificates relation. - -This library contains the Requires and Provides classes for handling the tls-certificates -interface. - -Pre-requisites: - - Juju >= 3.0 - -## Getting Started -From a charm directory, fetch the library using `charmcraft`: - -```shell -charmcraft fetch-lib charms.tls_certificates_interface.v3.tls_certificates -``` - -Add the following libraries to the charm's `requirements.txt` file: -- jsonschema -- cryptography >= 42.0.0 - -Add the following section to the charm's `charmcraft.yaml` file: -```yaml -parts: - charm: - build-packages: - - libffi-dev - - libssl-dev - - rustc - - cargo -``` - -### Provider charm -The provider charm is the charm providing certificates to another charm that requires them. In -this example, the provider charm is storing its private key using a peer relation interface called -`replicas`. - -Example: -```python -from charms.tls_certificates_interface.v3.tls_certificates import ( - CertificateCreationRequestEvent, - CertificateRevocationRequestEvent, - TLSCertificatesProvidesV3, - generate_private_key, -) -from ops.charm import CharmBase, InstallEvent -from ops.main import main -from ops.model import ActiveStatus, WaitingStatus - - -def generate_ca(private_key: bytes, subject: str) -> str: - return "whatever ca content" - - -def generate_certificate(ca: str, private_key: str, csr: str) -> str: - return "Whatever certificate" - - -class ExampleProviderCharm(CharmBase): - - def __init__(self, *args): - super().__init__(*args) - self.certificates = TLSCertificatesProvidesV3(self, "certificates") - self.framework.observe( - self.certificates.on.certificate_request, - self._on_certificate_request - ) - self.framework.observe( - self.certificates.on.certificate_revocation_request, - self._on_certificate_revocation_request - ) - self.framework.observe(self.on.install, self._on_install) - - def _on_install(self, event: InstallEvent) -> None: - private_key_password = b"banana" - private_key = generate_private_key(password=private_key_password) - ca_certificate = generate_ca(private_key=private_key, subject="whatever") - replicas_relation = self.model.get_relation("replicas") - if not replicas_relation: - self.unit.status = WaitingStatus("Waiting for peer relation to be created") - event.defer() - return - replicas_relation.data[self.app].update( - { - "private_key_password": "banana", - "private_key": private_key, - "ca_certificate": ca_certificate, - } - ) - self.unit.status = ActiveStatus() - - def _on_certificate_request(self, event: CertificateCreationRequestEvent) -> None: - replicas_relation = self.model.get_relation("replicas") - if not replicas_relation: - self.unit.status = WaitingStatus("Waiting for peer relation to be created") - event.defer() - return - ca_certificate = replicas_relation.data[self.app].get("ca_certificate") - private_key = replicas_relation.data[self.app].get("private_key") - certificate = generate_certificate( - ca=ca_certificate, - private_key=private_key, - csr=event.certificate_signing_request, - ) - - self.certificates.set_relation_certificate( - certificate=certificate, - certificate_signing_request=event.certificate_signing_request, - ca=ca_certificate, - chain=[ca_certificate, certificate], - relation_id=event.relation_id, - recommended_expiry_notification_time=720, - ) - - def _on_certificate_revocation_request(self, event: CertificateRevocationRequestEvent) -> None: - # Do what you want to do with this information - pass - - -if __name__ == "__main__": - main(ExampleProviderCharm) -``` - -### Requirer charm -The requirer charm is the charm requiring certificates from another charm that provides them. In -this example, the requirer charm is storing its certificates using a peer relation interface called -`replicas`. - -Example: -```python -from charms.tls_certificates_interface.v3.tls_certificates import ( - CertificateAvailableEvent, - CertificateExpiringEvent, - CertificateRevokedEvent, - TLSCertificatesRequiresV3, - generate_csr, - generate_private_key, -) -from ops.charm import CharmBase, RelationCreatedEvent -from ops.main import main -from ops.model import ActiveStatus, WaitingStatus -from typing import Union - - -class ExampleRequirerCharm(CharmBase): - - def __init__(self, *args): - super().__init__(*args) - self.cert_subject = "whatever" - self.certificates = TLSCertificatesRequiresV3(self, "certificates") - self.framework.observe(self.on.install, self._on_install) - self.framework.observe( - self.on.certificates_relation_created, self._on_certificates_relation_created - ) - self.framework.observe( - self.certificates.on.certificate_available, self._on_certificate_available - ) - self.framework.observe( - self.certificates.on.certificate_expiring, self._on_certificate_expiring - ) - self.framework.observe( - self.certificates.on.certificate_invalidated, self._on_certificate_invalidated - ) - self.framework.observe( - self.certificates.on.all_certificates_invalidated, - self._on_all_certificates_invalidated - ) - - def _on_install(self, event) -> None: - private_key_password = b"banana" - private_key = generate_private_key(password=private_key_password) - replicas_relation = self.model.get_relation("replicas") - if not replicas_relation: - self.unit.status = WaitingStatus("Waiting for peer relation to be created") - event.defer() - return - replicas_relation.data[self.app].update( - {"private_key_password": "banana", "private_key": private_key.decode()} - ) - - def _on_certificates_relation_created(self, event: RelationCreatedEvent) -> None: - replicas_relation = self.model.get_relation("replicas") - if not replicas_relation: - self.unit.status = WaitingStatus("Waiting for peer relation to be created") - event.defer() - return - private_key_password = replicas_relation.data[self.app].get("private_key_password") - private_key = replicas_relation.data[self.app].get("private_key") - csr = generate_csr( - private_key=private_key.encode(), - private_key_password=private_key_password.encode(), - subject=self.cert_subject, - ) - replicas_relation.data[self.app].update({"csr": csr.decode()}) - self.certificates.request_certificate_creation(certificate_signing_request=csr) - - def _on_certificate_available(self, event: CertificateAvailableEvent) -> None: - replicas_relation = self.model.get_relation("replicas") - if not replicas_relation: - self.unit.status = WaitingStatus("Waiting for peer relation to be created") - event.defer() - return - replicas_relation.data[self.app].update({"certificate": event.certificate}) - replicas_relation.data[self.app].update({"ca": event.ca}) - replicas_relation.data[self.app].update({"chain": event.chain}) - self.unit.status = ActiveStatus() - - def _on_certificate_expiring( - self, event: Union[CertificateExpiringEvent, CertificateInvalidatedEvent] - ) -> None: - replicas_relation = self.model.get_relation("replicas") - if not replicas_relation: - self.unit.status = WaitingStatus("Waiting for peer relation to be created") - event.defer() - return - old_csr = replicas_relation.data[self.app].get("csr") - private_key_password = replicas_relation.data[self.app].get("private_key_password") - private_key = replicas_relation.data[self.app].get("private_key") - new_csr = generate_csr( - private_key=private_key.encode(), - private_key_password=private_key_password.encode(), - subject=self.cert_subject, - ) - self.certificates.request_certificate_renewal( - old_certificate_signing_request=old_csr, - new_certificate_signing_request=new_csr, - ) - replicas_relation.data[self.app].update({"csr": new_csr.decode()}) - - def _certificate_revoked(self) -> None: - old_csr = replicas_relation.data[self.app].get("csr") - private_key_password = replicas_relation.data[self.app].get("private_key_password") - private_key = replicas_relation.data[self.app].get("private_key") - new_csr = generate_csr( - private_key=private_key.encode(), - private_key_password=private_key_password.encode(), - subject=self.cert_subject, - ) - self.certificates.request_certificate_renewal( - old_certificate_signing_request=old_csr, - new_certificate_signing_request=new_csr, - ) - replicas_relation.data[self.app].update({"csr": new_csr.decode()}) - replicas_relation.data[self.app].pop("certificate") - replicas_relation.data[self.app].pop("ca") - replicas_relation.data[self.app].pop("chain") - self.unit.status = WaitingStatus("Waiting for new certificate") - - def _on_certificate_invalidated(self, event: CertificateInvalidatedEvent) -> None: - replicas_relation = self.model.get_relation("replicas") - if not replicas_relation: - self.unit.status = WaitingStatus("Waiting for peer relation to be created") - event.defer() - return - if event.reason == "revoked": - self._certificate_revoked() - if event.reason == "expired": - self._on_certificate_expiring(event) - - def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEvent) -> None: - # Do what you want with this information, probably remove all certificates. - pass - - -if __name__ == "__main__": - main(ExampleRequirerCharm) -``` - -You can relate both charms by running: - -```bash -juju relate -``` - -""" # noqa: D405, D410, D411, D214, D416 - -import copy -import ipaddress -import json -import logging -import uuid -from contextlib import suppress -from dataclasses import dataclass -from datetime import datetime, timedelta, timezone -from typing import List, Literal, Optional, Union - -from cryptography import x509 -from cryptography.hazmat._oid import ExtensionOID -from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import rsa -from jsonschema import exceptions, validate -from ops.charm import ( - CharmBase, - CharmEvents, - RelationBrokenEvent, - RelationChangedEvent, - SecretExpiredEvent, -) -from ops.framework import EventBase, EventSource, Handle, Object -from ops.jujuversion import JujuVersion -from ops.model import ( - Application, - ModelError, - Relation, - RelationDataContent, - SecretNotFoundError, - Unit, -) - -# The unique Charmhub library identifier, never change it -LIBID = "afd8c2bccf834997afce12c2706d2ede" - -# Increment this major API version when introducing breaking changes -LIBAPI = 3 - -# Increment this PATCH version before using `charmcraft publish-lib` or reset -# to 0 if you are raising the major API version -LIBPATCH = 17 - -PYDEPS = ["cryptography", "jsonschema"] - -REQUIRER_JSON_SCHEMA = { - "$schema": "http://json-schema.org/draft-04/schema#", - "$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/tls_certificates/v1/schemas/requirer.json", - "type": "object", - "title": "`tls_certificates` requirer root schema", - "description": "The `tls_certificates` root schema comprises the entire requirer databag for this interface.", # noqa: E501 - "examples": [ - { - "certificate_signing_requests": [ - { - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----\\nMIICWjCCAUICAQAwFTETMBEGA1UEAwwKYmFuYW5hLmNvbTCCASIwDQYJKoZIhvcN\\nAQEBBQADggEPADCCAQoCggEBANWlx9wE6cW7Jkb4DZZDOZoEjk1eDBMJ+8R4pyKp\\nFBeHMl1SQSDt6rAWsrfL3KOGiIHqrRY0B5H6c51L8LDuVrJG0bPmyQ6rsBo3gVke\\nDSivfSLtGvHtp8lwYnIunF8r858uYmblAR0tdXQNmnQvm+6GERvURQ6sxpgZ7iLC\\npPKDoPt+4GKWL10FWf0i82FgxWC2KqRZUtNbgKETQuARLig7etBmCnh20zmynorA\\ncY7vrpTPAaeQpGLNqqYvKV9W6yWVY08V+nqARrFrjk3vSioZSu8ZJUdZ4d9++SGl\\nbH7A6e77YDkX9i/dQ3Pa/iDtWO3tXS2MvgoxX1iSWlGNOHcCAwEAAaAAMA0GCSqG\\nSIb3DQEBCwUAA4IBAQCW1fKcHessy/ZhnIwAtSLznZeZNH8LTVOzkhVd4HA7EJW+\\nKVLBx8DnN7L3V2/uPJfHiOg4Rx7fi7LkJPegl3SCqJZ0N5bQS/KvDTCyLG+9E8Y+\\n7wqCmWiXaH1devimXZvazilu4IC2dSks2D8DPWHgsOdVks9bme8J3KjdNMQudegc\\newWZZ1Dtbd+Rn7cpKU3jURMwm4fRwGxbJ7iT5fkLlPBlyM/yFEik4SmQxFYrZCQg\\n0f3v4kBefTh5yclPy5tEH+8G0LMsbbo3dJ5mPKpAShi0QEKDLd7eR1R/712lYTK4\\ndi4XaEfqERgy68O4rvb4PGlJeRGS7AmL7Ss8wfAq\\n-----END CERTIFICATE REQUEST-----\\n" # noqa: E501 - }, - { - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----\\nMIICWjCCAUICAQAwFTETMBEGA1UEAwwKYmFuYW5hLmNvbTCCASIwDQYJKoZIhvcN\\nAQEBBQADggEPADCCAQoCggEBAMk3raaX803cHvzlBF9LC7KORT46z4VjyU5PIaMb\\nQLIDgYKFYI0n5hf2Ra4FAHvOvEmW7bjNlHORFEmvnpcU5kPMNUyKFMTaC8LGmN8z\\nUBH3aK+0+FRvY4afn9tgj5435WqOG9QdoDJ0TJkjJbJI9M70UOgL711oU7ql6HxU\\n4d2ydFK9xAHrBwziNHgNZ72L95s4gLTXf0fAHYf15mDA9U5yc+YDubCKgTXzVySQ\\nUx73VCJLfC/XkZIh559IrnRv5G9fu6BMLEuBwAz6QAO4+/XidbKWN4r2XSq5qX4n\\n6EPQQWP8/nd4myq1kbg6Q8w68L/0YdfjCmbyf2TuoWeImdUCAwEAAaAAMA0GCSqG\\nSIb3DQEBCwUAA4IBAQBIdwraBvpYo/rl5MH1+1Um6HRg4gOdQPY5WcJy9B9tgzJz\\nittRSlRGTnhyIo6fHgq9KHrmUthNe8mMTDailKFeaqkVNVvk7l0d1/B90Kz6OfmD\\nxN0qjW53oP7y3QB5FFBM8DjqjmUnz5UePKoX4AKkDyrKWxMwGX5RoET8c/y0y9jp\\nvSq3Wh5UpaZdWbe1oVY8CqMVUEVQL2DPjtopxXFz2qACwsXkQZxWmjvZnRiP8nP8\\nbdFaEuh9Q6rZ2QdZDEtrU4AodPU3NaukFr5KlTUQt3w/cl+5//zils6G5zUWJ2pN\\ng7+t9PTvXHRkH+LnwaVnmsBFU2e05qADQbfIn7JA\\n-----END CERTIFICATE REQUEST-----\\n" # noqa: E501 - }, - ] - } - ], - "properties": { - "certificate_signing_requests": { - "type": "array", - "items": { - "type": "object", - "properties": { - "certificate_signing_request": {"type": "string"}, - "ca": {"type": "boolean"}, - }, - "required": ["certificate_signing_request"], - }, - } - }, - "required": ["certificate_signing_requests"], - "additionalProperties": True, -} - -PROVIDER_JSON_SCHEMA = { - "$schema": "http://json-schema.org/draft-04/schema#", - "$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/tls_certificates/v1/schemas/provider.json", - "type": "object", - "title": "`tls_certificates` provider root schema", - "description": "The `tls_certificates` root schema comprises the entire provider databag for this interface.", # noqa: E501 - "examples": [ - { - "certificates": [ - { - "ca": "-----BEGIN CERTIFICATE-----\\nMIIDJTCCAg2gAwIBAgIUMsSK+4FGCjW6sL/EXMSxColmKw8wDQYJKoZIhvcNAQEL\\nBQAwIDELMAkGA1UEBhMCVVMxETAPBgNVBAMMCHdoYXRldmVyMB4XDTIyMDcyOTIx\\nMTgyN1oXDTIzMDcyOTIxMTgyN1owIDELMAkGA1UEBhMCVVMxETAPBgNVBAMMCHdo\\nYXRldmVyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA55N9DkgFWbJ/\\naqcdQhso7n1kFvt6j/fL1tJBvRubkiFMQJnZFtekfalN6FfRtA3jq+nx8o49e+7t\\nLCKT0xQ+wufXfOnxv6/if6HMhHTiCNPOCeztUgQ2+dfNwRhYYgB1P93wkUVjwudK\\n13qHTTZ6NtEF6EzOqhOCe6zxq6wrr422+ZqCvcggeQ5tW9xSd/8O1vNID/0MTKpy\\nET3drDtBfHmiUEIBR3T3tcy6QsIe4Rz/2sDinAcM3j7sG8uY6drh8jY3PWar9til\\nv2l4qDYSU8Qm5856AB1FVZRLRJkLxZYZNgreShAIYgEd0mcyI2EO/UvKxsIcxsXc\\nd45GhGpKkwIDAQABo1cwVTAfBgNVHQ4EGAQWBBRXBrXKh3p/aFdQjUcT/UcvICBL\\nODAhBgNVHSMEGjAYgBYEFFcGtcqHen9oV1CNRxP9Ry8gIEs4MA8GA1UdEwEB/wQF\\nMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAGmCEvcoFUrT9e133SHkgF/ZAgzeIziO\\nBjfAdU4fvAVTVfzaPm0yBnGqzcHyacCzbZjKQpaKVgc5e6IaqAQtf6cZJSCiJGhS\\nJYeosWrj3dahLOUAMrXRr8G/Ybcacoqc+osKaRa2p71cC3V6u2VvcHRV7HDFGJU7\\noijbdB+WhqET6Txe67rxZCJG9Ez3EOejBJBl2PJPpy7m1Ml4RR+E8YHNzB0lcBzc\\nEoiJKlDfKSO14E2CPDonnUoWBJWjEvJys3tbvKzsRj2fnLilytPFU0gH3cEjCopi\\nzFoWRdaRuNHYCqlBmso1JFDl8h4fMmglxGNKnKRar0WeGyxb4xXBGpI=\\n-----END CERTIFICATE-----\\n", # noqa: E501 - "chain": [ - "-----BEGIN CERTIFICATE-----\\nMIIDJTCCAg2gAwIBAgIUMsSK+4FGCjW6sL/EXMSxColmKw8wDQYJKoZIhvcNAQEL\\nBQAwIDELMAkGA1UEBhMCVVMxETAPBgNVBAMMCHdoYXRldmVyMB4XDTIyMDcyOTIx\\nMTgyN1oXDTIzMDcyOTIxMTgyN1owIDELMAkGA1UEBhMCVVMxETAPBgNVBAMMCHdo\\nYXRldmVyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA55N9DkgFWbJ/\\naqcdQhso7n1kFvt6j/fL1tJBvRubkiFMQJnZFtekfalN6FfRtA3jq+nx8o49e+7t\\nLCKT0xQ+wufXfOnxv6/if6HMhHTiCNPOCeztUgQ2+dfNwRhYYgB1P93wkUVjwudK\\n13qHTTZ6NtEF6EzOqhOCe6zxq6wrr422+ZqCvcggeQ5tW9xSd/8O1vNID/0MTKpy\\nET3drDtBfHmiUEIBR3T3tcy6QsIe4Rz/2sDinAcM3j7sG8uY6drh8jY3PWar9til\\nv2l4qDYSU8Qm5856AB1FVZRLRJkLxZYZNgreShAIYgEd0mcyI2EO/UvKxsIcxsXc\\nd45GhGpKkwIDAQABo1cwVTAfBgNVHQ4EGAQWBBRXBrXKh3p/aFdQjUcT/UcvICBL\\nODAhBgNVHSMEGjAYgBYEFFcGtcqHen9oV1CNRxP9Ry8gIEs4MA8GA1UdEwEB/wQF\\nMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAGmCEvcoFUrT9e133SHkgF/ZAgzeIziO\\nBjfAdU4fvAVTVfzaPm0yBnGqzcHyacCzbZjKQpaKVgc5e6IaqAQtf6cZJSCiJGhS\\nJYeosWrj3dahLOUAMrXRr8G/Ybcacoqc+osKaRa2p71cC3V6u2VvcHRV7HDFGJU7\\noijbdB+WhqET6Txe67rxZCJG9Ez3EOejBJBl2PJPpy7m1Ml4RR+E8YHNzB0lcBzc\\nEoiJKlDfKSO14E2CPDonnUoWBJWjEvJys3tbvKzsRj2fnLilytPFU0gH3cEjCopi\\nzFoWRdaRuNHYCqlBmso1JFDl8h4fMmglxGNKnKRar0WeGyxb4xXBGpI=\\n-----END CERTIFICATE-----\\n" # noqa: E501, W505 - ], - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----\nMIICWjCCAUICAQAwFTETMBEGA1UEAwwKYmFuYW5hLmNvbTCCASIwDQYJKoZIhvcN\nAQEBBQADggEPADCCAQoCggEBANWlx9wE6cW7Jkb4DZZDOZoEjk1eDBMJ+8R4pyKp\nFBeHMl1SQSDt6rAWsrfL3KOGiIHqrRY0B5H6c51L8LDuVrJG0bPmyQ6rsBo3gVke\nDSivfSLtGvHtp8lwYnIunF8r858uYmblAR0tdXQNmnQvm+6GERvURQ6sxpgZ7iLC\npPKDoPt+4GKWL10FWf0i82FgxWC2KqRZUtNbgKETQuARLig7etBmCnh20zmynorA\ncY7vrpTPAaeQpGLNqqYvKV9W6yWVY08V+nqARrFrjk3vSioZSu8ZJUdZ4d9++SGl\nbH7A6e77YDkX9i/dQ3Pa/iDtWO3tXS2MvgoxX1iSWlGNOHcCAwEAAaAAMA0GCSqG\nSIb3DQEBCwUAA4IBAQCW1fKcHessy/ZhnIwAtSLznZeZNH8LTVOzkhVd4HA7EJW+\nKVLBx8DnN7L3V2/uPJfHiOg4Rx7fi7LkJPegl3SCqJZ0N5bQS/KvDTCyLG+9E8Y+\n7wqCmWiXaH1devimXZvazilu4IC2dSks2D8DPWHgsOdVks9bme8J3KjdNMQudegc\newWZZ1Dtbd+Rn7cpKU3jURMwm4fRwGxbJ7iT5fkLlPBlyM/yFEik4SmQxFYrZCQg\n0f3v4kBefTh5yclPy5tEH+8G0LMsbbo3dJ5mPKpAShi0QEKDLd7eR1R/712lYTK4\ndi4XaEfqERgy68O4rvb4PGlJeRGS7AmL7Ss8wfAq\n-----END CERTIFICATE REQUEST-----\n", # noqa: E501 - "certificate": "-----BEGIN CERTIFICATE-----\nMIICvDCCAaQCFFPAOD7utDTsgFrm0vS4We18OcnKMA0GCSqGSIb3DQEBCwUAMCAx\nCzAJBgNVBAYTAlVTMREwDwYDVQQDDAh3aGF0ZXZlcjAeFw0yMjA3MjkyMTE5Mzha\nFw0yMzA3MjkyMTE5MzhaMBUxEzARBgNVBAMMCmJhbmFuYS5jb20wggEiMA0GCSqG\nSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDVpcfcBOnFuyZG+A2WQzmaBI5NXgwTCfvE\neKciqRQXhzJdUkEg7eqwFrK3y9yjhoiB6q0WNAeR+nOdS/Cw7layRtGz5skOq7Aa\nN4FZHg0or30i7Rrx7afJcGJyLpxfK/OfLmJm5QEdLXV0DZp0L5vuhhEb1EUOrMaY\nGe4iwqTyg6D7fuBili9dBVn9IvNhYMVgtiqkWVLTW4ChE0LgES4oO3rQZgp4dtM5\nsp6KwHGO766UzwGnkKRizaqmLylfVusllWNPFfp6gEaxa45N70oqGUrvGSVHWeHf\nfvkhpWx+wOnu+2A5F/Yv3UNz2v4g7Vjt7V0tjL4KMV9YklpRjTh3AgMBAAEwDQYJ\nKoZIhvcNAQELBQADggEBAChjRzuba8zjQ7NYBVas89Oy7u++MlS8xWxh++yiUsV6\nWMk3ZemsPtXc1YmXorIQohtxLxzUPm2JhyzFzU/sOLmJQ1E/l+gtZHyRCwsb20fX\nmphuJsMVd7qv/GwEk9PBsk2uDqg4/Wix0Rx5lf95juJP7CPXQJl5FQauf3+LSz0y\nwF/j+4GqvrwsWr9hKOLmPdkyKkR6bHKtzzsxL9PM8GnElk2OpaPMMnzbL/vt2IAt\nxK01ZzPxCQCzVwHo5IJO5NR/fIyFbEPhxzG17QsRDOBR9fl9cOIvDeSO04vyZ+nz\n+kA2c3fNrZFAtpIlOOmFh8Q12rVL4sAjI5mVWnNEgvI=\n-----END CERTIFICATE-----\n", # noqa: E501 - } - ] - }, - { - "certificates": [ - { - "ca": "-----BEGIN CERTIFICATE-----\\nMIIDJTCCAg2gAwIBAgIUMsSK+4FGCjW6sL/EXMSxColmKw8wDQYJKoZIhvcNAQEL\\nBQAwIDELMAkGA1UEBhMCVVMxETAPBgNVBAMMCHdoYXRldmVyMB4XDTIyMDcyOTIx\\nMTgyN1oXDTIzMDcyOTIxMTgyN1owIDELMAkGA1UEBhMCVVMxETAPBgNVBAMMCHdo\\nYXRldmVyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA55N9DkgFWbJ/\\naqcdQhso7n1kFvt6j/fL1tJBvRubkiFMQJnZFtekfalN6FfRtA3jq+nx8o49e+7t\\nLCKT0xQ+wufXfOnxv6/if6HMhHTiCNPOCeztUgQ2+dfNwRhYYgB1P93wkUVjwudK\\n13qHTTZ6NtEF6EzOqhOCe6zxq6wrr422+ZqCvcggeQ5tW9xSd/8O1vNID/0MTKpy\\nET3drDtBfHmiUEIBR3T3tcy6QsIe4Rz/2sDinAcM3j7sG8uY6drh8jY3PWar9til\\nv2l4qDYSU8Qm5856AB1FVZRLRJkLxZYZNgreShAIYgEd0mcyI2EO/UvKxsIcxsXc\\nd45GhGpKkwIDAQABo1cwVTAfBgNVHQ4EGAQWBBRXBrXKh3p/aFdQjUcT/UcvICBL\\nODAhBgNVHSMEGjAYgBYEFFcGtcqHen9oV1CNRxP9Ry8gIEs4MA8GA1UdEwEB/wQF\\nMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAGmCEvcoFUrT9e133SHkgF/ZAgzeIziO\\nBjfAdU4fvAVTVfzaPm0yBnGqzcHyacCzbZjKQpaKVgc5e6IaqAQtf6cZJSCiJGhS\\nJYeosWrj3dahLOUAMrXRr8G/Ybcacoqc+osKaRa2p71cC3V6u2VvcHRV7HDFGJU7\\noijbdB+WhqET6Txe67rxZCJG9Ez3EOejBJBl2PJPpy7m1Ml4RR+E8YHNzB0lcBzc\\nEoiJKlDfKSO14E2CPDonnUoWBJWjEvJys3tbvKzsRj2fnLilytPFU0gH3cEjCopi\\nzFoWRdaRuNHYCqlBmso1JFDl8h4fMmglxGNKnKRar0WeGyxb4xXBGpI=\\n-----END CERTIFICATE-----\\n", # noqa: E501 - "chain": [ - "-----BEGIN CERTIFICATE-----\\nMIIDJTCCAg2gAwIBAgIUMsSK+4FGCjW6sL/EXMSxColmKw8wDQYJKoZIhvcNAQEL\\nBQAwIDELMAkGA1UEBhMCVVMxETAPBgNVBAMMCHdoYXRldmVyMB4XDTIyMDcyOTIx\\nMTgyN1oXDTIzMDcyOTIxMTgyN1owIDELMAkGA1UEBhMCVVMxETAPBgNVBAMMCHdo\\nYXRldmVyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA55N9DkgFWbJ/\\naqcdQhso7n1kFvt6j/fL1tJBvRubkiFMQJnZFtekfalN6FfRtA3jq+nx8o49e+7t\\nLCKT0xQ+wufXfOnxv6/if6HMhHTiCNPOCeztUgQ2+dfNwRhYYgB1P93wkUVjwudK\\n13qHTTZ6NtEF6EzOqhOCe6zxq6wrr422+ZqCvcggeQ5tW9xSd/8O1vNID/0MTKpy\\nET3drDtBfHmiUEIBR3T3tcy6QsIe4Rz/2sDinAcM3j7sG8uY6drh8jY3PWar9til\\nv2l4qDYSU8Qm5856AB1FVZRLRJkLxZYZNgreShAIYgEd0mcyI2EO/UvKxsIcxsXc\\nd45GhGpKkwIDAQABo1cwVTAfBgNVHQ4EGAQWBBRXBrXKh3p/aFdQjUcT/UcvICBL\\nODAhBgNVHSMEGjAYgBYEFFcGtcqHen9oV1CNRxP9Ry8gIEs4MA8GA1UdEwEB/wQF\\nMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAGmCEvcoFUrT9e133SHkgF/ZAgzeIziO\\nBjfAdU4fvAVTVfzaPm0yBnGqzcHyacCzbZjKQpaKVgc5e6IaqAQtf6cZJSCiJGhS\\nJYeosWrj3dahLOUAMrXRr8G/Ybcacoqc+osKaRa2p71cC3V6u2VvcHRV7HDFGJU7\\noijbdB+WhqET6Txe67rxZCJG9Ez3EOejBJBl2PJPpy7m1Ml4RR+E8YHNzB0lcBzc\\nEoiJKlDfKSO14E2CPDonnUoWBJWjEvJys3tbvKzsRj2fnLilytPFU0gH3cEjCopi\\nzFoWRdaRuNHYCqlBmso1JFDl8h4fMmglxGNKnKRar0WeGyxb4xXBGpI=\\n-----END CERTIFICATE-----\\n" # noqa: E501, W505 - ], - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----\nMIICWjCCAUICAQAwFTETMBEGA1UEAwwKYmFuYW5hLmNvbTCCASIwDQYJKoZIhvcN\nAQEBBQADggEPADCCAQoCggEBANWlx9wE6cW7Jkb4DZZDOZoEjk1eDBMJ+8R4pyKp\nFBeHMl1SQSDt6rAWsrfL3KOGiIHqrRY0B5H6c51L8LDuVrJG0bPmyQ6rsBo3gVke\nDSivfSLtGvHtp8lwYnIunF8r858uYmblAR0tdXQNmnQvm+6GERvURQ6sxpgZ7iLC\npPKDoPt+4GKWL10FWf0i82FgxWC2KqRZUtNbgKETQuARLig7etBmCnh20zmynorA\ncY7vrpTPAaeQpGLNqqYvKV9W6yWVY08V+nqARrFrjk3vSioZSu8ZJUdZ4d9++SGl\nbH7A6e77YDkX9i/dQ3Pa/iDtWO3tXS2MvgoxX1iSWlGNOHcCAwEAAaAAMA0GCSqG\nSIb3DQEBCwUAA4IBAQCW1fKcHessy/ZhnIwAtSLznZeZNH8LTVOzkhVd4HA7EJW+\nKVLBx8DnN7L3V2/uPJfHiOg4Rx7fi7LkJPegl3SCqJZ0N5bQS/KvDTCyLG+9E8Y+\n7wqCmWiXaH1devimXZvazilu4IC2dSks2D8DPWHgsOdVks9bme8J3KjdNMQudegc\newWZZ1Dtbd+Rn7cpKU3jURMwm4fRwGxbJ7iT5fkLlPBlyM/yFEik4SmQxFYrZCQg\n0f3v4kBefTh5yclPy5tEH+8G0LMsbbo3dJ5mPKpAShi0QEKDLd7eR1R/712lYTK4\ndi4XaEfqERgy68O4rvb4PGlJeRGS7AmL7Ss8wfAq\n-----END CERTIFICATE REQUEST-----\n", # noqa: E501 - "certificate": "-----BEGIN CERTIFICATE-----\nMIICvDCCAaQCFFPAOD7utDTsgFrm0vS4We18OcnKMA0GCSqGSIb3DQEBCwUAMCAx\nCzAJBgNVBAYTAlVTMREwDwYDVQQDDAh3aGF0ZXZlcjAeFw0yMjA3MjkyMTE5Mzha\nFw0yMzA3MjkyMTE5MzhaMBUxEzARBgNVBAMMCmJhbmFuYS5jb20wggEiMA0GCSqG\nSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDVpcfcBOnFuyZG+A2WQzmaBI5NXgwTCfvE\neKciqRQXhzJdUkEg7eqwFrK3y9yjhoiB6q0WNAeR+nOdS/Cw7layRtGz5skOq7Aa\nN4FZHg0or30i7Rrx7afJcGJyLpxfK/OfLmJm5QEdLXV0DZp0L5vuhhEb1EUOrMaY\nGe4iwqTyg6D7fuBili9dBVn9IvNhYMVgtiqkWVLTW4ChE0LgES4oO3rQZgp4dtM5\nsp6KwHGO766UzwGnkKRizaqmLylfVusllWNPFfp6gEaxa45N70oqGUrvGSVHWeHf\nfvkhpWx+wOnu+2A5F/Yv3UNz2v4g7Vjt7V0tjL4KMV9YklpRjTh3AgMBAAEwDQYJ\nKoZIhvcNAQELBQADggEBAChjRzuba8zjQ7NYBVas89Oy7u++MlS8xWxh++yiUsV6\nWMk3ZemsPtXc1YmXorIQohtxLxzUPm2JhyzFzU/sOLmJQ1E/l+gtZHyRCwsb20fX\nmphuJsMVd7qv/GwEk9PBsk2uDqg4/Wix0Rx5lf95juJP7CPXQJl5FQauf3+LSz0y\nwF/j+4GqvrwsWr9hKOLmPdkyKkR6bHKtzzsxL9PM8GnElk2OpaPMMnzbL/vt2IAt\nxK01ZzPxCQCzVwHo5IJO5NR/fIyFbEPhxzG17QsRDOBR9fl9cOIvDeSO04vyZ+nz\n+kA2c3fNrZFAtpIlOOmFh8Q12rVL4sAjI5mVWnNEgvI=\n-----END CERTIFICATE-----\n", # noqa: E501 - "revoked": True, - } - ] - }, - ], - "properties": { - "certificates": { - "$id": "#/properties/certificates", - "type": "array", - "items": { - "$id": "#/properties/certificates/items", - "type": "object", - "required": ["certificate_signing_request", "certificate", "ca", "chain"], - "properties": { - "certificate_signing_request": { - "$id": "#/properties/certificates/items/certificate_signing_request", - "type": "string", - }, - "certificate": { - "$id": "#/properties/certificates/items/certificate", - "type": "string", - }, - "ca": {"$id": "#/properties/certificates/items/ca", "type": "string"}, - "chain": { - "$id": "#/properties/certificates/items/chain", - "type": "array", - "items": { - "type": "string", - "$id": "#/properties/certificates/items/chain/items", - }, - }, - "revoked": { - "$id": "#/properties/certificates/items/revoked", - "type": "boolean", - }, - }, - "additionalProperties": True, - }, - } - }, - "required": ["certificates"], - "additionalProperties": True, -} - - -logger = logging.getLogger(__name__) - - -@dataclass -class RequirerCSR: - """This class represents a certificate signing request from an interface Requirer.""" - - relation_id: int - application_name: str - unit_name: str - csr: str - is_ca: bool - - -@dataclass -class ProviderCertificate: - """This class represents a certificate from an interface Provider.""" - - relation_id: int - application_name: str - csr: str - certificate: str - ca: str - chain: List[str] - revoked: bool - expiry_time: datetime - expiry_notification_time: Optional[datetime] = None - - def chain_as_pem(self) -> str: - """Return full certificate chain as a PEM string.""" - return "\n\n".join(reversed(self.chain)) - - def to_json(self) -> str: - """Return the object as a JSON string. - - Returns: - str: JSON representation of the object - """ - return json.dumps( - { - "relation_id": self.relation_id, - "application_name": self.application_name, - "csr": self.csr, - "certificate": self.certificate, - "ca": self.ca, - "chain": self.chain, - "revoked": self.revoked, - "expiry_time": self.expiry_time.isoformat(), - "expiry_notification_time": self.expiry_notification_time.isoformat() - if self.expiry_notification_time - else None, - } - ) - - -class CertificateAvailableEvent(EventBase): - """Charm Event triggered when a TLS certificate is available.""" - - def __init__( - self, - handle: Handle, - certificate: str, - certificate_signing_request: str, - ca: str, - chain: List[str], - ): - super().__init__(handle) - self.certificate = certificate - self.certificate_signing_request = certificate_signing_request - self.ca = ca - self.chain = chain - - def snapshot(self) -> dict: - """Return snapshot.""" - return { - "certificate": self.certificate, - "certificate_signing_request": self.certificate_signing_request, - "ca": self.ca, - "chain": self.chain, - } - - def restore(self, snapshot: dict): - """Restore snapshot.""" - self.certificate = snapshot["certificate"] - self.certificate_signing_request = snapshot["certificate_signing_request"] - self.ca = snapshot["ca"] - self.chain = snapshot["chain"] - - def chain_as_pem(self) -> str: - """Return full certificate chain as a PEM string.""" - return "\n\n".join(reversed(self.chain)) - - -class CertificateExpiringEvent(EventBase): - """Charm Event triggered when a TLS certificate is almost expired.""" - - def __init__(self, handle, certificate: str, expiry: str): - """CertificateExpiringEvent. - - Args: - handle (Handle): Juju framework handle - certificate (str): TLS Certificate - expiry (str): Datetime string representing the time at which the certificate - won't be valid anymore. - """ - super().__init__(handle) - self.certificate = certificate - self.expiry = expiry - - def snapshot(self) -> dict: - """Return snapshot.""" - return {"certificate": self.certificate, "expiry": self.expiry} - - def restore(self, snapshot: dict): - """Restore snapshot.""" - self.certificate = snapshot["certificate"] - self.expiry = snapshot["expiry"] - - -class CertificateInvalidatedEvent(EventBase): - """Charm Event triggered when a TLS certificate is invalidated.""" - - def __init__( - self, - handle: Handle, - reason: Literal["expired", "revoked"], - certificate: str, - certificate_signing_request: str, - ca: str, - chain: List[str], - ): - super().__init__(handle) - self.reason = reason - self.certificate_signing_request = certificate_signing_request - self.certificate = certificate - self.ca = ca - self.chain = chain - - def snapshot(self) -> dict: - """Return snapshot.""" - return { - "reason": self.reason, - "certificate_signing_request": self.certificate_signing_request, - "certificate": self.certificate, - "ca": self.ca, - "chain": self.chain, - } - - def restore(self, snapshot: dict): - """Restore snapshot.""" - self.reason = snapshot["reason"] - self.certificate_signing_request = snapshot["certificate_signing_request"] - self.certificate = snapshot["certificate"] - self.ca = snapshot["ca"] - self.chain = snapshot["chain"] - - -class AllCertificatesInvalidatedEvent(EventBase): - """Charm Event triggered when all TLS certificates are invalidated.""" - - def __init__(self, handle: Handle): - super().__init__(handle) - - def snapshot(self) -> dict: - """Return snapshot.""" - return {} - - def restore(self, snapshot: dict): - """Restore snapshot.""" - pass - - -class CertificateCreationRequestEvent(EventBase): - """Charm Event triggered when a TLS certificate is required.""" - - def __init__( - self, - handle: Handle, - certificate_signing_request: str, - relation_id: int, - is_ca: bool = False, - ): - super().__init__(handle) - self.certificate_signing_request = certificate_signing_request - self.relation_id = relation_id - self.is_ca = is_ca - - def snapshot(self) -> dict: - """Return snapshot.""" - return { - "certificate_signing_request": self.certificate_signing_request, - "relation_id": self.relation_id, - "is_ca": self.is_ca, - } - - def restore(self, snapshot: dict): - """Restore snapshot.""" - self.certificate_signing_request = snapshot["certificate_signing_request"] - self.relation_id = snapshot["relation_id"] - self.is_ca = snapshot["is_ca"] - - -class CertificateRevocationRequestEvent(EventBase): - """Charm Event triggered when a TLS certificate needs to be revoked.""" - - def __init__( - self, - handle: Handle, - certificate: str, - certificate_signing_request: str, - ca: str, - chain: str, - ): - super().__init__(handle) - self.certificate = certificate - self.certificate_signing_request = certificate_signing_request - self.ca = ca - self.chain = chain - - def snapshot(self) -> dict: - """Return snapshot.""" - return { - "certificate": self.certificate, - "certificate_signing_request": self.certificate_signing_request, - "ca": self.ca, - "chain": self.chain, - } - - def restore(self, snapshot: dict): - """Restore snapshot.""" - self.certificate = snapshot["certificate"] - self.certificate_signing_request = snapshot["certificate_signing_request"] - self.ca = snapshot["ca"] - self.chain = snapshot["chain"] - - -def _load_relation_data(relation_data_content: RelationDataContent) -> dict: - """Load relation data from the relation data bag. - - Json loads all data. - - Args: - relation_data_content: Relation data from the databag - - Returns: - dict: Relation data in dict format. - """ - certificate_data = {} - try: - for key in relation_data_content: - try: - certificate_data[key] = json.loads(relation_data_content[key]) - except (json.decoder.JSONDecodeError, TypeError): - certificate_data[key] = relation_data_content[key] - except ModelError: - pass - return certificate_data - - -def _get_closest_future_time( - expiry_notification_time: datetime, expiry_time: datetime -) -> datetime: - """Return expiry_notification_time if not in the past, otherwise return expiry_time. - - Args: - expiry_notification_time (datetime): Notification time of impending expiration - expiry_time (datetime): Expiration time - - Returns: - datetime: expiry_notification_time if not in the past, expiry_time otherwise - """ - return ( - expiry_notification_time - if datetime.now(timezone.utc) < expiry_notification_time - else expiry_time - ) - - -def calculate_expiry_notification_time( - validity_start_time: datetime, - expiry_time: datetime, - provider_recommended_notification_time: Optional[int], - requirer_recommended_notification_time: Optional[int], -) -> datetime: - """Calculate a reasonable time to notify the user about the certificate expiry. - - It takes into account the time recommended by the provider and by the requirer. - Time recommended by the provider is preferred, - then time recommended by the requirer, - then dynamically calculated time. - - Args: - validity_start_time: Certificate validity time - expiry_time: Certificate expiry time - provider_recommended_notification_time: - Time in hours prior to expiry to notify the user. - Recommended by the provider. - requirer_recommended_notification_time: - Time in hours prior to expiry to notify the user. - Recommended by the requirer. - - Returns: - datetime: Time to notify the user about the certificate expiry. - """ - if provider_recommended_notification_time is not None: - provider_recommended_notification_time = abs(provider_recommended_notification_time) - provider_recommendation_time_delta = ( - expiry_time - timedelta(hours=provider_recommended_notification_time) - ) - if validity_start_time < provider_recommendation_time_delta: - return provider_recommendation_time_delta - - if requirer_recommended_notification_time is not None: - requirer_recommended_notification_time = abs(requirer_recommended_notification_time) - requirer_recommendation_time_delta = ( - expiry_time - timedelta(hours=requirer_recommended_notification_time) - ) - if validity_start_time < requirer_recommendation_time_delta: - return requirer_recommendation_time_delta - calculated_hours = (expiry_time - validity_start_time).total_seconds() / (3600 * 3) - return expiry_time - timedelta(hours=calculated_hours) - - -def generate_ca( - private_key: bytes, - subject: str, - private_key_password: Optional[bytes] = None, - validity: int = 365, - country: str = "US", -) -> bytes: - """Generate a CA Certificate. - - Args: - private_key (bytes): Private key - subject (str): Common Name that can be an IP or a Full Qualified Domain Name (FQDN). - private_key_password (bytes): Private key password - validity (int): Certificate validity time (in days) - country (str): Certificate Issuing country - - Returns: - bytes: CA Certificate. - """ - private_key_object = serialization.load_pem_private_key( - private_key, password=private_key_password - ) - subject_name = x509.Name( - [ - x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country), - x509.NameAttribute(x509.NameOID.COMMON_NAME, subject), - ] - ) - subject_identifier_object = x509.SubjectKeyIdentifier.from_public_key( - private_key_object.public_key() # type: ignore[arg-type] - ) - subject_identifier = key_identifier = subject_identifier_object.public_bytes() - key_usage = x509.KeyUsage( - digital_signature=True, - key_encipherment=True, - key_cert_sign=True, - key_agreement=False, - content_commitment=False, - data_encipherment=False, - crl_sign=False, - encipher_only=False, - decipher_only=False, - ) - cert = ( - x509.CertificateBuilder() - .subject_name(subject_name) - .issuer_name(subject_name) - .public_key(private_key_object.public_key()) # type: ignore[arg-type] - .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.now(timezone.utc)) - .not_valid_after(datetime.now(timezone.utc) + timedelta(days=validity)) - .add_extension(x509.SubjectKeyIdentifier(digest=subject_identifier), critical=False) - .add_extension( - x509.AuthorityKeyIdentifier( - key_identifier=key_identifier, - authority_cert_issuer=None, - authority_cert_serial_number=None, - ), - critical=False, - ) - .add_extension(key_usage, critical=True) - .add_extension( - x509.BasicConstraints(ca=True, path_length=None), - critical=True, - ) - .sign(private_key_object, hashes.SHA256()) # type: ignore[arg-type] - ) - return cert.public_bytes(serialization.Encoding.PEM) - - -def get_certificate_extensions( - authority_key_identifier: bytes, - csr: x509.CertificateSigningRequest, - alt_names: Optional[List[str]], - is_ca: bool, -) -> List[x509.Extension]: - """Generate a list of certificate extensions from a CSR and other known information. - - Args: - authority_key_identifier (bytes): Authority key identifier - csr (x509.CertificateSigningRequest): CSR - alt_names (list): List of alt names to put on cert - prefer putting SANs in CSR - is_ca (bool): Whether the certificate is a CA certificate - - Returns: - List[x509.Extension]: List of extensions - """ - cert_extensions_list: List[x509.Extension] = [ - x509.Extension( - oid=ExtensionOID.AUTHORITY_KEY_IDENTIFIER, - value=x509.AuthorityKeyIdentifier( - key_identifier=authority_key_identifier, - authority_cert_issuer=None, - authority_cert_serial_number=None, - ), - critical=False, - ), - x509.Extension( - oid=ExtensionOID.SUBJECT_KEY_IDENTIFIER, - value=x509.SubjectKeyIdentifier.from_public_key(csr.public_key()), - critical=False, - ), - x509.Extension( - oid=ExtensionOID.BASIC_CONSTRAINTS, - critical=True, - value=x509.BasicConstraints(ca=is_ca, path_length=None), - ), - ] - - sans: List[x509.GeneralName] = [] - san_alt_names = [x509.DNSName(name) for name in alt_names] if alt_names else [] - sans.extend(san_alt_names) - try: - loaded_san_ext = csr.extensions.get_extension_for_class(x509.SubjectAlternativeName) - sans.extend( - [x509.DNSName(name) for name in loaded_san_ext.value.get_values_for_type(x509.DNSName)] - ) - sans.extend( - [x509.IPAddress(ip) for ip in loaded_san_ext.value.get_values_for_type(x509.IPAddress)] - ) - sans.extend( - [ - x509.RegisteredID(oid) - for oid in loaded_san_ext.value.get_values_for_type(x509.RegisteredID) - ] - ) - except x509.ExtensionNotFound: - pass - - if sans: - cert_extensions_list.append( - x509.Extension( - oid=ExtensionOID.SUBJECT_ALTERNATIVE_NAME, - critical=False, - value=x509.SubjectAlternativeName(sans), - ) - ) - - if is_ca: - cert_extensions_list.append( - x509.Extension( - ExtensionOID.KEY_USAGE, - critical=True, - value=x509.KeyUsage( - digital_signature=False, - content_commitment=False, - key_encipherment=False, - data_encipherment=False, - key_agreement=False, - key_cert_sign=True, - crl_sign=True, - encipher_only=False, - decipher_only=False, - ), - ) - ) - - existing_oids = {ext.oid for ext in cert_extensions_list} - for extension in csr.extensions: - if extension.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME: - continue - if extension.oid in existing_oids: - logger.warning("Extension %s is managed by the TLS provider, ignoring.", extension.oid) - continue - cert_extensions_list.append(extension) - - return cert_extensions_list - - -def generate_certificate( - csr: bytes, - ca: bytes, - ca_key: bytes, - ca_key_password: Optional[bytes] = None, - validity: int = 365, - alt_names: Optional[List[str]] = None, - is_ca: bool = False, -) -> bytes: - """Generate a TLS certificate based on a CSR. - - Args: - csr (bytes): CSR - ca (bytes): CA Certificate - ca_key (bytes): CA private key - ca_key_password: CA private key password - validity (int): Certificate validity (in days) - alt_names (list): List of alt names to put on cert - prefer putting SANs in CSR - is_ca (bool): Whether the certificate is a CA certificate - - Returns: - bytes: Certificate - """ - csr_object = x509.load_pem_x509_csr(csr) - subject = csr_object.subject - ca_pem = x509.load_pem_x509_certificate(ca) - issuer = ca_pem.issuer - private_key = serialization.load_pem_private_key(ca_key, password=ca_key_password) - - certificate_builder = ( - x509.CertificateBuilder() - .subject_name(subject) - .issuer_name(issuer) - .public_key(csr_object.public_key()) - .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.now(timezone.utc)) - .not_valid_after(datetime.now(timezone.utc) + timedelta(days=validity)) - ) - extensions = get_certificate_extensions( - authority_key_identifier=ca_pem.extensions.get_extension_for_class( - x509.SubjectKeyIdentifier - ).value.key_identifier, - csr=csr_object, - alt_names=alt_names, - is_ca=is_ca, - ) - for extension in extensions: - try: - certificate_builder = certificate_builder.add_extension( - extval=extension.value, - critical=extension.critical, - ) - except ValueError as e: - logger.warning("Failed to add extension %s: %s", extension.oid, e) - - cert = certificate_builder.sign(private_key, hashes.SHA256()) # type: ignore[arg-type] - return cert.public_bytes(serialization.Encoding.PEM) - - -def generate_private_key( - password: Optional[bytes] = None, - key_size: int = 2048, - public_exponent: int = 65537, -) -> bytes: - """Generate a private key. - - Args: - password (bytes): Password for decrypting the private key - key_size (int): Key size in bytes - public_exponent: Public exponent. - - Returns: - bytes: Private Key - """ - private_key = rsa.generate_private_key( - public_exponent=public_exponent, - key_size=key_size, - ) - key_bytes = private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=( - serialization.BestAvailableEncryption(password) - if password - else serialization.NoEncryption() - ), - ) - return key_bytes - - -def generate_csr( # noqa: C901 - private_key: bytes, - subject: str, - add_unique_id_to_subject_name: bool = True, - organization: Optional[str] = None, - email_address: Optional[str] = None, - country_name: Optional[str] = None, - state_or_province_name: Optional[str] = None, - locality_name: Optional[str] = None, - private_key_password: Optional[bytes] = None, - sans: Optional[List[str]] = None, - sans_oid: Optional[List[str]] = None, - sans_ip: Optional[List[str]] = None, - sans_dns: Optional[List[str]] = None, - additional_critical_extensions: Optional[List] = None, -) -> bytes: - """Generate a CSR using private key and subject. - - Args: - private_key (bytes): Private key - subject (str): CSR Common Name that can be an IP or a Full Qualified Domain Name (FQDN). - add_unique_id_to_subject_name (bool): Whether a unique ID must be added to the CSR's - subject name. Always leave to "True" when the CSR is used to request certificates - using the tls-certificates relation. - organization (str): Name of organization. - email_address (str): Email address. - country_name (str): Country Name. - state_or_province_name (str): State or Province Name. - locality_name (str): Locality Name. - private_key_password (bytes): Private key password - sans (list): Use sans_dns - this will be deprecated in a future release - List of DNS subject alternative names (keeping it for now for backward compatibility) - sans_oid (list): List of registered ID SANs - sans_dns (list): List of DNS subject alternative names (similar to the arg: sans) - sans_ip (list): List of IP subject alternative names - additional_critical_extensions (list): List of critical additional extension objects. - Object must be a x509 ExtensionType. - - Returns: - bytes: CSR - """ - signing_key = serialization.load_pem_private_key(private_key, password=private_key_password) - subject_name = [x509.NameAttribute(x509.NameOID.COMMON_NAME, subject)] - if add_unique_id_to_subject_name: - unique_identifier = uuid.uuid4() - subject_name.append( - x509.NameAttribute(x509.NameOID.X500_UNIQUE_IDENTIFIER, str(unique_identifier)) - ) - if organization: - subject_name.append(x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, organization)) - if email_address: - subject_name.append(x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, email_address)) - if country_name: - subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country_name)) - if state_or_province_name: - subject_name.append( - x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, state_or_province_name) - ) - if locality_name: - subject_name.append(x509.NameAttribute(x509.NameOID.LOCALITY_NAME, locality_name)) - csr = x509.CertificateSigningRequestBuilder(subject_name=x509.Name(subject_name)) - - _sans: List[x509.GeneralName] = [] - if sans_oid: - _sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid]) - if sans_ip: - _sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in sans_ip]) - if sans: - _sans.extend([x509.DNSName(san) for san in sans]) - if sans_dns: - _sans.extend([x509.DNSName(san) for san in sans_dns]) - if _sans: - csr = csr.add_extension(x509.SubjectAlternativeName(set(_sans)), critical=False) - - if additional_critical_extensions: - for extension in additional_critical_extensions: - csr = csr.add_extension(extension, critical=True) - - signed_certificate = csr.sign(signing_key, hashes.SHA256()) # type: ignore[arg-type] - return signed_certificate.public_bytes(serialization.Encoding.PEM) - - -def get_sha256_hex(data: str) -> str: - """Calculate the hash of the provided data and return the hexadecimal representation.""" - digest = hashes.Hash(hashes.SHA256()) - digest.update(data.encode()) - return digest.finalize().hex() - - -def csr_matches_certificate(csr: str, cert: str) -> bool: - """Check if a CSR matches a certificate. - - Args: - csr (str): Certificate Signing Request as a string - cert (str): Certificate as a string - Returns: - bool: True/False depending on whether the CSR matches the certificate. - """ - csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) - cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) - - if csr_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ) != cert_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ): - return False - return True - - -def _relation_data_is_valid( - relation: Relation, app_or_unit: Union[Application, Unit], json_schema: dict -) -> bool: - """Check whether relation data is valid based on json schema. - - Args: - relation (Relation): Relation object - app_or_unit (Union[Application, Unit]): Application or unit object - json_schema (dict): Json schema - - Returns: - bool: Whether relation data is valid. - """ - relation_data = _load_relation_data(relation.data[app_or_unit]) - try: - validate(instance=relation_data, schema=json_schema) - return True - except exceptions.ValidationError: - return False - - -class CertificatesProviderCharmEvents(CharmEvents): - """List of events that the TLS Certificates provider charm can leverage.""" - - certificate_creation_request = EventSource(CertificateCreationRequestEvent) - certificate_revocation_request = EventSource(CertificateRevocationRequestEvent) - - -class CertificatesRequirerCharmEvents(CharmEvents): - """List of events that the TLS Certificates requirer charm can leverage.""" - - certificate_available = EventSource(CertificateAvailableEvent) - certificate_expiring = EventSource(CertificateExpiringEvent) - certificate_invalidated = EventSource(CertificateInvalidatedEvent) - all_certificates_invalidated = EventSource(AllCertificatesInvalidatedEvent) - - -class TLSCertificatesProvidesV3(Object): - """TLS certificates provider class to be instantiated by TLS certificates providers.""" - - on = CertificatesProviderCharmEvents() # type: ignore[reportAssignmentType] - - def __init__(self, charm: CharmBase, relationship_name: str): - super().__init__(charm, relationship_name) - self.framework.observe( - charm.on[relationship_name].relation_changed, self._on_relation_changed - ) - self.charm = charm - self.relationship_name = relationship_name - - def _load_app_relation_data(self, relation: Relation) -> dict: - """Load relation data from the application relation data bag. - - Json loads all data. - - Args: - relation: Relation data from the application databag - - Returns: - dict: Relation data in dict format. - """ - # If unit is not leader, it does not try to reach relation data. - if not self.model.unit.is_leader(): - return {} - return _load_relation_data(relation.data[self.charm.app]) - - def _add_certificate( - self, - relation_id: int, - certificate: str, - certificate_signing_request: str, - ca: str, - chain: List[str], - recommended_expiry_notification_time: Optional[int] = None, - ) -> None: - """Add certificate to relation data. - - Args: - relation_id (int): Relation id - certificate (str): Certificate - certificate_signing_request (str): Certificate Signing Request - ca (str): CA Certificate - chain (list): CA Chain - recommended_expiry_notification_time (int): - Time in hours before the certificate expires to notify the user. - - Returns: - None - """ - relation = self.model.get_relation( - relation_name=self.relationship_name, relation_id=relation_id - ) - if not relation: - raise RuntimeError( - f"Relation {self.relationship_name} does not exist - " - f"The certificate request can't be completed" - ) - new_certificate = { - "certificate": certificate, - "certificate_signing_request": certificate_signing_request, - "ca": ca, - "chain": chain, - "recommended_expiry_notification_time": recommended_expiry_notification_time, - } - provider_relation_data = self._load_app_relation_data(relation) - provider_certificates = provider_relation_data.get("certificates", []) - certificates = copy.deepcopy(provider_certificates) - if new_certificate in certificates: - logger.info("Certificate already in relation data - Doing nothing") - return - certificates.append(new_certificate) - relation.data[self.model.app]["certificates"] = json.dumps(certificates) - - def _remove_certificate( - self, - relation_id: int, - certificate: Optional[str] = None, - certificate_signing_request: Optional[str] = None, - ) -> None: - """Remove certificate from a given relation based on user provided certificate or csr. - - Args: - relation_id (int): Relation id - certificate (str): Certificate (optional) - certificate_signing_request: Certificate signing request (optional) - - Returns: - None - """ - relation = self.model.get_relation( - relation_name=self.relationship_name, - relation_id=relation_id, - ) - if not relation: - raise RuntimeError( - f"Relation {self.relationship_name} with relation id {relation_id} does not exist" - ) - provider_relation_data = self._load_app_relation_data(relation) - provider_certificates = provider_relation_data.get("certificates", []) - certificates = copy.deepcopy(provider_certificates) - for certificate_dict in certificates: - if certificate and certificate_dict["certificate"] == certificate: - certificates.remove(certificate_dict) - if ( - certificate_signing_request - and certificate_dict["certificate_signing_request"] == certificate_signing_request - ): - certificates.remove(certificate_dict) - relation.data[self.model.app]["certificates"] = json.dumps(certificates) - - def revoke_all_certificates(self) -> None: - """Revoke all certificates of this provider. - - This method is meant to be used when the Root CA has changed. - """ - for relation in self.model.relations[self.relationship_name]: - provider_relation_data = self._load_app_relation_data(relation) - provider_certificates = copy.deepcopy(provider_relation_data.get("certificates", [])) - for certificate in provider_certificates: - certificate["revoked"] = True - relation.data[self.model.app]["certificates"] = json.dumps(provider_certificates) - - def set_relation_certificate( - self, - certificate: str, - certificate_signing_request: str, - ca: str, - chain: List[str], - relation_id: int, - recommended_expiry_notification_time: Optional[int] = None, - ) -> None: - """Add certificates to relation data. - - Args: - certificate (str): Certificate - certificate_signing_request (str): Certificate signing request - ca (str): CA Certificate - chain (list): CA Chain - relation_id (int): Juju relation ID - recommended_expiry_notification_time (int): - Recommended time in hours before the certificate expires to notify the user. - - Returns: - None - """ - if not self.model.unit.is_leader(): - return - certificates_relation = self.model.get_relation( - relation_name=self.relationship_name, relation_id=relation_id - ) - if not certificates_relation: - raise RuntimeError(f"Relation {self.relationship_name} does not exist") - self._remove_certificate( - certificate_signing_request=certificate_signing_request.strip(), - relation_id=relation_id, - ) - self._add_certificate( - relation_id=relation_id, - certificate=certificate.strip(), - certificate_signing_request=certificate_signing_request.strip(), - ca=ca.strip(), - chain=[cert.strip() for cert in chain], - recommended_expiry_notification_time=recommended_expiry_notification_time, - ) - - def remove_certificate(self, certificate: str) -> None: - """Remove a given certificate from relation data. - - Args: - certificate (str): TLS Certificate - - Returns: - None - """ - certificates_relation = self.model.relations[self.relationship_name] - if not certificates_relation: - raise RuntimeError(f"Relation {self.relationship_name} does not exist") - for certificate_relation in certificates_relation: - self._remove_certificate(certificate=certificate, relation_id=certificate_relation.id) - - def get_issued_certificates( - self, relation_id: Optional[int] = None - ) -> List[ProviderCertificate]: - """Return a List of issued (non revoked) certificates. - - Returns: - List: List of ProviderCertificate objects - """ - provider_certificates = self.get_provider_certificates(relation_id=relation_id) - return [certificate for certificate in provider_certificates if not certificate.revoked] - - def get_provider_certificates( - self, relation_id: Optional[int] = None - ) -> List[ProviderCertificate]: - """Return a List of issued certificates. - - Returns: - List: List of ProviderCertificate objects - """ - certificates: List[ProviderCertificate] = [] - relations = ( - [ - relation - for relation in self.model.relations[self.relationship_name] - if relation.id == relation_id - ] - if relation_id is not None - else self.model.relations.get(self.relationship_name, []) - ) - for relation in relations: - if not relation.app: - logger.warning("Relation %s does not have an application", relation.id) - continue - provider_relation_data = self._load_app_relation_data(relation) - provider_certificates = provider_relation_data.get("certificates", []) - for certificate in provider_certificates: - try: - certificate_object = x509.load_pem_x509_certificate( - data=certificate["certificate"].encode() - ) - except ValueError as e: - logger.error("Could not load certificate - Skipping: %s", e) - continue - provider_certificate = ProviderCertificate( - relation_id=relation.id, - application_name=relation.app.name, - csr=certificate["certificate_signing_request"], - certificate=certificate["certificate"], - ca=certificate["ca"], - chain=certificate["chain"], - revoked=certificate.get("revoked", False), - expiry_time=certificate_object.not_valid_after_utc, - expiry_notification_time=certificate.get( - "recommended_expiry_notification_time" - ), - ) - certificates.append(provider_certificate) - return certificates - - def _on_relation_changed(self, event: RelationChangedEvent) -> None: - """Handle relation changed event. - - Looks at the relation data and either emits: - - certificate request event: If the unit relation data contains a CSR for which - a certificate does not exist in the provider relation data. - - certificate revocation event: If the provider relation data contains a CSR for which - a csr does not exist in the requirer relation data. - - Args: - event: Juju event - - Returns: - None - """ - if event.unit is None: - logger.error("Relation_changed event does not have a unit.") - return - if not self.model.unit.is_leader(): - return - if not _relation_data_is_valid(event.relation, event.unit, REQUIRER_JSON_SCHEMA): - logger.debug("Relation data did not pass JSON Schema validation") - return - provider_certificates = self.get_provider_certificates(relation_id=event.relation.id) - requirer_csrs = self.get_requirer_csrs(relation_id=event.relation.id) - provider_csrs = [ - certificate_creation_request.csr - for certificate_creation_request in provider_certificates - ] - for certificate_request in requirer_csrs: - if certificate_request.csr not in provider_csrs: - self.on.certificate_creation_request.emit( - certificate_signing_request=certificate_request.csr, - relation_id=certificate_request.relation_id, - is_ca=certificate_request.is_ca, - ) - self._revoke_certificates_for_which_no_csr_exists(relation_id=event.relation.id) - - def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None: - """Revoke certificates for which no unit has a CSR. - - Goes through all generated certificates and compare against the list of CSRs for all units. - - Returns: - None - """ - provider_certificates = self.get_provider_certificates(relation_id) - requirer_csrs = self.get_requirer_csrs(relation_id) - list_of_csrs = [csr.csr for csr in requirer_csrs] - for certificate in provider_certificates: - if certificate.csr not in list_of_csrs: - self.on.certificate_revocation_request.emit( - certificate=certificate.certificate, - certificate_signing_request=certificate.csr, - ca=certificate.ca, - chain=certificate.chain, - ) - self.remove_certificate(certificate=certificate.certificate) - - def get_outstanding_certificate_requests( - self, relation_id: Optional[int] = None - ) -> List[RequirerCSR]: - """Return CSR's for which no certificate has been issued. - - Args: - relation_id (int): Relation id - - Returns: - list: List of RequirerCSR objects. - """ - requirer_csrs = self.get_requirer_csrs(relation_id=relation_id) - outstanding_csrs: List[RequirerCSR] = [] - for relation_csr in requirer_csrs: - if not self.certificate_issued_for_csr( - app_name=relation_csr.application_name, - csr=relation_csr.csr, - relation_id=relation_id, - ): - outstanding_csrs.append(relation_csr) - return outstanding_csrs - - def get_requirer_csrs(self, relation_id: Optional[int] = None) -> List[RequirerCSR]: - """Return a list of requirers' CSRs. - - It returns CSRs from all relations if relation_id is not specified. - CSRs are returned per relation id, application name and unit name. - - Returns: - list: List[RequirerCSR] - """ - relation_csrs: List[RequirerCSR] = [] - relations = ( - [ - relation - for relation in self.model.relations[self.relationship_name] - if relation.id == relation_id - ] - if relation_id is not None - else self.model.relations.get(self.relationship_name, []) - ) - - for relation in relations: - for unit in relation.units: - requirer_relation_data = _load_relation_data(relation.data[unit]) - unit_csrs_list = requirer_relation_data.get("certificate_signing_requests", []) - for unit_csr in unit_csrs_list: - csr = unit_csr.get("certificate_signing_request") - if not csr: - logger.warning("No CSR found in relation data - Skipping") - continue - ca = unit_csr.get("ca", False) - if not relation.app: - logger.warning("No remote app in relation - Skipping") - continue - relation_csr = RequirerCSR( - relation_id=relation.id, - application_name=relation.app.name, - unit_name=unit.name, - csr=csr, - is_ca=ca, - ) - relation_csrs.append(relation_csr) - return relation_csrs - - def certificate_issued_for_csr( - self, app_name: str, csr: str, relation_id: Optional[int] - ) -> bool: - """Check whether a certificate has been issued for a given CSR. - - Args: - app_name (str): Application name that the CSR belongs to. - csr (str): Certificate Signing Request. - relation_id (Optional[int]): Relation ID - - Returns: - bool: True/False depending on whether a certificate has been issued for the given CSR. - """ - issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id) - for issued_certificate in issued_certificates_per_csr: - if issued_certificate.csr == csr and issued_certificate.application_name == app_name: - return csr_matches_certificate(csr, issued_certificate.certificate) - return False - - -class TLSCertificatesRequiresV3(Object): - """TLS certificates requirer class to be instantiated by TLS certificates requirers.""" - - on = CertificatesRequirerCharmEvents() # type: ignore[reportAssignmentType] - - def __init__( - self, - charm: CharmBase, - relationship_name: str, - expiry_notification_time: Optional[int] = None, - ): - """Generate/use private key and observes relation changed event. - - Args: - charm: Charm object - relationship_name: Juju relation name - expiry_notification_time (int): Number of hours prior to certificate expiry. - Used to trigger the CertificateExpiring event. - This value is used as a recommendation only, - The actual value is calculated taking into account the provider's recommendation. - """ - super().__init__(charm, relationship_name) - if not JujuVersion.from_environ().has_secrets: - logger.warning("This version of the TLS library requires Juju secrets (Juju >= 3.0)") - self.relationship_name = relationship_name - self.charm = charm - self.expiry_notification_time = expiry_notification_time - self.framework.observe( - charm.on[relationship_name].relation_changed, self._on_relation_changed - ) - self.framework.observe( - charm.on[relationship_name].relation_broken, self._on_relation_broken - ) - self.framework.observe(charm.on.secret_expired, self._on_secret_expired) - - def get_requirer_csrs(self) -> List[RequirerCSR]: - """Return list of requirer's CSRs from relation unit data. - - Returns: - list: List of RequirerCSR objects. - """ - relation = self.model.get_relation(self.relationship_name) - if not relation: - return [] - requirer_csrs = [] - requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) - requirer_csrs_dict = requirer_relation_data.get("certificate_signing_requests", []) - for requirer_csr_dict in requirer_csrs_dict: - csr = requirer_csr_dict.get("certificate_signing_request") - if not csr: - logger.warning("No CSR found in relation data - Skipping") - continue - ca = requirer_csr_dict.get("ca", False) - relation_csr = RequirerCSR( - relation_id=relation.id, - application_name=self.model.app.name, - unit_name=self.model.unit.name, - csr=csr, - is_ca=ca, - ) - requirer_csrs.append(relation_csr) - return requirer_csrs - - def get_provider_certificates(self) -> List[ProviderCertificate]: - """Return list of certificates from the provider's relation data.""" - provider_certificates: List[ProviderCertificate] = [] - relation = self.model.get_relation(self.relationship_name) - if not relation: - logger.debug("No relation: %s", self.relationship_name) - return [] - if not relation.app: - logger.debug("No remote app in relation: %s", self.relationship_name) - return [] - provider_relation_data = _load_relation_data(relation.data[relation.app]) - provider_certificate_dicts = provider_relation_data.get("certificates", []) - for provider_certificate_dict in provider_certificate_dicts: - certificate = provider_certificate_dict.get("certificate") - if not certificate: - logger.warning("No certificate found in relation data - Skipping") - continue - try: - certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) - except ValueError as e: - logger.error("Could not load certificate - Skipping: %s", e) - continue - ca = provider_certificate_dict.get("ca") - chain = provider_certificate_dict.get("chain", []) - csr = provider_certificate_dict.get("certificate_signing_request") - recommended_expiry_notification_time = provider_certificate_dict.get( - "recommended_expiry_notification_time" - ) - expiry_time = certificate_object.not_valid_after_utc - validity_start_time = certificate_object.not_valid_before_utc - expiry_notification_time = calculate_expiry_notification_time( - validity_start_time=validity_start_time, - expiry_time=expiry_time, - provider_recommended_notification_time=recommended_expiry_notification_time, - requirer_recommended_notification_time=self.expiry_notification_time, - ) - if not csr: - logger.warning("No CSR found in relation data - Skipping") - continue - revoked = provider_certificate_dict.get("revoked", False) - provider_certificate = ProviderCertificate( - relation_id=relation.id, - application_name=relation.app.name, - csr=csr, - certificate=certificate, - ca=ca, - chain=chain, - revoked=revoked, - expiry_time=expiry_time, - expiry_notification_time=expiry_notification_time, - ) - provider_certificates.append(provider_certificate) - return provider_certificates - - def _add_requirer_csr_to_relation_data(self, csr: str, is_ca: bool) -> None: - """Add CSR to relation data. - - Args: - csr (str): Certificate Signing Request - is_ca (bool): Whether the certificate is a CA certificate - - Returns: - None - """ - relation = self.model.get_relation(self.relationship_name) - if not relation: - raise RuntimeError( - f"Relation {self.relationship_name} does not exist - " - f"The certificate request can't be completed" - ) - for requirer_csr in self.get_requirer_csrs(): - if requirer_csr.csr == csr and requirer_csr.is_ca == is_ca: - logger.info("CSR already in relation data - Doing nothing") - return - new_csr_dict = { - "certificate_signing_request": csr, - "ca": is_ca, - } - requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) - existing_relation_data = requirer_relation_data.get("certificate_signing_requests", []) - new_relation_data = copy.deepcopy(existing_relation_data) - new_relation_data.append(new_csr_dict) - relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps( - new_relation_data - ) - - def _remove_requirer_csr_from_relation_data(self, csr: str) -> None: - """Remove CSR from relation data. - - Args: - csr (str): Certificate signing request - - Returns: - None - """ - relation = self.model.get_relation(self.relationship_name) - if not relation: - raise RuntimeError( - f"Relation {self.relationship_name} does not exist - " - f"The certificate request can't be completed" - ) - if not self.get_requirer_csrs(): - logger.info("No CSRs in relation data - Doing nothing") - return - requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) - existing_relation_data = requirer_relation_data.get("certificate_signing_requests", []) - new_relation_data = copy.deepcopy(existing_relation_data) - for requirer_csr in new_relation_data: - if requirer_csr["certificate_signing_request"] == csr: - new_relation_data.remove(requirer_csr) - relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps( - new_relation_data - ) - - def request_certificate_creation( - self, certificate_signing_request: bytes, is_ca: bool = False - ) -> None: - """Request TLS certificate to provider charm. - - Args: - certificate_signing_request (bytes): Certificate Signing Request - is_ca (bool): Whether the certificate is a CA certificate - - Returns: - None - """ - relation = self.model.get_relation(self.relationship_name) - if not relation: - raise RuntimeError( - f"Relation {self.relationship_name} does not exist - " - f"The certificate request can't be completed" - ) - self._add_requirer_csr_to_relation_data( - certificate_signing_request.decode().strip(), is_ca=is_ca - ) - logger.info("Certificate request sent to provider") - - def request_certificate_revocation(self, certificate_signing_request: bytes) -> None: - """Remove CSR from relation data. - - The provider of this relation is then expected to remove certificates associated to this - CSR from the relation data as well and emit a request_certificate_revocation event for the - provider charm to interpret. - - Args: - certificate_signing_request (bytes): Certificate Signing Request - - Returns: - None - """ - self._remove_requirer_csr_from_relation_data(certificate_signing_request.decode().strip()) - logger.info("Certificate revocation sent to provider") - - def request_certificate_renewal( - self, old_certificate_signing_request: bytes, new_certificate_signing_request: bytes - ) -> None: - """Renew certificate. - - Removes old CSR from relation data and adds new one. - - Args: - old_certificate_signing_request: Old CSR - new_certificate_signing_request: New CSR - - Returns: - None - """ - try: - self.request_certificate_revocation( - certificate_signing_request=old_certificate_signing_request - ) - except RuntimeError: - logger.warning("Certificate revocation failed.") - self.request_certificate_creation( - certificate_signing_request=new_certificate_signing_request - ) - logger.info("Certificate renewal request completed.") - - def get_assigned_certificates(self) -> List[ProviderCertificate]: - """Get a list of certificates that were assigned to this unit. - - Returns: - List: List[ProviderCertificate] - """ - assigned_certificates = [] - for requirer_csr in self.get_certificate_signing_requests(fulfilled_only=True): - if cert := self._find_certificate_in_relation_data(requirer_csr.csr): - assigned_certificates.append(cert) - return assigned_certificates - - def get_expiring_certificates(self) -> List[ProviderCertificate]: - """Get a list of certificates that were assigned to this unit that are expiring or expired. - - Returns: - List: List[ProviderCertificate] - """ - expiring_certificates: List[ProviderCertificate] = [] - for requirer_csr in self.get_certificate_signing_requests(fulfilled_only=True): - if cert := self._find_certificate_in_relation_data(requirer_csr.csr): - if not cert.expiry_time or not cert.expiry_notification_time: - continue - if datetime.now(timezone.utc) > cert.expiry_notification_time: - expiring_certificates.append(cert) - return expiring_certificates - - def get_certificate_signing_requests( - self, - fulfilled_only: bool = False, - unfulfilled_only: bool = False, - ) -> List[RequirerCSR]: - """Get the list of CSR's that were sent to the provider. - - You can choose to get only the CSR's that have a certificate assigned or only the CSR's - that don't. - - Args: - fulfilled_only (bool): This option will discard CSRs that don't have certificates yet. - unfulfilled_only (bool): This option will discard CSRs that have certificates signed. - - Returns: - List of RequirerCSR objects. - """ - csrs = [] - for requirer_csr in self.get_requirer_csrs(): - cert = self._find_certificate_in_relation_data(requirer_csr.csr) - if (unfulfilled_only and cert) or (fulfilled_only and not cert): - continue - csrs.append(requirer_csr) - - return csrs - - def _on_relation_changed(self, event: RelationChangedEvent) -> None: - """Handle relation changed event. - - Goes through all providers certificates that match a requested CSR. - - If the provider certificate is revoked, emit a CertificateInvalidateEvent, - otherwise emit a CertificateAvailableEvent. - - Remove the secret for revoked certificate, or add a secret with the correct expiry - time for new certificates. - - Args: - event: Juju event - - Returns: - None - """ - if not event.app: - logger.warning("No remote app in relation - Skipping") - return - if not _relation_data_is_valid(event.relation, event.app, PROVIDER_JSON_SCHEMA): - logger.debug("Relation data did not pass JSON Schema validation") - return - provider_certificates = self.get_provider_certificates() - requirer_csrs = [ - certificate_creation_request.csr - for certificate_creation_request in self.get_requirer_csrs() - ] - for certificate in provider_certificates: - if certificate.csr in requirer_csrs: - csr_in_sha256_hex = get_sha256_hex(certificate.csr) - if certificate.revoked: - with suppress(SecretNotFoundError): - logger.debug( - "Removing secret with label %s", - f"{LIBID}-{csr_in_sha256_hex}", - ) - secret = self.model.get_secret( - label=f"{LIBID}-{csr_in_sha256_hex}") - secret.remove_all_revisions() - self.on.certificate_invalidated.emit( - reason="revoked", - certificate=certificate.certificate, - certificate_signing_request=certificate.csr, - ca=certificate.ca, - chain=certificate.chain, - ) - else: - try: - logger.debug( - "Setting secret with label %s", f"{LIBID}-{csr_in_sha256_hex}" - ) - secret = self.model.get_secret(label=f"{LIBID}-{csr_in_sha256_hex}") - secret.set_content( - {"certificate": certificate.certificate, "csr": certificate.csr} - ) - secret.set_info( - expire=self._get_next_secret_expiry_time(certificate), - ) - except SecretNotFoundError: - logger.debug( - "Creating new secret with label %s", f"{LIBID}-{csr_in_sha256_hex}" - ) - secret = self.charm.unit.add_secret( - {"certificate": certificate.certificate, "csr": certificate.csr}, - label=f"{LIBID}-{csr_in_sha256_hex}", - expire=self._get_next_secret_expiry_time(certificate), - ) - self.on.certificate_available.emit( - certificate_signing_request=certificate.csr, - certificate=certificate.certificate, - ca=certificate.ca, - chain=certificate.chain, - ) - - def _get_next_secret_expiry_time(self, certificate: ProviderCertificate) -> Optional[datetime]: - """Return the expiry time or expiry notification time. - - Extracts the expiry time from the provided certificate, calculates the - expiry notification time and return the closest of the two, that is in - the future. - - Args: - certificate: ProviderCertificate object - - Returns: - Optional[datetime]: None if the certificate expiry time cannot be read, - next expiry time otherwise. - """ - if not certificate.expiry_time or not certificate.expiry_notification_time: - return None - return _get_closest_future_time( - certificate.expiry_notification_time, - certificate.expiry_time, - ) - - def _on_relation_broken(self, event: RelationBrokenEvent) -> None: - """Handle Relation Broken Event. - - Emitting `all_certificates_invalidated` from `relation-broken` rather - than `relation-departed` since certs are stored in app data. - - Args: - event: Juju event - - Returns: - None - """ - self.on.all_certificates_invalidated.emit() - - def _on_secret_expired(self, event: SecretExpiredEvent) -> None: - """Handle Secret Expired Event. - - Loads the certificate from the secret, and will emit 1 of 2 - events. - - If the certificate is not yet expired, emits CertificateExpiringEvent - and updates the expiry time of the secret to the exact expiry time on - the certificate. - - If the certificate is expired, emits CertificateInvalidedEvent and - deletes the secret. - - Args: - event (SecretExpiredEvent): Juju event - """ - if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-"): - return - csr = event.secret.get_content()["csr"] - provider_certificate = self._find_certificate_in_relation_data(csr) - if not provider_certificate: - # A secret expired but we did not find matching certificate. Cleaning up - event.secret.remove_all_revisions() - return - - if not provider_certificate.expiry_time: - # A secret expired but matching certificate is invalid. Cleaning up - event.secret.remove_all_revisions() - return - - if datetime.now(timezone.utc) < provider_certificate.expiry_time: - logger.warning("Certificate almost expired") - self.on.certificate_expiring.emit( - certificate=provider_certificate.certificate, - expiry=provider_certificate.expiry_time.isoformat(), - ) - event.secret.set_info( - expire=provider_certificate.expiry_time, - ) - else: - logger.warning("Certificate is expired") - self.on.certificate_invalidated.emit( - reason="expired", - certificate=provider_certificate.certificate, - certificate_signing_request=provider_certificate.csr, - ca=provider_certificate.ca, - chain=provider_certificate.chain, - ) - self.request_certificate_revocation(provider_certificate.certificate.encode()) - event.secret.remove_all_revisions() - - def _find_certificate_in_relation_data(self, csr: str) -> Optional[ProviderCertificate]: - """Return the certificate that match the given CSR.""" - for provider_certificate in self.get_provider_certificates(): - if provider_certificate.csr != csr: - continue - return provider_certificate - return None diff --git a/lib/charms/tls_certificates_interface/v4/tls_certificates.py b/lib/charms/tls_certificates_interface/v4/tls_certificates.py new file mode 100644 index 0000000..9839759 --- /dev/null +++ b/lib/charms/tls_certificates_interface/v4/tls_certificates.py @@ -0,0 +1,1470 @@ +# Copyright 2024 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Charm library for managing TLS certificates (V4) - BETA. + +> Warning: This is a beta version of the tls-certificates interface library. +> Use at your own risk. + +This library contains the Requires and Provides classes for handling the tls-certificates +interface. + +Pre-requisites: + - Juju >= 3.0 + +## Getting Started +From a charm directory, fetch the library using `charmcraft`: + +```shell +charmcraft fetch-lib charms.tls_certificates_interface.v4.tls_certificates +``` + +Add the following libraries to the charm's `requirements.txt` file: +- cryptography >= 42.0.0 +- pydantic >= 2.0.0 + +Add the following section to the charm's `charmcraft.yaml` file: +```yaml +parts: + charm: + build-packages: + - libffi-dev + - libssl-dev + - rustc + - cargo +``` + +### Requirer charm +The requirer charm is the charm requiring certificates from another charm that provides them. + +#### Example + +In the following example, the requiring charm requests a certificate using attributes +from the Juju configuration options. + +```python +from typing import List, Optional, cast + +from ops.charm import ActionEvent, CharmBase +from ops.main import main + +from lib.charms.tls_certificates_interface.v4.tls_certificates import ( + CertificateAvailableEvent, + CertificateRequest, + Mode, + TLSCertificatesRequiresV4, +) + + +class DummyTLSCertificatesRequirerCharm(CharmBase): + def __init__(self, *args): + super().__init__(*args) + certificate_requests = self._get_certificate_requests() + self.certificates = TLSCertificatesRequiresV4( + charm=self, + relationship_name="certificates", + certificate_requests=certificate_requests, + mode=Mode.UNIT, + refresh_events=[self.on.config_changed], + ) + self.framework.observe( + self.certificates.on.certificate_available, self._on_certificate_available + ) + self.framework.observe( + self.on.regenerate_private_key_action, self._on_regenerate_private_key_action + ) + self.framework.observe(self.on.get_certificate_action, self._on_get_certificate_action) + + def _get_certificate_requests(self) -> List[CertificateRequest]: + if not self._get_config_common_name(): + return [] + return [ + CertificateRequest( + common_name=self._get_config_common_name(), + sans_dns=self._get_config_sans_dns(), + organization=self._get_config_organization_name(), + organizational_unit=self._get_config_organization_unit_name(), + email_address=self._get_config_email_address(), + country_name=self._get_config_country_name(), + state_or_province_name=self._get_config_state_or_province_name(), + locality_name=self._get_config_locality_name(), + ) + ] + + def _on_certificate_available(self, event: CertificateAvailableEvent) -> None: + print("Certificate available") + + def _on_regenerate_private_key_action(self, event: ActionEvent) -> None: + self.certificates.regenerate_private_key() + + def _on_get_certificate_action(self, event: ActionEvent) -> None: + certificate, _ = self.certificates.get_assigned_certificate( + certificate_request=self._get_certificate_requests()[0] + ) + if not certificate: + event.fail("Certificate not available") + return + event.set_results( + { + "certificate": str(certificate.certificate), + "ca": str(certificate.ca), + "csr": str(certificate.certificate_signing_request), + } + ) + + def _get_config_common_name(self) -> str: + return cast(str, self.model.config.get("common_name")) + + def _get_config_sans_dns(self) -> List[str]: + config_sans_dns = cast(str, self.model.config.get("sans_dns", "")) + return config_sans_dns.split(",") if config_sans_dns else [] + + def _get_config_organization_name(self) -> Optional[str]: + return cast(str, self.model.config.get("organization_name")) + + def _get_config_organization_unit_name(self) -> Optional[str]: + return cast(str, self.model.config.get("organization_unit_name")) + + def _get_config_email_address(self) -> Optional[str]: + return cast(str, self.model.config.get("email_address")) + + def _get_config_country_name(self) -> Optional[str]: + return cast(str, self.model.config.get("country_name")) + + def _get_config_state_or_province_name(self) -> Optional[str]: + return cast(str, self.model.config.get("state_or_province_name")) + + def _get_config_locality_name(self) -> Optional[str]: + return cast(str, self.model.config.get("locality_name")) + + +if __name__ == "__main__": + main(DummyTLSCertificatesRequirerCharm) +``` + +You can integrate both charms by running: + +```bash +juju integrate +``` +""" # noqa: D214, D405, D411, D416 + +import copy +import ipaddress +import json +import logging +import uuid +from contextlib import suppress +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from enum import Enum +from typing import List, MutableMapping, Optional, Tuple, Union + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID +from ops import BoundEvent, CharmBase, CharmEvents, SecretExpiredEvent +from ops.framework import EventBase, EventSource, Handle, Object +from ops.jujuversion import JujuVersion +from ops.model import ( + Application, + ModelError, + Relation, + SecretNotFoundError, + Unit, +) +from pydantic import BaseModel, ConfigDict, ValidationError + +# The unique Charmhub library identifier, never change it +LIBID = "afd8c2bccf834997afce12c2706d2ede" + +# Increment this major API version when introducing breaking changes +LIBAPI = 4 + +# Increment this PATCH version before using `charmcraft publish-lib` or reset +# to 0 if you are raising the major API version +LIBPATCH = 1 + +PYDEPS = ["cryptography", "pydantic"] + +logger = logging.getLogger(__name__) + + +class TLSCertificatesError(Exception): + """Base class for custom errors raised by this library.""" + + +class DataValidationError(TLSCertificatesError): + """Raised when data validation fails.""" + + +class _DatabagModel(BaseModel): + """Base databag model.""" + + model_config = ConfigDict( + # tolerate additional keys in databag + extra="ignore", + # Allow instantiating this class by field name (instead of forcing alias). + populate_by_name=True, + # Custom config key: whether to nest the whole datastructure (as json) + # under a field or spread it out at the toplevel. + _NEST_UNDER=None, + ) # type: ignore + """Pydantic config.""" + + @classmethod + def load(cls, databag: MutableMapping): + """Load this model from a Juju databag.""" + nest_under = cls.model_config.get("_NEST_UNDER") + if nest_under: + return cls.model_validate(json.loads(databag[nest_under])) + + try: + data = { + k: json.loads(v) + for k, v in databag.items() + # Don't attempt to parse model-external values + if k in {(f.alias or n) for n, f in cls.model_fields.items()} + } + except json.JSONDecodeError as e: + msg = f"invalid databag contents: expecting json. {databag}" + logger.error(msg) + raise DataValidationError(msg) from e + + try: + return cls.model_validate_json(json.dumps(data)) + except ValidationError as e: + msg = f"failed to validate databag: {databag}" + logger.debug(msg, exc_info=True) + raise DataValidationError(msg) from e + + def dump(self, databag: Optional[MutableMapping] = None, clear: bool = True): + """Write the contents of this model to Juju databag. + + Args: + databag: The databag to write to. + clear: Whether to clear the databag before writing. + + Returns: + MutableMapping: The databag. + """ + if clear and databag: + databag.clear() + + if databag is None: + databag = {} + nest_under = self.model_config.get("_NEST_UNDER") + if nest_under: + databag[nest_under] = self.model_dump_json( + by_alias=True, + # skip keys whose values are default + exclude_defaults=True, + ) + return databag + + dct = self.model_dump(mode="json", by_alias=True, exclude_defaults=True) + databag.update({k: json.dumps(v) for k, v in dct.items()}) + return databag + + +class _Certificate(BaseModel): + """Certificate model.""" + + ca: str + certificate_signing_request: str + certificate: str + chain: Optional[List[str]] = None + recommended_expiry_notification_time: Optional[int] = None + revoked: Optional[bool] = None + + def to_provider_certificate(self, relation_id: int) -> "ProviderCertificate": + """Convert to a ProviderCertificate.""" + return ProviderCertificate( + relation_id=relation_id, + certificate=Certificate.from_string(self.certificate), + certificate_signing_request=CertificateSigningRequest.from_string( + self.certificate_signing_request + ), + ca=Certificate.from_string(self.ca), + chain=[Certificate.from_string(certificate) for certificate in self.chain] + if self.chain + else [], + recommended_expiry_notification_time=self.recommended_expiry_notification_time, + revoked=self.revoked, + ) + + +class _CertificateSigningRequest(BaseModel): + """Certificate signing request model.""" + + certificate_signing_request: str + ca: Optional[bool] + + +class _ProviderApplicationData(_DatabagModel): + """Provider application data model.""" + + certificates: List[_Certificate] + + +class _RequirerData(_DatabagModel): + """Requirer data model. + + The same model is used for the unit and application data. + """ + + certificate_signing_requests: List[_CertificateSigningRequest] + + +class Mode(Enum): + """Enum representing the mode of the certificate request. + + UNIT (default): Request a certificate for the unit. + Each unit will have its own private key and certificate. + APP: Request a certificate for the application. + The private key and certificate will be shared by all units. + """ + + UNIT = 1 + APP = 2 + + +@dataclass(frozen=True) +class PrivateKey: + """This class represents a private key.""" + + raw: str + + def __str__(self): + """Return the private key as a string.""" + return self.raw + + @classmethod + def from_string(cls, private_key: str) -> "PrivateKey": + """Create a PrivateKey object from a private key.""" + return cls(raw=private_key.strip()) + + +@dataclass(frozen=True) +class Certificate: + """This class represents a certificate.""" + + raw: str + common_name: str + sans_dns: Optional[Tuple[str, ...]] = None + sans_ip: Optional[Tuple[str, ...]] = None + sans_oid: Optional[Tuple[str, ...]] = None + email_address: Optional[str] = None + organization: Optional[str] = None + organizational_unit: Optional[str] = None + country_name: Optional[str] = None + state_or_province_name: Optional[str] = None + locality_name: Optional[str] = None + expiry_time: Optional[datetime] = None + validity_start_time: Optional[datetime] = None + + def __str__(self) -> str: + """Return the certificate as a string.""" + return self.raw + + @classmethod + def from_string(cls, certificate: str) -> "Certificate": + """Create a Certificate object from a certificate.""" + try: + certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) + except ValueError as e: + logger.error("Could not load certificate: %s", e) + raise TLSCertificatesError("Could not load certificate") + common_name = certificate_object.subject.get_attributes_for_oid(NameOID.COMMON_NAME) + country_name = certificate_object.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME) + state_or_province_name = certificate_object.subject.get_attributes_for_oid( + NameOID.STATE_OR_PROVINCE_NAME + ) + locality_name = certificate_object.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME) + organization_name = certificate_object.subject.get_attributes_for_oid( + NameOID.ORGANIZATION_NAME + ) + email_address = certificate_object.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS) + try: + sans = certificate_object.extensions.get_extension_for_class( + x509.SubjectAlternativeName + ).value + sans_dns = tuple( + str(san) + for san in sans.get_values_for_type(x509.DNSName) + if isinstance(san, x509.DNSName) + ) + sans_ip = tuple( + str(san) + for san in sans.get_values_for_type(x509.IPAddress) + if isinstance(san, x509.IPAddress) + ) + sans_oid = tuple( + str(san) + for san in sans.get_values_for_type(x509.RegisteredID) + if isinstance(san, x509.RegisteredID) + ) + except x509.ExtensionNotFound: + logger.debug("No SANs found in certificate") + sans_dns = None + sans_ip = None + sans_oid = None + expiry_time = certificate_object.not_valid_after_utc + validity_start_time = certificate_object.not_valid_before_utc + + return cls( + raw=certificate.strip(), + common_name=str(common_name[0].value), + country_name=str(country_name[0].value) if country_name else None, + state_or_province_name=str(state_or_province_name[0].value) + if state_or_province_name + else None, + locality_name=str(locality_name[0].value) if locality_name else None, + organization=str(organization_name[0].value) if organization_name else None, + email_address=str(email_address[0].value) if email_address else None, + sans_dns=sans_dns, + sans_ip=sans_ip, + sans_oid=sans_oid, + expiry_time=expiry_time, + validity_start_time=validity_start_time, + ) + + +@dataclass(frozen=True) +class CertificateSigningRequest: + """This class represents a certificate signing request.""" + + raw: str + common_name: str + sans_dns: Optional[Tuple[str, ...]] = None + sans_ip: Optional[Tuple[str, ...]] = None + sans_oid: Optional[Tuple[str, ...]] = None + email_address: Optional[str] = None + organization: Optional[str] = None + organizational_unit: Optional[str] = None + country_name: Optional[str] = None + state_or_province_name: Optional[str] = None + locality_name: Optional[str] = None + is_ca: bool = False + + def __eq__(self, other: object) -> bool: + """Check if two CertificateSigningRequest objects are equal.""" + if not isinstance(other, CertificateSigningRequest): + return NotImplemented + return self.raw.strip() == other.raw.strip() + + def __str__(self) -> str: + """Return the CSR as a string.""" + return self.raw + + def to_certificate_request(self) -> "CertificateRequest": + """Convert to a CertificateRequest object.""" + return CertificateRequest( + common_name=self.common_name, + sans_dns=self.sans_dns, + sans_ip=self.sans_ip, + sans_oid=self.sans_oid, + email_address=self.email_address, + organization=self.organization, + organizational_unit=self.organizational_unit, + country_name=self.country_name, + state_or_province_name=self.state_or_province_name, + locality_name=self.locality_name, + is_ca=self.is_ca, + ) + + @classmethod + def from_string(cls, csr: str) -> "CertificateSigningRequest": + """Create a CertificateSigningRequest object from a CSR.""" + try: + csr_object = x509.load_pem_x509_csr(csr.encode()) + except ValueError as e: + logger.error("Could not load CSR: %s", e) + raise TLSCertificatesError("Could not load CSR") + common_name = csr_object.subject.get_attributes_for_oid(NameOID.COMMON_NAME) + country_name = csr_object.subject.get_attributes_for_oid(NameOID.COUNTRY_NAME) + state_or_province_name = csr_object.subject.get_attributes_for_oid( + NameOID.STATE_OR_PROVINCE_NAME + ) + locality_name = csr_object.subject.get_attributes_for_oid(NameOID.LOCALITY_NAME) + organization_name = csr_object.subject.get_attributes_for_oid(NameOID.ORGANIZATION_NAME) + email_address = csr_object.subject.get_attributes_for_oid(NameOID.EMAIL_ADDRESS) + try: + sans = csr_object.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + sans_dns = tuple(sans.get_values_for_type(x509.DNSName)) + sans_ip = tuple([str(san) for san in sans.get_values_for_type(x509.IPAddress)]) + sans_oid = tuple([str(san) for san in sans.get_values_for_type(x509.RegisteredID)]) + except x509.ExtensionNotFound: + sans = () + sans_dns = () + sans_ip = () + sans_oid = () + return cls( + raw=csr.strip(), + common_name=str(common_name[0].value), + country_name=str(country_name[0].value) if country_name else None, + state_or_province_name=str(state_or_province_name[0].value) + if state_or_province_name + else None, + locality_name=str(locality_name[0].value) if locality_name else None, + organization=str(organization_name[0].value) if organization_name else None, + email_address=str(email_address[0].value) if email_address else None, + sans_dns=sans_dns, + sans_ip=sans_ip if sans_ip else None, + sans_oid=sans_oid if sans_oid else None, + ) + + def matches_private_key(self, key: PrivateKey) -> bool: + """Check if a CSR matches a private key. + + This function only works with RSA keys. + + Args: + key (PrivateKey): Private key + Returns: + bool: True/False depending on whether the CSR matches the private key. + """ + try: + csr_object = x509.load_pem_x509_csr(self.raw.encode("utf-8")) + key_object = serialization.load_pem_private_key( + data=key.raw.encode("utf-8"), password=None + ) + key_object_public_key = key_object.public_key() + csr_object_public_key = csr_object.public_key() + if not isinstance(key_object_public_key, rsa.RSAPublicKey): + logger.warning("Key is not an RSA key") + return False + if not isinstance(csr_object_public_key, rsa.RSAPublicKey): + logger.warning("CSR is not an RSA key") + return False + if ( + csr_object_public_key.public_numbers().n + != key_object_public_key.public_numbers().n + ): + logger.warning("Public key numbers between CSR and key do not match") + return False + except ValueError: + logger.warning("Could not load certificate or CSR.") + return False + return True + + def matches_certificate(self, certificate: Certificate) -> bool: + """Check if a CSR matches a certificate. + + Args: + certificate (Certificate): Certificate + Returns: + bool: True/False depending on whether the CSR matches the certificate. + """ + csr_object = x509.load_pem_x509_csr(self.raw.encode("utf-8")) + cert_object = x509.load_pem_x509_certificate(certificate.raw.encode("utf-8")) + return csr_object.public_key() == cert_object.public_key() + + def get_sha256_hex(self) -> str: + """Calculate the hash of the provided data and return the hexadecimal representation.""" + digest = hashes.Hash(hashes.SHA256()) + digest.update(self.raw.encode()) + return digest.finalize().hex() + + +@dataclass(frozen=True) +class CertificateRequest: + """This class represents a certificate request. + + This class should be used inside the requirer charm to specify the requested + attributes for the certificate. + """ + + common_name: str + sans_dns: Optional[Tuple[str, ...]] = None + sans_ip: Optional[Tuple[str, ...]] = None + sans_oid: Optional[Tuple[str, ...]] = None + email_address: Optional[str] = None + organization: Optional[str] = None + organizational_unit: Optional[str] = None + country_name: Optional[str] = None + state_or_province_name: Optional[str] = None + locality_name: Optional[str] = None + is_ca: bool = False + + def is_valid(self) -> bool: + """Check whether the certificate request is valid.""" + if not self.common_name: + return False + return True + + def generate_csr( # noqa: C901 + self, + private_key: PrivateKey, + add_unique_id_to_subject_name: bool = True, + ) -> CertificateSigningRequest: + """Generate a CSR using private key and subject. + + Args: + private_key (PrivateKey): Private key + add_unique_id_to_subject_name (bool): Whether a unique ID must be added to the CSR's + subject name. Always leave to "True" when the CSR is used to request certificates + using the tls-certificates relation. + + Returns: + CertificateSigningRequest: CSR + """ + signing_key = serialization.load_pem_private_key(str(private_key).encode(), password=None) + subject_name = [x509.NameAttribute(x509.NameOID.COMMON_NAME, self.common_name)] + if add_unique_id_to_subject_name: + unique_identifier = uuid.uuid4() + subject_name.append( + x509.NameAttribute(x509.NameOID.X500_UNIQUE_IDENTIFIER, str(unique_identifier)) + ) + if self.organization: + subject_name.append( + x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, self.organization) + ) + if self.email_address: + subject_name.append(x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, self.email_address)) + if self.country_name: + subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, self.country_name)) + if self.state_or_province_name: + subject_name.append( + x509.NameAttribute( + x509.NameOID.STATE_OR_PROVINCE_NAME, self.state_or_province_name + ) + ) + if self.locality_name: + subject_name.append(x509.NameAttribute(x509.NameOID.LOCALITY_NAME, self.locality_name)) + csr = x509.CertificateSigningRequestBuilder(subject_name=x509.Name(subject_name)) + + _sans: List[x509.GeneralName] = [] + if self.sans_oid: + _sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in self.sans_oid]) + if self.sans_ip: + _sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in self.sans_ip]) + if self.sans_dns: + _sans.extend([x509.DNSName(san) for san in self.sans_dns]) + if _sans: + csr = csr.add_extension(x509.SubjectAlternativeName(set(_sans)), critical=False) + signed_certificate = csr.sign(signing_key, hashes.SHA256()) # type: ignore[arg-type] + csr_str = signed_certificate.public_bytes(serialization.Encoding.PEM).decode() + return CertificateSigningRequest.from_string(csr_str) + + +@dataclass(frozen=True) +class ProviderCertificate: + """This class represents a certificate provided by the TLS provider.""" + + relation_id: int + certificate: Certificate + certificate_signing_request: CertificateSigningRequest + ca: Certificate + chain: List[Certificate] + recommended_expiry_notification_time: Optional[int] = None + revoked: Optional[bool] = None + + def to_json(self) -> str: + """Return the object as a JSON string. + + Returns: + str: JSON representation of the object + """ + return json.dumps( + { + "csr": str(self.certificate_signing_request), + "certificate": str(self.certificate), + "ca": str(self.ca), + "chain": [str(cert) for cert in self.chain], + "revoked": self.revoked, + } + ) + + +@dataclass(frozen=True) +class RequirerCSR: + """This class represents a certificate signing request requested by the TLS requirer.""" + + relation_id: int + certificate_signing_request: CertificateSigningRequest + + +class CertificateAvailableEvent(EventBase): + """Charm Event triggered when a TLS certificate is available.""" + + def __init__( + self, + handle: Handle, + certificate: Certificate, + certificate_signing_request: CertificateSigningRequest, + ca: Certificate, + chain: List[Certificate], + ): + super().__init__(handle) + self.certificate = certificate + self.certificate_signing_request = certificate_signing_request + self.ca = ca + self.chain = chain + + def snapshot(self) -> dict: + """Return snapshot.""" + return { + "certificate": str(self.certificate), + "certificate_signing_request": str(self.certificate_signing_request), + "ca": str(self.ca), + "chain": json.dumps([str(certificate) for certificate in self.chain]), + } + + def restore(self, snapshot: dict): + """Restore snapshot.""" + self.certificate = Certificate.from_string(snapshot["certificate"]) + self.certificate_signing_request = CertificateSigningRequest.from_string( + snapshot["certificate_signing_request"] + ) + self.ca = Certificate.from_string(snapshot["ca"]) + chain_strs = json.loads(snapshot["chain"]) + self.chain = [Certificate.from_string(chain_str) for chain_str in chain_strs] + + def chain_as_pem(self) -> str: + """Return full certificate chain as a PEM string.""" + return "\n\n".join([str(cert) for cert in self.chain]) + + +def _get_closest_future_time( + expiry_notification_time: datetime, expiry_time: datetime +) -> datetime: + """Return expiry_notification_time if not in the past, otherwise return expiry_time. + + Args: + expiry_notification_time (datetime): Notification time of impending expiration + expiry_time (datetime): Expiration time + + Returns: + datetime: expiry_notification_time if not in the past, expiry_time otherwise + """ + return ( + expiry_notification_time + if datetime.now(timezone.utc) < expiry_notification_time + else expiry_time + ) + + +def calculate_expiry_notification_time( + validity_start_time: datetime, + expiry_time: datetime, + provider_recommended_notification_time: Optional[int], +) -> datetime: + """Calculate a reasonable time to notify the user about the certificate expiry. + + It takes into account the time recommended by the provider. + Time recommended by the provider is preferred, + then dynamically calculated time. + + Args: + validity_start_time: Certificate validity time + expiry_time: Certificate expiry time + provider_recommended_notification_time: + Time in hours prior to expiry to notify the user. + Recommended by the provider. + + Returns: + datetime: Time to notify the user about the certificate expiry. + """ + if provider_recommended_notification_time is not None: + provider_recommended_notification_time = abs(provider_recommended_notification_time) + provider_recommendation_time_delta = expiry_time - timedelta( + hours=provider_recommended_notification_time + ) + if validity_start_time < provider_recommendation_time_delta: + return provider_recommendation_time_delta + calculated_hours = (expiry_time - validity_start_time).total_seconds() / (3600 * 3) + return expiry_time - timedelta(hours=calculated_hours) + + +def _generate_private_key( + key_size: int = 2048, + public_exponent: int = 65537, +) -> PrivateKey: + """Generate a private key with the RSA algorithm. + + Args: + key_size (int): Key size in bytes + public_exponent: Public exponent. + + Returns: + str: Private Key + """ + private_key = rsa.generate_private_key( + public_exponent=public_exponent, + key_size=key_size, + ) + key_bytes = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + return PrivateKey.from_string(key_bytes.decode()) + + +class CertificatesRequirerCharmEvents(CharmEvents): + """List of events that the TLS Certificates requirer charm can leverage.""" + + certificate_available = EventSource(CertificateAvailableEvent) + + +class TLSCertificatesRequiresV4(Object): + """A class to manage the TLS certificates interface for a unit or app.""" + + on = CertificatesRequirerCharmEvents() # type: ignore[reportAssignmentType] + + def __init__( + self, + charm: CharmBase, + relationship_name: str, + certificate_requests: List[CertificateRequest], + mode: Mode = Mode.UNIT, + refresh_events: List[BoundEvent] = [], + ): + """Create a new instance of the TLSCertificatesRequiresV4 class. + + Args: + charm (CharmBase): The charm instance to relate to. + relationship_name (str): The name of the relation that provides the certificates. + certificate_requests (List[CertificateRequest]): A list of certificate requests. + mode (Mode): Whether to use unit or app certificates mode. Default is Mode.UNIT. + refresh_events (List[BoundEvent]): A list of events to trigger a refresh of + the certificates. + """ + super().__init__(charm, relationship_name) + if not JujuVersion.from_environ().has_secrets: + logger.warning("This version of the TLS library requires Juju secrets (Juju >= 3.0)") + if not self._mode_is_valid(mode): + raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP") + for certificate_request in certificate_requests: + if not certificate_request.is_valid(): + raise TLSCertificatesError("Invalid certificate request") + self.charm = charm + self.relationship_name = relationship_name + self.certificate_requests = certificate_requests + self.mode = mode + self.framework.observe(charm.on[relationship_name].relation_created, self._configure) + self.framework.observe(charm.on[relationship_name].relation_changed, self._configure) + self.framework.observe(charm.on.secret_expired, self._on_secret_expired) + for event in refresh_events: + self.framework.observe(event, self._configure) + + def _configure(self, _: EventBase): + """Handle TLS Certificates Relation Data. + + This method is called during any TLS relation event. + It will generate a private key if it doesn't exist yet. + It will send certificate requests if they haven't been sent yet. + It will find available certificates and emit events. + """ + if not self._tls_relation_created(): + logger.debug("TLS relation not created yet.") + return + self._generate_private_key() + self._send_certificate_requests() + self._find_available_certificates() + self._cleanup_certificate_requests() + + def _mode_is_valid(self, mode) -> bool: + return mode in [Mode.UNIT, Mode.APP] + + def _on_secret_expired(self, event: SecretExpiredEvent) -> None: + """Handle Secret Expired Event. + + Renews certificate requests and removes the expired secret. + """ + if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-certificate"): + return + try: + csr_str = event.secret.get_content(refresh=True)["csr"] + except ModelError: + logger.error("Failed to get CSR from secret - Skipping renewal") + return + csr = CertificateSigningRequest.from_string(csr_str) + self._renew_certificate_request(csr) + event.secret.remove_all_revisions() + + def _renew_certificate_request(self, csr: CertificateSigningRequest): + """Remove existing CSR from relation data and create a new one.""" + self._remove_requirer_csr_from_relation_data(csr) + self._send_certificate_requests() + logger.info("Renewed certificate request") + + def _remove_requirer_csr_from_relation_data(self, csr: CertificateSigningRequest) -> None: + relation = self.model.get_relation(self.relationship_name) + if not relation: + logger.debug("No relation: %s", self.relationship_name) + return + if not self.get_csrs_from_requirer_relation_data(): + logger.info("No CSRs in relation data - Doing nothing") + return + app_or_unit = self._get_app_or_unit() + try: + requirer_relation_data = _RequirerData.load(relation.data[app_or_unit]) + except DataValidationError: + logger.warning("Invalid relation data - Skipping removal of CSR") + return + new_relation_data = copy.deepcopy(requirer_relation_data.certificate_signing_requests) + for requirer_csr in new_relation_data: + if requirer_csr.certificate_signing_request.strip() == str(csr).strip(): + new_relation_data.remove(requirer_csr) + try: + _RequirerData(certificate_signing_requests=new_relation_data).dump( + relation.data[app_or_unit] + ) + logger.info("Removed CSR from relation data") + except ModelError: + logger.warning("Failed to update relation data") + + def _get_app_or_unit(self) -> Union[Application, Unit]: + """Return the unit or app object based on the mode.""" + if self.mode == Mode.UNIT: + return self.model.unit + elif self.mode == Mode.APP: + return self.model.app + raise TLSCertificatesError("Invalid mode") + + @property + def private_key(self) -> PrivateKey | None: + """Return the private key.""" + if not self._private_key_generated(): + return None + secret = self.charm.model.get_secret(label=self._get_private_key_secret_label()) + private_key = secret.get_content(refresh=True)["private-key"] + return PrivateKey.from_string(private_key) + + def _generate_private_key(self) -> None: + if self._private_key_generated(): + return + private_key = _generate_private_key() + self.charm.unit.add_secret( + content={"private-key": str(private_key)}, + label=self._get_private_key_secret_label(), + ) + logger.info("Private key generated") + + def regenerate_private_key(self) -> None: + """Regenerate the private key. + + Generate a new private key, remove old certificate requests and send new ones. + """ + if not self._private_key_generated(): + logger.warning("No private key to regenerate") + return + self._regenerate_private_key() + self._cleanup_certificate_requests() + self._send_certificate_requests() + + def _regenerate_private_key(self) -> None: + secret = self.charm.model.get_secret(label=self._get_private_key_secret_label()) + secret.set_content({"private-key": str(_generate_private_key())}) + + def _private_key_generated(self) -> bool: + try: + self.charm.model.get_secret(label=self._get_private_key_secret_label()) + except (SecretNotFoundError, KeyError): + return False + return True + + def _csr_matches_certificate_request(self, csr: CertificateSigningRequest) -> bool: + for certificate_request in self.certificate_requests: + if csr.to_certificate_request() == certificate_request: + return True + return False + + def _certificate_requested(self, certificate_request: CertificateRequest) -> bool: + if not self.private_key: + return False + csr = self._certificate_requested_for_attributes(certificate_request) + if not csr: + return False + if not csr.matches_private_key(key=self.private_key): + return False + return True + + def _certificate_requested_for_attributes( + self, certificate_request: CertificateRequest + ) -> Optional[CertificateSigningRequest]: + for requirer_csr in self.get_csrs_from_requirer_relation_data(): + if requirer_csr.to_certificate_request() == certificate_request: + return requirer_csr + return None + + def get_csrs_from_requirer_relation_data(self) -> List[CertificateSigningRequest]: + """Return list of requirer's CSRs from relation data.""" + relation = self.model.get_relation(self.relationship_name) + if not relation: + logger.debug("No relation: %s", self.relationship_name) + return [] + app_or_unit = self._get_app_or_unit() + try: + requirer_relation_data = _RequirerData.load(relation.data[app_or_unit]) + except DataValidationError: + logger.warning("Invalid relation data") + return [] + return [ + CertificateSigningRequest.from_string(csr.certificate_signing_request) + for csr in requirer_relation_data.certificate_signing_requests + ] + + def get_provider_certificates(self) -> List[ProviderCertificate]: + """Return list of certificates from the provider's relation data.""" + return self._load_provider_certificates() + + def _load_provider_certificates(self) -> List[ProviderCertificate]: + relation = self.model.get_relation(self.relationship_name) + if not relation: + logger.debug("No relation: %s", self.relationship_name) + return [] + if not relation.app: + logger.debug("No remote app in relation: %s", self.relationship_name) + return [] + try: + provider_relation_data = _ProviderApplicationData.load(relation.data[relation.app]) + except DataValidationError: + logger.warning("Invalid relation data") + return [] + return [ + certificate.to_provider_certificate(relation_id=relation.id) + for certificate in provider_relation_data.certificates + ] + + def _request_certificate(self, csr: CertificateSigningRequest, is_ca: bool) -> None: + """Add CSR to relation data.""" + relation = self.model.get_relation(self.relationship_name) + if not relation: + logger.debug("No relation: %s", self.relationship_name) + return + new_csr = _CertificateSigningRequest( + certificate_signing_request=str(csr).strip(), ca=is_ca + ) + app_or_unit = self._get_app_or_unit() + try: + requirer_relation_data = _RequirerData.load(relation.data[app_or_unit]) + except DataValidationError: + requirer_relation_data = _RequirerData( + certificate_signing_requests=[], + ) + new_relation_data = copy.deepcopy(requirer_relation_data.certificate_signing_requests) + new_relation_data.append(new_csr) + try: + _RequirerData(certificate_signing_requests=new_relation_data).dump( + relation.data[app_or_unit] + ) + logger.info("Certificate signing request added to relation data.") + except ModelError: + logger.warning("Failed to update relation data") + + def _send_certificate_requests(self): + if not self.private_key: + logger.debug("Private key not generated yet.") + return + for certificate_request in self.certificate_requests: + if not self._certificate_requested(certificate_request): + csr = certificate_request.generate_csr( + private_key=self.private_key, + ) + if not csr: + logger.warning("Failed to generate CSR") + continue + self._request_certificate(csr=csr, is_ca=certificate_request.is_ca) + + def get_assigned_certificate( + self, certificate_request: CertificateRequest + ) -> Tuple[ProviderCertificate | None, PrivateKey | None]: + """Get the certificate that was assigned to the given certificate request.""" + for requirer_csr in self.get_csrs_from_requirer_relation_data(): + if certificate_request == requirer_csr.to_certificate_request(): + return self._find_certificate_in_relation_data(requirer_csr), self.private_key + return None, None + + def get_assigned_certificates(self) -> Tuple[List[ProviderCertificate], PrivateKey | None]: + """Get a list of certificates that were assigned to this or app.""" + assigned_certificates = [] + for requirer_csr in self.get_csrs_from_requirer_relation_data(): + if cert := self._find_certificate_in_relation_data(requirer_csr): + assigned_certificates.append(cert) + return assigned_certificates, self.private_key + + def _find_certificate_in_relation_data( + self, csr: CertificateSigningRequest + ) -> Optional[ProviderCertificate]: + """Return the certificate that match the given CSR.""" + for provider_certificate in self.get_provider_certificates(): + if provider_certificate.certificate_signing_request == csr: + return provider_certificate + return None + + def _find_available_certificates(self): + """Find available certificates and emit events. + + This method will find certificates that are available for the requirer's CSRs. + If a certificate is found, it will be set as a secret and an event will be emitted. + If a certificate is revoked, the secret will be removed and an event will be emitted. + """ + requirer_csrs = self.get_csrs_from_requirer_relation_data() + provider_certificates = self.get_provider_certificates() + for provider_certificate in provider_certificates: + if provider_certificate.certificate_signing_request in requirer_csrs: + secret_label = self._get_csr_secret_label( + provider_certificate.certificate_signing_request + ) + if provider_certificate.revoked: + with suppress(SecretNotFoundError): + logger.debug( + "Removing secret with label %s", + secret_label, + ) + secret = self.model.get_secret(label=secret_label) + secret.remove_all_revisions() + else: + if not self._csr_matches_certificate_request( + provider_certificate.certificate_signing_request + ): + logger.debug("Certificate requested for different attributes - Skipping") + continue + try: + logger.debug("Setting secret with label %s", secret_label) + secret = self.model.get_secret(label=secret_label) + secret.set_content( + content={ + "certificate": str(provider_certificate.certificate), + "csr": str(provider_certificate.certificate_signing_request), + } + ) + secret.set_info( + expire=self._get_next_secret_expiry_time(provider_certificate), + ) + except SecretNotFoundError: + logger.debug("Creating new secret with label %s", secret_label) + secret = self.charm.unit.add_secret( + content={ + "certificate": str(provider_certificate.certificate), + "csr": str(provider_certificate.certificate_signing_request), + }, + label=secret_label, + expire=self._get_next_secret_expiry_time(provider_certificate), + ) + self.on.certificate_available.emit( + certificate_signing_request=provider_certificate.certificate_signing_request, + certificate=provider_certificate.certificate, + ca=provider_certificate.ca, + chain=provider_certificate.chain, + ) + + def _cleanup_certificate_requests(self): + """Clean up certificate requests. + + Remove any certificate requests that falls into one of the following categories: + - The CSR attributes do not match any of the certificate requests defined in + the charm's certificate_requests attribute. + - The CSR public key does not match the private key. + """ + for requirer_csr in self.get_csrs_from_requirer_relation_data(): + if not self._csr_matches_certificate_request(requirer_csr): + self._remove_requirer_csr_from_relation_data(requirer_csr) + logger.info( + "Removed CSR from relation data because \ + it did not match any certificate request" + ) + elif self.private_key and not requirer_csr.matches_private_key(self.private_key): + self._remove_requirer_csr_from_relation_data(requirer_csr) + logger.info( + "Removed CSR from relation data because \ + it did not match the private key" + ) + + def _get_next_secret_expiry_time( + self, provider_certificate: ProviderCertificate + ) -> Optional[datetime]: + """Return the expiry time or expiry notification time. + + Extracts the expiry time from the provided certificate, calculates the + expiry notification time and return the closest of the two, that is in + the future. + + Args: + provider_certificate: ProviderCertificate object + + Returns: + Optional[datetime]: None if the certificate expiry time cannot be read, + next expiry time otherwise. + """ + if not provider_certificate.certificate.expiry_time: + logger.warning("Certificate has no expiry time") + return None + if not provider_certificate.certificate.validity_start_time: + logger.warning("Certificate has no validity start time") + return None + expiry_notification_time = calculate_expiry_notification_time( + validity_start_time=provider_certificate.certificate.validity_start_time, + expiry_time=provider_certificate.certificate.expiry_time, + provider_recommended_notification_time=provider_certificate.recommended_expiry_notification_time, + ) + if not expiry_notification_time: + logger.warning("Could not calculate expiry notification time") + return None + return _get_closest_future_time( + expiry_notification_time, + provider_certificate.certificate.expiry_time, + ) + + def _tls_relation_created(self) -> bool: + relation = self.model.get_relation(self.relationship_name) + if not relation: + return False + return True + + def _get_private_key_secret_label(self) -> str: + if self.mode == Mode.UNIT: + return f"{LIBID}-private-key-{self._get_unit_number()}" + elif self.mode == Mode.APP: + return f"{LIBID}-private-key" + else: + raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP.") + + def _get_csr_secret_label(self, csr: CertificateSigningRequest) -> str: + csr_in_sha256_hex = csr.get_sha256_hex() + if self.mode == Mode.UNIT: + return f"{LIBID}-certificate-{self._get_unit_number()}-{csr_in_sha256_hex}" + elif self.mode == Mode.APP: + return f"{LIBID}-certificate-{csr_in_sha256_hex}" + else: + raise TLSCertificatesError("Invalid mode. Must be Mode.UNIT or Mode.APP.") + + def _get_unit_number(self) -> str: + return self.model.unit.name.split("/")[1] + + +class TLSCertificatesProvidesV4(Object): + """TLS certificates provider class to be instantiated by TLS certificates providers.""" + + def __init__(self, charm: CharmBase, relationship_name: str): + super().__init__(charm, relationship_name) + self.framework.observe(charm.on[relationship_name].relation_joined, self._configure) + self.framework.observe(charm.on[relationship_name].relation_changed, self._configure) + self.framework.observe(charm.on.update_status, self._configure) + self.charm = charm + self.relationship_name = relationship_name + + def _configure(self, _: EventBase) -> None: + """Handle update status and tls relation changed events. + + This is a common hook triggered on a regular basis. + + Revoke certificates for which no csr exists + """ + if not self.model.unit.is_leader(): + return + self._remove_certificates_for_which_no_csr_exists() + + def _remove_certificates_for_which_no_csr_exists(self) -> None: + provider_certificates = self._get_provider_certificates() + requirer_csrs = [ + request.certificate_signing_request for request in self.get_certificate_requests() + ] + for provider_certificate in provider_certificates: + if provider_certificate.certificate_signing_request not in requirer_csrs: + tls_relation = self._get_tls_relations( + relation_id=provider_certificate.relation_id + ) + self._remove_provider_certificate( + certificate=provider_certificate.certificate, + relation=tls_relation[0], + ) + + def _get_tls_relations(self, relation_id: Optional[int] = None) -> List[Relation]: + return ( + [ + relation + for relation in self.model.relations[self.relationship_name] + if relation.id == relation_id + ] + if relation_id is not None + else self.model.relations.get(self.relationship_name, []) + ) + + def get_certificate_requests(self, relation_id: Optional[int] = None) -> List[RequirerCSR]: + """Load certificate requests from the relation data.""" + relations = self._get_tls_relations(relation_id) + requirer_csrs: List[RequirerCSR] = [] + for relation in relations: + for unit in relation.units: + requirer_csrs.extend(self._load_requirer_databag(relation, unit)) + requirer_csrs.extend(self._load_requirer_databag(relation, relation.app)) + return requirer_csrs + + def _load_requirer_databag( + self, relation: Relation, unit_or_app: Union[Application, Unit] + ) -> List[RequirerCSR]: + try: + requirer_relation_data = _RequirerData.load(relation.data[unit_or_app]) + except DataValidationError: + logger.debug("Invalid requirer relation data for %s", unit_or_app.name) + return [] + return [ + RequirerCSR( + relation_id=relation.id, + certificate_signing_request=CertificateSigningRequest.from_string( + csr.certificate_signing_request + ), + ) + for csr in requirer_relation_data.certificate_signing_requests + ] + + def _add_provider_certificate( + self, + relation: Relation, + provider_certificate: ProviderCertificate, + ) -> None: + new_certificate = _Certificate( + certificate=str(provider_certificate.certificate), + certificate_signing_request=str(provider_certificate.certificate_signing_request), + ca=str(provider_certificate.ca), + chain=[str(certificate) for certificate in provider_certificate.chain], + recommended_expiry_notification_time=provider_certificate.recommended_expiry_notification_time, + ) + provider_certificates = self._load_provider_certificates(relation) + if new_certificate in provider_certificates: + logger.info("Certificate already in relation data - Doing nothing") + return + provider_certificates.append(new_certificate) + self._dump_provider_certificates(relation=relation, certificates=provider_certificates) + + def _load_provider_certificates(self, relation: Relation) -> List[_Certificate]: + try: + provider_relation_data = _ProviderApplicationData.load(relation.data[self.charm.app]) + except DataValidationError: + logger.debug("Invalid provider relation data") + return [] + return copy.deepcopy(provider_relation_data.certificates) + + def _dump_provider_certificates(self, relation: Relation, certificates: List[_Certificate]): + try: + _ProviderApplicationData(certificates=certificates).dump(relation.data[self.model.app]) + logger.info("Certificate relation data updated") + except ModelError: + logger.warning("Failed to update relation data") + + def _remove_provider_certificate( + self, + relation: Relation, + certificate: Optional[Certificate] = None, + certificate_signing_request: Optional[CertificateSigningRequest] = None, + ) -> None: + """Remove certificate based on certificate or certificate signing request.""" + provider_certificates = self._load_provider_certificates(relation) + for provider_certificate in provider_certificates: + if certificate and provider_certificate.certificate == str(certificate): + provider_certificates.remove(provider_certificate) + if ( + certificate_signing_request + and provider_certificate.certificate_signing_request + == str(certificate_signing_request) + ): + provider_certificates.remove(provider_certificate) + self._dump_provider_certificates(relation=relation, certificates=provider_certificates) + + def revoke_all_certificates(self) -> None: + """Revoke all certificates of this provider. + + This method is meant to be used when the Root CA has changed. + """ + if not self.model.unit.is_leader(): + logger.warning("Unit is not a leader - will not set relation data") + return + relations = self._get_tls_relations() + for relation in relations: + provider_certificates = self._load_provider_certificates(relation) + for certificate in provider_certificates: + certificate.revoked = True + self._dump_provider_certificates(relation=relation, certificates=provider_certificates) + + def set_relation_certificate( + self, + provider_certificate: ProviderCertificate, + ) -> None: + """Add certificates to relation data. + + Args: + provider_certificate (ProviderCertificate): ProviderCertificate object + + Returns: + None + """ + if not self.model.unit.is_leader(): + logger.warning("Unit is not a leader - will not set relation data") + return + certificates_relation = self.model.get_relation( + relation_name=self.relationship_name, relation_id=provider_certificate.relation_id + ) + if not certificates_relation: + raise TLSCertificatesError(f"Relation {self.relationship_name} does not exist") + self._remove_provider_certificate( + relation=certificates_relation, + certificate_signing_request=provider_certificate.certificate_signing_request, + ) + self._add_provider_certificate( + relation=certificates_relation, + provider_certificate=provider_certificate, + ) + + def get_issued_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Return a List of issued (non revoked) certificates. + + Returns: + List: List of ProviderCertificate objects + """ + if not self.model.unit.is_leader(): + logger.warning("Unit is not a leader - will not read relation data") + return [] + provider_certificates = self._get_provider_certificates(relation_id=relation_id) + return [certificate for certificate in provider_certificates if not certificate.revoked] + + def _get_provider_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Return a List of issued certificates.""" + certificates: List[ProviderCertificate] = [] + relations = self._get_tls_relations(relation_id) + for relation in relations: + if not relation.app: + logger.warning("Relation %s does not have an application", relation.id) + continue + for certificate in self._load_provider_certificates(relation): + certificates.append(certificate.to_provider_certificate(relation_id=relation.id)) + return certificates + + def get_outstanding_certificate_requests( + self, relation_id: Optional[int] = None + ) -> List[RequirerCSR]: + """Return CSR's for which no certificate has been issued. + + Args: + relation_id (int): Relation id + + Returns: + list: List of RequirerCSR objects. + """ + requirer_csrs = self.get_certificate_requests(relation_id=relation_id) + outstanding_csrs: List[RequirerCSR] = [] + for relation_csr in requirer_csrs: + if not self._certificate_issued_for_csr( + csr=relation_csr.certificate_signing_request, + relation_id=relation_id, + ): + outstanding_csrs.append(relation_csr) + return outstanding_csrs + + def _certificate_issued_for_csr( + self, csr: CertificateSigningRequest, relation_id: Optional[int] + ) -> bool: + """Check whether a certificate has been issued for a given CSR.""" + issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id) + for issued_certificate in issued_certificates_per_csr: + if issued_certificate.certificate_signing_request == csr: + return csr.matches_certificate(issued_certificate.certificate) + return False diff --git a/src/certificates.py b/src/certificates.py new file mode 100644 index 0000000..80d0567 --- /dev/null +++ b/src/certificates.py @@ -0,0 +1,259 @@ +#!/usr/bin/env python3 + +# Copyright 2024 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Utilities for generating certificates.""" + +import logging +from datetime import datetime, timedelta, timezone +from typing import List + +from cryptography import x509 +from cryptography.hazmat._oid import ExtensionOID +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa + +logger = logging.getLogger(__name__) + + +def generate_private_key( + key_size: int = 2048, + public_exponent: int = 65537, +) -> str: + """Generate a private key. + + Args: + key_size (int): Key size in bytes + public_exponent: Public exponent. + + Returns: + str: Private Key + """ + private_key = rsa.generate_private_key( + public_exponent=public_exponent, + key_size=key_size, + ) + key_bytes = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + return key_bytes.decode().strip() + + +def get_certificate_request_extensions( + authority_key_identifier: bytes, + csr: x509.CertificateSigningRequest, + is_ca: bool, +) -> List[x509.Extension]: + """Generate a list of certificate extensions from a CSR and other known information. + + Args: + authority_key_identifier (bytes): Authority key identifier + csr (x509.CertificateSigningRequest): CSR + is_ca (bool): Whether the certificate is a CA certificate + + Returns: + List[x509.Extension]: List of extensions + """ + cert_extensions_list: List[x509.Extension] = [ + x509.Extension( + oid=ExtensionOID.AUTHORITY_KEY_IDENTIFIER, + value=x509.AuthorityKeyIdentifier( + key_identifier=authority_key_identifier, + authority_cert_issuer=None, + authority_cert_serial_number=None, + ), + critical=False, + ), + x509.Extension( + oid=ExtensionOID.SUBJECT_KEY_IDENTIFIER, + value=x509.SubjectKeyIdentifier.from_public_key(csr.public_key()), + critical=False, + ), + x509.Extension( + oid=ExtensionOID.BASIC_CONSTRAINTS, + critical=True, + value=x509.BasicConstraints(ca=is_ca, path_length=None), + ), + ] + sans: List[x509.GeneralName] = [] + try: + loaded_san_ext = csr.extensions.get_extension_for_class(x509.SubjectAlternativeName) + sans.extend( + [x509.DNSName(name) for name in loaded_san_ext.value.get_values_for_type(x509.DNSName)] + ) + sans.extend( + [x509.IPAddress(ip) for ip in loaded_san_ext.value.get_values_for_type(x509.IPAddress)] + ) + sans.extend( + [ + x509.RegisteredID(oid) + for oid in loaded_san_ext.value.get_values_for_type(x509.RegisteredID) + ] + ) + except x509.ExtensionNotFound: + pass + + if sans: + cert_extensions_list.append( + x509.Extension( + oid=ExtensionOID.SUBJECT_ALTERNATIVE_NAME, + critical=False, + value=x509.SubjectAlternativeName(sans), + ) + ) + + if is_ca: + cert_extensions_list.append( + x509.Extension( + ExtensionOID.KEY_USAGE, + critical=True, + value=x509.KeyUsage( + digital_signature=False, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=True, + crl_sign=True, + encipher_only=False, + decipher_only=False, + ), + ) + ) + + existing_oids = {ext.oid for ext in cert_extensions_list} + for extension in csr.extensions: + if extension.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME: + continue + if extension.oid in existing_oids: + logger.warning("Extension %s is managed by the TLS provider, ignoring.", extension.oid) + continue + cert_extensions_list.append(extension) + + return cert_extensions_list + + +def generate_certificate( + csr: str, + ca: str, + ca_key: str, + validity: int = 365, + is_ca: bool = False, +) -> str: + """Generate a TLS certificate based on a CSR. + + Args: + csr (str): CSR + ca (str): CA Certificate + ca_key (str): CA private key + validity (int): Certificate validity (in days) + is_ca (bool): Whether the certificate is a CA certificate + + Returns: + str: Certificate + """ + csr_object = x509.load_pem_x509_csr(csr.encode()) + subject = csr_object.subject + ca_pem = x509.load_pem_x509_certificate(ca.encode()) + issuer = ca_pem.issuer + private_key = serialization.load_pem_private_key(ca_key.encode(), password=None) + + certificate_builder = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(csr_object.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=validity)) + ) + extensions = get_certificate_request_extensions( + authority_key_identifier=ca_pem.extensions.get_extension_for_class( + x509.SubjectKeyIdentifier + ).value.key_identifier, + csr=csr_object, + is_ca=is_ca, + ) + for extension in extensions: + try: + certificate_builder = certificate_builder.add_extension( + extval=extension.value, + critical=extension.critical, + ) + except ValueError as e: + logger.warning("Failed to add extension %s: %s", extension.oid, e) + + cert = certificate_builder.sign(private_key, hashes.SHA256()) # type: ignore[arg-type] + return cert.public_bytes(serialization.Encoding.PEM).decode().strip() + + +def generate_ca( + private_key: str, + subject: str, + validity: int = 365, + country: str = "US", +) -> str: + """Generate a CA Certificate. + + Args: + private_key (bytes): Private key + subject (str): Common Name that can be an IP or a Full Qualified Domain Name (FQDN). + validity (int): Certificate validity time (in days) + country (str): Certificate Issuing country + + Returns: + str: CA Certificate. + """ + private_key_object = serialization.load_pem_private_key( + private_key.encode(), + password=None, + ) + subject_name = x509.Name( + [ + x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country), + x509.NameAttribute(x509.NameOID.COMMON_NAME, subject), + ] + ) + subject_identifier_object = x509.SubjectKeyIdentifier.from_public_key( + private_key_object.public_key() # type: ignore[arg-type] + ) + subject_identifier = key_identifier = subject_identifier_object.public_bytes() + key_usage = x509.KeyUsage( + digital_signature=True, + key_encipherment=True, + key_cert_sign=True, + key_agreement=False, + content_commitment=False, + data_encipherment=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, + ) + cert = ( + x509.CertificateBuilder() + .subject_name(subject_name) + .issuer_name(subject_name) + .public_key(private_key_object.public_key()) # type: ignore[arg-type] + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=validity)) + .add_extension(x509.SubjectKeyIdentifier(digest=subject_identifier), critical=False) + .add_extension( + x509.AuthorityKeyIdentifier( + key_identifier=key_identifier, + authority_cert_issuer=None, + authority_cert_serial_number=None, + ), + critical=False, + ) + .add_extension(key_usage, critical=True) + .add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + .sign(private_key_object, hashes.SHA256()) # type: ignore[arg-type] + ) + return cert.public_bytes(serialization.Encoding.PEM).decode().strip() diff --git a/src/charm.py b/src/charm.py index 7f5abd6..4deca67 100755 --- a/src/charm.py +++ b/src/charm.py @@ -6,20 +6,19 @@ import datetime import logging -import secrets from typing import Optional, cast +from certificates import generate_ca, generate_certificate, generate_private_key from charms.certificate_transfer_interface.v0.certificate_transfer import ( CertificateTransferProvides, ) from charms.tempo_k8s.v1.charm_tracing import trace_charm from charms.tempo_k8s.v2.tracing import TracingEndpointRequirer -from charms.tls_certificates_interface.v3.tls_certificates import ( - CertificateCreationRequestEvent, - TLSCertificatesProvidesV3, - generate_ca, - generate_certificate, - generate_private_key, +from charms.tls_certificates_interface.v4.tls_certificates import ( + Certificate, + CertificateSigningRequest, + ProviderCertificate, + TLSCertificatesProvidesV4, ) from cryptography import x509 from ops.charm import ActionEvent, CharmBase, CollectStatusEvent, RelationJoinedEvent @@ -46,7 +45,7 @@ def certificate_has_common_name(certificate: bytes, common_name: str) -> bool: @trace_charm( tracing_endpoint="tempo_otlp_http_endpoint", - extra_types=(TLSCertificatesProvidesV3,), + extra_types=(TLSCertificatesProvidesV4,), ) class SelfSignedCertificatesCharm(CharmBase): """Main class to handle Juju events.""" @@ -54,17 +53,14 @@ class SelfSignedCertificatesCharm(CharmBase): def __init__(self, *args): """Observe config change and certificate request events.""" super().__init__(*args) - self.tls_certificates = TLSCertificatesProvidesV3(self, "certificates") + self.tls_certificates = TLSCertificatesProvidesV4(self, "certificates") self.tracing = TracingEndpointRequirer(self, protocols=["otlp_http"]) self.framework.observe(self.on.collect_unit_status, self._on_collect_unit_status) self.framework.observe(self.on.update_status, self._configure) self.framework.observe(self.on.config_changed, self._configure) self.framework.observe(self.on.secret_expired, self._configure) self.framework.observe(self.on.secret_changed, self._configure) - self.framework.observe( - self.tls_certificates.on.certificate_creation_request, - self._on_certificate_creation_request, - ) + self.framework.observe(self.on.certificates_relation_changed, self._configure) self.framework.observe(self.on.get_ca_certificate_action, self._on_get_ca_certificate) self.framework.observe( self.on.get_issued_certificates_action, self._on_get_issued_certificates @@ -158,17 +154,14 @@ def _generate_root_certificate(self) -> None: """ if not self._config_ca_common_name: raise ValueError("CA common name should not be empty") - private_key_password = generate_password() - private_key = generate_private_key(password=private_key_password.encode()) + private_key = generate_private_key() ca_certificate = generate_ca( private_key=private_key, subject=self._config_ca_common_name, - private_key_password=private_key_password.encode(), ) secret_content = { - "private-key-password": private_key_password, - "private-key": private_key.decode(), - "ca-certificate": ca_certificate.decode(), + "private-key": private_key, + "ca-certificate": ca_certificate, } if self._root_certificate_is_stored: secret = self.model.get_secret(label=CA_CERTIFICATES_SECRET_LABEL) @@ -214,8 +207,8 @@ def _process_outstanding_certificate_requests(self) -> None: """Process outstanding certificate requests.""" for request in self.tls_certificates.get_outstanding_certificate_requests(): self._generate_self_signed_certificate( - csr=request.csr, - is_ca=request.is_ca, + csr=str(request.certificate_signing_request), + is_ca=request.certificate_signing_request.is_ca, relation_id=request.relation_id, ) @@ -234,28 +227,6 @@ def _invalid_configs(self) -> list[str]: invalid_configs.append("certificate-validity") return invalid_configs - def _on_certificate_creation_request(self, event: CertificateCreationRequestEvent) -> None: - """Handle certificate requests. - - Args: - event (CertificateCreationRequestEvent): Juju event - """ - if not self.unit.is_leader(): - return - if self._invalid_configs(): - logger.warning("Invalid configuration. Certificate cannot be generated.") - return - if not self._root_certificate_is_stored: - logger.warning( - "Root certificate is not yet generated. Certificate cannot be generated." - ) - return - self._generate_self_signed_certificate( - csr=event.certificate_signing_request, - is_ca=event.is_ca, - relation_id=event.relation_id, - ) - def _generate_self_signed_certificate(self, csr: str, is_ca: bool, relation_id: int) -> None: """Generate self-signed certificate. @@ -267,19 +238,23 @@ def _generate_self_signed_certificate(self, csr: str, is_ca: bool, relation_id: ca_certificate_secret = self.model.get_secret(label=CA_CERTIFICATES_SECRET_LABEL) ca_certificate_secret_content = ca_certificate_secret.get_content(refresh=True) certificate = generate_certificate( - ca=ca_certificate_secret_content["ca-certificate"].encode(), - ca_key=ca_certificate_secret_content["private-key"].encode(), - ca_key_password=ca_certificate_secret_content["private-key-password"].encode(), - csr=csr.encode(), + ca=ca_certificate_secret_content["ca-certificate"], + ca_key=ca_certificate_secret_content["private-key"], + csr=csr, validity=self._config_certificate_validity, is_ca=is_ca, - ).decode() + ) self.tls_certificates.set_relation_certificate( - certificate_signing_request=csr, - certificate=certificate, - ca=ca_certificate_secret_content["ca-certificate"], - chain=[ca_certificate_secret_content["ca-certificate"], certificate], - relation_id=relation_id, + provider_certificate=ProviderCertificate( + relation_id=relation_id, + certificate=Certificate.from_string(certificate), + certificate_signing_request=CertificateSigningRequest.from_string(csr), + ca=Certificate.from_string(ca_certificate_secret_content["ca-certificate"]), + chain=[ + Certificate.from_string(ca_certificate_secret_content["ca-certificate"]), + Certificate.from_string(certificate), + ], + ), ) logger.info("Generated certificate for relation %s", relation_id) @@ -328,14 +303,5 @@ def tempo_otlp_http_endpoint(self) -> Optional[str]: return None -def generate_password() -> str: - """Generate a random string containing 64 bytes. - - Returns: - str: Password - """ - return secrets.token_hex(64) - - if __name__ == "__main__": main(SelfSignedCertificatesCharm) diff --git a/tests/integration/certificates.py b/tests/integration/certificate.py similarity index 100% rename from tests/integration/certificates.py rename to tests/integration/certificate.py diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index c43d2ae..043e1a4 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -11,7 +11,7 @@ import pytest import yaml -from certificates import get_common_name_from_certificate +from certificate import get_common_name_from_certificate from pytest_operator.plugin import OpsTest logger = logging.getLogger(__name__) diff --git a/tests/unit/certificates_helpers.py b/tests/unit/certificates_helpers.py new file mode 100644 index 0000000..21860b0 --- /dev/null +++ b/tests/unit/certificates_helpers.py @@ -0,0 +1,169 @@ +# Copyright 2024 Canonical Ltd. +# See LICENSE file for licensing details. + +import uuid +from datetime import datetime, timedelta, timezone +from typing import List, Optional + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa + + +def generate_private_key( + key_size: int = 2048, + public_exponent: int = 65537, +) -> str: + """Generate a private key. + + Args: + key_size (int): Key size in bytes + public_exponent: Public exponent. + + Returns: + str: Private Key + """ + private_key = rsa.generate_private_key( + public_exponent=public_exponent, + key_size=key_size, + ) + key_bytes = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + return key_bytes.decode().strip() + + +def generate_csr( + private_key: str, + common_name: str, + sans_dns: Optional[List[str]] = None, +) -> str: + """Generate a CSR using private key and subject. + + Args: + private_key (str): Private key + common_name (str): CSR common name. + sans_dns (list): List of subject alternative names + + Returns: + str: CSR + """ + signing_key = serialization.load_pem_private_key(private_key.encode(), password=None) + unique_identifier = uuid.uuid4() + subject_name = [x509.NameAttribute(x509.NameOID.COMMON_NAME, common_name)] + subject_name.append( + x509.NameAttribute(x509.NameOID.X500_UNIQUE_IDENTIFIER, str(unique_identifier)) + ) + csr = x509.CertificateSigningRequestBuilder(subject_name=x509.Name(subject_name)) + + _sans: List[x509.GeneralName] = [] + if sans_dns: + _sans.extend([x509.DNSName(san) for san in sans_dns]) + if _sans: + csr = csr.add_extension(x509.SubjectAlternativeName(set(_sans)), critical=False) + signed_certificate = csr.sign(signing_key, hashes.SHA256()) # type: ignore[arg-type] + return signed_certificate.public_bytes(serialization.Encoding.PEM).decode().strip() + + +def generate_certificate( + csr: str, + ca: str, + ca_key: str, + validity: int = 24 * 365, +) -> str: + """Generate a TLS certificate based on a CSR. + + Args: + csr (str): CSR + ca (str): CA Certificate + ca_key (str): CA private key + validity (int): Certificate validity (in hours) + + Returns: + str: Certificate + """ + csr_object = x509.load_pem_x509_csr(csr.encode()) + csr_subject = csr_object.subject + csr_common_name = csr_subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)[0].value + issuer = x509.load_pem_x509_certificate(ca.encode()).issuer + private_key = serialization.load_pem_private_key(ca_key.encode(), password=None) + subject = x509.Name( + [ + x509.NameAttribute(x509.NameOID.COMMON_NAME, csr_common_name), + ] + ) + + if validity > 0: + not_valid_before = datetime.now(timezone.utc) + not_valid_after = datetime.now(timezone.utc) + timedelta(hours=validity) + else: + not_valid_before = datetime.now(timezone.utc) + timedelta(hours=validity) + not_valid_after = datetime.now(timezone.utc) - timedelta(seconds=1) + certificate_builder = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(csr_object.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(not_valid_before) + .not_valid_after(not_valid_after) + ) + + certificate_builder._version = x509.Version.v3 + cert = certificate_builder.sign(private_key, hashes.SHA256()) # type: ignore[arg-type] + return cert.public_bytes(serialization.Encoding.PEM).decode().strip() + + +def generate_ca( + private_key: str, + common_name: str, + validity: int = 365, +) -> str: + """Generate a CA Certificate. + + Args: + private_key (bytes): Private key + common_name (str): Certificate common name. + validity (int): Certificate validity time (in days) + country (str): Certificate Issuing country + + Returns: + str: CA Certificate + """ + private_key_object = serialization.load_pem_private_key(private_key.encode(), password=None) + subject = issuer = x509.Name( + [ + x509.NameAttribute(x509.NameOID.COMMON_NAME, common_name), + ] + ) + subject_identifier_object = x509.SubjectKeyIdentifier.from_public_key( + private_key_object.public_key() # type: ignore[arg-type] + ) + subject_identifier = key_identifier = subject_identifier_object.public_bytes() + + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(private_key_object.public_key()) # type: ignore[arg-type] + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=validity)) + .add_extension(x509.SubjectKeyIdentifier(digest=subject_identifier), critical=False) + .add_extension( + x509.AuthorityKeyIdentifier( + key_identifier=key_identifier, + authority_cert_issuer=None, + authority_cert_serial_number=None, + ), + critical=False, + ) + .add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + .sign(private_key_object, hashes.SHA256()) # type: ignore[arg-type] + ) + return cert.public_bytes(serialization.Encoding.PEM).decode().strip() diff --git a/tests/unit/test_certificates.py b/tests/unit/test_certificates.py new file mode 100644 index 0000000..6266ade --- /dev/null +++ b/tests/unit/test_certificates.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python3 +# Copyright 2024 Canonical Ltd. +# See LICENSE file for licensing details. + + +from charm import ( + generate_ca, + generate_certificate, + generate_private_key, +) +from cryptography import x509 +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import padding, rsa +from cryptography.hazmat.primitives.serialization import load_pem_private_key + +from tests.unit.certificates_helpers import ( + generate_ca as generate_ca_helper, +) +from tests.unit.certificates_helpers import ( + generate_csr as generate_csr_helper, +) +from tests.unit.certificates_helpers import ( + generate_private_key as generate_private_key_helper, +) + + +def test_given_no_password_when_generate_private_key_then_key_is_generated_and_loadable(): + private_key = generate_private_key() + + load_pem_private_key(data=private_key.encode(), password=None) + + +def test_given_key_size_provided_when_generate_private_key_then_private_key_is_generated(): + key_size = 1234 + + private_key = generate_private_key(key_size=key_size) + + private_key_object = load_pem_private_key(private_key.encode(), password=None) + assert isinstance(private_key_object, rsa.RSAPrivateKeyWithSerialization) + assert private_key_object.key_size == key_size + + +def test_given_private_key_and_subject_when_generate_ca_then_ca_is_generated_correctly(): + subject = "certifier.example.com" + private_key = generate_private_key_helper() + + certifier_pem = generate_ca(private_key=private_key, subject=subject) + + cert = x509.load_pem_x509_certificate(certifier_pem.encode()) + private_key_object = load_pem_private_key(private_key.encode(), password=None) + certificate_public_key = cert.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.PKCS1, + ) + initial_public_key = private_key_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.PKCS1, + ) + + assert cert.issuer == x509.Name( + [ + x509.NameAttribute(x509.NameOID.COUNTRY_NAME, "US"), + x509.NameAttribute(x509.NameOID.COMMON_NAME, subject), + ] + ) + assert cert.subject == x509.Name( + [ + x509.NameAttribute(x509.NameOID.COUNTRY_NAME, "US"), + x509.NameAttribute(x509.NameOID.COMMON_NAME, subject), + ] + ) + assert certificate_public_key == initial_public_key + assert ( + x509.KeyUsage( + digital_signature=True, + key_encipherment=True, + key_cert_sign=True, + key_agreement=False, + content_commitment=False, + data_encipherment=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, + ) + == cert.extensions.get_extension_for_class(x509.KeyUsage).value + ) + assert cert.extensions.get_extension_for_class(x509.KeyUsage).critical + + +def test_given_csr_and_ca_when_generate_certificate_then_certificate_is_generated_with_correct_subject_and_issuer(): # noqa: E501 + ca_subject = "whatever.ca.subject" + csr_subject = "whatever.csr.subject" + ca_key = generate_private_key_helper() + ca = generate_ca_helper(private_key=ca_key, common_name=ca_subject) + csr_private_key = generate_private_key_helper() + csr = generate_csr_helper( + private_key=csr_private_key, + common_name=csr_subject, + ) + + certificate = generate_certificate( + csr=csr, + ca=ca, + ca_key=ca_key, + ) + + certificate_object = x509.load_pem_x509_certificate(certificate.encode()) + assert certificate_object.issuer == x509.Name( + [ + x509.NameAttribute(x509.NameOID.COMMON_NAME, ca_subject), + ] + ) + subject_name_attributes = certificate_object.subject.get_attributes_for_oid( + x509.NameOID.COMMON_NAME + ) + assert subject_name_attributes[0] == x509.NameAttribute(x509.NameOID.COMMON_NAME, csr_subject) + + +def test_given_csr_and_ca_when_generate_certificate_then_certificate_is_generated_with_correct_sans(): # noqa: E501 + ca_subject = "ca.subject" + csr_subject = "csr.subject" + sans = ["www.localhost.com", "www.test.com"] + + ca_key = generate_private_key_helper() + ca = generate_ca_helper( + private_key=ca_key, + common_name=ca_subject, + ) + csr_private_key = generate_private_key_helper() + csr = generate_csr_helper( + private_key=csr_private_key, + common_name=csr_subject, + sans_dns=sans, + ) + + certificate = generate_certificate(csr=csr, ca=ca, ca_key=ca_key) + + cert = x509.load_pem_x509_certificate(certificate.encode()) + result_all_sans = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName) + + result_sans_dns = sorted(result_all_sans.value.get_values_for_type(x509.DNSName)) + assert result_sans_dns == sorted(set(sans)) + + +def test_given_private_key_when_generate_ca_then_basic_constraints_extension_is_correctly_populated(): # noqa: E501 + subject = "whatever.ca.subject" + private_key = generate_private_key_helper() + + ca = generate_ca( + private_key=private_key, + subject=subject, + ) + + certificate_object = x509.load_pem_x509_certificate(ca.encode()) + basic_constraints = certificate_object.extensions.get_extension_for_class( + x509.BasicConstraints + ) + assert basic_constraints.value.ca is True + + +def test_given_certificate_created_when_generate_certificate_then_verify_public_key_then_doesnt_throw_exception(): # noqa: E501 + ca_subject = "whatever.ca.subject" + csr_subject = "whatever.csr.subject" + ca_key = generate_private_key_helper() + ca = generate_ca_helper( + private_key=ca_key, + common_name=ca_subject, + ) + csr_private_key = generate_private_key_helper() + csr = generate_csr_helper( + private_key=csr_private_key, + common_name=csr_subject, + ) + + certificate = generate_certificate( + csr=csr, + ca=ca, + ca_key=ca_key, + ) + + certificate_object = x509.load_pem_x509_certificate(certificate.encode()) + private_key_object = load_pem_private_key(ca_key.encode(), password=None) + public_key = private_key_object.public_key() + + public_key.verify( # type: ignore[call-arg, union-attr] + certificate_object.signature, + certificate_object.tbs_certificate_bytes, + padding.PKCS1v15(), # type: ignore[arg-type] + certificate_object.signature_hash_algorithm, # type: ignore[arg-type] + ) + + +def test_given_request_is_for_ca_certificate_when_generate_certificate_then_certificate_is_generated(): # noqa: E501 + ca_private_key = generate_private_key_helper() + ca = generate_ca( + private_key=ca_private_key, + subject="my.demo.ca", + ) + server_private_key = generate_private_key_helper() + + server_csr = generate_csr_helper( + private_key=server_private_key, + common_name="10.10.10.10", + sans_dns=[], + ) + + server_cert = generate_certificate( + csr=server_csr, + ca=ca, + ca_key=ca_private_key, + is_ca=True, + ) + + loaded_server_cert = x509.load_pem_x509_certificate(server_cert.encode()) + + assert ( + loaded_server_cert.extensions.get_extension_for_class(x509.BasicConstraints).value.ca + is True + ) + assert ( + loaded_server_cert.extensions.get_extension_for_class(x509.KeyUsage).value.key_cert_sign + is True + ) + assert ( + loaded_server_cert.extensions.get_extension_for_class(x509.KeyUsage).value.crl_sign is True + ) diff --git a/tests/unit/test_charm.py b/tests/unit/test_charm.py index ddd9149..af9aa19 100644 --- a/tests/unit/test_charm.py +++ b/tests/unit/test_charm.py @@ -3,16 +3,27 @@ import json import unittest -from datetime import datetime -from unittest.mock import Mock, patch +from unittest.mock import patch import ops import ops.testing from charm import SelfSignedCertificatesCharm -from charms.tls_certificates_interface.v3.tls_certificates import ProviderCertificate, RequirerCSR +from charms.tls_certificates_interface.v4.tls_certificates import ( + Certificate, + CertificateSigningRequest, + ProviderCertificate, + RequirerCSR, +) from ops.model import ActiveStatus, BlockedStatus -TLS_LIB_PATH = "charms.tls_certificates_interface.v3.tls_certificates" +from tests.unit.certificates_helpers import ( + generate_ca, + generate_certificate, + generate_csr, + generate_private_key, +) + +TLS_LIB_PATH = "charms.tls_certificates_interface.v4.tls_certificates" class TestCharm(unittest.TestCase): @@ -33,23 +44,32 @@ def test_given_invalid_config_when_config_changed_then_status_is_blocked(self): BlockedStatus("The following configuration values are not valid: ['ca-common-name']"), ) + def test_given_invalid_validity_config_when_config_changed_then_status_is_blocked(self): + self.harness.set_leader(is_leader=True) + key_values = {"ca-common-name": "pizza.com", "certificate-validity": 0} + + self.harness.update_config(key_values=key_values) + + self.harness.evaluate_status() + + self.assertEqual( + self.harness.model.unit.status, + BlockedStatus( + "The following configuration values are not valid: ['certificate-validity']" + ), + ) + @patch("charm.generate_private_key") - @patch("charm.generate_password") @patch("charm.generate_ca") def test_given_valid_config_when_config_changed_then_ca_certificate_is_stored_in_juju_secret( self, patch_generate_ca, - patch_generate_password, patch_generate_private_key, ): ca_certificate_string = "whatever CA certificate" private_key_string = "whatever private key" - private_key_password = "banana" - ca_certificate_bytes = ca_certificate_string.encode() - private_key_bytes = private_key_string.encode() - patch_generate_ca.return_value = ca_certificate_bytes - patch_generate_password.return_value = private_key_password - patch_generate_private_key.return_value = private_key_bytes + patch_generate_ca.return_value = ca_certificate_string + patch_generate_private_key.return_value = private_key_string key_values = {"ca-common-name": "pizza.com", "certificate-validity": 100} self.harness.set_leader(is_leader=True) @@ -61,29 +81,22 @@ def test_given_valid_config_when_config_changed_then_ca_certificate_is_stored_in ca_certificates_secret["ca-certificate"], ca_certificate_string, ) - self.assertEqual( - ca_certificates_secret["private-key-password"], - private_key_password, - ) self.assertEqual( ca_certificates_secret["private-key"], private_key_string, ) - @patch(f"{TLS_LIB_PATH}.TLSCertificatesProvidesV3.revoke_all_certificates") + @patch(f"{TLS_LIB_PATH}.TLSCertificatesProvidesV4.revoke_all_certificates") @patch("charm.generate_private_key") - @patch("charm.generate_password") @patch("charm.generate_ca") def test_given_valid_config_when_config_changed_then_existing_certificates_are_revoked( self, patch_generate_ca, - patch_generate_password, patch_generate_private_key, patch_revoke_all_certificates, ): - patch_generate_ca.return_value = b"whatever CA certificate" - patch_generate_password.return_value = "password" - patch_generate_private_key.return_value = b"whatever private key" + patch_generate_ca.return_value = "whatever CA certificate" + patch_generate_private_key.return_value = "whatever private key" key_values = {"ca-common-name": "pizza.com", "certificate-validity": 100} self.harness.set_leader(is_leader=True) @@ -92,17 +105,14 @@ def test_given_valid_config_when_config_changed_then_existing_certificates_are_r patch_revoke_all_certificates.assert_called() @patch("charm.generate_private_key") - @patch("charm.generate_password") @patch("charm.generate_ca") def test_given_valid_config_when_config_changed_then_status_is_active( self, patch_generate_ca, - patch_generate_password, patch_generate_private_key, ): - patch_generate_ca.return_value = b"whatever CA certificate" - patch_generate_password.return_value = "password" - patch_generate_private_key.return_value = b"whatever private key" + patch_generate_ca.return_value = "whatever CA certificate" + patch_generate_private_key.return_value = "whatever private key" key_values = {"ca-common-name": "pizza.com", "certificate-validity": 100} self.harness.set_leader(is_leader=True) self.harness.update_config(key_values=key_values) @@ -113,19 +123,16 @@ def test_given_valid_config_when_config_changed_then_status_is_active( @patch("charm.certificate_has_common_name") @patch("charm.generate_private_key") - @patch("charm.generate_password") @patch("charm.generate_ca") def test_given_new_common_name_when_config_changed_then_new_root_ca_is_stored( self, patch_generate_ca, - patch_generate_password, patch_generate_private_key, patch_certificate_has_common_name, ): validity = 100 initial_ca = "whatever initial CA certificate" new_ca = "whatever CA certificate" - private_key_password = "password" private_key = "whatever private key" patch_certificate_has_common_name.return_value = False self.harness._backend.secret_add( @@ -133,12 +140,10 @@ def test_given_new_common_name_when_config_changed_then_new_root_ca_is_stored( content={ "ca-certificate": initial_ca, "private-key": private_key, - "private-key-password": private_key_password, }, ) - patch_generate_ca.return_value = new_ca.encode() - patch_generate_password.return_value = private_key_password - patch_generate_private_key.return_value = private_key.encode() + patch_generate_ca.return_value = new_ca + patch_generate_private_key.return_value = private_key key_values = {"ca-common-name": "pizza.com", "certificate-validity": validity} self.harness.set_leader(is_leader=True) @@ -150,8 +155,8 @@ def test_given_new_common_name_when_config_changed_then_new_root_ca_is_stored( assert secret_content["ca-certificate"] == new_ca @patch("charm.certificate_has_common_name") - @patch(f"{TLS_LIB_PATH}.TLSCertificatesProvidesV3.set_relation_certificate") - @patch(f"{TLS_LIB_PATH}.TLSCertificatesProvidesV3.get_outstanding_certificate_requests") + @patch(f"{TLS_LIB_PATH}.TLSCertificatesProvidesV4.set_relation_certificate") + @patch(f"{TLS_LIB_PATH}.TLSCertificatesProvidesV4.get_outstanding_certificate_requests") @patch("charm.generate_certificate") def test_given_outstanding_certificate_requests_when_secret_changed_then_certificates_are_generated( # noqa: E501 self, @@ -160,12 +165,18 @@ def test_given_outstanding_certificate_requests_when_secret_changed_then_certifi patch_set_relation_certificate, patch_certificate_has_common_name, ): - private_key = "whatever" - private_key_password = "whatever" - ca = "whatever CA certificate" - requirer_csr = "whatever CSR" - requirer_is_ca = False - generated_certificate = "whatever certificate" + requirer_private_key = generate_private_key() + provider_private_key = generate_private_key() + provider_ca = generate_ca( + private_key=provider_private_key, + common_name="example.com", + ) + requirer_csr = generate_csr(private_key=requirer_private_key, common_name="example.com") + certificate = generate_certificate( + csr=requirer_csr, + ca=provider_ca, + ca_key=provider_private_key, + ) patch_certificate_has_common_name.return_value = True self.harness.set_leader(is_leader=True) relation_id = self.harness.add_relation( @@ -175,65 +186,43 @@ def test_given_outstanding_certificate_requests_when_secret_changed_then_certifi patch_get_outstanding_certificate_requests.return_value = [ RequirerCSR( relation_id=relation_id, - application_name="tls-requirer", - unit_name="tls-requirer/0", - csr=requirer_csr, - is_ca=requirer_is_ca, + certificate_signing_request=CertificateSigningRequest.from_string(requirer_csr), ), ] - patch_generate_certificate.return_value = generated_certificate.encode() + patch_generate_certificate.return_value = certificate self.harness._backend.secret_add( label="ca-certificates", content={ - "ca-certificate": ca, - "private-key": private_key, - "private-key-password": private_key_password, + "ca-certificate": provider_ca, + "private-key": provider_private_key, }, ) self.harness.update_config() - patch_set_relation_certificate.assert_called_with( - certificate_signing_request=requirer_csr, - certificate=generated_certificate, - ca=ca, - chain=[ca, generated_certificate], + expected_provider_certificate = ProviderCertificate( relation_id=relation_id, + certificate=Certificate.from_string(certificate), + certificate_signing_request=CertificateSigningRequest.from_string(requirer_csr), + ca=Certificate.from_string(provider_ca), + chain=[Certificate.from_string(provider_ca), Certificate.from_string(certificate)], ) - - def test_given_invalid_config_when_certificate_request_then_status_is_blocked(self): - self.harness.set_leader(is_leader=True) - key_values = {"ca-common-name": "pizza.com", "certificate-validity": 0} - self.harness.update_config(key_values=key_values) - self.harness.charm._on_certificate_creation_request(event=Mock()) # type: ignore[reportAttributeAccessIssue] - - self.harness.evaluate_status() - - self.assertEqual( - self.harness.model.unit.status, - BlockedStatus( - "The following configuration values are not valid: ['certificate-validity']" - ), + patch_set_relation_certificate.assert_called_with( + provider_certificate=expected_provider_certificate, ) @patch("charm.generate_private_key") - @patch("charm.generate_password") @patch("charm.generate_ca") def test_given_valid_config_and_unit_is_leader_when_secret_expired_then_new_ca_certificate_is_stored_in_juju_secret( # noqa: E501 self, patch_generate_ca, - patch_generate_password, patch_generate_private_key, ): ca_certificate_string = "whatever CA certificate" private_key_string = "whatever private key" - private_key_password = "banana" - ca_certificate_bytes = ca_certificate_string.encode() - private_key_bytes = private_key_string.encode() - patch_generate_ca.return_value = ca_certificate_bytes - patch_generate_password.return_value = private_key_password - patch_generate_private_key.return_value = private_key_bytes + patch_generate_ca.return_value = ca_certificate_string + patch_generate_private_key.return_value = private_key_string self.harness.set_leader(is_leader=True) mock_secret_id = self.harness.add_model_secret( @@ -250,71 +239,17 @@ def test_given_valid_config_and_unit_is_leader_when_secret_expired_then_new_ca_c ca_certificates_secret["ca-certificate"], ca_certificate_string, ) - self.assertEqual( - ca_certificates_secret["private-key-password"], - private_key_password, - ) self.assertEqual( ca_certificates_secret["private-key"], private_key_string, ) - @patch(f"{TLS_LIB_PATH}.TLSCertificatesProvidesV3.set_relation_certificate") - @patch("charm.generate_certificate") - def test_given_root_certificates_when_certificate_request_then_certificates_are_generated( - self, patch_generate_certificate, patch_set_certificate - ): - is_ca = True - self.harness.set_leader(is_leader=True) - ca_certificate = "whatever CA certificate" - private_key = "whatever private key" - private_key_password = "whatever private_key_password" - certificate = "new certificate" - certificate_signing_request = "whatever CSR" - relation_id = 123 - patch_generate_certificate.return_value = certificate.encode() - - self.harness._backend.secret_add( - label="ca-certificates", - content={ - "ca-certificate": ca_certificate, - "private-key": private_key, - "private-key-password": private_key_password, - }, - ) - - self.harness.charm._on_certificate_creation_request( # type: ignore[reportAttributeAccessIssue] - event=Mock( - relation_id=relation_id, - certificate_signing_request=certificate_signing_request, - is_ca=is_ca, - ) - ) - - patch_generate_certificate.assert_called_with( - ca=ca_certificate.encode(), - ca_key=private_key.encode(), - ca_key_password=private_key_password.encode(), - csr=certificate_signing_request.encode(), - validity=365, - is_ca=is_ca, - ) - patch_set_certificate.assert_called_with( - certificate="new certificate", - ca=ca_certificate, - chain=[ca_certificate, certificate], - relation_id=relation_id, - certificate_signing_request=certificate_signing_request, - ) - @patch("charm.certificate_has_common_name") @patch("charm.generate_private_key") - @patch("charm.generate_password") @patch("charm.generate_ca") def test_given_initial_config_when_config_changed_then_stored_ca_common_name_uses_new_config( self, patch_generate_ca, - patch_generate_password, patch_generate_private_key, patch_certificate_has_common_name, ): @@ -325,15 +260,8 @@ def test_given_initial_config_when_config_changed_then_stored_ca_common_name_use ca_certificate_2_string = "whatever CA certificate 2" private_key_string_1 = "whatever private key 1" private_key_string_2 = "whatever private key 2" - private_key_password_1 = "banana" - private_key_password_2 = "apple" - ca_certificate_bytes_1 = ca_certificate_1_string.encode() - ca_certificate_bytes_2 = ca_certificate_2_string.encode() - private_key_bytes_1 = private_key_string_1.encode() - private_key_bytes_2 = private_key_string_2.encode() - patch_generate_ca.side_effect = [ca_certificate_bytes_1, ca_certificate_bytes_2] - patch_generate_password.side_effect = [private_key_password_1, private_key_password_2] - patch_generate_private_key.side_effect = [private_key_bytes_1, private_key_bytes_2] + patch_generate_ca.side_effect = [ca_certificate_1_string, ca_certificate_2_string] + patch_generate_private_key.side_effect = [private_key_string_1, private_key_string_2] self.harness.set_leader(is_leader=True) self.harness.update_config(key_values={"ca-common-name": initial_common_name}) @@ -346,10 +274,7 @@ def test_given_initial_config_when_config_changed_then_stored_ca_common_name_use secret_content["ca-certificate"], ca_certificate_2_string, ) - self.assertEqual( - secret_content["private-key-password"], - private_key_password_2, - ) + self.assertEqual( secret_content["private-key"], private_key_string_2, @@ -363,56 +288,47 @@ def test_given_no_certificates_issued_when_get_issued_certificates_action_then_a self.assertEqual(e.exception.message, "No certificates issued yet.") - @patch(f"{TLS_LIB_PATH}.TLSCertificatesProvidesV3.get_issued_certificates") + @patch(f"{TLS_LIB_PATH}.TLSCertificatesProvidesV4.get_issued_certificates") def test_given_certificates_issued_when_get_issued_certificates_action_then_action_returns_certificates( # noqa: E501 self, patch_get_issued_certificates, ): - relation_id = 123 - application_name = "tls-requirer" - csr = "whatever csr" - certificate = "whatever certificate" - ca_certificate = "whatever CA certificate" - chain = ["whatever cert 1", "whatever cert 2"] + ca_private_key = generate_private_key() + ca_certificate = generate_ca( + private_key=ca_private_key, + common_name="example.com", + ) + requirer_private_key = generate_private_key() + csr = generate_csr(private_key=requirer_private_key, common_name="example.com") + certificate = generate_certificate( + csr=csr, + ca=ca_certificate, + ca_key=ca_private_key, + ) + chain = [ca_certificate, certificate] revoked = False - expiry_time = datetime.now() - expiry_notification_time = None + cert = Certificate.from_string(certificate) self.harness.set_leader(is_leader=True) patch_get_issued_certificates.return_value = [ ProviderCertificate( - relation_id=relation_id, - application_name=application_name, - csr=csr, - certificate=certificate, - ca=ca_certificate, - chain=chain, + relation_id=1, + certificate_signing_request=CertificateSigningRequest.from_string(csr), + certificate=cert, + ca=Certificate.from_string(ca_certificate), + chain=[Certificate.from_string(c) for c in chain], revoked=revoked, - expiry_time=expiry_time, - expiry_notification_time=expiry_notification_time, ) ] action_output = self.harness.run_action("get-issued-certificates") - expected_certificates = { - "certificates": [ - json.dumps( - { - "relation_id": relation_id, - "application_name": application_name, - "csr": csr, - "certificate": certificate, - "ca": ca_certificate, - "chain": chain, - "revoked": revoked, - "expiry_time": expiry_time.isoformat(), - "expiry_notification_time": expiry_notification_time, - } - ) - ] - } + output_certificate = json.loads(action_output.results["certificates"][0]) - self.assertEqual(action_output.results, expected_certificates) + assert output_certificate["csr"] == csr + assert output_certificate["certificate"] == certificate + assert output_certificate["ca"] == ca_certificate + assert output_certificate["chain"] == chain + assert output_certificate["revoked"] == revoked def test_given_ca_cert_generated_when_get_ca_certificate_action_then_returns_ca_certificate( self,