"""
Business logic for extended Finance & Accounting features.
"""
from __future__ import annotations

import csv
import io
import json
from datetime import date, timedelta
from decimal import Decimal

from django.db.models import Sum
from django.utils import timezone

from . import financial_reporting as fr
from .models import (
    AccountHistory,
    AccountingPeriod,
    Asset,
    BatchPayment,
    BatchPaymentLine,
    ChartOfAccount,
    CurrencyExchangeRate,
    DepreciationScheduleLine,
    FinanceAuditLog,
    ManualJournalEntry,
    ManualJournalLine,
    MpesaStatementLine,
    Refund,
    ScheduledPaymentReminder,
    TaxLossCarryForward,
    VendorInvoice,
    WithholdingTaxEntry,
)


def log_finance_audit(branch, actor, entity_type, entity_id, action, notes="", before=None, after=None):
    FinanceAuditLog.objects.create(
        company_branch=branch,
        actor=actor,
        entity_type=entity_type,
        entity_id=str(entity_id),
        action=action,
        notes=notes,
        before_data=json.dumps(before or {}),
        after_data=json.dumps(after or {}),
    )


def get_exchange_rate(branch, from_currency, to_currency, rate_date=None):
    as_of = rate_date or date.today()
    from_c = (from_currency or "kes").lower()
    to_c = (to_currency or "kes").lower()
    if from_c == to_c:
        return Decimal("1")
    row = CurrencyExchangeRate.objects.filter(
        company_branch=branch,
        from_currency=from_c,
        to_currency=to_c,
        rate_date__lte=as_of,
    ).order_by("-rate_date").first()
    if row:
        return row.exchange_rate
    try:
        from forex_python.converter import CurrencyRates
        return Decimal(str(CurrencyRates().get_rate(from_c.upper(), to_c.upper())))
    except Exception:
        return Decimal("1")


def _booking_rate(branch, from_currency, to_currency, as_of):
    """Earliest available rate up to as_of — used as a proxy booking rate."""
    row = CurrencyExchangeRate.objects.filter(
        company_branch=branch,
        from_currency=(from_currency or "").lower(),
        to_currency=(to_currency or "").lower(),
        rate_date__lte=as_of,
    ).order_by("rate_date").first()
    if row:
        return float(row.exchange_rate)
    return float(get_exchange_rate(branch, from_currency, to_currency, as_of))


def build_exchange_loss_report(branch, as_of_date=None):
    """
    Foreign-currency revaluation. The GL stores book values in KES; we estimate
    the original FC balance using the booking rate, revalue at the current rate,
    and split the difference into realised (already posted) vs unrealised FX.
    """
    as_of = as_of_date or date.today()
    if isinstance(as_of, str):
        as_of = fr.parse_finance_date(as_of)

    exposures = []
    accounts = ChartOfAccount.objects.filter(
        company_branch=branch,
        recycle_bin=False,
        is_active=True,
    ).exclude(account_currency="kes")

    total_unrealized = 0.0
    total_realized = 0.0

    for account in accounts:
        balance = AccountHistory.objects.filter(account=account).aggregate(
            dr=Sum("debit"), cr=Sum("credit"),
        )
        book_kes = round(float(balance["dr"] or 0) - float(balance["cr"] or 0), 2)

        # Realised FX = revaluation adjustments already posted against this account.
        realized = AccountHistory.objects.filter(
            account=account, rel_type="fx_revaluation",
        ).aggregate(dr=Sum("debit"), cr=Sum("credit"))
        realized_kes = round(float(realized["dr"] or 0) - float(realized["cr"] or 0), 2)
        total_realized += realized_kes

        if abs(book_kes) < 0.01:
            continue

        current_rate = float(get_exchange_rate(branch, account.account_currency, "kes", as_of))
        booking_rate = _booking_rate(branch, account.account_currency, "kes", as_of) or current_rate
        balance_fc = round(book_kes / booking_rate, 2) if booking_rate else 0.0
        revalued_kes = round(balance_fc * current_rate, 2)
        unrealized = round(revalued_kes - book_kes, 2)
        total_unrealized += unrealized

        exposures.append({
            "account_id": str(account.id),
            "account_name": account.account_name,
            "currency": account.account_currency,
            "booking_rate": round(booking_rate, 4),
            "current_rate": round(current_rate, 4),
            "balance_fc": balance_fc,
            "book_balance_kes": book_kes,
            "revalued_kes": revalued_kes,
            "unrealized_gain_loss": unrealized,
            "realized_gain_loss": realized_kes,
        })

    return {
        "as_of_date": as_of.isoformat(),
        "exposures": exposures,
        "total_book_kes": round(sum(r["book_balance_kes"] for r in exposures), 2),
        "total_revalued_kes": round(sum(r["revalued_kes"] for r in exposures), 2),
        "total_unrealized_gain_loss": round(total_unrealized, 2),
        "total_realized_gain_loss": round(total_realized, 2),
    }


