"""Kenya statutory payroll deductions — configurable via HrPayrollSettings."""
from __future__ import annotations

from datetime import datetime
from decimal import Decimal, ROUND_HALF_UP

from human_resource.payroll_settings import default_payroll_settings_map, resolve_payroll_settings

ADVANCE_DEDUCTION_TITLE = "Employee Advance Recovery"


def _decimal(value) -> Decimal:
    return Decimal(str(value or 0).replace(",", ""))


def _quantize(value: Decimal) -> Decimal:
    return value.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)


def _format_amount(value: Decimal | float | int) -> str:
    return str(_quantize(_decimal(value)))


def _settings_map(settings=None) -> dict:
    return settings if settings is not None else default_payroll_settings_map()


def _flag(value) -> bool:
    return str(value or "").lower() in ("true", "1", "yes", "on")


def should_skip_deduction_title(title: str) -> bool:
    """NHIF is deprecated — always excluded from payroll (replaced by SHA)."""
    return str(title or "").strip().upper() == "NHIF"


def is_sha_deduction(title: str) -> bool:
    key = str(title or "").strip().upper()
    return key in ("SHA", "SHIF")


def normalize_statutory_label(title: str, settings=None) -> str:
    key = str(title or "").strip().upper()
    if key == "SHIF" and _flag(_settings_map(settings).get("normalize_shif_to_sha", "true")):
        return "SHA"
    return str(title or "")


def calculate_shif_contribution(gross_salary, settings=None) -> Decimal:
    salary_gross = _decimal(gross_salary)
    rate = _decimal(_settings_map(settings).get("shif_contribution_rate", "0.0275"))
    return _quantize(salary_gross * rate)


def calculate_nssf_contribution(gross_salary, settings=None) -> Decimal:
    s = _settings_map(settings)
    salary_gross = _decimal(gross_salary)
    lower_limit = _decimal(s.get("nssf_tier1_limit", "8000"))
    upper_limit = _decimal(s.get("nssf_tier2_limit", "72000"))
    rate = _decimal(s.get("nssf_contribution_rate", "0.06"))
    band = upper_limit - lower_limit
    nssf_value = Decimal("0")

    if salary_gross > lower_limit:
        nssf_value += rate * lower_limit
        remainder = salary_gross - lower_limit
        if remainder < band:
            nssf_value += rate * remainder
        else:
            nssf_value += rate * band
    else:
        nssf_value = rate * lower_limit

    return _quantize(nssf_value)


def compute_gross_salary(
    basic_salary,
    prorate_factor,
    bonus_list,
    commissions_total,
) -> Decimal:
    salary_basic = _decimal(basic_salary) * _decimal(prorate_factor)
    total_commission = _decimal(commissions_total)
    bonus_total = sum(_decimal(b.get("bonus_instance_value")) for b in (bonus_list or []))
    return _quantize(salary_basic + total_commission + bonus_total)


def calculate_insurance_relief(
    salary_gross: Decimal,
    nhif_value: Decimal,
    deduction_list,
    settings=None,
) -> Decimal:
    s = _settings_map(settings)
    insurance_relief = Decimal("0")

    for deduction in deduction_list or []:
        title = str(deduction.get("deduction_title") or "").upper()
        if (
            deduction.get("deduction_type") == "insurance"
            and title not in ("SHIF", "SHA")
        ):
            if deduction.get("deduction_module") == "percentage":
                pct = _decimal(deduction.get("deduction_value"))
                insurance_relief += (pct / Decimal("100")) * salary_gross
            else:
                insurance_relief += _decimal(deduction.get("deduction_value"))

    relief_rate = _decimal(s.get("paye_insurance_relief_rate", "0.15"))
    relief_cap = _decimal(s.get("paye_insurance_relief_cap", "5000"))
    relief = relief_rate * (insurance_relief + nhif_value)
    return relief_cap if relief > relief_cap else relief


