# SPDX-License-Identifier: Apache-2.0
# Copyright (C) 2025 MINERVA European Support Centre contributors.
# See https://www.apache.org/licenses/LICENSE-2.0 for the full license text.

# ===========================================================================
# Imports and Initializations
# ===========================================================================
from __future__ import annotations
import argparse
import csv
import json
import logging
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from monai.networks.nets import resnet18
from torch.utils.data import DataLoader, Dataset
from captum.attr import IntegratedGradients, NoiseTunnel

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s  %(levelname)-8s  %(message)s",
    datefmt="%H:%M:%S",
)
log = logging.getLogger(__name__)

XAI_METHODS = ["gradcam", "gradcam++", "integrated-gradients", "occlusion"]
LABEL_NAMES = {0: "background", 1: "nodule"}


# ===========================================================================
# Model
# ===========================================================================

def load_model(checkpoint, device):
    model = resnet18(
        pretrained=False,
        n_input_channels=1,
        num_classes=2,
        spatial_dims=3,
    )
    in_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(p=0.4),
        nn.Linear(in_features, 2)
    )
    state = torch.load(checkpoint, map_location="cpu")
    model.load_state_dict(state)
    model.eval()
    return model.to(device)


def wrap_model(model: nn.Module, n_gpus: int, device: torch.device) -> nn.Module:
    if n_gpus > 1 and torch.cuda.device_count() >= n_gpus:
        log.info("Using DataParallel across %d GPUs.", n_gpus)
        model = nn.DataParallel(model, device_ids=list(range(n_gpus)))
    return model.to(device)


def unwrap(model: nn.Module) -> nn.Module:
    return model.module if isinstance(model, nn.DataParallel) else model


# ===========================================================================
# Preprocessing
# ===========================================================================

def preprocess(patch: np.ndarray) -> torch.Tensor:
    """Normalise [0, 255] -> [-1, 1], add batch + channel dims -> (1, 1, D, H, W)."""
    t = torch.from_numpy(patch.astype(np.float32))
    t = t / 127.5 - 1.0
    return t.unsqueeze(0).unsqueeze(0)


# ===========================================================================
# Dataset
# ===========================================================================

class PatchDataset(Dataset):
    def __init__(self, data_dir: str, split: str = "test"):
        self.data_dir = Path(data_dir)
        self.items: List[Tuple[Path, int, str]] = []
        for p in sorted((self.data_dir / "images" / split).glob("*.npy")):
            label = 1 if "_pos_" in p.stem else (0 if "_neg_" in p.stem else -1)
            self.items.append((p, label, p.stem))
        log.info("PatchDataset: %d samples (split=%s)", len(self.items), split)

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        path, label, stem = self.items[idx]
        patch  = np.load(path).astype(np.float32)
        tensor = torch.from_numpy(patch / 127.5 - 1.0).unsqueeze(0)
        return tensor, label, stem


# ===========================================================================
# Inference
# ===========================================================================

@torch.no_grad()
def predict(model: nn.Module, x: torch.Tensor,
            device: torch.device) -> Tuple[int, float]:
    logits = model(x.to(device))
    probs  = F.softmax(logits, dim=1)
    return int(probs.argmax(1).item()), float(probs[0, 1].item())


# ===========================================================================
# XAI helpers
# ===========================================================================

def _normalise(arr: np.ndarray, percentile: float = 99.0) -> np.ndarray:
    lo = arr.min()
    hi = float(np.percentile(arr, percentile))
    if hi - lo < 1e-8:
        hi = arr.max()
    return np.clip((arr - lo) / (hi - lo + 1e-8), 0.0, 1.0)
