"""OAuth2 authentication for WHOOP API with secure token storage.

Supports hybrid storage: macOS Keychain (primary) + file fallback (for cron jobs).

IMPORTANT: WHOOP uses rotating refresh tokens. Each successful refresh returns a NEW
refresh token and invalidates the old one. This means:
- Only one process should refresh at a time (file locking)
- Tokens must be saved immediately after refresh
- If refresh fails with 400, the refresh token is likely invalid and re-auth is needed
"""

import fcntl
import json
import os
import secrets
import stat
import time
import webbrowser
from http.server import HTTPServer, BaseHTTPRequestHandler
from pathlib import Path
from threading import Thread
from typing import Optional
from urllib.parse import urlencode, parse_qs, urlparse

import httpx
import keyring

# Refresh token lifetime estimate (WHOOP doesn't document this, but ~7 days is typical)
REFRESH_TOKEN_LIFETIME_DAYS = 7

# WHOOP OAuth endpoints
AUTH_URL = "https://api.prod.whoop.com/oauth/oauth2/auth"
TOKEN_URL = "https://api.prod.whoop.com/oauth/oauth2/token"

# Keychain service name
SERVICE_NAME = "whoop-mcp"

# File-based token storage path (for cron jobs that can't access keychain)
TOKEN_FILE = Path(__file__).parent.parent.parent / ".tokens.json"

# Required scopes for health data access
SCOPES = "read:recovery read:sleep read:cycles read:workout offline"


class KeychainTokenStorage:
    """Secure token storage using macOS Keychain."""

    @staticmethod
    def store_tokens(access_token: str, refresh_token: str, expires_at: float) -> bool:
        """Store OAuth tokens securely in keychain. Returns True on success."""
        try:
            token_data = json.dumps({
                "access_token": access_token,
                "refresh_token": refresh_token,
                "expires_at": expires_at
            })
            keyring.set_password(SERVICE_NAME, "tokens", token_data)
            return True
        except Exception as e:
            print(f"Keychain storage failed: {e}")
            return False

    @staticmethod
    def get_tokens() -> Optional[dict]:
        """Retrieve tokens from keychain."""
        try:
            token_data = keyring.get_password(SERVICE_NAME, "tokens")
            if token_data:
                return json.loads(token_data)
        except Exception as e:
            # Keychain access may fail in non-GUI contexts (e.g., cron)
            pass
        return None

    @staticmethod
    def clear_tokens() -> None:
        """Remove tokens from keychain."""
        try:
            keyring.delete_password(SERVICE_NAME, "tokens")
        except keyring.errors.PasswordDeleteError:
            pass  # Token doesn't exist
        except Exception:
            pass


class FileTokenStorage:
    """File-based token storage for non-GUI contexts (cron jobs).

    Stores tokens in a JSON file with 600 permissions for security.
    This is a fallback when keychain is not accessible.
    Uses file locking to prevent race conditions with rotating refresh tokens.
    """

    @staticmethod
    def store_tokens(access_token: str, refresh_token: str, expires_at: float, refreshed_at: float = None) -> bool:
        """Store OAuth tokens in file with file locking. Returns True on success."""
        try:
            token_data = {
                "access_token": access_token,
                "refresh_token": refresh_token,
                "expires_at": expires_at,
                "refreshed_at": refreshed_at or time.time()  # Track when we last refreshed
            }

            # Use atomic write with file locking
            lock_file = Path(str(TOKEN_FILE) + ".lock")
            with open(lock_file, 'w') as lf:
                fcntl.flock(lf.fileno(), fcntl.LOCK_EX)
                try:
                    # Write to temp file first, then rename (atomic on POSIX)
                    temp_file = Path(str(TOKEN_FILE) + ".tmp")
                    temp_file.write_text(json.dumps(token_data, indent=2))
                    os.chmod(temp_file, stat.S_IRUSR | stat.S_IWUSR)  # 600
                    temp_file.rename(TOKEN_FILE)
                finally:
                    fcntl.flock(lf.fileno(), fcntl.LOCK_UN)

            return True
        except Exception as e:
            print(f"File storage failed: {e}")
            return False

    @staticmethod
    def get_tokens() -> Optional[dict]:
        """Retrieve tokens from file with file locking."""
        try:
            if TOKEN_FILE.exists():
                lock_file = Path(str(TOKEN_FILE) + ".lock")
                with open(lock_file, 'w') as lf:
                    fcntl.flock(lf.fileno(), fcntl.LOCK_SH)  # Shared lock for reading
                    try:
                        return json.loads(TOKEN_FILE.read_text())
                    finally:
                        fcntl.flock(lf.fileno(), fcntl.LOCK_UN)
        except Exception:
            pass
        return None

    @staticmethod
    def clear_tokens() -> None:
        """Remove token file."""
        try:
            if TOKEN_FILE.exists():
                TOKEN_FILE.unlink()
        except Exception:
            pass