def calculate_paye(gross_salary, deduction_list, settings=None) -> Decimal:
    s = _settings_map(settings)
    personal_relief = _decimal(s.get("paye_personal_relief", "2400"))
    housing_levy_relief = Decimal("0")
    salary_gross = _decimal(gross_salary)
    housing_levy_value = Decimal("0")
    if _flag(s.get("include_housing_levy_in_paye", "true")):
        housing_levy_value = salary_gross * _decimal(s.get("housing_levy_rate", "0.015"))

    shif_value = calculate_shif_contribution(salary_gross, s)
    pension_contribution = calculate_nssf_contribution(salary_gross, s)
    taxable_pay = salary_gross - pension_contribution - shif_value - housing_levy_value
    paye_value = Decimal("0")

    band1_limit = _decimal(s.get("paye_band1_limit", "24000"))
    band1_rate = _decimal(s.get("paye_band1_rate", "0.10"))
    band2_limit = _decimal(s.get("paye_band2_limit", "8333"))
    band2_rate = _decimal(s.get("paye_band2_rate", "0.25"))
    band3_limit = _decimal(s.get("paye_band3_limit", "467667"))
    band3_rate = _decimal(s.get("paye_band3_rate", "0.30"))
    band4_limit = _decimal(s.get("paye_band4_limit", "300000"))
    band4_rate = _decimal(s.get("paye_band4_rate", "0.325"))
    band5_threshold = _decimal(s.get("paye_band5_threshold", "800000"))
    band5_rate = _decimal(s.get("paye_band5_rate", "0.35"))

    if taxable_pay > band1_limit:
        paye_value += band1_limit * band1_rate
        taxable_pay -= band1_limit
        if taxable_pay > band2_limit:
            paye_value += band2_limit * band2_rate
            taxable_pay -= band2_limit
            if taxable_pay > band3_limit:
                paye_value += band3_limit * band3_rate
                taxable_pay -= band3_limit
                if taxable_pay > band4_limit:
                    paye_value += band4_limit * band4_rate
                    taxable_pay -= band4_limit
                    if taxable_pay > band5_threshold:
                        paye_value += taxable_pay * band5_rate
                else:
                    paye_value += taxable_pay * band4_rate
            else:
                paye_value += taxable_pay * band3_rate
        else:
            paye_value += taxable_pay * band2_rate
    else:
        paye_value = taxable_pay * band1_rate

    insurance_relief = calculate_insurance_relief(salary_gross, Decimal("0"), deduction_list, s)
    net_paye = paye_value - personal_relief - insurance_relief - housing_levy_relief
    if net_paye < 0:
        return Decimal("0")
    return _quantize(net_paye)


def get_paye_value(deduction_list) -> Decimal:
    for deduction in deduction_list or []:
        if str(deduction.get("deduction_title") or "").upper() == "PAYE":
            return _decimal(deduction.get("deduction_value"))
    return Decimal("0")


def get_deductions_total_ex_paye(deduction_list) -> Decimal:
    total = Decimal("0")
    for deduction in deduction_list or []:
        if str(deduction.get("deduction_title") or "").upper() != "PAYE":
            total += _decimal(deduction.get("deduction_value"))
    return total


def build_deduction_instances(schemes, gross_salary, settings=None) -> list[dict]:
    s = _settings_map(settings)
    gross = _decimal(gross_salary)
    drafts: list[dict] = []

    for scheme in schemes or []:
        title = str(scheme.get("deduction_title") or "")
        if should_skip_deduction_title(title):
            continue

        draft = {
            "deduction_id": str(scheme.get("deduction_id") or ""),
            "deduction_title": normalize_statutory_label(title, s),
            "deduction_type": scheme.get("deduction_type") or "",
            "deduction_module": scheme.get("deduction_module") or "",
            "deduction_value": "0.00",
        }

        title_upper = title.upper()
        if title_upper == "NSSF":
            draft["deduction_value"] = _format_amount(calculate_nssf_contribution(gross, s))
        elif is_sha_deduction(title):
            draft["deduction_value"] = _format_amount(calculate_shif_contribution(gross, s))
        elif scheme.get("deduction_module") == "percentage":
            pct = _decimal(scheme.get("deduction_value"))
            draft["deduction_value"] = _format_amount((pct / Decimal("100")) * gross)
        else:
            draft["deduction_value"] = _format_amount(scheme.get("deduction_value"))

        drafts.append(draft)

    for draft in drafts:
        if str(draft.get("deduction_title") or "").upper() == "PAYE":
            draft["deduction_value"] = _format_amount(calculate_paye(gross, drafts, s))

    return drafts


def compute_advance_installment(advance: dict) -> Decimal:
    balance = _decimal(advance.get("balance_outstanding"))
    if balance <= 0:
        return Decimal("0")

    module = str(advance.get("recovery_module") or "fixed").lower()
    value = _decimal(advance.get("recovery_value"))

    if module == "percentage":
        pct = min(max(value, Decimal("0")), Decimal("100"))
        return _quantize(balance * pct / Decimal("100"))
    if module in ("fixed", "amount"):
        if value <= 0:
            return Decimal("0")
        return min(balance, value)
    return balance


