from __future__ import annotations
"""AWS provider implementation for dxaws-cloudfront.
This module is intentionally focused on AWS API access + low-level helpers.
Higher-level convergence logic lives in planner/executor.
We start by wiring an AWS CloudFront client and a few pragmatic helper methods
that we will use for:
- discovering distributions by tag/name
- creating/updating distributions
- creating/updating Origin Access Control (OAC)
- waiting for distribution deployment
The exact desired/current models will be layered on top in later steps.
"""
from dataclasses import dataclass
from typing import Any
from botocore.exceptions import ClientError
from dxaws_core.providers.aws import AwsProviderBase
def _err_code(e: ClientError) -> str:
return str(e.response.get("Error", {}).get("Code", ""))
[docs]
@dataclass(frozen=True)
class DistributionRef:
"""Minimal reference for a CloudFront distribution."""
id: str
arn: str | None
domain_name: str | None
status: str | None
comment: str | None
[docs]
@dataclass(frozen=True)
class OacRef:
"""Minimal reference for an Origin Access Control."""
id: str
name: str
[docs]
class AwsProvider(AwsProviderBase):
"""AWS implementation for dxaws-cloudfront."""
def __init__(self, *, aws) -> None:
# provider_ns is used for tagging, logging, and idempotency token namespaces.
super().__init__(aws=aws, provider_ns="cloudfront")
# ---------------------------------------------------------------------
# Clients
# ---------------------------------------------------------------------
@property
def cf(self):
"""CloudFront boto3 client."""
return self.aws.client("cloudfront")
# ---------------------------------------------------------------------
# Discovery helpers (read-only)
# ---------------------------------------------------------------------
[docs]
def list_distributions(self) -> list[DistributionRef]:
"""List distributions (best-effort summary).
CloudFront is a global service; this uses the global endpoint.
"""
out: list[DistributionRef] = []
paginator = self.cf.get_paginator("list_distributions")
for page in paginator.paginate():
items = ((page.get("DistributionList") or {}).get("Items")) or []
for d in items:
out.append(
DistributionRef(
id=d.get("Id"),
arn=d.get("ARN"),
domain_name=d.get("DomainName"),
status=d.get("Status"),
comment=d.get("Comment"),
)
)
# Filter out any weird empty items
return [d for d in out if d.id]
[docs]
def get_distribution(self, dist_id: str) -> tuple[dict[str, Any], str]:
"""Fetch the full distribution config and ETag.
Returns (DistributionConfig, ETag).
"""
resp = self.cf.get_distribution_config(Id=dist_id)
etag = resp.get("ETag")
cfg = resp.get("DistributionConfig")
if not etag or not cfg:
raise RuntimeError("CloudFront get_distribution_config returned no ETag/config")
return cfg, etag
[docs]
def get_distribution_status(self, dist_id: str) -> str:
resp = self.cf.get_distribution(Id=dist_id)
dist = resp.get("Distribution") or {}
return str(dist.get("Status") or "")
[docs]
def find_distribution_by_alias(self, alias: str) -> DistributionRef | None:
alias_c = alias.strip().lower().rstrip(".") + "."
for d in self.list_distributions():
cfg, _ = self.get_distribution(d.id)
items = ((cfg.get("Aliases") or {}).get("Items")) or []
aliases = {str(a).strip().lower().rstrip(".") + "." for a in items}
if alias_c in aliases:
return d
return None
# ---------------------------------------------------------------------
# Tag helpers
# ---------------------------------------------------------------------
[docs]
def tag_resource(self, resource_arn: str, tags: dict[str, str]) -> None:
items = [{"Key": k, "Value": v} for k, v in sorted(tags.items())]
self.cf.tag_resource(Resource=resource_arn, Tags={"Items": items})
# ---------------------------------------------------------------------
# Origin Access Control (OAC)
# ---------------------------------------------------------------------
[docs]
def list_oacs(self) -> list[OacRef]:
out: list[OacRef] = []
paginator = self.cf.get_paginator("list_origin_access_controls")
for page in paginator.paginate():
items = ((page.get("OriginAccessControlList") or {}).get("Items")) or []
for it in items:
# AWS may return either a summary shape:
# {"Id": "...", "Name": "...", ...}
# or a nested shape (older/alternate):
# {"OriginAccessControl": {"Id": "...", "Name": "...", ...}}
if isinstance(it, dict) and "Id" in it and "Name" in it:
oac_id = it.get("Id")
name = (it.get("Name") or "").strip()
else:
inner = (it.get("OriginAccessControl") or {}) if isinstance(it, dict) else {}
oac_id = inner.get("Id")
name = (inner.get("Name") or "").strip()
if oac_id and name:
out.append(OacRef(id=oac_id, name=name))
return out
[docs]
def find_oac_by_name(self, name: str) -> OacRef | None:
for o in self.list_oacs():
if o.name == name:
return o
return None
[docs]
def ensure_oac(
self,
*,
name: str,
description: str,
signing_behavior: str = "always",
signing_protocol: str = "sigv4",
origin_type: str = "s3",
) -> OacRef:
"""Ensure an OAC exists and return its id.
For MVP we create if missing; if CloudFront says it already exists,
we re-query by name and return it.
"""
existing = self.find_oac_by_name(name)
if existing:
return existing
try:
resp = self.cf.create_origin_access_control(
OriginAccessControlConfig={
"Name": name,
"Description": description,
"SigningProtocol": signing_protocol,
"SigningBehavior": signing_behavior,
"OriginAccessControlOriginType": origin_type,
}
)
except ClientError as e:
if _err_code(e) == "OriginAccessControlAlreadyExists":
# Race/duplicate: fetch again and return
existing = self.find_oac_by_name(name)
if existing:
return existing
raise
cfg = (resp.get("OriginAccessControl") or {}).get("OriginAccessControlConfig") or {}
oac_id = (resp.get("OriginAccessControl") or {}).get("Id")
oac_name = (cfg.get("Name") or name)
if not oac_id:
raise RuntimeError("CloudFront create_origin_access_control returned no Id")
return OacRef(id=oac_id, name=oac_name)
# ---------------------------------------------------------------------
# Distribution create/update (low-level)
# ---------------------------------------------------------------------
def _ensure_required_flags(self, config: dict[str, Any]) -> dict[str, Any]:
"""CloudFront API requires certain boolean flags to be present.
Some fields are mandatory even when unused. For example, CacheBehavior
requires SmoothStreaming to be explicitly set.
We defensively fill these in so planners/executors can stay focused on
convergence logic.
"""
def _ensure_forwarded_values(beh: dict[str, Any]) -> dict[str, Any]:
"""Ensure legacy ForwardedValues has required nested shapes.
CloudFront requires `ForwardedValues.Headers` to be present if
`ForwardedValues` is used (i.e., when CachePolicyId is not set).
"""
# If the modern policy id is set, ForwardedValues is not required.
if beh.get("CachePolicyId"):
return beh
fv = beh.get("ForwardedValues")
if not isinstance(fv, dict):
fv = {}
fv2 = dict(fv)
# Required booleans/structures for legacy ForwardedValues
fv2.setdefault("QueryString", False)
cookies = fv2.get("Cookies")
if not isinstance(cookies, dict):
cookies = {}
cookies2 = dict(cookies)
cookies2.setdefault("Forward", "none")
fv2["Cookies"] = cookies2
headers = fv2.get("Headers")
if not isinstance(headers, dict):
headers = {}
headers2 = dict(headers)
headers2.setdefault("Quantity", 0)
# Only include Items when Quantity > 0 (CloudFront is strict)
if headers2.get("Quantity", 0) == 0:
headers2.pop("Items", None)
fv2["Headers"] = headers2
qsk = fv2.get("QueryStringCacheKeys")
if not isinstance(qsk, dict):
qsk = {}
qsk2 = dict(qsk)
qsk2.setdefault("Quantity", 0)
if qsk2.get("Quantity", 0) == 0:
qsk2.pop("Items", None)
fv2["QueryStringCacheKeys"] = qsk2
beh2 = dict(beh)
beh2["ForwardedValues"] = fv2
return beh2
def _ensure_lambda_associations(beh: dict[str, Any]) -> dict[str, Any]:
"""Ensure LambdaFunctionAssociations shape is present.
CloudFront requires the nested Quantity (and omits Items when Quantity==0).
"""
lfa = beh.get("LambdaFunctionAssociations")
if not isinstance(lfa, dict):
lfa = {}
lfa2 = dict(lfa)
lfa2.setdefault("Quantity", 0)
if lfa2.get("Quantity", 0) == 0:
lfa2.pop("Items", None)
beh2 = dict(beh)
beh2["LambdaFunctionAssociations"] = lfa2
return beh2
def _ensure_function_associations(beh: dict[str, Any]) -> dict[str, Any]:
"""Ensure FunctionAssociations shape is present (CloudFront Functions)."""
fa = beh.get("FunctionAssociations")
if not isinstance(fa, dict):
fa = {}
fa2 = dict(fa)
fa2.setdefault("Quantity", 0)
if fa2.get("Quantity", 0) == 0:
fa2.pop("Items", None)
beh2 = dict(beh)
beh2["FunctionAssociations"] = fa2
return beh2
cfg = dict(config)
dcb = cfg.get("DefaultCacheBehavior")
# CloudFront requires CacheBehaviors to be present even when empty.
cbs = cfg.get("CacheBehaviors")
if not isinstance(cbs, dict):
cfg["CacheBehaviors"] = {"Quantity": 0}
cbs = cfg["CacheBehaviors"]
# Normalize CacheBehaviors Quantity/Items.
items = cbs.get("Items")
if items is None:
cbs2 = dict(cbs)
cbs2["Quantity"] = int(cbs2.get("Quantity") or 0)
if cbs2["Quantity"] == 0:
cbs2.pop("Items", None)
cfg["CacheBehaviors"] = cbs2
elif isinstance(items, list):
cbs2 = dict(cbs)
cbs2["Quantity"] = len(items)
if len(items) == 0:
cbs2.pop("Items", None)
cfg["CacheBehaviors"] = cbs2
if isinstance(dcb, dict):
dcb2 = dict(dcb)
dcb2.setdefault("SmoothStreaming", False)
dcb2.setdefault("Compress", False)
dcb2.setdefault("FieldLevelEncryptionId", "")
dcb2 = _ensure_forwarded_values(dcb2)
dcb2 = _ensure_lambda_associations(dcb2)
dcb2 = _ensure_function_associations(dcb2)
cfg["DefaultCacheBehavior"] = dcb2
cbs = cfg.get("CacheBehaviors")
if isinstance(cbs, dict):
items = cbs.get("Items")
if isinstance(items, list) and len(items) > 0:
new_items: list[dict[str, Any]] = []
for it in items:
if isinstance(it, dict):
it2 = dict(it)
it2.setdefault("SmoothStreaming", False)
it2.setdefault("Compress", False)
it2.setdefault("FieldLevelEncryptionId", "")
it2 = _ensure_forwarded_values(it2)
it2 = _ensure_lambda_associations(it2)
it2 = _ensure_function_associations(it2)
new_items.append(it2)
else:
new_items.append(it)
cbs2 = dict(cbs)
cbs2["Quantity"] = len(new_items)
cbs2["Items"] = new_items
cfg["CacheBehaviors"] = cbs2
return cfg
[docs]
def create_distribution(self, *, config: dict[str, Any], tags: dict[str, str] | None = None) -> DistributionRef:
"""Create a distribution from a raw DistributionConfig."""
config = self._ensure_required_flags(config)
# CloudFront has two different create APIs:
# - create_distribution(DistributionConfig=...)
# - create_distribution_with_tags(DistributionConfigWithTags=...)
# botocore validates input params strictly, so we must call the correct one.
if tags:
tag_items = [{"Key": k, "Value": v} for k, v in sorted(tags.items())]
resp = self.cf.create_distribution_with_tags(
DistributionConfigWithTags={
"DistributionConfig": config,
"Tags": {"Items": tag_items},
}
)
else:
resp = self.cf.create_distribution(DistributionConfig=config)
dist = (resp.get("Distribution") or {})
dist_cfg = (dist.get("DistributionConfig") or {})
return DistributionRef(
id=dist.get("Id"),
arn=dist.get("ARN"),
domain_name=dist.get("DomainName"),
status=dist.get("Status"),
comment=dist_cfg.get("Comment"),
)
[docs]
def update_distribution(self, *, dist_id: str, if_match: str, config: dict[str, Any]) -> DistributionRef:
config = self._ensure_required_flags(config)
resp = self.cf.update_distribution(Id=dist_id, IfMatch=if_match, DistributionConfig=config)
dist = (resp.get("Distribution") or {})
dist_cfg = (dist.get("DistributionConfig") or {})
return DistributionRef(
id=dist.get("Id"),
arn=dist.get("ARN"),
domain_name=dist.get("DomainName"),
status=dist.get("Status"),
comment=dist_cfg.get("Comment"),
)
[docs]
def disable_distribution(self, dist_id: str) -> None:
"""Disable a distribution (required before delete).
CloudFront is eventually consistent: a distribution can report Status=Deployed
from a previous state even after we submit an update. We therefore wait until
the config reflects Enabled=False and the distribution is Deployed.
"""
import time
cfg, etag = self.get_distribution(dist_id)
if cfg.get("Enabled") is False:
return
cfg2 = dict(cfg)
cfg2["Enabled"] = False
cfg2 = self._ensure_required_flags(cfg2)
# Submit the disable update.
self.cf.update_distribution(Id=dist_id, IfMatch=etag, DistributionConfig=cfg2)
# Wait for CloudFront to reflect the disabled state.
deadline = time.monotonic() + 1200 # 20 minutes
while True:
status = self.get_distribution_status(dist_id)
try:
cur_cfg, _ = self.get_distribution(dist_id)
except Exception:
cur_cfg = {}
enabled = cur_cfg.get("Enabled")
if enabled is False and status == "Deployed":
return
if time.monotonic() >= deadline:
raise TimeoutError(
f"Timed out waiting for distribution to disable (status={status}, enabled={enabled})"
)
time.sleep(10)
[docs]
def delete_distribution(self, dist_id: str) -> None:
"""Delete a distribution.
CloudFront requires the distribution to be disabled and deployed before delete.
"""
_, etag = self.get_distribution(dist_id)
self.cf.delete_distribution(Id=dist_id, IfMatch=etag)
# ---------------------------------------------------------------------
# Waiters
# ---------------------------------------------------------------------
[docs]
def wait_deployed(self, dist_id: str) -> None:
"""Wait for a distribution to reach Deployed."""
waiter = self.cf.get_waiter("distribution_deployed")
waiter.wait(Id=dist_id)