Source code for dxaws_s3.providers.aws

"""AWS provider for dxaws-s3.

This provider is responsible for translating boto3 state into our AWS-agnostic
`S3BucketCurrent` model, and applying changes requested by the executor.

We implement `get_current()` first because it's the most subtle/important part
of convergence.
"""

from __future__ import annotations

from dataclasses import replace
from typing import Any

import boto3
from botocore.exceptions import ClientError

from .base import S3Provider
from ..models import S3BucketCurrent, S3BucketDesired, S3BucketResult


def _err_code(e: ClientError) -> str:
    return str(e.response.get("Error", {}).get("Code", ""))


def _bucket_region_from_location(location: str | None) -> str:
    # S3 GetBucketLocation quirks:
    # - us-east-1 returns None
    # - legacy "EU" means eu-west-1
    if not location:
        return "us-east-1"
    if location == "EU":
        return "eu-west-1"
    return location


def _regional_domain_name(bucket: str, region: str) -> str:
    # Keep this conservative. S3 has several endpoint patterns and partitions.
    # For our purposes (CloudFront origin, etc.) this is sufficient.
    if region == "us-east-1":
        return f"{bucket}.s3.amazonaws.com"
    return f"{bucket}.s3.{region}.amazonaws.com"


[docs] class AwsS3Provider(S3Provider): def __init__(self, *, region: str | None = None, aws: AwsSession | None = None): self._region = region self._session_obj: boto3.session.Session | None = None def _session(self): if self._session_obj is None: self._session_obj = boto3.session.Session(region_name=self._region) return self._session_obj def _s3(self): # For S3, many calls are global-ish but still benefit from a region. # Bucket-specific reads can redirect; boto3 handles that. return self._session().client("s3")
[docs] def get_current(self, name: str) -> S3BucketCurrent: s3 = self._s3() # Start with the minimal shape. current = S3BucketCurrent(exists=False, name=name) # --- existence check ------------------------------------------------ try: s3.head_bucket(Bucket=name) current = replace(current, exists=True) except ClientError as e: code = _err_code(e) # These codes vary a bit; be conservative. if code in {"404", "NoSuchBucket", "NotFound"}: return current # Access denied means it may exist but we can't inspect details. if code in {"403", "AccessDenied"}: return replace(current, exists=True) raise # If it exists, we can always set the ARN. bucket_arn = f"arn:aws:s3:::{name}" current = replace(current, arn=bucket_arn) # --- region / domain ------------------------------------------------ try: resp = s3.get_bucket_location(Bucket=name) bucket_region = _bucket_region_from_location(resp.get("LocationConstraint")) current = replace( current, regional_domain_name=_regional_domain_name(name, bucket_region), ) except ClientError as e: # Non-fatal. if _err_code(e) not in {"AccessDenied"}: raise # --- tags ----------------------------------------------------------- try: resp = s3.get_bucket_tagging(Bucket=name) tagset = resp.get("TagSet", []) tags = {t["Key"]: t["Value"] for t in tagset if "Key" in t and "Value" in t} current = replace(current, tags=tags) except ClientError as e: code = _err_code(e) if code in { "NoSuchTagSet", "NoSuchTagSetError", "NoSuchTagSetConfiguration", "NoSuchTagSetException", }: current = replace(current, tags={}) elif code in {"AccessDenied"}: pass else: raise # --- public access block ------------------------------------------- try: resp = s3.get_public_access_block(Bucket=name) cfg = resp.get("PublicAccessBlockConfiguration") or {} enabled = bool( cfg.get("BlockPublicAcls", False) and cfg.get("IgnorePublicAcls", False) and cfg.get("BlockPublicPolicy", False) and cfg.get("RestrictPublicBuckets", False) ) current = replace(current, block_public_access=enabled) except ClientError as e: code = _err_code(e) if code in { "NoSuchPublicAccessBlockConfiguration", "NoSuchPublicAccessBlock", "NoSuchConfiguration", }: current = replace(current, block_public_access=None) elif code in {"AccessDenied"}: pass else: raise # --- encryption ----------------------------------------------------- try: resp = s3.get_bucket_encryption(Bucket=name) rules = (resp.get("ServerSideEncryptionConfiguration", {}) or {}).get("Rules", []) mode: str | None = None kms_arn: str | None = None if rules: sse = rules[0].get("ApplyServerSideEncryptionByDefault") or {} algo = sse.get("SSEAlgorithm") if algo == "AES256": mode = "SSE-S3" elif algo == "aws:kms": mode = "SSE-KMS" kms_arn = sse.get("KMSMasterKeyID") current = replace(current, encryption=mode, kms_key_arn=kms_arn) except ClientError as e: code = _err_code(e) if code in { "ServerSideEncryptionConfigurationNotFoundError", "NoSuchServerSideEncryptionConfiguration", }: current = replace(current, encryption="NONE", kms_key_arn=None) elif code in {"AccessDenied"}: pass else: raise # --- versioning ----------------------------------------------------- try: resp = s3.get_bucket_versioning(Bucket=name) status = resp.get("Status") if status is None: current = replace(current, versioning=False) else: current = replace(current, versioning=(status == "Enabled")) except ClientError as e: if _err_code(e) == "AccessDenied": pass else: raise # --- bucket policy -------------------------------------------------- try: resp = s3.get_bucket_policy(Bucket=name) policy_str = resp.get("Policy") current = replace(current, bucket_policy_json=policy_str) except ClientError as e: code = _err_code(e) if code in {"NoSuchBucketPolicy", "NoSuchPolicy"}: current = replace(current, bucket_policy_json=None) elif code in {"AccessDenied"}: pass else: raise return current
# --- create / mutate ---------------------------------------------------
[docs] def create_bucket(self, desired: S3BucketDesired) -> None: s3 = self._s3() # Region handling: CreateBucketConfiguration is not used for us-east-1. region = desired.region or self._region or self._session().region_name or "us-east-1" kwargs: dict[str, Any] = {"Bucket": desired.name} if region != "us-east-1": kwargs["CreateBucketConfiguration"] = {"LocationConstraint": region} s3.create_bucket(**kwargs)
[docs] def put_tags(self, name: str, tags: dict[str, str]) -> None: s3 = self._s3() tagset = [{"Key": k, "Value": v} for k, v in sorted(tags.items())] s3.put_bucket_tagging(Bucket=name, Tagging={"TagSet": tagset})
[docs] def put_public_access_block(self, name: str, enabled: bool) -> None: s3 = self._s3() cfg = { "BlockPublicAcls": enabled, "IgnorePublicAcls": enabled, "BlockPublicPolicy": enabled, "RestrictPublicBuckets": enabled, } s3.put_public_access_block(Bucket=name, PublicAccessBlockConfiguration=cfg)
[docs] def put_encryption(self, name: str, encryption: str, kms_key_arn: str | None) -> None: s3 = self._s3() if encryption == "NONE": s3.delete_bucket_encryption(Bucket=name) return if encryption == "SSE-S3": rule = {"ApplyServerSideEncryptionByDefault": {"SSEAlgorithm": "AES256"}} elif encryption == "SSE-KMS": if not kms_key_arn: raise ValueError("SSE-KMS requires kms_key_arn") rule = { "ApplyServerSideEncryptionByDefault": { "SSEAlgorithm": "aws:kms", "KMSMasterKeyID": kms_key_arn, } } else: raise ValueError(f"Unknown encryption mode: {encryption}") s3.put_bucket_encryption( Bucket=name, ServerSideEncryptionConfiguration={"Rules": [rule]}, )
[docs] def put_versioning(self, name: str, enabled: bool) -> None: s3 = self._s3() s3.put_bucket_versioning( Bucket=name, VersioningConfiguration={"Status": "Enabled" if enabled else "Suspended"}, )
[docs] def put_bucket_policy(self, name: str, policy_json: str) -> None: s3 = self._s3() s3.put_bucket_policy(Bucket=name, Policy=policy_json)
# --- destroy / cleanup -------------------------------------------------
[docs] def delete_all_objects(self, name: str) -> None: """Delete all objects (and versions, if present) from the bucket.""" s3 = self._s3() # Delete versioned objects (if versioning is enabled or suspended). try: paginator = s3.get_paginator("list_object_versions") for page in paginator.paginate(Bucket=name): to_delete: list[dict[str, str]] = [] for v in page.get("Versions", []) or []: to_delete.append({"Key": v["Key"], "VersionId": v["VersionId"]}) for m in page.get("DeleteMarkers", []) or []: to_delete.append({"Key": m["Key"], "VersionId": m["VersionId"]}) # S3 DeleteObjects supports up to 1000 keys per request. for i in range(0, len(to_delete), 1000): chunk = to_delete[i : i + 1000] if chunk: s3.delete_objects(Bucket=name, Delete={"Objects": chunk, "Quiet": True}) except ClientError as e: if _err_code(e) not in {"NoSuchBucket", "AccessDenied"}: raise # Delete current objects (covers non-versioned buckets). try: paginator = s3.get_paginator("list_objects_v2") for page in paginator.paginate(Bucket=name): objs = page.get("Contents", []) or [] if not objs: continue to_delete = [{"Key": o["Key"]} for o in objs] for i in range(0, len(to_delete), 1000): chunk = to_delete[i : i + 1000] if chunk: s3.delete_objects(Bucket=name, Delete={"Objects": chunk, "Quiet": True}) except ClientError as e: code = _err_code(e) if code in {"NoSuchBucket", "AccessDenied"}: return raise
[docs] def delete_bucket(self, name: str) -> None: """Delete the bucket (must be empty first).""" s3 = self._s3() try: s3.delete_bucket(Bucket=name) except ClientError as e: if _err_code(e) == "NoSuchBucket": return raise
[docs] def destroy_bucket(self, name: str) -> None: """Best-effort bucket destroy: empty then delete.""" self.delete_all_objects(name) self.delete_bucket(name)
# --- outputs -----------------------------------------------------------
[docs] def get_outputs(self, name: str) -> S3BucketResult: cur = self.get_current(name) if not cur.exists: raise ValueError(f"Bucket does not exist: {name}") arn = cur.arn or f"arn:aws:s3:::{name}" regional_domain_name = cur.regional_domain_name or f"{name}.s3.amazonaws.com" return S3BucketResult(name=name, arn=arn, regional_domain_name=regional_domain_name)