Skip to content

Commit

Permalink
chore: Use tls lib V4.0 (#515)
Browse files Browse the repository at this point in the history
  • Loading branch information
saltiyazan authored Oct 10, 2024
1 parent 6f63c32 commit dfc68d1
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 139 deletions.
167 changes: 41 additions & 126 deletions lib/charms/tls_certificates_interface/v4/tls_certificates.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
# Copyright 2024 Canonical Ltd.
# See LICENSE file for licensing details.

"""Charm library for managing TLS certificates (V4) - BETA.
"""Charm library for managing TLS certificates (V4).
> Warning: This is a beta version of the tls-certificates interface library.
> Use at your own risk.
This library contains the Requires and Provides classes for handling the tls-certificates
interface.
Learn how-to use the TLS Certificates interface library by reading the documentation:
Pre-requisites:
- Juju >= 3.0
- cryptography >= 43.0.0
- pydantic
Learn more on how-to use the TLS Certificates interface library by reading the documentation:
- https://charmhub.io/tls-certificates-interface/
""" # noqa: D214, D405, D411, D416
Expand Down Expand Up @@ -47,7 +52,7 @@

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 9
LIBPATCH = 0

PYDEPS = ["cryptography", "pydantic"]

Expand Down Expand Up @@ -138,7 +143,6 @@ class _Certificate(BaseModel):
certificate_signing_request: str
certificate: str
chain: Optional[List[str]] = None
recommended_expiry_notification_time: Optional[int] = None
revoked: Optional[bool] = None

def to_provider_certificate(self, relation_id: int) -> "ProviderCertificate":
Expand All @@ -153,7 +157,6 @@ def to_provider_certificate(self, relation_id: int) -> "ProviderCertificate":
chain=[Certificate.from_string(certificate) for certificate in self.chain]
if self.chain
else [],
recommended_expiry_notification_time=self.recommended_expiry_notification_time,
revoked=self.revoked,
)

Expand Down Expand Up @@ -215,6 +218,8 @@ class Certificate:

raw: str
common_name: str
expiry_time: datetime
validity_start_time: datetime
is_ca: bool = False
sans_dns: Optional[FrozenSet[str]] = frozenset()
sans_ip: Optional[FrozenSet[str]] = frozenset()
Expand All @@ -225,8 +230,6 @@ class Certificate:
country_name: Optional[str] = None
state_or_province_name: Optional[str] = None
locality_name: Optional[str] = None
expiry_time: Optional[datetime] = None
validity_start_time: Optional[datetime] = None

def __str__(self) -> str:
"""Return the certificate as a string."""
Expand Down Expand Up @@ -424,8 +427,8 @@ def get_sha256_hex(self) -> str:


@dataclass(frozen=True)
class CertificateRequest:
"""This class represents a certificate request.
class CertificateRequestAttributes:
"""A representation of the certificate request attributes.
This class should be used inside the requirer charm to specify the requested
attributes for the certificate.
Expand Down Expand Up @@ -477,7 +480,7 @@ def generate_csr(

@classmethod
def from_csr(cls, csr: CertificateSigningRequest, is_ca: bool):
"""Create a CertificateRequest object from a CSR."""
"""Create a CertificateRequestAttributes object from a CSR."""
return cls(
common_name=csr.common_name,
sans_dns=csr.sans_dns,
Expand All @@ -502,7 +505,6 @@ class ProviderCertificate:
certificate_signing_request: CertificateSigningRequest
ca: Certificate
chain: List[Certificate]
recommended_expiry_notification_time: Optional[int] = None
revoked: Optional[bool] = None

def to_json(self) -> str:
Expand All @@ -523,8 +525,8 @@ def to_json(self) -> str:


@dataclass(frozen=True)
class RequirerCSR:
"""This class represents a certificate signing request requested by the TLS requirer."""
class RequirerCertificateRequest:
"""This class represents a certificate signing request requested by a specific TLS requirer."""

relation_id: int
certificate_signing_request: CertificateSigningRequest
Expand Down Expand Up @@ -572,60 +574,6 @@ def chain_as_pem(self) -> str:
return "\n\n".join([str(cert) for cert in self.chain])


def _get_closest_future_time(
expiry_notification_time: datetime, expiry_time: datetime
) -> datetime:
"""Return expiry_notification_time if not in the past, otherwise return expiry_time.
Args:
expiry_notification_time (datetime): Notification time of impending expiration
expiry_time (datetime): Expiration time
Returns:
datetime: expiry_notification_time if not in the past, expiry_time otherwise
"""
return (
expiry_notification_time
if datetime.now(timezone.utc) < expiry_notification_time
else expiry_time
)


def calculate_expiry_notification_time(
not_valid_before: datetime,
not_valid_after: datetime,
provider_recommended_notification_time: Optional[int] = None,
) -> datetime:
"""Calculate a reasonable time to notify the user about the certificate expiry.
It takes into account the time recommended by the provider.
Time recommended by the provider is preferred,
then dynamically calculated time.
Args:
not_valid_before: Time when the certificate is valid from.
not_valid_after: Time when the certificate is valid until.
provider_recommended_notification_time:
Time in hours prior to expiry to notify the user.
Recommended by the provider.
Returns:
datetime: Time to notify the user about the certificate expiry.
"""
if provider_recommended_notification_time is not None:
provider_recommended_notification_time = abs(provider_recommended_notification_time)
provider_recommendation_time_delta = not_valid_after - timedelta(
hours=provider_recommended_notification_time
)
if not_valid_before < provider_recommendation_time_delta:
return provider_recommendation_time_delta
# Divide the time between not_valid_after and not_valid_before by 3
# For example, if there are 3 days between not_valid_after and not_valid_before,
# the notification time will be 1 day before not_valid_after.
calculated_time = (not_valid_after - not_valid_before) / 3
return not_valid_after - calculated_time


def generate_private_key(
key_size: int = 2048,
public_exponent: int = 65537,
Expand Down Expand Up @@ -996,7 +944,7 @@ def __init__(
self,
charm: CharmBase,
relationship_name: str,
certificate_requests: List[CertificateRequest],
certificate_requests: List[CertificateRequestAttributes],
mode: Mode = Mode.UNIT,
refresh_events: List[BoundEvent] = [],
):
Expand All @@ -1005,7 +953,8 @@ def __init__(
Args:
charm (CharmBase): The charm instance to relate to.
relationship_name (str): The name of the relation that provides the certificates.
certificate_requests (List[CertificateRequest]): A list of certificate requests.
certificate_requests (List[CertificateRequestAttributes]):
A list with the attributes of the certificate requests.
mode (Mode): Whether to use unit or app certificates mode. Default is Mode.UNIT.
refresh_events (List[BoundEvent]): A list of events to trigger a refresh of
the certificates.
Expand Down Expand Up @@ -1165,14 +1114,14 @@ def _csr_matches_certificate_request(
self, certificate_signing_request: CertificateSigningRequest, is_ca: bool
) -> bool:
for certificate_request in self.certificate_requests:
if certificate_request == CertificateRequest.from_csr(
if certificate_request == CertificateRequestAttributes.from_csr(
certificate_signing_request,
is_ca,
):
return True
return False

def _certificate_requested(self, certificate_request: CertificateRequest) -> bool:
def _certificate_requested(self, certificate_request: CertificateRequestAttributes) -> bool:
if not self.private_key:
return False
csr = self._certificate_requested_for_attributes(certificate_request)
Expand All @@ -1184,17 +1133,17 @@ def _certificate_requested(self, certificate_request: CertificateRequest) -> boo

def _certificate_requested_for_attributes(
self,
certificate_request: CertificateRequest,
) -> Optional[RequirerCSR]:
certificate_request: CertificateRequestAttributes,
) -> Optional[RequirerCertificateRequest]:
for requirer_csr in self.get_csrs_from_requirer_relation_data():
if certificate_request == CertificateRequest.from_csr(
if certificate_request == CertificateRequestAttributes.from_csr(
requirer_csr.certificate_signing_request,
requirer_csr.is_ca,
):
return requirer_csr
return None

def get_csrs_from_requirer_relation_data(self) -> List[RequirerCSR]:
def get_csrs_from_requirer_relation_data(self) -> List[RequirerCertificateRequest]:
"""Return list of requirer's CSRs from relation data."""
if self.mode == Mode.APP and not self.model.unit.is_leader():
logger.debug("Not a leader unit - Skipping")
Expand All @@ -1212,7 +1161,7 @@ def get_csrs_from_requirer_relation_data(self) -> List[RequirerCSR]:
requirer_csrs = []
for csr in requirer_relation_data.certificate_signing_requests:
requirer_csrs.append(
RequirerCSR(
RequirerCertificateRequest(
relation_id=relation.id,
certificate_signing_request=CertificateSigningRequest.from_string(
csr.certificate_signing_request
Expand Down Expand Up @@ -1288,11 +1237,11 @@ def _send_certificate_requests(self):
self._request_certificate(csr=csr, is_ca=certificate_request.is_ca)

def get_assigned_certificate(
self, certificate_request: CertificateRequest
self, certificate_request: CertificateRequestAttributes
) -> Tuple[ProviderCertificate | None, PrivateKey | None]:
"""Get the certificate that was assigned to the given certificate request."""
for requirer_csr in self.get_csrs_from_requirer_relation_data():
if certificate_request == CertificateRequest.from_csr(
if certificate_request == CertificateRequestAttributes.from_csr(
requirer_csr.certificate_signing_request,
requirer_csr.is_ca,
):
Expand All @@ -1308,7 +1257,7 @@ def get_assigned_certificates(self) -> Tuple[List[ProviderCertificate], PrivateK
return assigned_certificates, self.private_key

def _find_certificate_in_relation_data(
self, csr: RequirerCSR
self, csr: RequirerCertificateRequest
) -> Optional[ProviderCertificate]:
"""Return the certificate that match the given CSR."""
for provider_certificate in self.get_provider_certificates():
Expand Down Expand Up @@ -1359,7 +1308,7 @@ def _find_available_certificates(self):
}
)
secret.set_info(
expire=self._get_next_secret_expiry_time(provider_certificate),
expire=provider_certificate.certificate.expiry_time,
)
except SecretNotFoundError:
logger.debug("Creating new secret with label %s", secret_label)
Expand All @@ -1369,7 +1318,7 @@ def _find_available_certificates(self):
"csr": str(provider_certificate.certificate_signing_request),
},
label=secret_label,
expire=self._get_next_secret_expiry_time(provider_certificate),
expire=provider_certificate.certificate.expiry_time,
)
self.on.certificate_available.emit(
certificate_signing_request=provider_certificate.certificate_signing_request,
Expand Down Expand Up @@ -1410,41 +1359,6 @@ def _cleanup_certificate_requests(self):
"Removed CSR from relation data because it did not match the private key" # noqa: E501
)

def _get_next_secret_expiry_time(
self, provider_certificate: ProviderCertificate
) -> Optional[datetime]:
"""Return the expiry time or expiry notification time.
Extracts the expiry time from the provided certificate, calculates the
expiry notification time and return the closest of the two, that is in
the future.
Args:
provider_certificate: ProviderCertificate object
Returns:
Optional[datetime]: None if the certificate expiry time cannot be read,
next expiry time otherwise.
"""
if not provider_certificate.certificate.expiry_time:
logger.warning("Certificate has no expiry time")
return None
if not provider_certificate.certificate.validity_start_time:
logger.warning("Certificate has no validity start time")
return None
expiry_notification_time = calculate_expiry_notification_time(
not_valid_before=provider_certificate.certificate.validity_start_time,
not_valid_after=provider_certificate.certificate.expiry_time,
provider_recommended_notification_time=provider_certificate.recommended_expiry_notification_time,
)
if not expiry_notification_time:
logger.warning("Could not calculate expiry notification time")
return None
return _get_closest_future_time(
expiry_notification_time,
provider_certificate.certificate.expiry_time,
)

def _tls_relation_created(self) -> bool:
relation = self.model.get_relation(self.relationship_name)
if not relation:
Expand Down Expand Up @@ -1520,10 +1434,12 @@ def _get_tls_relations(self, relation_id: Optional[int] = None) -> List[Relation
else self.model.relations.get(self.relationship_name, [])
)

def get_certificate_requests(self, relation_id: Optional[int] = None) -> List[RequirerCSR]:
def get_certificate_requests(
self, relation_id: Optional[int] = None
) -> List[RequirerCertificateRequest]:
"""Load certificate requests from the relation data."""
relations = self._get_tls_relations(relation_id)
requirer_csrs: List[RequirerCSR] = []
requirer_csrs: List[RequirerCertificateRequest] = []
for relation in relations:
for unit in relation.units:
requirer_csrs.extend(self._load_requirer_databag(relation, unit))
Expand All @@ -1532,14 +1448,14 @@ def get_certificate_requests(self, relation_id: Optional[int] = None) -> List[Re

def _load_requirer_databag(
self, relation: Relation, unit_or_app: Union[Application, Unit]
) -> List[RequirerCSR]:
) -> List[RequirerCertificateRequest]:
try:
requirer_relation_data = _RequirerData.load(relation.data[unit_or_app])
except DataValidationError:
logger.debug("Invalid requirer relation data for %s", unit_or_app.name)
return []
return [
RequirerCSR(
RequirerCertificateRequest(
relation_id=relation.id,
certificate_signing_request=CertificateSigningRequest.from_string(
csr.certificate_signing_request
Expand All @@ -1559,7 +1475,6 @@ def _add_provider_certificate(
certificate_signing_request=str(provider_certificate.certificate_signing_request),
ca=str(provider_certificate.ca),
chain=[str(certificate) for certificate in provider_certificate.chain],
recommended_expiry_notification_time=provider_certificate.recommended_expiry_notification_time,
)
provider_certificates = self._load_provider_certificates(relation)
if new_certificate in provider_certificates:
Expand Down Expand Up @@ -1692,17 +1607,17 @@ def get_unsolicited_certificates(

def get_outstanding_certificate_requests(
self, relation_id: Optional[int] = None
) -> List[RequirerCSR]:
) -> List[RequirerCertificateRequest]:
"""Return CSR's for which no certificate has been issued.
Args:
relation_id (int): Relation id
Returns:
list: List of RequirerCSR objects.
list: List of RequirerCertificateRequest objects.
"""
requirer_csrs = self.get_certificate_requests(relation_id=relation_id)
outstanding_csrs: List[RequirerCSR] = []
outstanding_csrs: List[RequirerCertificateRequest] = []
for relation_csr in requirer_csrs:
if not self._certificate_issued_for_csr(
csr=relation_csr.certificate_signing_request,
Expand Down
6 changes: 3 additions & 3 deletions lib/charms/vault_k8s/v0/vault_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from charms.tls_certificates_interface.v4.tls_certificates import (
Certificate,
CertificateRequest,
CertificateRequestAttributes,
PrivateKey,
TLSCertificatesRequiresV4,
generate_ca,
Expand Down Expand Up @@ -200,11 +200,11 @@ def _configure_ca_cert_relation(self, event: EventBase):
"""Send the CA certificate to the relation."""
self.send_ca_cert()

def _get_certificate_requests(self) -> List[CertificateRequest]:
def _get_certificate_requests(self) -> List[CertificateRequestAttributes]:
if not self.common_name:
return []
return [
CertificateRequest(
CertificateRequestAttributes(
common_name=self.common_name, sans_dns=self.sans_dns, sans_ip=self.sans_ip
)
]
Expand Down
Loading

0 comments on commit dfc68d1

Please sign in to comment.