import os
import base64
import json
import logging
from openai import OpenAI
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception

logger = logging.getLogger(__name__)


# ═══════════════════════════════════════════════════════════════════════════════
# SECTION 1: PORTABLE ENV VAR RESOLUTION + LEGACY CLIENT
# ═══════════════════════════════════════════════════════════════════════════════

def _resolve_env(*keys: str) -> str | None:
    """Try multiple env var names in priority order, return first found."""
    for key in keys:
        val = os.environ.get(key)
        if val:
            return val
    return None


def _resolve_openai_key() -> str | None:
    return _resolve_env("AI_INTEGRATIONS_OPENAI_API_KEY", "OPENAI_API_KEY")

def _resolve_openai_base() -> str | None:
    return _resolve_env("AI_INTEGRATIONS_OPENAI_BASE_URL")

def _resolve_gemini_key() -> str | None:
    return _resolve_env("AI_INTEGRATIONS_GEMINI_API_KEY", "GEMINI_API_KEY", "GOOGLE_API_KEY")

def _resolve_gemini_base() -> str | None:
    return _resolve_env("AI_INTEGRATIONS_GEMINI_BASE_URL")

def _resolve_perplexity_key() -> str | None:
    return _resolve_env("PERPLEXITY_API", "PERPLEXITY_API_KEY")

def _resolve_huggingface_key() -> str | None:
    return _resolve_env("HUGGINGFACE_API_KEY", "HF_API_KEY")


_openai_key = _resolve_openai_key()
_openai_base = _resolve_openai_base()

client = OpenAI(
    api_key=_openai_key,
    base_url=_openai_base,
) if _openai_key else None


def is_rate_limit_error(exception: BaseException) -> bool:
    error_msg = str(exception)
    return (
        "429" in error_msg
        or "RATELIMIT_EXCEEDED" in error_msg
        or "quota" in error_msg.lower()
        or "rate limit" in error_msg.lower()
        or (hasattr(exception, "status_code") and exception.status_code == 429)
    )


# ═══════════════════════════════════════════════════════════════════════════════
# SECTION 2: LEGACY CALL FUNCTIONS (unchanged, backward compatible)
# ═══════════════════════════════════════════════════════════════════════════════

@retry(
    stop=stop_after_attempt(5),
    wait=wait_exponential(multiplier=1, min=2, max=60),
    retry=retry_if_exception(is_rate_limit_error),
    reraise=True,
)
def call_llm(prompt: str, system_prompt: str = None, max_tokens: int = 8192) -> str:
    if not client:
        return ""
    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})
    messages.append({"role": "user", "content": prompt})

    response = client.chat.completions.create(
        model="gpt-4o",
        messages=messages,
        max_completion_tokens=max_tokens,
        response_format={"type": "json_object"},
    )
    return response.choices[0].message.content or ""


@retry(
    stop=stop_after_attempt(5),
    wait=wait_exponential(multiplier=1, min=2, max=60),
    retry=retry_if_exception(is_rate_limit_error),
    reraise=True,
)
def call_llm_text(prompt: str, system_prompt: str = None, max_tokens: int = 16384) -> str:
    if not client:
        return ""
    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})
    messages.append({"role": "user", "content": prompt})

    response = client.chat.completions.create(
        model="gpt-4o",
        messages=messages,
        max_completion_tokens=max_tokens,
    )
    return response.choices[0].message.content or ""


def call_llm_stream(messages: list, max_tokens: int = 4096):
    if not client:
        return None
    return client.chat.completions.create(
        model="gpt-4o",
        messages=messages,
        max_completion_tokens=max_tokens,
        stream=True,
    )