def run_gradcam(
    model:     nn.Module,
    x:         torch.Tensor,
    class_idx: int,
    variant:   str = "gradcam",
) -> np.ndarray:
    base         = unwrap(model).to(x.device)
    target_layer = base.layer4[-1]
    activations, gradients = {}, {}

    def fwd_hook(module, inp, out):
        activations["v"] = out

    def bwd_hook(module, gin, gout):
        gradients["v"] = gout[0]

    fh = target_layer.register_forward_hook(fwd_hook)
    bh = target_layer.register_full_backward_hook(bwd_hook)

    base.zero_grad()
    logits = base(x)
    logits[0, class_idx].backward()

    fh.remove()
    bh.remove()

    acts  = activations["v"].detach()
    grads = gradients["v"].detach()

    if variant == "gradcam":
        weights = grads.mean(dim=(2, 3, 4), keepdim=True)
    else:
        g2    = grads ** 2
        g3    = grads ** 3
        denom = 2 * g2 + acts * g3.sum(dim=(2, 3, 4), keepdim=True)
        denom = torch.where(denom != 0, denom, torch.ones_like(denom))
        alpha = g2 / denom
        score = logits[0, class_idx]
        weights = (alpha * F.relu(score.exp() * grads)).mean(
            dim=(2, 3, 4), keepdim=True)

    cam = F.relu((weights * acts).sum(dim=1, keepdim=True))
    cam = F.interpolate(cam, size=x.shape[2:], mode="trilinear",
                        align_corners=False)
    return _normalise(cam.squeeze().detach().cpu().numpy())

def run_integrated_gradients(
    model:      nn.Module,
    x:          torch.Tensor,
    class_idx:  int,
    n_steps:    int = 100,
    nt_samples: int = 10,
    noise_std:  float = 0.05,
) -> np.ndarray:
    """
    SmoothGrad-IG: average Integrated Gradients over nt_samples noisy inputs.

    Parameters
    ----------
    n_steps    : Riemann integration steps.  100 is sufficient with SmoothGrad.
    nt_samples : SmoothGrad samples.  5 halves peak memory vs 10 with minimal
                 quality loss because the blurred baseline already stabilises
                 the integration path.
    noise_std  : Std of Gaussian noise in normalised [-1, 1] units.
    """
    base     = unwrap(model).to(x.device)
    baseline = torch.full_like(x, -1.0)

    ig = IntegratedGradients(base)
    nt = NoiseTunnel(ig)

    attrs = nt.attribute(
        x,
        nt_type            = "smoothgrad",
        nt_samples         = nt_samples,
        nt_samples_batch_size = 1,     
        stdevs             = noise_std,
        target             = class_idx,
        n_steps            = n_steps,
        baselines          = baseline,
    )
    sal = attrs.squeeze().abs().detach().cpu().numpy()
    return _normalise(sal)


# ---------------------------------------------------------------------------
# Occlusion Sensitivity
# ---------------------------------------------------------------------------

