import json
import logging
import datetime
from typing import Optional

logger = logging.getLogger(__name__)

MAX_CONTEXT_CHARS = 12000
MAX_EVENT_LOG_ENTRIES = 50


def assemble_context(
    domain: str,
    global_context: Optional[dict] = None,
    project_context: Optional[dict] = None,
    profile_config: Optional[dict] = None,
    discovery_answers: Optional[dict] = None,
    brandkit_summary: Optional[dict] = None,
    niche_data: Optional[dict] = None,
) -> str:
    blocks = []

    if global_context:
        gc_block = _build_global_context_block(global_context)
        if gc_block:
            blocks.append(gc_block)

    if project_context:
        pc_block = _build_project_context_block(project_context, domain)
        if pc_block:
            blocks.append(pc_block)

    if profile_config:
        prof_block = _build_profile_context_block(profile_config)
        if prof_block:
            blocks.append(prof_block)

    if niche_data:
        niche_block = _build_niche_context_block(niche_data)
        if niche_block:
            blocks.append(niche_block)

    if discovery_answers:
        disc_block = _build_discovery_context_block(discovery_answers)
        if disc_block:
            blocks.append(disc_block)

    if brandkit_summary:
        bk_block = _build_brandkit_context_block(brandkit_summary)
        if bk_block:
            blocks.append(bk_block)

    full_context = "\n\n".join(blocks)

    if len(full_context) > MAX_CONTEXT_CHARS:
        full_context = _smart_truncate(full_context, MAX_CONTEXT_CHARS)

    return full_context


def _build_global_context_block(gc: dict) -> str:
    parts = ["--- GLOBAL CONTEXT (Master Rules) ---"]
    if gc.get("master_rules"):
        parts.append(f"RULES:\n{gc['master_rules']}")
    if gc.get("style_prefs"):
        prefs = gc["style_prefs"]
        if isinstance(prefs, dict):
            pref_lines = [f"  - {k}: {v}" for k, v in prefs.items() if v]
            if pref_lines:
                parts.append("STYLE PREFERENCES:\n" + "\n".join(pref_lines))
    if gc.get("guardrails"):
        rails = gc["guardrails"]
        if isinstance(rails, dict):
            rail_lines = [f"  - {k}: {v}" for k, v in rails.items() if v]
            if rail_lines:
                parts.append("GUARDRAILS:\n" + "\n".join(rail_lines))
        elif isinstance(rails, list):
            parts.append("GUARDRAILS:\n" + "\n".join(f"  - {r}" for r in rails))
    return "\n".join(parts) if len(parts) > 1 else ""


def _build_project_context_block(pc: dict, domain: str) -> str:
    parts = [f"--- PROJECT CONTEXT ({domain}) ---"]
    state = pc.get("context_state", {})
    if not state and not pc.get("event_log"):
        return ""

    if state:
        if state.get("selected_niche"):
            parts.append(f"CHOSEN NICHE: {state['selected_niche']}")
        if state.get("template_type"):
            parts.append(f"TEMPLATE: {state['template_type']}")
        if state.get("depth"):
            parts.append(f"DEPTH: {state['depth']}")
        if state.get("profile_slug"):
            parts.append(f"PROFILE: {state['profile_slug']}")
        if state.get("brand_personality"):
            parts.append(f"BRAND PERSONALITY: {state['brand_personality']}")
        if state.get("key_decisions"):
            parts.append("KEY DECISIONS:")
            for d in state["key_decisions"]:
                parts.append(f"  - {d}")
        if state.get("content_direction"):
            parts.append(f"CONTENT DIRECTION: {state['content_direction']}")
        if state.get("prior_feedback"):
            parts.append("PRIOR FEEDBACK (user edits/refinements):")
            for fb in state["prior_feedback"][-5:]:
                parts.append(f"  - {fb}")

    event_log = pc.get("event_log", [])
    if event_log:
        recent = event_log[-10:]
        parts.append("RECENT PROJECT EVENTS:")
        for evt in recent:
            ts = evt.get("timestamp", "")
            action = evt.get("action", "unknown")
            detail = evt.get("detail", "")
            parts.append(f"  [{ts}] {action}: {detail}")

    return "\n".join(parts)


