# SPDX-License-Identifier: Apache-2.0
# Copyright (C) 2025 MINERVA European Support Centre contributors.
# See https://www.apache.org/licenses/LICENSE-2.0 for the full license text.

# ===========================================================================
# Imports and Initializations
# ===========================================================================
from __future__ import annotations
import argparse
import logging
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from monai.networks.nets import resnet18
from monai.transforms import (
    Compose, RandFlip, RandRotate90, RandZoom,
    RandGaussianNoise, RandShiftIntensity, RandScaleIntensity
)
from sklearn.metrics import roc_auc_score
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler

logging.basicConfig(level=logging.INFO,
                    format="%(asctime)s  %(levelname)-8s  %(message)s")
log = logging.getLogger(__name__)

# ======================================================================
# Normalisation
# ======================================================================

def normalize_patch(volume: np.ndarray) -> np.ndarray:
    """
    Normalise a patch that was saved by 1_ProcessData.py.

    Applies hu_window preprocessing and saves patches in
    [0, 255] float32.
    """
    return (volume / 127.5 - 1.0).astype(np.float32)


def safe_auc(y_true, y_prob):
    return roc_auc_score(y_true, y_prob) if len(set(y_true)) > 1 else 0.5

# ======================================================================
# Dataset
# ======================================================================

class LIDCDataset(Dataset):
    def __init__(self, data_dir: str, split: str, augment: bool = False):
        self.data_dir = Path(data_dir)
        self.items = []

        img_dir = self.data_dir / "images" / split
        if not img_dir.exists():
            raise FileNotFoundError(f"Image directory not found: {img_dir}")

        for img_path in sorted(img_dir.glob("*.npy")):
            if "_pos_" in img_path.stem:
                label = 1
            elif "_neg_" in img_path.stem:
                label = 0
            else:
                log.warning("Skipping patch with unrecognised stem: %s", img_path.stem)
                continue
            self.items.append((img_path, label))

        n_pos = sum(1 for _, l in self.items if l == 1)
        n_neg = sum(1 for _, l in self.items if l == 0)
        log.info("LIDC [%s]: %d pos + %d neg = %d samples",
                 split, n_pos, n_neg, len(self.items))

        self.aug = Compose([
            RandFlip(spatial_axis=0, prob=0.5),
            RandFlip(spatial_axis=1, prob=0.5),
            RandFlip(spatial_axis=2, prob=0.5),
            RandRotate90(prob=0.5),
            RandZoom(prob=0.3, min_zoom=0.9, max_zoom=1.1),
            RandGaussianNoise(prob=0.2, std=0.01),
            RandShiftIntensity(prob=0.3, offsets=0.1),
            RandScaleIntensity(prob=0.3, factors=0.1),
        ]) if augment else None

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        path, label = self.items[idx]
        volume = np.load(path).astype(np.float32)

        volume = normalize_patch(volume)
        volume = np.expand_dims(volume, 0)  

        if self.aug:
            volume = self.aug(volume)

        if hasattr(volume, 'numpy'):
            volume = volume.numpy()

        return torch.from_numpy(volume.copy()), torch.tensor(label)

    def class_weights(self):
        labels = [l for _, l in self.items]
        counts = {0: labels.count(0), 1: labels.count(1)}
        return torch.DoubleTensor([1.0 / counts[l] for l in labels])

# ======================================================================
# Model
# ======================================================================

def build_model(device, pretrained_path=None, dropout=0.4):
    model = resnet18(
        pretrained=False,
        n_input_channels=1,
        num_classes=2,
        spatial_dims=3,
    )

    if pretrained_path:
        state_dict = torch.load(pretrained_path, map_location="cpu")
        new_state = {}
        for k, v in state_dict.items():
            if "fc" in k:
                continue
            if k == "conv1.weight" and v.shape[1] == 3:
                v = v.mean(1, keepdim=True)
            new_state[k] = v
        model.load_state_dict(new_state, strict=False)
        log.info("Loaded Med3D weights")

    in_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(p=dropout),
        nn.Linear(in_features, 2)
    )

    return model.to(device)

# ======================================================================
# Training
# ======================================================================

def run_epoch(model, loader, criterion, optimizer, scaler, device, amp, train):
    model.train(train)
    total_loss, correct, n = 0, 0, 0
    all_probs, all_labels = [], []

    with torch.set_grad_enabled(train):
        for x, y in loader:
            x, y = x.to(device), y.to(device)

            with torch.autocast(device_type="cuda", enabled=amp):
                logits = model(x)
                loss = criterion(logits, y)

            if train:
                optimizer.zero_grad(set_to_none=True)
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()

            probs = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy()
            all_probs.extend(probs)
            all_labels.extend(y.cpu().numpy())

            total_loss += loss.item() * len(y)
            correct += (logits.argmax(1) == y).sum().item()
            n += len(y)

    auc = safe_auc(all_labels, all_probs)
    return total_loss / n, correct / n, auc

# ======================================================================
# Main
# ======================================================================

