#!/usr/bin/env python3
"""
Kitzu — Oura Ring Data Collector

Standalone collector that pulls data from the Oura Ring API v2
and stores it in the Kitzu data lake. No WhatsApp dependency.

Usage:
    python3 collect_oura.py              # Collect today + yesterday
    python3 collect_oura.py --days 14    # Backfill last 14 days
    python3 collect_oura.py --status     # Show what's in the data lake

Writes to: kitzu/data/oura/{date}.json
Updates:   kitzu/data/unified/profile.json (oura section)
"""

import sys
import os
import json
import logging
import argparse
import requests
from datetime import date, datetime, timedelta
from pathlib import Path

# Add parent to path for schema import
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from schema import (
    OuraDailyRecord, store_daily_oura, load_oura_history,
    load_profile, save_profile, OURA_DIR, data_lake_status
)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%H:%M:%S",
)
log = logging.getLogger("kitzu.oura")

# ── Configuration ────────────────────────────────────────────

BASE_URL = "https://api.ouraring.com"

# Token locations (checked in order)
TOKEN_PATHS = [
    Path.home() / "clawd" / ".oura-token.json",
    Path.home() / "clawd" / "_Organized" / "Data" / ".oura-token.json",
    Path(__file__).resolve().parent.parent.parent / "_Organized" / "Data" / ".oura-token.json",
]


def load_token() -> str:
    """Load Oura API token from env or file."""
    # Environment variable first
    token = os.getenv("OURA_ACCESS_TOKEN", "").strip()
    if token:
        log.info("Token loaded from OURA_ACCESS_TOKEN env var")
        return token

    # Try file locations
    for path in TOKEN_PATHS:
        if path.exists():
            try:
                data = json.loads(path.read_text())
                token = data.get("access_token", "").strip()
                if token:
                    log.info(f"Token loaded from {path}")
                    return token
            except Exception as e:
                log.warning(f"Failed to read {path}: {e}")

    log.error("No Oura API token found. Set OURA_ACCESS_TOKEN or create token file.")
    return ""


# ── API Client ───────────────────────────────────────────────