def _build_profile_context_block(config: dict) -> str:
    parts = ["--- ACTIVE SITE PROFILE ---"]
    overrides = config.get("prompt_overrides", {})
    if overrides.get("global_tone"):
        parts.append(f"TONE DIRECTIVE: {overrides['global_tone']}")
    visual = config.get("visual_defaults", {})
    if visual:
        vis_items = [f"{k}: {v}" for k, v in visual.items() if v]
        if vis_items:
            parts.append("VISUAL DEFAULTS: " + ", ".join(vis_items))
    sections = config.get("enabled_sections", config.get("sections", []))
    if sections:
        parts.append(f"ENABLED SECTIONS: {', '.join(sections)}")
    return "\n".join(parts) if len(parts) > 1 else ""


def _build_niche_context_block(niche: dict) -> str:
    parts = ["--- NICHE INTELLIGENCE ---"]
    if niche.get("name"):
        parts.append(f"NICHE: {niche['name']}")
    if niche.get("description"):
        parts.append(f"DESCRIPTION: {niche['description']}")
    if niche.get("target_audience"):
        parts.append(f"TARGET AUDIENCE: {niche['target_audience']}")
    if niche.get("monetization_model"):
        parts.append(f"MONETIZATION: {niche['monetization_model']}")
    if niche.get("affiliate_programs"):
        programs = niche["affiliate_programs"]
        if isinstance(programs, list):
            parts.append(f"AFFILIATE PROGRAMS: {', '.join(programs)}")
    return "\n".join(parts) if len(parts) > 1 else ""


def _build_discovery_context_block(answers: dict) -> str:
    if not answers:
        return ""
    parts = ["--- DISCOVERY ANSWERS (Client-Provided) ---"]
    field_labels = {
        "project_story": "PROJECT VISION",
        "success_vision": "SUCCESS DEFINITION",
        "one_thing": "CORE PURPOSE",
        "ideal_audience": "IDEAL VISITOR",
        "audience_journey": "AUDIENCE JOURNEY",
        "primary_action": "PRIMARY ACTION",
        "brand_personality": "BRAND PERSONALITY",
        "color_palette": "COLOR PALETTE",
        "visual_style": "VISUAL STYLE",
        "core_values": "CORE VALUES",
        "desired_feeling": "DESIRED FEELING",
        "key_features": "KEY FEATURES",
        "sites_loved": "INSPIRATION SITES",
        "future_vision": "FUTURE VISION",
        "additional_notes": "ADDITIONAL NOTES",
    }
    for key, value in answers.items():
        if value and str(value).strip():
            label = field_labels.get(key, key.upper().replace("_", " "))
            parts.append(f"{label}: {value}")
    return "\n".join(parts) if len(parts) > 1 else ""


def _build_brandkit_context_block(summary: dict) -> str:
    if not summary:
        return ""
    parts = ["--- BRAND KIT INTELLIGENCE ---"]
    if summary.get("tone"):
        parts.append(f"DETECTED TONE: {summary['tone']}")
    if summary.get("keywords"):
        kw = summary["keywords"]
        if isinstance(kw, list):
            parts.append(f"BRAND KEYWORDS: {', '.join(kw)}")
    if summary.get("visual_motifs"):
        parts.append(f"VISUAL MOTIFS: {summary['visual_motifs']}")
    if summary.get("brand_personality"):
        parts.append(f"BRAND PERSONALITY: {summary['brand_personality']}")
    if summary.get("content_strength"):
        parts.append(f"CONTENT STRENGTH: {summary['content_strength']}")
    return "\n".join(parts) if len(parts) > 1 else ""


