# SPDX-License-Identifier: Apache-2.0
# Copyright (C) 2025 MINERVA European Support Centre contributors.
# See https://www.apache.org/licenses/LICENSE-2.0 for the full license text.

# ===========================================================================
# Imports and Initializations
# ===========================================================================
from __future__ import annotations
import csv
import gzip
import hashlib
import json
import logging
import pickle
import time
import xml.etree.ElementTree as ET
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import cv2
import numpy as np
import pydicom
from PIL import Image
from scipy.ndimage import zoom
from sklearn.model_selection import train_test_split


logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s  %(levelname)-8s  %(message)s",
    datefmt="%H:%M:%S",
)
log = logging.getLogger(__name__)


# ===========================================================================
# DICOM indexing
# ===========================================================================

def _uid_to_folder_name(uid: str) -> str:
    """Convert a DICOM UID to the folder-name form used on disk (dots -> underscores)."""
    return uid.replace(".", "_")


def index_dicom_tree(
    dicom_root: str | Path,
) -> Tuple[
    Dict[str, Path],
    Dict[Tuple[str, str], Path],
    Dict[str, str],
]:
    """Walk *dicom_root* and build three lookup tables."""
    dicom_root = Path(dicom_root)
    sop_to_path:      Dict[str, Path]              = {}
    series_to_folder: Dict[Tuple[str, str], Path]  = {}
    study_to_patient: Dict[str, str]               = {}

    dcm_files = list(dicom_root.rglob("*.dcm"))
    log.info("Indexing %d DICOM files under %s ...", len(dcm_files), dicom_root)

    fallback_tags = ["SOPInstanceUID", "StudyInstanceUID", "SeriesInstanceUID", "PatientID"]

    for dcm_path in dcm_files:
        series_dir  = dcm_path.parent
        study_dir   = series_dir.parent
        patient_dir = study_dir.parent

        series_name  = series_dir.name
        study_name   = study_dir.name
        patient_name = patient_dir.name

        path_looks_valid = (
            patient_name.startswith("LIDC-IDRI-")
            and "_" in series_name
            and "_" in study_name
        )

        if path_looks_valid:
            study_uid  = study_name.replace("_", ".")
            series_uid = series_name.replace("_", ".")
            pid        = patient_name
            try:
                ds  = pydicom.dcmread(dcm_path, stop_before_pixels=True,
                                      specific_tags=["SOPInstanceUID"])
                sop = str(ds.SOPInstanceUID).strip()
            except Exception as exc:
                log.debug("Cannot read SOP from %s: %s", dcm_path, exc)
                continue
        else:
            try:
                ds         = pydicom.dcmread(dcm_path, stop_before_pixels=True,
                                             specific_tags=fallback_tags)
                sop        = str(ds.SOPInstanceUID).strip()
                study_uid  = str(getattr(ds, "StudyInstanceUID",  "")).strip()
                series_uid = str(getattr(ds, "SeriesInstanceUID", "")).strip()
                pid        = str(getattr(ds, "PatientID",         "")).strip()
            except Exception as exc:
                log.debug("Skip %s: %s", dcm_path, exc)
                continue

        sop_to_path[sop] = dcm_path

        key = (study_uid, series_uid)
        if key not in series_to_folder:
            series_to_folder[key] = series_dir
        else:
            existing = series_to_folder[key]
            if len(list(series_dir.glob("*.dcm"))) > len(list(existing.glob("*.dcm"))):
                series_to_folder[key] = series_dir

        if study_uid and pid and study_uid not in study_to_patient:
            study_to_patient[study_uid] = pid

    log.info(
        "Index built: %d SOPs | %d series | %d studies.",
        len(sop_to_path), len(series_to_folder), len(study_to_patient),
    )
    return sop_to_path, series_to_folder, study_to_patient


def resolve_dicom_folder(
    xml_study_uid:    str,
    xml_series_uid:   str,
    series_to_folder: Dict[Tuple[str, str], Path],
    sop_to_path:      Dict[str, Path],
    annotations:      list,
) -> Optional[Path]:
    """Return DICOM series folder from exact match -> study-only match -> SOP fallback."""
    folder = series_to_folder.get((xml_study_uid, xml_series_uid))
    if folder is not None:
        return folder

    candidates = [
        f for (study, _), f in series_to_folder.items()
        if study == xml_study_uid
    ]
    if candidates:
        candidates.sort(key=lambda p: len(list(p.glob("*.dcm"))), reverse=True)
        return candidates[0]

    for ann in annotations:
        for roi in ann.get("rois", []):
            sop = (roi.get("sop") or "").strip()
            if sop and sop in sop_to_path:
                return sop_to_path[sop].parent

    return None


# ===========================================================================
# Malignancy helper functions
# ===========================================================================

def _collect_malignancy_scores(annotations: Dict[str, dict]) -> List[float]:
    """
    Gather all non-None malignancy scores from every SOP entry in annotations.
    Each entry may carry a score from one radiologist reader.
    """
    scores = []
    for ann in annotations.values():
        mal = ann.get("characteristics", {}).get("malignancy")
        if mal is not None:
            try:
                scores.append(float(mal))
            except (ValueError, TypeError):
                pass
    return scores


def malignancy_label_from_scores(scores: List[float]) -> Optional[int]:
    """
    Convert a list of per-reader malignancy scores (1–5) to a binary label.

    0   benign      (mean < 3)
    1   malignant   (mean > 3)
    None ambiguous  (mean == 3) or no scores available => discarded
    """
    if not scores:
        return None
    mean = float(np.mean(scores))
    if abs(mean - 3.0) < 1e-6:
        return None   # ambiguous
    return 1 if mean > 3.0 else 0


# ===========================================================================
# Converter class
# ===========================================================================

class NotCTError(Exception):
    """Raised when an XML file is not a CT annotation (e.g. CXR read)."""


