#!/bin/bash
#SBATCH --nodes=#NUMBER_OF_NODES
#SBATCH --account=#YOUR_CLUSTER_ACCOUNT
#SBATCH --constraint=#MACHINE
#SBATCH --ntasks=8
#SBATCH --cpus-per-task=8
#SBATCH --job-name=LIDC_detect_infer
#SBATCH --output=logs/inference_detect_%j.log
#SBATCH --time=04:00:00
#SBATCH --exclusive

# ===========================================================================
# Paths
# ===========================================================================

DATA_DIR="./dataset_3d"
CHECKPOINT="./checkpoints_detect/best.pth"
INFER_SCRIPT="./3_InferAndEvaluate.py"
OUTPUT_DIR="./inference_detect_output"

# ===========================================================================
# Environnement
# ===========================================================================
source your_venv/bin/activate

mkdir -p logs "$OUTPUT_DIR"

echo "===== ENV INFO ====="
python -c "import torch; print('Torch:', torch.__version__)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
python -c "import torch; print('Device count:', torch.cuda.device_count())"
echo "===================="

# ===========================================================================
# Launch
# ===========================================================================
python "$INFER_SCRIPT"              \
    --input       "$DATA_DIR"       \
    --checkpoint  "$CHECKPOINT"     \
    --batch                         \
    --split       test              \
    --batch-size  4                 \
    --num-workers 8                 \
    --gpus        8                 \
    --xai         all               \
    --ig-steps    100               \
    --occlusion-window 8            \
    --occlusion-stride 4            \
    --report                        \
    --output-dir  "$OUTPUT_DIR"