package birnid_gen;

import java.io.*;
import java.util.*;
import java.security.*;

public final class BIRNIDManager {
  protected Random rng;
  protected MessageDigest md;
  protected List linksTable;
  protected Map linksMap;
  protected Map digestMap;
  public final static int NUMBER_LEN = 8;

  public BIRNIDManager() throws NoSuchAlgorithmException {
    md = MessageDigest.getInstance("MD5");
    linksTable = new LinkedList();
    linksMap = new HashMap();
    digestMap = new HashMap();
  }


  public void prepare() throws Exception {
    rng = SecureRandom.getInstance("SHA1PRNG");
    ((SecureRandom) rng).setSeed( SecureRandom.getSeed(20) );

  }

  public void createLinks(List clinicalIDs, String prefix, int numDigits) {
    Iterator iter = clinicalIDs.iterator();
    while (iter.hasNext()) {
      String clinicalID  = (String) iter.next();
      if (findLink(clinicalID) == null) {
        Link link = createLink(clinicalID, prefix, numDigits);
        if ( digestMap.get(link.digest) == null) {
          linksTable.add(link);
          linksMap.put(link.birnID, link);
          digestMap.put(toHex(link.digest), link);
        }
      }
    }
  }


  public Link createLink(String clinicalID, String prefix, int numDigits) {
     byte[] digest = createMessageDigest(clinicalID);
     String birnID = createBIRNId(prefix, numDigits);
     Link link = new Link(digest, birnID);
     return link;
  }


  public String  createBIRNId(String prefix, int numDigits) {
    String birnID =  prefix + createRandomId(numDigits);
    while( linksMap.get(birnID) != null) {
       birnID =  prefix + createRandomId(numDigits);
    }
    return birnID;
  }

  protected  String createRandomId(int length) {

     long maxNumber = 1;
     int i = length;
     while(i-- > 0) {
       maxNumber *= 10;
       if (maxNumber > Integer.MAX_VALUE) {
         break;
       }
     }

     StringBuffer buf = new StringBuffer(length);
     int number = rng.nextInt((int) maxNumber);
     buf.append( formatNumber(length, number) );
     return buf.toString();
   }


   protected String formatNumber(int noDigits, int number) {
     String s = Integer.toString(number);
     if (noDigits > s.length()) {
        // zero padding
       StringBuffer buf = new StringBuffer(noDigits);
       int diff = noDigits - s.length();
       for(int i = 0; i < diff; ++i)
         buf.append('0');
       buf.append(s);
       return buf.toString();
     } else
       return s;
   }


  protected byte[] createMessageDigest(String clinicalID) {
     md.update(clinicalID.getBytes());
     byte[] bytes = md.digest();
     return bytes;
  }


  protected static String toHex(byte[] bytes) {
    StringBuffer buf = new StringBuffer();
    for(int i = 0; i < bytes.length; ++i) {
      int byteVal = (bytes[i] < 0) ? bytes[i] + 256 : bytes[i];
      // System.out.println(""+ byteVal + ", hex="+ Integer.toHexString( byteVal) );
      String s = Integer.toHexString( byteVal );
      s = (s.length() == 1) ? "0"+ s : s;
      buf.append( s );
    }
    return buf.toString();
  }

  protected static byte[] hexToBytes(String s) {
      char[] carr = s.toLowerCase().toCharArray();
      byte[] bytes = new byte[ carr.length / 2 ];

      for(int i = 0; i < bytes.length; ++i) {
        StringBuffer buf = new StringBuffer(3);
        buf.append(carr[2*i]);
        buf.append(carr[2*i + 1]);
        int val = Integer.parseInt(buf.toString(), 16);
        bytes[i] = (byte) val;
      }

      return bytes;
  }


  protected void writeLinkTable(String filename) throws IOException {
    PrintWriter out = null;
    try {
      out = new PrintWriter( new BufferedWriter( new FileWriter(filename) ) );
      Iterator iter = linksTable.iterator();
      while (iter.hasNext()) {
        Link link = (Link) iter.next();
        out.println( toHex(link.digest) + "," + link.birnID);
      }

    } finally {
      if (out != null)
        try { out.close(); } catch(Exception x) {}
    }
  }


  protected void loadLinkTable(String filename) throws IOException {
    BufferedReader in =null;
    try {
      in = new BufferedReader( new FileReader(filename) );
      linksTable.clear();
      String line = null;
      while( (line = in.readLine()) != null) {
         StringTokenizer stok = new StringTokenizer(line,",");
         String digest = stok.nextToken();
         String birnID = stok.nextToken();

         byte[] digestBytes = hexToBytes(digest);
         Link link = new Link(digestBytes, birnID);
         linksTable.add(link);
         linksMap.put(birnID, link);
         digestMap.put( digest, link);

      }

    } finally {
      if (in != null)
        try { in.close(); } catch(Exception x) {}
    }
  }

  public List getClinicalIds(String filename) throws IOException {
    List clinicalIds = new LinkedList();
    BufferedReader in =null;
    try {
      in = new BufferedReader( new FileReader(filename) );
      String line = null;
      while( (line = in.readLine()) != null) {
        line = line.trim();
        if (line.length() > 0)
          clinicalIds.add(line);
      }

    } finally {
      if (in != null)
        try { in.close(); } catch(Exception x) {}
    }
    return clinicalIds;
  }