class HybridTokenStorage:
    """Hybrid token storage: tries keychain first, falls back to file.

    - GUI contexts (MCP server in terminal): Keychain works, file is backup
    - Non-GUI contexts (cron jobs): Keychain fails, file is used

    When storing, writes to BOTH to keep them in sync.
    """

    @staticmethod
    def store_tokens(access_token: str, refresh_token: str, expires_at: float, refreshed_at: float = None) -> None:
        """Store tokens in both keychain and file."""
        refreshed_at = refreshed_at or time.time()
        keychain_ok = KeychainTokenStorage.store_tokens(access_token, refresh_token, expires_at)
        file_ok = FileTokenStorage.store_tokens(access_token, refresh_token, expires_at, refreshed_at)

        if not keychain_ok and not file_ok:
            raise RuntimeError("Failed to store tokens in both keychain and file")

    @staticmethod
    def get_tokens() -> Optional[dict]:
        """Get tokens, trying keychain first then file."""
        # Try keychain first (more secure)
        tokens = KeychainTokenStorage.get_tokens()
        if tokens:
            return tokens

        # Fallback to file (for cron jobs)
        return FileTokenStorage.get_tokens()

    @staticmethod
    def clear_tokens() -> None:
        """Clear tokens from both storage locations."""
        KeychainTokenStorage.clear_tokens()
        FileTokenStorage.clear_tokens()

    @staticmethod
    def is_token_expired(tokens: dict, buffer_seconds: int = 300) -> bool:
        """Check if access token is expired (with 5 min buffer)."""
        expires_at = tokens.get("expires_at", 0)
        return time.time() > (expires_at - buffer_seconds)

    @staticmethod
    def is_refresh_token_likely_expired(tokens: dict) -> bool:
        """Check if refresh token is likely expired (based on when we last refreshed).

        WHOOP refresh tokens have a limited lifetime (~7 days). If we haven't
        refreshed in that time, the refresh token is probably invalid.
        """
        refreshed_at = tokens.get("refreshed_at", tokens.get("expires_at", 0) - 3600)
        days_since_refresh = (time.time() - refreshed_at) / (24 * 3600)
        return days_since_refresh > REFRESH_TOKEN_LIFETIME_DAYS

    @staticmethod
    def should_proactively_refresh(tokens: dict, buffer_seconds: int = 600) -> bool:
        """Check if we should proactively refresh (10 min before expiry)."""
        expires_at = tokens.get("expires_at", 0)
        return time.time() > (expires_at - buffer_seconds)

    @staticmethod
    def sync_keychain_to_file() -> bool:
        """Copy tokens from keychain to file (for setting up cron access)."""
        tokens = KeychainTokenStorage.get_tokens()
        if tokens:
            return FileTokenStorage.store_tokens(
                tokens["access_token"],
                tokens["refresh_token"],
                tokens["expires_at"]
            )
        return False

    @staticmethod
    def get_storage_status() -> dict:
        """Get status of both storage backends (for diagnostics)."""
        keychain_tokens = KeychainTokenStorage.get_tokens()
        file_tokens = FileTokenStorage.get_tokens()

        return {
            "keychain": {
                "available": keychain_tokens is not None,
                "expires_at": keychain_tokens.get("expires_at") if keychain_tokens else None
            },
            "file": {
                "available": file_tokens is not None,
                "path": str(TOKEN_FILE),
                "expires_at": file_tokens.get("expires_at") if file_tokens else None
            }
        }


# Backwards compatibility alias
TokenStorage = HybridTokenStorage


class OAuthCallbackHandler(BaseHTTPRequestHandler):
    """HTTP handler for OAuth callback."""

    authorization_code: Optional[str] = None
    state: Optional[str] = None
    error: Optional[str] = None

    def do_GET(self):
        """Handle OAuth callback GET request."""
        parsed = urlparse(self.path)
        params = parse_qs(parsed.query)

        if "error" in params:
            OAuthCallbackHandler.error = params["error"][0]
            self._send_response("Authorization failed. You can close this window.")
        elif "code" in params:
            OAuthCallbackHandler.authorization_code = params["code"][0]
            OAuthCallbackHandler.state = params.get("state", [None])[0]
            self._send_response("Authorization successful! You can close this window.")
        else:
            self._send_response("Invalid callback. Missing authorization code.")

    def _send_response(self, message: str):
        """Send HTML response."""
        self.send_response(200)
        self.send_header("Content-type", "text/html")
        self.end_headers()
        html = f"""
        <!DOCTYPE html>
        <html>
        <head><title>WHOOP Authorization</title></head>
        <body style="font-family: system-ui; text-align: center; padding: 50px;">
            <h1>{message}</h1>
        </body>
        </html>
        """
        self.wfile.write(html.encode())

    def log_message(self, format, *args):
        """Suppress HTTP server logs."""
        pass


