"""Universal Module Contract for aura_core.

Every module follows the same pattern:
1. Declare input/output types (Pydantic models)
2. Implement run(input) -> ModuleResult[OutputType]
3. Register via @aura_module decorator
4. Get MCP schema, swapability, and pipeline composition for free

Design principles:
- Protocol over ABC (structural typing, no inheritance chains)
- Pydantic for I/O (auto JSON schema = MCP tool schema for free)
- Registry for discovery/swap (decorator-based self-registration)
- ModuleResult wraps everything (output + diagnostics + timing + metadata)
- Pipeline = list of modules, output feeds input, validation between steps
"""

import time
import logging
from typing import TypeVar, Generic, Optional, Any, Type, Callable
from pydantic import BaseModel, Field, ConfigDict

logger = logging.getLogger(__name__)

T = TypeVar("T")


class ModuleDiagnostics(BaseModel):
    errors: list[dict] = Field(default_factory=list)
    warnings: list[dict] = Field(default_factory=list)
    repairs: list[dict] = Field(default_factory=list)

    @property
    def is_clean(self) -> bool:
        return len(self.errors) == 0

    @property
    def score(self) -> int:
        base = 100
        base -= len(self.errors) * 15
        base -= len(self.warnings) * 3
        return max(0, min(100, base))

    def error(self, field: str, message: str):
        self.errors.append({"field": field, "message": message})

    def warn(self, field: str, message: str):
        self.warnings.append({"field": field, "message": message})

    def repair(self, field: str, action: str):
        self.repairs.append({"field": field, "action": action})


class ModuleResult(BaseModel, Generic[T]):
    model_config = ConfigDict(arbitrary_types_allowed=True)

    ok: bool = True
    output: Optional[Any] = None
    diagnostics: ModuleDiagnostics = Field(default_factory=ModuleDiagnostics)
    elapsed_ms: float = 0.0
    module_name: str = ""
    module_version: str = ""
    metadata: dict = Field(default_factory=dict)

    def to_mcp_response(self) -> dict:
        return {
            "ok": self.ok,
            "output": self.output.model_dump() if isinstance(self.output, BaseModel) else self.output,
            "diagnostics": {
                "score": self.diagnostics.score,
                "errors": self.diagnostics.errors,
                "warnings": self.diagnostics.warnings,
                "repairs": self.diagnostics.repairs,
            },
            "elapsed_ms": self.elapsed_ms,
            "module": self.module_name,
            "version": self.module_version,
        }


class ModuleSpec(BaseModel):
    name: str
    version: str = "0.1.0"
    description: str = ""
    input_schema: dict = Field(default_factory=dict)
    output_schema: dict = Field(default_factory=dict)
    tags: list[str] = Field(default_factory=list)
    pure: bool = True


_MODULE_REGISTRY: dict[str, dict] = {}


def aura_module(
    name: str,
    version: str = "0.1.0",
    description: str = "",
    input_model: Optional[Type[BaseModel]] = None,
    output_model: Optional[Type[BaseModel]] = None,
    tags: Optional[list[str]] = None,
    pure: bool = True,
):
    def decorator(fn: Callable) -> Callable:
        spec = ModuleSpec(
            name=name,
            version=version,
            description=description or fn.__doc__ or "",
            input_schema=input_model.model_json_schema() if input_model else {},
            output_schema=output_model.model_json_schema() if output_model else {},
            tags=tags or [],
            pure=pure,
        )

        def wrapper(*args, **kwargs) -> ModuleResult:
            start = time.perf_counter()
            diagnostics = ModuleDiagnostics()
            try:
                result = fn(*args, **kwargs)
                elapsed = (time.perf_counter() - start) * 1000

                if isinstance(result, ModuleResult):
                    result.elapsed_ms = elapsed
                    result.module_name = name
                    result.module_version = version
                    return result

                return ModuleResult(
                    ok=True,
                    output=result,
                    diagnostics=diagnostics,
                    elapsed_ms=elapsed,
                    module_name=name,
                    module_version=version,
                )
            except Exception as e:
                elapsed = (time.perf_counter() - start) * 1000
                diagnostics.error("execution", str(e))
                logger.exception(f"[{name}] Module execution failed")
                return ModuleResult(
                    ok=False,
                    output=None,
                    diagnostics=diagnostics,
                    elapsed_ms=elapsed,
                    module_name=name,
                    module_version=version,
                    metadata={"exception": type(e).__name__},
                )

        wrapper.__wrapped__ = fn
        wrapper.__module_spec__ = spec
        wrapper.__name__ = fn.__name__
        wrapper.__doc__ = fn.__doc__

        _MODULE_REGISTRY[name] = {
            "fn": wrapper,
            "raw_fn": fn,
            "spec": spec,
        }

        return wrapper

    return decorator


def get_module(name: str) -> Optional[Callable]:
    entry = _MODULE_REGISTRY.get(name)
    return entry["fn"] if entry else None


def get_module_spec(name: str) -> Optional[ModuleSpec]:
    entry = _MODULE_REGISTRY.get(name)
    return entry["spec"] if entry else None


def list_modules() -> list[ModuleSpec]:
    return [entry["spec"] for entry in _MODULE_REGISTRY.values()]


def get_mcp_tool_definitions() -> list[dict]:
    tools = []
    for entry in _MODULE_REGISTRY.values():
        spec = entry["spec"]
        tools.append({
            "name": spec.name,
            "description": spec.description,
            "inputSchema": spec.input_schema,
            "outputSchema": spec.output_schema,
        })
    return tools


def swap_module(name: str, new_fn: Callable, version: str = "0.1.0-swap"):
    if name not in _MODULE_REGISTRY:
        raise KeyError(f"Module '{name}' not registered")
    old = _MODULE_REGISTRY[name]
    old_spec = old["spec"]
    new_spec = ModuleSpec(
        name=name,
        version=version,
        description=old_spec.description,
        input_schema=old_spec.input_schema,
        output_schema=old_spec.output_schema,
        tags=old_spec.tags + ["swapped"],
        pure=old_spec.pure,
    )
    _MODULE_REGISTRY[name] = {
        "fn": new_fn,
        "raw_fn": new_fn,
        "spec": new_spec,
    }
    logger.info(f"[registry] Swapped module '{name}' -> v{version}")


class Pipeline:
    def __init__(self, name: str = "pipeline"):
        self.name = name
        self.steps: list[tuple[str, Callable, Optional[Callable]]] = []

    def add(self, module_name: str, adapter: Optional[Callable] = None) -> "Pipeline":
        fn = get_module(module_name)
        if fn is None:
            raise KeyError(f"Module '{module_name}' not found in registry")
        self.steps.append((module_name, fn, adapter))
        return self

    def add_fn(self, name: str, fn: Callable, adapter: Optional[Callable] = None) -> "Pipeline":
        self.steps.append((name, fn, adapter))
        return self

    def run(self, initial_input: Any) -> list[ModuleResult]:
        results = []
        current_input = initial_input

        for step_name, fn, adapter in self.steps:
            if adapter and results:
                current_input = adapter(results[-1].output)

            result = fn(current_input)

            if not isinstance(result, ModuleResult):
                result = ModuleResult(ok=True, output=result, module_name=step_name)

            results.append(result)

            if not result.ok:
                logger.warning(f"[pipeline:{self.name}] Step '{step_name}' failed, halting")
                break

            current_input = result.output

        return results

    def run_last(self, initial_input: Any) -> ModuleResult:
        results = self.run(initial_input)
        return results[-1] if results else ModuleResult(ok=False, module_name=self.name)
