import json
import logging
import re

from app.providers.provider_manager import run_cascade_interpretation
from app.utils.prompts import (
    INTERPRETATION_RETRY_TEMPLATE,
    INTERPRETATION_USER_TEMPLATE,
    INTERPRETATION_VARIETY_RETRY_TEMPLATE,
    build_interpretation_system,
)
logger = logging.getLogger("lab_analyzer")

MIN_EXPLANATION_LEN_NORMAL = 25
MIN_EXPLANATION_LEN_ABNORMAL = 30
MAX_EXPLANATION_LEN = 280
MAX_SENTENCES = 1
MAX_EXPLANATION_WORDS = 32

# Zero-tolerance banned phrases (anywhere in explanation)
BANNED_PHRASES = (
    "هذا التحليل",
    "هذا اختبار",
    "ونتيجة هذا التحليل",
    "مما يشير إلى",
    "هذا اختبار لقياس",
    "ونتيجة الإنزيم",
    "بخصوص",
    "النتيجة دي بتقول",
    "لو بصينا على",
)

BANNED_OPENERS = BANNED_PHRASES

# Generic copy-paste advice — reject if reused across tests
GENERIC_ADVICE_PHRASES = (
    "تناول نظام غذائي صحي وممارسة الرياضة",
    "نظام غذائي صحي وممارسة الرياضة",
    "ممارسة الرياضة بانتظام",
    "اتباع نظام غذائي صحي",
    "حافظ على نمط حياتك الصحي",
    "متابعة روتينية مع الدكتور",
    "منع المقليات والدهون تماماً",
    "تقليل المقليات والدهون",
)

# Wrong medical meaning for HbA1c
HBA1C_WRONG_PHRASES = (
    "نقل الأكسجين",
    "توصيل الأكسجين",
    "هيموجلوبين",
)

# overall_summary must not claim "all normal" when abnormal tests exist
ALL_NORMAL_SUMMARY_PHRASES = (
    "كل التحاليل طبيعية",
    "جميع التحاليل طبيعية",
    "كل النتائج طبيعية",
    "كلها طبيعية",
    "كل شيء طبيعي",
    "كل شي طبيعي",
    "مفيش أي مشكلة",
    "لا توجد مشاكل",
)

TEMPLATE_PHRASES = BANNED_OPENERS + (
    "نتيجتك طبيعية وده مطمئن",
    "مفيش إجراء عاجل",
    "متابعة روتينية مع الدكتور",
    "متابعة روتينية",
    "متابعة عادية",
    "ومفيش حاجة مقلقة",
)

BODY_HINTS: list[tuple[str, str, str]] = [
    ("haemoglobin", "مادة الدم اللي بتوصل الأكسجين", "الدم"),
    ("hemoglobin", "مادة الدم اللي بتوصل الأكسجين", "الدم"),
    ("rbc", "عدد كرات الدم الحمراء", "الدم"),
    ("mcv", "حجم كرية الدم الحمراء", "الدم"),
    ("mch", "كمية الأكسجين في كرية الدم الحمراء", "الدم"),
    ("mchc", "تركيز الأكسجين في كرية الدم", "الدم"),
    ("platelet", "الصفائح الدموية", "الدم"),
    ("haematocrit", "نسبة كرات الدم الحمراء", "الدم"),
    ("pcv", "نسبة كرات الدم الحمراء", "الدم"),
    ("rdw", "تفاوت أحجام كرات الدم الحمراء", "الدم"),
    ("mpv", "حجم الصفائح", "الدم"),
    ("pct", "نسبة الصفائح في الدم", "الدم"),
    ("pdw", "تفاوت أحجام الصفائح", "الدم"),
    ("wbc", "خلايا الدم البيضاء", "المناعة"),
    ("leucocytic", "خلايا الدم البيضاء", "المناعة"),
    ("neutrophil", "خلايا دفاع الجسم ضد البكتيريا", "المناعة"),
    ("lymphocyte", "خلايا دفاع الجسم ضد الفيروسات", "المناعة"),
    ("monocyte", "خلايا التنظيف في الدم", "المناعة"),
    ("eosinophil", "خلايا الحساسية", "المناعة"),
    ("basophil", "خلايا التحسس", "المناعة"),
    ("sgpt", "إنزيم الكبد ALT", "الكبد"),
    ("alt", "إنزيم الكبد ALT", "الكبد"),
    ("alkaline", "إنزيم الفوسفاتيز القلوي", "الكبد"),
    ("phosphatase", "إنزيم الفوسفاتيز القلوي", "الكبد"),
    ("asot", "أجسام مضادة للميكروب السبحي", "العدوى"),
    ("creatinine", "وظيفة الكلى", "الكلى"),
    ("glucose", "سكر الدم", "البنكرياس"),
    ("hba1c", "معدل السكر التراكمي في آخر 3 شهور", "البنكرياس"),
    ("a1c", "معدل السكر التراكمي في آخر 3 شهور", "البنكرياس"),
    ("hemoglobin a1c", "معدل السكر التراكمي في آخر 3 شهور", "البنكرياس"),
    ("haemoglobin a1c", "معدل السكر التراكمي في آخر 3 شهور", "البنكرياس"),
    ("tsh", "الغدة الدرقية", "الغدة الدرقية"),
    ("esr", "معدل ترسيب كرات الدم", "الالتهاب"),
    ("crp", "بروتين التهاب", "الالتهاب"),
    ("calcium", "معدن الكالسيوم", "العظام"),
    ("phosphor", "معدن الفوسفور", "العظام"),
    ("phosphorus", "معدن الفوسفور", "العظام"),
]


