/*
 *  Type-ARQuE - the experimental SPARQL to SQL translator.
 *  Copyright (C) 2010  Sami Kiminki / Aalto University
 *
 *  This program is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <set>
#include <map>
#include <stdexcept>
#include <cassert>

#include "AQLTypeInference.h"
#include "AQLModel.h"
#include "AQLSupport.h"
#include "SQLBackendFunctions.h"
#include "Messages.h"
#include "AQLException.h"

namespace {
   using namespace TypeRQ;

   struct JoinAndProperty
   {
      const std::string joinName;
      const AQLPropertyExpr::Property property;

      JoinAndProperty(const std::string &_joinName, AQLPropertyExpr::Property _property) :
         joinName(_joinName), property(_property)
      {}

      struct less
      {
         bool operator () (const JoinAndProperty &a, const JoinAndProperty &b) const
         {
            if (a.joinName<b.joinName) return true;
            if (a.joinName>b.joinName) return false;
            // a.joinName == b.joinName
            if (a.property<b.property) return true;
            return false;
         }
      };

   };

   struct TypeConstraints : protected AQLDebugBase
   {
      typedef std::map<JoinAndProperty, AQLTypeSet, JoinAndProperty::less> property_constraint_map;
      property_constraint_map propertyConstraintMap;

      TypeConstraints()
      {
      }

      TypeConstraints(const TypeConstraints &o)
      {
         assign(o);
      }

      TypeConstraints &operator =(const TypeConstraints &o)
      {
         assign(o);
         return *this;
      }

      virtual ~TypeConstraints()
      {
         clear();
      }

      void clear()
      {
         propertyConstraintMap.clear();
      }

      bool clearAllTypeSets()
      {
         bool ret=false;
         for (property_constraint_map::iterator i=propertyConstraintMap.begin();
              i!=propertyConstraintMap.end(); ++i)
         {
            AQLTypeSet orig=i->second;
            i->second=AQLTypeSet(AQLTypeSet::UNSET);
            if (orig!=i->second) ret=true;

         }
         return ret;
      }

      void assign(const TypeConstraints &o)
      {
         propertyConstraintMap=o.propertyConstraintMap;
      }

      bool makeIntersection(const TypeConstraints &o)
      {
         bool ret=false;
         for (property_constraint_map::const_iterator i=o.propertyConstraintMap.begin();
              i!=o.propertyConstraintMap.end(); ++i)
         {
            if (constrain(i->first.joinName, i->first.property, i->second)) ret=true;
         }
         return ret;
      }

      void makeUnion(const TypeConstraints &o)
      {
         // first: find all properties only existing in this
         // note: default is ANY, if not defined
         for (property_constraint_map::iterator i=propertyConstraintMap.begin();
              i!=propertyConstraintMap.end(); ++i)
         {
            if (o.propertyConstraintMap.find(i->first)==o.propertyConstraintMap.end())
            {
               i->second=AQLTypeSet(AQLTypeSet::ANY);
            }
         }
         for (property_constraint_map::const_iterator i=o.propertyConstraintMap.begin();
              i!=o.propertyConstraintMap.end(); ++i)
         {
            const std::string &joinName=i->first.joinName;
            const AQLPropertyExpr::Property property=i->first.property;
            const AQLTypeSet otherTset=i->second;
            AQLTypeSet &tset=propertyConstraintMap[JoinAndProperty(joinName, property)];
            tset=::makeUnion(tset, otherTset);
         }
      }

      AQLTypeSet getConstraint(const std::string &joinName, AQLPropertyExpr::Property property) const
      {
         property_constraint_map::const_iterator i=propertyConstraintMap.find(JoinAndProperty(joinName, property));
         if (i==propertyConstraintMap.end()) return AQLTypeSet(); // ANY
         return i->second;
      }

      bool constrain(const std::string &joinName, AQLPropertyExpr::Property property,
                     AQLTypeSet constraint)
      {
         AQLTypeSet tset=getConstraint(joinName, property);
         AQLTypeSet isec=::makeIntersection(tset, constraint);
         bool const ret=(tset!=isec);
         print(OL_TRACE, "Constraining %s:%s%s with %s, result: %s, changed=%d\n",
               joinName.c_str(), getNameForAQLTripleProperty(property),
               toString(tset).c_str(),
               toString(constraint).c_str(),
               toString(isec).c_str(),
               static_cast<int>(ret));
         propertyConstraintMap[JoinAndProperty(joinName, property)]=isec;
         return ret;
      }

   };

   struct TypeConstraintMap
   {
      typedef std::map<const AQLExpr *, TypeConstraints *> scope_map;
      scope_map scopedConstraints;

      typedef std::map<const AQLFunctionExpr *, SQLFunctionType *> function_constraint_map;
      function_constraint_map functionConstraintMap;

      virtual ~TypeConstraintMap()
      {
         for (scope_map::iterator i=scopedConstraints.begin(); i!=scopedConstraints.end(); ++i)
         {
            delete i->second;
         }

         for (function_constraint_map::iterator i=functionConstraintMap.begin();
              i!=functionConstraintMap.end(); ++i)
         {
            delete i->second;
         }
         functionConstraintMap.clear();
      }

      TypeConstraints &getScope(const AQLExpr &expr)
      {
         scope_map::iterator i=scopedConstraints.find(&expr);
         if (i!=scopedConstraints.end()) return *(i->second);
         TypeConstraints *tc=new TypeConstraints;
         scopedConstraints.insert(std::make_pair(&expr, tc));
         return *tc;
      }

      const SQLFunctionType &getConstraint(const AQLFunctionExpr &fexpr) const
      {
         try {
            return *functionConstraintMap.at(&fexpr);
         }
         catch (std::out_of_range)
         {
            throw AQLException("Internal error: tried to get type constraint for function %s@%p "
                               "but no constrains were defined",
                               fexpr.functionName.c_str(), &fexpr);
         }
      }

      bool constrain(const AQLFunctionExpr &fexpr, SQLFunctionType &ft)
      {
         print(OL_TRACE, "Constraining %s@%p with %s,\n  result=",
               fexpr.functionName.c_str(), &fexpr,
               toString(ft).c_str());
         function_constraint_map::iterator i=functionConstraintMap.find(&fexpr);
         if (i==functionConstraintMap.end())
         {
            functionConstraintMap.insert(std::make_pair(&fexpr, new SQLFunctionType(ft)));
            print(OL_TRACE, "%s, new constraint\n", toString(ft).c_str());
            return true;
         }
         else {
            SQLFunctionType &functionFt=*i->second;
            SQLFunctionType origFt=functionFt;
            functionFt.intersect(ft);
            const bool ret=(origFt != functionFt); // was narrowed?
            print(OL_TRACE, "%s, changed=%d\n", toString(*i->second).c_str(), static_cast<int>(ret));
            return ret;
         }
      }
   };

   class FunctionTypeNarrowerVisitor : public AQLOptionalVisitor
   {
   private:
      TypeConstraintMap &tcm;
      SQLFunctionMap &functionMap;
      bool changes;

      void constrainFunction(const AQLFunctionExpr &fexpr, SQLFunctionType &ft)
      {
         print(OL_TRACE, "constrainFunction()\n");
         if (tcm.constrain(fexpr, ft))
         {
            changes=true;
         }
      }

   public:
      FunctionTypeNarrowerVisitor(TypeConstraintMap &_tcm, SQLFunctionMap &_functionMap) :
         tcm(_tcm), functionMap(_functionMap), changes(false)
      {
      }

      bool wasChanged() const
      {
         return changes;
      }

      void visitAfterChildren(AQLFunctionExpr &expr)
      {
         SQLFunctionMapping &fm=functionMap.getFunctionMapping(expr.functionName);

         SQLFunctionType ft(expr.functionType);
         fm.determineFunctionTypeSet(expr, ft);
         constrainFunction(expr, ft);
      }
   };

   class ConstraintCollectorVisitor : public AQLOptionalVisitor
   {
   private:
      TypeConstraintMap &tcm;
      SQLFunctionMap &functionMap;

      bool changes;

      void constrainProperty(TypeConstraints &scope, const std::string &joinName,
                             AQLPropertyExpr::Property p, AQLTypeSet tset)
      {
         if (scope.constrain(joinName, p, tset))
         {
            changes=true;
         }
      }

      void putDefaultConstraintsToScope(TypeConstraints &scope, AQLJoinGroupLike &jgl)
      {
         for (AQLJoinGroupLike::name_list::const_iterator i=jgl.names.begin(); i!=jgl.names.end(); ++i)
         {
            const std::string &joinName=*i;
            constrainProperty(scope, joinName, AQLPropertyExpr::SUBJECT, AQLTypeSet(AQLTypeSet::IRI));
            constrainProperty(scope, joinName, AQLPropertyExpr::PREDICATE, AQLTypeSet(AQLTypeSet::IRI));
            constrainProperty(scope, joinName, AQLPropertyExpr::OBJECT, AQLTypeSet(AQLTypeSet::ANY));
         }
      }

   public:
      ConstraintCollectorVisitor(TypeConstraintMap &_tcm, SQLFunctionMap &_functionMap) :
         tcm(_tcm), functionMap(_functionMap), changes(false)
      {
      }

      bool wasChanged() const
      {
         return changes;
      }

      TypeConstraints &flowExprConstraints(const TypeConstraints &globalScope, const AQLExpr *expr)
      {
         TypeConstraints &exprScope=tcm.getScope(*expr);
         if (dynamic_cast<const AQLLiteralExpr *>(expr))
         {
             // literals do not provide any constraints
            return exprScope;
         }
         const AQLPropertyExpr *pexpr=dynamic_cast<const AQLPropertyExpr *>(expr);
         if (pexpr)
         {
            exprScope.constrain(pexpr->joinName, pexpr->property,
                                makeIntersection(globalScope.getConstraint(pexpr->joinName, pexpr->property),
                                                 pexpr->propertyType));
            return exprScope;
         }
         const AQLFunctionExpr *fexpr=dynamic_cast<const AQLFunctionExpr *>(expr);
         if (fexpr)
         {
            const SQLFunctionType &ft=tcm.getConstraint(*fexpr);
            std::vector<TypeConstraints *> paramScopes(fexpr->arguments.size());
            size_t i=0;
            for (AQLFunctionExpr::arg_list_type::const_iterator j=fexpr->arguments.begin();
                 j!=fexpr->arguments.end(); ++i, ++j)
            {
               const AQLExpr *param=*j;
               TypeConstraints &paramScope=flowExprConstraints(globalScope, param);
               paramScopes.at(i)=&paramScope;

               const AQLFunctionExpr *fparam=dynamic_cast<const AQLFunctionExpr *>(param);
               if (fparam)
               {
                  SQLFunctionType paramFt=tcm.getConstraint(*fparam);
                  paramFt.returnType=makeIntersection(paramFt.returnType, ft.paramTypes.at(i));
                  if (tcm.constrain(*fparam, paramFt))
                  {
                     changes=true;
                  }                  
               }
               const AQLPropertyExpr *pparam=dynamic_cast<const AQLPropertyExpr *>(param);
               if (pparam)
               {
                  if (paramScope.constrain(pparam->joinName, pparam->property, ft.paramTypes.at(i)))
                  {
                     changes=true;
                  }
               }
            }

            // transfer constraints upwards
            if (fexpr->functionName=="builtin:or" || fexpr->functionName=="builtin:coalesce")
            {
               if (paramScopes.empty())
               {
                  changes=exprScope.clearAllTypeSets();
               }
               else {
                  TypeConstraints tctmp=*paramScopes[0];
                  for (size_t j=1; j<paramScopes.size(); ++j)
                  {
                     TypeConstraints &paramScope=*paramScopes[j];
                     tctmp.makeUnion(paramScope);
                  }
                  if (exprScope.makeIntersection(tctmp)) changes=true;
               }
            }
            else if (fexpr->functionName!="builtin:not") {
               for (size_t j=0; j<paramScopes.size(); ++j)
               {
                  TypeConstraints &paramScope=*paramScopes[j];
                  if (exprScope.makeIntersection(paramScope)) changes=true;
               }
            }
            else {
               // with not, we don't transfer constraints
            }
            return exprScope;
         }
         throw AQLException("Internal: flowExprConstraints: unhandled expression node type");
      }

      void collectConstraints(AQLJoinGroupLike &jgl)
      {
         TypeConstraints &scope=tcm.getScope(*jgl.criterion);
         print(OL_TRACE, "Gathering constraints from scope %p\n", &scope);
         putDefaultConstraintsToScope(scope, jgl);

         flowExprConstraints(scope, jgl.criterion);
      }

      void visitAfterJoins(AQLQuery &jgl)
      {
         collectConstraints(jgl);
      }

      void visitBeforeChildren(AQLJoinGroup &jgl)
      {
         collectConstraints(jgl);
      }
   };

   class TypeWriterVisitor : public AQLOptionalVisitor
   {
   private:
      TypeConstraintMap &tcm;
      TypeConstraints *scope;

      void writeInferenceMap(AQLJoinGroupLike &jgl)
      {
         jgl.inferenceMap.clear();
         for (AQLJoinGroupLike::name_list::iterator i=jgl.names.begin(); i!=jgl.names.end(); ++i)
         {
            const std::string &joinName=*i;
            jgl.inferenceMap[TripleProperty(joinName, AQLPropertyExpr::SUBJECT)]=
               scope->getConstraint(joinName, AQLPropertyExpr::SUBJECT);
            jgl.inferenceMap[TripleProperty(joinName, AQLPropertyExpr::PREDICATE)]=
               scope->getConstraint(joinName, AQLPropertyExpr::PREDICATE);
            jgl.inferenceMap[TripleProperty(joinName, AQLPropertyExpr::OBJECT)]=
               scope->getConstraint(joinName, AQLPropertyExpr::OBJECT);
         }
      }

   public:
      TypeWriterVisitor(TypeConstraintMap &_tcm) : tcm(_tcm), scope(0)
      {
      }

      void visit(AQLPropertyExpr &expr)
      {
         assert(scope);
         TypeConstraints &localScope=tcm.getScope(expr);
         AQLTypeSet globaltset=scope->getConstraint(expr.joinName, expr.property);
         AQLTypeSet localtset=localScope.getConstraint(expr.joinName, expr.property);
         expr.propertyType=makeIntersection(makeIntersection(globaltset, localtset), expr.propertyType);
      }

      void visitBeforeChildren(AQLFunctionExpr &expr)
      {
         const SQLFunctionType &ft=tcm.getConstraint(expr);
         expr.functionType=makeIntersection(ft.returnType, expr.functionType);
      }

      void visitBeforeChildren(AQLQuery &jgl)
      {
         scope=&(tcm.getScope(*jgl.criterion));
         print(OL_TRACE, "Entering scope %p\n", scope);
      }
      void visitBeforeCriterion(AQLQuery &jgl)
      {
         scope=&(tcm.getScope(*jgl.criterion));
         print(OL_TRACE, "Entering scope %p\n", scope);
         writeInferenceMap(jgl);
      }
      void visitAfterChildren(AQLQuery &jgl)
      {
         print(OL_TRACE, "Exit scope %p\n", scope);
         scope=0;
      }
      void visitBeforeChildren(AQLJoinGroup &jgl)
      {
         scope=&tcm.getScope(*jgl.criterion);
         print(OL_TRACE, "Entering scope %p\n", scope);
         writeInferenceMap(jgl);
      }
      void visitAfterChildren(AQLJoinGroup &jgl)
      {
         print(OL_TRACE, "Exit scope %p\n", scope);
         scope=0;
      }
   };

   class ScopedConstraintPropagator : protected AQLOptionalVisitor
   {
   protected:
      TypeConstraintMap &tcm;
      AQLQuery &aql;

      std::set<const AQLJoinGroupLike *> jgls;
      std::map<std::string, const AQLJoinGroupLike *> joinNameToJgl;
      typedef std::map<const AQLJoinGroupLike *, std::set<const AQLJoinGroupLike *> > children_map;
      children_map childrenMap;

      // this is used when visiting join groups to populate structures above
      std::vector<const AQLJoinGroupLike *> parentStack;

      void visitJglBeforeChildren(const AQLJoinGroupLike &jgl)
      {
         jgls.insert(&jgl);
         for (AQLJoinGroupLike::name_list::const_iterator i=jgl.names.begin(); i!=jgl.names.end(); ++i)
         {
            const std::string &joinName=*i;
            joinNameToJgl[joinName]=&jgl;
         }
         for (size_t i=0; i<parentStack.size(); ++i)
         {
            const AQLJoinGroupLike *parent=parentStack.at(i);
            childrenMap.at(parent).insert(&jgl);
         }
         childrenMap[&jgl];

         parentStack.push_back(&jgl);
      }
      void visitJglAfterChildren(const AQLJoinGroupLike &)
      {
         parentStack.pop_back();
      }

      void visitBeforeJoins(AQLQuery &jgl)
      {
         visitJglBeforeChildren(jgl);
      }
      void visitAfterJoins(AQLQuery &jgl)
      {
         visitJglAfterChildren(jgl);
      }
      void visitBeforeChildren(AQLJoinGroup &jgl)
      {
         visitJglBeforeChildren(jgl);
      }
      void visitAfterChildren(AQLJoinGroup &jgl)
      {
         visitJglAfterChildren(jgl);
      }

      bool propagateConstraintsToChildren()
      {
         bool ret=false;
         for (children_map::const_iterator i=childrenMap.begin(); i!=childrenMap.end(); ++i)
         {
            const AQLJoinGroupLike *parent=i->first;
            TypeConstraints &parentScope=tcm.getScope(*parent->criterion);
            const children_map::mapped_type &childSet=i->second;
            for (children_map::mapped_type::const_iterator j=childSet.begin();
                 j!=childSet.end(); ++j)
            {
               const AQLJoinGroupLike *child=*j;
               TypeConstraints &childScope=tcm.getScope(*child->criterion);
               if (childScope.makeIntersection(parentScope)) ret=true;
            }
         }
         return ret;
      }

      bool ownedTripleRecursive(const std::string &joinName, const AQLJoinGroupLike &jgl)
      {
         const AQLJoinGroupLike &definer=*joinNameToJgl.at(joinName);
         if (&definer==&jgl) return true; // defined by jgl

         const children_map::mapped_type &childSet=childrenMap.at(&jgl);
         if (childSet.find(&definer)!=childSet.end())
         {
            // ownder by a child
            return true;
         }
         // not owned by a child
         return false;
      }

      bool propagateOwnedConstraintsToRoot()
      {
         bool ret=false;
         TypeConstraints &rootScope=tcm.getScope(*aql.criterion);
         for (std::set<const AQLJoinGroupLike *>::const_iterator i=jgls.begin(); i!=jgls.end(); ++i)
         {
            const AQLJoinGroupLike *jgl=*i;
            if (jgl==&aql) continue; // not interested in propagating root stuff to root

            TypeConstraints &exprScope=tcm.getScope(*jgl->criterion);

            for (TypeConstraints::property_constraint_map::const_iterator j=
                    exprScope.propertyConstraintMap.begin();
                 j!=exprScope.propertyConstraintMap.end();
                 ++j)
            {
               const JoinAndProperty &key=j->first;
               const AQLTypeSet &tset=j->second;

               if (!ownedTripleRecursive(key.joinName, *jgl)) continue;

               // ok, this triple is owned by jgl, propagate to root
               if (rootScope.constrain(key.joinName, key.property, tset)) ret=true;
            }
         }
         return ret;
      }

   public:
      ScopedConstraintPropagator(TypeConstraintMap &_tcm, AQLQuery &_aql) : tcm(_tcm), aql(_aql)
      {
         aql.accept(*this);
      }

      bool propagateConstraints()
      {
         bool changes=false;

         if (propagateConstraintsToChildren()) changes=true;
         if (propagateOwnedConstraintsToRoot()) changes=true;

         return changes;
      }




   };

}

namespace TypeRQ {

   void performTypeInference(AQLQuery &aql, SQLFunctionMap &functionMap)
   {
      TypeConstraintMap tcm;
      ScopedConstraintPropagator scp(tcm, aql);

      for (int i=0; ; ++i)
      {
         print(OL_TRACE, "\n*** Iteration %d***\n\n", i);

         if (i==100)
         {
            throw AQLException("Type inference pass did not halt after %d iterations!", i);
         }

         bool changed=false;
         FunctionTypeNarrowerVisitor ftnv(tcm, functionMap);
         print(OL_TRACE, "Narrowing function types...\n");
         aql.accept(ftnv);
         if (ftnv.wasChanged())
         {
            changed=true;
         }

         ConstraintCollectorVisitor tccv(tcm, functionMap);
         print(OL_TRACE, "Collecting constraints...\n");
         aql.accept(tccv);
         if (tccv.wasChanged())
         {
            changed=true;
         }

         print(OL_TRACE, "Propagating constraints...\n");
         if (scp.propagateConstraints())
         {
            changed=true;
         }

         if (!changed)
         {
            print(OL_TRACE, "No new constraints: Fixed point reached\n");
            break;
         }

         TypeWriterVisitor twv(tcm);
         print(OL_TRACE, "Rewriting AQL types per constraints...\n");         
         aql.accept(twv);
      }
   }


}