def post_fx_revaluation(branch, staff, as_of_date, gain_loss_account_id):
    """Post the unrealised FX revaluation to the GL against a gain/loss account."""
    report = build_exchange_loss_report(branch, as_of_date)
    as_of = fr.parse_finance_date(as_of_date) if isinstance(as_of_date, str) else (as_of_date or date.today())
    total = Decimal("0")
    gl_account = ChartOfAccount.objects.get(id=gain_loss_account_id)

    entry = ManualJournalEntry.objects.create(
        company_branch=branch,
        entry_date=as_of,
        description=f"FX revaluation — {as_of.isoformat()}",
        reference=f"FX-REVAL-{as_of.strftime('%Y%m%d')}",
        source_type="fx_revaluation",
        created_by=staff, last_updated_by=staff,
    )

    def _post(account, debit, credit, description):
        ManualJournalLine.objects.create(
            journal_entry=entry, account=account, debit=debit, credit=credit, description=description)
        AccountHistory.objects.create(
            company_branch=branch, account=account, debit=debit, credit=credit,
            transaction_date=as_of, rel_type="fx_revaluation", rel_id=str(entry.id),
            description=description, created_by=staff)

    posted = 0
    for row in report["exposures"]:
        adj = Decimal(str(row["unrealized_gain_loss"]))
        if abs(adj) < Decimal("0.01"):
            continue
        account = ChartOfAccount.objects.get(id=row["account_id"])
        if adj > 0:
            _post(account, adj, Decimal("0"), "FX revaluation gain")
        else:
            _post(account, Decimal("0"), -adj, "FX revaluation loss")
        total += adj
        posted += 1

    if posted == 0:
        entry.delete()
        return {"posted": 0, "message": "No FX adjustment required"}

    # Offset to the gain/loss account.
    if total > 0:
        _post(gl_account, Decimal("0"), total, "Unrealised FX gain")
    else:
        _post(gl_account, -total, Decimal("0"), "Unrealised FX loss")

    log_finance_audit(branch, staff, "fx_revaluation", entry.id, "posted",
                      f"Revalued {posted} FX accounts, net {float(total)}")
    return {"posted": posted, "net_adjustment": float(total), "entry_number": entry.entry_number}


def build_pl_month_on_month(branch, date_from, date_to):
    start = fr.parse_finance_date(date_from)
    end = fr.parse_finance_date(date_to)
    periods = {}
    cursor = date(start.year, start.month, 1)
    while cursor <= end:
        month_start = cursor
        if cursor.month == 12:
            month_end = date(cursor.year, 12, 31)
            next_month = date(cursor.year + 1, 1, 1)
        else:
            month_end = date(cursor.year, cursor.month + 1, 1) - timedelta(days=1)
            next_month = date(cursor.year, cursor.month + 1, 1)
        if month_end > end:
            month_end = end
        key = month_start.strftime("%Y-%m")
        qs = fr.branch_ledger(branch, month_start, month_end)
        pl = fr.build_profit_and_loss(qs)
        periods[key] = {
            "date_from": month_start.isoformat(),
            "date_to": month_end.isoformat(),
            "net_income": pl.get("net_income", 0),
            "total_income": pl.get("total_income", 0),
            "total_expenses": pl.get("total_expenses", 0),
        }
        cursor = next_month
    return {"periods": periods}


