"""
Health Connect ZIP ingestion pipeline.

Parses the Health Connect export SQLite database from a zip file
and updates dashboard-data.json with:
  - Weight (latest morning reading)
  - Body fat (latest reading)
  - Blood pressure (latest reading)
  - Sleep (last night: duration, stages, HRV)
  - Exercise (recent sessions: type, duration, calories)
  - Steps (daily total)
  - Resting heart rate

Bill drops a Health Connect zip into Google Drive each morning between 7-8 AM.
The scheduler watches for new files and auto-ingests.
"""

import os
import re
import json
import sqlite3
import zipfile
import shutil
import tempfile
import logging
from datetime import datetime, date, timedelta
from pathlib import Path

logger = logging.getLogger("health-connect")

# Health Connect exercise type codes → human labels
EXERCISE_TYPES = {
    0: "Other", 2: "Badminton", 4: "Baseball", 5: "Biking",
    8: "Boxing", 10: "Cricket", 11: "Dancing", 13: "Fencing",
    14: "Football (American)", 15: "Football (Australian)",
    16: "Frisbee Disc", 17: "Golf", 18: "Guided Breathing",
    19: "Gymnastics", 20: "Handball", 21: "Walking",
    22: "HIIT", 23: "Hiking", 25: "Ice Hockey",
    26: "Ice Skating", 29: "Martial Arts",
    31: "Paddling", 32: "Paragliding",
    33: "Pilates", 34: "Racquetball", 35: "Rock Climbing",
    36: "Roller Hockey", 37: "Rowing", 38: "Rugby",
    39: "Running", 40: "Sailing", 41: "Scuba Diving",
    42: "Skating", 43: "Skiing", 44: "Snowboarding",
    45: "Strength Training", 46: "Stretching",
    47: "Surfing", 48: "Swimming (Open Water)",
    49: "Swimming (Pool)", 50: "Table Tennis",
    51: "Tennis", 52: "Volleyball", 53: "Elliptical",
    56: "Wheelchair", 57: "Water Polo", 58: "Yoga",
    59: "Stair Climbing",
}

# Sleep stage codes
SLEEP_STAGES = {
    0: "unknown", 1: "awake", 2: "sleeping",
    3: "out_of_bed", 4: "light", 5: "deep", 6: "rem"
}


