package clinical.test;

import java.io.FileInputStream;
import java.io.IOException;
import java.sql.Connection;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.dbunit.DBTestCase;
import org.dbunit.database.DatabaseConnection;
import org.dbunit.database.IDatabaseConnection;
import org.dbunit.dataset.DataSetException;
import org.dbunit.dataset.FilteredDataSet;
import org.dbunit.dataset.IDataSet;
import org.dbunit.dataset.ITableIterator;
import org.dbunit.dataset.ITableMetaData;
import org.dbunit.dataset.filter.ExcludeTableFilter;
import org.dbunit.dataset.filter.IncludeTableFilter;
import org.dbunit.dataset.stream.IDataSetProducer;
import org.dbunit.dataset.stream.StreamingDataSet;
import org.dbunit.dataset.xml.FlatXmlDataSet;
import org.dbunit.dataset.xml.FlatXmlProducer;
import org.dbunit.operation.DatabaseOperation;
import org.xml.sax.InputSource;

import clinical.event.SequenceCreateEvent;
import clinical.event.SequenceCreateEventListener;
import clinical.test.framework.DBTestDataExtractionUtils;
import clinical.test.framework.DeleteIfExistsOperation;
import clinical.test.framework.InsertIfMissingOperation;
import clinical.web.ConnectionSupportMixin;
import clinical.web.ISequenceHelper;
import clinical.web.MinimalServiceFactory;
import clinical.web.ServiceFactory;
import clinical.web.common.UserInfo;
import clinical.web.exception.DBPoolServiceException;
import clinical.web.services.DBPoolService;

/**
 * @author I. Burak Ozyurt
 * @version $Id: MyDBTestCase.java 365 2011-05-05 20:04:18Z bozyurt $
 */