@retry(
    stop=stop_after_attempt(3),
    wait=wait_exponential(multiplier=1, min=2, max=30),
    retry=retry_if_exception(is_rate_limit_error),
    reraise=True,
)
def call_llm_with_image(img_b64: str, mime_type: str, domain: str, niche: str, classifications: list, filename: str = "", filename_hint: str = "") -> dict:
    if not client:
        return {}
    class_list = ", ".join(classifications)

    filename_context = ""
    if filename:
        filename_context += f'\nThe original filename is: "{filename}".'
    if filename_hint:
        filename_context += f'\nBased on the filename, this image is likely a "{filename_hint}". Give strong weight to this hint unless the image clearly contradicts it.'

    messages = [
        {"role": "system", "content": "You are an image analyst for website design. Classify uploaded images and suggest where they should be used on a website. Pay close attention to both the visual content AND the filename when provided — filenames often contain reliable human-assigned labels. Always respond with valid JSON."},
        {"role": "user", "content": [
            {"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{img_b64}", "detail": "auto"}},
            {"type": "text", "text": f"""Analyze this image for a website about "{niche}" (domain: {domain}).{filename_context}

Classify it into ONE of these categories: {class_list}

Return JSON:
{{
  "classification": "one_of_the_categories_above",
  "tags": ["tag1", "tag2", "tag3"],
  "description": "Brief description of what the image shows",
  "suggested_sections": ["hero", "about", "team", "gallery", "testimonials", "features"],
  "quality_score": 1-10,
  "is_logo": true/false
}}"""}
        ]}
    ]

    for attempt in range(2):
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=messages,
            max_completion_tokens=500,
            response_format={"type": "json_object"},
        )
        content = response.choices[0].message.content or "{}"
        try:
            parsed = json.loads(content)
        except json.JSONDecodeError:
            parsed = {}

        if parsed.get("classification") and parsed.get("classification") != "other":
            return parsed
        if parsed.get("classification") == "other" and parsed.get("description"):
            return parsed
        if attempt == 0:
            logger.warning(f"Image classification returned incomplete response (attempt 1), retrying: {content[:200]}")
            continue

    if not parsed.get("classification"):
        parsed["classification"] = "other"
    if not parsed.get("tags"):
        parsed["tags"] = []
    if not parsed.get("description"):
        parsed["description"] = "Image could not be fully analyzed"
    if not parsed.get("suggested_sections"):
        parsed["suggested_sections"] = []
    return parsed


@retry(
    stop=stop_after_attempt(3),
    wait=wait_exponential(multiplier=1, min=2, max=30),
    retry=retry_if_exception(is_rate_limit_error),
    reraise=True,
)
def generate_image(prompt: str, size: str = "1536x1024") -> bytes:
    if not client:
        raise RuntimeError("OpenAI client not configured for image generation")
    response = client.images.generate(
        model="dall-e-3",
        prompt=prompt,
        size=size,
    )
    image_base64 = response.data[0].b64_json or ""
    return base64.b64decode(image_base64)


# ═══════════════════════════════════════════════════════════════════════════════
# SECTION 3: PIPELINE STAGE REGISTRY
# ═══════════════════════════════════════════════════════════════════════════════