def _english_test_name(test: dict) -> str:
    """Keep original report name in English."""
    return str(test.get("test_name") or test.get("canonical_name") or "").strip()


def _body_hints(test: dict) -> tuple[str, str]:
    name = _english_test_name(test).lower()
    for token, plain, body in BODY_HINTS:
        if token in name:
            return plain, body
    return "", "الجسم"


def _primary_status(test: dict) -> str:
    if test.get("status") in ("low", "high", "normal"):
        return test["status"]
    if test.get("absolute_value"):
        return test.get("status_absolute", "unknown")
    return test.get("status_percentage", "unknown")


def explanation_style(test: dict) -> str:
    return "abnormal" if _primary_status(test) in ("low", "high") else "normal"


def _status_label_ar(status: str) -> str:
    return {"normal": "طبيعي", "high": "مرتفع", "low": "منخفض"}.get(status, "غير محدد")


def _test_is_abnormal(test: dict) -> bool:
    for key in ("status", "status_percentage", "status_absolute"):
        if str(test.get(key, "")).lower() in ("high", "low"):
            return True
    return _primary_status(test) in ("high", "low")


def _has_abnormal_tests(tests: list[dict]) -> bool:
    return any(_test_is_abnormal(t) for t in tests)


def _summary_claims_all_normal(summary: str) -> bool:
    lower = summary.strip().lower()
    if not lower:
        return False
    return any(phrase in lower for phrase in ALL_NORMAL_SUMMARY_PHRASES)


def _fix_overall_summary(summary: str, validated_tests: list[dict], language: str) -> str:
    """Ensure summary does not deny abnormal results when any exist."""
    if language != "ar" or not _has_abnormal_tests(validated_tests):
        return summary
    if not _summary_claims_all_normal(summary):
        return summary

    abnormal_names = [
        _english_test_name(t)
        for t in validated_tests
        if _test_is_abnormal(t)
    ]
    names_text = "، ".join(abnormal_names[:6])
    logger.warning(
        "[INTERPRETATION] overall_summary claimed all normal but %d abnormal tests — using fallback",
        len(abnormal_names),
    )
    return (
        f"التقرير يظهر نتائج تحتاج متابعة طبية، منها: {names_text}. "
        "يُنصح بتنظيم الغذاء وتقليل الدهون والتوتر مع عرض التقرير على طبيب باطنة أو التخصص المناسب."
    )


def _sentence_count(text: str) -> int:
    parts = re.split(r"[.!?؟]\s*", text.strip())
    return len([p for p in parts if p.strip()])


def _arabic_word_count(text: str) -> int:
    return len(re.findall(r"[\w\u0600-\u06FF]+", text))


def _contains_banned_phrase(text: str) -> bool:
    t = text.strip()
    if not t:
        return True
    for phrase in BANNED_PHRASES:
        if phrase in t:
            return True
    return False


def _has_banned_opener(text: str) -> bool:
    return _contains_banned_phrase(text)


def _uses_generic_advice(text: str) -> bool:
    return any(phrase in text for phrase in GENERIC_ADVICE_PHRASES)


def _hba1c_wrong_explanation(test: dict, explanation: str) -> bool:
    name = _english_test_name(test).lower()
    if "a1c" not in name and "hba1c" not in name:
        return False
    return any(phrase in explanation for phrase in HBA1C_WRONG_PHRASES)


