import pytest
from unittest.mock import patch, MagicMock
from app.services.llm import (
    PIPELINE_STAGES, PROVIDERS,
    get_route, set_route, reset_route, get_all_routes,
    get_provider_status, get_routing_dashboard_data,
    _routing_overrides, _build_messages,
    _call_openai_compatible,
)


@pytest.fixture(autouse=True)
def clean_overrides():
    _routing_overrides.clear()
    yield
    _routing_overrides.clear()


class TestPipelineStagesRegistry:
    def test_all_stages_defined(self):
        assert len(PIPELINE_STAGES) >= 11

    def test_required_stages_present(self):
        expected = {
            "domain_analysis", "brand_identity", "site_copy",
            "sales_letter", "hero_image", "brand_kit_classify",
            "advisor_chat", "force_multiplier", "content_refine",
            "graphics_prompt", "market_research",
        }
        assert expected.issubset(set(PIPELINE_STAGES.keys()))

    def test_every_stage_has_required_fields(self):
        required_fields = {"name", "description", "mode", "quality_tier", "default_provider", "default_model"}
        for stage_id, config in PIPELINE_STAGES.items():
            missing = required_fields - set(config.keys())
            assert not missing, f"Stage '{stage_id}' missing fields: {missing}"

    def test_default_models_exist_in_provider_registry(self):
        for stage_id, config in PIPELINE_STAGES.items():
            provider = config["default_provider"]
            model = config["default_model"]
            assert provider in PROVIDERS, f"Stage '{stage_id}' references unknown provider '{provider}'"
            assert model in PROVIDERS[provider]["models"], (
                f"Stage '{stage_id}' references unknown model '{model}' in provider '{provider}'"
            )

    def test_default_model_supports_stage_mode(self):
        for stage_id, config in PIPELINE_STAGES.items():
            provider = config["default_provider"]
            model = config["default_model"]
            mode = config["mode"]
            supported_modes = PROVIDERS[provider]["models"][model]["modes"]
            assert mode in supported_modes, (
                f"Stage '{stage_id}' requires mode '{mode}' but default model '{model}' "
                f"only supports {supported_modes}"
            )

    def test_quality_first_defaults(self):
        legendary_count = sum(1 for s in PIPELINE_STAGES.values() if s["quality_tier"] == "legendary")
        assert legendary_count >= 7, f"Expected at least 7 Legendary-tier defaults, got {legendary_count}"

    def test_valid_modes_only(self):
        valid_modes = {"json", "text", "stream", "image", "vision"}
        for stage_id, config in PIPELINE_STAGES.items():
            assert config["mode"] in valid_modes, f"Stage '{stage_id}' has invalid mode: {config['mode']}"


class TestProviderRegistry:
    def test_four_providers(self):
        assert len(PROVIDERS) == 4

    def test_required_providers(self):
        assert set(PROVIDERS.keys()) == {"openai", "gemini", "perplexity", "huggingface"}

    def test_every_provider_has_env_key(self):
        for pid, p in PROVIDERS.items():
            assert "env_key" in p, f"Provider '{pid}' missing env_key"

    def test_every_model_has_tier(self):
        valid_tiers = {"legendary", "premium", "economy"}
        for pid, p in PROVIDERS.items():
            for mid, m in p["models"].items():
                assert m.get("tier") in valid_tiers, (
                    f"Provider '{pid}' model '{mid}' has invalid tier: {m.get('tier')}"
                )

    def test_every_model_has_modes(self):
        for pid, p in PROVIDERS.items():
            for mid, m in p["models"].items():
                modes = m.get("modes", [])
                assert len(modes) > 0, f"Provider '{pid}' model '{mid}' has no modes"

    def test_openai_has_image_capable_model(self):
        image_models = [
            mid for mid, m in PROVIDERS["openai"]["models"].items()
            if "image" in m.get("modes", [])
        ]
        assert len(image_models) >= 1, "OpenAI must have at least one image-capable model"


class TestRouteResolution:
    def test_all_routes_resolve(self):
        routes = get_all_routes()
        assert len(routes) >= 11
        for stage_id, route in routes.items():
            assert route["provider"]
            assert route["model"]
            assert route["quality_tier"]
            assert route["is_override"] is False

    def test_get_route_returns_required_fields(self):
        route = get_route("domain_analysis")
        required = {"stage", "stage_name", "mode", "provider", "provider_name",
                     "model", "model_name", "quality_tier", "default_quality_tier", "is_override"}
        assert required.issubset(set(route.keys()))

    def test_unknown_stage_raises(self):
        with pytest.raises(ValueError, match="Unknown pipeline stage"):
            get_route("nonexistent_stage")