def run_year_end_close(branch, staff, period_id, retained_earnings_account_id):
    """
    Proper closing entry: zero out every income, COGS and expense account into
    Retained Earnings (income summary). One balanced journal posts a reversing
    line per P&L account, with the net carried to retained earnings.
    """
    period = AccountingPeriod.objects.get(id=period_id, company_branch=branch)
    if period.period_locked:
        return {"error": "Period already locked"}
    if AccountHistory.objects.filter(
        company_branch=branch, rel_type="year_end_close", rel_id=str(period.id),
    ).exists():
        return {"error": "Year-end close already posted for this period"}

    qs = fr.branch_ledger(branch, period.period_start_date, period.period_end_date)
    pl = fr.build_profit_and_loss(qs)
    net_income = float(pl.get("net_income", 0) or 0)

    # Collect each P&L account and the closing entry needed to zero it.
    # income accounts (normal credit) → DEBIT by their net; cogs/expense → CREDIT.
    closing_lines = []  # (account_id, debit, credit)
    for row in pl.get("income_lines", []):
        amt = round(float(row["amount"]), 2)
        if amt > 0:
            closing_lines.append((row["account_id"], amt, 0.0))
        elif amt < 0:
            closing_lines.append((row["account_id"], 0.0, -amt))
    for row in pl.get("cogs_lines", []) + pl.get("expense_lines", []):
        amt = round(float(row["amount"]), 2)
        if amt > 0:
            closing_lines.append((row["account_id"], 0.0, amt))
        elif amt < 0:
            closing_lines.append((row["account_id"], -amt, 0.0))

    if not closing_lines and abs(net_income) < 0.01:
        period.period_locked = True
        period.save()
        log_finance_audit(branch, staff, "accounting_period", period.id, "closed", "Zero net income close")
        return {"net_income": 0, "period_locked": True, "accounts_closed": 0}

    re_account = ChartOfAccount.objects.get(id=retained_earnings_account_id)
    entry = ManualJournalEntry.objects.create(
        company_branch=branch,
        entry_date=period.period_end_date,
        description=f"Year-end close — {period.period_name}",
        reference=f"YE-CLOSE-{period.id}",
        source_type="year_end_close",
        source_id=str(period.id),
        created_by=staff,
        last_updated_by=staff,
    )

    account_cache = {}

    def _account(account_id):
        if account_id not in account_cache:
            account_cache[account_id] = ChartOfAccount.objects.get(id=account_id)
        return account_cache[account_id]

    def _post(account, debit, credit, description):
        ManualJournalLine.objects.create(
            journal_entry=entry, account=account,
            debit=debit, credit=credit, description=description)
        AccountHistory.objects.create(
            company_branch=branch, account=account,
            debit=debit, credit=credit, transaction_date=period.period_end_date,
            rel_type="year_end_close", rel_id=str(period.id),
            description=description, created_by=staff)

    for account_id, debit, credit in closing_lines:
        _post(_account(account_id), round(debit, 2), round(credit, 2),
              f"Close to retained earnings — {period.period_name}")

    # Balancing line to Retained Earnings: profit → credit, loss → debit.
    if net_income >= 0:
        _post(re_account, 0.0, round(net_income, 2), "Net income transferred to retained earnings")
    else:
        _post(re_account, round(abs(net_income), 2), 0.0, "Net loss transferred to retained earnings")

    period.period_locked = True
    period.save()
    log_finance_audit(branch, staff, "accounting_period", period.id, "closed",
                      f"Net income {net_income} closed across {len(closing_lines)} P&L accounts")
    return {
        "net_income": net_income,
        "period_locked": True,
        "entry_number": entry.entry_number,
        "accounts_closed": len(closing_lines),
    }


def match_vendor_invoice_3way(vendor_invoice):
    po = vendor_invoice.purchase_order
    if not po:
        vendor_invoice.match_status = "exception"
        vendor_invoice.match_notes = "No purchase order linked"
        vendor_invoice.save()
        return vendor_invoice

    po_amount = float(vendor_invoice.po_amount or po.purchase_value_overall or 0)
    received = 0.0
    for inst in po.purchase_order_product_instances.filter(recycle_bin=False):
        qty = float(inst.quantity_delivered or 0)
        price = float(inst.purchase_value_per_unit or 0)
        received += qty * price

    invoice_amt = float(vendor_invoice.invoice_amount or 0)
    vendor_invoice.po_amount = str(round(po_amount, 2))
    vendor_invoice.received_amount = str(round(received, 2))

    if abs(invoice_amt - po_amount) < 0.01 and abs(received - po_amount) < 0.01:
        vendor_invoice.match_status = "matched"
        vendor_invoice.match_notes = "PO, receipt, and invoice amounts align"
    elif abs(invoice_amt - received) < 0.01:
        vendor_invoice.match_status = "partial"
        vendor_invoice.match_notes = "Invoice matches receipt; PO variance"
    else:
        vendor_invoice.match_status = "exception"
        vendor_invoice.match_notes = f"Variance: PO={po_amount}, Received={received}, Invoice={invoice_amt}"

    vendor_invoice.save()
    return vendor_invoice


