from __future__ import annotations
"""Planner for dxaws-cloudfront.
Pure planning logic:
- desired + current -> actions
This file deliberately does NOT call AWS. Provider wiring happens in executor/module.
"""
from dataclasses import replace
import hashlib
import re
from typing import Any
from .models import (
Action,
ActionType,
CacheBehaviorDesired,
CustomErrorResponsesDesired,
DistributionCurrent,
DistributionDesired,
HttpVersion,
Plan,
PriceClass,
S3OriginDesired,
ViewerCertificateDesired,
)
[docs]
def truncate_with_hash(s: str, *, max_len: int) -> str:
"""Truncate a string to a max length, adding a stable hash suffix.
CloudFront resource names have strict length limits. This helper keeps names
deterministic by adding an 8-char hash suffix when truncation is required.
We also sanitize to a conservative character set to avoid API validation issues.
"""
# Sanitize: keep alnum, dash, underscore, dot. Replace runs of other chars with '-'.
safe = re.sub(r"[^A-Za-z0-9_.-]+", "-", s).strip("-.")
if len(safe) <= max_len:
return safe
# Stable hash of the *original* string (before sanitization) so uniqueness is preserved.
h = hashlib.sha1(s.encode("utf-8")).hexdigest()[:8]
# Reserve space for '-' + hash
keep = max_len - (1 + len(h))
if keep <= 0:
return h[:max_len]
prefix = safe[:keep].rstrip("-.")
return f"{prefix}-{h}"
[docs]
def resolve_oac_name(d: DistributionDesired) -> str:
if d.oac_name:
return d.oac_name
# CloudFront enforces a max length for OAC names; keep this safely bounded.
raw = f"dxaws-cloudfront-{d.name}-oac"
return truncate_with_hash(raw, max_len=64)
[docs]
def resolve_origin_domain(origin: S3OriginDesired) -> str:
"""Resolve origin domain name from desired origin."""
if origin.domain_name:
return origin.domain_name
if origin.bucket_name:
# For S3 origin with OAC, regional domain name is preferred, but for MVP
# we allow the global pattern. The s3 module can pass regional_domain_name later.
return f"{origin.bucket_name}.s3.amazonaws.com"
raise ValueError("origin requires either bucket_name or domain_name")
[docs]
def synth_viewer_certificate(vc: ViewerCertificateDesired) -> dict[str, Any]:
if vc.acm_certificate_arn:
return {
"ACMCertificateArn": vc.acm_certificate_arn,
"SSLSupportMethod": vc.ssl_support_method,
"MinimumProtocolVersion": vc.minimum_protocol_version,
"Certificate": vc.acm_certificate_arn,
"CertificateSource": "acm",
}
return {
"CloudFrontDefaultCertificate": True,
}
[docs]
def synth_custom_errors(ce: CustomErrorResponsesDesired) -> dict[str, Any] | None:
if not ce.enabled:
return None
items: list[dict[str, Any]] = []
for code in ce.error_codes:
items.append(
{
"ErrorCode": int(code),
"ResponsePagePath": ce.response_page_path,
"ResponseCode": str(int(ce.response_code)),
"ErrorCachingMinTTL": int(ce.error_caching_min_ttl),
}
)
return {
"Quantity": len(items),
"Items": items,
}
[docs]
def synth_default_cache_behavior(
cb: CacheBehaviorDesired,
*,
origin_id: str,
) -> dict[str, Any]:
# MVP: S3 origin, no cache policy object yet; use legacy ForwardedValues.
return {
"TargetOriginId": origin_id,
"ViewerProtocolPolicy": cb.viewer_protocol_policy.value,
"AllowedMethods": {
"Quantity": 2,
"Items": ["GET", "HEAD"],
"CachedMethods": {"Quantity": 2, "Items": ["GET", "HEAD"]},
},
"Compress": bool(cb.compress),
"ForwardedValues": {
"QueryString": False,
"Cookies": {"Forward": "none"},
},
"MinTTL": int(cb.min_ttl),
"DefaultTTL": int(cb.default_ttl),
"MaxTTL": int(cb.max_ttl),
}
[docs]
def synth_distribution_config(
d: DistributionDesired,
*,
origin_domain: str,
oac_id: str | None,
) -> dict[str, Any]:
"""Build a deterministic DistributionConfig dict (MVP)."""
comment = d.comment or stable_comment(d.name)
aliases = list(d.aliases)
origin_id = "s3-origin"
# Origins
origin: dict[str, Any] = {
"Id": origin_id,
"DomainName": origin_domain,
"OriginPath": "",
"CustomHeaders": {"Quantity": 0},
"ConnectionAttempts": 3,
"ConnectionTimeout": 10,
"OriginShield": {"Enabled": False},
# For S3 origins, use S3OriginConfig; with OAC, set empty OAI and set OAC id.
"S3OriginConfig": {"OriginAccessIdentity": ""},
}
if oac_id:
origin["OriginAccessControlId"] = oac_id
origins = {"Quantity": 1, "Items": [origin]}
# Default cache behavior
default_cb = synth_default_cache_behavior(d.default_cache_behavior, origin_id=origin_id)
# Custom errors (optional)
custom_errors = synth_custom_errors(d.custom_errors)
# Viewer certificate
viewer_cert = synth_viewer_certificate(d.viewer_certificate)
cfg: dict[str, Any] = {
"CallerReference": d.name, # ok for create; ignored for update
"Comment": comment,
"Enabled": bool(d.enabled),
"IsIPV6Enabled": bool(d.ipv6_enabled),
"PriceClass": d.price_class.value,
"HttpVersion": d.http_version.value,
"DefaultRootObject": d.default_root_object,
"Aliases": {"Quantity": len(aliases), "Items": aliases} if aliases else {"Quantity": 0},
"Origins": origins,
"DefaultCacheBehavior": default_cb,
"ViewerCertificate": viewer_cert,
"Restrictions": {"GeoRestriction": {"RestrictionType": "none", "Quantity": 0}},
"Logging": {"Enabled": False, "IncludeCookies": False, "Bucket": "", "Prefix": ""},
"WebACLId": "",
}
if custom_errors:
cfg["CustomErrorResponses"] = custom_errors
else:
cfg["CustomErrorResponses"] = {"Quantity": 0}
return cfg
[docs]
def normalize_current_for_compare(cur: DistributionCurrent) -> DistributionCurrent:
"""Normalize None collections/strings for easier comparisons."""
aliases = cur.aliases if cur.aliases is not None else ()
tags = cur.tags if cur.tags is not None else {}
return replace(cur, aliases=aliases, tags=tags)
[docs]
def desired_summary_for_compare(desired: DistributionDesired) -> dict[str, Any]:
"""Build a simple summary dict used for comparisons.
This keeps comparisons stable even if AWS returns extra fields in raw_config.
"""
comment = desired.comment or stable_comment(desired.name)
return {
"enabled": bool(desired.enabled),
"comment": comment,
"aliases": tuple(desired.aliases),
"origin_domain": resolve_origin_domain(desired.origin),
"default_root_object": desired.default_root_object,
"viewer_protocol_policy": desired.default_cache_behavior.viewer_protocol_policy.value,
"acm_certificate_arn": desired.viewer_certificate.acm_certificate_arn,
"using_default_certificate": desired.viewer_certificate.acm_certificate_arn is None,
"price_class": desired.price_class.value,
"http_version": desired.http_version.value,
"ipv6_enabled": bool(desired.ipv6_enabled),
}
[docs]
def current_summary_for_compare(cur: DistributionCurrent) -> dict[str, Any]:
cur = normalize_current_for_compare(cur)
return {
"enabled": cur.enabled,
"comment": cur.comment,
"aliases": tuple(cur.aliases or ()),
"origin_domain": cur.origin_domain_name,
"default_root_object": cur.default_root_object,
"viewer_protocol_policy": cur.viewer_protocol_policy,
"acm_certificate_arn": cur.acm_certificate_arn,
"using_default_certificate": cur.using_default_certificate,
"price_class": cur.price_class,
"http_version": cur.http_version,
"ipv6_enabled": cur.ipv6_enabled,
}
[docs]
def plan(
desired: DistributionDesired,
current: DistributionCurrent,
*,
oac_id: str | None,
wait_deployed: bool = True,
) -> Plan:
"""Plan convergence actions.
`oac_id` is supplied by the module/executor (ensure/create in AWS) or None if
OAC is not used yet.
"""
actions: list[Action] = []
# Absent/destroy path: if present=False, plan disable -> wait -> delete.
if desired.present is False:
# Already absent is a noop.
if not current.exists or not current.id:
return Plan(desired=desired, current=current, actions=[])
if current.enabled is not False:
actions.append(
Action(
type=ActionType.DISABLE_DISTRIBUTION,
reason="disable distribution before delete",
payload={"distribution_id": current.id},
)
)
if wait_deployed:
actions.append(
Action(
type=ActionType.WAIT_DEPLOYED,
reason="wait for CloudFront distribution deployment",
payload={},
)
)
actions.append(
Action(
type=ActionType.DELETE_DISTRIBUTION,
reason="delete distribution",
payload={"distribution_id": current.id},
)
)
return Plan(desired=desired, current=current, actions=actions)
origin_domain = resolve_origin_domain(desired.origin)
# Build desired config deterministically (even if current is missing)
cfg = synth_distribution_config(desired, origin_domain=origin_domain, oac_id=oac_id)
if not current.exists:
if oac_id is None:
actions.append(
Action(
type=ActionType.CREATE_OAC,
reason="origin access control required",
payload={
"oac_name": resolve_oac_name(desired),
"description": f"dxaws-cloudfront OAC for {desired.name}",
},
)
)
actions.append(
Action(
type=ActionType.CREATE_DISTRIBUTION,
reason="distribution does not exist",
payload={
"distribution_config": cfg,
"tags": desired.tags,
"comment": (desired.comment or stable_comment(desired.name)),
},
)
)
if wait_deployed:
actions.append(
Action(
type=ActionType.WAIT_DEPLOYED,
reason="wait for CloudFront distribution deployment",
payload={},
)
)
return Plan(desired=desired, current=current, actions=actions)
# Compare summarized current vs desired
ds = desired_summary_for_compare(desired)
cs = current_summary_for_compare(current)
# Comment is optional. Only enforce comment equality when the caller explicitly provides it.
if desired.comment is None:
ds.pop("comment", None)
cs.pop("comment", None)
if cs != ds:
actions.append(
Action(
type=ActionType.UPDATE_DISTRIBUTION,
reason="distribution config differs",
payload={
"distribution_id": current.id,
"distribution_config": cfg,
},
)
)
if wait_deployed:
actions.append(
Action(
type=ActionType.WAIT_DEPLOYED,
reason="wait for CloudFront distribution deployment",
payload={},
)
)
elif wait_deployed and current.status and current.status != "Deployed":
actions.append(
Action(
type=ActionType.WAIT_DEPLOYED,
reason="wait for CloudFront distribution deployment",
payload={},
)
)
# Tags are optional. Only plan tag changes when the caller explicitly provides desired tags.
if desired.tags is not None:
desired_tags = dict(desired.tags)
cur_tags = dict(current.tags or {}) if current.tags is not None else {}
if cur_tags != desired_tags:
actions.append(
Action(
type=ActionType.TAG_DISTRIBUTION,
reason="tags differ",
payload={"tags": desired_tags},
)
)
return Plan(desired=desired, current=current, actions=actions)