class HealthConnectParser:
    """Parse a Health Connect export zip and extract health metrics."""

    def __init__(self, zip_path: str):
        self.zip_path = zip_path
        self.db_path = None
        self.temp_dir = None

    def __enter__(self):
        """Extract the SQLite DB from the zip to a temp directory."""
        self.temp_dir = tempfile.mkdtemp()
        try:
            with zipfile.ZipFile(self.zip_path, 'r') as zf:
                zf.extractall(self.temp_dir)
        except zipfile.BadZipFile as e:
            if "Bad CRC-32" in str(e):
                # Health Connect exports sometimes have CRC mismatches.
                # Re-extract with CRC validation disabled.
                logger.warning(f"CRC mismatch in zip — retrying without CRC check: {e}")
                _orig_update_crc = zipfile.ZipExtFile._update_crc
                try:
                    zipfile.ZipExtFile._update_crc = lambda self, data: None
                    with zipfile.ZipFile(self.zip_path, 'r') as zf:
                        zf.extractall(self.temp_dir)
                finally:
                    zipfile.ZipExtFile._update_crc = _orig_update_crc
            else:
                raise
        # Find the .db file
        for f in os.listdir(self.temp_dir):
            if f.endswith('.db'):
                self.db_path = os.path.join(self.temp_dir, f)
                break
        if not self.db_path:
            raise FileNotFoundError("No .db file found in Health Connect zip")
        return self

    def __exit__(self, *args):
        """Cleanup temp directory."""
        if self.temp_dir:
            shutil.rmtree(self.temp_dir, ignore_errors=True)

    def _connect(self):
        # If we already recovered (or failed to), use cached path
        if hasattr(self, '_db_verified'):
            return sqlite3.connect(self.db_path)

        conn = sqlite3.connect(self.db_path)
        try:
            conn.execute("SELECT 1 FROM weight_record_table LIMIT 1")
            self._db_verified = True
            return conn
        except sqlite3.DatabaseError:
            conn.close()
            logger.warning("DB appears malformed — attempting recovery")
            recovered = self._recover_db()
            self._db_verified = True
            return recovered

    def _recover_db(self):
        """
        Recover a malformed SQLite DB.
        Strategy 1: sqlite3 CLI .recover (most robust for page-level corruption)
        Strategy 2: sqlite3 CLI .dump + rebuild
        Strategy 3: Python-level PRAGMA integrity_check + connect anyway
        """
        recovered_path = self.db_path + ".recovered"
        original_path = self.db_path

        import subprocess

        # Strategy 1: .recover command (available in sqlite3 >= 3.29.0)
        try:
            recover = subprocess.run(
                ["sqlite3", self.db_path, ".recover"],
                capture_output=True, text=True, timeout=60
            )
            if recover.stdout and len(recover.stdout) > 1000:
                rebuild = subprocess.run(
                    ["sqlite3", recovered_path],
                    input=recover.stdout, capture_output=True, text=True, timeout=60
                )
                if os.path.exists(recovered_path) and os.path.getsize(recovered_path) > 100000:
                    self.db_path = recovered_path
                    # Verify recovered DB
                    test_conn = sqlite3.connect(self.db_path)
                    test_conn.execute("SELECT 1 FROM weight_record_table LIMIT 1")
                    test_conn.close()
                    logger.info(f"DB recovered via .recover ({os.path.getsize(recovered_path):,} bytes)")
                    return sqlite3.connect(self.db_path)
        except (subprocess.TimeoutExpired, FileNotFoundError, sqlite3.DatabaseError) as e:
            logger.debug(f"Strategy 1 (.recover) failed: {e}")
            if os.path.exists(recovered_path):
                os.remove(recovered_path)

        # Strategy 2: .dump + rebuild
        try:
            dump = subprocess.run(
                ["sqlite3", original_path, ".dump"],
                capture_output=True, text=True, timeout=60
            )
            if dump.stdout and len(dump.stdout) > 1000:
                rebuild = subprocess.run(
                    ["sqlite3", recovered_path],
                    input=dump.stdout, capture_output=True, text=True, timeout=60
                )
                if os.path.exists(recovered_path) and os.path.getsize(recovered_path) > 100000:
                    self.db_path = recovered_path
                    test_conn = sqlite3.connect(self.db_path)
                    test_conn.execute("SELECT 1 FROM weight_record_table LIMIT 1")
                    test_conn.close()
                    logger.info(f"DB recovered via .dump ({os.path.getsize(recovered_path):,} bytes)")
                    return sqlite3.connect(self.db_path)
        except (subprocess.TimeoutExpired, FileNotFoundError, sqlite3.DatabaseError) as e:
            logger.debug(f"Strategy 2 (.dump) failed: {e}")
            if os.path.exists(recovered_path):
                os.remove(recovered_path)

        # Strategy 3: Just connect — some queries may still work
        logger.warning("DB recovery failed — proceeding with original (some queries may fail)")
        self.db_path = original_path
        return sqlite3.connect(self.db_path)

    # ── Weight ──────────────────────────────────────────────

    def get_latest_weight(self):
        """Get the most recent weight reading. Returns dict or None."""
        conn = self._connect()
        c = conn.cursor()
        c.execute("""
            SELECT time, weight FROM weight_record_table
            ORDER BY time DESC LIMIT 1
        """)
        row = c.fetchone()
        conn.close()
        if row:
            dt = datetime.fromtimestamp(row[0] / 1000)
            lbs = round(row[1] / 453.592, 1)
            return {"timestamp": dt.isoformat(), "date": dt.strftime("%Y-%m-%d"), "lbs": lbs}
        return None

    def get_weight_series(self, days=30):
        """Get daily weight readings (first reading of each day) for the last N days."""
        conn = self._connect()
        c = conn.cursor()
        cutoff = int((datetime.now() - timedelta(days=days)).timestamp() * 1000)
        c.execute("""
            SELECT time, weight FROM weight_record_table
            WHERE time > ? ORDER BY time ASC
        """, (cutoff,))
        rows = c.fetchall()
        conn.close()

        # Group by date, take first reading (morning weigh-in)
        daily = {}
        for ts, grams in rows:
            dt = datetime.fromtimestamp(ts / 1000)
            day_key = dt.strftime("%Y-%m-%d")
            if day_key not in daily:
                daily[day_key] = round(grams / 453.592, 1)

        return daily

    # ── Body Fat ────────────────────────────────────────────

    def get_latest_body_fat(self):
        """Get the most recent body fat percentage."""
        conn = self._connect()
        c = conn.cursor()
        c.execute("""
            SELECT time, percentage FROM body_fat_record_table
            ORDER BY time DESC LIMIT 1
        """)
        row = c.fetchone()
        conn.close()
        if row:
            dt = datetime.fromtimestamp(row[0] / 1000)
            return {"timestamp": dt.isoformat(), "percentage": round(row[1], 1)}
        return None

    def get_body_fat_series(self, days=30):
        """Get daily body fat readings for the last N days."""
        conn = self._connect()
        c = conn.cursor()
        cutoff = int((datetime.now() - timedelta(days=days)).timestamp() * 1000)
        c.execute("""
            SELECT time, percentage FROM body_fat_record_table
            WHERE time > ? ORDER BY time ASC
        """, (cutoff,))
        rows = c.fetchall()
        conn.close()

        daily = {}
        for ts, pct in rows:
            dt = datetime.fromtimestamp(ts / 1000)
            day_key = dt.strftime("%Y-%m-%d")
            if day_key not in daily:
                daily[day_key] = round(pct, 1)
        return daily

    # ── Blood Pressure ──────────────────────────────────────

    def get_latest_blood_pressure(self):
        """Get the most recent BP reading."""
        conn = self._connect()
        c = conn.cursor()
        c.execute("""
            SELECT time, systolic, diastolic
            FROM blood_pressure_record_table
            ORDER BY time DESC LIMIT 1
        """)
        row = c.fetchone()
        conn.close()
        if row:
            dt = datetime.fromtimestamp(row[0] / 1000)
            sys_val = round(row[1])
            dia_val = round(row[2])
            # Classify
            if sys_val < 120 and dia_val < 80:
                status = "normal"
            elif sys_val < 130 and dia_val < 80:
                status = "elevated"
            elif sys_val < 140 or dia_val < 90:
                status = "high-stage1"
            else:
                status = "high-stage2"
            return {
                "timestamp": dt.isoformat(),
                "date": dt.strftime("%Y-%m-%d"),
                "systolic": sys_val,
                "diastolic": dia_val,
                "status": status,
            }
        return None

    # ── Sleep ───────────────────────────────────────────────

    def get_last_sleep(self):
        """Get last night's sleep session with stage breakdown."""
        conn = self._connect()
        c = conn.cursor()

        # Get the most recent sleep session
        # Use local_date_time columns for accurate local bedtime/wake times
        c.execute("""
            SELECT row_id, start_time, end_time, title,
                   local_date_time_start_time, local_date_time_end_time
            FROM sleep_session_record_table
            ORDER BY start_time DESC LIMIT 1
        """)
        session = c.fetchone()
        if not session:
            conn.close()
            return None

        row_id, start_ms, end_ms, title, local_start_ms, local_end_ms = session
        # Use local timestamps for display (bedtime/wake), UTC for duration
        start = datetime.fromtimestamp((local_start_ms or start_ms) / 1000)
        end = datetime.fromtimestamp((local_end_ms or end_ms) / 1000)
        total_hours = (end_ms - start_ms) / 3600000

        # Get sleep stages
        c.execute("""
            SELECT stage_start_time, stage_end_time, stage_type
            FROM sleep_stages_table
            WHERE parent_key = ?
        """, (row_id,))
        stages = c.fetchall()

        stage_minutes = {"awake": 0, "light": 0, "deep": 0, "rem": 0}
        for s_start, s_end, s_type in stages:
            mins = (s_end - s_start) / 60000
            stage_name = SLEEP_STAGES.get(s_type, "unknown")
            if stage_name in stage_minutes:
                stage_minutes[stage_name] += mins

        # Get HRV during sleep
        c.execute("""
            SELECT AVG(heart_rate_variability_millis)
            FROM heart_rate_variability_rmssd_record_table
            WHERE time BETWEEN ? AND ?
        """, (start_ms, end_ms))
        avg_hrv = c.fetchone()[0]

        conn.close()

        return {
            "date": start.strftime("%Y-%m-%d"),
            "bedtime": start.strftime("%H:%M"),
            "wake_time": end.strftime("%H:%M"),
            "total_hours": round(total_hours, 1),
            "source": title or "Unknown",
            "stages": {
                "deep_min": round(stage_minutes["deep"]),
                "rem_min": round(stage_minutes["rem"]),
                "light_min": round(stage_minutes["light"]),
                "awake_min": round(stage_minutes["awake"]),
            },
            "avg_hrv": round(avg_hrv, 1) if avg_hrv else None,
        }

    def get_sleep_series(self, days=14):
        """Get sleep duration for the last N nights."""
        conn = self._connect()
        c = conn.cursor()
        cutoff = int((datetime.now() - timedelta(days=days)).timestamp() * 1000)
        c.execute("""
            SELECT start_time, end_time
            FROM sleep_session_record_table
            WHERE start_time > ?
            ORDER BY start_time ASC
        """, (cutoff,))
        rows = c.fetchall()
        conn.close()

        # Group by night (use start_time date)
        nightly = {}
        for start_ms, end_ms in rows:
            dt = datetime.fromtimestamp(start_ms / 1000)
            night = dt.strftime("%Y-%m-%d")
            hours = (end_ms - start_ms) / 3600000
            # Keep longest session per night (filter out naps)
            if night not in nightly or hours > nightly[night]:
                nightly[night] = round(hours, 1)

        return nightly

    # ── Exercise ────────────────────────────────────────────

    def get_recent_exercises(self, days=7):
        """Get exercise sessions from the last N days."""
        conn = self._connect()
        c = conn.cursor()
        cutoff = int((datetime.now() - timedelta(days=days)).timestamp() * 1000)
        c.execute("""
            SELECT start_time, end_time, exercise_type, title, notes
            FROM exercise_session_record_table
            WHERE start_time > ?
            ORDER BY start_time DESC
        """, (cutoff,))
        rows = c.fetchall()
        conn.close()

        sessions = []
        for start_ms, end_ms, ex_type, title, notes in rows:
            dt = datetime.fromtimestamp(start_ms / 1000)
            duration_min = round((end_ms - start_ms) / 60000)
            type_label = EXERCISE_TYPES.get(ex_type, f"Type {ex_type}")

            # Use title if available (Peloton class names), otherwise type label
            name = title if title else type_label

            # Skip very short sessions (< 5 min) that are likely just walking around
            if duration_min < 5 and ex_type == 21:  # Walking
                continue

            sessions.append({
                "date": dt.strftime("%Y-%m-%d"),
                "time": dt.strftime("%H:%M"),
                "type": type_label,
                "name": name,
                "duration_min": duration_min,
            })

        return sessions

    # ── Steps ───────────────────────────────────────────────

    def get_daily_steps(self, days=7):
        """Get daily step totals for the last N days."""
        conn = self._connect()
        c = conn.cursor()
        cutoff = int((datetime.now() - timedelta(days=days)).timestamp() * 1000)
        c.execute("""
            SELECT start_time, count FROM steps_record_table
            WHERE start_time > ?
        """, (cutoff,))
        rows = c.fetchall()
        conn.close()

        daily = {}
        for ts, count in rows:
            day_key = datetime.fromtimestamp(ts / 1000).strftime("%Y-%m-%d")
            daily[day_key] = daily.get(day_key, 0) + count

        return daily

    # ── Full Extraction ─────────────────────────────────────

    def extract_all(self):
        """
        Extract all health metrics into a single dict.
        Each metric is extracted independently — if one table is corrupt,
        the rest still get pulled.
        """
        result = {
            "exported_at": datetime.now().isoformat(),
            "source_file": os.path.basename(self.zip_path),
        }

        extractors = [
            ("weight",         lambda: self.get_latest_weight()),
            ("weight_series",  lambda: self.get_weight_series(30)),
            ("body_fat",       lambda: self.get_latest_body_fat()),
            ("body_fat_series",lambda: self.get_body_fat_series(30)),
            ("blood_pressure", lambda: self.get_latest_blood_pressure()),
            ("sleep",          lambda: self.get_last_sleep()),
            ("sleep_series",   lambda: self.get_sleep_series(14)),
            ("exercises",      lambda: self.get_recent_exercises(7)),
            ("steps",          lambda: self.get_daily_steps(7)),
        ]

        for key, fn in extractors:
            try:
                val = fn()
                if val:
                    result[key] = val
            except Exception as e:
                logger.warning(f"Failed to extract {key}: {e}")

        extracted = [k for k in result if k not in ("exported_at", "source_file")]
        logger.info(f"Extracted {len(extracted)} metrics: {', '.join(extracted)}")

        return result