PIPELINE_STAGES = {
    "domain_analysis": {
        "name": "Domain Analysis",
        "description": "Decompose domain keywords, generate niche ideas",
        "mode": "json",
        "quality_tier": "legendary",
        "default_provider": "gemini",
        "default_model": "gemini-2.0-pro-exp-02-05",
    },
    "brand_identity": {
        "name": "Brand Identity",
        "description": "Generate brand names, taglines, color palettes, typography",
        "mode": "json",
        "quality_tier": "legendary",
        "default_provider": "gemini",
        "default_model": "gemini-2.0-pro-exp-02-05",
    },
    "site_copy": {
        "name": "Site Copy Generation",
        "description": "Generate 16 website sections with full content",
        "mode": "json",
        "quality_tier": "legendary",
        "default_provider": "gemini",
        "default_model": "gemini-2.0-pro-exp-02-05",
    },
    "sales_letter": {
        "name": "Sales Letter",
        "description": "Persuasive long-form sales copy",
        "mode": "text",
        "quality_tier": "legendary",
        "default_provider": "gemini",
        "default_model": "gemini-2.0-pro-exp-02-05",
    },
    "hero_image": {
        "name": "Hero Image Generation",
        "description": "AI-generated hero/banner images",
        "mode": "image",
        "quality_tier": "legendary",
        "default_provider": "openai",
        "default_model": "dall-e-3",
    },
    "brand_kit_classify": {
        "name": "Brand Kit Classification",
        "description": "Classify uploaded images by type and usage",
        "mode": "vision",
        "quality_tier": "premium",
        "default_provider": "openai",
        "default_model": "gpt-4o",
    },
    "advisor_chat": {
        "name": "AI Advisor Chat",
        "description": "Context-aware business strategy conversation",
        "mode": "stream",
        "quality_tier": "legendary",
        "default_provider": "gemini",
        "default_model": "gemini-2.0-pro-exp-02-05",
    },
    "force_multiplier": {
        "name": "Force Multiplier Docs",
        "description": "Generate 33 business document types across 6 tiers",
        "mode": "text",
        "quality_tier": "legendary",
        "default_provider": "gemini",
        "default_model": "gemini-2.0-pro-exp-02-05",
    },
    "content_refine": {
        "name": "Content Refinement",
        "description": "Refine and improve existing content",
        "mode": "text",
        "quality_tier": "legendary",
        "default_provider": "gemini",
        "default_model": "gemini-2.0-pro-exp-02-05",
    },
    "graphics_prompt": {
        "name": "Graphics Pack Prompts",
        "description": "Generate image prompts for graphics pack",
        "mode": "text",
        "quality_tier": "premium",
        "default_provider": "gemini",
        "default_model": "gemini-2.0-flash",
    },
    "market_research": {
        "name": "Market Research",
        "description": "Live marketplace intelligence via Perplexity Sonar",
        "mode": "text",
        "quality_tier": "premium",
        "default_provider": "perplexity",
        "default_model": "sonar-pro",
    },
}


# ═══════════════════════════════════════════════════════════════════════════════
# SECTION 4: PROVIDERS & MODELS REGISTRY
# ═══════════════════════════════════════════════════════════════════════════════

PROVIDERS = {
    "openai": {
        "name": "OpenAI",
        "env_key": "AI_INTEGRATIONS_OPENAI_API_KEY",
        "env_base": "AI_INTEGRATIONS_OPENAI_BASE_URL",
        "portable_keys": ["AI_INTEGRATIONS_OPENAI_API_KEY", "OPENAI_API_KEY"],
        "sdk": "openai",
        "color": "#10B981",
        "models": {
            "gpt-4o": {
                "name": "GPT-4o",
                "tier": "premium",
                "input_cost": 2.50,
                "output_cost": 10.0,
                "modes": ["json", "text", "stream", "vision"],
            },
            "gpt-4o-mini": {
                "name": "GPT-4o Mini",
                "tier": "economy",
                "input_cost": 0.15,
                "output_cost": 0.60,
                "modes": ["json", "text", "stream", "vision"],
            },
            "dall-e-3": {
                "name": "DALL-E 3",
                "tier": "legendary",
                "modes": ["image"],
            },
        },
    },
    "gemini": {
        "name": "Google Gemini",
        "env_key": "AI_INTEGRATIONS_GEMINI_API_KEY",
        "env_base": "AI_INTEGRATIONS_GEMINI_BASE_URL",
        "portable_keys": ["AI_INTEGRATIONS_GEMINI_API_KEY", "GEMINI_API_KEY", "GOOGLE_API_KEY"],
        "sdk": "google_genai",
        "color": "#3B82F6",
        "models": {
            "gemini-2.0-flash": {
                "name": "Gemini 2.0 Flash",
                "tier": "premium",
                "input_cost": 0.10,
                "output_cost": 0.40,
                "modes": ["json", "text", "stream"],
            },
            "gemini-2.0-pro-exp-02-05": {
                "name": "Gemini 2.0 Pro",
                "tier": "legendary",
                "input_cost": 1.25,
                "output_cost": 5.00,
                "modes": ["json", "text", "stream"],
            },
        },
    },
    "perplexity": {
        "name": "Perplexity",
        "env_key": "PERPLEXITY_API",
        "portable_keys": ["PERPLEXITY_API", "PERPLEXITY_API_KEY"],
        "sdk": "openai",
        "color": "#22D3EE",
        "models": {
            "sonar-pro": {
                "name": "Sonar Pro",
                "tier": "premium",
                "input_cost": 3.0,
                "output_cost": 15.0,
                "modes": ["text", "stream"],
            },
            "sonar": {
                "name": "Sonar",
                "tier": "economy",
                "input_cost": 1.0,
                "output_cost": 1.0,
                "modes": ["text", "stream"],
            },
        },
    },
}