class LIDCConverter:
    """Convert LIDC-IDRI DICOM + XML annotations to 2D or 3D segmentation /
    classification training data."""

    _NS     = {"lidc": "http://www.nih.gov"}
    _CXR_NS = "http://www.nih.gov/idri"

    def __init__(
        self,
        output_dir: str | Path = "lung_nodule_dataset",
        cache_dir:  str | Path = ".dicom_cache",
        mode:       str        = "2d",
        task:       str        = "detection",   # "detection" | "malignancy"
    ):
        if mode not in ("2d", "3d"):
            raise ValueError(f"mode must be '2d' or '3d', got {mode!r}")
        if task not in ("detection", "malignancy"):
            raise ValueError(f"task must be 'detection' or 'malignancy', got {task!r}")
        if task == "malignancy" and mode != "3d":
            raise ValueError("task='malignancy' requires mode='3d'")

        self.output_dir = Path(output_dir)
        self.cache_dir  = Path(cache_dir)
        self.mode       = mode
        self.task       = task

        self.dicom_data:  Dict[str, pydicom.Dataset] = {}
        self.annotations: Dict[str, dict]            = {}
        self._consensus_threshold: int               = 1

        self._label_rows: List[dict] = []

        self._setup_directories()

    # ------------------------------------------------------------------
    # Directories
    # ------------------------------------------------------------------

    def _setup_directories(self) -> None:
        for split in ("train", "val", "test"):
            for kind in ("images", "masks"):
                (self.output_dir / kind / split).mkdir(parents=True, exist_ok=True)
        (self.output_dir / "metadata").mkdir(parents=True, exist_ok=True)
        log.info("Output directory ready: %s  (mode=%s, task=%s)",
                 self.output_dir, self.mode, self.task)

    # ------------------------------------------------------------------
    # Index cache
    # ------------------------------------------------------------------

    def _cache_path(self, dicom_root: str | Path) -> Path:
        h = hashlib.sha1(str(Path(dicom_root).resolve()).encode()).hexdigest()[:12]
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        return self.cache_dir / f"index_{h}.pkl.gz"

    def _save_index(self, dicom_root, sop_to_path, series_to_folder, study_to_patient):
        payload = {
            "meta": {"dicom_root": str(Path(dicom_root).resolve()), "created_at": time.time()},
            "sop_to_path": {k: str(v) for k, v in sop_to_path.items()},
            "series_to_folder": {f"{k[0]}||{k[1]}": str(v)
                                 for k, v in series_to_folder.items()},
            "study_to_patient": study_to_patient,
        }
        with gzip.open(self._cache_path(dicom_root), "wb") as fh:
            pickle.dump(payload, fh, protocol=pickle.HIGHEST_PROTOCOL)
        log.info("Saved index cache -> %s", self._cache_path(dicom_root))

    def _load_index(self, dicom_root):
        cp = self._cache_path(dicom_root)
        if not cp.exists():
            return None, None, None
        with gzip.open(cp, "rb") as fh:
            data = pickle.load(fh)
        if data["meta"]["dicom_root"] != str(Path(dicom_root).resolve()):
            log.warning("Cache root mismatch -- ignoring stale cache.")
            return None, None, None
        sop_to_path = {k: Path(v) for k, v in data["sop_to_path"].items()}
        series_to_folder = {}
        for key_str, folder_str in data["series_to_folder"].items():
            study, series = key_str.split("||", 1)
            series_to_folder[(study, series)] = Path(folder_str)
        return sop_to_path, series_to_folder, data.get("study_to_patient", {})

    def load_or_build_index(self, dicom_root, force=False):
        if not force:
            s, f, p = self._load_index(dicom_root)
            if s is not None:
                log.info("Loaded index cache: %d SOPs | %d series | %d studies.",
                         len(s), len(f), len(p))
                return s, f, p
        log.info("Building DICOM index (force=%s) ...", force)
        s, f, p = index_dicom_tree(dicom_root)
        self._save_index(dicom_root, s, f, p)
        return s, f, p

    # ------------------------------------------------------------------
    # DICOM loading
    # ------------------------------------------------------------------

    def load_dicom_series(self, folder: str | Path) -> None:
        """Load all *.dcm files from *folder*, sorted by InstanceNumber."""
        self.dicom_data.clear()
        files: List[Tuple[int, pydicom.Dataset]] = []
        for dcm_path in Path(folder).glob("*.dcm"):
            try:
                ds = pydicom.dcmread(dcm_path)
                try:
                    order = int(ds.InstanceNumber)
                except Exception:
                    stem = dcm_path.stem
                    order = int(stem) if stem.isdigit() else 0
                files.append((order, ds))
            except Exception as exc:
                log.warning("Cannot read %s: %s", dcm_path, exc)
        files.sort(key=lambda x: x[0])
        self.dicom_data = {str(ds.SOPInstanceUID).strip(): ds for _, ds in files}
        log.info("Loaded %d slices from %s.", len(self.dicom_data), folder)

    # ------------------------------------------------------------------
    # XML parsing
    # ------------------------------------------------------------------

    def parse_lidc_xml(self, xml_path: str | Path) -> Tuple[str, str, List[dict]]:
        ns = self._NS

        def _txt(elem, tag):
            t = elem.find(f"lidc:{tag}", ns)
            return t.text.strip() if (t is not None and t.text) else None

        root = ET.parse(xml_path).getroot()
        root_ns = root.tag.partition("}")[0].lstrip("{") if root.tag.startswith("{") else ""
        if root_ns == self._CXR_NS or "IdriReadMessage" in root.tag:
            raise NotCTError(f"CXR read file: {xml_path}")

        header     = root.find("lidc:ResponseHeader", ns)
        series_uid = (_txt(header, "SeriesInstanceUid") or "") if header is not None else ""
        study_uid  = (_txt(header, "StudyInstanceUID")  or "") if header is not None else ""

        if not study_uid:
            log.warning("No StudyInstanceUID in %s", xml_path)

        nodules: List[dict] = []
        for session in root.findall("lidc:readingSession", ns):
            for n_elem in session.findall("lidc:unblindedReadNodule", ns):
                nid = _txt(n_elem, "noduleID")
                c   = n_elem.find("lidc:characteristics", ns)
                chars = {}
                if c is not None:
                    for field in ("subtlety", "internalStructure", "calcification",
                                  "sphericity", "margin", "lobulation", "spiculation",
                                  "texture", "malignancy"):
                        chars[field] = _txt(c, field)
                rois = self._parse_rois(n_elem, ns)
                if rois:
                    nodules.append({"StudyInstanceUID": study_uid,
                                    "SeriesInstanceUid": series_uid,
                                    "noduleID": nid, "rois": rois,
                                    "characteristics": chars, "nodule_type": "nodule"})
            for sn in session.findall("lidc:blindedReadNodule", ns):
                rois = self._parse_rois(sn, ns)
                if rois:
                    nodules.append({"StudyInstanceUID": study_uid,
                                    "SeriesInstanceUid": series_uid,
                                    "noduleID": _txt(sn, "noduleID"), "rois": rois,
                                    "characteristics": {}, "nodule_type": "small_nodule"})

        log.info("  XML %s  study=...%s  %d nodule entries",
                 Path(xml_path).name, study_uid[-20:], len(nodules))
        return study_uid, series_uid, nodules

    @staticmethod
    def _parse_rois(nodule_elem, ns):
        rois = []
        for roi in nodule_elem.findall("lidc:roi", ns):
            def _t(tag):
                e = roi.find(f"lidc:{tag}", ns)
                return e.text.strip() if (e is not None and e.text) else None
            z_txt = _t("imageZposition")
            sop   = (_t("imageSOP_UID") or _t("imageSOP_UID ") or "").strip()
            inc   = (_t("inclusion") or "").upper() == "TRUE"
            pts   = []
            for em in roi.findall("lidc:edgeMap", ns):
                xe, ye = em.find("lidc:xCoord", ns), em.find("lidc:yCoord", ns)
                if xe is not None and ye is not None:
                    pts.append((int(xe.text), int(ye.text)))
            rois.append({"z": float(z_txt) if z_txt else None,
                         "sop": sop, "inclusion": inc, "points": pts})
        return rois

    # ------------------------------------------------------------------
    # Annotation index
    # ------------------------------------------------------------------

    def build_annotation_index(self, nodules: List[dict], consensus_threshold: int = 1) -> None:
        """
        Merge all readers' annotations keyed by SOP UID.

        consensus_threshold : min number of readers that must annotate a slice
            1 = any reader (union)
            2 = at least 2 readers agreed  (recommended for segmentation)
            3 = majority agreed
            4 = all 4 readers agreed
        """
        self.annotations.clear()
        self._consensus_threshold = consensus_threshold
        for nodule in nodules:
            chars = nodule.get("characteristics", {})
            for roi in nodule["rois"]:
                sop = roi["sop"]
                if not sop:
                    continue
                entry = self.annotations.setdefault(sop, {"nodules": [], "characteristics": {}})
                entry["nodules"].append({
                    "id": nodule["noduleID"], "points": roi["points"],
                    "z": roi["z"], "inclusion": roi["inclusion"],
                    "nodule_type": nodule["nodule_type"], "characteristics": chars,
                })
                for k, v in chars.items():
                    if v is not None and entry["characteristics"].get(k) is None:
                        entry["characteristics"][k] = v

    # ------------------------------------------------------------------
    # Preprocessing
    # ------------------------------------------------------------------

    @staticmethod
    def _first_scalar(v, default=None):
        if v is None:
            return default
        try:
            if hasattr(v, "__iter__") and not isinstance(v, (str, bytes)):
                return float(v[0])
            return float(v)
        except Exception:
            return default

    def preprocess(self, ds: pydicom.Dataset, method: str = "hu_window") -> np.ndarray:
        """Return a float32 slice in [0, 255] after HU conversion and windowing."""
        img = ds.pixel_array.astype(np.float32)
        slope     = float(getattr(ds, "RescaleSlope",     1.0) or 1.0)
        intercept = float(getattr(ds, "RescaleIntercept", 0.0) or 0.0)
        img = img * slope + intercept

        if method == "hu_window":
            lo, hi = -1000.0, 400.0
        elif method == "dicom_window":
            wc = self._first_scalar(getattr(ds, "WindowCenter", None))
            ww = self._first_scalar(getattr(ds, "WindowWidth",  None))
            lo, hi = (wc - ww / 2.0, wc + ww / 2.0) \
                     if (wc is not None and ww is not None and ww > 1e-6) \
                     else (img.min(), img.max())
        elif method == "zscore":
            sd = img.std() or 1.0
            img = (img - img.mean()) / sd
            lo, hi = img.min(), img.max()
        else:
            lo, hi = img.min(), img.max()

        rng = (hi - lo) or 1.0
        return np.clip((img - lo) / rng * 255.0, 0, 255).astype(np.float32)

    @staticmethod
    def _mask_from_points(points, shape) -> np.ndarray:
        mask = np.zeros(shape, dtype=np.uint8)
        if not points or len(points) < 3:
            return mask
        pts = np.array(points, dtype=np.int32).reshape(-1, 1, 2)
        cv2.fillPoly(mask, [pts], 255)
        return mask

    def _build_slice_mask(self, sop_id: str, h: int, w: int) -> np.ndarray:
        """Build a consensus mask for one slice using stored threshold."""
        ann          = self.annotations.get(sop_id, {"nodules": []})
        reader_count = np.zeros((h, w), dtype=np.int16)
        for nodule in ann["nodules"]:
            if nodule["inclusion"] and len(nodule["points"]) >= 3:
                nm = self._mask_from_points(nodule["points"], (h, w))
                reader_count += (nm > 0).astype(np.int16)
        return (reader_count >= self._consensus_threshold).astype(np.uint8) * 255

    # ------------------------------------------------------------------
    # 2D: per-slice extraction
    # ------------------------------------------------------------------

    def extract_slice(self, sop_id: str, normalize: str = "hu_window"):
        ds = self.dicom_data.get(sop_id)
        if ds is None:
            return None, None, None
        image = self.preprocess(ds, normalize)
        h, w  = image.shape
        mask  = self._build_slice_mask(sop_id, h, w)
        ann   = self.annotations.get(sop_id, {"nodules": [], "characteristics": {}})
        meta  = {
            "sop_id": sop_id, "has_nodules": bool(ann["nodules"]),
            "nodule_count": sum(1 for n in ann["nodules"] if n["inclusion"]),
            "characteristics": ann["characteristics"], "image_shape": [h, w],
            "pixel_spacing":   list(getattr(ds, "PixelSpacing",         None) or []),
            "slice_thickness": float(getattr(ds, "SliceThickness",      0) or 0),
            "image_position":  list(getattr(ds, "ImagePositionPatient", None) or []),
            "patient_id":      str(getattr(ds, "PatientID", "")),
            "study_uid":       str(getattr(ds, "StudyInstanceUID", "")),
        }
        return image, mask, meta

    # ------------------------------------------------------------------
    # 3D helpers: volume assemble + resample
    # ------------------------------------------------------------------

    def build_volume(
        self, normalize: str = "hu_window"
    ) -> Tuple[np.ndarray, np.ndarray, dict]:
        """
        Stack all loaded slices into an ordered (D, H, W) volume.

        Returns raw (non-resampled) float32 image volume, uint8 mask volume,
        and a metadata dict containing original spacing.
        """
        sop_order = list(self.dicom_data.keys())
        if not sop_order:
            raise ValueError("dicom_data is empty.")

        ds0 = next(iter(self.dicom_data.values()))
        h   = int(ds0.Rows)
        w   = int(ds0.Columns)

        slices_2d, masks_2d, positions = [], [], []
        for idx, sop_id in enumerate(sop_order):
            ds = self.dicom_data[sop_id]
            slices_2d.append(self.preprocess(ds, normalize))
            masks_2d.append(self._build_slice_mask(sop_id, h, w) // 255)  # 0/1

            pos = getattr(ds, "ImagePositionPatient", None)
            positions.append(float(pos[2]) if pos is not None else float(idx))

        image_vol = np.stack(slices_2d, axis=0).astype(np.float32)   # (D, H, W)
        mask_vol  = np.stack(masks_2d,  axis=0).astype(np.uint8)     # (D, H, W)

        px = list(getattr(ds0, "PixelSpacing", [1.0, 1.0]))
        st = getattr(ds0, "SliceThickness", None)
        if st:
            z_spacing = float(st)
        elif len(positions) > 1:
            gaps = [abs(positions[i+1] - positions[i]) for i in range(len(positions)-1)]
            z_spacing = float(np.median(gaps))
        else:
            z_spacing = 1.0

        meta = {
            "patient_id":   str(getattr(ds0, "PatientID", "")),
            "study_uid":    str(getattr(ds0, "StudyInstanceUID", "")),
            "volume_shape": list(image_vol.shape),
            "spacing_zyx":  [z_spacing, float(px[0]), float(px[1])],
            "n_slices":     len(sop_order),
        }
        return image_vol, mask_vol, meta

    @staticmethod
    def resample_volume(
        image: np.ndarray,
        mask:  np.ndarray,
        spacing_zyx: Tuple[float, float, float],
        target_spacing: float = 1.0,
    ) -> Tuple[np.ndarray, np.ndarray, Tuple[float, float, float]]:
        """
        Resample to isotropic *target_spacing* mm voxels.
        Image: trilinear (order=1).  Mask: nearest-neighbour (order=0).
        """
        zoom_factors = tuple(s / target_spacing for s in spacing_zyx)
        image_rs = zoom(image, zoom_factors, order=1, prefilter=False)
        mask_rs  = zoom(mask,  zoom_factors, order=0, prefilter=False)
        new_spacing = (target_spacing, target_spacing, target_spacing)
        return image_rs.astype(np.float32), mask_rs.astype(np.uint8), new_spacing

    @staticmethod
    def _nodule_centroids(mask_vol: np.ndarray) -> List[Tuple[int, int, int]]:
        """
        Return (z, y, x) centroid of each connected nodule region.
        Uses scipy connected-component labelling.
        """
        from scipy.ndimage import label, center_of_mass
        labelled, n = label(mask_vol > 0)
        return [
            tuple(int(round(c)) for c in center_of_mass(labelled == cid))
            for cid in range(1, n + 1)
        ]

    @staticmethod
    def _extract_patch(
        volume: np.ndarray,
        centre: Tuple[int, int, int],
        patch_size: int,
    ) -> np.ndarray:
        """
        Extract a (patch_size)^3 cube centred at *centre*.
        Zero-pads if the cube extends beyond volume boundaries.
        """
        D, H, W = volume.shape
        half    = patch_size // 2
        cz, cy, cx = centre

        z0, y0, x0 = cz - half, cy - half, cx - half
        z1, y1, x1 = z0 + patch_size, y0 + patch_size, x0 + patch_size

        vz0, vz1 = max(z0, 0), min(z1, D)
        vy0, vy1 = max(y0, 0), min(y1, H)
        vx0, vx1 = max(x0, 0), min(x1, W)

        patch = np.zeros((patch_size, patch_size, patch_size), dtype=volume.dtype)
        patch[vz0-z0:vz0-z0+(vz1-vz0),
              vy0-y0:vy0-y0+(vy1-vy0),
              vx0-x0:vx0-x0+(vx1-vx0)] = volume[vz0:vz1, vy0:vy1, vx0:vx1]
        return patch

    def _sample_negative_centres(
        self,
        mask_vol:   np.ndarray,
        n_samples:  int,
        patch_size: int,
        rng:        np.random.Generator,
    ) -> List[Tuple[int, int, int]]:
        """Sample random patch centres from regions with no nodule annotation."""
        D, H, W  = mask_vol.shape
        half     = patch_size // 2
        centres, max_tries = [], n_samples * 50

        for _ in range(max_tries):
            if len(centres) >= n_samples:
                break
            z = int(rng.integers(half, max(half + 1, D - half)))
            y = int(rng.integers(half, max(half + 1, H - half)))
            x = int(rng.integers(half, max(half + 1, W - half)))
            patch_mask = mask_vol[z-half:z-half+patch_size,
                                  y-half:y-half+patch_size,
                                  x-half:x-half+patch_size]
            if patch_mask.sum() == 0:
                centres.append((z, y, x))

        if len(centres) < n_samples:
            log.debug("Sampled %d / %d negative patches.", len(centres), n_samples)
        return centres

    # ------------------------------------------------------------------
    # 2D: save patient slices
    # ------------------------------------------------------------------

    def save_patient_2d(
        self,
        patient_id:      str,
        split:           str,
        fmt:             str   = "png",
        normalize:       str   = "hu_window",
        min_mask_area:   int   = 10,
        negative_ratio:  Optional[float] = None,
        require_nodules: bool  = True,
        dry_run:         bool  = False,
    ) -> Optional[dict]:
        """Extract every slice and save as 2D image/mask files into *split* folder."""
        all_slices = []
        for sop_id in self.dicom_data:
            img, msk, meta = self.extract_slice(sop_id, normalize)
            if img is None:
                continue
            area = int(np.sum(msk > 0))
            if 0 < area < min_mask_area:
                continue
            all_slices.append({"sop_id": sop_id, "image": img, "mask": msk,
                                "metadata": meta, "is_positive": area > 0})

        positives = [s for s in all_slices if     s["is_positive"]]
        negatives = [s for s in all_slices if not s["is_positive"]]

        if not positives and require_nodules:
            log.info("  %s  ->  0 pos slices. Skipping.", patient_id)
            return None

        if negative_ratio is not None:
            cap = int(len(positives) * negative_ratio)
            if cap < len(negatives):
                idx = np.random.default_rng(42).choice(len(negatives), cap, replace=False)
                negatives = [negatives[i] for i in sorted(idx)]

        all_slices = positives + negatives
        log.info("  %s  ->  %d pos + %d neg = %d slices",
                 patient_id, len(positives), len(negatives), len(all_slices))

        if dry_run:
            log.info("  [dry-run] No files written.")
            return {}

        split_meta = {}
        for idx, s in enumerate(all_slices):
            stem = f"{patient_id}_s{idx:04d}_{s['sop_id'][:8]}"
            img8 = np.clip(s["image"], 0, 255).astype(np.uint8)
            msk8 = (s["mask"] > 0).astype(np.uint8) * 255
            if fmt in ("png", "both"):
                Image.fromarray(img8, "L").save(
                    self.output_dir / "images" / split / f"{stem}.png")
                Image.fromarray(msk8, "L").save(
                    self.output_dir / "masks"  / split / f"{stem}.png")
            if fmt in ("npy", "both"):
                np.save(self.output_dir / "images" / split / f"{stem}.npy", s["image"])
                np.save(self.output_dir / "masks"  / split / f"{stem}.npy", s["mask"])
            split_meta[stem] = s["metadata"]

        log.info("    %s  %d slices saved", split, len(all_slices))
        return split_meta

    # ------------------------------------------------------------------
    # 3D: save patient patches
    # ------------------------------------------------------------------

    def save_patient_3d(
        self,
        patient_id:      str,
        split:           str,
        normalize:       str   = "hu_window",
        patch_size:      int   = 64,
        target_spacing:  float = 1.0,
        negative_ratio:  float = 1.0,
        require_nodules: bool  = True,
        dry_run:         bool  = False,
    ) -> Optional[dict]:
        """
        Build the full 3D volume, resample to isotropic spacing, extract
        nodule-centred positive patches and random negative patches.

        Output folder structure:

        images/<split>/<patient_id>_pos_NNNN.npy  -- float32 (D, H, W) patch
        masks/<split>/<patient_id>_pos_NNNN.npy   -- uint8  (D, H, W) patch
        (same for neg patches)

        Returns a metadata dict keyed by patch stem, or None if no nodules
        and require_nodules=True.
        """
        image_vol, mask_vol, vol_meta = self.build_volume(normalize)

        # Isotropic target spacing
        spacing_zyx = tuple(vol_meta["spacing_zyx"])
        image_rs, mask_rs, _ = self.resample_volume(
            image_vol, mask_vol, spacing_zyx, target_spacing)
        log.info("  %s  volume %s -> resampled %s  (%.2f mm isotropic)",
                 patient_id, image_vol.shape, image_rs.shape, target_spacing)

        # Find nodule centroids in resampled space
        centroids = self._nodule_centroids(mask_rs)
        if not centroids and require_nodules:
            log.info("  %s  ->  0 nodule centroids after resampling. Skipping.", patient_id)
            return None

        # ------------------------------------------------------------------
        # Collect malignancy scores once per patient.
        # ------------------------------------------------------------------
        mal_scores  = _collect_malignancy_scores(self.annotations)
        mal_mean    = float(np.mean(mal_scores)) if mal_scores else None
        mal_label   = malignancy_label_from_scores(mal_scores)

        if self.task == "malignancy":
            if mal_label is None:
                log.info("  %s  ->  malignancy ambiguous or missing (mean=%.2f). "
                         "Skipping positive patches.",
                         patient_id, mal_mean if mal_mean is not None else -1)
                return None
            log.info("  %s  ->  malignancy_label=%d  (mean=%.2f  scores=%s)",
                     patient_id, mal_label, mal_mean, mal_scores)

        n_pos = len(centroids)
        n_neg = max(1, int(n_pos * negative_ratio)) if n_pos > 0 else 0
        log.info("  %s  ->  %d pos + %d neg patches", patient_id, n_pos, n_neg)

        if dry_run:
            log.info("  [dry-run] No files written.")
            return {}

        patch_meta: dict = {}
        common = {
            "patient_id":     patient_id,
            "patch_size":     patch_size,
            "target_spacing": target_spacing,
            "spacing_zyx":    vol_meta["spacing_zyx"],
            "volume_shape":   vol_meta["volume_shape"],
            "split":          split,
            # Malignancy fields (None for detection task)
            "malignancy_scores": mal_scores  if self.task == "malignancy" else None,
            "malignancy_mean":   mal_mean    if self.task == "malignancy" else None,
            "malignancy_label":  mal_label   if self.task == "malignancy" else None,
        }

        # Positive patches
        for n, centre in enumerate(centroids):
            stem  = f"{patient_id}_pos_{n:04d}"
            img_p = self._extract_patch(image_rs, centre, patch_size)
            msk_p = self._extract_patch(mask_rs,  centre, patch_size)
            np.save(self.output_dir / "images" / split / f"{stem}.npy", img_p)
            np.save(self.output_dir / "masks"  / split / f"{stem}.npy", msk_p)
            patch_meta[stem] = {**common, "detection_label": "positive",
                                 "centre_zyx": list(centre)}

            if self.task == "malignancy" and mal_label is not None:
                self._label_rows.append({
                    "stem":            stem,
                    "label":           mal_label,
                    "malignancy_mean": round(mal_mean, 4),
                    "split":           split,
                    "patient_id":      patient_id,
                })

        # Negative patches
        if n_neg > 0:
            rng         = np.random.default_rng(42)
            neg_centres = self._sample_negative_centres(mask_rs, n_neg, patch_size, rng)
            for n, centre in enumerate(neg_centres):
                stem  = f"{patient_id}_neg_{n:04d}"
                img_p = self._extract_patch(image_rs, centre, patch_size)
                msk_p = self._extract_patch(mask_rs,  centre, patch_size)
                np.save(self.output_dir / "images" / split / f"{stem}.npy", img_p)
                np.save(self.output_dir / "masks"  / split / f"{stem}.npy", msk_p)
                patch_meta[stem] = {**common, "detection_label": "negative",
                                     "centre_zyx": list(centre)}
                # Negatives are not added to labels.csv since they have no malignancy label
            log.info("    %d neg patches saved", len(neg_centres))

        return patch_meta

    # ------------------------------------------------------------------
    # Unified dispatcher
    # ------------------------------------------------------------------

    def save_patient(
        self,
        patient_id:      str,
        split:           str,
        fmt:             str   = "png",
        normalize:       str   = "hu_window",
        min_mask_area:   int   = 10,
        negative_ratio:  Optional[float] = None,
        require_nodules: bool  = True,
        dry_run:         bool  = False,
        patch_size:      int   = 64,
        target_spacing:  float = 1.0,
    ) -> Optional[dict]:
        if self.mode == "3d":
            return self.save_patient_3d(
                patient_id      = patient_id,
                split           = split,
                normalize       = normalize,
                patch_size      = patch_size,
                target_spacing  = target_spacing,
                negative_ratio  = negative_ratio if negative_ratio is not None else 1.0,
                require_nodules = require_nodules,
                dry_run         = dry_run,
            )
        return self.save_patient_2d(
            patient_id      = patient_id,
            split           = split,
            fmt             = fmt,
            normalize       = normalize,
            min_mask_area   = min_mask_area,
            negative_ratio  = negative_ratio,
            require_nodules = require_nodules,
            dry_run         = dry_run,
        )

    # ------------------------------------------------------------------
    # Patient-level split assignment
    # ------------------------------------------------------------------

    @staticmethod
    def assign_split(
        patient_ids: List[str],
        ratios: Tuple[float, float, float] = (0.7, 0.2, 0.1),
        seed: int = 42,
    ) -> Dict[str, str]:
        """
        Assign each patient to train/val/test BEFORE any data is written.

        3D mode MUST use patient-level splits (all patches from one patient
        go to the same split) to prevent data leakage.
        """
        train_r, val_r, test_r = ratios
        tt_r = val_r + test_r
        if len(patient_ids) <= 1 or tt_r <= 0:
            return {pid: "train" for pid in patient_ids}

        train_ids, temp_ids = train_test_split(
            patient_ids, test_size=tt_r, random_state=seed)
        if not temp_ids:
            return {pid: "train" for pid in patient_ids}

        val_share = val_r / tt_r
        if len(temp_ids) == 1:
            val_ids, test_ids = temp_ids, []
        else:
            val_ids, test_ids = train_test_split(
                temp_ids, test_size=(1 - val_share), random_state=seed)

        assignment: Dict[str, str] = {}
        for pid in train_ids: assignment[pid] = "train"
        for pid in val_ids:   assignment[pid] = "val"
        for pid in test_ids:  assignment[pid] = "test"
        return assignment

    # ------------------------------------------------------------------
    # Dataset summary + labels csv
    # ------------------------------------------------------------------

    def write_dataset_info(self) -> dict:
        info: dict = {"mode": self.mode, "task": self.task, "splits": {},
                      "total_samples": 0}
        for split in ("train", "val", "test"):
            d = self.output_dir / "images" / split
            if d.exists():
                stems = {p.stem for p in d.iterdir()
                         if p.suffix in (".npy", ".png")}
                info["splits"][split] = len(stems)
                info["total_samples"] += len(stems)
        with open(self.output_dir / "dataset_info.json", "w") as fh:
            json.dump(info, fh, indent=2)
        log.info("Dataset summary -> %s", info)

        # Write labels.csv for malignancy task
        if self.task == "malignancy" and self._label_rows:
            csv_path = self.output_dir / "labels.csv"
            fieldnames = ["stem", "label", "malignancy_mean", "split", "patient_id"]
            with open(csv_path, "w", newline="") as fh:
                w = csv.DictWriter(fh, fieldnames=fieldnames)
                w.writeheader()
                w.writerows(self._label_rows)
            n_benign    = sum(1 for r in self._label_rows if r["label"] == 0)
            n_malignant = sum(1 for r in self._label_rows if r["label"] == 1)
            log.info("labels.csv written: %d benign / %d malignant -> %s",
                     n_benign, n_malignant, csv_path)
            info["malignancy_labels"] = {"benign": n_benign, "malignant": n_malignant}

        return info


# ===========================================================================
# Full pipeline
# ===========================================================================

def run(
    xml_root:            str   = "./justinkirby/the-cancer-imaging-archive-lidcidri/versions/1/LIDC-XML-only/tcia-lidc-xml",
    dicom_root:          str   = "./data",
    output_dir:          str   = "lung_nodule_dataset",
    cache_dir:           str   = ".dicom_cache",
    mode:                str   = "2d",
    task:                str   = "detection",
    fmt:                 str   = "both",
    normalize:           str   = "hu_window",
    force_reindex:       bool  = False,
    skip_done:           bool  = True,
    dry_run:             bool  = False,
    negative_ratio:      Optional[float] = None,
    require_nodules:     bool  = True,
    consensus_threshold: int   = 1,
    patch_size:          int   = 64,
    target_spacing:      float = 1.0,
    split_seed:          int   = 42,
) -> None:
    """
    Process every XML and produce training data.

    2D / detection : per-slice files; split is per-slice.
    3D / detection : nodule-centred patches; patient-level split.
    3D / malignancy: nodule-centred patches labelled benign/malignant;
                     ambiguous nodules discarded; labels.csv written.
    """
    converter = LIDCConverter(output_dir=output_dir, cache_dir=cache_dir,
                               mode=mode, task=task)
    sop_to_path, series_to_folder, study_to_patient = converter.load_or_build_index(
        dicom_root, force=force_reindex)

    xml_files = sorted(Path(xml_root).rglob("*.xml"))
    log.info("Found %d XML files under %s.", len(xml_files), xml_root)

    patient_split: Dict[str, str] = {}
    if mode == "3d" and not dry_run:
        log.info("3D mode: pre-scanning to assign patient-level splits ...")
        candidate_ids: List[str] = []
        for xml_path in xml_files:
            try:
                root_elem = ET.parse(xml_path).getroot()
                root_ns = root_elem.tag.partition("}")[0].lstrip("{") \
                          if root_elem.tag.startswith("{") else ""
                if root_ns == LIDCConverter._CXR_NS or "IdriReadMessage" in root_elem.tag:
                    continue
                header = root_elem.find("lidc:ResponseHeader", LIDCConverter._NS)
                if header is None:
                    continue
                t = header.find("lidc:StudyInstanceUID", LIDCConverter._NS)
                if t is None or not t.text:
                    continue
                pid = study_to_patient.get(t.text.strip())
                if pid and pid not in candidate_ids:
                    candidate_ids.append(pid)
            except Exception:
                pass
        patient_split = converter.assign_split(candidate_ids, seed=split_seed)
        log.info("Patient split: %d total -> train/val/test assigned.", len(candidate_ids))

    stats = {
        "processed": 0, "skipped_done": 0, "skipped_cxr": 0,
        "skipped_no_dicom": 0, "skipped_no_study_uid": 0,
        "skipped_no_nodules": 0, "skipped_ambiguous": 0,
        "duplicate_patients": 0, "failed": 0, "patient_id_fallbacks": 0,
    }
    seen_patient_ids: Dict[str, Path] = {}
    all_meta: Dict[str, dict] = {}

    for i, xml_path in enumerate(xml_files, 1):
        log.info("[%d/%d] %s", i, len(xml_files), xml_path)

        try:
            study_uid, series_uid, nodules = converter.parse_lidc_xml(xml_path)

            if not study_uid:
                log.error("  No StudyInstanceUID -- skipping.")
                stats["skipped_no_study_uid"] += 1
                continue

            patient_id = study_to_patient.get(study_uid)
            if patient_id:
                patient_id = patient_id.replace(" ", "_")
                log.info("  PatientID resolved: %s", patient_id)
            else:
                patient_id = f"unknown_{xml_path.parent.name}_{xml_path.stem}"
                log.warning("  StudyUID ...%s not in index. Fallback: %s",
                            study_uid[-25:], patient_id)
                stats["patient_id_fallbacks"] += 1

            if patient_id in seen_patient_ids:
                log.warning("  Duplicate %s (first from %s). Skipping.",
                            patient_id, seen_patient_ids[patient_id].name)
                stats["duplicate_patients"] += 1
                continue
            seen_patient_ids[patient_id] = xml_path

            if skip_done and not dry_run:
                done = converter.output_dir / "metadata" / f"{patient_id}.json"
                if done.exists():
                    log.info("  Already done -- skipping.")
                    stats["skipped_done"] += 1
                    continue

            # Resolve DICOM folder
            if not patient_id.startswith("unknown_"):
                direct = (Path(dicom_root) / patient_id
                          / _uid_to_folder_name(study_uid)
                          / _uid_to_folder_name(series_uid))
                if direct.is_dir() and any(direct.glob("*.dcm")):
                    folder = direct
                else:
                    study_dir = (Path(dicom_root) / patient_id
                                 / _uid_to_folder_name(study_uid))
                    if study_dir.is_dir():
                        cands = sorted(
                            [d for d in study_dir.iterdir()
                             if d.is_dir() and any(d.glob("*.dcm"))],
                            key=lambda d: len(list(d.glob("*.dcm"))), reverse=True)
                        folder = cands[0] if cands else None
                    else:
                        folder = resolve_dicom_folder(
                            study_uid, series_uid, series_to_folder,
                            sop_to_path, nodules)
            else:
                folder = resolve_dicom_folder(
                    study_uid, series_uid, series_to_folder, sop_to_path, nodules)

            if folder is None:
                log.warning("  No DICOM folder -- skipping.")
                stats["skipped_no_dicom"] += 1
                continue
            log.info("  DICOM folder: %s", folder)

            converter.load_dicom_series(folder)
            converter.build_annotation_index(
                nodules, consensus_threshold=consensus_threshold)

            xml_sops   = {roi["sop"] for n in nodules for roi in n["rois"] if roi["sop"]}
            dicom_sops = set(converter.dicom_data.keys())
            overlap    = xml_sops & dicom_sops
            if xml_sops:
                pct = 100 * len(overlap) / len(xml_sops)
                log.info("  SOP overlap: %d / %d (%.0f%%)",
                         len(overlap), len(xml_sops), pct)
                if not overlap:
                    log.warning("  ZERO SOP overlap -- masks will be empty!")

            split = patient_split.get(patient_id, "train") if mode == "3d" else "train"

            result = converter.save_patient(
                patient_id      = patient_id,
                split           = split,
                fmt             = fmt,
                normalize       = normalize,
                min_mask_area   = 10,
                negative_ratio  = negative_ratio,
                require_nodules = require_nodules,
                dry_run         = dry_run,
                patch_size      = patch_size,
                target_spacing  = target_spacing,
            )

            if result is None:
                if task == "malignancy":
                    stats["skipped_ambiguous"] += 1
                else:
                    stats["skipped_no_nodules"] += 1
            else:
                stats["processed"] += 1
                if result and not dry_run:
                    all_meta[patient_id] = result

        except NotCTError:
            log.debug("  Skipping CXR file: %s", xml_path.name)
            stats["skipped_cxr"] += 1

        except Exception as exc:
            log.exception("  Unhandled error: %s", exc)
            stats["failed"] += 1

        finally:
            converter.dicom_data.clear()
            converter.annotations.clear()

    # Persist metadata and summary
    if not dry_run:
        for pid, meta in all_meta.items():
            with open(converter.output_dir / "metadata" / f"{pid}.json", "w") as fh:
                json.dump(meta, fh, indent=2, default=str)
        converter.write_dataset_info()

    log.info(
        "\n=== Run complete (mode=%s  task=%s) ===\n"
        "  Processed             : %d\n"
        "  Skipped (done)        : %d\n"
        "  Skipped (CXR files)   : %d\n"
        "  Skipped (no DICOM)    : %d\n"
        "  Skipped (no StudyUID) : %d\n"
        "  Skipped (no nodules)  : %d\n"
        "  Skipped (ambiguous)   : %d\n"
        "  Duplicate patients    : %d\n"
        "  PatientID fallbacks   : %d\n"
        "  Failed                : %d",
        mode, task,
        stats["processed"], stats["skipped_done"], stats["skipped_cxr"],
        stats["skipped_no_dicom"], stats["skipped_no_study_uid"],
        stats["skipped_no_nodules"], stats["skipped_ambiguous"],
        stats["duplicate_patients"], stats["patient_id_fallbacks"],
        stats["failed"],
    )


# ===========================================================================
# main function
# ===========================================================================

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(
        description="LIDC-IDRI DICOM + XML -> 2D slices or 3D patches for "
                    "segmentation or malignancy classification",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument("--xml-root", default=(
        "./justinkirby/the-cancer-imaging-archive-lidcidri"
        "/versions/1/LIDC-XML-only/tcia-lidc-xml"))
    parser.add_argument("--dicom-root",  default="./data")
    parser.add_argument("--output-dir",  default="lung_nodule_dataset")
    parser.add_argument("--cache-dir",   default=".dicom_cache")

    parser.add_argument("--mode", choices=["2d", "3d"], default="2d",
        help=(
            "2d: one image/mask file per CT slice (U-Net / 2D ResNet). "
            "3d: nodule-centred cubic patches after isotropic resampling "
            "(3D ResNet / nnU-Net 3D)."))

    parser.add_argument("--task", choices=["detection", "malignancy"],
        default="detection",
        help=(
            "detection: binary nodule/background patches (default). "
            "malignancy: positive patches labelled benign(0)/malignant(1) "
            "from radiologist consensus; ambiguous nodules discarded; "
            "labels.csv written. Requires --mode 3d."))

    # 2D-specific
    parser.add_argument("--fmt", choices=["png", "npy", "both"], default="both",
        help="[2D only] Output format.")

    # 3D-specific
    parser.add_argument("--patch-size", type=int, default=64,
        help="[3D only] Cubic patch side in voxels post-resampling.")
    parser.add_argument("--target-spacing", type=float, default=1.0,
        help="[3D only] Isotropic target voxel spacing in mm.")
    parser.add_argument("--split-seed", type=int, default=42,
        help="[3D only] Seed for patient-level train/val/test split.")

    # Shared
    parser.add_argument("--normalize",
        choices=["hu_window", "dicom_window", "minmax", "zscore"], default="hu_window")
    parser.add_argument("--force-reindex",   action="store_true")
    parser.add_argument("--no-skip",         action="store_true",
        help="Reprocess patients whose output already exists.")
    parser.add_argument("--dry-run",         action="store_true",
        help="Parse and resolve everything, report stats, write nothing.")
    parser.add_argument("--negative-ratio",  type=float, default=None,
        help="2D: cap neg at N x pos (None=keep all). 3D: neg patches per pos patch.")
    parser.add_argument("--consensus-threshold", type=int, default=1, choices=[1,2,3,4],
        help="Min readers that must agree for a slice to be positive.")
    parser.add_argument("--keep-all-patients", action="store_true",
        help="Include patients with zero positive slices/patches.")

    args = parser.parse_args()

    run(
        xml_root            = args.xml_root,
        dicom_root          = args.dicom_root,
        output_dir          = args.output_dir,
        cache_dir           = args.cache_dir,
        mode                = args.mode,
        task                = args.task,
        fmt                 = args.fmt,
        normalize           = args.normalize,
        force_reindex       = args.force_reindex,
        skip_done           = not args.no_skip,
        dry_run             = args.dry_run,
        negative_ratio      = args.negative_ratio,
        require_nodules     = not args.keep_all_patients,
        consensus_threshold = args.consensus_threshold,
        patch_size          = args.patch_size,
        target_spacing      = args.target_spacing,
        split_seed          = args.split_seed,
    )