def build_vat_registers(branch, date_from, date_to):
    from sales_and_marketing.models import CustomerOrder, CustomerOrderItem, ProductVAT
    from procurement.models import ProductPurchaseInstance

    output_rows = []
    orders = CustomerOrder.objects.filter(
        company_branch=branch,
        customer_order_approved=True,
        recycle_bin=False,
        created_on__date__gte=fr.parse_finance_date(date_from),
        created_on__date__lte=fr.parse_finance_date(date_to),
    )

    for order in orders:
        for item in order.customer_order_items.all():
            vat_pct = 0.0
            category = "standard"
            try:
                pv = ProductVAT.objects.get(product=item.product)
                vat_pct = float(pv.vat_percentage_value or 0)
                category = pv.vat_category
            except ProductVAT.DoesNotExist:
                pass
            net = float(item.net_subtotal or 0)
            vat_amt = round(net * vat_pct / 100, 2) if category == "standard" else 0
            customer = getattr(order, "customer_profile", None)
            output_rows.append({
                "type": "output",
                "date": order.created_on.date().isoformat(),
                "reference": order.customer_order_number,
                "customer_name": str(customer) if customer else "",
                "customer_pin": getattr(customer, "kra_pin", "") or getattr(customer, "tax_pin", "") or "",
                "description": str(item.product),
                "net_amount": net,
                "vat_category": category,
                "vat_rate": vat_pct,
                "vat_amount": vat_amt,
            })

    input_rows = []
    purchases = ProductPurchaseInstance.objects.filter(
        purchase_order__company_profile=branch.company_profile,
        recycle_bin=False,
        created_on__date__gte=fr.parse_finance_date(date_from),
        created_on__date__lte=fr.parse_finance_date(date_to),
    ).select_related("supplier")

    for line in purchases:
        net_in = float(line.purchase_value_overall or 0)
        input_rows.append({
            "type": "input",
            "date": line.created_on.date().isoformat(),
            "reference": str(line.purchase_order),
            "supplier_name": str(line.supplier) if line.supplier else "",
            "supplier_pin": "",
            "description": str(line.purchase_product),
            "net_amount": net_in,
            "vat_category": "standard",
            "vat_rate": 16,
            "vat_amount": float(line.input_vat_amount or 0),
        })

    return {
        "output_register": output_rows,
        "input_register": input_rows,
        "output_vat_total": round(sum(r["vat_amount"] for r in output_rows), 2),
        "input_vat_total": round(sum(r["vat_amount"] for r in input_rows), 2),
        "output_net_total": round(sum(r["net_amount"] for r in output_rows), 2),
        "input_net_total": round(sum(r["net_amount"] for r in input_rows), 2),
    }


def export_kra_vat_csv(registers):
    """KRA iTax VAT3-aligned CSV: separate Sales (output) and Purchases (input)
    sections with PIN / invoice / taxable value / VAT columns."""
    buf = io.StringIO()
    writer = csv.writer(buf)

    writer.writerow(["KRA VAT RETURN (VAT3) — GENERATED EXPORT"])
    writer.writerow([])
    writer.writerow(["SECTION B: SALES / OUTPUT TAX"])
    writer.writerow([
        "PIN of Customer", "Customer Name", "Invoice Date", "Invoice No",
        "Description", "Taxable Value (Excl VAT)", "Rate (%)", "Amount of VAT", "VAT Category",
    ])
    for row in registers.get("output_register", []):
        writer.writerow([
            row.get("customer_pin", ""), row.get("customer_name", ""), row.get("date"),
            row.get("reference"), row.get("description"), row.get("net_amount"),
            row.get("vat_rate", ""), row.get("vat_amount"), row.get("vat_category"),
        ])
    writer.writerow(["", "", "", "", "TOTAL OUTPUT VAT", registers.get("output_net_total", ""),
                     "", registers.get("output_vat_total", "")])

    writer.writerow([])
    writer.writerow(["SECTION F: PURCHASES / INPUT TAX"])
    writer.writerow([
        "PIN of Supplier", "Supplier Name", "Invoice Date", "Invoice No",
        "Description", "Taxable Value (Excl VAT)", "Rate (%)", "Amount of VAT", "VAT Category",
    ])
    for row in registers.get("input_register", []):
        writer.writerow([
            row.get("supplier_pin", ""), row.get("supplier_name", ""), row.get("date"),
            row.get("reference"), row.get("description"), row.get("net_amount"),
            row.get("vat_rate", "16"), row.get("vat_amount"), row.get("vat_category"),
        ])
    writer.writerow(["", "", "", "", "TOTAL INPUT VAT", registers.get("input_net_total", ""),
                     "", registers.get("input_vat_total", "")])

    writer.writerow([])
    net_vat = round(float(registers.get("output_vat_total") or 0) - float(registers.get("input_vat_total") or 0), 2)
    writer.writerow(["NET VAT PAYABLE / (CREDIT)", net_vat])
    return buf.getvalue()


# ---------------------------------------------------------------------------
# WITHHOLDING TAX (WHVAT 2% & WHT) — KRA filing register
# ---------------------------------------------------------------------------

