#!/usr/bin/env python3
# -*- coding: utf-8 -*-

#Example usage
# python cluster_label.py ./c.nii 0.5
#Based on
# https://github.com/nilearn/nilearn/blob/main/nilearn/reporting/_get_clusters_table.py

import nibabel as nib
from scipy import ndimage
import numpy as np
import scipy.stats as st
import os
import sys
import math
from scipy.ndimage import (
    maximum_filter,
    minimum_filter,
    label,
    center_of_mass,
    generate_binary_structure,
)

def process_nifti(fnm, stat_threshold = 0.5, cluster_threshold = 0): 
    hdr = nib.load(fnm)
    img = hdr.get_fdata()
    hdr.header.set_data_dtype(np.float32)
    bin_struct = generate_binary_structure(rank=3, connectivity=1)
    binarized = img > stat_threshold
    if stat_threshold < 0:
        binarized = img < stat_threshold
    label_map = label(binarized, bin_struct)[0]
    n_label = np.max(label_map)
    if (cluster_threshold > 0):
        n_ok = 0;
        for i in range(1, n_label + 1):
            n_vox = np.count_nonzero(label_map == i)
            if (n_vox >= cluster_threshold):
                # preserve cluster as it is large enough
                n_ok = n_ok + 1;
                label_map[label_map == i] = n_ok
                continue;
            # delete cluster as it is too small
            label_map[label_map == i] = 0
    n_label = np.max(label_map)
    if n_label < 1:
        print(
            'Error: No clusters with intensity higher than {0}'.format(
                stat_threshold,
            )
        )
        exit()
    #print(np.max(label_map))
    #print(nii.header)
    #clear intent code
    img = label_map.astype(np.int32)
    
    nii = nib.Nifti1Image(img, hdr.affine, hdr.header)
    nii.header['intent_code'] = 0
    nii.header['scl_slope'] = 1.0
    nii.header['scl_inter'] = 0.0
    nii.header['cal_max'] = 0.0
    nii.header['cal_min'] = 0.0
    pth, nm = os.path.split(fnm)
    if not pth:
        pth = '.'
    outnm = pth + os.path.sep + 'l' + nm
    nib.save(nii, outnm)    

if __name__ == '__main__':
    """Label image
    Parameters
    ----------
    fnm : str
        NIfTI image to convert
    stat_threshold: float, optional
        cluster more extreme values. Default=0.5.
    cluster_threshold : :obj:`int`, optional
        Cluster size threshold, in :term:`voxels<voxel>`.
        If 0, then no cluster size threshold will be applied. Default=0.
    """
    if len(sys.argv) < 2:
        print('No filename provided: I do not know which image to convert!')
        sys.exit()
    fnm = sys.argv[1]
    stat_threshold = 0.5
    if len(sys.argv) > 2:
        stat_threshold = float(sys.argv[2])
    cluster_threshold = 0
    if len(sys.argv) > 3:
        cluster_threshold = int(sys.argv[3])
    process_nifti(fnm, stat_threshold, cluster_threshold)

