"""
Health screenshot parser — uses Claude Vision to extract data from
iHealth app screenshots sent via WhatsApp.

Supported screenshot types:
  1. Scale / body composition (iHealth Nexus Pro)
     → weight, body fat, BMI, muscle mass, body water, etc.
  2. Blood pressure history (iHealth BP monitor)
     → systolic, diastolic, pulse, multiple readings

Bill screenshots his iHealth app and sends to Luke. This module
extracts the numbers and updates the dashboard automatically.
"""

import base64
import json
import logging
from datetime import date, datetime
from anthropic import Anthropic
from config import Config

logger = logging.getLogger("health-vision")

_EXTRACTION_PROMPT = """Look at this image. Determine what type of health screenshot it is and extract the data.

TYPE 1 — SCALE / BODY COMPOSITION (shows weight in large text, may show body fat, BMI, etc.)
Return:
{
  "type": "scale",
  "weight_lb": <number>,
  "weight_change": <number or null>,
  "bmi": <number or null>,
  "body_fat_pct": <number or null>,
  "muscle_mass_lb": <number or null>,
  "body_water_pct": <number or null>,
  "lean_body_mass_lb": <number or null>,
  "bone_mass_lb": <number or null>,
  "protein_pct": <number or null>,
  "visceral_fat": <number or null>,
  "bmr_kcal": <number or null>,
  "metabolic_age": <number or null>,
  "date": "<YYYY-MM-DD if visible, else null>",
  "time": "<HH:MM if visible, else null>"
}

TYPE 2 — BLOOD PRESSURE (shows SYS/DIA/Pulse readings, may show history)
Return:
{
  "type": "blood_pressure",
  "readings": [
    {
      "systolic": <number>,
      "diastolic": <number>,
      "pulse": <number>,
      "date": "<YYYY-MM-DD if visible>",
      "time": "<HH:MM if visible>",
      "reading_number": <1st/2nd/3rd label if visible, else null>
    }
  ]
}
For BP, extract ALL visible readings in the image (there may be multiple).
Sort from most recent to oldest.

NOT A HEALTH SCREENSHOT:
Return: {"type": "unknown"}

Rules:
- All numbers must be numeric, not strings
- weight_change is the small delta (like ↑0.9 means +0.9)
- For BP dates shown as "Feb 23, 2026", convert to "2026-02-23"
- For BP times shown as "7:46 a.m.", convert to "07:46"
- Return ONLY the JSON object, no other text"""


def parse_health_screenshot(image_data: bytes, content_type: str = "image/jpeg") -> dict:
    """
    Send a health app screenshot to Claude Vision for extraction.
    Returns dict with 'type' field indicating what was found.
    """
    client = Anthropic(api_key=Config.CLAUDE_API_KEY)

    if "png" in content_type:
        media_type = "image/png"
    elif "webp" in content_type:
        media_type = "image/webp"
    else:
        media_type = "image/jpeg"

    b64_image = base64.b64encode(image_data).decode("utf-8")

    try:
        response = client.messages.create(
            model=Config.CLAUDE_MODEL,
            max_tokens=1024,
            messages=[{
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "source": {
                            "type": "base64",
                            "media_type": media_type,
                            "data": b64_image,
                        },
                    },
                    {
                        "type": "text",
                        "text": _EXTRACTION_PROMPT,
                    },
                ],
            }],
        )

        text = response.content[0].text.strip()
        logger.info(f"Health vision raw response: {text[:300]}")

        # Parse JSON (handle markdown code blocks if present)
        if text.startswith("```"):
            text = text.split("\n", 1)[1].rsplit("```", 1)[0].strip()

        data = json.loads(text)
        data_type = data.get("type", "unknown")
        logger.info(f"Health vision detected type: {data_type}")
        return data

    except json.JSONDecodeError as e:
        logger.error(f"Health vision: failed to parse JSON: {e}")
        logger.error(f"Raw response: {text[:300]}")
        return {"type": "unknown"}
    except Exception as e:
        logger.error(f"Health vision API call failed: {e}", exc_info=True)
        return {"type": "error", "error": str(e)}


# Legacy alias for backward compatibility
def parse_scale_screenshot(image_data, content_type="image/jpeg"):
    return parse_health_screenshot(image_data, content_type)


# ── Scale Data Application ───────────────────────────────────

