/*
 *  Type-ARQuE - the experimental SPARQL to SQL translator.
 *  Copyright (C) 2010  Sami Kiminki / Aalto University
 *
 *  This program is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <sstream>

#include <mysql.h>

#include "AQLModel.h"
#include "AQLException.h"
#include "FormatUtils.h"
#include "SQLBackendMySQL.h"
#include "SQLBackendFunctions.h"
#include "Messages.h"
#include "AQLSupport.h"
#include "ExpressionWriter.h"

namespace TypeRQInternal
{
   using namespace TypeRQ;

   struct MySQLContext
   {
      MYSQL *mysql;
      SQLBackendCapabilities caps;


      MySQLContext() : mysql(0)
      {
         caps.createTableSuffix=" ENGINE=InnoDB";
      }

      void checkMysqlError()
      {
         if (mysql_errno(mysql))
         {
            // error
            throw AQLException("MySQL error: %s", mysql_error(mysql));
         }
      }

      void requireConnection()
      {
         if (!mysql)
         {
            throw AQLException("MySQL: must be connected");
         }
      }

      void executeStatement(const std::string &statement)
      {
         requireConnection();

         mysql_real_query(mysql, statement.data(), statement.length());
         checkMysqlError();

         MYSQL_RES *res=mysql_use_result(mysql);
         checkMysqlError();

         if (res!=NULL)
         {
            while (mysql_fetch_row(res)!=NULL) ; // read all rows
            mysql_free_result(res); 
         }
      }
   };

   class MySQLResultIteratorImpl : public SQLResultIteratorImpl
   {
   private:
      MySQLContext &ctx;
      bool initialized;
      MYSQL_RES *result;
      MYSQL_ROW row;
      unsigned long *row_lengths;

      std::map<std::string, int> fieldNames;

      std::string tmp;
      int numFields;

   protected:
      void ensureInitialized()
      {
         if (!initialized)
         {
            result = mysql_store_result(ctx.mysql);
            if (result)  // there are rows
            {
               numFields = 0;

               int pos=0;
               MYSQL_FIELD *field;
               while ((field=mysql_fetch_field(result))!=NULL)
               {
                  ++numFields;
                  fieldNames.insert(std::make_pair<std::string, int>(field->name, pos++));
               }

               // retrieve rows, then call mysql_free_result(result)
            }
            else  // mysql_store_result() returned nothing; should it have?
            {
               if(mysql_field_count(ctx.mysql) != 0)
               {
                  // mysql_store_result() should have returned data
                  ctx.checkMysqlError();
               }
            }
            initialized=true;
         }
      }

   public:
      MySQLResultIteratorImpl(MySQLContext &_ctx) : ctx(_ctx), initialized(false), result(0) {}

      ~MySQLResultIteratorImpl()
      {
         if (result)
         {
            mysql_free_result(result);
         }
      }

      int getColCount()
      {
         ensureInitialized();
         //print(OL_DEBUG, "Col count=%d\n", numFields);
         return numFields;
      }

      bool nextRow()
      {
         ensureInitialized();
         if (!result) return false; // no result - not a select statement, perhaps

         row=mysql_fetch_row(result);
         if (!row) return false;

         row_lengths=mysql_fetch_lengths(result);
         return true;
      }

      const std::string &getString(int col)
      {
         char *colData=row[col];
         //unsigned long len=row_lengths[col];
         //print(OL_DEBUG, "getString(%d), colData=%p, length=%lu\n", col, colData, len);
         if (colData!=NULL)
         {
            tmp.assign(colData, colData+row_lengths[col]);
         }
         else {
            tmp.clear();
         }
         return tmp;
      }

      const std::string &getString(const std::string &col)
      {
         int pos=fieldNames.at(col);
         return getString(pos);
      }

      int64_t getInt(int col)
      {
         getString(col);

         std::stringstream s(tmp);
         int64_t ret;
         s >> ret;
         return ret;
      }

      int64_t getInt(const std::string &col)
      {
         int pos=fieldNames.at(col);
         return getInt(pos);
      }

      double getDouble(int col)
      {
         getString(col);

         std::stringstream s(tmp);
         double ret;
         s >> ret;
         return ret;
      }

      double getDouble(const std::string &col)
      {
         int pos=fieldNames.at(col);
         return getDouble(pos);
      }

      bool isNull(int col)
      {
         return row[col] == NULL;
      }
   };

   class MySQLTransactionImpl : public TransactionImpl, protected AQLDebugBase
   {
   private:
      MySQLContext &ctx;
      bool active;
   public:
      MySQLTransactionImpl(MySQLContext &_ctx) : ctx(_ctx), active(false) {}

      bool isActive() const
      {
         return active;
      }

      void begin()
      {
         ctx.executeStatement("START TRANSACTION");
         active=true;
      }

      void commit()
      {
         ctx.executeStatement("COMMIT");
         active=false;
      }

      void rollback()
      {
         ctx.executeStatement("ROLLBACK");
         active=false;
      }
   };
}

namespace {
   using namespace TypeRQ;

   class SQLConvertMapping : public SQLFunctionSelectionMapping
   {
   protected:
      const std::string sqlTypeName;

      void startSQLExp(const AQLFunctionExpr &expr, SQLExpressionWriter &ew, const SQLFunctionType &functionType)
      {
         ew.startFunction("CONVERT");
      }

      void endSQLExp(const AQLFunctionExpr &expr, SQLExpressionWriter &ew, const SQLFunctionType &functionType)
      {
         ew.nextTerm();
         ew << sqlTypeName;
         ew.endOp();
      }


   public:
      SQLConvertMapping(const std::string &_sparqlFn, AQLTypeSet::ExprType _returnType, const char *_sqlTypeName) :
         SQLFunctionSelectionMapping(_sparqlFn),
         sqlTypeName(_sqlTypeName)
      {
         SQLFunctionType *functionType=new SQLFunctionType(_returnType);
         (*functionType) << AQLTypeSet::ANY;

         SQLFunctionSelectionMapping::operator << (functionType);
      }    
   };

   class MySQLToDatetimeMapping : public SQLFixedFunctionMapping
   {
   protected:
      void endSQLExp(const AQLFunctionExpr &expr, SQLExpressionWriter &ew, const SQLFunctionType &ft)
      {
         ew.nextTerm();
         ew << "'%Y-%m-%dT%H:%i:%s.%fZ'";
         SQLFixedFunctionMapping::endSQLExp(expr, ew, ft);
      }

   public:
      MySQLToDatetimeMapping(const char *_sparqlFn) : SQLFixedFunctionMapping(_sparqlFn, "str_to_date")
      {
         (*this) << &(*new SQLFunctionType(AQLTypeSet(AQLTypeSet::ANY)) << AQLTypeSet(AQLTypeSet::DATETIME));
      }
   };


   SQLFunctionMap *instantiateMysqlFunctionMap(SQLBackend &backend)
   {
      SQLFunctionMapping *mysqlFunctions[]={
         new SQLConvertMapping("builtin:to-integer", AQLTypeSet::INTEGER, "SIGNED INTEGER"),
         new MySQLToDatetimeMapping("builtin:to-datetime"),
         0,
      };
      return new SQLDelegatingFunctionMap(getStandardSQLFunctionMap(), mysqlFunctions);
   }

}

namespace TypeRQ {
   using namespace TypeRQInternal;

   SQLBackendMySQL::SQLBackendMySQL() :
      functionMap(instantiateMysqlFunctionMap(*this)),
      mysqlContext(new MySQLContext) {}

   SQLBackendMySQL::~SQLBackendMySQL()
   {
      disconnect();

      delete functionMap;
      delete mysqlContext;
   }

   const std::string &SQLBackendMySQL::getName()
   {
      static const std::string name("mysql");
      return name;
   }

   SQLFunctionMap &SQLBackendMySQL::getFunctionMap()
   {
      return *functionMap;
   }

   void SQLBackendMySQL::connect(const std::string &connectionString)
   {
      disconnect();

      mysqlContext->mysql=mysql_init(NULL);
      mysql_options(mysqlContext->mysql, MYSQL_SET_CHARSET_NAME, "utf8");

      const char *host=NULL;
      const char *user=NULL;
      const char *pass=NULL;
      const char *db=NULL;
      unsigned int port=0;

      std::map<std::string, std::string> params;
      parseConnectionString(connectionString, params);

      std::map<std::string, std::string>::const_iterator param_i;
      param_i=params.find("host");
      if (param_i!=params.end()) host=param_i->second.c_str();

      param_i=params.find("user");
      if (param_i!=params.end()) user=param_i->second.c_str();

      param_i=params.find("pass");
      if (param_i!=params.end()) pass=param_i->second.c_str();

      param_i=params.find("db");
      if (param_i!=params.end()) db=param_i->second.c_str();

      param_i=params.find("port");
      if (param_i!=params.end())
      {
         std::istringstream iss(param_i->second);
         iss >> port;
      }


      mysql_real_connect(mysqlContext->mysql,
                         host, user, pass, db, port, NULL,
                         CLIENT_IGNORE_SIGPIPE);
      try {
         mysqlContext->checkMysqlError();

         mysql_autocommit(mysqlContext->mysql, 0);
         mysqlContext->checkMysqlError();
      }
      catch (...) {
         disconnect();
         throw;
      }                             
   }

   void SQLBackendMySQL::disconnect()
   {
      if (!mysqlContext->mysql) return; // nothing to do

      mysql_close(mysqlContext->mysql);
      mysqlContext->mysql=0;
   }

   void SQLBackendMySQL::requireConnection()
   {
      mysqlContext->requireConnection();
   }

   void SQLBackendMySQL::executeStatement(const std::string &statement)
   {
      mysqlContext->executeStatement(statement);
   }

   SQLResultIterator SQLBackendMySQL::executeQuery(const std::string &statement)
   {
      requireConnection();

      mysql_real_query(mysqlContext->mysql, statement.data(), statement.length());
      mysqlContext->checkMysqlError();

      SQLResultIteratorImpl *impl=new MySQLResultIteratorImpl(*mysqlContext);

      return SQLResultIterator(impl);
   }

   std::string SQLBackendMySQL::escapeString(const std::string &stringLiteral)
   {
      requireConnection();
      char buffer[stringLiteral.size()*2+1];

      mysql_real_escape_string(mysqlContext->mysql,
                               buffer,
                               stringLiteral.data(),
                               stringLiteral.size());
      return std::string(buffer);
   }

   Transaction SQLBackendMySQL::newTransaction()
   {
      MySQLTransactionImpl *impl=new MySQLTransactionImpl(*mysqlContext);
      impl->begin();
      return Transaction(impl);
   }

   const SQLBackendCapabilities &SQLBackendMySQL::getCapabilities()
   {
      return mysqlContext->caps;
   }

   std::string SQLBackendMySQL::getSQLNameForType(AQLTypeSet::ExprType exprType, size_t sizeLimit)
   {
      switch (exprType)
      {
         case AQLTypeSet::DATETIME:
            return "DATETIME";

         case AQLTypeSet::INTEGER:
            return "INTEGER";

         default:
            return SQLBackend::getSQLNameForType(exprType, sizeLimit);
      }
   }
}
