import json
import uuid
import datetime
import logging
import concurrent.futures
import threading

logger = logging.getLogger(__name__)

BATCH_BUILD_CONCURRENCY = 3  # Lower than analysis since builds are heavier

BATCH_BUILD_STEPS = [
    {"key": "init", "label": "Preparing build", "description": "Loading domain analysis and selecting niche", "est_seconds": 2},
    {"key": "brand", "label": "Creating brand identity", "description": "Designing brand names, taglines, and color palettes", "est_seconds": 20},
    {"key": "copy", "label": "Writing website copy", "description": "Generating all site section content", "est_seconds": 15},
    {"key": "sales", "label": "Writing sales letter", "description": "Crafting marketplace-ready sales letter", "est_seconds": 15},
    {"key": "image", "label": "Generating hero image", "description": "Creating custom hero banner with DALL-E", "est_seconds": 20},
    {"key": "calculators", "label": "Generating calculators", "description": "Building 4 interactive niche calculators", "est_seconds": 30},
    {"key": "reference", "label": "Building reference library", "description": "Researching and compiling niche encyclopedia", "est_seconds": 45},
    {"key": "assets", "label": "Generating luxury assets", "description": "Creating icons, diagrams, and visual assets", "est_seconds": 25},
    {"key": "quality", "label": "Quality gate check", "description": "Validating luxury quality standards", "est_seconds": 5},
    {"key": "saving", "label": "Saving package", "description": "Writing all assets to database", "est_seconds": 3},
    {"key": "complete", "label": "Package complete", "description": "Luxury business-in-a-box ready", "est_seconds": 0},
]


def select_best_niche(analysis: dict) -> tuple:
    """Auto-select highest-scoring niche from analysis data.
    Returns (niche_name, niche_data) tuple."""
    niches = analysis.get("niches", [])
    if not niches:
        return None, None

    best = max(niches, key=lambda n: n.get("viability_score", 0))
    return best.get("name", ""), best


def get_legendary_blueprint():
    """Get a Legendary Protocol blueprint for maximum quality output."""
    from app.services.blueprint import get_default_blueprint
    bp = get_default_blueprint("legendary")
    if not bp:
        bp = get_default_blueprint("comprehensive")
        if bp:
            bp["depth"] = "legendary"
    return bp