def compute_staff_advance_recovery(staff_map: dict) -> Decimal:
    advances = staff_map.get("employee_advances_list") or []
    if advances:
        total = Decimal("0")
        for adv in advances:
            total += compute_advance_installment(adv)
        return _quantize(total)
    return _decimal(staff_map.get("employee_advance_recovery_suggested"))


def append_advance_recovery_deduction(
    staff_map: dict,
    deduction_list: list[dict],
    gross_salary: Decimal,
) -> list[dict]:
    balance = _decimal(staff_map.get("employee_advance_balance"))
    deduction_id = str(staff_map.get("employee_advance_deduction_id") or "")
    if balance <= 0 or not deduction_id:
        return deduction_list

    for item in deduction_list:
        if (
            item.get("deduction_id") == deduction_id
            or item.get("deduction_title") == ADVANCE_DEDUCTION_TITLE
        ):
            return deduction_list

    statutory_total = Decimal("0")
    for item in deduction_list:
        key = str(item.get("deduction_title") or "").upper()
        if key not in ("PAYE", ADVANCE_DEDUCTION_TITLE.upper()):
            statutory_total += _decimal(item.get("deduction_value"))

    paye_estimate = get_paye_value(deduction_list)
    room = max(Decimal("0"), gross_salary - statutory_total - paye_estimate)
    suggested = compute_staff_advance_recovery(staff_map)
    recovery = min(suggested, balance, room)
    if recovery <= 0:
        return deduction_list

    return [
        *deduction_list,
        {
            "deduction_id": deduction_id,
            "deduction_title": ADVANCE_DEDUCTION_TITLE,
            "deduction_value": _format_amount(recovery),
            "deduction_type": "loan_repayment",
            "deduction_module": "fixed",
        },
    ]


def _active_company_deduction_queryset(company_profile, current_date=None):
    if company_profile is None:
        return []
    current_date = current_date or datetime.now().date()
    active: list = []
    for deduction in company_profile.company_deductions.filter(recycle_bin=False).order_by("-id"):
        if deduction.date_effective_to and deduction.date_effective_to < current_date:
            continue
        if should_skip_deduction_title(deduction.deduction_title):
            continue
        active.append(deduction)
    return active


def company_deduction_schemes_list(company_profile, current_date=None) -> list[dict]:
    """Active company deductions as staff scheme maps (used when staff has no schemes)."""
    schemes: list[dict] = []
    for deduction in _active_company_deduction_queryset(company_profile, current_date):
        schemes.append({
            "deduction_id": str(deduction.id),
            "deduction_title": deduction.deduction_title,
            "deduction_description": deduction.deduction_description,
            "deduction_type": deduction.deduction_type,
            "deduction_module": deduction.deduction_module,
            "deduction_value": deduction.deduction_value,
        })
    return schemes


def company_deduction_catalog_list(company_profile, current_date=None) -> list[dict]:
    """Full active company deduction catalog for payroll UI/API payloads."""
    catalog: list[dict] = []
    for deduction in _active_company_deduction_queryset(company_profile, current_date):
        catalog.append({
            "deduction_id": str(deduction.id),
            "deduction_title": deduction.deduction_title,
            "deduction_description": deduction.deduction_description,
            "deduction_type": deduction.deduction_type,
            "deduction_module": deduction.deduction_module,
            "deduction_value": deduction.deduction_value,
            "statutory_formula": deduction.statutory_formula or "",
            "formula_config": deduction.formula_config or "",
            "date_effective_from": deduction.date_effective_from.isoformat() if deduction.date_effective_from else None,
            "date_effective_to": deduction.date_effective_to.isoformat() if deduction.date_effective_to else None,
        })
    return catalog


def _is_statutory_catalog_title(title: str) -> bool:
    key = str(title or "").strip().upper()
    if key in ("NSSF", "SHA", "SHIF", "PAYE"):
        return True
    return "HOUSING" in key and "LEVY" in key


def _has_statutory_scheme(schemes: list[dict], title: str) -> bool:
    key = str(title or "").strip().upper()
    for scheme in schemes or []:
        scheme_key = str(scheme.get("deduction_title") or "").strip().upper()
        if key == scheme_key:
            return True
        if key in ("SHA", "SHIF") and scheme_key in ("SHA", "SHIF"):
            return True
    return False


def resolve_staff_deduction_schemes(staff_map: dict, company_profile=None) -> list[dict]:
    """Merge staff schemes with missing statutory company deductions."""
    staff_schemes = list(staff_map.get("staff_deduction_schemes_list") or [])
    company_schemes = company_deduction_schemes_list(company_profile) if company_profile is not None else []

    if not staff_schemes:
        return company_schemes

    merged = list(staff_schemes)
    for catalog_item in company_schemes:
        title = catalog_item.get("deduction_title") or ""
        if _is_statutory_catalog_title(title) and not _has_statutory_scheme(merged, title):
            merged.append(dict(catalog_item))
    return merged