def _find_generic_advice_indexes(explanations: list[str]) -> set[int]:
    """Flag explanations sharing the same generic diet/exercise template."""
    phrase_to_indexes: dict[str, list[int]] = {}
    for i, expl in enumerate(explanations):
        for phrase in GENERIC_ADVICE_PHRASES:
            if phrase in expl:
                phrase_to_indexes.setdefault(phrase, []).append(i)
    bad: set[int] = set()
    for indexes in phrase_to_indexes.values():
        if len(indexes) >= 2:
            bad.update(indexes)
    return bad


def _test_for_interpretation(test: dict, _index: int = 0) -> dict:
    plain, body = _body_hints(test)
    status = _primary_status(test)
    row: dict = {
        "test_name": _english_test_name(test),
        "primary_status": status,
        "status_label_ar": _status_label_ar(status),
        "body_area": body,
    }
    if plain:
        row["simple_description_ar"] = plain
        row["medical_class_hint"] = plain
    for key in ("result", "unit", "normal_range", "percentage_value", "percentage_unit", "percentage_range"):
        val = test.get(key)
        if val is not None and str(val).strip():
            row[key] = str(val).strip()
    return row


def _looks_like_template(text: str) -> bool:
    t = text.strip()
    if not t:
        return True
    hits = sum(1 for p in TEMPLATE_PHRASES if p in t)
    if hits >= 2:
        return True
    if t.startswith("التحليل ده بيخص") and "مطمئن" in t:
        return True
    if t.startswith("التحليل ده بيخص") and "متابعة روتينية" in t:
        return True
    return False


def _word_overlap_ratio(a: str, b: str) -> float:
    wa = set(re.findall(r"\w+", a.lower()))
    wb = set(re.findall(r"\w+", b.lower()))
    if not wa or not wb:
        return 0.0
    return len(wa & wb) / len(wa | wb)


def _find_repetitive_indexes(explanations: list[str], threshold: float = 0.52) -> set[int]:
    """Flag explanations too similar to another (template-like batch)."""
    repetitive: set[int] = set()
    for i, expl_i in enumerate(explanations):
        if _looks_like_template(expl_i):
            repetitive.add(i)
        for j in range(i + 1, len(explanations)):
            if _word_overlap_ratio(expl_i, explanations[j]) >= threshold:
                repetitive.add(i)
                repetitive.add(j)
    return repetitive


def _explanation_acceptable(
    explanation: str,
    style: str,
    test: dict | None = None,
) -> bool:
    text = (explanation or "").strip()
    min_len = MIN_EXPLANATION_LEN_ABNORMAL if style == "abnormal" else MIN_EXPLANATION_LEN_NORMAL
    if len(text) < min_len or len(text) > MAX_EXPLANATION_LEN:
        return False
    if _contains_banned_phrase(text) or _looks_like_template(text):
        return False
    if _uses_generic_advice(text):
        return False
    if _sentence_count(text) > MAX_SENTENCES:
        return False
    if _arabic_word_count(text) > MAX_EXPLANATION_WORDS:
        return False
    if text.count("،") + text.count(".") < 1:
        return False
    if test and _hba1c_wrong_explanation(test, text):
        return False
    return True


def _summary_acceptable(summary: str) -> bool:
    text = (summary or "").strip()
    if not text or len(text) > 400:
        return False
    if _contains_banned_phrase(text):
        return False
    return _sentence_count(text) <= 2


def _build_interp_lookup(ai_tests: list) -> dict[str, str]:
    """Map lowered test_name -> explanation from AI."""
    lookup: dict[str, str] = {}
    for item in ai_tests:
        if not isinstance(item, dict):
            continue
        key = str(item.get("test_name", "")).strip().lower()
        expl = str(item.get("explanation", "")).strip()
        if key and expl:
            lookup[key] = expl
    return lookup


def _match_explanation(vt: dict, lookup: dict[str, str]) -> str:
    keys = [
        _english_test_name(vt).lower(),
        str(vt.get("canonical_name", "")).lower(),
    ]
    for key in keys:
        if key and key in lookup:
            return lookup[key]
    # partial
    base = keys[0]
    for stored, expl in lookup.items():
        if base and (base in stored or stored in base):
            return expl
    return ""


