Source code for dxaws_acm.providers.aws

from __future__ import annotations

import hashlib
from dataclasses import dataclass
from typing import Optional

from dxaws_acm.exceptions import ProviderError
from dxaws_core import AwsSession
from dxaws_core.providers.aws import AwsProviderBase


[docs] @dataclass(frozen=True, slots=True, kw_only=True) class CertificateSummary: """Normalized certificate summary returned by list operations. This is intentionally *not* a boto3 shape. Note: - `status` may be empty when returned from list operations unless status enrichment is requested. """ certificate_arn: str domain_name: str status: str
[docs] @dataclass(frozen=True, slots=True, kw_only=True) class ValidationRecord: """Normalized DNS validation record required by ACM. Notes: - `name` is canonical FQDN with trailing dot. - `value` is canonical FQDN with trailing dot (ACM typically returns this already). """ domain_name: str name: str type: str # usually "CNAME" value: str
[docs] @dataclass(frozen=True, slots=True, kw_only=True) class CertificateInfo: """Normalized certificate information returned by describe operations. This is intentionally *not* a boto3 shape. """ certificate_arn: str status: str domain_name: str # primary DomainName (normalized lower-case) domains: list[str] # DomainName + SANs (normalized lower-case) validation_records: list[ValidationRecord] failure_reason: Optional[str]
[docs] class AwsProvider(AwsProviderBase): """ACM provider. All boto3 and AWS-specific behavior lives in this class. Planner/executor code must not import boto3 directly. """ def __init__(self, *, aws: AwsSession) -> None: super().__init__(aws=aws, provider_ns="acm") self._client = self.client("acm") # --------------------------------------------------------------------- # Read operations (used by planner) # ---------------------------------------------------------------------
[docs] def list_certificates( self, *, statuses: Optional[list[str]] = None, include_status: bool = False, ) -> list[CertificateSummary]: """List certificate summaries (normalized). Notes: - ACM list operations do not return Status. When ``include_status=True`` this method describes each certificate to populate ``status``. """ paginator = self._client.get_paginator("list_certificates") kwargs = {} if statuses: kwargs["CertificateStatuses"] = statuses out: list[CertificateSummary] = [] for page in paginator.paginate(**kwargs): for item in page.get("CertificateSummaryList", []): arn = item.get("CertificateArn") domain = item.get("DomainName") if not arn or not domain: continue status = "" if include_status: try: d = self._client.describe_certificate(CertificateArn=arn) except Exception: # Certificates can disappear between list and describe (eventual consistency). continue c = (d.get("Certificate") or {}) status = str(c.get("Status") or "") out.append( CertificateSummary( certificate_arn=str(arn), domain_name=str(domain).rstrip(".").lower(), status=status, ) ) return out
[docs] def describe_certificate(self, *, certificate_arn: str) -> CertificateInfo: """Describe an ACM certificate and return normalized info.""" resp = self._client.describe_certificate(CertificateArn=certificate_arn) cert = resp.get("Certificate") or {} status = str(cert.get("Status") or "") domain_name = str(cert.get("DomainName") or "").rstrip(".").lower() sans = [ str(d).rstrip(".").lower() for d in (cert.get("SubjectAlternativeNames") or []) if d ] # Normalize domains (include primary + SANs, dedupe, keep stable ordering) domains: list[str] = [] seen: set[str] = set() for d in [domain_name, *sans]: if not d: continue if d in seen: continue seen.add(d) domains.append(d) failure_reason = cert.get("FailureReason") failure_reason = str(failure_reason) if failure_reason else None validation_records: list[ValidationRecord] = [] for dvo in cert.get("DomainValidationOptions", []) or []: rr = dvo.get("ResourceRecord") if not rr: continue name = rr.get("Name") value = rr.get("Value") rr_type = rr.get("Type") if not name or not value or not rr_type: continue validation_records.append( ValidationRecord( domain_name=str(dvo.get("DomainName") or "").rstrip(".").lower(), name=self._canonical_fqdn(str(name)), type=str(rr_type), value=self._canonical_fqdn(str(value)), ) ) return CertificateInfo( certificate_arn=certificate_arn, status=status, domain_name=domain_name, domains=domains, validation_records=validation_records, failure_reason=failure_reason, )
[docs] def get_dns_validation_records(self, *, certificate_arn: str) -> list[ValidationRecord]: info = self.describe_certificate(certificate_arn=certificate_arn) return list(info.validation_records)
[docs] def delete_certificate(self, *, certificate_arn: str) -> None: self._client.delete_certificate(CertificateArn=certificate_arn)
# --------------------------------------------------------------------- # Write operations (used by executor) # ---------------------------------------------------------------------
[docs] def request_dns_validated_certificate( self, *, primary_domain: str, sans: list[str], tags: dict[str, str], ) -> str: """Request a DNS-validated certificate and return its ARN. ACM IdempotencyToken constraints: - max length 32 - must match \\w+ (no hyphens) """ primary = primary_domain.rstrip(".").lower() san_norm = sorted({s.rstrip(".").lower() for s in sans if s}) # ACM-safe idempotency token: stable, <=32 chars, [0-9a-f] only. token_src = f"request_cert|{primary}|{','.join(san_norm)}" token = hashlib.sha256(token_src.encode("utf-8")).hexdigest()[:32] tag_list = [{"Key": k, "Value": v} for k, v in (tags or {}).items()] req: dict[str, object] = { "DomainName": primary, "ValidationMethod": "DNS", "IdempotencyToken": token, } if tag_list: req["Tags"] = tag_list # SubjectAlternativeNames is optional, but if provided it must have >= 1 item. if san_norm: req["SubjectAlternativeNames"] = san_norm resp = self._client.request_certificate(**req) arn = resp.get("CertificateArn") if not arn: raise ProviderError("ACM request_certificate did not return CertificateArn") return str(arn)
# --------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------- @staticmethod def _canonical_fqdn(value: str) -> str: v = value.strip() if not v: return v v = v.rstrip(".").lower() + "." return v