class OuraCollector:
    """Minimal Oura API client focused on data collection."""

    def __init__(self, token: str):
        self.token = token
        self.headers = {
            "Authorization": f"Bearer {token}",
            "Accept": "application/json",
        }

    def _get(self, endpoint: str, params: dict = None) -> dict:
        """Make an authenticated GET request."""
        url = BASE_URL + endpoint
        resp = requests.get(url, params=params, headers=self.headers, timeout=15)
        if resp.status_code != 200:
            log.error(f"API {resp.status_code}: {endpoint} → {resp.text[:200]}")
            resp.raise_for_status()
        return resp.json()

    def validate(self) -> bool:
        """Check if token is valid."""
        try:
            self._get("/v2/usercollection/personal_info")
            return True
        except Exception as e:
            log.error(f"Token validation failed: {e}")
            return False

    def collect_day(self, target_date: date) -> OuraDailyRecord:
        """Collect all Oura data for a single day and return a normalized record."""
        record = OuraDailyRecord(
            date=target_date.isoformat(),
            synced_at=datetime.now().isoformat(),
        )

        # ── Sleep ────────────────────────────────────
        try:
            data = self._get("/v2/usercollection/sleep", {
                "start_date": target_date.isoformat(),
                "end_date": (target_date + timedelta(days=1)).isoformat(),
            })
            sessions = data.get("data", [])
            if sessions:
                # Pick the longest session (main sleep, not naps)
                main = max(sessions, key=lambda s: s.get("total_sleep_duration", 0))
                record.sleep_hours = round(main.get("total_sleep_duration", 0) / 3600, 1)
                record.deep_sleep_min = round(main.get("deep_sleep_duration", 0) / 60)
                record.rem_sleep_min = round(main.get("rem_sleep_duration", 0) / 60)
                record.light_sleep_min = round(main.get("light_sleep_duration", 0) / 60)
                record.awake_min = round(main.get("awake_time", 0) / 60)

                # HRV from sleep
                avg_hrv = main.get("average_hrv")
                if avg_hrv is not None:
                    record.avg_hrv = round(avg_hrv, 1)

                # Resting HR from sleep
                rhr = main.get("lowest_heart_rate")
                if rhr is not None:
                    record.resting_hr = rhr

                # Bed/wake times
                try:
                    bt = main.get("bedtime_start", "")
                    if bt:
                        record.bedtime = datetime.fromisoformat(
                            bt.replace("Z", "+00:00")
                        ).strftime("%H:%M")
                    wt = main.get("bedtime_end", "")
                    if wt:
                        record.wake_time = datetime.fromisoformat(
                            wt.replace("Z", "+00:00")
                        ).strftime("%H:%M")
                except (ValueError, TypeError):
                    pass

                log.info(f"  Sleep: {record.sleep_hours}h | Deep: {record.deep_sleep_min}m | HRV: {record.avg_hrv}")
        except Exception as e:
            log.warning(f"  Sleep fetch failed: {e}")

        # ── Readiness ────────────────────────────────
        try:
            data = self._get("/v2/usercollection/daily_readiness", {
                "start_date": target_date.isoformat(),
                "end_date": (target_date + timedelta(days=1)).isoformat(),
            })
            entries = data.get("data", [])
            if entries:
                record.readiness_score = entries[-1].get("score")
                log.info(f"  Readiness: {record.readiness_score}")
        except Exception as e:
            log.warning(f"  Readiness fetch failed: {e}")

        # ── Activity (Steps) ─────────────────────────
        try:
            data = self._get("/v2/usercollection/daily_activity", {
                "start_date": target_date.isoformat(),
                "end_date": (target_date + timedelta(days=1)).isoformat(),
            })
            entries = data.get("data", [])
            if entries:
                day_data = entries[-1]
                record.steps = day_data.get("steps", 0)
                record.active_calories = day_data.get("active_calories", 0)
                log.info(f"  Steps: {record.steps:,} | Active cal: {record.active_calories}")
        except Exception as e:
            log.warning(f"  Activity fetch failed: {e}")

        # ── Exercise Sessions ────────────────────────
        exercises = []

        # Workouts (auto-detected + app-started)
        try:
            data = self._get("/v2/usercollection/workout", {
                "start_date": target_date.isoformat(),
                "end_date": (target_date + timedelta(days=1)).isoformat(),
            })
            for w in data.get("data", []):
                ex = self._parse_exercise(w)
                if ex:
                    exercises.append(ex)
        except Exception as e:
            log.warning(f"  Workout fetch failed: {e}")

        # Sessions (manually tagged: meditation, breathing, etc.)
        try:
            data = self._get("/v2/usercollection/sessions", {
                "start_date": target_date.isoformat(),
                "end_date": (target_date + timedelta(days=1)).isoformat(),
            })
            for s in data.get("data", []):
                ex = self._parse_exercise(s)
                if ex:
                    exercises.append(ex)
        except Exception as e:
            log.debug(f"  Sessions fetch skipped: {e}")

        if exercises:
            record.exercises = exercises
            log.info(f"  Exercises: {len(exercises)} sessions")

        return record

    def _fetch_hr_zones(self, start_dt: datetime, end_dt: datetime) -> dict:
        """
        Fetch HR samples during a workout window and compute zone-minutes.

        Uses Karvonen-adjusted zones for age 55 (est max HR ~165 bpm):
          Z1 < 60% max  (recovery)
          Z2 60-70%     (easy aerobic)
          Z3 70-80%     (tempo / aerobic threshold)
          Z4 80-90%     (lactate threshold)
          Z5 >= 90%     (VO2max / anaerobic)

        Returns dict with zone_minutes, avg_hr, max_hr, and training_load score.
        Training load = weighted zone-minutes (Z1×1 + Z2×2 + Z3×3 + Z4×4 + Z5×5).
        """
        try:
            data = self._get("/v2/usercollection/heartrate", {
                "start_datetime": start_dt.isoformat(),
                "end_datetime": end_dt.isoformat(),
            })
            samples = data.get("data", [])
            if not samples:
                return {}

            bpms = [s["bpm"] for s in samples if s.get("bpm")]
            if not bpms:
                return {}

            # Zone boundaries (age-estimated max HR for 55yo = 220-55 = 165)
            max_hr = 165
            boundaries = [max_hr * f for f in (0.6, 0.7, 0.8, 0.9)]

            # Count samples per zone
            zone_counts = [0, 0, 0, 0, 0]
            for b in bpms:
                if b < boundaries[0]:
                    zone_counts[0] += 1
                elif b < boundaries[1]:
                    zone_counts[1] += 1
                elif b < boundaries[2]:
                    zone_counts[2] += 1
                elif b < boundaries[3]:
                    zone_counts[3] += 1
                else:
                    zone_counts[4] += 1

            # Convert sample counts to zone-minutes proportionally
            total_samples = len(bpms)
            duration_min = (end_dt - start_dt).total_seconds() / 60
            zone_minutes = [round(c / total_samples * duration_min, 1) for c in zone_counts]

            # Training load: weighted zone-minutes
            # Z1 barely counts, Z4/Z5 are where real recovery cost lives
            weights = [1, 2, 3, 4, 5]
            training_load = round(sum(zm * w for zm, w in zip(zone_minutes, weights)), 1)

            return {
                "avg_hr": round(sum(bpms) / len(bpms)),
                "max_hr": max(bpms),
                "zone_minutes": {
                    "z1": zone_minutes[0],
                    "z2": zone_minutes[1],
                    "z3": zone_minutes[2],
                    "z4": zone_minutes[3],
                    "z5": zone_minutes[4],
                },
                "training_load": training_load,
                "hr_samples": total_samples,
            }

        except Exception as e:
            log.debug(f"  HR zone fetch failed: {e}")
            return {}

    def _parse_exercise(self, entry: dict) -> dict:
        """Parse an exercise/session entry into a clean dict with HR zone data."""
        start_str = entry.get("start_datetime", "")
        try:
            start_dt = datetime.fromisoformat(start_str.replace("Z", "+00:00"))
        except (ValueError, TypeError):
            return None

        end_str = entry.get("end_datetime", "")
        duration_min = 0
        end_dt = None
        if end_str:
            try:
                end_dt = datetime.fromisoformat(end_str.replace("Z", "+00:00"))
                duration_min = round((end_dt - start_dt).total_seconds() / 60)
            except (ValueError, TypeError):
                pass

        activity_type = entry.get("type", entry.get("activity", "other"))
        type_map = {
            "cycling": "Cycling", "running": "Running", "walking": "Walking",
            "hiking": "Hiking", "swimming": "Swimming", "yoga": "Yoga",
            "strength_training": "Strength", "hiit": "HIIT",
            "meditation": "Meditation", "breathing": "Breathwork",
            "indoor_cycling": "Spinning", "outdoor_cycling": "Cycling",
            "indoor_running": "Treadmill", "outdoor_running": "Running",
        }
        etype = type_map.get(activity_type, activity_type.replace("_", " ").title())

        result = {
            "type": etype,
            "time": start_dt.strftime("%H:%M"),
            "duration_min": duration_min,
            "calories": entry.get("calories"),
            "intensity": entry.get("intensity"),  # Oura's own: easy/moderate/hard
        }

        # Fetch HR zone data for this workout window
        if end_dt and duration_min >= 10:  # Only fetch zones for 10+ min activities
            hr_data = self._fetch_hr_zones(start_dt, end_dt)
            if hr_data:
                result["avg_hr"] = hr_data["avg_hr"]
                result["max_hr"] = hr_data["max_hr"]
                result["zone_minutes"] = hr_data["zone_minutes"]
                result["training_load"] = hr_data["training_load"]
                log.info(f"    HR zones: avg {hr_data['avg_hr']} bpm, "
                         f"load {hr_data['training_load']}, "
                         f"Z3+: {hr_data['zone_minutes']['z3'] + hr_data['zone_minutes']['z4'] + hr_data['zone_minutes']['z5']:.0f} min")

        return result


