from sqlalchemy.orm import Session
from sqlalchemy import func
from app.models import Augment


def find_duplicate_augments(db: Session, domain_name: str = None):
    query = db.query(
        Augment.domain_name,
        Augment.title,
        func.count(Augment.id).label("cnt"),
        func.min(Augment.id).label("keep_id"),
    ).group_by(Augment.domain_name, Augment.title).having(func.count(Augment.id) > 1)

    if domain_name:
        query = query.filter(Augment.domain_name == domain_name)

    groups = []
    for row in query.all():
        all_ids = [
            a.id for a in db.query(Augment.id)
            .filter(Augment.domain_name == row.domain_name, Augment.title == row.title)
            .order_by(Augment.id)
            .all()
        ]
        groups.append({
            "domain_name": row.domain_name,
            "title": row.title,
            "count": row.cnt,
            "keep_id": row.keep_id,
            "ids": all_ids,
            "remove_ids": sorted([i for i in all_ids if i != row.keep_id]),
        })
    return groups


def remove_duplicate_augments(db: Session, dry_run: bool = True, domain_name: str = None):
    groups = find_duplicate_augments(db, domain_name=domain_name)
    total_removed = 0
    remove_ids = []
    for g in groups:
        remove_ids.extend(g["remove_ids"])
        total_removed += len(g["remove_ids"])

    if not dry_run and remove_ids:
        db.query(Augment).filter(Augment.id.in_(remove_ids)).delete(synchronize_session=False)
        db.commit()

    return {
        "removed": total_removed,
        "groups": len(groups),
        "details": groups if dry_run else [],
    }