def generate_withholding_entries(branch, date_from, date_to, default_rate=2.0):
    """Auto-generate WHVAT entries for approved purchases in the window that
    don't yet have a withholding record. Returns count created."""
    from procurement.models import ProductPurchaseInstance

    start = fr.parse_finance_date(date_from)
    end = fr.parse_finance_date(date_to)
    created = 0

    purchases = ProductPurchaseInstance.objects.filter(
        purchase_order__company_profile=branch.company_profile,
        recycle_bin=False,
        created_on__date__gte=start,
        created_on__date__lte=end,
    ).select_related("purchase_order", "supplier")

    for line in purchases:
        po = line.purchase_order
        if not po:
            continue
        if WithholdingTaxEntry.objects.filter(
            company_branch=branch, purchase_order=po, tax_type="vat_withholding",
        ).exists():
            continue
        gross = Decimal(str(line.purchase_value_overall or "0").replace(",", "") or "0")
        if gross <= 0:
            continue
        rate = Decimal(str(default_rate))
        wht = (gross * rate / Decimal("100")).quantize(Decimal("0.01"))
        WithholdingTaxEntry.objects.create(
            company_branch=branch,
            supplier=line.supplier,
            purchase_order=po,
            tax_type="vat_withholding",
            supplier_name=str(line.supplier) if line.supplier else "",
            invoice_number=po.purchase_order_number,
            gross_amount=gross,
            wht_rate=rate,
            wht_amount=wht,
            period_date=line.created_on.date(),
        )
        created += 1
    return created


def build_withholding_register(branch, date_from=None, date_to=None, tax_type=None):
    qs = WithholdingTaxEntry.objects.filter(company_branch=branch, recycle_bin=False)
    if date_from:
        qs = qs.filter(period_date__gte=fr.parse_finance_date(date_from))
    if date_to:
        qs = qs.filter(period_date__lte=fr.parse_finance_date(date_to))
    if tax_type:
        qs = qs.filter(tax_type=tax_type)

    rows = []
    total_gross = Decimal("0")
    total_wht = Decimal("0")
    for e in qs.select_related("supplier"):
        rows.append({
            "id": str(e.id),
            "tax_type": e.tax_type,
            "supplier_name": e.supplier_name or (str(e.supplier) if e.supplier else ""),
            "supplier_pin": e.supplier_pin,
            "invoice_number": e.invoice_number,
            "gross_amount": float(e.gross_amount),
            "wht_rate": float(e.wht_rate),
            "wht_amount": float(e.wht_amount),
            "certificate_number": e.certificate_number,
            "period_date": e.period_date.isoformat(),
            "status": e.status,
        })
        total_gross += e.gross_amount
        total_wht += e.wht_amount

    return {
        "rows": rows,
        "total_gross": float(total_gross),
        "total_wht": float(total_wht),
    }


def export_kra_whvat_csv(register):
    buf = io.StringIO()
    writer = csv.writer(buf)
    writer.writerow(["KRA WITHHOLDING VAT/TAX RETURN — GENERATED EXPORT"])
    writer.writerow([])
    writer.writerow([
        "PIN of Supplier", "Supplier Name", "Invoice No", "Tax Type",
        "Gross Amount", "WHT Rate (%)", "WHT Amount", "Certificate No", "Period", "Status",
    ])
    for row in register.get("rows", []):
        writer.writerow([
            row.get("supplier_pin", ""), row.get("supplier_name"), row.get("invoice_number"),
            row.get("tax_type"), row.get("gross_amount"), row.get("wht_rate"),
            row.get("wht_amount"), row.get("certificate_number"), row.get("period_date"), row.get("status"),
        ])
    writer.writerow([])
    writer.writerow(["TOTAL GROSS", register.get("total_gross", "")])
    writer.writerow(["TOTAL WHT", register.get("total_wht", "")])
    return buf.getvalue()


def import_mpesa_csv(branch, account_id, csv_text, staff):
    account = ChartOfAccount.objects.get(id=account_id, company_branch=branch)
    reader = csv.DictReader(io.StringIO(csv_text))
    created = 0
    matched = 0
    for row in reader:
        ref = row.get("Reference") or row.get("reference_code") or row.get("Receipt No") or ""
        name = row.get("Name") or row.get("payer_name") or ""
        amount = float(row.get("Amount") or row.get("amount") or 0)
        date_str = row.get("Date") or row.get("transaction_date") or date.today().isoformat()
        tx_date = fr.parse_finance_date(date_str)

        line = MpesaStatementLine.objects.create(
            company_branch=branch,
            mpesa_account=account,
            transaction_date=tx_date,
            reference_code=ref.strip(),
            payer_name=name.strip(),
            amount=amount,
        )
        created += 1

        history = AccountHistory.objects.filter(
            account=account,
            reconciled=False,
        ).filter(description__icontains=ref[:20]).first() if ref else None
        if not history and ref:
            history = AccountHistory.objects.filter(
                account=account,
                debit=amount,
                reconciled=False,
            ).first()
        if history:
            history.reconciled = True
            history.save()
            line.matched = True
            line.matched_history = history
            line.save()
            matched += 1

    log_finance_audit(branch, staff, "mpesa_import", account_id, "imported",
                      f"{created} lines, {matched} matched")
    return {"imported": created, "matched": matched, "unmatched": created - matched}