# ═══════════════════════════════════════════════════════════════════════════════
# SECTION 5: ROUTING CONFIG (in-memory, overridable via API)
# ═══════════════════════════════════════════════════════════════════════════════

_routing_overrides = {}


def get_route(stage: str) -> dict:
    """Get the current provider+model for a pipeline stage."""
    if stage not in PIPELINE_STAGES:
        raise ValueError(f"Unknown pipeline stage: {stage}")

    stage_config = PIPELINE_STAGES[stage]

    if stage in _routing_overrides:
        override = _routing_overrides[stage]
        provider_id = override["provider"]
        model_id = override["model"]
        is_override = True
    else:
        provider_id = stage_config["default_provider"]
        model_id = stage_config["default_model"]
        is_override = False

    provider_info = PROVIDERS.get(provider_id, {})
    model_info = provider_info.get("models", {}).get(model_id, {})

    return {
        "stage": stage,
        "stage_name": stage_config["name"],
        "mode": stage_config["mode"],
        "provider": provider_id,
        "provider_name": provider_info.get("name", provider_id),
        "model": model_id,
        "model_name": model_info.get("name", model_id),
        "quality_tier": model_info.get("tier", "unknown"),
        "default_quality_tier": stage_config["quality_tier"],
        "is_override": is_override,
    }


def set_route(stage: str, provider: str, model: str) -> dict:
    """Override the routing for a pipeline stage."""
    if stage not in PIPELINE_STAGES:
        raise ValueError(f"Unknown pipeline stage: {stage}")
    if provider not in PROVIDERS:
        raise ValueError(f"Unknown provider: {provider}")
    if model not in PROVIDERS[provider]["models"]:
        raise ValueError(f"Unknown model '{model}' for provider '{provider}'")

    stage_mode = PIPELINE_STAGES[stage]["mode"]
    model_modes = PROVIDERS[provider]["models"][model].get("modes", [])
    if stage_mode not in model_modes:
        raise ValueError(
            f"Model '{model}' does not support mode '{stage_mode}' "
            f"required by stage '{stage}'. Supported modes: {model_modes}"
        )

    _routing_overrides[stage] = {"provider": provider, "model": model}
    logger.info(f"Route override set: {stage} -> {provider}/{model}")
    return get_route(stage)


def reset_route(stage: str):
    """Reset a stage back to its default (quality-first) route."""
    if stage not in PIPELINE_STAGES:
        raise ValueError(f"Unknown pipeline stage: {stage}")
    removed = _routing_overrides.pop(stage, None)
    if removed:
        logger.info(f"Route override cleared: {stage} (back to default)")
    return get_route(stage)


def get_all_routes() -> dict:
    """Return all stages with their current routing (defaults + overrides)."""
    return {stage: get_route(stage) for stage in PIPELINE_STAGES}


# ═══════════════════════════════════════════════════════════════════════════════
# SECTION 6: PROVIDER CLIENT FACTORY (lazy init)
# ═══════════════════════════════════════════════════════════════════════════════

_provider_clients = {}

_OPENAI_COMPAT_BASE_URLS = {
    "perplexity": "https://api.perplexity.ai",
}


def _get_client(provider: str):
    """Lazy-initialize and cache provider clients."""
    if provider in _provider_clients:
        return _provider_clients[provider]

    if provider not in PROVIDERS:
        raise ValueError(f"Unknown provider: {provider}")

    prov_cfg = PROVIDERS[provider]
    sdk_type = prov_cfg.get("sdk", "openai")
    portable_keys = prov_cfg.get("portable_keys", [prov_cfg["env_key"]])
    api_key = _resolve_env(*portable_keys)

    if not api_key:
        tried = " / ".join(portable_keys)
        raise RuntimeError(f"{prov_cfg['name']} API key not configured. Set one of: {tried}")

    if sdk_type == "google_genai":
        from google import genai
        base_url = _resolve_env(prov_cfg.get("env_base", ""))
        http_opts = {}
        if base_url:
            http_opts = {'api_version': '', 'base_url': base_url}
        c = genai.Client(api_key=api_key, http_options=http_opts)

    else:
        base_url = None
        if prov_cfg.get("env_base"):
            base_url = _resolve_env(prov_cfg["env_base"])
        if not base_url:
            base_url = _OPENAI_COMPAT_BASE_URLS.get(provider)
        c = OpenAI(api_key=api_key, base_url=base_url)

    _provider_clients[provider] = c
    return c


