"""Planner for dxaws-s3.
Computes a deterministic list of actions by comparing desired vs current state.
"""
from __future__ import annotations
import json
import re
from dataclasses import dataclass
from typing import Any
from .constants import POLICY_VERSION
from .exceptions import InvalidDesiredStateError, MissingCloudFrontInfoError
from .models import (
ActionType,
BucketPurpose,
OperationType,
S3BucketCurrent,
S3BucketDesired,
S3BucketOperation,
S3BucketPlan,
)
_CLOUDFRONT_DIST_ARN_RE = re.compile(r"^arn:aws:cloudfront::\d{12}:distribution\/[A-Z0-9]+$")
def _validate_cloudfront_distribution_arns(arns: list[str]) -> None:
"""Validate CloudFront distribution ARNs.
This lives in the module planner so any caller (CLI, library, other program)
gets consistent validation.
Common failure we want to catch: users passing placeholder values like
`arn:aws:cloudfront::...:distribution/...`.
"""
for a in arns:
if not a:
continue
if "..." in a:
raise InvalidDesiredStateError(
"Invalid CloudFront distribution ARN (contains '...'). Provide the full ARN, e.g. "
"arn:aws:cloudfront::123456789012:distribution/EXAMPLE123"
)
if not _CLOUDFRONT_DIST_ARN_RE.match(a):
raise InvalidDesiredStateError(
"Invalid CloudFront distribution ARN format. Expected "
"arn:aws:cloudfront::123456789012:distribution/EXAMPLE123"
)
[docs]
@dataclass(frozen=True)
class PlanOptions:
"""Planner behavior flags."""
strict_cloudfront: bool = False
def _validate(desired: S3BucketDesired, *, options: PlanOptions) -> None:
# Validate CloudFront distribution ARNs if provided.
arns = desired.cloudfront_distribution_arns
if arns:
_validate_cloudfront_distribution_arns(arns)
if desired.encryption == "SSE-KMS" and not desired.kms_key_arn:
raise InvalidDesiredStateError(
"encryption=SSE-KMS requires kms_key_arn"
)
if desired.purpose == BucketPurpose.CLOUDFRONT_ORIGIN:
if desired.block_public_access is False:
raise InvalidDesiredStateError(
"purpose=cloudfront_origin requires block_public_access=True"
)
if options.strict_cloudfront:
cf = desired.cloudfront
if not cf or not cf.distribution_arn:
raise MissingCloudFrontInfoError(
"strict_cloudfront=True requires cloudfront.distribution_arn"
)
def _policy_deny_insecure_transport(bucket_arn: str) -> dict[str, Any]:
return {
"Sid": "DenyInsecureTransport",
"Effect": "Deny",
"Principal": "*",
"Action": "s3:*",
"Resource": [bucket_arn, f"{bucket_arn}/*"],
"Condition": {
"Bool": {
"aws:SecureTransport": "false"
}
},
}
def _policy_allow_cloudfront_oac_read(
*,
bucket_arn: str,
distribution_arns: list[str],
) -> dict[str, Any]:
"""Allow CloudFront OAC to read objects from this bucket.
We scope access to specific distribution ARNs via AWS:SourceArn.
When multiple distributions are provided, we emit a single statement
with a list-valued StringEquals condition (CloudFront/S3 support this).
"""
arns = [a for a in distribution_arns if a]
# Deterministic ordering
arns = sorted(set(arns))
return {
"Sid": "AllowCloudFrontServiceRead",
"Effect": "Allow",
"Principal": {"Service": "cloudfront.amazonaws.com"},
"Action": ["s3:GetObject"],
"Resource": [f"{bucket_arn}/*"],
"Condition": {"StringEquals": {"AWS:SourceArn": arns}},
}
def _desired_bucket_policy_json(
desired: S3BucketDesired,
current: S3BucketCurrent,
) -> str | None:
# Prefer the observed ARN when available, but derive it deterministically
# from the desired name for first-run planning (bucket may not exist yet).
bucket_arn = current.arn or f"arn:aws:s3:::{desired.name}"
statements: list[dict[str, Any]] = []
if desired.enforce_tls:
statements.append(
_policy_deny_insecure_transport(bucket_arn)
)
# If CloudFront distribution ARNs are provided, allow CloudFront OAC read access.
# Do this regardless of bucket purpose so the CLI can opt-in via flags without
# requiring a specific purpose value.
arns = desired.cloudfront_distribution_arns
if arns:
statements.append(
_policy_allow_cloudfront_oac_read(
bucket_arn=bucket_arn,
distribution_arns=arns,
)
)
if not statements:
return None
policy = {
"Version": POLICY_VERSION,
"Statement": statements,
}
return json.dumps(
policy,
sort_keys=True,
separators=(",", ":"),
)
[docs]
def plan(
desired: S3BucketDesired,
current: S3BucketCurrent,
*,
options: PlanOptions | None = None,
) -> S3BucketPlan:
options = options or PlanOptions()
_validate(desired, options=options)
operations: list[S3BucketOperation] = []
if not current.exists:
operations.append(
S3BucketOperation.for_bucket(
op_type=OperationType.create,
type=ActionType.CREATE_BUCKET,
bucket_name=desired.name,
reason="bucket does not exist",
payload={
"name": desired.name,
"region": desired.region,
},
)
)
if current.block_public_access != desired.block_public_access:
operations.append(
S3BucketOperation.for_bucket(
op_type=OperationType.update,
type=ActionType.PUT_PUBLIC_ACCESS_BLOCK,
bucket_name=desired.name,
reason="public access block differs",
payload={"enabled": desired.block_public_access},
)
)
if current.encryption != desired.encryption:
operations.append(
S3BucketOperation.for_bucket(
op_type=OperationType.update,
type=ActionType.PUT_ENCRYPTION,
bucket_name=desired.name,
reason="encryption differs",
payload={
"encryption": desired.encryption,
"kms_key_arn": desired.kms_key_arn,
},
)
)
if current.versioning != desired.versioning:
operations.append(
S3BucketOperation.for_bucket(
op_type=OperationType.update,
type=ActionType.PUT_VERSIONING,
bucket_name=desired.name,
reason="versioning differs",
payload={"enabled": desired.versioning},
)
)
if current.tags != desired.tags:
operations.append(
S3BucketOperation.for_bucket(
op_type=OperationType.update,
type=ActionType.PUT_TAGS,
bucket_name=desired.name,
reason="tags differ",
payload={"tags": desired.tags},
)
)
desired_policy = _desired_bucket_policy_json(desired, current)
if desired_policy is not None:
if not current.bucket_policy_json:
operations.append(
S3BucketOperation.for_bucket(
op_type=OperationType.update,
type=ActionType.PUT_BUCKET_POLICY,
bucket_name=desired.name,
reason="policy required but missing",
payload={"policy_json": desired_policy},
)
)
else:
try:
cur_obj = json.loads(current.bucket_policy_json)
des_obj = json.loads(desired_policy)
if cur_obj != des_obj:
operations.append(
S3BucketOperation.for_bucket(
op_type=OperationType.update,
type=ActionType.PUT_BUCKET_POLICY,
bucket_name=desired.name,
reason="policy differs",
payload={"policy_json": desired_policy},
)
)
except Exception:
operations.append(
S3BucketOperation.for_bucket(
op_type=OperationType.update,
type=ActionType.PUT_BUCKET_POLICY,
bucket_name=desired.name,
reason="policy unparsable; overwriting",
payload={"policy_json": desired_policy},
)
)
return S3BucketPlan(
desired=desired,
current=current,
operations=operations,
)