import logging
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from typing import Optional
from sqlalchemy.orm import Session

from app.database import get_db

import aura_core.valuation.engine
import aura_core.validator.engine
import aura_core.theme.engine
import aura_core.blueprint.engine
import aura_core.orchestrator.engine
from aura_core.module_contract import get_module, get_mcp_tool_definitions
from aura_core.types import (
    ValuationInput,
    ValidatorInput,
    ThemeInput,
    BlueprintInput,
    OrchestratorInput,
)

logger = logging.getLogger(__name__)

router = APIRouter(prefix="/api/aura-core", tags=["aura_core"])


@router.get("/tools")
def list_tools():
    tools = get_mcp_tool_definitions()
    return JSONResponse(content={"tools": tools, "count": len(tools)})


@router.post("/blueprint")
def api_blueprint(payload: BlueprintInput):
    fn = get_module("blueprint")
    result = fn(payload.model_dump())
    return JSONResponse(content=result.to_mcp_response())


@router.post("/valuation")
def api_valuation(payload: ValuationInput):
    fn = get_module("valuation")
    result = fn(payload.model_dump())
    return JSONResponse(content=result.to_mcp_response())


@router.post("/theme")
def api_theme(payload: ThemeInput):
    fn = get_module("theme")
    result = fn(payload.model_dump())
    return JSONResponse(content=result.to_mcp_response())


@router.post("/validator")
def api_validator(payload: ValidatorInput):
    fn = get_module("validator")
    result = fn(payload.model_dump())
    return JSONResponse(content=result.to_mcp_response())


class OrchestrateRequest(BaseModel):
    config: OrchestratorInput = Field(default_factory=OrchestratorInput)
    persist: bool = Field(default=False, description="Save assembled package to database")
    domain_id: Optional[int] = Field(default=None, description="Domain ID for persistence (required if persist=True)")


@router.post("/orchestrate")
def api_orchestrate(payload: OrchestrateRequest, db: Session = Depends(get_db)):
    fn = get_module("orchestrator")
    all_stages = ["blueprint", "valuation", "context", "theme", "validator"]
    skip_stages = [s for s in all_stages if s not in payload.config.stages]
    orch_config = {
        "domain": payload.config.domain,
        "depth": payload.config.depth,
        "mood": payload.config.mood,
        "brand_colors": {
            "primary": payload.config.primary,
            "secondary": payload.config.secondary,
            "accent": payload.config.accent,
        },
        "skip_stages": skip_stages,
        "niches": payload.config.analysis.get("niches", []) if payload.config.analysis else [],
        "niche": "",
    }
    if orch_config["niches"] and len(orch_config["niches"]) > 0:
        first_niche = orch_config["niches"][0]
        orch_config["niche"] = first_niche.get("name", "") if isinstance(first_niche, dict) else ""

    result = fn(orch_config)
    response = result.to_mcp_response()

    if payload.persist and result.ok and payload.domain_id is not None:
        from app.services.aura_core_bridge import save_package_from_orchestrator
        try:
            pkg_id = save_package_from_orchestrator(
                db=db,
                domain_id=payload.domain_id,
                domain_name=payload.config.domain,
                orchestrator_output=result.output,
            )
            response["persisted"] = True
            response["package_id"] = pkg_id
        except Exception as e:
            logger.exception("Failed to persist orchestrator output")
            response["persisted"] = False
            response["persist_error"] = str(e)

    return JSONResponse(content=response)


@router.post("/run/{tool_name}")
def api_run_tool(tool_name: str, payload: dict):
    fn = get_module(tool_name)
    if fn is None:
        return JSONResponse(
            status_code=404,
            content={"ok": False, "error": f"Tool '{tool_name}' not found"},
        )
    result = fn(payload)
    return JSONResponse(content=result.to_mcp_response())