def run_depreciation_for_period(branch, staff, period_month, expense_account_id, accum_account_id):
    period_date = fr.parse_finance_date(period_month)
    expense_acct = ChartOfAccount.objects.get(id=expense_account_id)
    accum_acct = ChartOfAccount.objects.get(id=accum_account_id)
    posted = 0
    total = Decimal("0")
    entry = None

    for asset in Asset.objects.filter(recycle_bin=False, status="active"):
        cost = Decimal(str(asset.purchase_cost or "0").replace(",", ""))
        residual = Decimal(str(asset.residual_value or "0").replace(",", ""))
        life_years = Decimal(str(asset.useful_life_years or "1").replace(",", ""))

        accumulated = DepreciationScheduleLine.objects.filter(
            asset=asset,
        ).aggregate(total=Sum("depreciation_amount"))["total"] or Decimal("0")
        net_book_value = cost - accumulated

        if asset.depreciation_method == "declining_balance":
            # Reducing balance: rate % of the current net book value, per month.
            rate = Decimal(str(asset.depreciation_rate or "0").replace(",", ""))
            if rate <= 0 and life_years > 0:
                rate = (Decimal("100") / life_years)  # sensible default annual rate
            monthly = (net_book_value * rate / Decimal("100")) / Decimal("12")
        else:
            if life_years <= 0:
                continue
            monthly = (cost - residual) / (life_years * 12)

        # Never depreciate below the residual value.
        depreciable_remaining = net_book_value - residual
        if depreciable_remaining <= 0:
            continue
        if monthly > depreciable_remaining:
            monthly = depreciable_remaining
        monthly = monthly.quantize(Decimal("0.01"))
        if monthly <= 0:
            continue

        line, created = DepreciationScheduleLine.objects.get_or_create(
            asset=asset,
            period_month=period_date,
            defaults={"depreciation_amount": monthly},
        )
        if line.posted:
            continue
        if not created:
            line.depreciation_amount = monthly

        # One consolidated depreciation journal for the whole period run.
        if entry is None:
            entry = ManualJournalEntry.objects.create(
                company_branch=branch,
                entry_date=period_date,
                description=f"Depreciation — {period_date.strftime('%b %Y')}",
                reference=f"DEP-{period_date.strftime('%Y%m')}",
                source_type="depreciation",
                created_by=staff,
                last_updated_by=staff,
            )

        for acct, dr, cr in [(expense_acct, monthly, 0), (accum_acct, 0, monthly)]:
            ManualJournalLine.objects.create(
                journal_entry=entry, account=acct,
                debit=dr, credit=cr, description=f"Depreciation {asset.asset_name}",
            )
            AccountHistory.objects.create(
                company_branch=branch, account=acct,
                debit=dr, credit=cr, transaction_date=period_date,
                rel_type="depreciation", rel_id=str(asset.id),
                description=f"Depreciation {asset.asset_name}",
                created_by=staff,
            )
        line.posted = True
        line.journal_entry = entry
        line.save()
        posted += 1
        total += monthly

    return {
        "assets_posted": posted,
        "total_depreciation": float(total),
        "entry_number": entry.entry_number if entry else "",
    }


def update_tax_loss_carry_forward(branch, tax_year, net_income):
    row, _ = TaxLossCarryForward.objects.get_or_create(
        company_branch=branch, tax_year=tax_year,
        defaults={"assessed_loss": 0, "utilized": 0, "balance": 0},
    )
    if net_income < 0:
        loss = abs(net_income)
        row.assessed_loss += Decimal(str(loss))
        row.balance += Decimal(str(loss))
    elif net_income > 0 and row.balance > 0:
        utilize = min(Decimal(str(net_income)), row.balance)
        row.utilized += utilize
        row.balance -= utilize
    row.save()
    return row