def invalidate_provider_cache(provider: str = None):
    """Clear cached provider clients."""
    if provider:
        _provider_clients.pop(provider, None)
    else:
        _provider_clients.clear()


# ═══════════════════════════════════════════════════════════════════════════════
# SECTION 7: PROVIDER STATUS CHECK
# ═══════════════════════════════════════════════════════════════════════════════

def get_provider_status() -> dict:
    """Check which providers have valid API keys configured."""
    status = {}
    for pid, p in PROVIDERS.items():
        portable_keys = p.get("portable_keys", [p["env_key"]])
        resolved_key = _resolve_env(*portable_keys)
        configured = bool(resolved_key)
        active_key_name = None
        if configured:
            for k in portable_keys:
                if os.environ.get(k):
                    active_key_name = k
                    break
        status[pid] = {
            "name": p["name"],
            "configured": configured,
            "env_key": p["env_key"],
            "active_key": active_key_name,
            "portable_keys": portable_keys,
            "sdk": p.get("sdk", "openai"),
            "color": p["color"],
            "model_count": len(p["models"]),
        }
    return status


# ═══════════════════════════════════════════════════════════════════════════════
# SECTION 8: ROUTED CALL FUNCTIONS
# ═══════════════════════════════════════════════════════════════════════════════

def _build_messages(prompt: str, system_prompt: str = None) -> list:
    """Build a standard messages array."""
    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})
    messages.append({"role": "user", "content": prompt})
    return messages


def _call_openai_compatible(client_instance, model: str, messages: list, max_tokens: int, json_mode: bool = False, stream: bool = False):
    """Dispatch a call to any OpenAI-compatible API."""
    kwargs = {
        "model": model,
        "messages": messages,
        "max_completion_tokens": max_tokens,
    }
    if json_mode:
        kwargs["response_format"] = {"type": "json_object"}
    if stream:
        kwargs["stream"] = True
        return client_instance.chat.completions.create(**kwargs)

    response = client_instance.chat.completions.create(**kwargs)
    return response.choices[0].message.content or ""


def _call_gemini_native(client_instance, model: str, messages: list, max_tokens: int, json_mode: bool = False, stream: bool = False):
    """Dispatch a call to Gemini via native google.genai SDK."""
    from google.genai import types as genai_types

    system_parts = []
    user_parts = []
    for msg in messages:
        if msg["role"] == "system":
            system_parts.append(msg["content"])
        else:
            user_parts.append(msg["content"])

    contents = "\n\n".join(user_parts) if user_parts else ""
    config_kwargs = {"max_output_tokens": max_tokens}

    if json_mode:
        config_kwargs["response_mime_type"] = "application/json"

    system_instruction = "\n\n".join(system_parts) if system_parts else None
    if system_instruction:
        config_kwargs["system_instruction"] = system_instruction

    config = genai_types.GenerateContentConfig(**config_kwargs)

    if stream:
        return client_instance.models.generate_content_stream(
            model=model, contents=contents, config=config
        )

    response = client_instance.models.generate_content(
        model=model, contents=contents, config=config
    )
    return response.text or ""


def _dispatch_call(provider: str, client_instance, model: str, messages: list, max_tokens: int, json_mode: bool = False, stream: bool = False):
    """Universal dispatcher."""
    sdk_type = PROVIDERS.get(provider, {}).get("sdk", "openai")

    if sdk_type == "google_genai":
        return _call_gemini_native(client_instance, model, messages, max_tokens, json_mode=json_mode, stream=stream)
    else:
        return _call_openai_compatible(client_instance, model, messages, max_tokens, json_mode=json_mode, stream=stream)


