Source code for dxaws_s3.planner

"""Planner for dxaws-s3.

Computes a deterministic list of actions by comparing desired vs current state.
"""

from __future__ import annotations

import json
import re
from dataclasses import dataclass
from typing import Any

from .constants import POLICY_VERSION
from .exceptions import InvalidDesiredStateError, MissingCloudFrontInfoError
from .models import (
    ActionType,
    BucketPurpose,
    OperationType,
    S3BucketCurrent,
    S3BucketDesired,
    S3BucketOperation,
    S3BucketPlan,
)


_CLOUDFRONT_DIST_ARN_RE = re.compile(r"^arn:aws:cloudfront::\d{12}:distribution\/[A-Z0-9]+$")


def _validate_cloudfront_distribution_arns(arns: list[str]) -> None:
    """Validate CloudFront distribution ARNs.

    This lives in the module planner so any caller (CLI, library, other program)
    gets consistent validation.

    Common failure we want to catch: users passing placeholder values like
    `arn:aws:cloudfront::...:distribution/...`.
    """
    for a in arns:
        if not a:
            continue
        if "..." in a:
            raise InvalidDesiredStateError(
                "Invalid CloudFront distribution ARN (contains '...'). Provide the full ARN, e.g. "
                "arn:aws:cloudfront::123456789012:distribution/EXAMPLE123"
            )
        if not _CLOUDFRONT_DIST_ARN_RE.match(a):
            raise InvalidDesiredStateError(
                "Invalid CloudFront distribution ARN format. Expected "
                "arn:aws:cloudfront::123456789012:distribution/EXAMPLE123"
            )


[docs] @dataclass(frozen=True) class PlanOptions: """Planner behavior flags.""" strict_cloudfront: bool = False
def _validate(desired: S3BucketDesired, *, options: PlanOptions) -> None: # Validate CloudFront distribution ARNs if provided. arns = desired.cloudfront_distribution_arns if arns: _validate_cloudfront_distribution_arns(arns) if desired.encryption == "SSE-KMS" and not desired.kms_key_arn: raise InvalidDesiredStateError( "encryption=SSE-KMS requires kms_key_arn" ) if desired.purpose == BucketPurpose.CLOUDFRONT_ORIGIN: if desired.block_public_access is False: raise InvalidDesiredStateError( "purpose=cloudfront_origin requires block_public_access=True" ) if options.strict_cloudfront: cf = desired.cloudfront if not cf or not cf.distribution_arn: raise MissingCloudFrontInfoError( "strict_cloudfront=True requires cloudfront.distribution_arn" ) def _policy_deny_insecure_transport(bucket_arn: str) -> dict[str, Any]: return { "Sid": "DenyInsecureTransport", "Effect": "Deny", "Principal": "*", "Action": "s3:*", "Resource": [bucket_arn, f"{bucket_arn}/*"], "Condition": { "Bool": { "aws:SecureTransport": "false" } }, } def _policy_allow_cloudfront_oac_read( *, bucket_arn: str, distribution_arns: list[str], ) -> dict[str, Any]: """Allow CloudFront OAC to read objects from this bucket. We scope access to specific distribution ARNs via AWS:SourceArn. When multiple distributions are provided, we emit a single statement with a list-valued StringEquals condition (CloudFront/S3 support this). """ arns = [a for a in distribution_arns if a] # Deterministic ordering arns = sorted(set(arns)) return { "Sid": "AllowCloudFrontServiceRead", "Effect": "Allow", "Principal": {"Service": "cloudfront.amazonaws.com"}, "Action": ["s3:GetObject"], "Resource": [f"{bucket_arn}/*"], "Condition": {"StringEquals": {"AWS:SourceArn": arns}}, } def _desired_bucket_policy_json( desired: S3BucketDesired, current: S3BucketCurrent, ) -> str | None: # Prefer the observed ARN when available, but derive it deterministically # from the desired name for first-run planning (bucket may not exist yet). bucket_arn = current.arn or f"arn:aws:s3:::{desired.name}" statements: list[dict[str, Any]] = [] if desired.enforce_tls: statements.append( _policy_deny_insecure_transport(bucket_arn) ) # If CloudFront distribution ARNs are provided, allow CloudFront OAC read access. # Do this regardless of bucket purpose so the CLI can opt-in via flags without # requiring a specific purpose value. arns = desired.cloudfront_distribution_arns if arns: statements.append( _policy_allow_cloudfront_oac_read( bucket_arn=bucket_arn, distribution_arns=arns, ) ) if not statements: return None policy = { "Version": POLICY_VERSION, "Statement": statements, } return json.dumps( policy, sort_keys=True, separators=(",", ":"), )
[docs] def plan( desired: S3BucketDesired, current: S3BucketCurrent, *, options: PlanOptions | None = None, ) -> S3BucketPlan: options = options or PlanOptions() _validate(desired, options=options) operations: list[S3BucketOperation] = [] if not current.exists: operations.append( S3BucketOperation.for_bucket( op_type=OperationType.create, type=ActionType.CREATE_BUCKET, bucket_name=desired.name, reason="bucket does not exist", payload={ "name": desired.name, "region": desired.region, }, ) ) if current.block_public_access != desired.block_public_access: operations.append( S3BucketOperation.for_bucket( op_type=OperationType.update, type=ActionType.PUT_PUBLIC_ACCESS_BLOCK, bucket_name=desired.name, reason="public access block differs", payload={"enabled": desired.block_public_access}, ) ) if current.encryption != desired.encryption: operations.append( S3BucketOperation.for_bucket( op_type=OperationType.update, type=ActionType.PUT_ENCRYPTION, bucket_name=desired.name, reason="encryption differs", payload={ "encryption": desired.encryption, "kms_key_arn": desired.kms_key_arn, }, ) ) if current.versioning != desired.versioning: operations.append( S3BucketOperation.for_bucket( op_type=OperationType.update, type=ActionType.PUT_VERSIONING, bucket_name=desired.name, reason="versioning differs", payload={"enabled": desired.versioning}, ) ) if current.tags != desired.tags: operations.append( S3BucketOperation.for_bucket( op_type=OperationType.update, type=ActionType.PUT_TAGS, bucket_name=desired.name, reason="tags differ", payload={"tags": desired.tags}, ) ) desired_policy = _desired_bucket_policy_json(desired, current) if desired_policy is not None: if not current.bucket_policy_json: operations.append( S3BucketOperation.for_bucket( op_type=OperationType.update, type=ActionType.PUT_BUCKET_POLICY, bucket_name=desired.name, reason="policy required but missing", payload={"policy_json": desired_policy}, ) ) else: try: cur_obj = json.loads(current.bucket_policy_json) des_obj = json.loads(desired_policy) if cur_obj != des_obj: operations.append( S3BucketOperation.for_bucket( op_type=OperationType.update, type=ActionType.PUT_BUCKET_POLICY, bucket_name=desired.name, reason="policy differs", payload={"policy_json": desired_policy}, ) ) except Exception: operations.append( S3BucketOperation.for_bucket( op_type=OperationType.update, type=ActionType.PUT_BUCKET_POLICY, bucket_name=desired.name, reason="policy unparsable; overwriting", payload={"policy_json": desired_policy}, ) ) return S3BucketPlan( desired=desired, current=current, operations=operations, )