def generate_due_reminders(branch):
    from sales_and_marketing.models import CustomerOrder, CustomerProfile
    from procurement.models import PurchaseOrder

    created = 0
    today = date.today()

    customers = CustomerProfile.objects.filter(company_profile=branch.company_profile, recycle_bin=False)
    for customer in customers:
        terms = int(customer.customer_payment_terms_days or 30)
        orders = CustomerOrder.objects.filter(
            customer_profile=customer,
            customer_order_approved=True,
            recycle_bin=False,
        )
        for order in orders:
            net = float(order.customer_order_total_net_value or 0)
            paid = float(order.customer_order_total_amount_paid or 0)
            if net - paid <= 0:
                continue
            due = order.created_on.date() + timedelta(days=terms)
            if due >= today:
                continue
            exists = ScheduledPaymentReminder.objects.filter(
                company_branch=branch,
                entity_type="customer_order",
                entity_id=str(order.id),
                sent=False,
            ).exists()
            if exists:
                continue
            ScheduledPaymentReminder.objects.create(
                company_branch=branch,
                reminder_type="credit_follow_up",
                entity_type="customer_order",
                entity_id=str(order.id),
                recipient_email=customer.email_address,
                due_date=due,
                amount_due=str(round(net - paid, 2)),
            )
            created += 1

    pos = PurchaseOrder.objects.filter(
        company_profile=branch.company_profile,
        purchase_order_approved=True,
        recycle_bin=False,
    )
    for po in pos:
        total = float(po.purchase_value_overall or 0)
        if total <= 0:
            continue
        due = po.created_on.date() + timedelta(days=30)
        if due >= today:
            continue
        exists = ScheduledPaymentReminder.objects.filter(
            company_branch=branch,
            entity_type="purchase_order",
            entity_id=str(po.id),
            sent=False,
        ).exists()
        if exists:
            continue
        ScheduledPaymentReminder.objects.create(
            company_branch=branch,
            reminder_type="ap_due",
            entity_type="purchase_order",
            entity_id=str(po.id),
            due_date=due,
            amount_due=str(total),
        )
        created += 1

    return {"reminders_created": created}


def reconcile_mpesa_report(branch):
    base = MpesaStatementLine.objects.filter(company_branch=branch)
    lines = base.order_by("-transaction_date")[:500]
    return {
        "lines": [{
            "id": str(l.id),
            "date": l.transaction_date.isoformat(),
            "reference_code": l.reference_code,
            "payer_name": l.payer_name,
            "amount": float(l.amount),
            "matched": l.matched,
        } for l in lines],
        "unmatched_count": base.filter(matched=False).count(),
    }


# ---------------------------------------------------------------------------
# SCHEDULED / BATCH SUPPLIER PAYMENTS
# ---------------------------------------------------------------------------

def build_batch_payment_candidates(branch):
    """Outstanding approved purchase orders eligible for batch payment."""
    from procurement.models import PurchaseOrder

    paid_map = {}
    for h in AccountHistory.objects.filter(
        company_branch=branch, rel_type="purchase_order",
    ).values("rel_id", "credit"):
        rid = str(h["rel_id"])
        paid_map[rid] = paid_map.get(rid, 0.0) + float(h["credit"] or 0)

    candidates = []
    pos = PurchaseOrder.objects.filter(
        company_profile=branch.company_profile,
        purchase_order_approved=True,
        recycle_bin=False,
    ).order_by("created_on")
    for po in pos:
        total = float(po.purchase_value_overall or 0)
        paid = paid_map.get(str(po.id), 0.0)
        outstanding = round(total - paid, 2)
        if outstanding <= 0:
            continue
        first_instance = po.purchase_order_product_instances.filter(recycle_bin=False).first()
        supplier = first_instance.supplier if first_instance else None
        candidates.append({
            "purchase_order_id": po.id,
            "purchase_order_number": po.purchase_order_number,
            "supplier_name": str(supplier) if supplier else "",
            "outstanding": outstanding,
        })
    return candidates