@retry(
    stop=stop_after_attempt(5),
    wait=wait_exponential(multiplier=1, min=2, max=60),
    retry=retry_if_exception(is_rate_limit_error),
    reraise=True,
)
def call_llm_routed(stage: str, prompt: str, system_prompt: str = None, max_tokens: int = 8192) -> str:
    """Route a JSON-mode LLM call."""
    route = get_route(stage)
    provider = route["provider"]
    model = route["model"]
    messages = _build_messages(prompt, system_prompt)

    logger.info(f"[ROUTED] {stage} -> {provider}/{model} (tier: {route['quality_tier']}, mode: json)")

    client_instance = _get_client(provider)
    return _dispatch_call(provider, client_instance, model, messages, max_tokens, json_mode=True)


@retry(
    stop=stop_after_attempt(5),
    wait=wait_exponential(multiplier=1, min=2, max=60),
    retry=retry_if_exception(is_rate_limit_error),
    reraise=True,
)
def call_llm_text_routed(stage: str, prompt: str, system_prompt: str = None, max_tokens: int = 16384) -> str:
    """Route a text-mode LLM call."""
    route = get_route(stage)
    provider = route["provider"]
    model = route["model"]
    messages = _build_messages(prompt, system_prompt)

    logger.info(f"[ROUTED] {stage} -> {provider}/{model} (tier: {route['quality_tier']}, mode: text)")

    client_instance = _get_client(provider)
    return _dispatch_call(provider, client_instance, model, messages, max_tokens, json_mode=False)


def call_llm_stream_routed(stage: str, messages: list, max_tokens: int = 4096):
    """Route a streaming LLM call."""
    route = get_route(stage)
    provider = route["provider"]
    model = route["model"]

    logger.info(f"[ROUTED] {stage} -> {provider}/{model} (tier: {route['quality_tier']}, mode: stream)")

    client_instance = _get_client(provider)
    return _dispatch_call(provider, client_instance, model, messages, max_tokens, stream=True)


@retry(
    stop=stop_after_attempt(3),
    wait=wait_exponential(multiplier=1, min=2, max=30),
    retry=retry_if_exception(is_rate_limit_error),
    reraise=True,
)
def generate_image_routed(stage: str, prompt: str, size: str = "1536x1024") -> bytes:
    """Route an image generation call."""
    route = get_route(stage)
    provider = route["provider"]
    model = route["model"]
    sdk_type = PROVIDERS.get(provider, {}).get("sdk", "openai")

    logger.info(f"[ROUTED] {stage} -> {provider}/{model} (tier: {route['quality_tier']}, mode: image)")

    client_instance = _get_client(provider)

    if sdk_type == "google_genai":
        from google.genai import types as genai_types
        response = client_instance.models.generate_content(
            model=model,
            contents=prompt,
            config=genai_types.GenerateContentConfig(
                response_modalities=["TEXT", "IMAGE"]
            ),
        )
        if response.candidates:
            for part in response.candidates[0].content.parts:
                if hasattr(part, 'inline_data') and part.inline_data:
                    img_data = part.inline_data.data
                    if isinstance(img_data, str):
                        return base64.b64decode(img_data)
                    return img_data
        raise ValueError("Gemini image generation returned no image data")

    elif sdk_type == "openai":
        response = client_instance.images.generate(
            model=model,
            prompt=prompt,
            size=size,
        )
        image_base64 = response.data[0].b64_json or ""
        return base64.b64decode(image_base64)

    else:
        raise ValueError(f"Image generation not supported for provider: {provider} (sdk: {sdk_type})")


# ═══════════════════════════════════════════════════════════════════════════════
# SECTION 9: ADMIN DASHBOARD API HELPER
# ═══════════════════════════════════════════════════════════════════════════════

def get_routing_dashboard_data() -> dict:
    """Return everything the admin routing page needs: stages, routes, providers, status."""
    return {
        "stages": PIPELINE_STAGES,
        "routes": get_all_routes(),
        "providers": PROVIDERS,
        "provider_status": get_provider_status(),
    }
