#!/usr/bin/env python3
"""
Kitzu AI — Personal Health Intelligence Chatbot

A FastAPI server that powers the Kitzu dashboard chat widget.
Injects the user's full health profile (Oura, blood, BP, body comp, genetics,
microbiome) as context so Claude can reason across all data sources.

Runs on the same machine as the dashboard, accessible via Tailscale.

Usage:
    python3 ai_server.py                    # Start on port 8081
    python3 ai_server.py --port 8082        # Custom port
"""

import json
import argparse
import logging
from datetime import datetime, date
from pathlib import Path
from typing import Optional

from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn
import anthropic

# ── Config ───────────────────────────────────────────────────
log = logging.getLogger("kitzu.ai")
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")

KITZU_ROOT = Path(__file__).resolve().parent
DATA_ROOT = KITZU_ROOT / "data"
UNIFIED_DIR = DATA_ROOT / "unified"
BRIEFS_DIR = UNIFIED_DIR / "briefs"

# Load API key from .env
ENV_PATH = Path.home() / "clawd" / "_Organized" / "Config" / ".env"

def load_api_key() -> str:
    """Load Anthropic API key from .env file."""
    if ENV_PATH.exists():
        for line in ENV_PATH.read_text().splitlines():
            line = line.strip()
            if line.startswith("CLAUDE_API_KEY="):
                return line.split("=", 1)[1].strip()
    # Fallback to environment variable
    import os
    key = os.environ.get("ANTHROPIC_API_KEY") or os.environ.get("CLAUDE_API_KEY")
    if key:
        return key
    raise ValueError("No Anthropic API key found. Set CLAUDE_API_KEY in .env or ANTHROPIC_API_KEY env var.")


def load_model() -> str:
    """Load preferred model from .env, default to opus."""
    if ENV_PATH.exists():
        for line in ENV_PATH.read_text().splitlines():
            line = line.strip()
            if line.startswith("CLAUDE_MODEL="):
                return line.split("=", 1)[1].strip()
    return "claude-sonnet-4-6"


# ── Health Context Builder ───────────────────────────────────

