# SPDX-License-Identifier: Apache-2.0
# Copyright (C) 2025 MINERVA European Support Centre contributors.
# See https://www.apache.org/licenses/LICENSE-2.0 for the full license text.

# ===========================================================================
# Imports and Initializations
# ===========================================================================

from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import pydicom
from scipy.ndimage import zoom


# ===========================================================================
# Volume extraction 
# ===========================================================================

def _hu_to_uint8(ds: pydicom.Dataset) -> np.ndarray:
    img = ds.pixel_array.astype(np.float32)
    img = img * float(getattr(ds, "RescaleSlope",     1.0) or 1.0) \
              + float(getattr(ds, "RescaleIntercept", 0.0) or 0.0)
    return np.clip((img + 1000.0) / 1400.0 * 255.0, 0.0, 255.0).astype(np.float32)


def _pick_largest_series(patient_dir: Path) -> Optional[Path]:
    series: Dict[Path, List[Path]] = {}
    for f in patient_dir.rglob("*.dcm"):
        series.setdefault(f.parent, []).append(f)
    if not series:
        return None
    return max(series, key=lambda k: len(series[k]))


def extract_full_volume(
    patient_dir: Path,
    target_spacing: float = 1.0,
) -> Tuple[Optional[np.ndarray], Optional[Tuple[float, float, float]]]:
    folder = _pick_largest_series(patient_dir)
    if folder is None:
        print(f"  [warn] No DICOM files found under {patient_dir}")
        return None, None

    slices: List[Tuple[int, pydicom.Dataset]] = []
    for f in folder.glob("*.dcm"):
        try:
            ds = pydicom.dcmread(f)
            slices.append((int(getattr(ds, "InstanceNumber", 0)), ds))
        except Exception as exc:
            print(f"  [warn] Cannot read {f.name}: {exc}")

    if not slices:
        print(f"  [warn] No readable slices in {folder}")
        return None, None

    slices.sort(key=lambda x: x[0])
    vol_slices, positions = [], []
    for _, ds in slices:
        vol_slices.append(_hu_to_uint8(ds))
        pos = getattr(ds, "ImagePositionPatient", None)
        positions.append(float(pos[2]) if pos is not None else float(len(positions)))

    volume_raw = np.stack(vol_slices, axis=0)
    ds0 = slices[0][1]
    px  = list(getattr(ds0, "PixelSpacing", [1.0, 1.0]))
    st  = getattr(ds0, "SliceThickness", None)
    if st:
        z_spacing = float(st)
    elif len(positions) > 1:
        gaps = [abs(positions[i+1] - positions[i]) for i in range(len(positions)-1)]
        z_spacing = float(np.median(gaps))
    else:
        z_spacing = 1.0
    spacing_zyx = (z_spacing, float(px[0]), float(px[1]))
    print(f"  Raw shape : {volume_raw.shape}  spacing ZYX: {spacing_zyx}")
    zoom_factors = tuple(s / target_spacing for s in spacing_zyx)
    volume_rs = zoom(volume_raw, zoom_factors, order=1, prefilter=False).astype(np.float32)
    print(f"  Resampled : {volume_rs.shape}  ({target_spacing} mm isotropic)")
    return volume_rs, spacing_zyx


def save_full_volume(
    patient_dir: Path,
    output_dir:  Path,
    target_spacing: float = 1.0,
) -> Optional[Path]:
    volume, _ = extract_full_volume(patient_dir, target_spacing)
    if volume is None:
        return None
    output_dir.mkdir(parents=True, exist_ok=True)
    out_path = output_dir / f"{patient_dir.name}_full.npy"
    np.save(out_path, volume)
    print(f"  Saved -> {out_path}")
    return out_path


# ===========================================================================
# Array helpers
# ===========================================================================

def _normalise(arr: np.ndarray, percentile: float = 99.0) -> np.ndarray:
    lo = arr.min()
    hi = float(np.percentile(arr, percentile))
    if hi - lo < 1e-8:
        hi = arr.max()
    return np.clip((arr - lo) / (hi - lo + 1e-8), 0.0, 1.0)


def load_volume_norm(path: str) -> np.ndarray:
    return _normalise(np.load(path).astype(np.float32))


