"""WHOOP API v2 client for health data retrieval."""

from datetime import datetime, timedelta, timezone
from zoneinfo import ZoneInfo
from typing import Optional

import httpx

from .auth import WhoopAuth, get_auth_from_env
from .models import (
    Recovery, Sleep, Cycle, Workout,
    RecoverySummary, SleepSummary, StrainSummary, WorkoutSummary, TodayData
)

# WHOOP API base URL (v2 is current, v1 deprecated)
API_BASE = "https://api.prod.whoop.com/developer/v2"


class WhoopClient:
    """Client for WHOOP API v2."""

    def __init__(self, auth: Optional[WhoopAuth] = None):
        self.auth = auth or get_auth_from_env()
        self._client: Optional[httpx.Client] = None

    @property
    def client(self) -> httpx.Client:
        """Lazy-load HTTP client with auth headers."""
        if self._client is None:
            self._client = httpx.Client(
                base_url=API_BASE,
                timeout=30.0,
            )
        return self._client

    def _get_headers(self) -> dict:
        """Get request headers with current access token."""
        token = self.auth.get_valid_access_token()
        if not token:
            raise ValueError("Not authorized. Run setup_auth.py first.")
        return {
            "Authorization": f"Bearer {token}",
            "Content-Type": "application/json",
        }

    def _get(self, endpoint: str, params: Optional[dict] = None) -> dict:
        """Make authenticated GET request."""
        response = self.client.get(
            endpoint,
            headers=self._get_headers(),
            params=params,
        )
        response.raise_for_status()
        return response.json()

    # -------------------------------------------------------------------------
    # Raw API methods
    # -------------------------------------------------------------------------

    def get_recovery_collection(
        self,
        start: Optional[datetime] = None,
        end: Optional[datetime] = None,
        limit: int = 10,
    ) -> list[dict]:
        """Get recovery records for a date range."""
        params = {"limit": limit}
        if start:
            params["start"] = start.strftime("%Y-%m-%dT%H:%M:%S.000Z")
        if end:
            params["end"] = end.strftime("%Y-%m-%dT%H:%M:%S.000Z")

        data = self._get("/recovery", params)
        return (data or {}).get("records") or []

    def get_sleep_collection(
        self,
        start: Optional[datetime] = None,
        end: Optional[datetime] = None,
        limit: int = 10,
    ) -> list[dict]:
        """Get sleep records for a date range."""
        params = {"limit": limit}
        if start:
            params["start"] = start.strftime("%Y-%m-%dT%H:%M:%S.000Z")
        if end:
            params["end"] = end.strftime("%Y-%m-%dT%H:%M:%S.000Z")

        data = self._get("/activity/sleep", params)
        return (data or {}).get("records") or []

    def get_cycle_collection(
        self,
        start: Optional[datetime] = None,
        end: Optional[datetime] = None,
        limit: int = 10,
    ) -> list[dict]:
        """Get physiological cycles for a date range."""
        params = {"limit": limit}
        if start:
            params["start"] = start.strftime("%Y-%m-%dT%H:%M:%S.000Z")
        if end:
            params["end"] = end.strftime("%Y-%m-%dT%H:%M:%S.000Z")

        data = self._get("/cycle", params)
        return (data or {}).get("records") or []

    def get_workout_collection(
        self,
        start: Optional[datetime] = None,
        end: Optional[datetime] = None,
        limit: int = 10,
    ) -> list[dict]:
        """Get workout records for a date range."""
        params = {"limit": limit}
        if start:
            params["start"] = start.strftime("%Y-%m-%dT%H:%M:%S.000Z")
        if end:
            params["end"] = end.strftime("%Y-%m-%dT%H:%M:%S.000Z")

        data = self._get("/activity/workout", params)
        return (data or {}).get("records") or []

    def get_user_profile(self) -> dict:
        """Get user profile information."""
        return self._get("/user/profile/basic")

    # -------------------------------------------------------------------------
    # Simplified data methods for MCP tools
    # -------------------------------------------------------------------------

    def get_today_data(self) -> TodayData:
        """Get today's combined health data for daily check-in."""
        today = datetime.now(timezone.utc).date()
        yesterday = today - timedelta(days=1)

        # Get most recent data
        start = datetime.combine(yesterday, datetime.min.time(), tzinfo=timezone.utc)
        end = datetime.now(timezone.utc)

        recovery_data = None
        sleep_data = None
        strain_data = None

        # Fetch recovery
        try:
            recoveries = self.get_recovery_collection(start=start, end=end, limit=1)
            if recoveries:
                r = recoveries[0]
                score = r.get("score") or {}
                recovery_data = {
                    "score": score.get("recovery_score"),
                    "hrv": round(score.get("hrv_rmssd_milli", 0), 1) if score.get("hrv_rmssd_milli") else None,
                    "rhr": round(score.get("resting_heart_rate", 0)) if score.get("resting_heart_rate") else None,
                    "spo2": score.get("spo2_percentage"),
                    "skin_temp": round(score.get("skin_temp_celsius", 0), 1) if score.get("skin_temp_celsius") else None,
                    "zone": self._get_recovery_zone(score.get("recovery_score")),
                }
        except Exception as e:
            print(f"Error fetching recovery: {e}")

        # Fetch sleep
        try:
            sleeps = self.get_sleep_collection(start=start, end=end, limit=1)
            if sleeps:
                s = sleeps[0]
                score = s.get("score") or {}
                stages = score.get("stage_summary") or {}
                sleep_data = {
                    "total_hours": round(self._ms_to_hours(stages.get("total_in_bed_time_milli", 0) - stages.get("total_awake_time_milli", 0)), 1),
                    "efficiency": round(score.get("sleep_efficiency_percentage", 0), 1) if score.get("sleep_efficiency_percentage") else None,
                    "performance": round(score.get("sleep_performance_percentage", 0), 1) if score.get("sleep_performance_percentage") else None,
                    "light_hours": round(self._ms_to_hours(stages.get("total_light_sleep_time_milli", 0)), 1),
                    "deep_hours": round(self._ms_to_hours(stages.get("total_slow_wave_sleep_time_milli", 0)), 1),
                    "rem_hours": round(self._ms_to_hours(stages.get("total_rem_sleep_time_milli", 0)), 1),
                    "disturbances": stages.get("disturbance_count", 0),
                }
        except Exception as e:
            print(f"Error fetching sleep: {e}")

        # Fetch strain (from cycle)
        try:
            cycles = self.get_cycle_collection(start=start, end=end, limit=1)
            if cycles:
                c = cycles[0]
                score = c.get("score") or {}
                strain_data = {
                    "day_strain": round(score.get("strain", 0), 1),
                    "calories": round(score.get("kilojoule", 0)),
                    "avg_hr": score.get("average_heart_rate", 0),
                    "max_hr": score.get("max_heart_rate", 0),
                }
        except Exception as e:
            print(f"Error fetching strain: {e}")

        return TodayData(
            date=today.isoformat(),
            recovery=recovery_data,
            sleep=sleep_data,
            strain=strain_data,
        )

    def get_recovery_history(self, days: int = 7) -> list[RecoverySummary]:
        """Get recovery scores for the past N days."""
        end = datetime.now(timezone.utc)
        start = end - timedelta(days=days)

        recoveries = self.get_recovery_collection(start=start, end=end, limit=days)
        results = []

        for r in recoveries:
            score = r.get("score") or {}
            created = r.get("created_at", "")
            date_str = created[:10] if created else "unknown"

            results.append(RecoverySummary(
                date=date_str,
                score=score.get("recovery_score"),
                hrv=round(score.get("hrv_rmssd_milli", 0), 1) if score.get("hrv_rmssd_milli") else None,
                rhr=round(score.get("resting_heart_rate", 0)) if score.get("resting_heart_rate") else None,
                spo2=score.get("spo2_percentage"),
                skin_temp=round(score.get("skin_temp_celsius", 0), 1) if score.get("skin_temp_celsius") else None,
                zone=self._get_recovery_zone(score.get("recovery_score")),
            ))

        return results

    def get_sleep_history(self, days: int = 7) -> list[SleepSummary]:
        """Get sleep data for the past N days."""
        end = datetime.now(timezone.utc)
        start = end - timedelta(days=days)

        sleeps = self.get_sleep_collection(start=start, end=end, limit=days)
        results = []

        for s in sleeps:
            score = s.get("score") or {}
            stages = score.get("stage_summary") or {}
            created = s.get("created_at", "")
            date_str = created[:10] if created else "unknown"

            total_sleep_ms = stages.get("total_in_bed_time_milli", 0) - stages.get("total_awake_time_milli", 0)

            results.append(SleepSummary(
                date=date_str,
                total_hours=round(self._ms_to_hours(total_sleep_ms), 1),
                efficiency=round(score.get("sleep_efficiency_percentage", 0), 1) if score.get("sleep_efficiency_percentage") else None,
                performance=round(score.get("sleep_performance_percentage", 0), 1) if score.get("sleep_performance_percentage") else None,
                light_hours=round(self._ms_to_hours(stages.get("total_light_sleep_time_milli", 0)), 1),
                deep_hours=round(self._ms_to_hours(stages.get("total_slow_wave_sleep_time_milli", 0)), 1),
                rem_hours=round(self._ms_to_hours(stages.get("total_rem_sleep_time_milli", 0)), 1),
                disturbances=stages.get("disturbance_count", 0),
            ))

        return results

    def get_strain_history(self, days: int = 7) -> list[StrainSummary]:
        """Get strain data for the past N days."""
        end = datetime.now(timezone.utc)
        start = end - timedelta(days=days)

        cycles = self.get_cycle_collection(start=start, end=end, limit=days)
        results = []

        for c in cycles:
            score = c.get("score") or {}
            created = c.get("created_at", "")
            date_str = created[:10] if created else "unknown"

            results.append(StrainSummary(
                date=date_str,
                day_strain=round(score.get("strain", 0), 1),
                calories=round(score.get("kilojoule", 0)),
                avg_hr=score.get("average_heart_rate", 0),
                max_hr=score.get("max_heart_rate", 0),
            ))

        return results

    def get_workout_history(self, days: int = 7) -> list[WorkoutSummary]:
        """Get workouts/activities for the past N days."""
        end = datetime.now(timezone.utc)
        start = end - timedelta(days=days)

        workouts = self.get_workout_collection(start=start, end=end, limit=25)
        results = []

        for w in workouts:
            score = w.get("score") or {}
            start_time = w.get("start", "")
            end_time = w.get("end", "")

            # Parse date and time from ISO timestamp, converting UTC to Brisbane
            if start_time:
                utc_dt = datetime.fromisoformat(start_time.replace('Z', '+00:00'))
                brisbane_dt = utc_dt.astimezone(ZoneInfo("Australia/Brisbane"))
                date_str = brisbane_dt.strftime('%Y-%m-%d')
                time_str = brisbane_dt.strftime('%H:%M')
            else:
                date_str = "unknown"
                time_str = "00:00"

            # Calculate duration in minutes
            duration_minutes = 0.0
            if start_time and end_time:
                try:
                    start_dt = datetime.fromisoformat(start_time.replace("Z", "+00:00"))
                    end_dt = datetime.fromisoformat(end_time.replace("Z", "+00:00"))
                    duration_minutes = (end_dt - start_dt).total_seconds() / 60
                except (ValueError, TypeError):
                    pass

            # Get workout ID (handle both string UUID and int formats)
            workout_id = str(w.get("id", ""))

            results.append(WorkoutSummary(
                id=workout_id,
                date=date_str,
                time=time_str,
                sport_name=w.get("sport_name", "Unknown"),
                sport_id=w.get("sport_id", -1),
                duration_minutes=round(duration_minutes, 1),
                strain=round(score.get("strain", 0), 1),
                avg_hr=score.get("average_heart_rate", 0),
                max_hr=score.get("max_heart_rate", 0),
                calories=round(score.get("kilojoule", 0)),
            ))

        return results

    # -------------------------------------------------------------------------
    # Helpers
    # -------------------------------------------------------------------------

    @staticmethod
    def _ms_to_hours(ms: int) -> float:
        """Convert milliseconds to hours."""
        return ms / (1000 * 60 * 60)

    @staticmethod
    def _get_recovery_zone(score: Optional[int]) -> str:
        """Get recovery zone color from score."""
        if score is None:
            return "unknown"
        if score >= 67:
            return "green"
        if score >= 34:
            return "yellow"
        return "red"