def build_health_context() -> str:
    """Build the full health context string from the unified profile and latest brief."""
    profile_path = UNIFIED_DIR / "profile.json"
    if not profile_path.exists():
        return "No health data available yet. The Kitzu data lake is empty."

    profile = json.loads(profile_path.read_text())

    # Load latest brief
    brief = None
    today = date.today().isoformat()
    brief_path = BRIEFS_DIR / f"{today}.json"
    if brief_path.exists():
        brief = json.loads(brief_path.read_text())
    else:
        # Find most recent brief
        briefs = sorted(BRIEFS_DIR.glob("*.json"))
        if briefs:
            brief = json.loads(briefs[-1].read_text())

    # Build structured context
    sections = []

    # Oura / Sleep / Recovery
    oura = profile.get("oura", {})
    if oura.get("sleep"):
        s = oura["sleep"]
        sections.append(f"""## Sleep & Recovery (Oura Ring)
Last sync: {oura.get('last_sync', 'unknown')}
Latest sleep: {s.get('hours', '?')}h total, {s.get('deep_min', '?')}min deep, {s.get('rem_min', '?')}min REM
HRV: {s.get('avg_hrv', '?')}ms, Resting HR: {s.get('resting_hr', '?')}bpm
Bedtime: {s.get('bedtime', '?')}, Wake: {s.get('wake_time', '?')}
Readiness score: {oura.get('readiness', {}).get('score', '?')}
Steps yesterday: {oura.get('steps', '?')}, Active calories: {oura.get('exercise', {}).get('active_calories', '?') if isinstance(oura.get('exercise'), dict) else '?'}""")

        hrv_trend = oura.get("hrv_trend", [])
        if hrv_trend:
            last7 = hrv_trend[:7]
            avg7 = round(sum(h["hrv"] for h in last7) / len(last7), 1) if last7 else 0
            sections.append(f"HRV 7-day avg: {avg7}ms, trend data: {json.dumps(hrv_trend[:14])}")

    # Blood biomarkers
    blood = profile.get("blood", {})
    if blood.get("markers"):
        markers_text = []
        for name, m in blood["markers"].items():
            status = ""
            if m.get("optimal_range"):
                val = m.get("value", 0)
                low, high = m["optimal_range"]
                if val < low:
                    status = " ⚠️ BELOW optimal"
                elif val > high:
                    status = " ⚠️ ABOVE optimal"
                else:
                    status = " ✅ optimal"
            trend_str = ""
            if m.get("trend"):
                trend_str = f" (trend: {' → '.join(str(t) for t in m['trend'][-4:])})"
            markers_text.append(f"- {name}: {m.get('value', '?')} {m.get('unit', '')}{status}{trend_str}")

        sections.append(f"""## Blood Biomarkers (InsideTracker)
Last test: {blood.get('latest_test', '?')}
InnerAge: {blood.get('inner_age', '?')} (chronological age likely ~50)
Tests in history: {len(blood.get('test_history', []))}

Markers:
{chr(10).join(markers_text)}""")

    # Blood pressure & vitals
    vitals = profile.get("vitals", {})
    bp = vitals.get("blood_pressure", {})
    if bp:
        bp_stats = vitals.get("bp_stats_30d", {})
        bp_hist = vitals.get("bp_history", [])
        sections.append(f"""## Blood Pressure & Vitals (iHealth)
Latest BP: {bp.get('systolic', '?')}/{bp.get('diastolic', '?')} mmHg, pulse {bp.get('pulse', '?')}bpm
Date: {bp.get('date', '?')} {bp.get('time', '')}
30-day avg: {bp_stats.get('avg_systolic', '?')}/{bp_stats.get('avg_diastolic', '?')} ({bp_stats.get('reading_count', 0)} readings)
Total BP readings in history: {len(bp_hist)}
Recent readings: {json.dumps(bp_hist[-8:]) if bp_hist else 'none'}""")

    weight = vitals.get("weight", {})
    if weight:
        wt = vitals.get("weight_trend_7d", {})
        sections.append(f"""## Body Composition
Weight: {weight.get('weight_lb', '?')} lbs
Body fat: {weight.get('body_fat_pct', '?')}%
Muscle mass: {weight.get('muscle_mass', '?')} lbs
Lean mass: {weight.get('lean_mass', '?')} lbs
BMI: {weight.get('bmi', '?')}, BMR: {weight.get('bmr', '?')}
Body water: {weight.get('body_water', '?')}%, Visceral fat: {weight.get('visceral_fat', '?')}
7-day trend: {wt.get('delta_lb', '?')} lbs ({wt.get('direction', '?')})""")

    # Genetics
    genetics = profile.get("genetics", {})
    if genetics.get("health_snps"):
        snps = genetics["health_snps"]
        snp_lines = [f"- {rsid}: {info.get('gene', '?')} — {info.get('interpretation', '?')}" for rsid, info in list(snps.items())[:20]]
        sections.append(f"""## Genetics ({genetics.get('source', 'unknown')})
SNPs analyzed: {genetics.get('snp_count', '?')}
Key health SNPs:
{chr(10).join(snp_lines)}""")

    # Microbiome
    micro = profile.get("microbiome", {})
    if micro.get("food_recs"):
        recs = micro["food_recs"]
        sections.append(f"""## Microbiome (Viome)
Superfoods: {', '.join(recs.get('superfoods', [])[:10])}
Enjoy: {', '.join(recs.get('enjoy', [])[:10])}
Minimize: {', '.join(recs.get('minimize', [])[:10])}
Avoid: {', '.join(recs.get('avoid', [])[:10])}""")

    # Today's brief highlights
    if brief:
        priority_actions = brief.get("priority_actions", [])
        if priority_actions:
            pa_text = "\n".join(f"- {a.get('icon','')} {a.get('action','')} ({a.get('reason','')})" for a in priority_actions)
            sections.append(f"""## Today's Priority Actions
{pa_text}""")

        alerts = brief.get("alerts", [])
        if alerts:
            alert_text = "\n".join(f"- [{a.get('level','')}] {a.get('message','')}" for a in alerts)
            sections.append(f"""## Active Alerts
{alert_text}""")

        xrefs = brief.get("cross_references", [])
        if xrefs:
            xref_text = "\n".join(f"- {xr.get('finding', '')} → {xr.get('action', '')}" for xr in xrefs)
            sections.append(f"""## Cross-Reference Insights
{xref_text}""")

    return "\n\n".join(sections)


def build_system_prompt() -> str:
    """Build the full system prompt with health context."""
    health_context = build_health_context()
    today = datetime.now().strftime("%A, %B %d, %Y")

    return f"""You are Kitzu AI — a personal health intelligence assistant for a wellness-focused individual. You have deep access to their real health data from multiple sources: Oura Ring (sleep, HRV, readiness, activity), InsideTracker blood biomarkers, iHealth blood pressure monitor, body composition scales, and potentially genetics and microbiome data.

Today is {today}.

Your role:
- Answer health questions using their ACTUAL data, not generic advice
- Cross-reference data across sources (e.g., "your HRV is declining AND your cortisol is elevated — these are connected")
- Give direct, actionable recommendations they can do TODAY
- Be conversational and warm but precise with numbers
- When you see concerning patterns, flag them clearly but without being alarmist
- Always ground your responses in their specific numbers
- You can suggest things to discuss with their doctor, but prioritize what they can control right now
- Use plain language, not medical jargon — explain like a knowledgeable friend

What you should NOT do:
- Don't diagnose medical conditions
- Don't recommend prescription medications
- Don't override their doctor's advice
- Don't be vague or generic — they want specificity based on their data

Style notes:
- Keep responses focused and concise (2-4 paragraphs unless they ask for detail)
- Use their actual numbers when referencing data
- Bold key insights or action items
- If they ask something outside your data, say so honestly

─── HEALTH DATA ───

{health_context}

─── END HEALTH DATA ───

Remember: This person is actively optimizing their health. They want real talk, real data, real actions. Not disclaimers and hedging."""