def load_saliency(path: str) -> np.ndarray:
    return np.load(path).astype(np.float32)


def get_mid_slices(vol: np.ndarray) -> Tuple[int, int, int]:
    D, H, W = vol.shape
    return D // 2, H // 2, W // 2


# ===========================================================================
# Prediction & Ground truth helper functions
# ===========================================================================

def ground_truth_from_stem(stem: str) -> str:
    if "_pos_" in stem:
        return "nodule"
    if "_neg_" in stem:
        return "background"
    return "unknown"


def load_prediction_from_report(
    report_path: str,
    stem: str,
) -> Tuple[Optional[str], Optional[float]]:
    """Return (prediction_label, nodule_probability)"""
    try:
        with open(report_path) as f:
            data = json.load(f)
    except Exception:
        return None, None

    for entry in data:
        if entry.get("stem") == stem:
            return entry.get("prediction"), entry.get("nodule_prob")
    return None, None


def build_title(
    stem:       str,
    report_json: Optional[str],
    suffix:     str = "",
) -> str:
    """Build the figure suptitle including prediction and ground truth when available."""
    gt = ground_truth_from_stem(stem)
    base = f"XAI comparison{(' ' + suffix) if suffix else ''} — {stem}\nGT: {gt}"

    if report_json:
        pred, prob = load_prediction_from_report(report_json, stem)
        if pred is not None:
            correct = (pred == gt)
            marker  = "✓" if correct else "✗"
            prob_str = f"  (p={prob:.3f})" if prob is not None else ""
            base += f"  |  Pred: {pred}{prob_str}  {marker}"

    return base


# ===========================================================================
# Metrics helpers
# ===========================================================================

def load_metrics_from_report(
    report_path: str,
    stem: str,
) -> Dict[str, dict]:
    try:
        with open(report_path) as f:
            data = json.load(f)
    except Exception as exc:
        print(f"[warn] Cannot read report {report_path}: {exc}")
        return {}
    for entry in data:
        if entry.get("stem") == stem:
            return {k: v for k, v in entry.get("xai", {}).items()
                    if k != "xai_summary"}
    if data and "xai_summary" in data[0].get("xai", {}):
        return data[0]["xai"]["xai_summary"]
    print(f"[warn] Stem {stem!r} not found in {report_path}")
    return {}


def _metric_label(method: str, metrics: Dict[str, dict]) -> str:
    m = metrics.get(method)
    if not m:
        return method
    parts = [method]
    if m.get("latency_ms") is not None:
        parts.append(f"{m['latency_ms']:.0f} ms")
    if m.get("gpu_mem_mb") is not None:
        parts.append(f"{m['gpu_mem_mb']:.0f} MB")
    if m.get("stability") is not None:
        parts.append(f"stab={m['stability']:.3f}")
    return "  |  ".join(parts)


# ===========================================================================
# Other helper functions
# ===========================================================================

def embed_saliency_in_volume(
    full_vol_shape: Tuple[int, int, int],
    sal_patch:      np.ndarray,
    centre_zyx:     Tuple[int, int, int],
    patch_size:     int,
) -> np.ndarray:
    D, H, W  = full_vol_shape
    sal_vol  = np.zeros((D, H, W), dtype=np.float32)
    half = patch_size // 2
    cz, cy, cx = centre_zyx
    z0, y0, x0 = cz-half, cy-half, cx-half
    z1, y1, x1 = z0+patch_size, y0+patch_size, x0+patch_size
    vz0, vz1 = max(z0,0), min(z1,D)
    vy0, vy1 = max(y0,0), min(y1,H)
    vx0, vx1 = max(x0,0), min(x1,W)
    sal_vol[vz0:vz1, vy0:vy1, vx0:vx1] = \
        sal_patch[vz0-z0:vz0-z0+(vz1-vz0),
                  vy0-y0:vy0-y0+(vy1-vy0),
                  vx0-x0:vx0-x0+(vx1-vx0)]
    return sal_vol


def _patch_mask(
    full_vol_shape: Tuple[int, int, int],
    centre_zyx:     Tuple[int, int, int],
    patch_size:     int,
) -> np.ndarray:
    D, H, W = full_vol_shape
    half = patch_size // 2
    cz, cy, cx = centre_zyx
    mask = np.zeros((D, H, W), dtype=bool)
    mask[max(cz-half,0):min(cz+half,D),
         max(cy-half,0):min(cy+half,H),
         max(cx-half,0):min(cx+half,W)] = True
    return mask


