#!/bin/bash
#SBATCH --nodes=#NUMBER_OF_NODES
#SBATCH --account=#YOUR_CLUSTER_ACCOUNT
#SBATCH --constraint=#MACHINE
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=8
#SBATCH --job-name=LIDC_detect
#SBATCH --output=logs/train_detect_%j.log
#SBATCH --time=04:00:00
#SBATCH --exclusive

# ===========================================================================
# Paths
# ===========================================================================
DATA_DIR="./dataset_3d"
OUTPUT_DIR="./checkpoints_detect"
TRAIN_SCRIPT="./2_FinetuneResnet3D.py"
PRETRAINED="./resnet_18_23dataset.pth"

# ===========================================================================
# Environnement
# ===========================================================================
source your_venv/bin/activate

mkdir -p logs "$OUTPUT_DIR"

echo "===== ENV INFO ====="
python -c "import torch; print('Torch:', torch.__version__)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
python -c "import torch; print('Device count:', torch.cuda.device_count())"
echo "===================="

export OMP_NUM_THREADS=8

# ===========================================================================
# Launch
# ===========================================================================
python "$TRAIN_SCRIPT" \
    --data-dir      "$DATA_DIR"   \
    --output-dir    "$OUTPUT_DIR" \
    --pretrained    "$PRETRAINED" \
    --epochs        40            \
    --phase1-epochs 5             \
    --batch-size    8             \
    --lr-head       3e-4          \
    --lr-backbone   5e-6          \
    --label-smoothing 0.15        \
    --dropout       0.4           \
    --weight-decay  1e-2          \
    --freeze-layer3               \
    --device cuda