class TestRouteOverrides:
    def test_set_override(self):
        result = set_route("domain_analysis", "gemini", "gemini-2.5-pro")
        assert result["is_override"] is True
        assert result["provider"] == "gemini"
        assert result["model"] == "gemini-2.5-pro"

    def test_override_persists(self):
        set_route("domain_analysis", "gemini", "gemini-2.5-pro")
        route = get_route("domain_analysis")
        assert route["is_override"] is True
        assert route["provider"] == "gemini"

    def test_double_override_replaces(self):
        set_route("domain_analysis", "gemini", "gemini-2.5-pro")
        set_route("domain_analysis", "gemini", "gemini-2.5-flash")
        route = get_route("domain_analysis")
        assert route["model"] == "gemini-2.5-flash"

    def test_reset_clears_override(self):
        set_route("domain_analysis", "gemini", "gemini-2.5-pro")
        result = reset_route("domain_analysis")
        assert result["is_override"] is False
        assert result["provider"] == PIPELINE_STAGES["domain_analysis"]["default_provider"]

    def test_reset_nonexistent_override_no_error(self):
        result = reset_route("domain_analysis")
        assert result["is_override"] is False

    def test_set_unknown_stage_raises(self):
        with pytest.raises(ValueError, match="Unknown pipeline stage"):
            set_route("fake_stage", "openai", "gpt-5")

    def test_set_unknown_provider_raises(self):
        with pytest.raises(ValueError, match="Unknown provider"):
            set_route("domain_analysis", "anthropic", "claude-4")

    def test_set_unknown_model_raises(self):
        with pytest.raises(ValueError, match="Unknown model"):
            set_route("domain_analysis", "openai", "gpt-99")

    def test_mode_incompatible_rejects(self):
        with pytest.raises(ValueError, match="does not support mode"):
            set_route("hero_image", "gemini", "gemini-2.5-pro")

    def test_reset_unknown_stage_raises(self):
        with pytest.raises(ValueError, match="Unknown pipeline stage"):
            reset_route("nonexistent")


class TestProviderStatus:
    def test_returns_all_providers(self):
        status = get_provider_status()
        assert len(status) == 4

    def test_status_fields(self):
        status = get_provider_status()
        for pid, s in status.items():
            assert "name" in s
            assert "configured" in s
            assert isinstance(s["configured"], bool)
            assert "model_count" in s
            assert s["model_count"] > 0


class TestDashboardData:
    def test_returns_all_sections(self):
        data = get_routing_dashboard_data()
        assert "stages" in data
        assert "routes" in data
        assert "providers" in data
        assert "provider_status" in data

    def test_stages_match_routes(self):
        data = get_routing_dashboard_data()
        assert set(data["stages"].keys()) == set(data["routes"].keys())


class TestBuildMessages:
    def test_user_only(self):
        msgs = _build_messages("hello")
        assert len(msgs) == 1
        assert msgs[0]["role"] == "user"
        assert msgs[0]["content"] == "hello"

    def test_with_system(self):
        msgs = _build_messages("hello", "you are helpful")
        assert len(msgs) == 2
        assert msgs[0]["role"] == "system"
        assert msgs[1]["role"] == "user"

    def test_none_system_excluded(self):
        msgs = _build_messages("hello", None)
        assert len(msgs) == 1


class TestCallOpenAICompatible:
    def test_json_mode(self, mock_openai_client):
        result = _call_openai_compatible(
            mock_openai_client, "gpt-5",
            [{"role": "user", "content": "test"}],
            max_tokens=100, json_mode=True
        )
        call_kwargs = mock_openai_client.chat.completions.create.call_args
        assert call_kwargs.kwargs.get("response_format") == {"type": "json_object"}
        assert isinstance(result, str)

    def test_text_mode(self, mock_openai_client):
        _call_openai_compatible(
            mock_openai_client, "gpt-5",
            [{"role": "user", "content": "test"}],
            max_tokens=100, json_mode=False
        )
        call_kwargs = mock_openai_client.chat.completions.create.call_args
        assert "response_format" not in call_kwargs.kwargs

    def test_stream_mode_returns_generator(self, mock_openai_client):
        mock_openai_client.chat.completions.create.return_value = iter(["chunk1"])
        result = _call_openai_compatible(
            mock_openai_client, "gpt-5",
            [{"role": "user", "content": "test"}],
            max_tokens=100, stream=True
        )
        call_kwargs = mock_openai_client.chat.completions.create.call_args
        assert call_kwargs.kwargs.get("stream") is True