public abstract class MyDBTestCase extends DBTestCase implements
      SequenceCreateEventListener {
   protected ConnectionSupportMixin mixin;
   protected Connection con = null;
   protected Map<String, List<String>> seqMap = new HashMap<String, List<String>>();
   protected List<String> tableNames = new ArrayList<String>(3);
   protected String rootDir;
   protected String beforeFlatSetFile;
   protected String afterFlatSetFile;
   protected String diffFlatSetFile;

   protected String schemaName;
   protected String snapshotDir;
   protected boolean verbose = false;

   public MyDBTestCase(String testName, String propsFile,
         boolean skipQueryProcessorCache) throws IOException {
      super(testName);
      mixin = new ConnectionSupportMixin(propsFile, skipQueryProcessorCache);
      rootDir = getProperty("test.data.rootdir");
      snapshotDir = getProperty("test.data.snapshot.rootdir");
      beforeFlatSetFile = rootDir + "/" + getProperty("test.beforeflatsetfile");
      afterFlatSetFile = rootDir + "/" + getProperty("test.afterflatsetfile");
      diffFlatSetFile = rootDir + "/"
            + getProperty("test.differenceflatsetfile");
      schemaName = getProperty("schema.name");
   }

   protected String prepareSnapshotFile() throws Exception {
      String beforeFSFile = snapshotDir + "/" + getName() + "_before.xml";
      DBTestDataExtractionUtils.exportFullDB2FlatSet(getConnection()
            .getConnection(), beforeFSFile, schemaName);

      return beforeFSFile;
   }

   protected String prepareFullSetDiffSetFile() throws Exception {
      String beforeFSFile = snapshotDir + "/" + getName() + "_before.xml";
      String afterFSFile = snapshotDir + "/" + getName() + "_after.xml";
      String diffFSFile = snapshotDir + "/" + getName() + "_diff.xml";

      DBTestDataExtractionUtils.exportFullDB2FlatSet(getConnection()
            .getConnection(), afterFSFile, schemaName);

      DBTestDataExtractionUtils.findFlatSetDiffs(beforeFSFile, afterFSFile,
            diffFSFile, verbose);
      return diffFSFile;
   }

   protected void takeEffectedTablesSnapshot() throws Exception {
      DBTestDataExtractionUtils.exportDB2FlatSet(getConnection()
            .getConnection(), beforeFlatSetFile, tableNames, schemaName);
      // System.out.println("snapshot before test:" + beforeFlatSetFile);
   }

   protected void restoreEffectedTables() throws Exception {
      DBTestDataExtractionUtils.exportDB2FlatSet(getConnection()
            .getConnection(), afterFlatSetFile, tableNames, schemaName);

      DBTestDataExtractionUtils.findFlatSetDiffs(beforeFlatSetFile,
            afterFlatSetFile, diffFlatSetFile, verbose);

      IDataSet dataSet = new FlatXmlDataSet(
            new FileInputStream(diffFlatSetFile));
      // System.out.println("snapshot diff :" + diffFlatSetFile);

      // dataSet = new FilteredDataSet(new
      // DatabaseSequenceFilter(getConnection()),
      // dataSet);
      DatabaseOperation.DELETE.execute(getConnection(), dataSet);
   }

   protected void restoreEffectedTables(List<String> nonPKtables)
         throws Exception {
      DBTestDataExtractionUtils.exportDB2FlatSet(getConnection()
            .getConnection(), afterFlatSetFile, tableNames, schemaName);

      DBTestDataExtractionUtils.findFlatSetDiffs(beforeFlatSetFile,
            afterFlatSetFile, diffFlatSetFile, verbose);

      IDataSet dataSet = new FlatXmlDataSet(
            new FileInputStream(diffFlatSetFile));
      // System.out.println("snapshot diff :" + diffFlatSetFile);
      if (nonPKtables == null || nonPKtables.isEmpty()) {
         DatabaseOperation.DELETE.execute(getConnection(), dataSet);
      } else {

         String[] tnArr = new String[nonPKtables.size()];
         tnArr = nonPKtables.toArray(tnArr);
         IDataSet filteredSet = new FilteredDataSet(new ExcludeTableFilter(
               tnArr), dataSet);

         if (verbose) {
            System.out.println("restoreEffectedTables: filteredSet");
            showDataSetTableOrder(filteredSet);
         }

         DatabaseOperation.DELETE.execute(getConnection(), filteredSet);
         // now cleanup the tables without any primary key also
         filteredSet = new FilteredDataSet(new IncludeTableFilter(tnArr),
               dataSet);
         new DeleteIfExistsOperation().execute(getConnection(), filteredSet);

      }
   }

   protected List<String> getEffectedTables(String testStateFlatSetFile)
         throws Exception {
      IDataSetProducer producer = new FlatXmlProducer(new InputSource(
            testStateFlatSetFile));
      IDataSet dataSet = new StreamingDataSet(producer);
      Set<String> uniqTables = new LinkedHashSet<String>();
      for (ITableIterator it = dataSet.iterator(); it.next();) {
         String tableName = it.getTableMetaData().getTableName();
         if (!uniqTables.contains(tableName))
            uniqTables.add(tableName);
      }
      return new ArrayList<String>(uniqTables);
   }

   protected void prepareState(String testStateFlatSetFile,
         List<String> nonPKtables) throws Exception {
      if (nonPKtables == null || nonPKtables.isEmpty()) {
         IDataSetProducer producer = new FlatXmlProducer(new InputSource(
               testStateFlatSetFile));
         IDataSet dataSet = new StreamingDataSet(producer);
         DatabaseOperation.REFRESH.execute(getConnection(), dataSet);
      } else {
         IDataSet dataSet = new FlatXmlDataSet(new FileInputStream(
               testStateFlatSetFile));
         if (verbose) {
            showDataSetTableOrder(dataSet);
         }

         // new InsertIfMissingOperation().execute(getConnection(), dataSet);

         String[] tnArr = new String[nonPKtables.size()];
         tnArr = nonPKtables.toArray(tnArr);
         dataSet = new FilteredDataSet(new ExcludeTableFilter(tnArr), dataSet); //
         DatabaseOperation.REFRESH.execute(getConnection(), dataSet);
         new InsertIfMissingOperation().execute(getConnection(), dataSet);

         dataSet = new FlatXmlDataSet(new FileInputStream(testStateFlatSetFile));
         dataSet = new FilteredDataSet(new IncludeTableFilter(tnArr), dataSet);

         new InsertIfMissingOperation().execute(getConnection(), dataSet);

      }
   }

   private void showDataSetTableOrder(IDataSet dataSet) throws DataSetException {
      ITableIterator it = dataSet.iterator();
      while (it.next()) {
         ITableMetaData tmd = it.getTableMetaData();
         System.out.println(tmd.getTableName());
         System.out.println("# rows:" + it.getTable().getRowCount());
      }
   }

   @Override
   protected void setUp() throws Exception {
      MinimalServiceFactory.setMimimalOpMode(true);
      ServiceFactory.setMimimalOpMode(true);
      mixin.startup();
      con = mixin.getConnection();
      ISequenceHelper seqHelper = ServiceFactory.getSequenceHelper(mixin
            .getDbID());
      seqHelper.addListener(this);
      super.setUp();
   }

   @Override
   protected void tearDown() throws Exception {
      super.tearDown();
      if (con != null) {
         mixin.releaseConnection(con);
      }
      seqMap.clear();
      mixin.shutdown();
   }

   @Override
   protected DatabaseOperation getSetUpOperation() throws Exception {
      return DatabaseOperation.NONE;
   }

   @Override
   protected DatabaseOperation getTearDownOperation() throws Exception {
      return DatabaseOperation.NONE;
   }

   @Override
   protected IDatabaseConnection getConnection() throws Exception {
      String schemaName = mixin.getProperty("schema.name");
      if (schemaName != null) {
         return new DatabaseConnection(con, schemaName);
      } else {
         return new DatabaseConnection(con);
      }
   }

   @Override
   protected IDataSet getDataSet() throws Exception {
      return null;
   }

   public DBPoolService getDbPoolService() {
      return mixin.getDbPoolService();
   }

   public String getDbID() {
      return mixin.getDbID();
   }

   public String[] getDbIDs() {
      return mixin.getDbIDs();
   }

   public UserInfo getUi() {
      return mixin.getUi();
   }

   public boolean isVerbose() {
      return verbose;
   }

   public void setVerbose(boolean verbose) {
      this.verbose = verbose;
   }

   public Connection getPooledConnection() throws DBPoolServiceException {
      return mixin.getConnection();
   }

   public void releasePooledConnection(Connection con)
         throws DBPoolServiceException {
      mixin.releaseConnection(con);
   }

   public String getProperty(String propName) {
      return mixin.getProperty(propName);
   }

   // interface methods
   public void sequenceNoCreated(SequenceCreateEvent event) {
      String key = event.getTableName() + "_" + event.getColumnName();
      List<String> uidList = seqMap.get(key);
      if (uidList == null) {
         uidList = new ArrayList<String>(1);
         seqMap.put(key, uidList);
      }
      uidList.add(event.getSequenceNumber());
   }
}