# ── Profile Update ───────────────────────────────────────────

def update_profile_oura(records: list):
    """Update the unified profile with latest Oura data."""
    if not records:
        return

    profile = load_profile()
    latest = records[0]  # Most recent day

    profile["oura"]["last_sync"] = datetime.now().isoformat()

    # Latest sleep
    if latest.get("sleep_hours"):
        profile["oura"]["sleep"] = {
            "date": latest["date"],
            "hours": latest["sleep_hours"],
            "deep_min": latest.get("deep_sleep_min"),
            "rem_min": latest.get("rem_sleep_min"),
            "avg_hrv": latest.get("avg_hrv"),
            "resting_hr": latest.get("resting_hr"),
            "bedtime": latest.get("bedtime"),
            "wake_time": latest.get("wake_time"),
        }

    # Latest readiness
    if latest.get("readiness_score"):
        profile["oura"]["readiness"] = {
            "date": latest["date"],
            "score": latest["readiness_score"],
        }

    # Latest steps
    if latest.get("steps"):
        profile["oura"]["steps"] = {
            "date": latest["date"],
            "count": latest["steps"],
            "active_calories": latest.get("active_calories"),
        }

    # HRV trend (last 14 days) — read from full data lake, not just current run
    hrv_trend = []
    for r in load_oura_history(days=14):
        if r.get("avg_hrv"):
            hrv_trend.append({
                "date": r["date"],
                "hrv": r["avg_hrv"],
            })
    if hrv_trend:
        profile["oura"]["hrv_trend"] = hrv_trend

    # Latest exercise (today only — for brief engine)
    if latest.get("exercises"):
        profile["oura"]["exercise"] = {
            "date": latest["date"],
            "sessions": latest["exercises"],
        }

    # Exercise feed (last 14 days — for dashboard feed card + training load engine)
    exercise_feed = []
    for r in load_oura_history(days=14):
        for ex in (r.get("exercises") or []):
            entry = {
                "date": r["date"],
                "type": ex.get("type", "Other"),
                "time": ex.get("time"),
                "duration_min": ex.get("duration_min", 0),
                "calories": int(ex.get("calories") or 0),
                "intensity": ex.get("intensity"),
            }
            # Carry through HR zone data if available
            if ex.get("training_load"):
                entry["avg_hr"] = ex.get("avg_hr")
                entry["max_hr"] = ex.get("max_hr")
                entry["zone_minutes"] = ex.get("zone_minutes")
                entry["training_load"] = ex.get("training_load")
            exercise_feed.append(entry)
    # Sort newest first
    exercise_feed.sort(key=lambda x: (x["date"], x["time"] or ""), reverse=True)
    profile["oura"]["exercise_feed"] = exercise_feed[:30]  # cap at 30 entries

    save_profile(profile)
    log.info("Unified profile updated (oura section)")