def run_occlusion(
    model:      nn.Module,
    x:          torch.Tensor,
    class_idx:  int,
    window:     int   = 4,
    stride:     int   = 2,
) -> np.ndarray:
    """
    Occlusion Sensitivity: slide a window^3 cube across the volume, replace
    it with the mean voxel value, and record the drop in class score.

    Parameters
    ----------
    window : side length of the occlusion cube in voxels (8 recommended).
             Smaller = finer map but more forward passes.
    stride : step between consecutive cube positions (4 recommended,
             giving 50% overlap between adjacent windows).
             Total passes = ceil(D/stride) * ceil(H/stride) * ceil(W/stride).
             With D=H=W=64, window=8, stride=4: 15^3 = 3375 passes,
             batched so actual forward calls = 3375 / batch_size ~ 211.

    Memory
    ------
    Occluded patches are batched in groups of batch_size before each
    forward pass. Peak memory = batch_size * patch_size^3 * 4 bytes,
    which is identical to the inference batch.
    """
    D, H, W = x.shape[2], x.shape[3], x.shape[4]
    device  = x.device

    base = unwrap(model).to(device)
    base.eval()
    fill_value = float(x.min().item())   

    # Baseline score on the unmasked input
    with torch.no_grad():
        baseline_score = float(
            F.softmax(base(x), dim=1)[0, class_idx].item()
        )

    # Build list of all (z0, y0, x0) window positions
    positions = [
        (z, y, xp)
        for z  in range(0, D - window + 1, stride)
        for y  in range(0, H - window + 1, stride)
        for xp in range(0, W - window + 1, stride)
    ]

    sal = np.zeros((D, H, W), dtype=np.float64)
    cnt = np.zeros((D, H, W), dtype=np.float64)

    batch_size = 32
    with torch.no_grad():
        for batch_start in range(0, len(positions), batch_size):
            batch_pos = positions[batch_start:batch_start + batch_size]
            bs = len(batch_pos)

            zs  = torch.tensor([p[0] for p in batch_pos], device=device)  # (bs,)
            ys  = torch.tensor([p[1] for p in batch_pos], device=device)
            xs  = torch.tensor([p[2] for p in batch_pos], device=device)

            # Coordinate grids — shape (1, D, H, W) each
            gz = torch.arange(D, device=device).view(1, D, 1, 1)
            gy = torch.arange(H, device=device).view(1, 1, H, 1)
            gx = torch.arange(W, device=device).view(1, 1, 1, W)

            # Boolean mask: True where the window for sample b covers voxel (d,h,w)
            in_window = (
                (gz >= zs.view(bs,1,1,1)) & (gz < (zs+window).view(bs,1,1,1)) &
                (gy >= ys.view(bs,1,1,1)) & (gy < (ys+window).view(bs,1,1,1)) &
                (gx >= xs.view(bs,1,1,1)) & (gx < (xs+window).view(bs,1,1,1))
            )  # (bs, D, H, W)

            # Apply fill: occluded voxels get fill_value, rest keeps original
            x_tiled = x.expand(bs, -1, -1, -1, -1)           # (bs,1,D,H,W)
            batch   = torch.where(
                in_window.unsqueeze(1),
                torch.full_like(x_tiled, fill_value),
                x_tiled,
            )
            scores = F.softmax(base(batch), dim=1)[:, class_idx].cpu().numpy()
            # Accumulate using numpy advanced indexing — no Python loop.
            drops = baseline_score - scores  # (bs,)
            m_np  = in_window.cpu().numpy()  # (bs, D, H, W)
            sal   += (drops[:, None, None, None] * m_np).sum(axis=0)
            cnt   += m_np.sum(axis=0)


    raw = sal / (cnt + 1e-8)
    raw = np.abs(raw)   # abs to capture any sign ambiguity safely
    return _normalise(raw)


def run_xai(
    method:           str,
    model:            nn.Module,
    x:                torch.Tensor,
    class_idx:        int,
    ig_steps:         int = 100,
    occlusion_window: int = 4,
    occlusion_stride: int = 2,
) -> Tuple[np.ndarray, float, float]:
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()

    t0 = time.perf_counter()

    if method == "gradcam":
        sal = run_gradcam(model, x, class_idx, variant="gradcam")
    elif method == "gradcam++":
        sal = run_gradcam(model, x, class_idx, variant="gradcam++")
    elif method == "integrated-gradients":
        sal = run_integrated_gradients(model, x, class_idx, n_steps=ig_steps)
    elif method == "occlusion":
        sal = run_occlusion(model, x, class_idx,
                            window=occlusion_window, stride=occlusion_stride)
    else:
        raise ValueError(f"Unknown XAI method: {method!r}")

    latency_ms = (time.perf_counter() - t0) * 1000.0
    gpu_mem_mb = (torch.cuda.max_memory_allocated() / 1e6
                  if torch.cuda.is_available() else 0.0)
    return sal, latency_ms, gpu_mem_mb


# ===========================================================================
# Stability score
# ===========================================================================

