Skip to content

Commit

Permalink
fix: Makes sure there always is an active issuer in Vault PKI (#506)
Browse files Browse the repository at this point in the history
  • Loading branch information
saltiyazan authored Oct 6, 2024
1 parent ab90f24 commit 8f3f38b
Show file tree
Hide file tree
Showing 9 changed files with 335 additions and 205 deletions.
215 changes: 56 additions & 159 deletions lib/charms/tls_certificates_interface/v4/tls_certificates.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,147 +6,9 @@
> Warning: This is a beta version of the tls-certificates interface library.
> Use at your own risk.
This library contains the Requires and Provides classes for handling the tls-certificates
interface.
Learn how-to use the TLS Certificates interface library by reading the documentation:
- https://charmhub.io/tls-certificates-interface/
Pre-requisites:
- Juju >= 3.0
## Getting Started
From a charm directory, fetch the library using `charmcraft`:
```shell
charmcraft fetch-lib charms.tls_certificates_interface.v4.tls_certificates
```
Add the following libraries to the charm's `requirements.txt` file:
- cryptography >= 42.0.0
- pydantic >= 2.0.0
Add the following section to the charm's `charmcraft.yaml` file:
```yaml
parts:
charm:
build-packages:
- libffi-dev
- libssl-dev
- rustc
- cargo
```
### Requirer charm
The requirer charm is the charm requiring certificates from another charm that provides them.
#### Example
In the following example, the requiring charm requests a certificate using attributes
from the Juju configuration options.
```python
from typing import List, Optional, cast
from ops.charm import ActionEvent, CharmBase
from ops.main import main
from lib.charms.tls_certificates_interface.v4.tls_certificates import (
CertificateAvailableEvent,
CertificateRequest,
Mode,
TLSCertificatesRequiresV4,
)
class DummyTLSCertificatesRequirerCharm(CharmBase):
def __init__(self, *args):
super().__init__(*args)
certificate_requests = self._get_certificate_requests()
self.certificates = TLSCertificatesRequiresV4(
charm=self,
relationship_name="certificates",
certificate_requests=certificate_requests,
mode=Mode.UNIT,
refresh_events=[self.on.config_changed],
)
self.framework.observe(
self.certificates.on.certificate_available, self._on_certificate_available
)
self.framework.observe(
self.on.regenerate_private_key_action, self._on_regenerate_private_key_action
)
self.framework.observe(self.on.get_certificate_action, self._on_get_certificate_action)
def _get_certificate_requests(self) -> List[CertificateRequest]:
if not self._get_config_common_name():
return []
return [
CertificateRequest(
common_name=self._get_config_common_name(),
sans_dns=self._get_config_sans_dns(),
organization=self._get_config_organization_name(),
organizational_unit=self._get_config_organization_unit_name(),
email_address=self._get_config_email_address(),
country_name=self._get_config_country_name(),
state_or_province_name=self._get_config_state_or_province_name(),
locality_name=self._get_config_locality_name(),
)
]
def _on_certificate_available(self, event: CertificateAvailableEvent) -> None:
print("Certificate available")
def _on_regenerate_private_key_action(self, event: ActionEvent) -> None:
self.certificates.regenerate_private_key()
def _on_get_certificate_action(self, event: ActionEvent) -> None:
certificate, _ = self.certificates.get_assigned_certificate(
certificate_request=self._get_certificate_requests()[0]
)
if not certificate:
event.fail("Certificate not available")
return
event.set_results(
{
"certificate": str(certificate.certificate),
"ca": str(certificate.ca),
"csr": str(certificate.certificate_signing_request),
}
)
def _get_config_common_name(self) -> str:
return cast(str, self.model.config.get("common_name"))
def _get_config_sans_dns(self) -> List[str]:
config_sans_dns = cast(str, self.model.config.get("sans_dns", ""))
return config_sans_dns.split(",") if config_sans_dns else []
def _get_config_organization_name(self) -> Optional[str]:
return cast(str, self.model.config.get("organization_name"))
def _get_config_organization_unit_name(self) -> Optional[str]:
return cast(str, self.model.config.get("organization_unit_name"))
def _get_config_email_address(self) -> Optional[str]:
return cast(str, self.model.config.get("email_address"))
def _get_config_country_name(self) -> Optional[str]:
return cast(str, self.model.config.get("country_name"))
def _get_config_state_or_province_name(self) -> Optional[str]:
return cast(str, self.model.config.get("state_or_province_name"))
def _get_config_locality_name(self) -> Optional[str]:
return cast(str, self.model.config.get("locality_name"))
if __name__ == "__main__":
main(DummyTLSCertificatesRequirerCharm)
```
You can integrate both charms by running:
```bash
juju integrate <tls-certificates provider charm> <tls-certificates requirer charm>
```
""" # noqa: D214, D405, D411, D416

import copy
Expand Down Expand Up @@ -185,7 +47,7 @@ def _get_config_locality_name(self) -> Optional[str]:

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 6
LIBPATCH = 9

PYDEPS = ["cryptography", "pydantic"]

Expand Down Expand Up @@ -730,9 +592,9 @@ def _get_closest_future_time(


def calculate_expiry_notification_time(
validity_start_time: datetime,
expiry_time: datetime,
provider_recommended_notification_time: Optional[int],
not_valid_before: datetime,
not_valid_after: datetime,
provider_recommended_notification_time: Optional[int] = None,
) -> datetime:
"""Calculate a reasonable time to notify the user about the certificate expiry.
Expand All @@ -741,8 +603,8 @@ def calculate_expiry_notification_time(
then dynamically calculated time.
Args:
validity_start_time: Certificate validity time
expiry_time: Certificate expiry time
not_valid_before: Time when the certificate is valid from.
not_valid_after: Time when the certificate is valid until.
provider_recommended_notification_time:
Time in hours prior to expiry to notify the user.
Recommended by the provider.
Expand All @@ -752,13 +614,16 @@ def calculate_expiry_notification_time(
"""
if provider_recommended_notification_time is not None:
provider_recommended_notification_time = abs(provider_recommended_notification_time)
provider_recommendation_time_delta = expiry_time - timedelta(
provider_recommendation_time_delta = not_valid_after - timedelta(
hours=provider_recommended_notification_time
)
if validity_start_time < provider_recommendation_time_delta:
if not_valid_before < provider_recommendation_time_delta:
return provider_recommendation_time_delta
calculated_hours = (expiry_time - validity_start_time).total_seconds() / (3600 * 3)
return expiry_time - timedelta(hours=calculated_hours)
# Divide the time between not_valid_after and not_valid_before by 3
# For example, if there are 3 days between not_valid_after and not_valid_before,
# the notification time will be 1 day before not_valid_after.
calculated_time = (not_valid_after - not_valid_before) / 3
return not_valid_after - calculated_time


def generate_private_key(
Expand Down Expand Up @@ -862,7 +727,7 @@ def generate_csr( # noqa: C901

def generate_ca(
private_key: PrivateKey,
validity: int,
validity: timedelta,
common_name: str,
sans_dns: Optional[FrozenSet[str]] = frozenset(),
sans_ip: Optional[FrozenSet[str]] = frozenset(),
Expand All @@ -878,7 +743,7 @@ def generate_ca(
Args:
private_key (PrivateKey): Private key
validity (int): Certificate validity time (in days)
validity (timedelta): Certificate validity time
common_name (str): Common Name that can be an IP or a Full Qualified Domain Name (FQDN).
sans_dns (FrozenSet[str]): DNS Subject Alternative Names
sans_ip (FrozenSet[str]): IP Subject Alternative Names
Expand Down Expand Up @@ -945,7 +810,7 @@ def generate_ca(
.public_key(private_key_object.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.now(timezone.utc))
.not_valid_after(datetime.now(timezone.utc) + timedelta(days=validity))
.not_valid_after(datetime.now(timezone.utc) + validity)
.add_extension(x509.SubjectAlternativeName(set(_sans)), critical=False)
.add_extension(x509.SubjectKeyIdentifier(digest=subject_identifier), critical=False)
.add_extension(
Expand All @@ -971,7 +836,7 @@ def generate_certificate(
csr: CertificateSigningRequest,
ca: Certificate,
ca_private_key: PrivateKey,
validity: int,
validity: timedelta,
is_ca: bool = False,
) -> Certificate:
"""Generate a TLS certificate based on a CSR.
Expand All @@ -980,7 +845,7 @@ def generate_certificate(
csr (CertificateSigningRequest): CSR
ca (Certificate): CA Certificate
ca_private_key (PrivateKey): CA private key
validity (int): Certificate validity (in days)
validity (timedelta): Certificate validity time
is_ca (bool): Whether the certificate is a CA certificate
Returns:
Expand All @@ -999,7 +864,7 @@ def generate_certificate(
.public_key(csr_object.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.now(timezone.utc))
.not_valid_after(datetime.now(timezone.utc) + timedelta(days=validity))
.not_valid_after(datetime.now(timezone.utc) + validity)
)
extensions = _get_certificate_request_extensions(
authority_key_identifier=ca_pem.extensions.get_extension_for_class(
Expand Down Expand Up @@ -1192,12 +1057,28 @@ def _on_secret_expired(self, event: SecretExpiredEvent) -> None:
try:
csr_str = event.secret.get_content(refresh=True)["csr"]
except ModelError:
logger.error("Failed to get CSR from secret - Skipping renewal")
logger.error("Failed to get CSR from secret - Skipping")
return
csr = CertificateSigningRequest.from_string(csr_str)
self._renew_certificate_request(csr)
event.secret.remove_all_revisions()

def renew_certificate(self, certificate: ProviderCertificate) -> None:
"""Request the renewal of the provided certificate."""
certificate_signing_request = certificate.certificate_signing_request
secret_label = self._get_csr_secret_label(certificate_signing_request)
try:
secret = self.model.get_secret(label=secret_label)
except SecretNotFoundError:
logger.warning("No matching secret found - Skipping renewal")
return
current_csr = secret.get_content(refresh=True).get("csr", "")
if current_csr != str(certificate_signing_request):
logger.warning("No matching CSR found - Skipping renewal")
return
self._renew_certificate_request(certificate_signing_request)
secret.remove_all_revisions()

def _renew_certificate_request(self, csr: CertificateSigningRequest):
"""Remove existing CSR from relation data and create a new one."""
self._remove_requirer_csr_from_relation_data(csr)
Expand Down Expand Up @@ -1552,8 +1433,8 @@ def _get_next_secret_expiry_time(
logger.warning("Certificate has no validity start time")
return None
expiry_notification_time = calculate_expiry_notification_time(
validity_start_time=provider_certificate.certificate.validity_start_time,
expiry_time=provider_certificate.certificate.expiry_time,
not_valid_before=provider_certificate.certificate.validity_start_time,
not_valid_after=provider_certificate.certificate.expiry_time,
provider_recommended_notification_time=provider_certificate.recommended_expiry_notification_time,
)
if not expiry_notification_time:
Expand Down Expand Up @@ -1793,6 +1674,22 @@ def get_provider_certificates(
certificates.append(certificate.to_provider_certificate(relation_id=relation.id))
return certificates

def get_unsolicited_certificates(
self, relation_id: Optional[int] = None
) -> List[ProviderCertificate]:
"""Return provider certificates for which no certificate requests exists.
Those certificates should be revoked.
"""
unsolicited_certificates: List[ProviderCertificate] = []
provider_certificates = self.get_provider_certificates(relation_id=relation_id)
requirer_csrs = self.get_certificate_requests(relation_id=relation_id)
list_of_csrs = [csr.certificate_signing_request for csr in requirer_csrs]
for certificate in provider_certificates:
if certificate.certificate_signing_request not in list_of_csrs:
unsolicited_certificates.append(certificate)
return unsolicited_certificates

def get_outstanding_certificate_requests(
self, relation_id: Optional[int] = None
) -> List[RequirerCSR]:
Expand Down
Loading

0 comments on commit 8f3f38b

Please sign in to comment.