#!/usr/bin/python
"""Detects outliers in a set of BOLD image series

This module generates files containing indices of outliers in
functional imaging run based on the intensity of the images and the
estimates of movement parameters.

"""

import os
import sys
import yaml
import nifti as ni
import numpy as np
import copy
import scipy.signal as signal
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.pylab  as plab
import traceback

def artifactdetect(inputfile,outputfile=''):
    """Wrapper routine around the outlier detection algorithm

    Input
    -----
    YAML formatted input specification
    
    """
    source = 'mit'
    modulename = 'artifactdetect'
    modulename = ''.join((source,'.',modulename))
    
    # read YAML input
    try:
        if type(inputfile) == type(str()):
            fh = open(inputfile,'r')
            yaml_in = yaml.load(fh)
            fh.close()
        else:
            yaml_in = inputfile
            
        # Ensure correct parameters are being sent for the correct module
        yaml_in = yaml_in[modulename]
    except:
        print("Cannot find or load input file")
        traceback.print_exc()

    # Create directory for module output
    outputlocation = os.path.join(yaml_in['location'],modulename)
    if not os.path.exists(outputlocation):
        os.mkdir(outputlocation)
        
    # initialize YAML output
    output             = copy.deepcopy(yaml_in)
    output['outliers']  = []

    # find outliers for every session
    for  series in range(len(yaml_in['series'])):
        output['outliers'].insert(series,detectoutliers(yaml_in['series'][series],yaml_in['mocopar'][series],yaml_in['param'],outputlocation))

    yaml_out = {}
    yaml_out[modulename] = output

    # Generate filename and write yaml output
    if outputfile == '':
        outputfile = os.path.join(outputlocation,''.join((modulename,'.yaml')))

    # print("writing output to: %s"%outputfile)
    fh = open(outputfile,'wt')
    yaml.dump(output, fh, default_flow_style=False)
    fh.close()

    return yaml_out

def detectoutliers(imgfile,motionfile,params,outputlocation):
    """Core routine for detecting outliers

    """

    # read in motion parameters
    mc_in = np.loadtxt(motionfile)

    # if using differences recompute mc
    if params['flags']['usediff']:
        mc = np.concatenate( (np.zeros((1,6)),np.diff(mc_in,n=1,axis=0)) , axis=0)
    else:
        mc = mc_in
        
    traval = mc[:,0:3]  # translation parameters (mm)
    rotval = mc[:,3:6]  # rotation parameters (rad)
    if params['flags']['usenorm'] == 0:
        tidx = plab.find(np.sum(abs(traval)>params['thresholds']['translation'],1)>0)
        ridx = plab.find(np.sum(abs(rotval)>params['thresholds']['rotation'],1)>0)
    elif  params['flags']['usenorm'] == 1:
        # calculate the norm of the motion parameters
        traval = np.sqrt(np.sum(traval*traval,1))
        rotval = np.sqrt(np.sum(rotval*rotval,1))
        tidx = plab.find(traval>params['thresholds']['translation'])
        ridx = plab.find(rotval>params['thresholds']['rotation'])

    # Display outliers on a plot if available
    if params['flags']['plot'] == 1:
        plt.figure()
        plt.subplot(311)
        plt.plot(traval)
        plt.ylabel('Translation [mm]')
        plt.plot(tidx,np.zeros((np.size(tidx),1)),'o')
    
        plt.subplot(312)
        plt.plot(rotval)
        plt.ylabel('Rotation [rad]')
        plt.plot(ridx,np.zeros((np.size(ridx),1)),'o')

    # read in functional image
    nim = ni.NiftiImage(imgfile)

    # compute global intensity signal
    g = np.zeros((nim.timepoints,1))
    reshaped_data = np.reshape(nim.data,[nim.timepoints,np.prod(nim.volextent)])
    masktype = params['mask']['masktype']
    if  masktype == 'spmglobal':  # spm_global like calculation
        mean_data = np.mean(reshaped_data,1)/8
        for t0 in range(nim.timepoints):
            g[t0] = np.mean(reshaped_data[t0,reshaped_data[t0]>mean_data[t0]])
    elif masktype == 'file': # uses a mask image to determine intensity
        mask = ni.NiftiImage(params['mask']['maskfile'])
        reshaped_mask = np.reshape(mask.data[0],[np.prod(mask.volextent),1])
        g = np.mean(reshaped_data[:,plab.find(reshaped_mask)],1)
    elif masktype == 'thresh': # uses a fixed signal threshold
        for t0 in range(nim.timepoints):
            g[t0] = np.mean(reshaped_data[t0,reshaped_data[t0]>params['mask']['thresh']])
    elif masktype == 'all': # uses all voxels in image
        g = np.mean(reshaped_data,1)
        
    # compute normalized intensity values
    gz = signal.detrend(g,axis=0)       # detrend the signal
    gz = (gz-np.mean(gz))/np.std(gz)    # normalize the detrended signal
    iidx = plab.find(abs(gz)>params['thresholds']['intensity'])

    if params['flags']['plot'] == 1:
        plt.subplot(313)
        plt.plot(gz)
        plt.plot(iidx,gz[iidx],'o')
        plt.show()
        plt.ylabel('Intensity [Z]')
    
    outliers = np.unique(np.union1d(iidx,np.union1d(tidx,ridx)))

    # write output to outputfile
    (filepath,filename) = os.path.split(motionfile)
    outputfile = os.path.join(outputlocation,''.join(('art.',filename,'_outliers.txt')))
    np.savetxt(outputfile, outliers, fmt='%d', delimiter=' ')

    return outputfile

if __name__=="__main__":
    try:
        artifactdetect(sys.argv[1],sys.argv[2])
    except:
        print "An unhandled exception occured, here's the traceback!"
        traceback.print_exc()