def stability_score(
    method:           str,
    model:            nn.Module,
    x:                torch.Tensor,
    class_idx:        int,
    n_trials:         int   = 5,
    noise_std:        float = 0.02,
    ig_steps:         int   = 100,
    occlusion_window: int   = 4,
    occlusion_stride: int   = 2,
) -> float:
    base_sal, _, _ = run_xai(method, model, x, class_idx, ig_steps,
                             occlusion_window, occlusion_stride)
    base_flat = base_sal.flatten()
    scores = []
    for _ in range(n_trials):
        noisy = x + torch.randn_like(x) * noise_std
        noisy_sal, _, _ = run_xai(method, model, noisy, class_idx, ig_steps,
                                   occlusion_window, occlusion_stride)
        n_flat = noisy_sal.flatten()
        cos = float(np.dot(base_flat, n_flat) /
                    (np.linalg.norm(base_flat) * np.linalg.norm(n_flat) + 1e-8))
        scores.append(cos)
    return float(np.mean(scores))


# ===========================================================================
# Single-sample pipeline
# ===========================================================================

def run_single(
    patch_path:       str,
    model:            nn.Module,
    device:           torch.device,
    xai_methods:      List[str],
    report:           bool,
    out_dir:          Path,
    ig_steps:         int = 100,
    occlusion_window: int = 4,
    occlusion_stride: int = 2,
) -> dict:
    patch = np.load(patch_path).astype(np.float32)
    x     = preprocess(patch).to(device)

    t0             = time.perf_counter()
    pred, nod_prob = predict(model, x, device)
    infer_ms       = (time.perf_counter() - t0) * 1000.0

    result = {
        "stem":         Path(patch_path).stem,
        "prediction":   LABEL_NAMES[pred],
        "nodule_prob":  round(nod_prob, 4),
        "inference_ms": round(infer_ms, 2),
        "xai":          {},
    }
    log.info("Prediction: %s  (p_nodule=%.3f)  [%.1f ms]",
             result["prediction"], nod_prob, infer_ms)

    for method in xai_methods:
        log.info("Running XAI: %s ...", method)
        sal, lat_ms, gpu_mb = run_xai(method, model, x, pred, ig_steps,
                                       occlusion_window, occlusion_stride)
        xai_entry = {"latency_ms": round(lat_ms, 2),
                     "gpu_mem_mb": round(gpu_mb, 2)}
        if report:
            stab = stability_score(method, model, x, pred,
                                   ig_steps=ig_steps,
                                   occlusion_window=occlusion_window,
                                   occlusion_stride=occlusion_stride)
            xai_entry["stability"] = round(stab, 4)
            log.info("  %s: %.1f ms | %.1f MB | stability=%.3f",
                     method, lat_ms, gpu_mb, stab)
        else:
            log.info("  %s: %.1f ms | %.1f MB", method, lat_ms, gpu_mb)

        result["xai"][method] = xai_entry
        sal_path = out_dir / f"{Path(patch_path).stem}_{method}_saliency.npy"
        np.save(sal_path, sal)
        log.info("  Saliency saved -> %s", sal_path)

    return result


# ===========================================================================
# Batch pipeline
# ===========================================================================

