clc;
close all;
clear all
try
RWSsetup()
catch
RWSsetup()
end

%diffusionModel.Geometry.javaObject = SimCompartmentCylinder(PT(.5,.5,.5), .25,.25, 0,2e-3);
extraComparmentDiffusivity=3e-3;
intraCompartmentDiffusivity=1e-3;
transmissionProbability=0;
center=[3 3 3];
radius=1;
latticeDimensions = [6 6 6];
numSpinPackets = 10000; %number of packets


% use the cylinder method
%cylinderLength = 3.5;
%diffusionModel.Geometry{1} = RWScreateCompartmentCylinder(intraCompartmentDiffusivity,transmissionProbability,center,radius,cylinderLength);
%diffusionModel.Geometry{1} = RWScreateCompartmentSphere(intraCompartmentDiffusivity,transmissionProbability,center,radius);
%diffusionModel.latticeDimensions = latticeDimensions;
%diffusionModel.name = [];
%.javaObject = SimCompartmentSphere(PT(.5,.5,.5), .333, 0,1e-3);

% use the mesh method
%d=load('Model_NerveUnmylenated');
%d = load('Model_jon_big_bulge_sds');
%mesh=d.diffusionModel.Geometry.shape;
%diffusionModel.Geometry{1} = RWScreateCompartmentMesh(intraCompartmentDiffusivity,transmissionProbability,mesh);
%diffusionModel.latticeDimensions = latticeDimensions;
%diffusionModel.name = [];

% use the union method
cylinderLength = 13;
radius = 1.4;
diffusionModel.Geometry{1} = RWScreateCompartmentSphere(intraCompartmentDiffusivity,transmissionProbability,center,radius);
radius = 0.5;
diffusionModel.Geometry{2} = RWScreateCompartmentCylinder(intraCompartmentDiffusivity,transmissionProbability,center,radius,cylinderLength);
diffusionModel.latticeDimensions = latticeDimensions;
diffusionModel.name = [];
U = RWScreateCompartmentUnion(intraCompartmentDiffusivity,transmissionProbability,diffusionModel.Geometry{1},diffusionModel.Geometry{2});
%U = RWScreateCompartmentIntersection(intraCompartmentDiffusivity,transmissionProbability,diffusionModel.Geometry{1},diffusionModel.Geometry{2});
clear diffusionModel
diffusionModel.Geometry{1} = U;
diffusionModel.latticeDimensions = latticeDimensions;
%diffusionModel.name = [];



NSteps=5;
diffusionTimeMS = linspace(0,25,NSteps+1); %ms
diffusionTimeMS =diffusionTimeMS (2:end);
% Convert times arguments to stanard units:
diffusionTimeUS = diffusionTimeMS(end)*1000/NSteps; % ms -> us
timeStepUS = 20; %usa

% 1) Initialize Java Simulator
javaSimulator = RWSnewSimulator(extraComparmentDiffusivity,latticeDimensions);
javaSimulator.intenseDebug=0;

% 3) Enter the geometries
currentGeometry=[];
for j=1:length(diffusionModel.Geometry)
    [compartmentIndex(j),currentGeometry]=RWSaddCompartment(javaSimulator,diffusionModel.Geometry{j},currentGeometry);
end

% 4) select number of spins
SpinLocations = 'RandomInside';
RWSaddSpins(javaSimulator,numSpinPackets,SpinLocations);


for ii = 1:length(diffusionTimeMS)

    % 5) simulate the random walks (non-interacting)

    % 6) read out results

    MotionResults = RWSsimulate(javaSimulator,timeStepUS,diffusionTimeUS);

    displacement = sqrt(sum(MotionResults.brownianMotion.^2,2));
    displacementx(:,ii) =  MotionResults.brownianMotion(:,1);
    displacementy(:,ii) =  MotionResults.brownianMotion(:,2);
    displacementz(:,ii) =  MotionResults.brownianMotion(:,3);
    rmsd(ii) = sqrt(mean(displacement.^2)); %um
    rmsd_vec(:,ii) = sqrt([mean(displacementx(:,ii).^2) mean(displacementy(:,ii).^2) mean(displacementz(:,ii).^2)]);
    
    motionInfo = MotionResults;
    encodingDir = [1 0 0; 0 1 0; 0 0 1];
    delta = 16; %ms
    G = 30; %mT/m
    G = G/100; % mT/cm
    GradLobeArea = delta*G; %(mT/cm)*ms
    [adc(ii,:),signal(ii,:,:)] = simulateADC(motionInfo,encodingDir,GradLobeArea);