  public Link findLink(String clinicalId) {
    byte[] myDigest = createMessageDigest(clinicalId);
    Link link = (Link) digestMap.get(toHex(myDigest) );

/*
    Iterator iter = linksTable.iterator();
    while (iter.hasNext()) {
      Link link = (Link) iter.next();
      if ( isSame(link.digest, myDigest) ){
         return link;
      }
    }
*/
    return link;
  }


  public static class Link {
    byte[] digest;
    String birnID;

    public Link(byte[] digest, String birnID) {
      this.digest = digest;
      this.birnID = birnID;
    }
  }

  public boolean isSame(byte[] arr1, byte[] arr2) {
    if (arr1.length != arr2.length)
      return false;
    for(int i = 0; i < arr1.length; ++i) {
      if (arr1[i] != arr2[i])
        return false;
    }
    return true;
  }

  public void process(Arguments args)  throws Exception {
    if (args.cmd == Arguments.CREATE) {
      prepare();
    }

    File f = new File(args.linkTableFilename);
    if (f.exists()) {
      loadLinkTable(args.linkTableFilename);
    }

    if (args.cmd == Arguments.FIND) {
        Link link = findLink(args.clinicalId);
        if (link != null) {
           System.out.println("Birn ID="+ link.birnID+", hashed clinical ID="+ toHex(link.digest));
        } else {
          System.out.println("No match for "+ args.clinicalId);
        }
        return;
    }

    if (args.clinicalIdsFilename != null) {
        List clinicalIds = getClinicalIds(args.clinicalIdsFilename);
        createLinks(clinicalIds, args.prefix, NUMBER_LEN);

    } else if (args.clinicalId != null) {
      if (findLink(args.clinicalId) == null) {
        Link link = createLink(args.clinicalId, args.prefix, NUMBER_LEN);

        linksTable.add(link);
        linksMap.put(link.birnID, link);
        digestMap.put( toHex(link.digest), link);
      }
    }


    writeLinkTable(args.linkTableFilename);
  }

  public void test() throws Exception {
    byte[] digest = createMessageDigest("bn4444");
    String hexStr = toHex(digest);
    System.out.println("digest="+ hexStr);
    byte[] back = hexToBytes(hexStr);
    boolean same = isSame(digest, back);
    System.out.println("hexToBytes test result: "+ same);

    System.out.println("birnID="+ createBIRNId("UCSD",8) );
  }

public static class Arguments {
   String linkTableFilename;
   String clinicalIdsFilename;
   String prefix;
   String clinicalId;
   int cmd;
   public final static int CREATE = 1;
   public final static int FIND = 2;

   public Arguments() {
     cmd = CREATE;
   }


}

 public static void usage() {
    System.err.println("Usage:");
    System.err.println("\tjava birnid_gen.BIRNIDManager <-create|find> <args>");
    System.err.println("for BIRN ID creation");
    System.err.println("\tjava birnid_gen.BIRNIDManager -create -p <prefix> -l <linkTableFilename> \\");
    System.err.println("\t\t-cf <clinicalIdsFilename> | -c <clinicalId>]");
    System.err.println("to find the matching message digest for the patient (clinical) ID");
    System.err.println("\tjava birnid_gen.BIRNIDManager -find -c <clinicalId> -l <linkTableFilename> ");

    System.exit(1);
 }



 public static Arguments parseArguments(String[] args) {
   if (args.length != 7 && args.length != 5)
     usage();

   Arguments argObj = new Arguments();
   int i = 0;
   do {
     if (args[i].equalsIgnoreCase("-create")) {
      argObj.cmd = Arguments.CREATE;
      i++;
     } else if (args[i].equalsIgnoreCase("-find") ) {
       argObj.cmd = Arguments.FIND;
      i++;
     } else if (args[i].equalsIgnoreCase("-l") ) {

         argObj.linkTableFilename =  args[i+1];
         i += 2;
     } else if (args[i].equalsIgnoreCase("-cf") ) {
         argObj.clinicalIdsFilename =  args[i+1];
         i += 2;
     } else  if (args[i].equalsIgnoreCase("-p") ) {
         argObj.prefix =  args[i+1];
         i += 2;
     } else if (args[i].equalsIgnoreCase("-c") ) {
         argObj.clinicalId =  args[i+1];
         i += 2;
     }
   } while( i < args.length);

   if (argObj.cmd == Arguments.CREATE && (argObj.linkTableFilename == null ||
       (argObj.clinicalIdsFilename == null && argObj.clinicalId == null) ||
       argObj.prefix == null) )
   {
     usage();
   } if (argObj.cmd == Arguments.FIND && argObj.clinicalId == null) {
     usage();
   }

   return argObj;
 }

  public static void main(String[] args) {
    try {
      Arguments argObj = parseArguments(args);
      BIRNIDManager man = new BIRNIDManager();
      man.process(argObj);

    } catch(Exception x) {
      x.printStackTrace();
    }

  }

}