def execute_batch_payment(branch, staff, payment_date, source_account_id, lines, payable_account_id=None, notes=""):
    """
    Create and post a batch supplier payment: DR Accounts Payable, CR Bank for
    the net of each line. Withholding (if any) is credited to a WHT payable via
    the supplied payable account fallback. One balanced journal for the batch.
    """
    from procurement.models import PurchaseOrder
    from .posting_engine import _account_by_number

    pay_date = fr.parse_finance_date(payment_date) if isinstance(payment_date, str) else (payment_date or date.today())
    source_account = ChartOfAccount.objects.get(id=source_account_id, company_branch=branch)
    payable_account = (
        ChartOfAccount.objects.get(id=payable_account_id)
        if payable_account_id else _account_by_number(branch, "2000")
    )

    clean_lines = []
    total_net = Decimal("0")
    for raw in lines:
        amount = Decimal(str(raw.get("amount") or "0").replace(",", "") or "0")
        wht = Decimal(str(raw.get("wht_amount") or "0").replace(",", "") or "0")
        net = amount - wht
        if net <= 0:
            continue
        po = None
        if raw.get("purchase_order_id"):
            po = PurchaseOrder.objects.filter(id=int(raw["purchase_order_id"])).first()
        clean_lines.append({"po": po, "amount": amount, "wht": wht, "net": net,
                            "supplier_name": raw.get("supplier_name", "")})
        total_net += net

    if not clean_lines:
        return {"error": "No payable lines with a positive net amount"}

    batch = BatchPayment.objects.create(
        company_branch=branch,
        payment_date=pay_date,
        source_account=source_account,
        payable_account=payable_account,
        total_amount=total_net,
        notes=notes,
        created_by=staff,
    )

    entry = ManualJournalEntry.objects.create(
        company_branch=branch,
        entry_date=pay_date,
        description=f"Batch supplier payment {batch.batch_number}",
        reference=batch.batch_number,
        source_type="batch_payment",
        source_id=str(batch.id),
        created_by=staff, last_updated_by=staff,
    )

    def _post(account, debit, credit, description, rel_id):
        ManualJournalLine.objects.create(
            journal_entry=entry, account=account, debit=debit, credit=credit, description=description)
        AccountHistory.objects.create(
            company_branch=branch, account=account, debit=debit, credit=credit,
            transaction_date=pay_date, rel_type="batch_payment", rel_id=str(rel_id),
            description=description, created_by=staff)

    for cl in clean_lines:
        ref = cl["po"].purchase_order_number if cl["po"] else cl["supplier_name"]
        BatchPaymentLine.objects.create(
            batch_payment=batch, purchase_order=cl["po"], supplier_name=cl["supplier_name"],
            amount=cl["amount"], wht_amount=cl["wht"], net_amount=cl["net"], reference=ref)
        # DR Accounts Payable for the gross owed, recorded against the PO for ageing.
        _post(payable_account, cl["amount"], Decimal("0"),
              f"Batch settlement {ref}", cl["po"].id if cl["po"] else batch.id)
        if cl["po"]:
            # mirror to purchase_order rel so AP ageing recognises the settlement
            AccountHistory.objects.create(
                company_branch=branch, account=payable_account,
                debit=Decimal("0"), credit=cl["amount"],
                transaction_date=pay_date, rel_type="purchase_order", rel_id=str(cl["po"].id),
                description=f"Batch settlement {ref}", created_by=staff)

    # CR bank for total net actually paid out.
    _post(source_account, Decimal("0"), total_net, f"Batch payout {batch.batch_number}", batch.id)

    # CR WHT payable for total withheld.
    total_wht = sum((cl["wht"] for cl in clean_lines), Decimal("0"))
    if total_wht > 0:
        wht_payable = _account_by_number(branch, "2100")
        _post(wht_payable, Decimal("0"), total_wht, f"Withholding tax {batch.batch_number}", batch.id)

    batch.status = "executed"
    batch.journal_entry = entry
    batch.executed_by = staff
    batch.executed_on = timezone.now()
    batch.save()

    log_finance_audit(branch, staff, "batch_payment", batch.id, "executed",
                      f"{len(clean_lines)} lines, net {float(total_net)}")
    return {
        "batch_number": batch.batch_number,
        "lines": len(clean_lines),
        "total_net": float(total_net),
        "total_wht": float(total_wht),
        "entry_number": entry.entry_number,
    }


# ---------------------------------------------------------------------------
# EMAIL DISPATCH FOR PAYMENT REMINDERS
# ---------------------------------------------------------------------------

def dispatch_due_reminders(branch, limit=200):
    """Send queued, unsent reminders by email. Marks each sent/failed."""
    from django.conf import settings
    from django.core.mail import send_mail

    pending = ScheduledPaymentReminder.objects.filter(
        company_branch=branch, sent=False,
    ).exclude(recipient_email="").order_by("due_date")[:limit]

    sent = 0
    failed = 0
    from_email = getattr(settings, "EMAIL_HOST_USER", None) or "no-reply@megawatt.local"
    for reminder in pending:
        subject = {
            "ar_due": "Payment Due — Outstanding Invoice",
            "credit_follow_up": "Friendly Reminder — Account Balance Due",
            "ap_due": "Supplier Payment Due",
        }.get(reminder.reminder_type, "Payment Reminder")
        body = (
            f"Dear Customer,\n\nThis is a reminder that an amount of "
            f"KES {reminder.amount_due} was due on {reminder.due_date}.\n"
            f"Kindly arrange settlement at your earliest convenience.\n\n"
            f"Regards,\nFinance Department"
        )
        try:
            send_mail(subject, body, from_email, [reminder.recipient_email], fail_silently=False)
            reminder.sent = True
            reminder.sent_on = timezone.now()
            reminder.last_error = ""
            reminder.save()
            sent += 1
        except Exception as exc:  # noqa: BLE001
            reminder.last_error = str(exc)[:300]
            reminder.save()
            failed += 1

    return {"sent": sent, "failed": failed, "queued_remaining": ScheduledPaymentReminder.objects.filter(
        company_branch=branch, sent=False,
    ).count()}