def build_payroll_preview_from_staff_data(
    staff_map: dict,
    *,
    prorate_factor: str = "1",
    commissions_total: str = "0.00",
    is_prorated: bool = False,
    settings=None,
    company_profile=None,
) -> dict:
    """Build gross, net, and deduction lines when basic salary and schemes are available."""
    s = settings or staff_map.get("_payroll_settings")
    if s is None and company_profile is not None:
        s = resolve_payroll_settings(company_profile)
    if s is None:
        s = default_payroll_settings_map()

    basic_salary = _decimal(staff_map.get("basic_salary"))
    schemes = resolve_staff_deduction_schemes(staff_map, company_profile)

    if basic_salary <= 0 or not schemes:
        return {"payroll_preview_available": "false"}

    bonus_list = [
        {
            "bonus_id": str(b.get("bonus_id") or ""),
            "bonus_instance_value": _format_amount(b.get("bonus_amount")),
        }
        for b in (staff_map.get("staff_bonus_schemes_list") or [])
    ]

    gross_salary = compute_gross_salary(
        staff_map.get("basic_salary"),
        prorate_factor,
        bonus_list,
        commissions_total,
    )

    deduction_list = build_deduction_instances(schemes, gross_salary, s)
    for item in deduction_list:
        if str(item.get("deduction_title") or "").upper() == "PAYE":
            item["deduction_value"] = _format_amount(calculate_paye(gross_salary, deduction_list, s))

    deduction_list = append_advance_recovery_deduction(staff_map, deduction_list, gross_salary)
    net_salary = gross_salary - get_paye_value(deduction_list) - get_deductions_total_ex_paye(deduction_list)

    computed_by_title = {
        str(d.get("deduction_title") or "").upper(): d.get("deduction_value")
        for d in deduction_list
    }

    enriched_schemes = []
    for scheme in schemes:
        scheme_copy = dict(scheme)
        title_key = normalize_statutory_label(scheme.get("deduction_title") or "", s).upper()
        if should_skip_deduction_title(scheme.get("deduction_title") or ""):
            scheme_copy["computed_deduction_value"] = "0.00"
        elif title_key in computed_by_title:
            scheme_copy["computed_deduction_value"] = computed_by_title[title_key]
        elif str(scheme.get("deduction_title") or "").upper() in computed_by_title:
            scheme_copy["computed_deduction_value"] = computed_by_title[
                str(scheme.get("deduction_title") or "").upper()
            ]
        else:
            scheme_copy["computed_deduction_value"] = "0.00"
        enriched_schemes.append(scheme_copy)

    return {
        "payroll_preview_available": "true",
        "payroll_preview_gross_salary": _format_amount(gross_salary),
        "payroll_preview_net_salary": _format_amount(net_salary),
        "payroll_preview_prorate_factor": str(prorate_factor),
        "payroll_preview_commissions_total": _format_amount(commissions_total),
        "payroll_preview_is_prorated": "true" if is_prorated else "false",
        "payroll_preview_bonus_instance_list": bonus_list,
        "payroll_preview_deduction_instance_list": deduction_list,
        "staff_deduction_schemes_list": enriched_schemes,
    }


def preview_sample_payroll(gross_salary, settings=None) -> dict:
    """Quick sample for HR settings screen (statutory lines only)."""
    s = _settings_map(settings)
    gross = _decimal(gross_salary)
    nssf = calculate_nssf_contribution(gross, s)
    shif = calculate_shif_contribution(gross, s)
    dummy_deductions = [
        {"deduction_title": "NSSF", "deduction_type": "statutory", "deduction_module": "fixed", "deduction_value": str(nssf)},
        {"deduction_title": "SHA", "deduction_type": "statutory", "deduction_module": "fixed", "deduction_value": str(shif)},
        {"deduction_title": "PAYE", "deduction_type": "statutory", "deduction_module": "fixed", "deduction_value": "0.00"},
    ]
    paye = calculate_paye(gross, dummy_deductions, s)
    dummy_deductions[2]["deduction_value"] = str(paye)
    total_deductions = nssf + shif + paye
    return {
        "sample_gross_salary": _format_amount(gross),
        "sample_nssf": _format_amount(nssf),
        "sample_shif": _format_amount(shif),
        "sample_paye": _format_amount(paye),
        "sample_net_salary": _format_amount(gross - total_deductions),
    }
