/*
 *  Type-ARQuE - the experimental SPARQL to SQL translator.
 *  Copyright (C) 2010  Sami Kiminki / Aalto University
 *
 *  This program is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <algorithm>
#include <limits>

#include "AQLModel.h"
#include "AQLException.h"
#include "GeneralUtils.h"
#include "SQLBackendFunctions.h"

namespace {

   using namespace TypeRQ;

   class AcceptVisitor
   {
   protected:
      AQLVisitor &visitor;
   public:
      AcceptVisitor(AQLVisitor &_visitor) : visitor(_visitor) {}
      void operator() (AQLVisitable *pvisitable)
      {
         pvisitable->accept(visitor);
      }
   };

   template <class TVisitable> class AcceptBetweenVisitor
   {
   protected:
      AQLVisitor &visitor;
      TVisitable &visitable;
      int pos;
   public:
      AcceptBetweenVisitor(AQLVisitor &_visitor, TVisitable &_visitable) : visitor(_visitor), visitable(_visitable), pos(0) {}
      void operator() (AQLVisitable *pvisitable)
      {
         if (pos>=1) visitor.visitBetweenChildren(visitable, pos);
         ++pos;
         pvisitable->accept(visitor);
      }
   };

}

namespace TypeRQ
{
   AQLTypeSet::AQLTypeSet() : bitField(getBitMaskForExprType(AQLTypeSet::ANY))
   {
   }

   AQLTypeSet::AQLTypeSet(const AQLTypeSet &other) : bitField(other.bitField)
   {
   }

   AQLTypeSet::AQLTypeSet(bitfield_type bits) : bitField(bits)
   {
   }

   AQLTypeSet::AQLTypeSet(ExprType exprType) : bitField(getBitMaskForExprType(exprType))
   {
   }

   AQLTypeSet &AQLTypeSet::operator =(AQLTypeSet other)
   {
      bitField=other.bitField;
      return *this;
   }

   AQLTypeSet::bitfield_type AQLTypeSet::getBitField() const
   {
      return bitField;
   }

   void AQLTypeSet::setBitField(bitfield_type field)
   {
      bitField=field;
   }

   bool AQLTypeSet::isType(ExprType exprType) const
   {
      const bitfield_type mask=getBitMaskForExprType(exprType);
      return (bitField & mask)==mask;
   }

   void AQLTypeSet::clear()
   {
      bitField=0;
   }

   void AQLTypeSet::setAny()
   {
      bitField=getBitMaskForExprType(ANY);
   }

   void AQLTypeSet::removeType(ExprType exprType)
   {
      const bitfield_type mask=getBitMaskForExprType(exprType);
      bitField&=(~mask);
   }

   void AQLTypeSet::removeTypes(AQLTypeSet typeSet)
   {
      bitField&=(~typeSet.bitField);
   }

   void AQLTypeSet::setType(ExprType exprType)
   {
      const bitfield_type mask=getBitMaskForExprType(exprType);
      bitField|=mask;
   }

   void AQLTypeSet::setTypes(AQLTypeSet typeSet)
   {
      bitField|=typeSet.bitField;
   }

   void AQLTypeSet::addSuperTypes()
   {
      if (isType(AQLTypeSet::INTEGER))
      {
         setType(AQLTypeSet::DOUBLE);
      }
   }

   void AQLTypeSet::addSubTypes()
   {
      if (isType(AQLTypeSet::DOUBLE))
      {
         setType(AQLTypeSet::INTEGER);
      }
   }

   bool AQLTypeSet::operator ==(AQLTypeSet other) const
   {
      return bitField==other.bitField;
   }

   bool AQLTypeSet::operator !=(AQLTypeSet other) const
   {
      return bitField!=other.bitField;
   }

   AQLTypeSet::ExprType AQLTypeSet::singularType() const
   {
      AQLTypeSet expandedSet=*this;
      expandedSet.addSubTypes();

      for (int i=UNSET; i<EXPRTYPE_SIZE; ++i)
      {
         ExprType exprType=static_cast<ExprType>(i);
         AQLTypeSet tmp=exprType;
         tmp.addSubTypes();

         if (expandedSet==tmp) return exprType;
      }
      return ANY;
   }

   size_t AQLTypeSet::size() const
   {
      size_t ret=0;

      bitfield_type tmp=bitField;
      while (tmp)
      {
         if (tmp & 1) ++ret;
         tmp>>=1;
      }

      return ret;
   }

   AQLTypeSet::bitfield_type AQLTypeSet::getBitMaskForExprType(ExprType exprType)
   {
      if (exprType==UNSET) return 0;
      if (exprType==ANY)
      {
         // any is anything but reference
         return ((1<<unsigned(ANY))-1) &~ (1<<unsigned(REFERENCE));
      }
      return 1<<unsigned(exprType);
   }

   AQLTypeSet::iterator AQLTypeSet::begin() const
   {
      return iterator(bitField, false);
   }

   AQLTypeSet::iterator AQLTypeSet::end() const
   {
      return iterator(bitField, true);
   }

   AQLTypeSet::iterator::iterator(bitfield_type field, bool endIterator) : bits(field)
   { 
      if (endIterator)
      {
         shift=-1;
      }
      else {
         shift=0;
         skipZeroBits();
      }
   }

   AQLTypeSet::iterator::iterator(const iterator &other) : bits(other.bits), shift(other.shift)
   {
   }

   void AQLTypeSet::iterator::skipZeroBits()
   {
      while (true)
      {
         if (shift>=std::numeric_limits<bitfield_type>::digits)
         {
            // end of bits
            shift=-1;
            break;
         }

         bitfield_type mask=bitfield_type(1) << unsigned(shift);
         if ((bits & mask)!=0) break; // non-zero bit

         ++shift;
      }
   }

   AQLTypeSet::iterator &AQLTypeSet::iterator::operator ++()
   {
      ++shift;
      skipZeroBits();
      return *this;
   }

   AQLTypeSet::iterator &AQLTypeSet::iterator::operator ++(int)
   {
      return this->operator ++();
   }

   bool AQLTypeSet::iterator::operator ==(const iterator &other) const
   {
      return other.shift==shift;
   }

   bool AQLTypeSet::iterator::operator !=(const iterator &other) const
   {
      return !(this->operator ==(other));
   }

   AQLTypeSet::ExprType AQLTypeSet::iterator::operator *() const
   {
      return ExprType(shift);
   }

   AQLTypeSet makeUnion(AQLTypeSet a, AQLTypeSet b)
   {
      return AQLTypeSet(a.getBitField() | b.getBitField());
   }

   AQLTypeSet makeIntersection(AQLTypeSet a, AQLTypeSet b)
   {
      return AQLTypeSet(a.getBitField() & b.getBitField());
   }

   AQLJunctionCriterion::~AQLJunctionCriterion()
   {
      std::for_each(terms.begin(), terms.end(), deleteObject<AQLExpr>);
   }
   const char *AQLJunctionCriterion::getTypeName() const
   {
      return "junction";
   }
   void AQLJunctionCriterion::accept(AQLVisitor &v)
   {
      v.visitBeforeChildren(*this);
      std::for_each(terms.begin(), terms.end(), AcceptBetweenVisitor<AQLJunctionCriterion>(v, *this));
      v.visitAfterChildren(*this);
   }
   void AQLJunctionCriterion::unlinkChildren()
   {
      terms.clear();
   }

   bool AQLJunctionCriterion::equals(const AQLExpr &) const
   {
      throw AQLException("Unsupported: AQLJunctionCriterion::equals");
   }

   AQLTypeSet AQLLogicalExpr::getExprTypeSet() const
   {
      return AQLTypeSet(AQLTypeSet::BOOLEAN);
   }

   AQLPropertyExpr::AQLPropertyExpr() : property(), propertyType(AQLTypeSet()) {}

   AQLPropertyExpr::AQLPropertyExpr(const std::string &_joinName, Property _p,
                                    AQLTypeSet _typeSet) : joinName(_joinName), property(_p),
                                                           propertyType(_typeSet) {}

   AQLPropertyExpr::AQLPropertyExpr(const AQLPropertyExpr &e) : joinName(e.joinName), property(e.property),
                                                                propertyType(e.propertyType) {}

   AQLTypeSet AQLPropertyExpr::getExprTypeSet() const
   {
      return propertyType;
   }


   void AQLPropertyExpr::accept(AQLVisitor &v)
   {
      v.visit(*this);
   }

   const char *AQLPropertyExpr::getTypeName() const
   {
      return "property";
   }

   void AQLPropertyExpr::unlinkChildren()
   {
      // no children
   }

   bool AQLPropertyExpr::equals(const AQLExpr &o) const
   {
      const AQLPropertyExpr *other=dynamic_cast<const AQLPropertyExpr *>(&o);
      if (!other) return false;

      return joinName==other->joinName
         && property==other->property
         && propertyType==other->propertyType;
   }


   AQLLiteralExpr::AQLLiteralExpr(const std::string &_stringLiteral, AQLTypeSet::ExprType exprType) :
      stringLiteral(_stringLiteral), booleanLiteral(), intLiteral(), doubleLiteral(), literalType(exprType) {}

   AQLLiteralExpr::AQLLiteralExpr(bool _booleanLiteral) :
      booleanLiteral(_booleanLiteral),
      intLiteral(), doubleLiteral(), literalType(AQLTypeSet::BOOLEAN) {}

   AQLLiteralExpr::AQLLiteralExpr(int64_t _intLiteral) :
      booleanLiteral(),
      intLiteral(_intLiteral), doubleLiteral(), literalType(AQLTypeSet::INTEGER) {}

   AQLLiteralExpr::AQLLiteralExpr(double _doubleLiteral) :
      booleanLiteral(),
      intLiteral(), doubleLiteral(_doubleLiteral), literalType(AQLTypeSet::DOUBLE) {}

   AQLLiteralExpr::AQLLiteralExpr(const AQLLiteralExpr &o) :
      stringLiteral(o.stringLiteral), booleanLiteral(o.booleanLiteral), intLiteral(o.intLiteral),
      doubleLiteral(o.doubleLiteral), literalType(o.literalType) {}

   void AQLLiteralExpr::accept(AQLVisitor &v)
   {
      v.visit(*this);
   }

   AQLTypeSet AQLLiteralExpr::getExprTypeSet() const
   {
      return literalType;
   }

   const char *AQLLiteralExpr::getTypeName() const
   {
      return "literal";
   }

   void AQLLiteralExpr::unlinkChildren()
   {
      // no children
   }

   bool AQLLiteralExpr::equals(const AQLExpr &o) const
   {
      const AQLLiteralExpr *other=dynamic_cast<const AQLLiteralExpr *>(&o);
      if (!other) return false;

      if (other->literalType!=literalType)
      {
         return false;
      }

      switch (literalType)
      {
         case AQLTypeSet::STRING:
         case AQLTypeSet::IRI:
         case AQLTypeSet::DATETIME:
            return stringLiteral==other->stringLiteral;

         case AQLTypeSet::INTEGER:
            return intLiteral==other->intLiteral;

         case AQLTypeSet::DOUBLE:
            return doubleLiteral==other->doubleLiteral;

         case AQLTypeSet::BOOLEAN:
            return booleanLiteral==other->booleanLiteral;

         default:
            throw AQLException("AQLLiteralExpr::equals: unhandled literal type %d",
                               static_cast<int>(literalType));
      }
   }

   AQLNullExpr::AQLNullExpr() {}

   const char *AQLNullExpr::getTypeName() const
   {
      return "null";
   }

   AQLTypeSet AQLNullExpr::getExprTypeSet() const
   {
      return AQLTypeSet::UNSET;
   }

   void AQLNullExpr::accept(AQLVisitor &v)
   {
      v.visit(*this);
   }

   void AQLNullExpr::unlinkChildren() {}

   bool AQLNullExpr::equals(const AQLExpr &o) const
   {
      return dynamic_cast<const AQLNullExpr *>(&o);
   }

   AQLFunctionExpr::AQLFunctionExpr() : functionType(), chosenVariant(0) {}
   AQLFunctionExpr::~AQLFunctionExpr()
   {
      std::for_each(arguments.begin(), arguments.end(), deleteObject<AQLExpr>);
      delete chosenVariant;
   }
   const char *AQLFunctionExpr::getTypeName() const
   {
      return "function";
   }
   AQLTypeSet AQLFunctionExpr::getExprTypeSet() const
   {
      return functionType;
   }

   void AQLFunctionExpr::accept(AQLVisitor &v)
   {
      v.visitBeforeChildren(*this);
      std::for_each(arguments.begin(), arguments.end(), AcceptBetweenVisitor<AQLFunctionExpr>(v, *this));
      v.visitAfterChildren(*this);
   }

   void AQLFunctionExpr::unlinkChildren()
   {
      arguments.clear();
   }

   bool AQLFunctionExpr::equals(const AQLExpr &o) const
   {
      const AQLFunctionExpr *other=dynamic_cast<const AQLFunctionExpr *>(&o);
      if (!other) return false;

      if (functionName!=other->functionName) return false;
      if (arguments.size()!=other->arguments.size()) return false;
      if (functionType!=other->functionType) return false;

      arg_list_type::const_iterator i=arguments.begin();
      arg_list_type::const_iterator j=other->arguments.begin();
      for (; i!=arguments.end(); ++i, ++j)
      {
         const AQLExpr &arg1=**i;
         const AQLExpr &arg2=**j;
         if (!arg1.equals(arg2)) return false;
      }
      return true;
   }

   AQLComparisonCriterion::AQLComparisonCriterion() : left(0), right(0), comparisonType(EQUAL) {}

   AQLComparisonCriterion::AQLComparisonCriterion(AQLExpr *_left, AQLExpr *_right,
                                                  ComparisonType _comparisonType) :
      left(_left), right(_right), comparisonType(_comparisonType)
   {
   }

   AQLComparisonCriterion::~AQLComparisonCriterion()
   {
      delete left;
      delete right;
   }
   const char *AQLComparisonCriterion::getTypeName() const
   {
      return "comparison";
   }
   void AQLComparisonCriterion::accept(AQLVisitor &v)
   {
      v.visitBeforeChildren(*this);
      if (left) left->accept(v);
      v.visitBetweenChildren(*this);
      if (right) right->accept(v);
      v.visitAfterChildren(*this);
   }

   void AQLComparisonCriterion::unlinkChildren()
   {
      left=0;
      right=0;
   }

   bool AQLComparisonCriterion::equals(const AQLExpr &) const
   {
      throw AQLException("Unsupported: AQLComparisonCriterion::equals");
   }


   AQLNotExpression::AQLNotExpression() : expr(0)
   {
   }
   AQLNotExpression::~AQLNotExpression()
   {
      delete expr;
   }
   const char *AQLNotExpression::getTypeName() const
   {
      return "not";
   }
   void AQLNotExpression::accept(AQLVisitor &v)
   {
      v.visitBeforeChildren(*this);
      if (expr) expr->accept(v);
      v.visitAfterChildren(*this);
   }
   void AQLNotExpression::unlinkChildren()
   {
      expr=0;
   }

   bool AQLNotExpression::equals(const AQLExpr &o) const
   {
      const AQLNotExpression *other=dynamic_cast<const AQLNotExpression *>(&o);
      if (!other) return false;
      return expr->equals(*other->expr);
   }

   AQLTypecastExpression::AQLTypecastExpression() : expr(), toType(AQLTypeSet::UNSET)
   {}

   AQLTypecastExpression::AQLTypecastExpression(AQLExpr *_expr, AQLTypeSet::ExprType _toType)
      : expr(_expr), toType(_toType) {}

   AQLTypecastExpression::~AQLTypecastExpression()
   {
      delete expr;
   }

   const char *AQLTypecastExpression::getTypeName() const
   {
      return "typecast";
   }

   void AQLTypecastExpression::accept(AQLVisitor &v)
   {
      v.visitBeforeChildren(*this);
      if (expr) expr->accept(v);
      v.visitAfterChildren(*this);
   }

   AQLTypeSet AQLTypecastExpression::getExprTypeSet() const
   {
      return AQLTypeSet(toType);
   }

   void AQLTypecastExpression::unlinkChildren()
   {
      expr=0;
   }

   bool AQLTypecastExpression::equals(const AQLExpr &) const
   {
      throw AQLException("Unsupported: AQLTypecastExpression::equals");
   }

   AQLJoinGroupLike::AQLJoinGroupLike() : criterion(0) {}

   AQLJoinGroupLike::~AQLJoinGroupLike()
   {
      delete criterion;
      std::for_each(nestedJoins.begin(), nestedJoins.end(), deleteObject<AQLJoinGroup>);
   }

   void AQLJoinGroupLike::addCriterion(AQLExpr *expr) {
      if (!criterion) {
         criterion=expr;
         return;
      }

      AQLJunctionCriterion *junction=dynamic_cast<AQLJunctionCriterion *>(criterion);
      if (junction && junction->junctionType==AQLJunctionCriterion::CONJUNCTION) {
         junction->terms.push_back(expr);
      } else {
         junction=new AQLJunctionCriterion;
         junction->junctionType=AQLJunctionCriterion::CONJUNCTION;
         junction->terms.push_back(criterion);
         junction->terms.push_back(expr);
         criterion=junction;
      }
   }

   AQLJoinGroup::AQLJoinGroup() : joinType(AQLJoinGroup::INNER) {}
   AQLJoinGroup::~AQLJoinGroup() {}

   void AQLJoinGroup::accept(AQLVisitor &v)
   {
      v.visitBeforeChildren(*this);
      if (criterion) criterion->accept(v);
      v.visitBeforeNestedJoins(*this);
      std::for_each(nestedJoins.begin(), nestedJoins.end(), AcceptVisitor(v));
      v.visitAfterChildren(*this);
   }

   bool operator == (const AQLPropertyExpr &left, const AQLPropertyExpr &right) {
      return (left.joinName==right.joinName && left.property==right.property);
   }


   bool operator != (const AQLPropertyExpr &left, const AQLPropertyExpr &right) {
      return !(left==right);
   }



   AQLSelect::AQLSelect() : expr(0) {}
   AQLSelect::~AQLSelect()
   {
      delete expr;
   }

   void AQLSelect::accept(AQLVisitor &v)
   {
      v.visitBeforeChildren(*this);
      if (expr) expr->accept(v);
      v.visitAfterChildren(*this);
   }

   AQLSort::AQLSort() : ascending(true), expr(0)
   {
   }

   AQLSort::~AQLSort()
   {
      delete expr;
   }

   void AQLSort::accept(AQLVisitor &v)
   {
      v.visitBeforeChildren(*this);
      if (expr) expr->accept(v);
      v.visitAfterChildren(*this);
   }

   AQLQuery::AQLQuery() : distinct(false), maxRows(-1), rowOffset(-1)
   {
   }

   AQLQuery::~AQLQuery()
   {
      std::for_each(selects.begin(), selects.end(), deleteObject<AQLSelect>);
      std::for_each(sorts.begin(), sorts.end(), deleteObject<AQLSort>);
   }

   void AQLQuery::accept(AQLVisitor &v)
   {
      v.visitBeforeChildren(*this);
      v.visitBeforeSelects(*this);
      std::for_each(selects.begin(), selects.end(), AcceptVisitor(v));
      v.visitAfterSelects(*this);
      v.visitBeforeJoins(*this);
      std::for_each(nestedJoins.begin(), nestedJoins.end(), AcceptVisitor(v));
      v.visitAfterJoins(*this);
      v.visitBeforeCriterion(*this);
      if (criterion)
         criterion->accept(v);
      v.visitAfterCriterion(*this);
      v.visitBeforeSorts(*this);
      int pos=0;
      for (std::list<AQLSort *>::iterator i=sorts.begin();
           i!=sorts.end(); ++i)
      {
         AQLSort *sort=*i;
         if (pos>=1) v.visitBetweenSorts(*this, pos);
         sort->accept(v);
         ++pos;
      }
      v.visitAfterSorts(*this);
      v.visitAfterChildren(*this);
   }

   AQLResult::~AQLResult() {}

}
