Source code for dxaws_cloudfront.planner

from __future__ import annotations

"""Planner for dxaws-cloudfront.

Pure planning logic:
- desired + current -> actions

This file deliberately does NOT call AWS. Provider wiring happens in executor/module.
"""

from dataclasses import replace
import hashlib
import re
from typing import Any

from .models import (
    Action,
    ActionType,
    CacheBehaviorDesired,
    CustomErrorResponsesDesired,
    DistributionCurrent,
    DistributionDesired,
    HttpVersion,
    Plan,
    PriceClass,
    S3OriginDesired,
    ViewerCertificateDesired,
)


[docs] def stable_comment(name: str) -> str: """Stable comment used for idempotent discovery.""" return f"dxaws-cloudfront:{name}"
[docs] def truncate_with_hash(s: str, *, max_len: int) -> str: """Truncate a string to a max length, adding a stable hash suffix. CloudFront resource names have strict length limits. This helper keeps names deterministic by adding an 8-char hash suffix when truncation is required. We also sanitize to a conservative character set to avoid API validation issues. """ # Sanitize: keep alnum, dash, underscore, dot. Replace runs of other chars with '-'. safe = re.sub(r"[^A-Za-z0-9_.-]+", "-", s).strip("-.") if len(safe) <= max_len: return safe # Stable hash of the *original* string (before sanitization) so uniqueness is preserved. h = hashlib.sha1(s.encode("utf-8")).hexdigest()[:8] # Reserve space for '-' + hash keep = max_len - (1 + len(h)) if keep <= 0: return h[:max_len] prefix = safe[:keep].rstrip("-.") return f"{prefix}-{h}"
[docs] def resolve_oac_name(d: DistributionDesired) -> str: if d.oac_name: return d.oac_name # CloudFront enforces a max length for OAC names; keep this safely bounded. raw = f"dxaws-cloudfront-{d.name}-oac" return truncate_with_hash(raw, max_len=64)
[docs] def resolve_origin_domain(origin: S3OriginDesired) -> str: """Resolve origin domain name from desired origin.""" if origin.domain_name: return origin.domain_name if origin.bucket_name: # For S3 origin with OAC, regional domain name is preferred, but for MVP # we allow the global pattern. The s3 module can pass regional_domain_name later. return f"{origin.bucket_name}.s3.amazonaws.com" raise ValueError("origin requires either bucket_name or domain_name")
[docs] def synth_viewer_certificate(vc: ViewerCertificateDesired) -> dict[str, Any]: if vc.acm_certificate_arn: return { "ACMCertificateArn": vc.acm_certificate_arn, "SSLSupportMethod": vc.ssl_support_method, "MinimumProtocolVersion": vc.minimum_protocol_version, "Certificate": vc.acm_certificate_arn, "CertificateSource": "acm", } return { "CloudFrontDefaultCertificate": True, }
[docs] def synth_custom_errors(ce: CustomErrorResponsesDesired) -> dict[str, Any] | None: if not ce.enabled: return None items: list[dict[str, Any]] = [] for code in ce.error_codes: items.append( { "ErrorCode": int(code), "ResponsePagePath": ce.response_page_path, "ResponseCode": str(int(ce.response_code)), "ErrorCachingMinTTL": int(ce.error_caching_min_ttl), } ) return { "Quantity": len(items), "Items": items, }
[docs] def synth_default_cache_behavior( cb: CacheBehaviorDesired, *, origin_id: str, ) -> dict[str, Any]: # MVP: S3 origin, no cache policy object yet; use legacy ForwardedValues. return { "TargetOriginId": origin_id, "ViewerProtocolPolicy": cb.viewer_protocol_policy.value, "AllowedMethods": { "Quantity": 2, "Items": ["GET", "HEAD"], "CachedMethods": {"Quantity": 2, "Items": ["GET", "HEAD"]}, }, "Compress": bool(cb.compress), "ForwardedValues": { "QueryString": False, "Cookies": {"Forward": "none"}, }, "MinTTL": int(cb.min_ttl), "DefaultTTL": int(cb.default_ttl), "MaxTTL": int(cb.max_ttl), }
[docs] def synth_distribution_config( d: DistributionDesired, *, origin_domain: str, oac_id: str | None, ) -> dict[str, Any]: """Build a deterministic DistributionConfig dict (MVP).""" comment = d.comment or stable_comment(d.name) aliases = list(d.aliases) origin_id = "s3-origin" # Origins origin: dict[str, Any] = { "Id": origin_id, "DomainName": origin_domain, "OriginPath": "", "CustomHeaders": {"Quantity": 0}, "ConnectionAttempts": 3, "ConnectionTimeout": 10, "OriginShield": {"Enabled": False}, # For S3 origins, use S3OriginConfig; with OAC, set empty OAI and set OAC id. "S3OriginConfig": {"OriginAccessIdentity": ""}, } if oac_id: origin["OriginAccessControlId"] = oac_id origins = {"Quantity": 1, "Items": [origin]} # Default cache behavior default_cb = synth_default_cache_behavior(d.default_cache_behavior, origin_id=origin_id) # Custom errors (optional) custom_errors = synth_custom_errors(d.custom_errors) # Viewer certificate viewer_cert = synth_viewer_certificate(d.viewer_certificate) cfg: dict[str, Any] = { "CallerReference": d.name, # ok for create; ignored for update "Comment": comment, "Enabled": bool(d.enabled), "IsIPV6Enabled": bool(d.ipv6_enabled), "PriceClass": d.price_class.value, "HttpVersion": d.http_version.value, "DefaultRootObject": d.default_root_object, "Aliases": {"Quantity": len(aliases), "Items": aliases} if aliases else {"Quantity": 0}, "Origins": origins, "DefaultCacheBehavior": default_cb, "ViewerCertificate": viewer_cert, "Restrictions": {"GeoRestriction": {"RestrictionType": "none", "Quantity": 0}}, "Logging": {"Enabled": False, "IncludeCookies": False, "Bucket": "", "Prefix": ""}, "WebACLId": "", } if custom_errors: cfg["CustomErrorResponses"] = custom_errors else: cfg["CustomErrorResponses"] = {"Quantity": 0} return cfg
[docs] def normalize_current_for_compare(cur: DistributionCurrent) -> DistributionCurrent: """Normalize None collections/strings for easier comparisons.""" aliases = cur.aliases if cur.aliases is not None else () tags = cur.tags if cur.tags is not None else {} return replace(cur, aliases=aliases, tags=tags)
[docs] def desired_summary_for_compare(desired: DistributionDesired) -> dict[str, Any]: """Build a simple summary dict used for comparisons. This keeps comparisons stable even if AWS returns extra fields in raw_config. """ comment = desired.comment or stable_comment(desired.name) return { "enabled": bool(desired.enabled), "comment": comment, "aliases": tuple(desired.aliases), "origin_domain": resolve_origin_domain(desired.origin), "default_root_object": desired.default_root_object, "viewer_protocol_policy": desired.default_cache_behavior.viewer_protocol_policy.value, "acm_certificate_arn": desired.viewer_certificate.acm_certificate_arn, "using_default_certificate": desired.viewer_certificate.acm_certificate_arn is None, "price_class": desired.price_class.value, "http_version": desired.http_version.value, "ipv6_enabled": bool(desired.ipv6_enabled), }
[docs] def current_summary_for_compare(cur: DistributionCurrent) -> dict[str, Any]: cur = normalize_current_for_compare(cur) return { "enabled": cur.enabled, "comment": cur.comment, "aliases": tuple(cur.aliases or ()), "origin_domain": cur.origin_domain_name, "default_root_object": cur.default_root_object, "viewer_protocol_policy": cur.viewer_protocol_policy, "acm_certificate_arn": cur.acm_certificate_arn, "using_default_certificate": cur.using_default_certificate, "price_class": cur.price_class, "http_version": cur.http_version, "ipv6_enabled": cur.ipv6_enabled, }
[docs] def plan( desired: DistributionDesired, current: DistributionCurrent, *, oac_id: str | None, wait_deployed: bool = True, ) -> Plan: """Plan convergence actions. `oac_id` is supplied by the module/executor (ensure/create in AWS) or None if OAC is not used yet. """ actions: list[Action] = [] # Absent/destroy path: if present=False, plan disable -> wait -> delete. if desired.present is False: # Already absent is a noop. if not current.exists or not current.id: return Plan(desired=desired, current=current, actions=[]) if current.enabled is not False: actions.append( Action( type=ActionType.DISABLE_DISTRIBUTION, reason="disable distribution before delete", payload={"distribution_id": current.id}, ) ) if wait_deployed: actions.append( Action( type=ActionType.WAIT_DEPLOYED, reason="wait for CloudFront distribution deployment", payload={}, ) ) actions.append( Action( type=ActionType.DELETE_DISTRIBUTION, reason="delete distribution", payload={"distribution_id": current.id}, ) ) return Plan(desired=desired, current=current, actions=actions) origin_domain = resolve_origin_domain(desired.origin) # Build desired config deterministically (even if current is missing) cfg = synth_distribution_config(desired, origin_domain=origin_domain, oac_id=oac_id) if not current.exists: if oac_id is None: actions.append( Action( type=ActionType.CREATE_OAC, reason="origin access control required", payload={ "oac_name": resolve_oac_name(desired), "description": f"dxaws-cloudfront OAC for {desired.name}", }, ) ) actions.append( Action( type=ActionType.CREATE_DISTRIBUTION, reason="distribution does not exist", payload={ "distribution_config": cfg, "tags": desired.tags, "comment": (desired.comment or stable_comment(desired.name)), }, ) ) if wait_deployed: actions.append( Action( type=ActionType.WAIT_DEPLOYED, reason="wait for CloudFront distribution deployment", payload={}, ) ) return Plan(desired=desired, current=current, actions=actions) # Compare summarized current vs desired ds = desired_summary_for_compare(desired) cs = current_summary_for_compare(current) # Comment is optional. Only enforce comment equality when the caller explicitly provides it. if desired.comment is None: ds.pop("comment", None) cs.pop("comment", None) if cs != ds: actions.append( Action( type=ActionType.UPDATE_DISTRIBUTION, reason="distribution config differs", payload={ "distribution_id": current.id, "distribution_config": cfg, }, ) ) if wait_deployed: actions.append( Action( type=ActionType.WAIT_DEPLOYED, reason="wait for CloudFront distribution deployment", payload={}, ) ) elif wait_deployed and current.status and current.status != "Deployed": actions.append( Action( type=ActionType.WAIT_DEPLOYED, reason="wait for CloudFront distribution deployment", payload={}, ) ) # Tags are optional. Only plan tag changes when the caller explicitly provides desired tags. if desired.tags is not None: desired_tags = dict(desired.tags) cur_tags = dict(current.tags or {}) if current.tags is not None else {} if cur_tags != desired_tags: actions.append( Action( type=ActionType.TAG_DISTRIBUTION, reason="tags differ", payload={"tags": desired_tags}, ) ) return Plan(desired=desired, current=current, actions=actions)