def main(args):
    device = args.device

    train_ds = LIDCDataset(args.data_dir, "train", True)
    val_ds   = LIDCDataset(args.data_dir, "val", False)
    test_ds  = LIDCDataset(args.data_dir, "test", False)

    sampler = WeightedRandomSampler(
        train_ds.class_weights(),
        num_samples=len(train_ds),
        replacement=True
    )

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, sampler=sampler)
    val_loader   = DataLoader(val_ds,   batch_size=args.batch_size)
    test_loader  = DataLoader(test_ds,  batch_size=args.batch_size)

    model = build_model(device, args.pretrained, args.dropout)

    labels = [l for _, l in train_ds.items]
    counts = np.bincount(labels)
    counts = np.maximum(counts, 1)
    weights = 1.0 / counts
    weights = weights / weights.sum()
    class_weights = torch.tensor(weights, dtype=torch.float32).to(device)
    log.info("Class counts: %s  weights: %s", counts, weights)

    criterion = nn.CrossEntropyLoss(
        weight=class_weights,
        label_smoothing=args.label_smoothing
    )

    amp = args.amp and device.startswith("cuda")
    scaler = torch.amp.GradScaler('cuda', enabled=amp)

    # Phase 1: 
    for p in model.parameters():
        p.requires_grad = False
    for p in model.fc.parameters():
        p.requires_grad = True

    optimizer = AdamW(model.fc.parameters(), lr=args.lr_head,
                      weight_decay=args.weight_decay)
    scheduler = CosineAnnealingLR(optimizer, T_max=args.phase1_epochs)

    best_auc = 0.0
    out_dir = Path(args.output_dir)
    out_dir.mkdir(exist_ok=True)

    log.info("=== Phase 1: head warmup (%d epochs) ===", args.phase1_epochs)
    for epoch in range(1, args.phase1_epochs + 1):
        tr = run_epoch(model, train_loader, criterion, optimizer, scaler,
                       device, amp, True)
        va = run_epoch(model, val_loader,   criterion, optimizer, scaler,
                       device, amp, False)
        scheduler.step()
        log.info("P1 ep%02d  train loss=%.4f acc=%.3f AUC=%.3f  |  "
                 "val loss=%.4f acc=%.3f AUC=%.3f",
                 epoch, tr[0], tr[1], tr[2], va[0], va[1], va[2])
        if va[2] > best_auc:
            best_auc = va[2]
            torch.save(model.state_dict(), out_dir / "best.pth")
            log.info("  -> New best val AUC=%.4f  (saved)", best_auc)

    # Phase 2: 
    for name, p in model.named_parameters():
        p.requires_grad = (
            (not args.freeze_layer3 and "layer3" in name) or
            "layer4" in name or
            "fc"     in name
        )

    param_groups = [
        {"params": model.layer4.parameters(), "lr": args.lr_backbone},
        {"params": model.fc.parameters(),     "lr": args.lr_head},
    ]
    if not args.freeze_layer3:
        param_groups.insert(0, {
            "params": model.layer3.parameters(), "lr": args.lr_backbone})

    optimizer = AdamW(param_groups, weight_decay=args.weight_decay)
    scheduler = CosineAnnealingLR(optimizer,
                                   T_max=args.epochs - args.phase1_epochs)

    log.info("=== Phase 2: fine-tune (%d epochs) ===",
             args.epochs - args.phase1_epochs)
    for epoch in range(args.phase1_epochs + 1, args.epochs + 1):
        tr = run_epoch(model, train_loader, criterion, optimizer, scaler,
                       device, amp, True)
        va = run_epoch(model, val_loader,   criterion, optimizer, scaler,
                       device, amp, False)
        scheduler.step()
        log.info("P2 ep%02d  train loss=%.4f acc=%.3f AUC=%.3f  |  "
                 "val loss=%.4f acc=%.3f AUC=%.3f",
                 epoch, tr[0], tr[1], tr[2], va[0], va[1], va[2])
        if va[2] > best_auc:
            best_auc = va[2]
            torch.save(model.state_dict(), out_dir / "best.pth")
            log.info("  -> New best val AUC=%.4f  (saved)", best_auc)

    # Test phase
    model.load_state_dict(torch.load(out_dir / "best.pth", map_location=device))
    te = run_epoch(model, test_loader, criterion, optimizer, scaler,
                   device, amp, False)
    log.info("=== TEST  loss=%.4f  acc=%.3f  AUC=%.4f ===", te[0], te[1], te[2])

# ======================================================================
# main function
# ======================================================================

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--data-dir",    required=True)
    parser.add_argument("--output-dir",  default="./checkpoints")
    parser.add_argument("--pretrained",  type=str, default=None)

    parser.add_argument("--epochs",        type=int,   default=30)
    parser.add_argument("--phase1-epochs", type=int,   default=5)
    parser.add_argument("--batch-size",    type=int,   default=8)

    parser.add_argument("--lr-head",     type=float, default=3e-4)
    parser.add_argument("--lr-backbone", type=float, default=1e-5)

    parser.add_argument("--amp",           action="store_true", default=True)
    parser.add_argument("--device",        default="cuda" if torch.cuda.is_available() else "cpu")

    parser.add_argument("--label-smoothing", type=float, default=0.05)
    parser.add_argument("--dropout",         type=float, default=0.0)
    parser.add_argument("--weight-decay",    type=float, default=1e-2)
    parser.add_argument("--freeze-layer3",   action="store_true", default=False)

    main(parser.parse_args())