def patch_bbox_on_slice(
    centre_zyx: Tuple[int, int, int],
    patch_size: int,
    vol_shape:  Tuple[int, int, int],
    plane:      str,
) -> Optional[Tuple[int, int, int, int]]:
    D, H, W = vol_shape
    half = patch_size // 2
    cz, cy, cx = centre_zyx
    z0 = max(cz-half,0); z1 = min(cz+half,D)
    y0 = max(cy-half,0); y1 = min(cy+half,H)
    x0 = max(cx-half,0); x1 = min(cx+half,W)
    if plane == "axial":     return y0, x0, y1-y0, x1-x0
    elif plane == "coronal":   return z0, x0, z1-z0, x1-x0
    elif plane == "sagittal":  return z0, y0, z1-z0, y1-y0
    return None


def load_centre_from_metadata(
    metadata_path: str,
    stem: str,
) -> Optional[Tuple[int, int, int]]:
    with open(metadata_path) as f:
        meta = json.load(f)
    entry = meta.get(stem)
    if entry is None:
        print(f"[warn] Stem {stem!r} not found in {metadata_path}")
        return None
    c = entry.get("centre_zyx")
    if c is None:
        print(f"[warn] No centre_zyx for {stem!r} in {metadata_path}")
        return None
    return tuple(int(v) for v in c)


def _draw_bbox(ax, centre_zyx, patch_size, vol_shape, plane_key):
    bbox = patch_bbox_on_slice(centre_zyx, patch_size, vol_shape, plane_key)
    if bbox is not None:
        r, c, bh, bw = bbox
        ax.add_patch(mpatches.Rectangle(
            (c, r), bw, bh,
            linewidth=1.5, edgecolor="white", facecolor="none", linestyle="--",
        ))


def _draw_centroid(
    ax,
    centre_zyx:  Tuple[int, int, int],
    plane:       str,
    vol_shape:   Tuple[int, int, int],
    has_nodule:  bool = True,
) -> None:
    """Draws a marker at the nodule centroid projected onto a 2D plane."""
    if not has_nodule:
        return

    cz, cy, cx = centre_zyx

    if plane == "axial":
        col, row = cx, cy
    elif plane == "coronal":
        col, row = cx, cz
    elif plane == "sagittal":
        col, row = cy, cz
    else:
        return

    # Draw if and only if centroid is within image bounds
    D, H, W = vol_shape
    limits = {"axial": (W, H), "coronal": (W, D), "sagittal": (H, D)}
    max_col, max_row = limits[plane]
    if not (0 <= col < max_col and 0 <= row < max_row):
        return

    size = 6
    ax.plot(col, row, "+", color="yellow", markersize=size,
            markeredgewidth=1.5, zorder=5)
    ax.plot(col, row, "o", color="yellow", markersize=size * 0.6,
            fillstyle="none", markeredgewidth=1.2, zorder=5)


def _row_label(ax, text: str) -> None:
    ax.text(
        -0.04, 0.5, text,
        transform=ax.transAxes,
        fontsize=8, fontweight="bold",
        ha="right", va="center", rotation=90,
        clip_on=False,
    )


# ===========================================================================
# Saliency file fetcher
# ===========================================================================

def resolve_saliency_files(
    stem:    Optional[str],
    patch:   Optional[str],
    sal_dir: str,
) -> Tuple[str, Dict[str, Path]]:
    sal_dir_path = Path(sal_dir)

    if patch:
        p    = Path(patch)
        bare = p.stem
        candidates = sorted(p.parent.glob("*_saliency.npy"))
        if candidates:
            names  = [c.stem.replace("_saliency", "") for c in candidates]
            common = names[0]
            for n in names[1:]:
                while not n.startswith(common):
                    common = common.rsplit("_", 1)[0]
            inferred_stem = common
        else:
            inferred_stem = bare.rsplit("_", 2)[0]
        method = bare.replace(f"{inferred_stem}_", "").replace("_saliency", "")
        return inferred_stem, {method: p}

    files = sorted(sal_dir_path.glob(f"{stem}_*_saliency.npy"))
    if not files:
        raise FileNotFoundError(
            f"No saliency files found matching: {sal_dir_path / stem}_*_saliency.npy")
    sal_files: Dict[str, Path] = {}
    for f in files:
        method = f.stem.replace(f"{stem}_", "").replace("_saliency", "")
        sal_files[method] = f
        print(f"Found saliency: {method} -> {f.name}")
    return stem, sal_files


