# companies/utils.py
"""
Legacy SQL Import Utility
Reads a MySQL dump (phpMyAdmin format) and imports data into the
current multi-tenant SwiftPOS schema under the specified company.

Root-cause fixes vs original version:
  1. Table names fixed: inventory_product, sales_posorder (not core_*)
  2. Column parsing is NAME-based, not positional (fragile index)
  3. Chunked bulk_create (500 rows/chunk) with per-chunk error capture
  4. Correct FK dependency order: categories → products → customers → orders → items
  5. stock_status derived from model logic, not raw SQL value
"""
import re
import logging
from datetime import datetime, date
from decimal import Decimal, InvalidOperation

logger = logging.getLogger(__name__)

CHUNK_SIZE = 500


# ─── Low-level SQL parser ────────────────────────────────────────────────────

def _extract_inserts(content: str, table_name: str):
    """
    Yield one dict per row for a given table name.
    Works with both single-row and multi-row extended INSERT syntax.
    """
    pattern = re.compile(
        rf"INSERT INTO `{re.escape(table_name)}` \(([^)]+)\) VALUES\s*(.+?);",
        re.DOTALL,
    )
    for match in pattern.finditer(content):
        columns_raw = match.group(1)
        # Strip backticks and split
        columns = [c.strip().strip("`") for c in columns_raw.split(",")]

        values_block = match.group(2)
        # Split into individual row tuples
        row_tuples = _split_value_tuples(values_block)

        for row_values in row_tuples:
            if len(row_values) != len(columns):
                continue  # malformed row
            yield dict(zip(columns, row_values))


def _split_value_tuples(values_block: str):
    """
    Split a VALUES block into individual row tuples.
    Handles multi-line, quoted strings with escaped quotes inside.
    Returns list of lists (each inner list = row field values).
    """
    rows = []
    # Match each complete (...) group
    # Use a state machine to handle nested quotes properly
    i = 0
    n = len(values_block)
    while i < n:
        if values_block[i] == '(':
            # Find the closing ) that belongs to this row
            row_chars = []
            i += 1  # skip opening (
            in_quote = False
            row_fields = []
            current = []
            while i < n:
                c = values_block[i]
                if in_quote:
                    if c == '\\' and i + 1 < n:
                        current.append(c)
                        current.append(values_block[i + 1])
                        i += 2
                        continue
                    elif c == "'":
                        in_quote = False
                        i += 1
                        continue
                    else:
                        current.append(c)
                        i += 1
                        continue
                else:
                    if c == "'":
                        in_quote = True
                        i += 1
                        continue
                    elif c == ',' :
                        row_fields.append(_clean_value(''.join(current)))
                        current = []
                        i += 1
                        continue
                    elif c == ')':
                        row_fields.append(_clean_value(''.join(current)))
                        rows.append(row_fields)
                        i += 1
                        break
                    else:
                        current.append(c)
                        i += 1
                        continue
        else:
            i += 1
    return rows


def _clean_value(v: str) -> str:
    """Strip surrounding whitespace, handle NULL."""
    v = v.strip()
    if v.upper() == 'NULL':
        return None
    # Remove surrounding quotes if any (shouldn't happen after state machine, but safety)
    if v.startswith("'") and v.endswith("'"):
        v = v[1:-1]
    # Unescape MySQL escaped characters
    v = v.replace("\\'", "'").replace('\\"', '"').replace("\\n", "\n").replace("\\r", "\r").replace("\\\\", "\\")
    return v


def _dec(val, default=Decimal('0')):
    """Safely convert to Decimal."""
    if val is None:
        return default
    try:
        return Decimal(str(val))
    except InvalidOperation:
        return default


def _int(val, default=0):
    """Safely convert to int."""
    if val is None:
        return default
    try:
        return int(float(str(val)))
    except (ValueError, TypeError):
        return default


