from __future__ import annotations
import hashlib
from dataclasses import dataclass
from typing import Optional
from dxaws_acm.exceptions import ProviderError
from dxaws_core import AwsSession
from dxaws_core.providers.aws import AwsProviderBase
[docs]
@dataclass(frozen=True, slots=True, kw_only=True)
class CertificateSummary:
"""Normalized certificate summary returned by list operations.
This is intentionally *not* a boto3 shape.
Note:
- `status` may be empty when returned from list operations unless status enrichment is requested.
"""
certificate_arn: str
domain_name: str
status: str
[docs]
@dataclass(frozen=True, slots=True, kw_only=True)
class ValidationRecord:
"""Normalized DNS validation record required by ACM.
Notes:
- `name` is canonical FQDN with trailing dot.
- `value` is canonical FQDN with trailing dot (ACM typically returns this already).
"""
domain_name: str
name: str
type: str # usually "CNAME"
value: str
[docs]
@dataclass(frozen=True, slots=True, kw_only=True)
class CertificateInfo:
"""Normalized certificate information returned by describe operations.
This is intentionally *not* a boto3 shape.
"""
certificate_arn: str
status: str
domain_name: str # primary DomainName (normalized lower-case)
domains: list[str] # DomainName + SANs (normalized lower-case)
validation_records: list[ValidationRecord]
failure_reason: Optional[str]
[docs]
class AwsProvider(AwsProviderBase):
"""ACM provider.
All boto3 and AWS-specific behavior lives in this class.
Planner/executor code must not import boto3 directly.
"""
def __init__(self, *, aws: AwsSession) -> None:
super().__init__(aws=aws, provider_ns="acm")
self._client = self.client("acm")
# ---------------------------------------------------------------------
# Read operations (used by planner)
# ---------------------------------------------------------------------
[docs]
def list_certificates(
self,
*,
statuses: Optional[list[str]] = None,
include_status: bool = False,
) -> list[CertificateSummary]:
"""List certificate summaries (normalized).
Notes:
- ACM list operations do not return Status. When ``include_status=True``
this method describes each certificate to populate ``status``.
"""
paginator = self._client.get_paginator("list_certificates")
kwargs = {}
if statuses:
kwargs["CertificateStatuses"] = statuses
out: list[CertificateSummary] = []
for page in paginator.paginate(**kwargs):
for item in page.get("CertificateSummaryList", []):
arn = item.get("CertificateArn")
domain = item.get("DomainName")
if not arn or not domain:
continue
status = ""
if include_status:
try:
d = self._client.describe_certificate(CertificateArn=arn)
except Exception:
# Certificates can disappear between list and describe (eventual consistency).
continue
c = (d.get("Certificate") or {})
status = str(c.get("Status") or "")
out.append(
CertificateSummary(
certificate_arn=str(arn),
domain_name=str(domain).rstrip(".").lower(),
status=status,
)
)
return out
[docs]
def describe_certificate(self, *, certificate_arn: str) -> CertificateInfo:
"""Describe an ACM certificate and return normalized info."""
resp = self._client.describe_certificate(CertificateArn=certificate_arn)
cert = resp.get("Certificate") or {}
status = str(cert.get("Status") or "")
domain_name = str(cert.get("DomainName") or "").rstrip(".").lower()
sans = [
str(d).rstrip(".").lower()
for d in (cert.get("SubjectAlternativeNames") or [])
if d
]
# Normalize domains (include primary + SANs, dedupe, keep stable ordering)
domains: list[str] = []
seen: set[str] = set()
for d in [domain_name, *sans]:
if not d:
continue
if d in seen:
continue
seen.add(d)
domains.append(d)
failure_reason = cert.get("FailureReason")
failure_reason = str(failure_reason) if failure_reason else None
validation_records: list[ValidationRecord] = []
for dvo in cert.get("DomainValidationOptions", []) or []:
rr = dvo.get("ResourceRecord")
if not rr:
continue
name = rr.get("Name")
value = rr.get("Value")
rr_type = rr.get("Type")
if not name or not value or not rr_type:
continue
validation_records.append(
ValidationRecord(
domain_name=str(dvo.get("DomainName") or "").rstrip(".").lower(),
name=self._canonical_fqdn(str(name)),
type=str(rr_type),
value=self._canonical_fqdn(str(value)),
)
)
return CertificateInfo(
certificate_arn=certificate_arn,
status=status,
domain_name=domain_name,
domains=domains,
validation_records=validation_records,
failure_reason=failure_reason,
)
[docs]
def get_dns_validation_records(self, *, certificate_arn: str) -> list[ValidationRecord]:
info = self.describe_certificate(certificate_arn=certificate_arn)
return list(info.validation_records)
[docs]
def delete_certificate(self, *, certificate_arn: str) -> None:
self._client.delete_certificate(CertificateArn=certificate_arn)
# ---------------------------------------------------------------------
# Write operations (used by executor)
# ---------------------------------------------------------------------
[docs]
def request_dns_validated_certificate(
self,
*,
primary_domain: str,
sans: list[str],
tags: dict[str, str],
) -> str:
"""Request a DNS-validated certificate and return its ARN.
ACM IdempotencyToken constraints:
- max length 32
- must match \\w+ (no hyphens)
"""
primary = primary_domain.rstrip(".").lower()
san_norm = sorted({s.rstrip(".").lower() for s in sans if s})
# ACM-safe idempotency token: stable, <=32 chars, [0-9a-f] only.
token_src = f"request_cert|{primary}|{','.join(san_norm)}"
token = hashlib.sha256(token_src.encode("utf-8")).hexdigest()[:32]
tag_list = [{"Key": k, "Value": v} for k, v in (tags or {}).items()]
req: dict[str, object] = {
"DomainName": primary,
"ValidationMethod": "DNS",
"IdempotencyToken": token,
}
if tag_list:
req["Tags"] = tag_list
# SubjectAlternativeNames is optional, but if provided it must have >= 1 item.
if san_norm:
req["SubjectAlternativeNames"] = san_norm
resp = self._client.request_certificate(**req)
arn = resp.get("CertificateArn")
if not arn:
raise ProviderError("ACM request_certificate did not return CertificateArn")
return str(arn)
# ---------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------
@staticmethod
def _canonical_fqdn(value: str) -> str:
v = value.strip()
if not v:
return v
v = v.rstrip(".").lower() + "."
return v