def run_batch(
    data_dir:         str,
    model:            nn.Module,
    device:           torch.device,
    xai_methods:      List[str],
    report:           bool,
    out_dir:          Path,
    batch_size:       int = 16,
    split:            str = "test",
    n_workers:        int = 4,
    ig_steps:         int = 100,
    occlusion_window: int = 4,
    occlusion_stride: int = 2,
) -> List[dict]:
    dataset = PatchDataset(data_dir, split=split)
    loader  = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                         num_workers=n_workers,
                         pin_memory=(device.type == "cuda"))

    log.info("Batch inference over %d samples ...", len(dataset))
    all_preds, all_probs, all_stems, all_labels = [], [], [], []
    total_infer_ms = 0.0

    model.eval()
    with torch.no_grad():
        for patches, labels, stems in loader:
            patches = patches.to(device)
            t0      = time.perf_counter()
            logits  = model(patches)
            total_infer_ms += (time.perf_counter() - t0) * 1000.0
            probs   = F.softmax(logits, dim=1)[:, 1].cpu().numpy()
            preds   = logits.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds.tolist())
            all_probs.extend(probs.tolist())
            all_stems.extend(list(stems))
            all_labels.extend(labels.tolist())

    avg_infer_ms = total_infer_ms / max(len(dataset), 1)
    log.info("Avg inference: %.2f ms/sample.", avg_infer_ms)

    all_results = [
        {
            "stem":         stem,
            "prediction":   LABEL_NAMES[pred],
            "nodule_prob":  round(float(prob), 4),
            "true_label":   LABEL_NAMES.get(label, "unknown"),
            "correct":      (pred == label) if label != -1 else None,
            "inference_ms": round(avg_infer_ms, 2),
            "xai":          {},
        }
        for stem, pred, prob, label in zip(
            all_stems, all_preds, all_probs, all_labels)
    ]

    if xai_methods:
        log.info("Running XAI on %d samples ...", len(dataset))
        xai_accum: Dict[str, Dict[str, list]] = {
            m: {"latency_ms": [], "gpu_mem_mb": [], "stability": []}
            for m in xai_methods
        }

        for item, res in zip(dataset.items, all_results):
            path, label, stem = item
            x    = preprocess(np.load(path).astype(np.float32)).to(device)
            pred = int(res["prediction"] == "nodule")

            for method in xai_methods:
                sal, lat_ms, gpu_mb = run_xai(method, model, x, pred, ig_steps,
                                               occlusion_window, occlusion_stride)
                xai_entry = {"latency_ms": round(lat_ms, 2),
                             "gpu_mem_mb": round(gpu_mb, 2)}
                if report:
                    stab = stability_score(method, model, x, pred,
                                           ig_steps=ig_steps,
                                           occlusion_window=occlusion_window,
                                           occlusion_stride=occlusion_stride)
                    xai_entry["stability"] = round(stab, 4)
                    xai_accum[method]["stability"].append(stab)

                xai_accum[method]["latency_ms"].append(lat_ms)
                xai_accum[method]["gpu_mem_mb"].append(gpu_mb)
                res["xai"][method] = xai_entry
                np.save(out_dir / f"{stem}_{method}_saliency.npy", sal)

        if report:
            summary = {}
            for method, vals in xai_accum.items():
                summary[method] = {
                    "mean_latency_ms": round(float(np.mean(vals["latency_ms"])), 2),
                    "mean_gpu_mem_mb": round(float(np.mean(vals["gpu_mem_mb"])), 2),
                    "mean_stability":  round(float(np.mean(vals["stability"])), 4)
                                       if vals["stability"] else None,
                }
                log.info(
                    "XAI summary [%s]: latency=%.1f ms | mem=%.1f MB | stability=%.3f",
                    method,
                    summary[method]["mean_latency_ms"],
                    summary[method]["mean_gpu_mem_mb"],
                    summary[method]["mean_stability"] or float("nan"),
                )
            all_results[0]["xai_summary"] = summary

    return all_results


# ===========================================================================
# Report writing
# ===========================================================================