def _smart_truncate(text: str, max_chars: int) -> str:
    if len(text) <= max_chars:
        return text
    blocks = text.split("\n\n")
    priority = []
    normal = []
    low = []
    for block in blocks:
        if "GLOBAL CONTEXT" in block or "DISCOVERY ANSWERS" in block:
            priority.append(block)
        elif "PROJECT CONTEXT" in block or "NICHE INTELLIGENCE" in block:
            normal.append(block)
        else:
            low.append(block)

    result_blocks = priority + normal + low
    result = ""
    for block in result_blocks:
        if len(result) + len(block) + 2 > max_chars:
            remaining = max_chars - len(result) - 20
            if remaining > 100:
                result += "\n\n" + block[:remaining] + "\n[...truncated]"
            break
        result += ("\n\n" if result else "") + block
    return result


def log_context_event(domain: str, action: str, detail: str, db_session=None):
    if not db_session:
        return
    from app.models import ProjectContext
    ctx = db_session.query(ProjectContext).filter(ProjectContext.domain == domain).first()
    if not ctx:
        ctx = ProjectContext(domain=domain, context_state={}, event_log=[])
        db_session.add(ctx)

    event_log = ctx.event_log or []
    event_log.append({
        "timestamp": datetime.datetime.utcnow().isoformat() + "Z",
        "action": action,
        "detail": detail[:500],
    })
    if len(event_log) > MAX_EVENT_LOG_ENTRIES:
        event_log = event_log[-MAX_EVENT_LOG_ENTRIES:]

    ctx.event_log = event_log
    from sqlalchemy.orm.attributes import flag_modified
    flag_modified(ctx, "event_log")
    db_session.commit()


def update_project_state(domain: str, updates: dict, db_session=None):
    if not db_session:
        return
    from app.models import ProjectContext
    ctx = db_session.query(ProjectContext).filter(ProjectContext.domain == domain).first()
    if not ctx:
        ctx = ProjectContext(domain=domain, context_state={}, event_log=[])
        db_session.add(ctx)

    state = ctx.context_state or {}
    state.update(updates)
    ctx.context_state = state
    from sqlalchemy.orm.attributes import flag_modified
    flag_modified(ctx, "context_state")
    db_session.commit()


def get_full_context_for_domain(domain: str, db_session) -> str:
    from app.models import ProjectContext, GlobalContext, BrandKit, SiteProfile

    gc_row = db_session.query(GlobalContext).first()
    gc = None
    if gc_row:
        gc = {
            "master_rules": gc_row.master_rules,
            "style_prefs": gc_row.style_prefs,
            "guardrails": gc_row.guardrails,
        }

    pc_row = db_session.query(ProjectContext).filter(ProjectContext.domain == domain).first()
    pc = None
    if pc_row:
        pc = {
            "context_state": pc_row.context_state or {},
            "event_log": pc_row.event_log or [],
        }

    profile_config = None
    if pc and pc.get("context_state", {}).get("profile_slug"):
        slug = pc["context_state"]["profile_slug"]
        prof = db_session.query(SiteProfile).filter(SiteProfile.slug == slug).first()
        if prof:
            profile_config = prof.config

    bk_summary = None
    kit = db_session.query(BrandKit).filter(BrandKit.domain == domain, BrandKit.status == "ready").first()
    if kit and kit.summary:
        bk_summary = kit.summary

    discovery = None
    if pc and pc.get("context_state", {}).get("discovery_answers"):
        discovery = pc["context_state"]["discovery_answers"]

    niche_data = None
    if pc and pc.get("context_state", {}).get("niche_data"):
        niche_data = pc["context_state"]["niche_data"]

    return assemble_context(
        domain=domain,
        global_context=gc,
        project_context=pc,
        profile_config=profile_config,
        discovery_answers=discovery,
        brandkit_summary=bk_summary,
        niche_data=niche_data,
    )