# ── FastAPI App ──────────────────────────────────────────────

app = FastAPI(title="Kitzu AI", version="1.0")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Conversation history (in-memory per session, keyed by session_id)
conversations: dict[str, list] = {}
MAX_HISTORY = 20  # Keep last N messages per session

api_key = None
model = None
client = None


def get_client():
    global api_key, model, client
    if client is None:
        api_key = load_api_key()
        model = load_model()
        client = anthropic.Anthropic(api_key=api_key)
        log.info(f"Anthropic client initialized (model: {model})")
    return client


@app.get("/health")
async def health_check():
    """Health check endpoint."""
    profile_path = UNIFIED_DIR / "profile.json"
    return {
        "status": "ok",
        "model": load_model(),
        "has_profile": profile_path.exists(),
        "timestamp": datetime.now().isoformat(),
    }


@app.post("/chat")
async def chat(request: Request):
    """Chat endpoint — accepts message, returns streamed Claude response."""
    body = await request.json()
    user_message = body.get("message", "").strip()
    session_id = body.get("session_id", "default")

    if not user_message:
        return JSONResponse({"error": "Empty message"}, status_code=400)

    # Get or create conversation history
    if session_id not in conversations:
        conversations[session_id] = []
    history = conversations[session_id]

    # Add user message to history
    history.append({"role": "user", "content": user_message})

    # Trim history if too long
    if len(history) > MAX_HISTORY:
        history = history[-MAX_HISTORY:]
        conversations[session_id] = history

    log.info(f"Chat [{session_id}]: {user_message[:80]}...")

    # Build system prompt fresh each time (data may have changed)
    system_prompt = build_system_prompt()

    async def generate():
        """Stream the response token by token."""
        c = get_client()
        full_response = ""
        try:
            with c.messages.stream(
                model=model,
                max_tokens=2048,
                system=system_prompt,
                messages=history,
            ) as stream:
                for text in stream.text_stream:
                    full_response += text
                    yield f"data: {json.dumps({'type': 'text', 'text': text})}\n\n"

            # Add assistant response to history
            history.append({"role": "assistant", "content": full_response})
            yield f"data: {json.dumps({'type': 'done'})}\n\n"

        except Exception as e:
            log.error(f"Stream error: {e}")
            yield f"data: {json.dumps({'type': 'error', 'text': str(e)})}\n\n"

    return StreamingResponse(
        generate(),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "X-Accel-Buffering": "no",
        },
    )


@app.post("/chat/sync")
async def chat_sync(request: Request):
    """Non-streaming chat endpoint for simpler clients."""
    body = await request.json()
    user_message = body.get("message", "").strip()
    session_id = body.get("session_id", "default")

    if not user_message:
        return JSONResponse({"error": "Empty message"}, status_code=400)

    if session_id not in conversations:
        conversations[session_id] = []
    history = conversations[session_id]
    history.append({"role": "user", "content": user_message})

    if len(history) > MAX_HISTORY:
        history = history[-MAX_HISTORY:]
        conversations[session_id] = history

    system_prompt = build_system_prompt()
    c = get_client()

    try:
        response = c.messages.create(
            model=model,
            max_tokens=2048,
            system=system_prompt,
            messages=history,
        )
        assistant_text = response.content[0].text
        history.append({"role": "assistant", "content": assistant_text})
        return {"response": assistant_text, "session_id": session_id}
    except Exception as e:
        log.error(f"Chat error: {e}")
        return JSONResponse({"error": str(e)}, status_code=500)


@app.post("/reset")
async def reset_conversation(request: Request):
    """Reset conversation history for a session."""
    body = await request.json()
    session_id = body.get("session_id", "default")
    conversations.pop(session_id, None)
    return {"status": "ok", "message": "Conversation reset"}


@app.get("/context")
async def get_context():
    """Debug endpoint — shows what health context Claude sees."""
    return {"context": build_health_context()}


# ── Entry Point ──────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(description="Kitzu AI Server")
    parser.add_argument("--port", type=int, default=8081, help="Port to run on")
    parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
    args = parser.parse_args()

    print(f"\n🧠 Kitzu AI Server starting on {args.host}:{args.port}")
    print(f"   Model: {load_model()}")
    print(f"   Dashboard: http://100.91.208.3:8080/kitzu/dashboard.html")
    print(f"   Chat API:  http://100.91.208.3:{args.port}/chat")
    print(f"   Health:    http://localhost:{args.port}/health\n")

    uvicorn.run(app, host=args.host, port=args.port, log_level="info")


if __name__ == "__main__":
    main()