# ── Main ─────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(description="Kitzu Oura Ring Data Collector")
    parser.add_argument("--days", type=int, default=2, help="Number of days to collect (default: 2)")
    parser.add_argument("--status", action="store_true", help="Show data lake status")
    parser.add_argument("--backfill", type=int, help="Backfill N days of historical data")
    args = parser.parse_args()

    if args.status:
        status = data_lake_status()
        print("\nKitzu Data Lake Status")
        print("=" * 50)
        for source, info in status.items():
            if source == "profile":
                print(f"  {'profile':<14} exists={info['exists']}  updated={info['last_updated']}")
            else:
                print(f"  {source:<14} {info['files']:>3} files  latest={info['latest']}")
        print()
        return

    # Load token
    token = load_token()
    if not token:
        print("\nNo Oura API token found.")
        print("Set OURA_ACCESS_TOKEN env var or create ~/clawd/.oura-token.json")
        print('Format: {"access_token": "YOUR_TOKEN_HERE"}')
        sys.exit(1)

    # Initialize collector
    collector = OuraCollector(token)

    # Validate token
    log.info("Validating Oura API token...")
    if not collector.validate():
        print("Token validation failed. Check your access token.")
        sys.exit(1)
    log.info("Token valid ✓")

    # Determine date range
    days = args.backfill if args.backfill else args.days
    end_date = date.today()
    start_date = end_date - timedelta(days=days - 1)

    log.info(f"Collecting {days} days: {start_date} → {end_date}")
    print()

    # Collect each day
    collected = []
    for i in range(days):
        target = end_date - timedelta(days=i)
        existing = OURA_DIR / f"{target.isoformat()}.json"

        # Skip if already collected (unless it's today or yesterday — always refresh)
        if existing.exists() and i >= 2 and not args.backfill:
            log.info(f"[{target}] Already collected, skipping")
            # Still load it for profile update
            collected.append(json.loads(existing.read_text()))
            continue

        log.info(f"[{target}] Collecting...")
        try:
            record = collector.collect_day(target)
            store_daily_oura(record)
            collected.append(record.to_dict())
            log.info(f"[{target}] Stored ✓")
        except Exception as e:
            log.error(f"[{target}] Failed: {e}")

    # Update unified profile
    if collected:
        # Sort newest first
        collected.sort(key=lambda r: r["date"], reverse=True)
        update_profile_oura(collected)

    # Summary
    print()
    print(f"Collected {len(collected)} days of Oura data")
    oura_files = list(OURA_DIR.glob("*.json"))
    print(f"Data lake: {len(oura_files)} total Oura records")

    if collected:
        latest = collected[0]
        print(f"\nLatest ({latest['date']}):")
        if latest.get("sleep_hours"):
            print(f"  Sleep:     {latest['sleep_hours']}h")
        if latest.get("avg_hrv"):
            print(f"  HRV:       {latest['avg_hrv']} ms")
        if latest.get("readiness_score"):
            print(f"  Readiness: {latest['readiness_score']}")
        if latest.get("steps"):
            print(f"  Steps:     {latest['steps']:,}")


if __name__ == "__main__":
    main()