def ensure_complete_interpretations(
    parsed: dict,
    validated_tests: list[dict],
) -> tuple[list[dict], list[dict]]:
    ai_tests = parsed.get("tests", [])
    if not isinstance(ai_tests, list):
        ai_tests = []

    lookup = _build_interp_lookup(ai_tests)
    ordered: list[dict] = []
    need_retry: list[dict] = []

    for vt in validated_tests:
        eng_name = _english_test_name(vt)
        style = explanation_style(vt)
        expl = _match_explanation(vt, lookup)

        ordered.append({
            "test_name": eng_name,
            "explanation": expl,
            "_style": style,
        })

    expl_texts = [t["explanation"] for t in ordered]
    bad_indexes = _find_repetitive_indexes(expl_texts)
    bad_indexes |= _find_generic_advice_indexes(expl_texts)

    for i, item in enumerate(ordered):
        style = item.pop("_style", "normal")
        expl = item["explanation"]
        source_test = validated_tests[i] if i < len(validated_tests) else {}
        if (
            not expl
            or not _explanation_acceptable(expl, style, source_test)
            or i in bad_indexes
        ):
            need_retry.append({"test_name": item["test_name"], "_style": style, "_index": i})

    parsed["tests"] = [{"test_name": t["test_name"], "explanation": t["explanation"]} for t in ordered]
    return parsed, need_retry


async def _retry_tests(
    to_retry: list[dict],
    validated_tests: list[dict],
    language: str,
    variety: bool = False,
) -> dict[str, str]:
    if not to_retry:
        return {}

    full_payload = {
        "language": language,
        "tests": [_test_for_interpretation(t, i) for i, t in enumerate(validated_tests)],
    }
    retry_list = json.dumps(
        [{"test_name": r["test_name"], "primary_status": r.get("_style", "normal")} for r in to_retry],
        ensure_ascii=False,
        separators=(",", ":"),
    )

    system_prompt = build_interpretation_system(language)
    if variety:
        compact_ctx = json.dumps(full_payload, ensure_ascii=False, separators=(",", ":"))
        user_prompt = INTERPRETATION_VARIETY_RETRY_TEMPLATE.format(
            tests_to_rewrite=retry_list,
            json_payload=compact_ctx,
        )
    else:
        user_prompt = INTERPRETATION_RETRY_TEMPLATE.format(
            missing_tests=retry_list,
            json_payload=json.dumps(full_payload, ensure_ascii=False, separators=(",", ":")),
        )

    try:
        _, parsed = await run_cascade_interpretation(
            system_prompt=system_prompt,
            user_prompt=user_prompt,
        )
        out: dict[str, str] = {}
        for item in parsed.get("tests", []) or []:
            if isinstance(item, dict):
                key = str(item.get("test_name", "")).strip().lower()
                expl = str(item.get("explanation", "")).strip()
                if key and expl:
                    out[key] = expl
        return out
    except Exception as exc:
        logger.warning("[INTERPRETATION] Retry failed: %s", exc)
        return {}


async def interpret_results(
    validated_tests: list[dict],
    language: str,
) -> tuple[list[dict], str]:
    if not validated_tests:
        return [], "No test results were available to interpret."

    payload = {
        "language": language,
        "tests": [_test_for_interpretation(t, i) for i, t in enumerate(validated_tests)],
    }
    compact_payload = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))

    system_prompt = build_interpretation_system(language)
    user_prompt = INTERPRETATION_USER_TEMPLATE.format(json_payload=compact_payload)

    try:
        _, parsed = await run_cascade_interpretation(
            system_prompt=system_prompt,
            user_prompt=user_prompt,
        )
    except Exception as exc:
        logger.exception("[INTERPRETATION] Failed: %s", exc)
        return [], "Interpretation could not be generated."

    parsed, need_retry = ensure_complete_interpretations(parsed, validated_tests)

    if need_retry:
        logger.info("[INTERPRETATION] Retry for %d tests (banned/generic/invalid)", len(need_retry))
        fixes = await _retry_tests(need_retry, validated_tests, language, variety=True)
        for item in parsed["tests"]:
            key = str(item.get("test_name", "")).lower()
            if key in fixes and fixes[key]:
                item["explanation"] = fixes[key]

    tests_out = parsed.get("tests", [])
    summary = str(parsed.get("overall_summary", "")).strip()
    if not summary or not _summary_acceptable(summary):
        if language == "ar" and _has_abnormal_tests(validated_tests):
            abnormal_names = [
                _english_test_name(t) for t in validated_tests if _test_is_abnormal(t)
            ]
            names = "، ".join(abnormal_names[:5])
            summary = (
                f"التقرير يظهر نتائج غير طبيعية تحتاج متابعة، منها: {names}؛ "
                "يُنصح بزيارة طبيب باطنة وتقليل الدهون والراحة."
            )
        elif not summary:
            summary = (
                "التقرير ده بيلخص نتائج التحاليل. ناقشها مع دكتورك."
                if language == "ar"
                else "Discuss with your doctor."
            )
    else:
        summary = _fix_overall_summary(summary, validated_tests, language)

    return tests_out, summary
