Source code for gecko_iot_client.transporters.token_manager

"""
Token manager for handling AWS IoT token refresh and expiration.
"""

import logging
import threading
from datetime import datetime, timedelta
from typing import Callable, Optional

from .exceptions import TokenRefreshError

logger = logging.getLogger(__name__)


[docs] class TokenManager: """Manages AWS IoT tokens with automatic refresh capabilities."""
[docs] def __init__( self, token_refresh_callback: Callable[[], dict], refresh_threshold_minutes: int = 15, ): """ Initialize the token manager. Args: token_refresh_callback: Function that returns new token data refresh_threshold_minutes: How many minutes before expiry to refresh """ self._token_refresh_callback = token_refresh_callback self._refresh_threshold = timedelta(minutes=refresh_threshold_minutes) self._current_token = None self._token_expiry = None self._refresh_thread = None self._stop_refresh = threading.Event() self._token_lock = threading.Lock() self._refresh_listeners = []
[docs] def add_refresh_listener(self, callback: Callable[[dict], None]): """Add a callback to be notified when token is refreshed.""" self._refresh_listeners.append(callback)
[docs] def remove_refresh_listener(self, callback: Callable[[dict], None]): """Remove a token refresh listener.""" if callback in self._refresh_listeners: self._refresh_listeners.remove(callback)
[docs] def set_token(self, token_data: dict): """ Set the current token and start refresh monitoring. Args: token_data: Dictionary containing token and expiry information Expected keys: 'access_token', 'expires_in' or 'expires_at' """ with self._token_lock: self._current_token = token_data # Calculate expiry time if "expires_at" in token_data: self._token_expiry = datetime.fromisoformat(token_data["expires_at"]) elif "expires_in" in token_data: self._token_expiry = datetime.now() + timedelta( seconds=token_data["expires_in"] ) else: # Default to 1 hour if no expiry info self._token_expiry = datetime.now() + timedelta(hours=1) logger.warning("No expiry information in token, defaulting to 1 hour") self._start_refresh_monitoring() logger.info(f"Token set with expiry: {self._token_expiry}")
[docs] def get_current_token(self) -> Optional[dict]: """Get the current token data.""" with self._token_lock: return self._current_token.copy() if self._current_token else None
[docs] def is_token_valid(self) -> bool: """Check if the current token is valid and not expired.""" with self._token_lock: if not self._current_token or not self._token_expiry: return False return datetime.now() < self._token_expiry
[docs] def needs_refresh(self) -> bool: """Check if token needs to be refreshed based on threshold.""" with self._token_lock: if not self._token_expiry: return True return datetime.now() >= (self._token_expiry - self._refresh_threshold)
[docs] def refresh_token(self) -> dict: """ Manually refresh the token. Returns: New token data Raises: TokenRefreshError: If refresh fails """ try: logger.info("Refreshing token...") new_token = self._token_refresh_callback() if not new_token: raise TokenRefreshError("Token refresh callback returned None") self.set_token(new_token) # Notify listeners for listener in self._refresh_listeners: try: listener(new_token) except Exception as e: logger.error(f"Error notifying token refresh listener: {e}") logger.info("Token refreshed successfully") return new_token except Exception as e: logger.error(f"Token refresh failed: {e}") raise TokenRefreshError(f"Failed to refresh token: {e}")
def _start_refresh_monitoring(self): """Start the background thread for token refresh monitoring.""" if self._refresh_thread and self._refresh_thread.is_alive(): self.stop_refresh_monitoring() self._stop_refresh.clear() self._refresh_thread = threading.Thread( target=self._refresh_monitor_loop, daemon=True ) self._refresh_thread.start() logger.info("Token refresh monitoring started") def _refresh_monitor_loop(self): """Background loop to monitor and refresh tokens.""" while not self._stop_refresh.is_set(): try: if self.needs_refresh(): logger.info("Token needs refresh, attempting refresh...") self.refresh_token() # Check every 30 seconds self._stop_refresh.wait(30) except TokenRefreshError as e: logger.error(f"Token refresh failed in monitor loop: {e}") # Wait longer before retrying on failure self._stop_refresh.wait(300) # 5 minutes except Exception as e: logger.error(f"Unexpected error in token refresh monitor: {e}") self._stop_refresh.wait(60)
[docs] def stop_refresh_monitoring(self): """Stop the background token refresh monitoring.""" if self._refresh_thread: logger.info("Stopping token refresh monitoring...") self._stop_refresh.set() self._refresh_thread.join(timeout=5) if self._refresh_thread.is_alive(): logger.warning("Token refresh thread did not stop gracefully")
[docs] def cleanup(self): """Clean up resources.""" self.stop_refresh_monitoring() with self._token_lock: self._current_token = None self._token_expiry = None self._refresh_listeners.clear() logger.info("Token manager cleaned up")