Source code for dxaws_cloudfront.providers.aws

from __future__ import annotations

"""AWS provider implementation for dxaws-cloudfront.

This module is intentionally focused on AWS API access + low-level helpers.
Higher-level convergence logic lives in planner/executor.

We start by wiring an AWS CloudFront client and a few pragmatic helper methods
that we will use for:
- discovering distributions by tag/name
- creating/updating distributions
- creating/updating Origin Access Control (OAC)
- waiting for distribution deployment

The exact desired/current models will be layered on top in later steps.
"""

from dataclasses import dataclass
from typing import Any

from botocore.exceptions import ClientError

from dxaws_core.providers.aws import AwsProviderBase


def _err_code(e: ClientError) -> str:
    return str(e.response.get("Error", {}).get("Code", ""))


[docs] @dataclass(frozen=True) class DistributionRef: """Minimal reference for a CloudFront distribution.""" id: str arn: str | None domain_name: str | None status: str | None comment: str | None
[docs] @dataclass(frozen=True) class OacRef: """Minimal reference for an Origin Access Control.""" id: str name: str
[docs] class AwsProvider(AwsProviderBase): """AWS implementation for dxaws-cloudfront.""" def __init__(self, *, aws) -> None: # provider_ns is used for tagging, logging, and idempotency token namespaces. super().__init__(aws=aws, provider_ns="cloudfront") # --------------------------------------------------------------------- # Clients # --------------------------------------------------------------------- @property def cf(self): """CloudFront boto3 client.""" return self.aws.client("cloudfront") # --------------------------------------------------------------------- # Discovery helpers (read-only) # ---------------------------------------------------------------------
[docs] def list_distributions(self) -> list[DistributionRef]: """List distributions (best-effort summary). CloudFront is a global service; this uses the global endpoint. """ out: list[DistributionRef] = [] paginator = self.cf.get_paginator("list_distributions") for page in paginator.paginate(): items = ((page.get("DistributionList") or {}).get("Items")) or [] for d in items: out.append( DistributionRef( id=d.get("Id"), arn=d.get("ARN"), domain_name=d.get("DomainName"), status=d.get("Status"), comment=d.get("Comment"), ) ) # Filter out any weird empty items return [d for d in out if d.id]
[docs] def get_distribution(self, dist_id: str) -> tuple[dict[str, Any], str]: """Fetch the full distribution config and ETag. Returns (DistributionConfig, ETag). """ resp = self.cf.get_distribution_config(Id=dist_id) etag = resp.get("ETag") cfg = resp.get("DistributionConfig") if not etag or not cfg: raise RuntimeError("CloudFront get_distribution_config returned no ETag/config") return cfg, etag
[docs] def get_distribution_status(self, dist_id: str) -> str: resp = self.cf.get_distribution(Id=dist_id) dist = resp.get("Distribution") or {} return str(dist.get("Status") or "")
[docs] def find_distribution_by_alias(self, alias: str) -> DistributionRef | None: alias_c = alias.strip().lower().rstrip(".") + "." for d in self.list_distributions(): cfg, _ = self.get_distribution(d.id) items = ((cfg.get("Aliases") or {}).get("Items")) or [] aliases = {str(a).strip().lower().rstrip(".") + "." for a in items} if alias_c in aliases: return d return None
# --------------------------------------------------------------------- # Tag helpers # ---------------------------------------------------------------------
[docs] def list_tags(self, resource_arn: str) -> dict[str, str]: resp = self.cf.list_tags_for_resource(Resource=resource_arn) items = ((resp.get("Tags") or {}).get("Items")) or [] out: dict[str, str] = {} for it in items: k = it.get("Key") v = it.get("Value") if k is not None and v is not None: out[str(k)] = str(v) return out
[docs] def tag_resource(self, resource_arn: str, tags: dict[str, str]) -> None: items = [{"Key": k, "Value": v} for k, v in sorted(tags.items())] self.cf.tag_resource(Resource=resource_arn, Tags={"Items": items})
# --------------------------------------------------------------------- # Origin Access Control (OAC) # ---------------------------------------------------------------------
[docs] def list_oacs(self) -> list[OacRef]: out: list[OacRef] = [] paginator = self.cf.get_paginator("list_origin_access_controls") for page in paginator.paginate(): items = ((page.get("OriginAccessControlList") or {}).get("Items")) or [] for it in items: # AWS may return either a summary shape: # {"Id": "...", "Name": "...", ...} # or a nested shape (older/alternate): # {"OriginAccessControl": {"Id": "...", "Name": "...", ...}} if isinstance(it, dict) and "Id" in it and "Name" in it: oac_id = it.get("Id") name = (it.get("Name") or "").strip() else: inner = (it.get("OriginAccessControl") or {}) if isinstance(it, dict) else {} oac_id = inner.get("Id") name = (inner.get("Name") or "").strip() if oac_id and name: out.append(OacRef(id=oac_id, name=name)) return out
[docs] def find_oac_by_name(self, name: str) -> OacRef | None: for o in self.list_oacs(): if o.name == name: return o return None
[docs] def ensure_oac( self, *, name: str, description: str, signing_behavior: str = "always", signing_protocol: str = "sigv4", origin_type: str = "s3", ) -> OacRef: """Ensure an OAC exists and return its id. For MVP we create if missing; if CloudFront says it already exists, we re-query by name and return it. """ existing = self.find_oac_by_name(name) if existing: return existing try: resp = self.cf.create_origin_access_control( OriginAccessControlConfig={ "Name": name, "Description": description, "SigningProtocol": signing_protocol, "SigningBehavior": signing_behavior, "OriginAccessControlOriginType": origin_type, } ) except ClientError as e: if _err_code(e) == "OriginAccessControlAlreadyExists": # Race/duplicate: fetch again and return existing = self.find_oac_by_name(name) if existing: return existing raise cfg = (resp.get("OriginAccessControl") or {}).get("OriginAccessControlConfig") or {} oac_id = (resp.get("OriginAccessControl") or {}).get("Id") oac_name = (cfg.get("Name") or name) if not oac_id: raise RuntimeError("CloudFront create_origin_access_control returned no Id") return OacRef(id=oac_id, name=oac_name)
# --------------------------------------------------------------------- # Distribution create/update (low-level) # --------------------------------------------------------------------- def _ensure_required_flags(self, config: dict[str, Any]) -> dict[str, Any]: """CloudFront API requires certain boolean flags to be present. Some fields are mandatory even when unused. For example, CacheBehavior requires SmoothStreaming to be explicitly set. We defensively fill these in so planners/executors can stay focused on convergence logic. """ def _ensure_forwarded_values(beh: dict[str, Any]) -> dict[str, Any]: """Ensure legacy ForwardedValues has required nested shapes. CloudFront requires `ForwardedValues.Headers` to be present if `ForwardedValues` is used (i.e., when CachePolicyId is not set). """ # If the modern policy id is set, ForwardedValues is not required. if beh.get("CachePolicyId"): return beh fv = beh.get("ForwardedValues") if not isinstance(fv, dict): fv = {} fv2 = dict(fv) # Required booleans/structures for legacy ForwardedValues fv2.setdefault("QueryString", False) cookies = fv2.get("Cookies") if not isinstance(cookies, dict): cookies = {} cookies2 = dict(cookies) cookies2.setdefault("Forward", "none") fv2["Cookies"] = cookies2 headers = fv2.get("Headers") if not isinstance(headers, dict): headers = {} headers2 = dict(headers) headers2.setdefault("Quantity", 0) # Only include Items when Quantity > 0 (CloudFront is strict) if headers2.get("Quantity", 0) == 0: headers2.pop("Items", None) fv2["Headers"] = headers2 qsk = fv2.get("QueryStringCacheKeys") if not isinstance(qsk, dict): qsk = {} qsk2 = dict(qsk) qsk2.setdefault("Quantity", 0) if qsk2.get("Quantity", 0) == 0: qsk2.pop("Items", None) fv2["QueryStringCacheKeys"] = qsk2 beh2 = dict(beh) beh2["ForwardedValues"] = fv2 return beh2 def _ensure_lambda_associations(beh: dict[str, Any]) -> dict[str, Any]: """Ensure LambdaFunctionAssociations shape is present. CloudFront requires the nested Quantity (and omits Items when Quantity==0). """ lfa = beh.get("LambdaFunctionAssociations") if not isinstance(lfa, dict): lfa = {} lfa2 = dict(lfa) lfa2.setdefault("Quantity", 0) if lfa2.get("Quantity", 0) == 0: lfa2.pop("Items", None) beh2 = dict(beh) beh2["LambdaFunctionAssociations"] = lfa2 return beh2 def _ensure_function_associations(beh: dict[str, Any]) -> dict[str, Any]: """Ensure FunctionAssociations shape is present (CloudFront Functions).""" fa = beh.get("FunctionAssociations") if not isinstance(fa, dict): fa = {} fa2 = dict(fa) fa2.setdefault("Quantity", 0) if fa2.get("Quantity", 0) == 0: fa2.pop("Items", None) beh2 = dict(beh) beh2["FunctionAssociations"] = fa2 return beh2 cfg = dict(config) dcb = cfg.get("DefaultCacheBehavior") # CloudFront requires CacheBehaviors to be present even when empty. cbs = cfg.get("CacheBehaviors") if not isinstance(cbs, dict): cfg["CacheBehaviors"] = {"Quantity": 0} cbs = cfg["CacheBehaviors"] # Normalize CacheBehaviors Quantity/Items. items = cbs.get("Items") if items is None: cbs2 = dict(cbs) cbs2["Quantity"] = int(cbs2.get("Quantity") or 0) if cbs2["Quantity"] == 0: cbs2.pop("Items", None) cfg["CacheBehaviors"] = cbs2 elif isinstance(items, list): cbs2 = dict(cbs) cbs2["Quantity"] = len(items) if len(items) == 0: cbs2.pop("Items", None) cfg["CacheBehaviors"] = cbs2 if isinstance(dcb, dict): dcb2 = dict(dcb) dcb2.setdefault("SmoothStreaming", False) dcb2.setdefault("Compress", False) dcb2.setdefault("FieldLevelEncryptionId", "") dcb2 = _ensure_forwarded_values(dcb2) dcb2 = _ensure_lambda_associations(dcb2) dcb2 = _ensure_function_associations(dcb2) cfg["DefaultCacheBehavior"] = dcb2 cbs = cfg.get("CacheBehaviors") if isinstance(cbs, dict): items = cbs.get("Items") if isinstance(items, list) and len(items) > 0: new_items: list[dict[str, Any]] = [] for it in items: if isinstance(it, dict): it2 = dict(it) it2.setdefault("SmoothStreaming", False) it2.setdefault("Compress", False) it2.setdefault("FieldLevelEncryptionId", "") it2 = _ensure_forwarded_values(it2) it2 = _ensure_lambda_associations(it2) it2 = _ensure_function_associations(it2) new_items.append(it2) else: new_items.append(it) cbs2 = dict(cbs) cbs2["Quantity"] = len(new_items) cbs2["Items"] = new_items cfg["CacheBehaviors"] = cbs2 return cfg
[docs] def create_distribution(self, *, config: dict[str, Any], tags: dict[str, str] | None = None) -> DistributionRef: """Create a distribution from a raw DistributionConfig.""" config = self._ensure_required_flags(config) # CloudFront has two different create APIs: # - create_distribution(DistributionConfig=...) # - create_distribution_with_tags(DistributionConfigWithTags=...) # botocore validates input params strictly, so we must call the correct one. if tags: tag_items = [{"Key": k, "Value": v} for k, v in sorted(tags.items())] resp = self.cf.create_distribution_with_tags( DistributionConfigWithTags={ "DistributionConfig": config, "Tags": {"Items": tag_items}, } ) else: resp = self.cf.create_distribution(DistributionConfig=config) dist = (resp.get("Distribution") or {}) dist_cfg = (dist.get("DistributionConfig") or {}) return DistributionRef( id=dist.get("Id"), arn=dist.get("ARN"), domain_name=dist.get("DomainName"), status=dist.get("Status"), comment=dist_cfg.get("Comment"), )
[docs] def update_distribution(self, *, dist_id: str, if_match: str, config: dict[str, Any]) -> DistributionRef: config = self._ensure_required_flags(config) resp = self.cf.update_distribution(Id=dist_id, IfMatch=if_match, DistributionConfig=config) dist = (resp.get("Distribution") or {}) dist_cfg = (dist.get("DistributionConfig") or {}) return DistributionRef( id=dist.get("Id"), arn=dist.get("ARN"), domain_name=dist.get("DomainName"), status=dist.get("Status"), comment=dist_cfg.get("Comment"), )
[docs] def disable_distribution(self, dist_id: str) -> None: """Disable a distribution (required before delete). CloudFront is eventually consistent: a distribution can report Status=Deployed from a previous state even after we submit an update. We therefore wait until the config reflects Enabled=False and the distribution is Deployed. """ import time cfg, etag = self.get_distribution(dist_id) if cfg.get("Enabled") is False: return cfg2 = dict(cfg) cfg2["Enabled"] = False cfg2 = self._ensure_required_flags(cfg2) # Submit the disable update. self.cf.update_distribution(Id=dist_id, IfMatch=etag, DistributionConfig=cfg2) # Wait for CloudFront to reflect the disabled state. deadline = time.monotonic() + 1200 # 20 minutes while True: status = self.get_distribution_status(dist_id) try: cur_cfg, _ = self.get_distribution(dist_id) except Exception: cur_cfg = {} enabled = cur_cfg.get("Enabled") if enabled is False and status == "Deployed": return if time.monotonic() >= deadline: raise TimeoutError( f"Timed out waiting for distribution to disable (status={status}, enabled={enabled})" ) time.sleep(10)
[docs] def delete_distribution(self, dist_id: str) -> None: """Delete a distribution. CloudFront requires the distribution to be disabled and deployed before delete. """ _, etag = self.get_distribution(dist_id) self.cf.delete_distribution(Id=dist_id, IfMatch=etag)
# --------------------------------------------------------------------- # Waiters # ---------------------------------------------------------------------
[docs] def wait_deployed(self, dist_id: str) -> None: """Wait for a distribution to reach Deployed.""" waiter = self.cf.get_waiter("distribution_deployed") waiter.wait(Id=dist_id)