def write_report(results: List[dict], out_dir: Path) -> None:
    json_path = out_dir / "inference_report.json"
    with open(json_path, "w") as f:
        json.dump(results, f, indent=2, default=str)
    log.info("JSON report -> %s", json_path)

    csv_path = out_dir / "inference_report.csv"
    rows = []
    for res in results:
        base = {k: res[k] for k in
                ("stem", "prediction", "nodule_prob", "inference_ms")}
        base["true_label"] = res.get("true_label", "")
        base["correct"]    = res.get("correct", "")
        if res["xai"]:
            for method, vals in res["xai"].items():
                if method == "xai_summary":
                    continue
                rows.append({**base, "xai_method": method, **vals})
        else:
            rows.append({**base, "xai_method": "none"})

    if rows:
        with open(csv_path, "w", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
            writer.writeheader()
            writer.writerows(rows)
        log.info("CSV report -> %s", csv_path)

    labelled = [r for r in results if r.get("correct") is not None]
    if labelled:
        acc = sum(r["correct"] for r in labelled) / len(labelled)
        log.info("Accuracy on %d labelled samples: %.3f", len(labelled), acc)


# ===========================================================================
# CLI
# ===========================================================================

def resolve_device(n_gpus: int) -> torch.device:
    if n_gpus == 0 or not torch.cuda.is_available():
        log.info("Device: CPU")
        return torch.device("cpu")
    available = torch.cuda.device_count()
    if n_gpus > available:
        log.warning("Requested %d GPUs but only %d available. Using %d.",
                    n_gpus, available, available)
    log.info("Device: CUDA (%d GPU%s)", min(n_gpus, available),
             "s" if n_gpus > 1 else "")
    return torch.device("cuda")


def main():
    parser = argparse.ArgumentParser(
        description="LIDC-IDRI 3D ResNet -- inference + XAI benchmarking (detection)",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument("--input",      required=True)
    parser.add_argument("--checkpoint", required=True)
    parser.add_argument("--batch",      action="store_true")
    parser.add_argument("--split",      default="test",
                        choices=["train", "val", "test"])
    parser.add_argument("--batch-size", type=int, default=16)
    parser.add_argument("--num-workers",type=int, default=4)
    parser.add_argument("--gpus",       type=int, default=0)
    parser.add_argument("--xai",        default=None,
        help="gradcam | gradcam++ | integrated-gradients | occlusion | all")
    parser.add_argument("--ig-steps",   type=int, default=100,
        help="[IG] Riemann steps (200 recommended)")
    parser.add_argument("--occlusion-window", type=int, default=4,
        help="[Occlusion] Side length of the occlusion cube in voxels.")
    parser.add_argument("--occlusion-stride", type=int, default=2,
        help="[Occlusion] Stride between consecutive occlusion windows.")
    parser.add_argument("--report",     action="store_true")
    parser.add_argument("--output-dir", default="inference_detect_output")

    args    = parser.parse_args()
    out_dir = Path(args.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    device = resolve_device(args.gpus)
    model  = load_model(args.checkpoint, device)
    model  = wrap_model(model, args.gpus, device)

    if args.xai is None:
        xai_methods = []
    elif args.xai == "all":
        xai_methods = XAI_METHODS
    else:
        if args.xai not in XAI_METHODS:
            parser.error(f"Unknown XAI method {args.xai!r}.")
        xai_methods = [args.xai]

    log.info("XAI methods: %s", xai_methods or "none")

    if args.batch:
        results = run_batch(
            data_dir    = args.input,
            model       = model,
            device      = device,
            xai_methods = xai_methods,
            report      = args.report,
            out_dir     = out_dir,
            batch_size  = args.batch_size,
            split       = args.split,
            n_workers   = args.num_workers,
            ig_steps         = args.ig_steps,
            occlusion_window = args.occlusion_window,
            occlusion_stride = args.occlusion_stride,
        )
    else:
        if not Path(args.input).is_file():
            parser.error("--input must be a .npy file in single-sample mode.")
        results = [run_single(
            patch_path  = args.input,
            model       = model,
            device      = device,
            xai_methods = xai_methods,
            report      = args.report,
            out_dir     = out_dir,
            ig_steps         = args.ig_steps,
            occlusion_window = args.occlusion_window,
            occlusion_stride = args.occlusion_stride,
        )]

    if args.report:
        write_report(results, out_dir)
    elif not args.batch:
        r = results[0]
        print(f"\nPrediction : {r['prediction']}")
        print(f"P(nodule)  : {r['nodule_prob']:.4f}")
        print(f"Inference  : {r['inference_ms']:.1f} ms")
        if r["xai"]:
            print("\nXAI results:")
            for method, vals in r["xai"].items():
                print(f"  {method:25s}  latency={vals['latency_ms']:.1f} ms"
                      f"  gpu={vals['gpu_mem_mb']:.1f} MB"
                      + (f"  stability={vals['stability']:.3f}"
                         if "stability" in vals else ""))


if __name__ == "__main__":
    main()