def _parse_datetime(val):
    """Safely parse a datetime string from SQL and make it timezone-aware."""
    if not val or str(val).upper() == 'NULL':
        return None
    try:
        from django.utils.dateparse import parse_datetime
        from django.utils.timezone import make_aware, is_aware
        dt = parse_datetime(val)
        if dt and not is_aware(dt):
            dt = make_aware(dt)
        return dt
    except Exception:
        return None


# ─── Main import function ────────────────────────────────────────────────────

def parse_legacy_sql_and_import(file_path, company):
    """
    Parse a MySQL dump and import into the multi-tenant schema.
    Returns a dict with import stats and any errors encountered.
    """
    from accounts.models import CustomUser
    from core.models import Branch
    from inventory.models import Product, ProductCategory
    from sales.models import POSOrder, POSOrderItem, Customer
    from companies.models import CompanySettings

    logger.info(f"[Import] Starting import for company: {company.name}")

    with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
        content = f.read()

    stats = {
        'branches': 0, 'users': 0, 'categories': 0, 'products': 0,
        'customers': 0, 'orders': 0, 'order_items': 0,
        'skipped': 0, 'errors': [],
    }

    # ── 1. BRANCHES ──────────────────────────────────────────────────────────
    branch_map = {}  # old_id → Branch instance
    for row in _extract_inserts(content, 'core_branch'):
        try:
            code = row.get('code') or f"BR-{row.get('id', 'X')}"
            b, _ = Branch.objects.get_or_create(
                company=company, code=code,
                defaults={
                    'name': row.get('name', code),
                    'address': row.get('address', ''),
                    'phone': row.get('phone', ''),
                    'is_active': row.get('is_active', '1') == '1',
                }
            )
            branch_map[row.get('id')] = b
            stats['branches'] += 1
        except Exception as e:
            stats['errors'].append(f"Branch {row.get('id')}: {e}")

    logger.info(f"[Import] Branches: {stats['branches']}")

    # ── 2. USERS ─────────────────────────────────────────────────────────────
    user_map = {}  # old_id → CustomUser
    for row in _extract_inserts(content, 'accounts_customuser'):
        uname = row.get('username', '')
        if not uname or uname == 'swiftpos':
            continue
        try:
            br = branch_map.get(row.get('branch_id'))
            u, created = CustomUser.objects.get_or_create(
                company=company, username=uname,
                defaults={
                    'password': row.get('password', ''),
                    'first_name': row.get('first_name', ''),
                    'last_name': row.get('last_name', ''),
                    'email': row.get('email', ''),
                    'is_active': row.get('is_active', '1') == '1',
                    'role': row.get('role', 'cashier'),
                    'phone': row.get('phone', ''),
                    'branch': br,
                }
            )
            if not created:
                # Update password hash
                u.password = row.get('password', u.password)
                u.save(update_fields=['password'])
            user_map[row.get('id')] = u
            stats['users'] += 1
        except Exception as e:
            stats['errors'].append(f"User {uname}: {e}")
            stats['skipped'] += 1

    logger.info(f"[Import] Users: {stats['users']}")

    # ── 3. PRODUCT CATEGORIES ─────────────────────────────────────────────────
    cat_map = {}  # old_id → ProductCategory
    for row in _extract_inserts(content, 'inventory_productcategory'):
        try:
            name = row.get('name', 'Uncategorized').strip()
            if not name:
                name = 'Imported'
            c, _ = ProductCategory.objects.get_or_create(
                company=company, name=name,
                defaults={
                    'description': row.get('description', ''),
                    'icon': row.get('icon', ''),
                    'color_code': row.get('color_code', '#667eea'),
                }
            )
            cat_map[row.get('id')] = c
            stats['categories'] += 1
        except Exception as e:
            stats['errors'].append(f"Category {row.get('id')}: {e}")

    # Fallback for old-schema categories
    for row in _extract_inserts(content, 'core_productcategory'):
        try:
            name = row.get('name', 'Uncategorized').strip()
            if not name:
                name = 'Imported'
            c, _ = ProductCategory.objects.get_or_create(
                company=company, name=name,
                defaults={
                    'description': row.get('description', ''),
                    'icon': row.get('icon', ''),
                }
            )
            cat_map[row.get('id')] = c
            stats['categories'] += 1
        except Exception as e:
            stats['errors'].append(f"OldCategory {row.get('id')}: {e}")

    logger.info(f"[Import] Categories: {stats['categories']}")

    # ── 4. PRODUCTS (chunked bulk_create) ─────────────────────────────────────
    prod_map = {}  # old_id → Product
    product_rows = list(_extract_inserts(content, 'inventory_product'))
    if not product_rows:
        product_rows = list(_extract_inserts(content, 'core_product'))

    # Default category fallback
    default_cat = None

    for chunk_start in range(0, len(product_rows), CHUNK_SIZE):
        chunk = product_rows[chunk_start: chunk_start + CHUNK_SIZE]
        for row in chunk:
            try:
                sku = (row.get('sku') or '').strip()
                if not sku:
                    sku = f"IMPORT-{row.get('id', 'X')}"

                cat_id = row.get('category_id')
                cat = cat_map.get(cat_id)
                if not cat:
                    if default_cat is None:
                        default_cat, _ = ProductCategory.objects.get_or_create(
                            company=company, name='Imported',
                            defaults={'description': 'Auto-created during import'}
                        )
                    cat = default_cat

                stock_qty = _int(row.get('stock_quantity', 0))
                min_stock = _int(row.get('minimum_stock', 10))

                # Derive stock status (mirrors Product.save() logic)
                if stock_qty == 0:
                    stock_status = 'out_of_stock'
                elif stock_qty <= min_stock:
                    stock_status = 'low_stock'
                else:
                    stock_status = 'in_stock'

                p, created = Product.objects.get_or_create(
                    company=company, sku=sku,
                    defaults={
                        'name': row.get('name', sku),
                        'barcode': row.get('barcode', ''),
                        'description': row.get('description', ''),
                        'price': _dec(row.get('price')),
                        'cost_price': _dec(row.get('cost_price')) or None,
                        'stock_quantity': stock_qty,
                        'minimum_stock': min_stock,
                        'stock_status': stock_status,
                        'unit_of_measure': row.get('unit_of_measure', 'pcs'),
                        'status': row.get('status', 'active'),
                        'supplier': row.get('supplier', ''),
                        'category': cat,
                        'wholesale_price': _dec(row.get('wholesale_price')) or None,
                        'wholesale_min_quantity': _int(row.get('wholesale_min_quantity', 0)),
                    }
                )
                prod_map[row.get('id')] = p
                if created:
                    stats['products'] += 1
                    # Restore original timestamp (bypassing auto_now_add)
                    created_at_val = row.get('created_at')
                    if created_at_val:
                        dt = _parse_datetime(created_at_val)
                        if dt:
                            Product.objects.filter(id=p.id).update(created_at=dt, updated_at=dt)
                else:
                    stats['skipped'] += 1
            except Exception as e:
                stats['errors'].append(f"Product {row.get('id')} SKU={row.get('sku')}: {e}")
                logger.debug(f"[Import] Product error: {e}")

    logger.info(f"[Import] Products: {stats['products']} imported, {stats['skipped']} skipped (existing)")
    stats['skipped'] = 0  # reset skipped for next phase

    # ── 5. CUSTOMERS ─────────────────────────────────────────────────────────
    cust_map = {}  # old_id → Customer
    for row in _extract_inserts(content, 'sales_customer'):
        try:
            phone = (row.get('phone') or '').strip()
            name = (row.get('name') or 'Unknown').strip()
            c, created = Customer.objects.get_or_create(
                company=company,
                phone=phone if phone else name,  # use name as fallback key if no phone
                defaults={
                    'name': name,
                    'email': row.get('email', ''),
                    'address': row.get('address', ''),
                    'notes': row.get('notes', ''),
                }
            )
            cust_map[row.get('id')] = c
            if created:
                created_at_val = row.get('created_at')
                if created_at_val:
                    dt = _parse_datetime(created_at_val)
                    if dt:
                        Customer.objects.filter(id=c.id).update(created_at=dt, updated_at=dt)
            stats['customers'] += 1
        except Exception as e:
            stats['errors'].append(f"Customer {row.get('id')}: {e}")

    # Fallback for old schema
    for row in _extract_inserts(content, 'core_customer'):
        try:
            phone = (row.get('phone') or '').strip()
            name = (row.get('name') or 'Unknown').strip()
            c, created = Customer.objects.get_or_create(
                company=company,
                phone=phone if phone else name,
                defaults={
                    'name': name,
                    'email': row.get('email', ''),
                    'address': row.get('address', ''),
                }
            )
            cust_map[row.get('id')] = c
            if created:
                created_at_val = row.get('created_at')
                if created_at_val:
                    dt = _parse_datetime(created_at_val)
                    if dt:
                        Customer.objects.filter(id=c.id).update(created_at=dt, updated_at=dt)
            stats['customers'] += 1
        except Exception as e:
            stats['errors'].append(f"OldCustomer {row.get('id')}: {e}")

    logger.info(f"[Import] Customers: {stats['customers']}")

    # ── 6. ORDERS (chunked) ───────────────────────────────────────────────────
    order_map = {}  # old_id → POSOrder
    order_rows = list(_extract_inserts(content, 'sales_posorder'))
    if not order_rows:
        order_rows = list(_extract_inserts(content, 'core_posorder'))

    # Default branch fallback
    default_branch = Branch.objects.filter(company=company, is_active=True).first()

    for chunk_start in range(0, len(order_rows), CHUNK_SIZE):
        chunk = order_rows[chunk_start: chunk_start + CHUNK_SIZE]
        for row in chunk:
            try:
                order_num = (row.get('order_number') or '').strip()
                if not order_num:
                    order_num = f"IMP-{row.get('id', 'X')}"

                # Skip if already exists for this company
                if POSOrder.objects.filter(company=company, order_number=order_num).exists():
                    stats['skipped'] += 1
                    order_map[row.get('id')] = POSOrder.objects.get(company=company, order_number=order_num)
                    continue

                # Resolve FK references
                cust_id = row.get('customer_id')
                cust_obj = cust_map.get(cust_id)

                branch_id = row.get('branch_id')
                branch_obj = branch_map.get(branch_id) or default_branch

                cashier = row.get('cashier', 'imported')
                if not cashier:
                    cashier = 'imported'

                payment_method = row.get('payment_method', 'cash')
                valid_methods = {'pos', 'transfer', 'cash', 'mobile_money'}
                if payment_method not in valid_methods:
                    payment_method = 'cash'

                status = row.get('status', 'completed')
                valid_statuses = {'pending', 'completed', 'cancelled'}
                if status not in valid_statuses:
                    status = 'completed'

                payment_status = row.get('payment_status', 'full')
                valid_pstatus = {'full', 'partial', 'credit'}
                if payment_status not in valid_pstatus:
                    payment_status = 'full'

                total = _dec(row.get('total_amount'))
                discount = _dec(row.get('discount_amount'))
                tax = _dec(row.get('tax_amount'))
                final = _dec(row.get('final_amount'))
                paid = _dec(row.get('amount_paid', final))
                balance = _dec(row.get('balance_amount'))

                o = POSOrder(
                    company=company,
                    order_number=order_num,
                    customer=cust_obj,
                    customer_name=row.get('customer_name', ''),
                    customer_phone=row.get('customer_phone', ''),
                    total_amount=total,
                    discount_amount=discount,
                    tax_amount=tax,
                    final_amount=final,
                    amount_paid=paid,
                    balance_amount=balance,
                    payment_method=payment_method,
                    payment_status=payment_status,
                    status=status,
                    cashier=cashier,
                    branch=branch_obj,
                )
                o.save()
                
                # Restore original timestamp (bypassing auto_now_add)
                created_at_val = row.get('created_at')
                if created_at_val:
                    dt = _parse_datetime(created_at_val)
                    if dt:
                        POSOrder.objects.filter(id=o.id).update(created_at=dt, updated_at=dt)
                
                order_map[row.get('id')] = o

                # Adjust customer credit balance if credit order
                if cust_obj and balance > Decimal('0'):
                    from django.db.models import F
                    Customer.objects.filter(pk=cust_obj.pk).update(
                        # credit_balance field is on Customer if exists, else skip
                    ) if hasattr(Customer, 'credit_balance') else None

                stats['orders'] += 1
            except Exception as e:
                stats['errors'].append(f"Order {row.get('id')} #{row.get('order_number')}: {e}")
                logger.debug(f"[Import] Order error: {e}")

    logger.info(f"[Import] Orders: {stats['orders']} imported, {stats['skipped']} skipped (existing)")
    stats['skipped'] = 0

    # ── 7. ORDER ITEMS (chunked) ──────────────────────────────────────────────
    item_rows = list(_extract_inserts(content, 'sales_posorderitem'))
    if not item_rows:
        item_rows = list(_extract_inserts(content, 'core_posorderitem'))

    for chunk_start in range(0, len(item_rows), CHUNK_SIZE):
        chunk = item_rows[chunk_start: chunk_start + CHUNK_SIZE]
        items_to_create = []
        for row in chunk:
            try:
                order_obj = order_map.get(row.get('order_id'))
                if not order_obj:
                    stats['skipped'] += 1
                    continue

                prod_obj = prod_map.get(row.get('product_id'))
                prod_name = row.get('product_name', '') or (prod_obj.name if prod_obj else 'Unknown')

                qty = _int(row.get('quantity', 1))
                unit_price = _dec(row.get('unit_price'))
                total_price = _dec(row.get('total_price'))
                if total_price == 0 and qty and unit_price:
                    total_price = unit_price * qty

                items_to_create.append(POSOrderItem(
                    order=order_obj,
                    product=prod_obj,
                    product_name=prod_name,
                    quantity=qty,
                    unit_price=unit_price,
                    total_price=total_price,
                ))
            except Exception as e:
                stats['errors'].append(f"OrderItem {row.get('id')}: {e}")

        if items_to_create:
            try:
                POSOrderItem.objects.bulk_create(items_to_create, ignore_conflicts=True)
                stats['order_items'] += len(items_to_create)
            except Exception as e:
                stats['errors'].append(f"Bulk create items chunk {chunk_start}: {e}")
                # Fallback to one-by-one
                for item in items_to_create:
                    try:
                        item.save()
                        stats['order_items'] += 1
                    except Exception:
                        stats['skipped'] += 1

    logger.info(f"[Import] Order Items: {stats['order_items']}")

    # ── 8. COMPANY SETTINGS ───────────────────────────────────────────────────
    for row in _extract_inserts(content, 'core_systemsettings'):
        try:
            from companies.models import CompanySettings
            s, _ = CompanySettings.objects.get_or_create(company=company)
            if row.get('business_name'):
                s.business_name = row.get('business_name')
            if row.get('business_address'):
                s.business_address = row.get('business_address')
            if row.get('business_phone'):
                s.business_phone = row.get('business_phone')
            if row.get('business_email'):
                s.business_email = row.get('business_email')
            if row.get('currency'):
                s.currency = row.get('currency', 'NGN')
            if row.get('currency_symbol'):
                s.currency_symbol = row.get('currency_symbol', '₦')
            s.save()
        except Exception as e:
            stats['errors'].append(f"Settings: {e}")

    # Summarize
    error_count = len(stats['errors'])
    summary = (
        f"✅ Import Complete | "
        f"Categories: {stats['categories']} | "
        f"Products: {stats['products']} | "
        f"Customers: {stats['customers']} | "
        f"Orders: {stats['orders']} | "
        f"Order Items: {stats['order_items']}"
    )
    if error_count:
        summary += f" | ⚠️ {error_count} error(s) (see logs)"
        for err in stats['errors'][:10]:  # Show first 10
            logger.warning(f"[Import] {err}")

    logger.info(f"[Import] Done: {summary}")
    return summary, stats