end

adj_rmsd = rmsd/1000; %mm

if strcmpi(SpinLocations,'RandomOutside')
    D = extraComparmentDiffusivity; %mm^2/s
elseif strcmpi(SpinLocations,'RandomInside')
    D = intraCompartmentDiffusivity; % mm^2/s
end

D = D/1000 % mm^2/ms

%at long diffusion times, the rmsd should be the rmsd of the
%autocorrelation function (ie. a fixed value).  the rmsd should approach
%this assymptote as a function of tdif

figure(1)
plot(diffusionTimeMS,adj_rmsd,'ro'), hold on
plot(diffusionTimeMS,sqrt(6*D*diffusionTimeMS),'k*-'), hold on
xlabel('diffusionTime in ms')
ylabel('rmsd in mm')
legend('rand walk mean','sim from known D')

figure(2)
plot(sqrt(diffusionTimeMS),adj_rmsd,'ro'), hold on
plot(sqrt(diffusionTimeMS),sqrt(6*D*diffusionTimeMS),'k*-'), hold on
xlabel('sqrt of diffusionTime in ms')
ylabel('rmsd in mm')
legend('rand walk mean','sim from known D')

figure(3)
plot(diffusionTimeMS,mean(rmsd_vec(1:2,:),1)/1000,'ro-'), hold on
plot(diffusionTimeMS,mean(rmsd_vec(3,:),1)/1000,'bo-'), hold on
plot(diffusionTimeMS,sqrt(2*D*diffusionTimeMS),'k*'), hold on
xlabel('diffusionTime in ms')
ylabel('rmsd in mm')
legend('rand walk xy','rand walk z','sim from known D')

figure(4)
plot(sqrt(diffusionTimeMS),mean(rmsd_vec(1:2,:),1)/1000,'ro-'), hold on
plot(sqrt(diffusionTimeMS),mean(rmsd_vec(3,:),1)/1000,'bo-'), hold on
plot(sqrt(diffusionTimeMS),sqrt(2*D*diffusionTimeMS),'k*'), hold on
xlabel('sqrt of diffusionTime in ms')
ylabel('rmsd in mm')
legend('rand walk xy','rand walk z','sim from known D')
 
figure(5)
%showCompartment(diffusionModel), hold on
RWSshowGeometry(diffusionModel.Geometry,diffusionModel.latticeDimensions)
plot3(MotionResults.origPositions(:,1),MotionResults.origPositions(:,2),MotionResults.origPositions(:,3),'.')
a = 20;
axis([-a a -a a -a a])
 
figure(6)
%showCompartment(diffusionModel), hold on
RWSshowGeometry(diffusionModel.Geometry,diffusionModel.latticeDimensions)
plot3(MotionResults.resultPositions(:,1),MotionResults.resultPositions(:,2),MotionResults.resultPositions(:,3),'.')
 
 
figure(7)
ii=length(diffusionTimeMS);
a = 3*mean(rmsd_vec(1:2,end));
subplot(2,1,1)
vec = linspace(-a,a,20);
N = histc([displacementx(:,ii);displacementy(:,ii)],vec);
[x,xvec,V2Dproj,C2Dproj,C] = simAxonsQspace_convolutions(radius,'circle');
binArea = trapz(vec,N);
subplot(2,1,1)
bar(vec,N,'histc'), hold on
plot(x,binArea*V2Dproj,'b*-'); hold on
plot(xvec,binArea*C2Dproj,'ro-'); hold on  
axis([-a a 0 numSpinPackets/2])
xlabel('xy displacement in microns')
ylabel('Spin Count')

%% Doesn't work for me
% a = 3*mean(rmsd_vec(3,end));
% vec = linspace(-a,a,20);
% subplot(2,1,2)
% N = histc([displacementz(:,ii)],vec);
% binArea = trapz(vec,N);
% sigma = sqrt(2*D*1e6*diffusionTimeMS(ii));
% bar(vec,N,'histc'), hold on
% plot(vec,binArea*normpdf(vec,0,sigma),'ro-'); hold on 
% axis([-a a 0 numSpinPackets/4])
% xlabel('z displacement in microns')
% ylabel('Spin Count')


 
   
      


    
clear CC

import bl.diffusion.*;
sp=diffusionModel.Geometry{1}.javaObject;
for j=1:size(MotionResults.resultPositions,1), CC(j) = sp.contains(PT(MotionResults.resultPositions(j,:))); end
sum(CC==0)