class WhoopAuth:
    """WHOOP OAuth2 authentication handler."""

    def __init__(self, client_id: str, client_secret: str, redirect_uri: str = "http://localhost:8765/callback"):
        self.client_id = client_id
        self.client_secret = client_secret
        self.redirect_uri = redirect_uri
        self._http_client: Optional[httpx.Client] = None

    @property
    def http_client(self) -> httpx.Client:
        """Lazy-load HTTP client."""
        if self._http_client is None:
            self._http_client = httpx.Client(timeout=30.0)
        return self._http_client

    def get_valid_access_token(self) -> Optional[str]:
        """Get a valid access token, refreshing if necessary.

        Uses proactive refresh (before expiry) to reduce failure risk.
        """
        tokens = TokenStorage.get_tokens()

        if not tokens:
            return None

        # Check if refresh token is likely expired (needs re-auth)
        if TokenStorage.is_refresh_token_likely_expired(tokens):
            print(f"WARNING: Refresh token likely expired (>{REFRESH_TOKEN_LIFETIME_DAYS} days old).")
            print("Run setup_auth.py to re-authenticate.")
            # Still try to use the access token if it's valid
            if not TokenStorage.is_token_expired(tokens, buffer_seconds=0):
                return tokens["access_token"]
            return None

        # Proactively refresh if close to expiry (or already expired)
        if TokenStorage.should_proactively_refresh(tokens) or TokenStorage.is_token_expired(tokens):
            new_tokens = self._refresh_tokens(tokens["refresh_token"])
            if new_tokens:
                return new_tokens["access_token"]
            # If refresh failed but token isn't fully expired yet, use it
            if not TokenStorage.is_token_expired(tokens, buffer_seconds=0):
                print("Refresh failed but current token still valid, using it.")
                return tokens["access_token"]
            return None

        return tokens["access_token"]

    def _refresh_tokens(self, refresh_token: str, retry_count: int = 0) -> Optional[dict]:
        """Refresh the access token using refresh token.

        IMPORTANT: WHOOP uses rotating refresh tokens. Each successful refresh
        returns a NEW refresh token and invalidates the old one. We must:
        1. Use file locking to prevent concurrent refreshes
        2. Save the new tokens immediately
        3. Handle 400 errors (invalid refresh token) by prompting re-auth
        """
        lock_file = Path(str(TOKEN_FILE) + ".refresh.lock")

        try:
            # Acquire exclusive lock for refresh operation
            with open(lock_file, 'w') as lf:
                fcntl.flock(lf.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
                try:
                    # Re-read tokens in case another process already refreshed
                    current_tokens = TokenStorage.get_tokens()
                    if current_tokens and not TokenStorage.is_token_expired(current_tokens, buffer_seconds=60):
                        # Another process already refreshed, use those tokens
                        return current_tokens

                    # Use the refresh token from current storage (might have been updated)
                    current_refresh = current_tokens.get("refresh_token", refresh_token) if current_tokens else refresh_token

                    response = self.http_client.post(
                        TOKEN_URL,
                        data={
                            "grant_type": "refresh_token",
                            "refresh_token": current_refresh,
                            "client_id": self.client_id,
                            "client_secret": self.client_secret,
                        },
                        headers={"Content-Type": "application/x-www-form-urlencoded"}
                    )

                    if response.status_code == 400:
                        error_data = response.json() if response.text else {}
                        error_hint = error_data.get("error_hint", "")
                        print(f"\nToken refresh failed (400): {error_data.get('error_description', 'Unknown error')}")
                        if "invalid" in str(error_data).lower():
                            print("\nThe refresh token is invalid. This happens when:")
                            print("  1. Another process already used this refresh token")
                            print("  2. The refresh token expired (~7 days)")
                            print("  3. WHOOP revoked the token")
                            print("\nRun setup_auth.py to re-authenticate.\n")
                            # Clear invalid tokens so we don't keep retrying
                            TokenStorage.clear_tokens()
                        return None

                    response.raise_for_status()

                    data = response.json()
                    expires_at = time.time() + data.get("expires_in", 3600)
                    refreshed_at = time.time()

                    # CRITICAL: Save new tokens immediately (rotating refresh token)
                    TokenStorage.store_tokens(
                        access_token=data["access_token"],
                        refresh_token=data.get("refresh_token", current_refresh),
                        expires_at=expires_at,
                        refreshed_at=refreshed_at
                    )

                    return TokenStorage.get_tokens()

                finally:
                    fcntl.flock(lf.fileno(), fcntl.LOCK_UN)

        except BlockingIOError:
            # Another process is refreshing, wait and retry
            if retry_count < 3:
                print("Another process is refreshing tokens, waiting...")
                time.sleep(2)
                # After waiting, check if tokens are now valid
                tokens = TokenStorage.get_tokens()
                if tokens and not TokenStorage.is_token_expired(tokens, buffer_seconds=60):
                    return tokens
                return self._refresh_tokens(refresh_token, retry_count + 1)
            print("Timeout waiting for token refresh lock")
            return None

        except httpx.HTTPStatusError as e:
            print(f"Token refresh HTTP error: {e.response.status_code}")
            if e.response.text:
                print(f"Response: {e.response.text}")
            return None

        except Exception as e:
            print(f"Token refresh failed: {e}")
            return None

    def start_authorization_flow(self) -> bool:
        """Start OAuth authorization flow (opens browser)."""
        state = secrets.token_urlsafe(32)

        # Build authorization URL
        auth_params = {
            "response_type": "code",
            "client_id": self.client_id,
            "redirect_uri": self.redirect_uri,
            "scope": SCOPES,
            "state": state,
        }
        auth_url = f"{AUTH_URL}?{urlencode(auth_params)}"

        # Parse port from redirect URI
        parsed = urlparse(self.redirect_uri)
        port = parsed.port or 8765

        # Reset handler state
        OAuthCallbackHandler.authorization_code = None
        OAuthCallbackHandler.state = None
        OAuthCallbackHandler.error = None

        # Start local server for callback
        server = HTTPServer(("localhost", port), OAuthCallbackHandler)
        server_thread = Thread(target=server.handle_request)
        server_thread.start()

        # Open browser for authorization
        print(f"\nOpening browser for WHOOP authorization...")
        print(f"If browser doesn't open, visit: {auth_url}\n")
        webbrowser.open(auth_url)

        # Wait for callback
        server_thread.join(timeout=120)
        server.server_close()

        if OAuthCallbackHandler.error:
            print(f"Authorization error: {OAuthCallbackHandler.error}")
            return False

        if not OAuthCallbackHandler.authorization_code:
            print("No authorization code received")
            return False

        if OAuthCallbackHandler.state != state:
            print("State mismatch - possible CSRF attack")
            return False

        # Exchange code for tokens
        return self._exchange_code_for_tokens(OAuthCallbackHandler.authorization_code)

    def _exchange_code_for_tokens(self, code: str) -> bool:
        """Exchange authorization code for access tokens."""
        try:
            response = self.http_client.post(
                TOKEN_URL,
                data={
                    "grant_type": "authorization_code",
                    "code": code,
                    "redirect_uri": self.redirect_uri,
                    "client_id": self.client_id,
                    "client_secret": self.client_secret,
                },
                headers={"Content-Type": "application/x-www-form-urlencoded"}
            )
            response.raise_for_status()

            data = response.json()
            expires_at = time.time() + data.get("expires_in", 3600)

            TokenStorage.store_tokens(
                access_token=data["access_token"],
                refresh_token=data["refresh_token"],
                expires_at=expires_at
            )

            print("Authorization successful! Tokens stored in keychain and file.")
            return True
        except Exception as e:
            print(f"Token exchange failed: {e}")
            return False

    def is_authorized(self) -> bool:
        """Check if we have valid authorization."""
        return self.get_valid_access_token() is not None

    def revoke(self) -> None:
        """Clear stored tokens."""
        TokenStorage.clear_tokens()
        print("Authorization revoked. Tokens cleared.")


def get_auth_from_env() -> WhoopAuth:
    """Create WhoopAuth instance from environment variables."""
    from dotenv import load_dotenv

    # Load .env file if it exists
    env_path = os.path.join(os.path.dirname(__file__), "..", "..", ".env")
    load_dotenv(env_path)

    client_id = os.getenv("WHOOP_CLIENT_ID")
    client_secret = os.getenv("WHOOP_CLIENT_SECRET")
    redirect_uri = os.getenv("WHOOP_REDIRECT_URI", "http://localhost:8765/callback")

    if not client_id or not client_secret:
        raise ValueError(
            "Missing WHOOP credentials. Set WHOOP_CLIENT_ID and WHOOP_CLIENT_SECRET "
            "in environment or .env file."
        )

    return WhoopAuth(client_id, client_secret, redirect_uri)
