"""Manager for dxaws-dns record sets."""
from __future__ import annotations
import time
from dataclasses import dataclass, field, replace
from typing import Any, Sequence
from dxaws_core import AwsSession
from dxaws_core.o11y import O11y
from dxaws_dns.models import DnsManagerResult, DnsRecordCurrent, DnsRecordDesired, DnsRecordPlan
from dxaws_dns.record_executor import apply_record_plan
from dxaws_dns.record_planner import normalize_record_name, normalize_record_type, plan_record
from dxaws_dns.providers.base import DnsRecordProvider
from dxaws_dns.providers.aws import Route53Provider
[docs]
@dataclass(frozen=True, slots=True, kw_only=True)
class ApplyOptions:
emit_events: bool = True
[docs]
@dataclass(frozen=True, slots=True, kw_only=True)
class ExecuteOptions:
apply_options: ApplyOptions | None = None
emit_events: bool = True
[docs]
@dataclass(frozen=True, slots=True, kw_only=True)
class DnsManager:
# Provider is optional for the common case.
# If omitted, the manager will default to the AWS provider.
provider: DnsRecordProvider | None = None
o11y: O11y = field(default_factory=O11y.noop)
def _provider(self) -> DnsRecordProvider:
if self.provider is not None:
return self.provider
return Route53Provider(aws=AwsSession())
def _emit(self, event: str, **fields: Any) -> None:
self.o11y.info(event, **fields)
def _plan_summary(self, plan: DnsRecordPlan) -> dict[str, Any]:
return {
"action": plan.action,
"record_fqdn": plan.record_fqdn,
"record_type": plan.record_type,
"changes": plan.changes,
}
def _normalize_desired(self, desired: DnsRecordDesired) -> DnsRecordDesired:
"""Normalize/complete desired inputs before calling provider/planner.
In particular, allow callers (including acceptance tests) to specify only
`zone_name`; if `zone_id` is missing, resolve it via the provider.
Also normalize TXT values to quoted form so idempotency is stable.
"""
zone_id = getattr(desired, "zone_id", None)
if not zone_id:
zone_name = getattr(desired, "zone_name", None)
if zone_name:
resolver = getattr(self._provider(), "resolve_zone_id", None)
if resolver is None:
raise ValueError(
"zone_name was provided but the provider does not implement resolve_zone_id(zone_name=...)"
)
zone_id = resolver(zone_name=zone_name)
else:
# No explicit zone supplied: resolve from the record FQDN.
record_name = str(desired.record_name)
if "." not in record_name.rstrip("."):
raise ValueError(
"zone_id or zone_name is required when record_name is not a fully-qualified domain name"
)
zone_id = self.resolve_zone_for_fqdn(record_name)
# Normalize TXT values for stable diffs
record_type = normalize_record_type(desired.record_type)
values = desired.values
if values is not None and record_type == "TXT":
normalized: list[str] = []
for v in values:
s = str(v)
if s.startswith('"') and s.endswith('"'):
normalized.append(s)
else:
normalized.append('"' + s.strip('"') + '"')
values = normalized
if zone_id == desired.zone_id and values == desired.values:
return desired
return replace(desired, zone_id=zone_id, values=values)
[docs]
def get_current(self, desired: DnsRecordDesired) -> DnsRecordCurrent:
desired = self._normalize_desired(desired)
record_fqdn = normalize_record_name(desired.record_name, zone_name=desired.zone_name)
record_type = normalize_record_type(desired.record_type)
return self._provider().get_record_current(
zone_id=desired.zone_id,
record_fqdn=record_fqdn,
record_type=record_type,
)
[docs]
def plan(
self,
desired: DnsRecordDesired,
current: DnsRecordCurrent | None = None,
) -> DnsRecordPlan:
desired = self._normalize_desired(desired)
current = current or self.get_current(desired)
start = time.monotonic()
self._emit(
"manager.plan.start",
module="dxaws-dns",
zone_id=desired.zone_id,
record_name=desired.record_name,
record_type=desired.record_type,
)
plan = plan_record(desired=desired, current=current)
summary = self._plan_summary(plan)
self._emit(
"manager.plan.done",
module="dxaws-dns",
zone_id=desired.zone_id,
record_name=desired.record_name,
record_type=plan.record_type,
action=plan.action,
duration_ms=int((time.monotonic() - start) * 1000),
plan_summary=summary,
)
return plan
[docs]
def apply(
self,
plan: DnsRecordPlan,
*,
options: ApplyOptions | None = None,
):
options = options or ApplyOptions()
if plan.is_noop:
return apply_record_plan(provider=self._provider(), plan=plan)
if options.emit_events:
self._emit(
"manager.apply.start",
module="dxaws-dns",
zone_id=plan.desired.zone_id,
record_name=plan.desired.record_name,
record_type=plan.record_type,
action=plan.action,
)
start = time.monotonic()
outputs = apply_record_plan(provider=self._provider(), plan=plan)
if options.emit_events:
self._emit(
"manager.apply.done",
module="dxaws-dns",
zone_id=plan.desired.zone_id,
record_name=plan.desired.record_name,
record_type=plan.record_type,
action=plan.action,
outcome="applied",
duration_ms=int((time.monotonic() - start) * 1000),
)
return outputs
[docs]
def execute(
self,
desired: DnsRecordDesired,
*,
options: ExecuteOptions | None = None,
) -> DnsManagerResult:
options = options or ExecuteOptions()
desired = self._normalize_desired(desired)
current = self.get_current(desired)
plan = self.plan(desired, current)
summary = self._plan_summary(plan)
try:
outputs = self.apply(plan, options=options.apply_options)
outcome = "noop" if plan.is_noop else "applied"
if options.emit_events:
self._emit(
"manager.execute.done",
module="dxaws-dns",
zone_id=desired.zone_id,
record_name=desired.record_name,
record_type=plan.record_type,
action=plan.action,
outcome=outcome,
plan_summary=summary,
)
return DnsManagerResult(
desired=desired,
current=current,
plan=plan,
outputs=outputs,
outcome=outcome,
)
except Exception as exc:
if options.emit_events:
self._emit(
"manager.execute.done",
module="dxaws-dns",
zone_id=desired.zone_id,
record_name=desired.record_name,
record_type=plan.record_type,
action=plan.action,
outcome="failed",
plan_summary=summary,
error=str(exc),
)
raise
[docs]
def resolve_zone_for_fqdn(self, fqdn: str) -> str:
"""Return the hosted zone id that authoritatively covers fqdn (longest-suffix match).
This is intended for composed modules (e.g. dxaws-acm) that are given a full
record FQDN (like _abc.foo.example.com.) and need to know which hosted zone
to write into.
The provider may implement a specialized resolver. If not, we fall back to
listing zones (if the provider supports it) and selecting the longest suffix match.
"""
# Prefer provider-native resolver if present.
resolver = getattr(self._provider(), "resolve_zone_for_fqdn", None)
if callable(resolver):
return str(resolver(fqdn=fqdn))
# Fallback: longest-suffix match using hosted zone listing.
prov = self._provider()
lister = getattr(prov, "list_zones", None) or getattr(prov, "list_hosted_zones", None)
if not callable(lister):
raise ValueError(
"Provider does not support resolve_zone_for_fqdn(fqdn=...) and does not implement "
"list_zones()/list_hosted_zones() for fallback longest-suffix resolution."
)
target = str(fqdn).rstrip(".").lower()
best_zone_id: str | None = None
best_zone_name: str | None = None
zones = lister()
for z in zones:
# Accept multiple possible shapes returned by providers.
zone_id = getattr(z, "zone_id", None) or getattr(z, "id", None) or getattr(z, "Id", None)
zone_name = getattr(z, "zone_name", None) or getattr(z, "name", None) or getattr(z, "Name", None)
if not zone_id or not zone_name:
continue
zn = str(zone_name).rstrip(".").lower()
if target == zn or target.endswith("." + zn):
if best_zone_name is None or len(zn) > len(best_zone_name):
best_zone_name = zn
best_zone_id = str(zone_id)
if best_zone_id is None:
raise ValueError(f"No hosted zone found that covers fqdn={fqdn!r}")
return best_zone_id
[docs]
def ensure_records(self, *, records: Sequence[Any], zone_id: str | None = None) -> None:
"""Idempotently ensure a batch of records exists.
Records are preferred to be fully qualified. If `zone_id` is omitted, the
manager resolves the authoritative hosted zone internally.
"""
for r in records:
if isinstance(r, DnsRecordDesired):
desired = r if zone_id is None else replace(r, zone_id=zone_id)
else:
name = getattr(r, "name", None)
rr_type = getattr(r, "type", None)
value = getattr(r, "value", None)
ttl = getattr(r, "ttl", None)
zone_name = getattr(r, "zone_name", None)
if not name or not rr_type or value is None:
continue
desired = DnsRecordDesired(
zone_id=zone_id,
zone_name=zone_name,
record_name=str(name),
record_type=normalize_record_type(str(rr_type)),
ttl=int(ttl) if ttl is not None else 60,
values=[str(value)],
)
self.execute(desired)