# ===========================================================================
# Comparison grid
# ===========================================================================

def plot_patch_comparison(
    vol:         np.ndarray,
    sals:        Dict[str, np.ndarray],
    stem:        str,
    out:         Path,
    metrics:     Dict[str, dict],
    report_json: Optional[str] = None,
    alpha:       float = 0.45,
    cmap:        str   = "jet",
) -> None:
    methods    = list(sals.keys())
    n_rows     = len(methods) + 1
    d, h, w    = get_mid_slices(vol)
    has_nodule = "_pos_" in stem

    D, H, W = vol.shape
    centroid_patch = (D // 2, H // 2, W // 2)

    ct_planes = [
        ("Axial",    vol[d,:,:],  "axial",    d),
        ("Coronal",  vol[:,h,:],  "coronal",  h),
        ("Sagittal", vol[:,:,w],  "sagittal", w),
    ]

    title = build_title(stem, report_json)
    fig = plt.figure(figsize=(4*3, 3.5*n_rows))
    fig.suptitle(title, fontsize=11, fontweight="bold", y=0.98)
    gs = gridspec.GridSpec(n_rows, 4, width_ratios=[1,1,1,0.05],
                           hspace=0.45, wspace=0.08)

    # Row 0: plain CT + centroid marker
    for col, (plane_name, ct_slice, plane_key, _) in enumerate(ct_planes):
        ax = fig.add_subplot(gs[0, col])
        ax.imshow(ct_slice, cmap="gray", interpolation="bilinear")
        _draw_centroid(ax, centroid_patch, plane_key, (D, H, W), has_nodule)
        ax.set_title(plane_name, fontsize=9)
        ax.axis("off")
        if col == 0:
            _row_label(ax, "CT (input)")

    # Rows 1+: saliency overlays (centroid marker on each row too)
    last_im = None
    for row, method in enumerate(methods, start=1):
        sal = sals[method]
        sal_planes = [sal[d,:,:], sal[:,h,:], sal[:,:,w]]
        for col, ((plane_name, ct_slice, plane_key, _), sal_slice) in \
                enumerate(zip(ct_planes, sal_planes)):
            ax = fig.add_subplot(gs[row, col])
            ax.imshow(ct_slice, cmap="gray", interpolation="bilinear")
            last_im = ax.imshow(sal_slice, cmap=cmap, alpha=alpha,
                                interpolation="bilinear", vmin=0, vmax=1)
            ax.axis("off")
            if col == 0:
                _row_label(ax, _metric_label(method, metrics))

    if last_im is not None:
        cbar_ax = fig.add_subplot(gs[1:, -1])
        fig.colorbar(last_im, cax=cbar_ax, label="Saliency")

    fig.savefig(out, dpi=150, bbox_inches="tight")
    plt.close(fig)
    print(f"Saved -> {out}")


# ===========================================================================
# Context-mode comparison grid
# ===========================================================================

def _build_context_planes(
    full_vol:   np.ndarray,
    sal_patch:  np.ndarray,
    centre_zyx: Tuple[int, int, int],
    patch_size: int,
) -> list:
    sal_vol    = embed_saliency_in_volume(full_vol.shape, sal_patch, centre_zyx, patch_size)
    mask       = _patch_mask(full_vol.shape, centre_zyx, patch_size)
    sal_masked = np.where(mask, sal_vol, np.nan)
    cz, cy, cx = centre_zyx
    return [
        ("Axial",    cz, full_vol[cz,:,:],  sal_masked[cz,:,:],  "axial"),
        ("Coronal",  cy, full_vol[:,cy,:],  sal_masked[:,cy,:],  "coronal"),
        ("Sagittal", cx, full_vol[:,:,cx],  sal_masked[:,:,cx],  "sagittal"),
    ]


def _plot_context_plane(ax, ct_slice, sal_slice, plane_name, slice_idx,
                        centre_zyx, patch_size, vol_shape, plane_key, alpha, cmap,
                        has_nodule=True):
    ax.imshow(ct_slice, cmap="gray", interpolation="bilinear")
    im = ax.imshow(sal_slice, cmap=cmap, alpha=alpha,
                   interpolation="bilinear", vmin=0, vmax=1)
    _draw_bbox(ax, centre_zyx, patch_size, vol_shape, plane_key)
    ax.set_title(f"{plane_name}  (slice {slice_idx})", fontsize=10)
    ax.axis("off")
    return im


def plot_context_comparison(
    full_vol:    np.ndarray,
    sals:        Dict[str, np.ndarray],
    centre_zyx:  Tuple[int, int, int],
    patch_size:  int,
    stem:        str,
    out:         Path,
    metrics:     Dict[str, dict],
    report_json: Optional[str] = None,
    alpha:       float = 0.55,
    cmap:        str   = "jet",
) -> None:
    cz, cy, cx    = centre_zyx
    has_nodule    = "_pos_" in stem
    ct_plane_specs = [
        ("Axial",    cz, full_vol[cz,:,:],  "axial"),
        ("Coronal",  cy, full_vol[:,cy,:],  "coronal"),
        ("Sagittal", cx, full_vol[:,:,cx],  "sagittal"),
    ]

    methods = list(sals.keys())
    n_rows  = len(methods) + 1

    title = build_title(stem, report_json, suffix="(context)")
    fig = plt.figure(figsize=(5*3, 4*n_rows))
    fig.suptitle(title, fontsize=11, fontweight="bold", y=0.98)
    gs = gridspec.GridSpec(n_rows, 4, width_ratios=[1,1,1,0.05],
                           hspace=0.45, wspace=0.08)

    # Row 0: plain CT + centroid marker + patch bbox
    for col, (plane_name, idx, ct_slice, plane_key) in enumerate(ct_plane_specs):
        ax = fig.add_subplot(gs[0, col])
        ax.imshow(ct_slice, cmap="gray", interpolation="bilinear")
        _draw_bbox(ax, centre_zyx, patch_size, full_vol.shape, plane_key)
        _draw_centroid(ax, centre_zyx, plane_key, full_vol.shape, has_nodule)
        ax.set_title(f"{plane_name}  (slice {idx})", fontsize=9)
        ax.axis("off")
        if col == 0:
            _row_label(ax, "CT (input)")
            # Inline legend on the CT axial panel — no bottom whitespace
            if has_nodule:
                legend_handles = [
                    plt.Line2D([0],[0], marker="+", color="yellow",
                               linestyle="none", markersize=7,
                               markeredgewidth=1.5, label="nodule centroid"),
                    mpatches.Patch(facecolor="none", edgecolor="white",
                                   linestyle="--", linewidth=1.2,
                                   label="patch region"),
                ]
                ax.legend(
                    handles=legend_handles,
                    loc="lower left",
                    fontsize=6.5,
                    framealpha=0.55,
                    facecolor="black",
                    labelcolor="white",
                    edgecolor="none",
                    borderpad=0.4,
                )

    # Rows 1+: saliency overlays
    last_im = None
    for row, method in enumerate(methods, start=1):
        planes = _build_context_planes(full_vol, sals[method], centre_zyx, patch_size)
        for col, (plane_name, idx, ct_slice, sal_slice, plane_key) in enumerate(planes):
            ax = fig.add_subplot(gs[row, col])
            last_im = _plot_context_plane(
                ax, ct_slice, sal_slice, plane_name, idx,
                centre_zyx, patch_size, full_vol.shape, plane_key, alpha, cmap,
                has_nodule=has_nodule,
            )
            if col == 0:
                _row_label(ax, _metric_label(method, metrics))

    if last_im is not None:
        cbar_ax = fig.add_subplot(gs[1:, -1])
        fig.colorbar(last_im, cax=cbar_ax, label="Saliency")


    fig.savefig(out, dpi=150, bbox_inches="tight")
    plt.close(fig)
    print(f"Saved -> {out}")


# ===========================================================================
# CLI
# ===========================================================================

def main():
    parser = argparse.ArgumentParser(
        description="LIDC-IDRI saliency visualisation (patch + context modes)",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    parser.add_argument("--extract-only", action="store_true")
    parser.add_argument("--dicom-root",   default=None)
    parser.add_argument("--patient",      default=None)

    group = parser.add_mutually_exclusive_group()
    group.add_argument("--patch", default=None)
    group.add_argument("--stem",  default=None)

    parser.add_argument("--dir",    default="inference_detect_output")
    parser.add_argument("--volume", default=None,
        help="[patch mode] Path to the 64^3 CT patch .npy.")

    parser.add_argument("--context",        action="store_true")
    parser.add_argument("--dicom-dir",      default=None)
    parser.add_argument("--full-volume",    default=None)
    parser.add_argument("--metadata",       default=None)
    parser.add_argument("--patch-size",     type=int,   default=64)
    parser.add_argument("--target-spacing", type=float, default=1.0)

    parser.add_argument("--report-json",      default=None,
        help="Path to inference_report.json. Used for prediction annotation "
             "and XAI benchmark metrics in row labels.")
    parser.add_argument("--output-dir",       default="inference_detect_output/figures")
    parser.add_argument("--save-full-volume", default=None)

    parser.add_argument("--alpha", type=float, default=0.45)
    parser.add_argument("--cmap",  default="jet")

    args    = parser.parse_args()
    out_dir = Path(args.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    if args.extract_only:
        if not args.dicom_root:
            parser.error("--extract-only requires --dicom-root")
        root = Path(args.dicom_root)
        patients = ([root / args.patient] if args.patient
                    else [d for d in sorted(root.iterdir()) if d.is_dir()])
        for p in patients:
            print(f"\nProcessing {p.name} ...")
            save_full_volume(p, out_dir, args.target_spacing)
        return

    if args.stem is None and args.patch is None:
        parser.error("Provide --stem or --patch (or --extract-only).")
    if args.volume is None and not args.context:
        parser.error("--volume is required in patch mode.")

    stem, sal_files = resolve_saliency_files(args.stem, args.patch, args.dir)
    sals = {method: load_saliency(str(p)) for method, p in sal_files.items()}

    metrics: Dict[str, dict] = {}
    if args.report_json:
        metrics = load_metrics_from_report(args.report_json, stem)
        if metrics:
            print(f"Loaded metrics for {len(metrics)} methods from {args.report_json}")

    if args.context:
        if not args.metadata:
            parser.error("--context requires --metadata")
        centre_zyx = load_centre_from_metadata(args.metadata, stem)
        if centre_zyx is None:
            print("Cannot proceed without centre_zyx.")
            return

        if args.full_volume:
            full_vol = load_volume_norm(args.full_volume)
        elif args.dicom_dir:
            print(f"Extracting full volume from {args.dicom_dir} ...")
            full_vol, _ = extract_full_volume(Path(args.dicom_dir), args.target_spacing)
            if full_vol is None:
                print("Volume extraction failed.")
                return
            full_vol = _normalise(full_vol)
            if args.save_full_volume:
                np.save(args.save_full_volume, full_vol)
                print(f"Full volume saved -> {args.save_full_volume}")
        else:
            parser.error("--context requires either --full-volume or --dicom-dir")

        print(f"Full volume: {full_vol.shape} | Centre ZYX: {centre_zyx} | "
              f"Patch: {args.patch_size}")

        plot_context_comparison(
            full_vol    = full_vol,
            sals        = sals,
            centre_zyx  = centre_zyx,
            patch_size  = args.patch_size,
            stem        = stem,
            out         = out_dir / f"{stem}_comparison_context.png",
            metrics     = metrics,
            report_json = args.report_json,
            alpha       = args.alpha,
            cmap        = args.cmap,
        )

    else:
        patch_vol = load_volume_norm(args.volume)
        plot_patch_comparison(
            vol         = patch_vol,
            sals        = sals,
            stem        = stem,
            out         = out_dir / f"{stem}_comparison.png",
            metrics     = metrics,
            report_json = args.report_json,
            alpha       = args.alpha,
            cmap        = args.cmap,
        )


if __name__ == "__main__":
    main()