"""
Gecko-specific MQTT transporter implementation.
This module provides the MqttTransporter class which implements the AbstractTransporter
interface with Gecko IoT-specific logic for configuration loading, state management,
and AWS IoT shadow operations.
"""
import json
import logging
import threading
import time
import uuid
from datetime import datetime
from concurrent.futures import Future
from typing import Any, Callable, Dict, Optional
from .. import AbstractTransporter
from ..exceptions import ConfigurationError, ConnectionError
from .client import MqttClient
from .token_manager import TokenManager
from .reconnection_handler import ReconnectionHandler
from .callback_registry import CallbackRegistry
from .utils import parse_json_safely, complete_future_safely, notify_callbacks_safely
from .constants import (
NOT_CONNECTED_ERROR,
CONNECTION_TIMEOUT
)
logger = logging.getLogger(__name__)
[docs]
class MqttTransporter(AbstractTransporter):
"""
Gecko-specific MQTT transporter.
Responsibilities:
- Gecko topic structure (config, state, shadow)
- Token refresh and expiration management
- Configuration and state loading
- AbstractTransporter interface implementation
- Reconnection logic with token refresh
This class contains all Gecko IoT business logic and delegates
MQTT protocol operations to MqttClient.
"""
[docs]
def __init__(
self,
broker_url: str,
monitor_id: str,
token_refresh_callback: Optional[Callable[[str], str]] = None,
token_refresh_buffer_seconds: int = 300,
):
"""
Initialize MQTT transporter with Gecko-specific logic.
Args:
broker_url: WebSocket URL with embedded JWT token
monitor_id: Device monitor identifier
token_refresh_callback: Function to get new broker URL with fresh token
token_refresh_buffer_seconds: Seconds before expiry to refresh token
"""
if not broker_url or not monitor_id:
raise ConfigurationError("Both broker_url and monitor_id are required")
self._broker_url = broker_url
self._monitor_id = monitor_id
self._token_refresh_callback = token_refresh_callback
self._token_refresh_buffer = token_refresh_buffer_seconds
# Helper components
self._token_manager = TokenManager(broker_url, token_refresh_buffer_seconds)
self._reconnection_handler = ReconnectionHandler()
self._callback_registry = CallbackRegistry()
# MQTT client - delegates all MQTT operations
self._mqtt_client = MqttClient(
on_connected=self._on_mqtt_connected,
on_message=None, # We use specific handlers only
)
# State management
self._is_refreshing_token = False
self._state_lock = threading.RLock()
# Loading state
self._config_future: Optional[Future] = None
self._state_future: Optional[Future] = None
self._subscriptions_setup = False
# Threading for expiry monitoring
self._monitor_thread: Optional[threading.Thread] = None
self._monitor_stop_event = threading.Event()
# ========================================================================
# AbstractTransporter Interface
# ========================================================================
[docs]
def connect(self, **kwargs):
"""Connect using preformatted WebSocket URL with expiration management."""
self._monitor_stop_event.clear()
if self._mqtt_client.is_connected():
logger.debug("Already connected")
return
# Check if token is already expired before attempting connection
if self._token_manager.is_expired():
logger.warning("Token expired, refreshing before connection")
if self._token_refresh_callback:
self._refresh_token_before_connect()
try:
# Generate unique client ID
client_id = f"ha-{self._monitor_id}-{uuid.uuid4().hex}"
# Connect via MQTT client
self._mqtt_client.connect(
broker_url=self._broker_url,
client_id=client_id,
timeout=kwargs.get("timeout", CONNECTION_TIMEOUT)
)
# Start expiry monitoring after successful connection
if self._token_refresh_callback and self._token_manager.expiry:
self._start_expiry_monitoring()
except Exception as e:
logger.error(f"Connection failed: {e}")
raise ConnectionError(f"Connection failed: {e}")
[docs]
def disconnect(self):
"""Disconnect and cleanup."""
# Stop monitoring first
self._monitor_stop_event.set()
self._stop_expiry_monitoring()
# Disconnect MQTT client
self._mqtt_client.disconnect()
# Clear subscription state
with self._state_lock:
self._subscriptions_setup = False
logger.info("Transporter disconnected successfully")
[docs]
def is_connected(self) -> bool:
"""Check if connected to broker."""
return self._mqtt_client.is_connected()
[docs]
def update_broker_url(self, new_broker_url: str) -> None:
"""
Update broker URL with a fresh token.
This is useful when reusing an existing connection that needs a token refresh.
The method updates the broker URL and token manager without disconnecting.
Args:
new_broker_url: New WebSocket URL with fresh JWT token
"""
if not new_broker_url:
logger.warning("Attempted to update with empty broker URL")
return
logger.debug("Updating broker URL with fresh token")
self._broker_url = new_broker_url
self._token_manager.update_broker_url(new_broker_url)
logger.debug(f"Token expiry updated to: {self._token_manager.expiry}")
[docs]
def load_configuration(self, timeout: float = 30.0):
"""Load configuration from AWS IoT."""
if not self._mqtt_client.is_connected():
raise ConnectionError(NOT_CONNECTED_ERROR)
if self._config_future and not self._config_future.done():
logger.debug("Configuration request already in progress")
return
# Wait for subscriptions to be ready (set up by connection callback)
wait_start = time.time()
while not self._subscriptions_setup and (time.time() - wait_start) < timeout:
logger.debug("Waiting for subscriptions to be ready...")
time.sleep(0.1)
if not self._subscriptions_setup:
raise ConfigurationError("Subscriptions not ready within timeout")
logger.debug(f"Loading configuration for monitor_id: {self._monitor_id}")
# Create future BEFORE publishing request to avoid race condition
# where response arrives before future exists
self._config_future = Future()
topic = self._build_topic("config/get")
try:
logger.debug(f"Publishing configuration request to: {topic}")
publish_future = self._mqtt_client.publish(topic, "{}")
# Wait for publish to complete
try:
publish_future.result(timeout=5.0)
logger.debug("Configuration request published")
except Exception as e:
logger.error(f"Failed to publish configuration request: {e}")
raise ConfigurationError(f"Failed to publish config request: {e}")
logger.debug(f"Waiting for configuration response (timeout: {timeout}s)")
# Wait for response
result = self._config_future.result(timeout=timeout)
logger.debug("Configuration loaded successfully")
return result
except Exception as e:
self._config_future = None
logger.error(f"Configuration loading failed: {e}")
raise ConfigurationError(f"Configuration loading failed: {e}")
[docs]
def load_state(self):
"""Load state from AWS IoT shadow."""
if not self._mqtt_client.is_connected():
raise ConnectionError(NOT_CONNECTED_ERROR)
if self._state_future and not self._state_future.done():
logger.debug("State request already in progress")
return
logger.debug(f"Loading state for monitor_id: {self._monitor_id}")
self._state_future = Future()
topic = self._build_topic("shadow/name/state/get")
try:
self._mqtt_client.publish(topic, "{}")
logger.debug("State request sent")
except Exception as e:
self._state_future = None
logger.error(f"State loading failed: {e}")
raise ConfigurationError(f"State loading failed: {e}")
[docs]
def publish_desired_state(self, desired_state: Dict[str, Any]) -> Future:
"""Publish desired state update to AWS IoT shadow."""
if not self._mqtt_client.is_connected():
raise ConnectionError(NOT_CONNECTED_ERROR)
payload = {"state": {"desired": desired_state}}
topic = self._build_topic("shadow/name/state/update")
return self._mqtt_client.publish(topic, json.dumps(payload))
[docs]
def publish_batch_desired_state(
self, zone_updates: Dict[str, Dict[str, Dict[str, Any]]]
) -> Future:
"""Publish batch desired state updates for multiple zones."""
desired_state = {"zones": zone_updates}
return self.publish_desired_state(desired_state)
[docs]
def on_configuration_loaded(self, callback):
"""Register config callback."""
self._callback_registry.register("config", callback)
[docs]
def on_state_loaded(self, callback):
"""Register state callback."""
self._callback_registry.register("state", callback)
[docs]
def on_state_change(self, callback):
"""Register state change callback."""
self._callback_registry.register("state_update", callback)
[docs]
def on_connectivity_change(self, callback):
"""Register connectivity change callback."""
self._callback_registry.register("connectivity", callback)
[docs]
def change_state(self, new_state):
"""Change state (placeholder for interface compliance)."""
notify_callbacks_safely(
self._callback_registry.get_callbacks("state_update"),
new_state
)
# ========================================================================
# Gecko-Specific Logic
# ========================================================================
def _build_topic(self, path: str) -> str:
"""Build AWS IoT topic for this monitor."""
return f"$aws/things/{self._monitor_id}/{path}"
def _refresh_token_before_connect(self) -> None:
"""Refresh token before initial connection attempt."""
if not self._token_refresh_callback:
return
try:
new_broker_url = self._token_refresh_callback(self._monitor_id)
if new_broker_url:
self._broker_url = new_broker_url
self._token_manager.update_broker_url(new_broker_url)
logger.debug("Token refreshed successfully before connection")
else:
logger.error("Token refresh callback returned empty URL")
except Exception as e:
logger.error(f"Failed to refresh expired token before connection: {e}")
def _setup_subscriptions(self):
"""Setup essential AWS IoT subscriptions."""
if self._subscriptions_setup:
logger.debug("Subscriptions already set up")
return
logger.debug(f"Setting up subscriptions for monitor_id: {self._monitor_id}")
topics = [
(self._build_topic("config/get/accepted"), self._on_config_response),
(self._build_topic("config/get/rejected"), self._on_config_rejected),
(self._build_topic("shadow/name/state/get/accepted"), self._on_state_response),
(self._build_topic("shadow/name/state/get/rejected"), self._on_state_rejected),
(self._build_topic("shadow/name/state/update/documents"), self._on_state_document_update),
(self._build_topic("shadow/name/state/update/rejected"), self._on_state_update_rejected),
]
successful_subscriptions = 0
for topic, handler in topics:
try:
logger.debug(f"Subscribing to: {topic}")
self._mqtt_client.subscribe(topic, handler)
successful_subscriptions += 1
except Exception as e:
logger.error(f"Failed to subscribe to {topic}: {e}")
if successful_subscriptions > 0:
self._subscriptions_setup = True
logger.debug(f"Set up {successful_subscriptions}/{len(topics)} subscriptions")
else:
logger.error("Failed to set up any subscriptions")
raise ConnectionError("Failed to establish subscriptions")
# ========================================================================
# Token Refresh and Expiry Monitoring
# ========================================================================
def _start_expiry_monitoring(self):
"""Start monitoring token expiry in background thread."""
if self._monitor_thread and self._monitor_thread.is_alive():
return
self._monitor_stop_event.clear()
self._monitor_thread = threading.Thread(
target=self._expiry_monitor_loop, daemon=True
)
self._monitor_thread.start()
logger.debug("Started token expiry monitoring")
def _stop_expiry_monitoring(self):
"""Stop token expiry monitoring."""
if self._monitor_thread and self._monitor_thread.is_alive():
self._monitor_stop_event.set()
self._monitor_thread.join(timeout=5)
logger.debug("Stopped token expiry monitoring")
def _expiry_monitor_loop(self):
"""Background thread loop to monitor token expiry."""
while not self._monitor_stop_event.is_set():
try:
# Check if token needs refreshing
with self._state_lock:
already_refreshing = self._is_refreshing_token
if not already_refreshing and self._should_refresh_token():
logger.info("Token approaching expiry, initiating refresh...")
self._handle_token_refresh()
# Check every 10 seconds for more responsive refresh
self._monitor_stop_event.wait(10)
except Exception as e:
logger.error(f"Error in expiry monitoring: {e}")
self._monitor_stop_event.wait(60) # Back off on error
def _should_refresh_token(self) -> bool:
"""Check if token needs refreshing."""
return self._token_manager.should_refresh(self._mqtt_client.is_connected())
def _handle_token_refresh(self):
"""Handle token refresh and reconnection."""
if not self._token_refresh_callback:
logger.warning("No token refresh callback configured")
return
with self._state_lock:
self._is_refreshing_token = True
try:
# Log timing information
expiry = self._token_manager.expiry
if expiry:
time_to_expiry = (expiry - datetime.now()).total_seconds()
logger.info(f"Refreshing token ({time_to_expiry:.1f}s until expiry)...")
else:
logger.info("Refreshing token...")
# Get new broker URL with fresh token (track callback duration)
callback_start = datetime.now()
new_broker_url = self._token_refresh_callback(self._monitor_id)
callback_duration = (datetime.now() - callback_start).total_seconds()
logger.debug(f"Token refresh callback completed in {callback_duration:.1f}s")
if not new_broker_url:
logger.error("Token refresh callback returned empty URL")
with self._state_lock:
self._is_refreshing_token = False
self._schedule_reconnect()
return
# Update broker URL and token expiry
old_broker_url = self._broker_url
self._broker_url = new_broker_url
self._token_manager.update_broker_url(new_broker_url)
# Reset reconnection counter - fresh token means fresh start
self._reconnection_handler.on_success()
logger.debug("Token updated, establishing new connection")
# Attempt to connect with new token
# We need to disconnect first, then reconnect with fresh token
# Connectivity events are suppressed via _is_refreshing_token flag
try:
client_id = f"ha-{self._monitor_id}-{uuid.uuid4().hex}"
# Disconnect old connection (connectivity event will be suppressed)
if self._mqtt_client.is_connected():
logger.debug("Disconnecting old connection before token refresh reconnect")
self._mqtt_client.disconnect()
# Establish new connection with fresh token
self._mqtt_client.connect(
broker_url=self._broker_url,
client_id=client_id
)
# Clear subscription state since we need to re-subscribe with new connection
with self._state_lock:
self._subscriptions_setup = False
# Clear intentional disconnect flag
self._mqtt_client.clear_intentional_disconnect_flag()
logger.debug("Token refreshed with minimal downtime")
self._reconnection_handler.on_success()
with self._state_lock:
self._is_refreshing_token = False
except Exception as e:
logger.error(f"Reconnection after token refresh failed: {e}")
# Restore old broker URL on failure
self._broker_url = old_broker_url
with self._state_lock:
self._is_refreshing_token = False
self._schedule_reconnect()
except Exception as e:
logger.error(f"Token refresh failed: {e}")
with self._state_lock:
self._is_refreshing_token = False
self._schedule_reconnect()
def _schedule_reconnect(self):
"""Schedule reconnection with exponential backoff."""
if not self._reconnection_handler.should_attempt():
logger.warning(
"Max reconnection attempts reached, will retry after cooldown period. "
"If this persists, token may be expired - forcing token refresh."
)
# Reset counter and try token refresh if available
self._reconnection_handler.on_success()
if self._token_refresh_callback:
# Force a token refresh after cooldown
def delayed_refresh():
time.sleep(300) # 5 minute cooldown
if not self._monitor_stop_event.is_set():
logger.info("Cooldown period ended, forcing token refresh")
self._handle_token_refresh()
refresh_thread = threading.Thread(target=delayed_refresh, daemon=True)
refresh_thread.start()
return
delay = self._reconnection_handler.get_delay()
attempt_num = self._reconnection_handler.on_attempt()
logger.debug(f"Scheduling reconnection attempt {attempt_num} in {delay}s")
def delayed_reconnect():
time.sleep(delay)
if not self._monitor_stop_event.is_set():
try:
client_id = f"ha-{self._monitor_id}-{uuid.uuid4().hex}"
self._mqtt_client.connect(
broker_url=self._broker_url,
client_id=client_id
)
logger.debug("Reconnection successful")
self._reconnection_handler.on_success()
# Clear subscription state to force re-setup
with self._state_lock:
self._subscriptions_setup = False
except Exception as e:
logger.error(f"Reconnection attempt {attempt_num} failed: {e}")
self._schedule_reconnect()
reconnect_thread = threading.Thread(target=delayed_reconnect, daemon=True)
reconnect_thread.start()
# ========================================================================
# MQTT Client Callbacks
# ========================================================================
def _on_mqtt_connected(self, connected: bool):
"""Handle MQTT connection status changes."""
logger.debug(f"MQTT connection status changed: {connected}")
# Check if we're in the middle of a token refresh
with self._state_lock:
is_refreshing = self._is_refreshing_token
if connected:
# Reset reconnection handler on successful connection
self._reconnection_handler.on_success()
# Schedule subscription setup and state loading in a background thread
# to avoid blocking the lifecycle callback and allow connection to stabilize
def setup_after_connection():
# Brief delay to ensure MQTT client is fully ready for subscriptions
time.sleep(0.5)
# Always setup subscriptions after connection
logger.debug("Setting up subscriptions after connection")
with self._state_lock:
self._subscriptions_setup = False
try:
self._setup_subscriptions()
# Load initial state after subscriptions are ready
logger.debug("Loading initial state after connection")
try:
self.load_state()
except Exception as e:
logger.warning(f"Failed to load initial state after connection: {e}")
except Exception as e:
logger.error(f"Failed to setup subscriptions after connection: {e}")
setup_thread = threading.Thread(target=setup_after_connection, daemon=True)
setup_thread.start()
# Suppress connectivity callbacks during token refresh to prevent
# entities from flickering unavailable during the brief disconnect/reconnect
if is_refreshing:
logger.debug("Suppressing connectivity callback during token refresh")
return
else:
# Disconnection event
# Check if we should attempt reconnection
with self._state_lock:
is_refreshing = self._is_refreshing_token
# Only schedule reconnect if not already refreshing and we have a callback
if not is_refreshing and self._token_refresh_callback and not self._monitor_stop_event.is_set():
# Check if token is expired/expiring
if self._token_manager.is_expired() or self._token_manager.should_refresh(False):
logger.info("Token expired/expiring, will refresh on reconnect")
self._token_manager.force_expiry()
logger.info("Unexpected disconnection, scheduling reconnection...")
self._schedule_reconnect()
# Suppress disconnection callbacks during token refresh
if is_refreshing:
logger.debug("Suppressing disconnection callback during token refresh")
return
# Notify connectivity callbacks
self._callback_registry.notify("connectivity", connected)
# ========================================================================
# Message Handlers for Gecko Topics
# ========================================================================
def _on_config_response(self, topic: str, payload: str):
"""Handle configuration response."""
logger.debug("Configuration response received")
config = parse_json_safely(payload)
if config:
config = config.get("configuration", {}).get("configuration", {})
notify_callbacks_safely(
self._callback_registry.get_callbacks("config"),
config
)
complete_future_safely(self._config_future, config)
else:
logger.error("Failed to parse configuration response")
if self._config_future and not self._config_future.done():
self._config_future.set_exception(
ConfigurationError("Invalid JSON in configuration response")
)
def _on_config_rejected(self, topic: str, payload: str):
"""Handle configuration request rejection."""
logger.warning(f"Configuration request rejected on topic: {topic}")
logger.warning(f"Rejection payload: {payload}")
if self._config_future and not self._config_future.done():
self._config_future.set_exception(
ConfigurationError(f"Configuration rejected: {payload}")
)
def _on_state_response(self, topic: str, payload: str):
"""Handle state response."""
logger.debug("State response received")
state = parse_json_safely(payload)
if state:
notify_callbacks_safely(
self._callback_registry.get_callbacks("state"),
state
)
complete_future_safely(self._state_future, state)
else:
logger.error("Failed to parse state response")
if self._state_future and not self._state_future.done():
self._state_future.set_exception(
ConfigurationError("Invalid JSON in state response")
)
def _on_state_rejected(self, topic: str, payload: str):
"""Handle state request rejection."""
logger.warning(f"State request rejected: {payload}")
if self._state_future and not self._state_future.done():
self._state_future.set_exception(
ConfigurationError(f"State rejected: {payload}")
)
def _on_state_document_update(self, topic: str, payload: str):
"""Handle state document update notifications."""
logger.debug("State document update received")
document = parse_json_safely(payload)
if document:
# Extract current state from document structure
current_state = document.get("current", {}).get("state", {})
logger.debug("Extracted state from document")
notify_callbacks_safely(
self._callback_registry.get_callbacks("state_update"),
{"state": current_state}
)
else:
logger.error("Failed to parse state document update")
def _on_state_update_rejected(self, topic: str, payload: str):
"""Handle state update rejection."""
logger.warning(f"State update rejected: {payload}")