/*
 *  Type-ARQuE - the experimental SPARQL to SQL translator.
 *  Copyright (C) 2010  Sami Kiminki / Aalto University
 *
 *  This program is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <ostream>
#include <sstream>

#include "AQLException.h"
#include "AQLSupport.h"
#include "AQLModel.h"
#include "FormatUtils.h"
#include "ListPrettyPrinter.h"
#include "SQLBackendFunctions.h"

namespace {

   using namespace TypeRQ;

   void rewriteWalk(AQLExpr *&expr, AQLExpressionRewriter &rewriter, bool childrenFirst);

   bool rewriteWalk(AQLExpr *&expr, AQLExpressionRewriter &rewriter, bool childrenFirst,
                    bool dorewrite, bool dorecurse)
   {
      AQLExpr *rewritten=0;

      do {
         AQLJunctionCriterion *junctionExpr=dynamic_cast<AQLJunctionCriterion *>(expr);
         if (junctionExpr)
         {
            if (dorecurse)
            {
               for (std::list<AQLExpr *>::iterator i=junctionExpr->terms.begin();
                    i!=junctionExpr->terms.end(); ++i)
               {
                  rewriteWalk(*i, rewriter, childrenFirst);
               }
            }

            if (dorewrite)
            {
               rewritten=rewriter.rewrite(*junctionExpr);
            }
            break;
         }

         AQLComparisonCriterion *comparisonExpr=dynamic_cast<AQLComparisonCriterion *>(expr);
         if (comparisonExpr)
         {
            if (dorecurse)
            {
               if (comparisonExpr->left) rewriteWalk(comparisonExpr->left, rewriter, childrenFirst);
               if (comparisonExpr->right) rewriteWalk(comparisonExpr->right, rewriter, childrenFirst);
            }
            if (dorewrite)
            {
               rewritten=rewriter.rewrite(*comparisonExpr);
            }
            break;
         }

         AQLPropertyExpr *propertyExpr=dynamic_cast<AQLPropertyExpr *>(expr);
         if (propertyExpr)
         {
            if (dorewrite)
            {
               rewritten=rewriter.rewrite(*propertyExpr);
            }
            break;
         }

         AQLLiteralExpr *literalExpr=dynamic_cast<AQLLiteralExpr *>(expr);
         if (literalExpr)
         {
            if (dorewrite)
            {
               rewritten=rewriter.rewrite(*literalExpr);
            }
            break;
         }

         AQLNullExpr *nullExpr=dynamic_cast<AQLNullExpr *>(expr);
         if (nullExpr)
         {
            if (dorewrite)
            {
               rewritten=rewriter.rewrite(*nullExpr);
            }
            break;
         }

         AQLFunctionExpr *functionExpr=dynamic_cast<AQLFunctionExpr *>(expr);
         if (functionExpr)
         {
            if (dorecurse)
            {
               for (AQLFunctionExpr::arg_list_type::iterator i=functionExpr->arguments.begin();
                    i!=functionExpr->arguments.end();
                    ++i)
               {
                  rewriteWalk(*i, rewriter, childrenFirst);
               }
            }
            if (dorewrite)
            {
               rewritten=rewriter.rewrite(*functionExpr);
            }
            break;
         }

         AQLNotExpression *notExpr=dynamic_cast<AQLNotExpression *>(expr);
         if (notExpr)
         {
            if (dorecurse && notExpr->expr) rewriteWalk(notExpr->expr, rewriter, childrenFirst);

            if (dorewrite)
            {
               rewritten=rewriter.rewrite(*notExpr);
            }
            break;
         }

         AQLTypecastExpression *typecastExpr=dynamic_cast<AQLTypecastExpression *>(expr);
         if (typecastExpr)
         {
            if (dorecurse && typecastExpr->expr) rewriteWalk(typecastExpr->expr, rewriter, childrenFirst);

            if (dorewrite)
            {
               rewritten=rewriter.rewrite(*typecastExpr);
            }
            break;
         }

         AQLCustomExpr *customExpr=dynamic_cast<AQLCustomExpr *>(expr);
         if (customExpr)
         {
            if (dorewrite)
            {
               rewritten=rewriter.rewrite(*customExpr);
            }
            break;
         }

         throw AQLException("rewriteWalk: unhandled expr type %s", expr->getTypeName());

      } while (false);

      if (dorewrite && rewritten)
      {
         expr=rewritten;
         return true;
      }
      return false;
   }

   void rewriteWalk(AQLExpr *&expr, AQLExpressionRewriter &rewriter, bool childrenFirst)
   {
      if (!expr) return; // null expression, nothing to rewrite

      if (childrenFirst)
      {
         // recurse children first, then rewrite
         rewriteWalk(expr, rewriter, childrenFirst, true, true);
      }
      else {
         // invoke rewrite as long as something was rewritten
         while (rewriteWalk(expr, rewriter, childrenFirst, true, false)) {}

         // then recurse into children
         rewriteWalk(expr, rewriter, childrenFirst, false, true);
      }
   }

   void rewriteWalk(AQLJoinGroup &join, AQLExpressionRewriter &rewriter, bool childrenFirst)
   {
      rewriteWalk(join.criterion, rewriter, childrenFirst);
      for (AQLJoinGroup::nested_join_list::iterator i=join.nestedJoins.begin();
           i!=join.nestedJoins.end(); ++i)
      {
         AQLJoinGroup *nestedJoin=*i;
         rewriteWalk(nestedJoin->criterion, rewriter, childrenFirst);
      }
   }
}

namespace TypeRQ {


   AQLPrinterVisitor::AQLPrinterVisitor(std::ostream &_os) : os(* new ListPrettyPrinter(_os))
   {
      // nothing here
   }

   AQLPrinterVisitor::~AQLPrinterVisitor()
   {
      delete &os;
   }

   void AQLPrinterVisitor::printString(const std::string &s) const
   {
      os << '"';
      for (std::string::const_iterator i=s.begin(); i!=s.end(); ++i)
      {
         unsigned int c=unsigned(*i);

         if (c>=0x20 && c!='"')
         {
            os << char(c);
            continue;
         }

         switch (c)
         {
            case '\n':
               os << "\\n";
               break;
            case '\r':
               os << "\\r";
               break;
            case '"':
               os << "\\\"";
               break;
            default:
               os << "\\x";
               os << intToHexString(c, 2);
               break;
         }
      }
      os << '"';
   }

   void AQLPrinterVisitor::visitBeforeChildren(AQLJunctionCriterion &junction)
   {
      switch (junction.junctionType)
      {
         case AQLJunctionCriterion::CONJUNCTION:
            os.openLevel("and");
            break;
         case AQLJunctionCriterion::DISJUNCTION:
            os.openLevel("or");
            break;
      }
   }
   void AQLPrinterVisitor::visitBetweenChildren(AQLJunctionCriterion &, int)
   {
   }
   void AQLPrinterVisitor::visitAfterChildren(AQLJunctionCriterion &)
   {
      os.closeLevel();
   }
   const char *getNameForAQLTripleProperty(AQLPropertyExpr::Property p)
   {
      switch (p)
      {
         case AQLPropertyExpr::SUBJECT:    return "subject"; break;
         case AQLPropertyExpr::PREDICATE:  return "predicate"; break;
         case AQLPropertyExpr::OBJECT:     return "object"; break;
         default:
            throw AQLException("Unknown property part %d", int(p));
      }
   }
   const char *getNameForExprType(AQLTypeSet::ExprType type) {
      switch (type)
      {
         case AQLTypeSet::UNSET:     return "!!UNSET!!";
         case AQLTypeSet::IRI:       return "IRI";
         case AQLTypeSet::STRING:    return "string";
         case AQLTypeSet::INTEGER:   return "integer";
         case AQLTypeSet::DOUBLE:    return "double";
         case AQLTypeSet::BOOLEAN:   return "boolean";
         case AQLTypeSet::DATETIME:  return "datetime";
         case AQLTypeSet::REFERENCE: return "reference";
         case AQLTypeSet::ANY:       return "any";
         default:                 return "!!BAD TYPE!!";
      }
   }

   std::ostream &operator << (std::ostream &os, AQLTypeSet typeSet)
   {
      os << '(';
      bool first=true;
      for (AQLTypeSet::iterator i=typeSet.begin(); i!=typeSet.end(); ++i)
      {
         AQLTypeSet::ExprType exprType=*i;
         if (first)
         {
            first=false;
         }
         else {
            os << ' ';
         }
         os << getNameForExprType(exprType);
      }

      os << ')';
      return os;
   }

   ListPrettyPrinter &operator << (ListPrettyPrinter &os, AQLTypeSet typeSet)
   {
      os.openLevel("");
      for (AQLTypeSet::iterator i=typeSet.begin(); i!=typeSet.end(); ++i)
      {
         AQLTypeSet::ExprType exprType=*i;
         os.newTerm();
         os << getNameForExprType(exprType);
      }

      os.closeLevel();
      return os;
   }

   std::string toString(AQLTypeSet typeSet)
   {
      std::stringstream ss;
      ss << typeSet;
      return ss.str();
   }

   void AQLPrinterVisitor::visit(AQLPropertyExpr &expr)
   {
      os.openLevel("property");
      os << expr.getExprTypeSet();
      os.newTerm();
      printString(expr.joinName);
      os.newTerm();
      const char *property=getNameForAQLTripleProperty(expr.property);
      os << property;
      os.closeLevel();
   }
   void AQLPrinterVisitor::visit(AQLLiteralExpr &expr)
   {
      os.openLevel("literal");
      os.newTerm();
      os << getNameForExprType(expr.literalType);
      os.newTerm();
      switch (expr.literalType)
      {
         case AQLTypeSet::IRI:
         case AQLTypeSet::STRING:
         case AQLTypeSet::DATETIME:
            printString(expr.stringLiteral);
            break;

         case AQLTypeSet::BOOLEAN:
            os << (expr.booleanLiteral?"true":"false");
            break;

         case AQLTypeSet::INTEGER:
            os << expr.intLiteral;
            break;

         case AQLTypeSet::DOUBLE:
            os << expr.doubleLiteral;
            break;

         default:
            throw AQLException("AQLPrinterVisitor::visit(AQLLiteralExpr &): unhandled literal type %s",
                               toString(expr.getExprTypeSet()).c_str());
      }
      os.closeLevel();
   }

   void AQLPrinterVisitor::visit(AQLNullExpr &)
   {
      os.openLevel("null");
      os.closeLevel();
   }

   void AQLPrinterVisitor::visitBeforeChildren(AQLComparisonCriterion &c)
   {
      os.openLevel("comp");
      switch (c.comparisonType)
      {
         case AQLComparisonCriterion::EQUAL:
            os << "-eq";
            break;
         case AQLComparisonCriterion::NOT_EQUAL:
            os << "-ne";
            break;
         case AQLComparisonCriterion::LESS:
            os << "-lt";
            break;
         case AQLComparisonCriterion::LESS_OR_EQUAL:
            os << "-le";
            break;
         case AQLComparisonCriterion::GREATER_OR_EQUAL:
            os << "-ge";
            break;
         case AQLComparisonCriterion::GREATER:
            os << "-gt";
            break;
         default:
            throw AQLException("Unknown comparison type %d", int(c.comparisonType));
      }
   }
   void AQLPrinterVisitor::visitBetweenChildren(AQLComparisonCriterion &)
   {
   }
   void AQLPrinterVisitor::visitAfterChildren(AQLComparisonCriterion &)
   {
      os.closeLevel();
   }
   void AQLPrinterVisitor::visitBeforeChildren(AQLNotExpression &)
   {
      os.openLevel("not");
   }
   void AQLPrinterVisitor::visitAfterChildren(AQLNotExpression &)
   {
      os.closeLevel();
   }
   void AQLPrinterVisitor::visitBeforeChildren(AQLFunctionExpr &fexpr)
   {
      os.openLevel("function");
      printString(fexpr.functionName);
      os << fexpr.getExprTypeSet();
      if (fexpr.chosenVariant)
      {
         os << " ### variant " << toString(*fexpr.chosenVariant) << "\n";
         os.printIndent();
      }
   }
   void AQLPrinterVisitor::visitBetweenChildren(AQLFunctionExpr &, int)
   {
   }
   void AQLPrinterVisitor::visitAfterChildren(AQLFunctionExpr &)
   {
      os.closeLevel();
   }
   void AQLPrinterVisitor::visitBeforeChildren(AQLTypecastExpression &expr)
   {
      os.openLevel("typecast");
      os.newTerm();
      os << getNameForExprType(expr.toType);
   }
   void AQLPrinterVisitor::visitAfterChildren(AQLTypecastExpression &)
   {
      os.closeLevel();
   }
   void AQLPrinterVisitor::visitBeforeChildren(AQLJoinGroup &j)
   {
      os.openLevel("join");
      os.newTerm();
      switch (j.joinType)
      {
         case AQLJoinGroup::INNER:
            os << "inner";
            break;
         case AQLJoinGroup::LEFT_OUTER:
            os << "left";
            break;
      }
      os.openLevel("");
      for (std::list<std::string>::iterator i=j.names.begin();
           i!=j.names.end(); ++i)
      {
         os.newTerm();
         os << '"' << *i << '"';
      }
      os.closeLevel();
   }
   void AQLPrinterVisitor::visitBeforeNestedJoins(AQLJoinGroup &)
   {
   }
   void AQLPrinterVisitor::visitAfterChildren(AQLJoinGroup &jg)
   {
      printTypeInferenceMap(jg);
      os.closeLevel();
   }
   void AQLPrinterVisitor::visitBeforeChildren(AQLSelect &s)
   {
      os.openLevel("select");
      os.newTerm();
      printString(s.label);
   }
   void AQLPrinterVisitor::visitBetweenChildren(AQLSelect &, int)
   {
   }
   void AQLPrinterVisitor::visitAfterChildren(AQLSelect &s)
   {
      os.closeLevel();
   }
   void AQLPrinterVisitor::visitBeforeChildren(AQLSort &s)
   {
      os.openLevel("sort");
      os.newTerm();
      if (s.ascending)
         os << "ascending";
      else
         os << "descending";
   }
   void AQLPrinterVisitor::visitAfterChildren(AQLSort &s)
   {
      os.closeLevel();
   }
   void AQLPrinterVisitor::visitBeforeChildren(AQLQuery &q)
   {
      os.openLevel("aql-query");

      os.openLevel("");
      for (AQLQuery::name_list::iterator i=q.names.begin();
           i!=q.names.end(); ++i)
      {
         os.newTerm();
         os << '"' << *i << '"';
      }
      os.closeLevel();

      if (q.distinct) {
         os.openLevel("distinct");
         os.closeLevel();
      }
   }
   void AQLPrinterVisitor::visitAfterChildren(AQLQuery &query)
   {
      if (query.maxRows>=0)
      {
         os.openLevel("result-max-rows");
         os.newTerm();
         os << query.maxRows;
         os.closeLevel();
      }
      if (query.rowOffset>=0)
      {
         os.openLevel("result-row-offset");
         os.newTerm();
         os << query.rowOffset;
         os.closeLevel();
      }

      printTypeInferenceMap(query);

      os.closeLevel();
   }
   void AQLPrinterVisitor::visitBeforeChildren(AQLCustomExpr &expr)
   {
      os.openLevel("custom");
      os << expr.getExprTypeSet();
      os.newTerm();
      os << expr.toString();
   }
   void AQLPrinterVisitor::visitBetweenChildren(AQLCustomExpr &expr, int pos)
   {
   }
   void AQLPrinterVisitor::visitAfterChildren(AQLCustomExpr &expr)
   {
      os.closeLevel();
   }
   void AQLPrinterVisitor::visitBeforeSelects(AQLQuery &)
   {
   }
   void AQLPrinterVisitor::visitAfterSelects(AQLQuery &)
   {

   }
   void AQLPrinterVisitor::visitBeforeJoins(AQLQuery &)
   {
   }
   void AQLPrinterVisitor::visitAfterJoins(AQLQuery &)
   {
   }
   void AQLPrinterVisitor::visitBeforeCriterion(AQLQuery &)
   {
      os.openLevel("criterion");
   }
   void AQLPrinterVisitor::visitAfterCriterion(AQLQuery &aql)
   {
      os.closeLevel();
   }
   void AQLPrinterVisitor::visitBeforeSorts(AQLQuery &)
   {
   }
   void AQLPrinterVisitor::visitBetweenSorts(AQLQuery &, int)
   {
   }
   void AQLPrinterVisitor::visitAfterSorts(AQLQuery &)
   {
   }

   void AQLPrinterVisitor::printTypeInferenceMap(const AQLJoinGroupLike &jgl) const
   {
      if (jgl.inferenceMap.empty()) return; // print nothing
      os << "\n# Inference map = {\n";
      for (AQLJoinGroupLike::property_type_map::const_iterator i=jgl.inferenceMap.begin();
           i!=jgl.inferenceMap.end(); ++i)
      {
         const TripleProperty &tp=i->first;
         const AQLTypeSet &types=i->second;

         os << "#  " << tp.joinName << '.' << getNameForAQLTripleProperty(tp.property) << ": ";
         os << toString(types) << '\n';
      }
      os.printIndent();
   }


   void AQLOptionalVisitor::visitBeforeChildren(AQLJunctionCriterion &) {}
   void AQLOptionalVisitor::visitBetweenChildren(AQLJunctionCriterion &, int) {}
   void AQLOptionalVisitor::visitAfterChildren(AQLJunctionCriterion &) {}
   void AQLOptionalVisitor::visit(AQLPropertyExpr &) {}
   void AQLOptionalVisitor::visit(AQLLiteralExpr &) {}
   void AQLOptionalVisitor::visit(AQLNullExpr &) {}
   void AQLOptionalVisitor::visitBeforeChildren(AQLFunctionExpr &) {}
   void AQLOptionalVisitor::visitBetweenChildren(AQLFunctionExpr &, int) {}
   void AQLOptionalVisitor::visitAfterChildren(AQLFunctionExpr &) {}
   void AQLOptionalVisitor::visitBeforeChildren(AQLTypecastExpression &) {}
   void AQLOptionalVisitor::visitAfterChildren(AQLTypecastExpression &) {}
   void AQLOptionalVisitor::visitBeforeChildren(AQLComparisonCriterion &) {}
   void AQLOptionalVisitor::visitBetweenChildren(AQLComparisonCriterion &) {}
   void AQLOptionalVisitor::visitAfterChildren(AQLComparisonCriterion &) {}
   void AQLOptionalVisitor::visitBeforeChildren(AQLNotExpression &) {}
   void AQLOptionalVisitor::visitAfterChildren(AQLNotExpression &) {}
   void AQLOptionalVisitor::visitBeforeChildren(AQLJoinGroup &) {}
   void AQLOptionalVisitor::visitBeforeNestedJoins(AQLJoinGroup &) {}
   void AQLOptionalVisitor::visitAfterChildren(AQLJoinGroup &) {}
   void AQLOptionalVisitor::visitBeforeChildren(AQLSelect &) {}
   void AQLOptionalVisitor::visitBetweenChildren(AQLSelect &, int) {}
   void AQLOptionalVisitor::visitAfterChildren(AQLSort &) {}
   void AQLOptionalVisitor::visitBeforeChildren(AQLSort &) {}
   void AQLOptionalVisitor::visitAfterChildren(AQLQuery &) {}
   void AQLOptionalVisitor::visitBeforeChildren(AQLQuery &) {}
   void AQLOptionalVisitor::visitAfterChildren(AQLSelect &) {}
   void AQLOptionalVisitor::visitBeforeSelects(AQLQuery &) {}
   void AQLOptionalVisitor::visitAfterSelects(AQLQuery &) {}
   void AQLOptionalVisitor::visitBeforeJoins(AQLQuery &) {}
   void AQLOptionalVisitor::visitAfterJoins(AQLQuery &) {}
   void AQLOptionalVisitor::visitBeforeCriterion(AQLQuery &) {}
   void AQLOptionalVisitor::visitAfterCriterion(AQLQuery &) {}
   void AQLOptionalVisitor::visitBeforeSorts(AQLQuery &) {}
   void AQLOptionalVisitor::visitBetweenSorts(AQLQuery &, int) {}
   void AQLOptionalVisitor::visitAfterSorts(AQLQuery &) {}
   void AQLOptionalVisitor::visitBeforeChildren(AQLCustomExpr &) {}
   void AQLOptionalVisitor::visitBetweenChildren(AQLCustomExpr &, int) {}
   void AQLOptionalVisitor::visitAfterChildren(AQLCustomExpr &) {}


   void AQLExpressionRewriter::walk(AQLQuery &query, AQLExpressionRewriter &rewriter, bool childrenFirst)
   {
      // joins
      for (AQLQuery::nested_join_list::iterator i=query.nestedJoins.begin();
           i!=query.nestedJoins.end(); ++i)
      {
         AQLJoinGroup *join=*i;
         rewriteWalk(*join, rewriter, childrenFirst);
      }

      // base criterion
      rewriteWalk(query.criterion, rewriter, childrenFirst);

      // selects
      for (AQLQuery::select_list_type::iterator i=query.selects.begin();
           i!=query.selects.end(); ++i)
      {
         AQLSelect *select=*i;
         rewriteWalk(select->expr, rewriter, childrenFirst);
      }

      // sorts
      for (AQLQuery::sort_list_type::iterator i=query.sorts.begin();
           i!=query.sorts.end(); ++i)
      {
         AQLSort *sort=*i;
         rewriteWalk(sort->expr, rewriter, childrenFirst);
      }
   }

   AQLExpr *AQLOptionalExpressionRewriter::rewrite(AQLTypecastExpression &)
   {
      return 0;
   }

   AQLExpr *AQLOptionalExpressionRewriter::rewrite(AQLJunctionCriterion &)
   {
      return 0;
   }

   AQLExpr *AQLOptionalExpressionRewriter::rewrite(AQLPropertyExpr &)
   {
      return 0;
   }

   AQLExpr *AQLOptionalExpressionRewriter::rewrite(AQLLiteralExpr &)
   {
      return 0;
   }

   AQLExpr *AQLOptionalExpressionRewriter::rewrite(AQLNullExpr &)
   {
      return 0;
   }

   AQLExpr *AQLOptionalExpressionRewriter::rewrite(AQLFunctionExpr &)
   {
      return 0;
   }

   AQLExpr *AQLOptionalExpressionRewriter::rewrite(AQLNotExpression &)
   {
      return 0;
   }

   AQLExpr *AQLOptionalExpressionRewriter::rewrite(AQLComparisonCriterion &)
   {
      return 0;
   }

   AQLExpr *AQLOptionalExpressionRewriter::rewrite(AQLCustomExpr &)
   {
      return 0;
   }

   AQLLiteralExpr simplifyLiteral(const AQLLiteralExpr &literal)
   {
      switch (literal.literalType)
      {
         case AQLTypeSet::DOUBLE:
            if (static_cast<int64_t>(literal.doubleLiteral)==literal.doubleLiteral)
            {
               return simplifyLiteral(AQLLiteralExpr(static_cast<int64_t>(literal.doubleLiteral)));
            }
            break;

         case AQLTypeSet::INTEGER:
            if (literal.intLiteral==0)
            {
               return AQLLiteralExpr(false);
            }
            else if (literal.intLiteral==1)
            {
               return AQLLiteralExpr(true);
            }
            break;

         default:
            // it's ok, if we don't handle all literal types
            break;
      }
      return AQLLiteralExpr(literal);
   }

   AQLTypeSet::ExprType getNextTypePromotion(AQLTypeSet::ExprType exprType)
   {
      switch (exprType)
      {
         case AQLTypeSet::INTEGER: return AQLTypeSet::DOUBLE;
         default:
            return AQLTypeSet::UNSET;
      }
   }

   AQLTypeSet::ExprType getClosestCommonSupertype(AQLTypeSet::ExprType a, AQLTypeSet::ExprType b)
   {
      if (a==b) return a;

      // brute force implementation suffices here, as the number of combinations is so low...
      for (AQLTypeSet::ExprType i=a; i!=AQLTypeSet::UNSET; i=getNextTypePromotion(i))
      {
         for (AQLTypeSet::ExprType j=b; j!=AQLTypeSet::UNSET; j=getNextTypePromotion(j))
         {
            if (i==j) return i;
         }
      }

      return AQLTypeSet::UNSET;
   }

}
