import pytest
from aura_core.module_contract import (
    ModuleResult, ModuleDiagnostics, ModuleSpec,
    aura_module, get_module, get_module_spec, list_modules,
    get_mcp_tool_definitions, swap_module, Pipeline,
    _MODULE_REGISTRY,
)
from pydantic import BaseModel


class SampleInput(BaseModel):
    value: int

class SampleOutput(BaseModel):
    doubled: int


class TestModuleDiagnostics:
    def test_clean_by_default(self):
        d = ModuleDiagnostics()
        assert d.is_clean is True
        assert d.score == 100

    def test_error_marks_dirty(self):
        d = ModuleDiagnostics()
        d.error("field", "broken")
        assert d.is_clean is False
        assert d.score == 85

    def test_warning_reduces_score(self):
        d = ModuleDiagnostics()
        d.warn("field", "meh")
        assert d.is_clean is True
        assert d.score == 97

    def test_repair_tracked(self):
        d = ModuleDiagnostics()
        d.repair("field", "fixed")
        assert d.is_clean is True
        assert len(d.repairs) == 1


class TestModuleResult:
    def test_default_ok(self):
        r = ModuleResult(output=42)
        assert r.ok is True
        assert r.output == 42

    def test_to_mcp_response(self):
        r = ModuleResult(output=SampleOutput(doubled=10), module_name="test", module_version="0.1.0")
        mcp = r.to_mcp_response()
        assert mcp["ok"] is True
        assert mcp["output"]["doubled"] == 10
        assert mcp["module"] == "test"

    def test_failed_result(self):
        d = ModuleDiagnostics()
        d.error("exec", "boom")
        r = ModuleResult(ok=False, diagnostics=d)
        assert r.ok is False
        assert r.diagnostics.score < 100


class TestAuraModuleDecorator:
    def test_registration(self):
        @aura_module("test_doubler", input_model=SampleInput, output_model=SampleOutput)
        def doubler(input_data: SampleInput) -> SampleOutput:
            """Doubles a value"""
            return SampleOutput(doubled=input_data.value * 2)

        assert "test_doubler" in _MODULE_REGISTRY
        result = doubler(SampleInput(value=5))
        assert isinstance(result, ModuleResult)
        assert result.ok is True
        assert result.output.doubled == 10
        assert result.elapsed_ms > 0
        assert result.module_name == "test_doubler"

    def test_error_handling(self):
        @aura_module("test_exploder")
        def exploder(x):
            raise ValueError("boom")

        result = exploder("anything")
        assert result.ok is False
        assert len(result.diagnostics.errors) == 1
        assert "boom" in result.diagnostics.errors[0]["message"]

    def test_passthrough_module_result(self):
        @aura_module("test_passthrough")
        def passthrough(x):
            return ModuleResult(ok=True, output=x, metadata={"custom": True})

        result = passthrough(42)
        assert result.ok is True
        assert result.output == 42
        assert result.metadata["custom"] is True


class TestRegistry:
    def test_get_module(self):
        @aura_module("test_getter")
        def getter(x):
            return x

        fn = get_module("test_getter")
        assert fn is not None
        result = fn(99)
        assert result.output == 99

    def test_get_missing_module(self):
        assert get_module("nonexistent_xyzzy") is None

    def test_get_spec(self):
        @aura_module("test_spec", version="1.0.0", description="A test", tags=["test"])
        def spec_fn(x):
            return x

        spec = get_module_spec("test_spec")
        assert spec.name == "test_spec"
        assert spec.version == "1.0.0"
        assert "test" in spec.tags

    def test_list_modules(self):
        modules = list_modules()
        names = [m.name for m in modules]
        assert "test_getter" in names

    def test_mcp_tool_definitions(self):
        @aura_module("test_mcp_tool", input_model=SampleInput, output_model=SampleOutput, description="MCP test")
        def mcp_fn(x):
            return SampleOutput(doubled=x.value * 2)

        tools = get_mcp_tool_definitions()
        tool = next(t for t in tools if t["name"] == "test_mcp_tool")
        assert "properties" in tool["inputSchema"]
        assert "properties" in tool["outputSchema"]
        assert tool["description"] == "MCP test"


class TestSwapModule:
    def test_swap(self):
        @aura_module("test_swappable")
        def original(x):
            return x * 2

        assert original(5).output == 10

        def replacement(x):
            return ModuleResult(ok=True, output=x * 100)

        swap_module("test_swappable", replacement, version="2.0.0")
        fn = get_module("test_swappable")
        assert fn(5).output == 500

        spec = get_module_spec("test_swappable")
        assert spec.version == "2.0.0"
        assert "swapped" in spec.tags


class TestPipeline:
    def test_simple_chain(self):
        @aura_module("test_pipe_add")
        def adder(x):
            return x + 10

        @aura_module("test_pipe_mul")
        def multiplier(x):
            return x * 3

        pipe = Pipeline("test_pipe")
        pipe.add("test_pipe_add").add("test_pipe_mul")
        results = pipe.run(5)
        assert len(results) == 2
        assert results[0].output == 15
        assert results[1].output == 45

    def test_pipeline_halts_on_failure(self):
        @aura_module("test_pipe_fail")
        def failer(x):
            raise RuntimeError("stop")

        @aura_module("test_pipe_after")
        def after(x):
            return x

        pipe = Pipeline()
        pipe.add("test_pipe_fail").add("test_pipe_after")
        results = pipe.run(1)
        assert len(results) == 1
        assert results[0].ok is False

    def test_pipeline_with_adapter(self):
        @aura_module("test_pipe_dict")
        def dict_maker(x):
            return {"value": x}

        @aura_module("test_pipe_extract")
        def extractor(x):
            return x * 2

        pipe = Pipeline()
        pipe.add("test_pipe_dict")
        pipe.add("test_pipe_extract", adapter=lambda out: out["value"])
        results = pipe.run(7)
        assert results[1].output == 14

    def test_run_last(self):
        @aura_module("test_pipe_last")
        def last_fn(x):
            return x + 1

        pipe = Pipeline()
        pipe.add("test_pipe_last")
        result = pipe.run_last(10)
        assert result.output == 11
