"""AWS provider for dxaws-s3.
This provider is responsible for translating boto3 state into our AWS-agnostic
`S3BucketCurrent` model, and applying changes requested by the executor.
We implement `get_current()` first because it's the most subtle/important part
of convergence.
"""
from __future__ import annotations
from dataclasses import replace
from typing import Any
import boto3
from botocore.exceptions import ClientError
from .base import S3Provider
from ..models import S3BucketCurrent, S3BucketDesired, S3BucketResult
def _err_code(e: ClientError) -> str:
return str(e.response.get("Error", {}).get("Code", ""))
def _bucket_region_from_location(location: str | None) -> str:
# S3 GetBucketLocation quirks:
# - us-east-1 returns None
# - legacy "EU" means eu-west-1
if not location:
return "us-east-1"
if location == "EU":
return "eu-west-1"
return location
def _regional_domain_name(bucket: str, region: str) -> str:
# Keep this conservative. S3 has several endpoint patterns and partitions.
# For our purposes (CloudFront origin, etc.) this is sufficient.
if region == "us-east-1":
return f"{bucket}.s3.amazonaws.com"
return f"{bucket}.s3.{region}.amazonaws.com"
[docs]
class AwsS3Provider(S3Provider):
def __init__(self, *, region: str | None = None, aws: AwsSession | None = None):
self._region = region
self._session_obj: boto3.session.Session | None = None
def _session(self):
if self._session_obj is None:
self._session_obj = boto3.session.Session(region_name=self._region)
return self._session_obj
def _s3(self):
# For S3, many calls are global-ish but still benefit from a region.
# Bucket-specific reads can redirect; boto3 handles that.
return self._session().client("s3")
[docs]
def get_current(self, name: str) -> S3BucketCurrent:
s3 = self._s3()
# Start with the minimal shape.
current = S3BucketCurrent(exists=False, name=name)
# --- existence check ------------------------------------------------
try:
s3.head_bucket(Bucket=name)
current = replace(current, exists=True)
except ClientError as e:
code = _err_code(e)
# These codes vary a bit; be conservative.
if code in {"404", "NoSuchBucket", "NotFound"}:
return current
# Access denied means it may exist but we can't inspect details.
if code in {"403", "AccessDenied"}:
return replace(current, exists=True)
raise
# If it exists, we can always set the ARN.
bucket_arn = f"arn:aws:s3:::{name}"
current = replace(current, arn=bucket_arn)
# --- region / domain ------------------------------------------------
try:
resp = s3.get_bucket_location(Bucket=name)
bucket_region = _bucket_region_from_location(resp.get("LocationConstraint"))
current = replace(
current,
regional_domain_name=_regional_domain_name(name, bucket_region),
)
except ClientError as e:
# Non-fatal.
if _err_code(e) not in {"AccessDenied"}:
raise
# --- tags -----------------------------------------------------------
try:
resp = s3.get_bucket_tagging(Bucket=name)
tagset = resp.get("TagSet", [])
tags = {t["Key"]: t["Value"] for t in tagset if "Key" in t and "Value" in t}
current = replace(current, tags=tags)
except ClientError as e:
code = _err_code(e)
if code in {
"NoSuchTagSet",
"NoSuchTagSetError",
"NoSuchTagSetConfiguration",
"NoSuchTagSetException",
}:
current = replace(current, tags={})
elif code in {"AccessDenied"}:
pass
else:
raise
# --- public access block -------------------------------------------
try:
resp = s3.get_public_access_block(Bucket=name)
cfg = resp.get("PublicAccessBlockConfiguration") or {}
enabled = bool(
cfg.get("BlockPublicAcls", False)
and cfg.get("IgnorePublicAcls", False)
and cfg.get("BlockPublicPolicy", False)
and cfg.get("RestrictPublicBuckets", False)
)
current = replace(current, block_public_access=enabled)
except ClientError as e:
code = _err_code(e)
if code in {
"NoSuchPublicAccessBlockConfiguration",
"NoSuchPublicAccessBlock",
"NoSuchConfiguration",
}:
current = replace(current, block_public_access=None)
elif code in {"AccessDenied"}:
pass
else:
raise
# --- encryption -----------------------------------------------------
try:
resp = s3.get_bucket_encryption(Bucket=name)
rules = (resp.get("ServerSideEncryptionConfiguration", {}) or {}).get("Rules", [])
mode: str | None = None
kms_arn: str | None = None
if rules:
sse = rules[0].get("ApplyServerSideEncryptionByDefault") or {}
algo = sse.get("SSEAlgorithm")
if algo == "AES256":
mode = "SSE-S3"
elif algo == "aws:kms":
mode = "SSE-KMS"
kms_arn = sse.get("KMSMasterKeyID")
current = replace(current, encryption=mode, kms_key_arn=kms_arn)
except ClientError as e:
code = _err_code(e)
if code in {
"ServerSideEncryptionConfigurationNotFoundError",
"NoSuchServerSideEncryptionConfiguration",
}:
current = replace(current, encryption="NONE", kms_key_arn=None)
elif code in {"AccessDenied"}:
pass
else:
raise
# --- versioning -----------------------------------------------------
try:
resp = s3.get_bucket_versioning(Bucket=name)
status = resp.get("Status")
if status is None:
current = replace(current, versioning=False)
else:
current = replace(current, versioning=(status == "Enabled"))
except ClientError as e:
if _err_code(e) == "AccessDenied":
pass
else:
raise
# --- bucket policy --------------------------------------------------
try:
resp = s3.get_bucket_policy(Bucket=name)
policy_str = resp.get("Policy")
current = replace(current, bucket_policy_json=policy_str)
except ClientError as e:
code = _err_code(e)
if code in {"NoSuchBucketPolicy", "NoSuchPolicy"}:
current = replace(current, bucket_policy_json=None)
elif code in {"AccessDenied"}:
pass
else:
raise
return current
# --- create / mutate ---------------------------------------------------
[docs]
def create_bucket(self, desired: S3BucketDesired) -> None:
s3 = self._s3()
# Region handling: CreateBucketConfiguration is not used for us-east-1.
region = desired.region or self._region or self._session().region_name or "us-east-1"
kwargs: dict[str, Any] = {"Bucket": desired.name}
if region != "us-east-1":
kwargs["CreateBucketConfiguration"] = {"LocationConstraint": region}
s3.create_bucket(**kwargs)
[docs]
def put_public_access_block(self, name: str, enabled: bool) -> None:
s3 = self._s3()
cfg = {
"BlockPublicAcls": enabled,
"IgnorePublicAcls": enabled,
"BlockPublicPolicy": enabled,
"RestrictPublicBuckets": enabled,
}
s3.put_public_access_block(Bucket=name, PublicAccessBlockConfiguration=cfg)
[docs]
def put_encryption(self, name: str, encryption: str, kms_key_arn: str | None) -> None:
s3 = self._s3()
if encryption == "NONE":
s3.delete_bucket_encryption(Bucket=name)
return
if encryption == "SSE-S3":
rule = {"ApplyServerSideEncryptionByDefault": {"SSEAlgorithm": "AES256"}}
elif encryption == "SSE-KMS":
if not kms_key_arn:
raise ValueError("SSE-KMS requires kms_key_arn")
rule = {
"ApplyServerSideEncryptionByDefault": {
"SSEAlgorithm": "aws:kms",
"KMSMasterKeyID": kms_key_arn,
}
}
else:
raise ValueError(f"Unknown encryption mode: {encryption}")
s3.put_bucket_encryption(
Bucket=name,
ServerSideEncryptionConfiguration={"Rules": [rule]},
)
[docs]
def put_versioning(self, name: str, enabled: bool) -> None:
s3 = self._s3()
s3.put_bucket_versioning(
Bucket=name,
VersioningConfiguration={"Status": "Enabled" if enabled else "Suspended"},
)
[docs]
def put_bucket_policy(self, name: str, policy_json: str) -> None:
s3 = self._s3()
s3.put_bucket_policy(Bucket=name, Policy=policy_json)
# --- destroy / cleanup -------------------------------------------------
[docs]
def delete_all_objects(self, name: str) -> None:
"""Delete all objects (and versions, if present) from the bucket."""
s3 = self._s3()
# Delete versioned objects (if versioning is enabled or suspended).
try:
paginator = s3.get_paginator("list_object_versions")
for page in paginator.paginate(Bucket=name):
to_delete: list[dict[str, str]] = []
for v in page.get("Versions", []) or []:
to_delete.append({"Key": v["Key"], "VersionId": v["VersionId"]})
for m in page.get("DeleteMarkers", []) or []:
to_delete.append({"Key": m["Key"], "VersionId": m["VersionId"]})
# S3 DeleteObjects supports up to 1000 keys per request.
for i in range(0, len(to_delete), 1000):
chunk = to_delete[i : i + 1000]
if chunk:
s3.delete_objects(Bucket=name, Delete={"Objects": chunk, "Quiet": True})
except ClientError as e:
if _err_code(e) not in {"NoSuchBucket", "AccessDenied"}:
raise
# Delete current objects (covers non-versioned buckets).
try:
paginator = s3.get_paginator("list_objects_v2")
for page in paginator.paginate(Bucket=name):
objs = page.get("Contents", []) or []
if not objs:
continue
to_delete = [{"Key": o["Key"]} for o in objs]
for i in range(0, len(to_delete), 1000):
chunk = to_delete[i : i + 1000]
if chunk:
s3.delete_objects(Bucket=name, Delete={"Objects": chunk, "Quiet": True})
except ClientError as e:
code = _err_code(e)
if code in {"NoSuchBucket", "AccessDenied"}:
return
raise
[docs]
def delete_bucket(self, name: str) -> None:
"""Delete the bucket (must be empty first)."""
s3 = self._s3()
try:
s3.delete_bucket(Bucket=name)
except ClientError as e:
if _err_code(e) == "NoSuchBucket":
return
raise
[docs]
def destroy_bucket(self, name: str) -> None:
"""Best-effort bucket destroy: empty then delete."""
self.delete_all_objects(name)
self.delete_bucket(name)
# --- outputs -----------------------------------------------------------
[docs]
def get_outputs(self, name: str) -> S3BucketResult:
cur = self.get_current(name)
if not cur.exists:
raise ValueError(f"Bucket does not exist: {name}")
arn = cur.arn or f"arn:aws:s3:::{name}"
regional_domain_name = cur.regional_domain_name or f"{name}.s3.amazonaws.com"
return S3BucketResult(name=name, arn=arn, regional_domain_name=regional_domain_name)