def ingest_health_connect(zip_path: str, dashboard_manager) -> str:
    """
    Ingest a Health Connect zip file and update the dashboard.
    Returns a summary string.
    """
    logger.info(f"Ingesting Health Connect export: {zip_path}")

    with HealthConnectParser(zip_path) as parser:
        data = parser.extract_all()

    # Update dashboard
    changes = []

    def updater(dashboard):
        nonlocal changes

        # Weight
        if "weight" in data:
            w = data["weight"]
            old = dashboard["weight"]["current"]
            new_val = w["lbs"]

            dashboard["weight"]["deltaVsLast"] = round(new_val - old, 1)
            dashboard["weight"]["current"] = new_val
            dashboard["weight"]["lastWeighIn"] = datetime.fromisoformat(w["timestamp"]).strftime("%B %-d")
            dashboard["date"] = date.today().isoformat()

            # Update series from Health Connect data
            if "weight_series" in data:
                series_values = list(data["weight_series"].values())
                if len(series_values) >= 7:
                    dashboard["weight"]["series30d"] = series_values[-30:]
                    last7 = series_values[-7:]
                    dashboard["weight"]["avg7d"] = round(sum(last7) / len(last7), 1)

            # Recalculate pace
            baseline = dashboard["weight"].get("baseline", 245.0)
            goal = dashboard["weight"].get("goal", 220)
            baseline_date = datetime(2026, 2, 1)
            weeks = max((datetime.now() - baseline_date).days / 7, 1)
            lost = baseline - new_val
            pace = round(lost / weeks, 1)
            dashboard["weight"]["pace"] = pace
            remaining = new_val - goal
            dashboard["weight"]["etaWeeks"] = round(remaining / pace) if pace > 0 else 999

            changes.append(f"Weight: {new_val} lb ({dashboard['weight']['deltaVsLast']:+.1f})")

        # Body fat
        if "body_fat" in data:
            bf = data["body_fat"]
            old = dashboard["bodyFat"]["current"]
            new_val = bf["percentage"]
            dashboard["bodyFat"]["deltaVsLast"] = round(new_val - old, 1)
            dashboard["bodyFat"]["current"] = new_val

            if "body_fat_series" in data:
                series_values = list(data["body_fat_series"].values())
                if series_values:
                    dashboard["bodyFat"]["series30d"] = series_values[-30:]

            changes.append(f"Body fat: {new_val}%")

        # Blood pressure
        if "blood_pressure" in data:
            bp = data["blood_pressure"]
            dashboard["bloodPressure"]["systolic"] = bp["systolic"]
            dashboard["bloodPressure"]["diastolic"] = bp["diastolic"]
            dashboard["bloodPressure"]["status"] = bp["status"]
            dashboard["bloodPressure"]["date"] = bp["date"]
            changes.append(f"BP: {bp['systolic']}/{bp['diastolic']}")

        # Sleep (new section)
        if "sleep" in data:
            dashboard["sleep"] = data["sleep"]
            if "sleep_series" in data:
                dashboard["sleep"]["series14d"] = data["sleep_series"]
            changes.append(f"Sleep: {data['sleep']['total_hours']}h ({data['sleep']['bedtime']}-{data['sleep']['wake_time']})")

        # Exercise (new section)
        if "exercises" in data:
            dashboard["exercise"] = {
                "recent": data["exercises"][:10],
                "updated": date.today().isoformat(),
            }
            named = [e for e in data["exercises"] if e["name"] != e["type"]]
            if named:
                changes.append(f"Exercise: {len(data['exercises'])} sessions (latest: {named[0]['name']})")
            else:
                changes.append(f"Exercise: {len(data['exercises'])} sessions")

        # Steps
        if "steps" in data:
            dashboard["steps"] = {
                "daily": data["steps"],
                "updated": date.today().isoformat(),
            }
            today_steps = data["steps"].get(date.today().isoformat(), 0)
            if today_steps:
                changes.append(f"Steps: {today_steps:,} today")

    dashboard_manager._read_and_write(updater)

    # Save the full extraction as a snapshot for reference
    snapshot_dir = Path(dashboard_manager.file_path).parent / "health-snapshots"
    snapshot_dir.mkdir(exist_ok=True)
    snapshot_path = snapshot_dir / f"health-{date.today().isoformat()}.json"
    with open(snapshot_path, "w") as f:
        json.dump(data, f, indent=2, default=str)

    summary = "Health Connect import complete:\n" + "\n".join(f"  • {c}" for c in changes)
    logger.info(summary)
    return summary