def run_batch_build_orchestrator(batch_id: str, domain_ids: list, config: dict = None):
    """
    Orchestrate batch package building for multiple domains.

    Args:
        batch_id: Unique batch job ID
        domain_ids: List of domain IDs (integers) to build packages for
        config: Optional config dict with keys:
            - niche_override: str - Force specific niche for all domains
            - legendary: bool - Enable Legendary Protocol (default True)
            - skip_existing_packages: bool - Skip domains that already have packages (default False)
            - template_type: str - Template type (default "hero")
    """
    from app.main import (
        batch_control, batch_executor, BatchState,
        create_job, update_job, run_build_job,
        SessionLocal, BUILD_STEPS
    )
    from app.models import Domain, Package, Job

    config = config or {}
    legendary = config.get("legendary", True)
    niche_override = config.get("niche_override", "")
    skip_existing = config.get("skip_existing_packages", False)
    template_type = config.get("template_type", "hero")

    state = batch_control.get(batch_id)
    if not state:
        state = BatchState(batch_id)
        batch_control[batch_id] = state

    batch_start_time = datetime.datetime.utcnow().isoformat()
    domain_times = []

    db = SessionLocal()
    try:
        domains_to_build = []
        skipped_domains = []

        domain_records = db.query(Domain).filter(Domain.id.in_(domain_ids)).all()
        domain_map = {d.id: d for d in domain_records}

        for did in domain_ids:
            drec = domain_map.get(did)
            if not drec:
                skipped_domains.append({"id": did, "reason": "not_found"})
                continue
            if not drec.analysis:
                skipped_domains.append({"id": did, "domain": drec.domain, "reason": "not_analyzed"})
                continue

            if skip_existing:
                existing_pkg = db.query(Package).filter(
                    Package.domain_name == drec.domain
                ).first()
                if existing_pkg:
                    skipped_domains.append({"id": did, "domain": drec.domain, "reason": "has_package"})
                    continue

            if niche_override:
                niche_name = niche_override
                niche_data = None
                for n in drec.analysis.get("niches", []):
                    if n.get("name", "").lower() == niche_override.lower():
                        niche_data = n
                        break
                if not niche_data:
                    niche_name, niche_data = select_best_niche(drec.analysis)
            else:
                niche_name, niche_data = select_best_niche(drec.analysis)

            if not niche_name:
                skipped_domains.append({"id": did, "domain": drec.domain, "reason": "no_niches"})
                continue

            domains_to_build.append({
                "id": did,
                "domain": drec.domain,
                "niche_name": niche_name,
                "niche_data": niche_data,
            })
    finally:
        db.close()

    total = len(domains_to_build)
    if total == 0:
        update_job(batch_id, status="completed",
                   current_step=f"No domains to build. {len(skipped_domains)} skipped.",
                   progress_pct=100,
                   result={
                       "mode": "completed",
                       "total": 0,
                       "completed": 0,
                       "failed": 0,
                       "skipped": skipped_domains,
                       "domains": [],
                   })
        state.active = False
        return

    domain_status = []
    for d in domains_to_build:
        domain_status.append({
            "domain": d["domain"],
            "domain_id": d["id"],
            "niche": d["niche_name"],
            "job_id": None,
            "status": "queued",
            "error": None,
            "started_at": None,
            "completed_at": None,
        })

    completed_count = 0
    failed_count = 0

    def save_batch_state(mode, step_msg=None):
        nonlocal completed_count, failed_count
        completed_count = sum(1 for ds in domain_status if ds["status"] == "completed")
        failed_count = sum(1 for ds in domain_status if ds["status"] == "failed")
        in_flight = sum(1 for ds in domain_status if ds["status"] == "running")
        queued = sum(1 for ds in domain_status if ds["status"] == "queued")
        done_total = completed_count + failed_count
        pct = round((done_total / max(total, 1)) * 100)

        try:
            running_ds = [ds for ds in domain_status if ds["status"] == "running" and ds.get("job_id")]
            if running_ds:
                step_db = SessionLocal()
                try:
                    job_ids = [ds["job_id"] for ds in running_ds]
                    child_jobs = step_db.query(Job).filter(Job.job_id.in_(job_ids)).all()
                    step_map = {j.job_id: j.current_step for j in child_jobs if j.current_step}
                    for ds in running_ds:
                        cs = step_map.get(ds["job_id"])
                        if cs:
                            ds["current_step"] = cs
                finally:
                    step_db.close()
        except Exception:
            pass

        now = datetime.datetime.utcnow()
        batch_start_dt = datetime.datetime.fromisoformat(batch_start_time)
        elapsed = round((now - batch_start_dt).total_seconds(), 1)
        avg_time = round(sum(domain_times) / len(domain_times), 1) if domain_times else 0
        remaining_count = total - done_total
        est_remaining = round(avg_time * remaining_count, 1) if avg_time > 0 else 0

        update_job(batch_id,
                   status="running" if mode == "running" else mode,
                   current_step=step_msg or f"{done_total}/{total} packages built ({in_flight} active, {queued} queued)",
                   progress_pct=pct if mode != "completed" else 100,
                   result={
                       "mode": mode,
                       "batch_type": "build",
                       "domains": domain_status,
                       "total": total,
                       "completed": completed_count,
                       "failed": failed_count,
                       "in_flight": in_flight,
                       "queued": queued,
                       "skipped": skipped_domains,
                       "legendary": legendary,
                       "concurrency": BATCH_BUILD_CONCURRENCY,
                       "batch_started_at": batch_start_time,
                       "elapsed_seconds": elapsed,
                       "avg_seconds_per_domain": avg_time,
                       "est_remaining_seconds": est_remaining,
                   })

    save_batch_state("running", f"Starting batch build of {total} packages (Legendary Protocol: {'ON' if legendary else 'OFF'})...")

    blueprint = get_legendary_blueprint() if legendary else None

    idx = 0
    futures = {}

    try:
        while idx < total or futures:
            if state.stop_flag.is_set():
                for ds in domain_status:
                    if ds["status"] == "queued":
                        ds["status"] = "cancelled"
                save_batch_state("stopped", f"Batch stopped. {completed_count + failed_count}/{total} built.")
                break

            if not state.pause_event.is_set():
                save_batch_state("paused", f"Batch paused. {completed_count + failed_count}/{total} built, {len(futures)} in flight.")
                state.pause_event.wait(timeout=2)
                continue

            while len(futures) < BATCH_BUILD_CONCURRENCY and idx < total:
                if state.stop_flag.is_set():
                    break
                if not state.pause_event.is_set():
                    break

                ds = domain_status[idx]
                domain_info = domains_to_build[idx]
                domain_name = domain_info["domain"]
                niche_name = domain_info["niche_name"]

                child_job_id = f"bb-{batch_id[:4]}-{str(uuid.uuid4())[:6]}"
                ds["job_id"] = child_job_id
                ds["status"] = "running"
                ds["started_at"] = datetime.datetime.utcnow().isoformat()

                create_job(child_job_id, "build", domain_name, len(BUILD_STEPS) - 1, BUILD_STEPS,
                           retry_params={
                               "domain": domain_name,
                               "niche": niche_name,
                               "batch_id": batch_id,
                               "legendary": legendary,
                           })

                future = batch_executor.submit(
                    run_build_job,
                    child_job_id,
                    domain_name,
                    niche_name,
                    template_type,
                    None,
                    "single-scroll",
                    "legendary" if legendary else "balanced",
                    blueprint,
                    "default",
                )
                futures[future] = idx
                idx += 1

            save_batch_state("running")

            if futures:
                done_futures, _ = concurrent.futures.wait(
                    futures.keys(), timeout=2,
                    return_when=concurrent.futures.FIRST_COMPLETED
                )
                for f in done_futures:
                    f_idx = futures.pop(f)
                    ds = domain_status[f_idx]
                    ds["completed_at"] = datetime.datetime.utcnow().isoformat()
                    if ds.get("started_at"):
                        started_dt = datetime.datetime.fromisoformat(ds["started_at"])
                        completed_dt = datetime.datetime.fromisoformat(ds["completed_at"])
                        ds["elapsed_seconds"] = round((completed_dt - started_dt).total_seconds(), 1)
                        domain_times.append(ds["elapsed_seconds"])

                    child_db = SessionLocal()
                    try:
                        child_job = child_db.query(Job).filter(Job.job_id == ds["job_id"]).first()
                        if child_job:
                            ds["status"] = child_job.status
                            if child_job.error:
                                ds["error"] = child_job.error
                            if child_job.result and isinstance(child_job.result, dict):
                                ds["package_id"] = child_job.result.get("id")
                        else:
                            ds["status"] = "completed"
                    except Exception:
                        ds["status"] = "completed"
                    finally:
                        child_db.close()

                    save_batch_state("running")
            elif idx >= total:
                break

        if not state.stop_flag.is_set():
            save_batch_state("completed", f"Batch build complete! {completed_count} packages created, {failed_count} failed out of {total}.")

    except Exception as e:
        logger.error(f"Batch build orchestrator failed: {e}")
        save_batch_state("failed", f"Orchestrator error: {str(e)}")
    finally:
        state.active = False