def apply_scale_data(data: dict, dashboard_manager) -> str:
    """Apply extracted scale data to the dashboard. Returns summary string."""
    if not data or not data.get("weight_lb"):
        return ""

    today = date.today().isoformat()
    changes = []

    def updater(dashboard):
        nonlocal changes

        weight = data["weight_lb"]
        if "weight" not in dashboard:
            dashboard["weight"] = {}
        w = dashboard["weight"]

        w["current"] = weight
        w["date"] = today
        w["source"] = "iHealth Nexus Pro"

        series = w.get("series30d", [])
        series.append(weight)
        if len(series) > 30:
            series = series[-30:]
        w["series30d"] = series

        goal = w.get("goal", 215)
        if len(series) >= 7:
            week_change = series[-1] - series[-7]
            w["pace"] = round(week_change, 1)
            if week_change < 0:
                remaining = weight - goal
                weeks = remaining / abs(week_change) if week_change != 0 else 99
                w["etaWeeks"] = max(1, round(weeks))
            else:
                w["etaWeeks"] = w.get("etaWeeks", "?")

        w["lost"] = round(w.get("start", weight) - weight, 1)
        changes.append(f"Weight: {weight} lb")

        if data.get("body_fat_pct"):
            if "bodyFat" not in dashboard:
                dashboard["bodyFat"] = {}
            bf = dashboard["bodyFat"]
            old_bf = bf.get("current", data["body_fat_pct"])
            bf["current"] = data["body_fat_pct"]
            bf["date"] = today
            bf["deltaVsLast"] = round(data["body_fat_pct"] - old_bf, 1)
            changes.append(f"Body Fat: {data['body_fat_pct']}%")

        if data.get("bmi"):
            dashboard["bmi"] = {"value": data["bmi"], "date": today}
            changes.append(f"BMI: {data['bmi']}")

        body_comp = {}
        for key in ["muscle_mass_lb", "body_water_pct", "lean_body_mass_lb",
                     "bone_mass_lb", "protein_pct", "visceral_fat", "bmr_kcal",
                     "metabolic_age"]:
            if data.get(key) is not None:
                body_comp[key] = data[key]

        if body_comp:
            body_comp["date"] = today
            dashboard["bodyComposition"] = body_comp
            extras = []
            if data.get("muscle_mass_lb"):
                extras.append(f"Muscle: {data['muscle_mass_lb']} lb")
            if data.get("visceral_fat") is not None:
                extras.append(f"Visceral Fat: {data['visceral_fat']}")
            if data.get("metabolic_age"):
                extras.append(f"Metabolic Age: {data['metabolic_age']}")
            if extras:
                changes.append(" | ".join(extras))

        dashboard["date"] = today

    dashboard_manager._read_and_write(updater)

    summary = "Scale data captured:\n" + "\n".join(f"  • {c}" for c in changes)
    logger.info(summary)
    return summary


# ── Blood Pressure Data Application ──────────────────────────

def apply_bp_data(data: dict, dashboard_manager) -> str:
    """Apply extracted blood pressure data to the dashboard. Returns summary string."""
    readings = data.get("readings", [])
    if not readings:
        return ""

    # Use the most recent reading (first in list)
    latest = readings[0]
    sys_val = latest.get("systolic")
    dia_val = latest.get("diastolic")
    pulse = latest.get("pulse")

    if not sys_val or not dia_val:
        return ""

    today = date.today().isoformat()
    reading_date = latest.get("date", today)
    reading_time = latest.get("time", "")

    # Classify BP
    if sys_val < 120 and dia_val < 80:
        status = "normal"
    elif sys_val < 130 and dia_val < 80:
        status = "elevated"
    elif sys_val < 140 or dia_val < 90:
        status = "stage1"
    else:
        status = "stage2"

    status_labels = {
        "normal": "Normal",
        "elevated": "Elevated",
        "stage1": "Stage 1 Hypertension",
        "stage2": "Stage 2 Hypertension",
    }

    changes = []

    def updater(dashboard):
        nonlocal changes

        bp = dashboard.get("bloodPressure", {})

        # Store current reading
        bp["systolic"] = sys_val
        bp["diastolic"] = dia_val
        bp["pulse"] = pulse
        bp["date"] = reading_date
        bp["time"] = reading_time
        bp["status"] = status
        bp["source"] = "iHealth BP Monitor"

        changes.append(f"BP: {sys_val}/{dia_val} mmHg ({status_labels.get(status, status)})")
        if pulse:
            changes.append(f"Pulse: {pulse} bpm")

        # Store history (last 30 readings)
        history = bp.get("history", [])
        new_entry = {
            "systolic": sys_val,
            "diastolic": dia_val,
            "pulse": pulse,
            "date": reading_date,
            "time": reading_time,
        }
        # Avoid duplicates
        if not history or history[0] != new_entry:
            history.insert(0, new_entry)
            if len(history) > 30:
                history = history[:30]
        bp["history"] = history

        # If multiple readings in this screenshot, add them all to history
        for r in readings[1:]:
            entry = {
                "systolic": r.get("systolic"),
                "diastolic": r.get("diastolic"),
                "pulse": r.get("pulse"),
                "date": r.get("date", reading_date),
                "time": r.get("time", ""),
            }
            if entry not in bp["history"]:
                bp["history"].append(entry)

        # Keep history sorted by date/time (newest first) and trimmed
        bp["history"] = sorted(
            bp["history"],
            key=lambda x: (x.get("date", ""), x.get("time", "")),
            reverse=True,
        )[:30]

        # Average of all readings in this batch
        if len(readings) > 1:
            avg_sys = round(sum(r["systolic"] for r in readings) / len(readings))
            avg_dia = round(sum(r["diastolic"] for r in readings) / len(readings))
            changes.append(f"Avg ({len(readings)} readings): {avg_sys}/{avg_dia}")

        dashboard["bloodPressure"] = bp
        dashboard["date"] = today

    dashboard_manager._read_and_write(updater)

    summary = "Blood pressure captured:\n" + "\n".join(f"  • {c}" for c in changes)
    logger.info(summary)
    return summary
