/***************************************************************************
**
**  This file is part of QGpCoreStat.
**
**  QGpCoreStat is free software: you can redistribute it and/or modify
**  it under the terms of the GNU General Public License as published by
**  the Free Software Foundation, either version 3 of the License, or
**  (at your option) any later version.
**
**  QGpCoreStat is distributed in the hope that it will be useful,
**  but WITHOUT ANY WARRANTY; without even the implied warranty of
**  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
**  GNU General Public License for more details.
**
**  You should have received a copy of the GNU General Public License
**  along with Foobar.  If not, see <http://www.gnu.org/licenses/>
**
**  See http://www.geopsy.org for more information.
**
**  Created: 2017-06-12
**  Copyright: 2017-2019
**    Marc Wathelet (ISTerre, Grenoble, France)
**
***************************************************************************/

#include "GaussianMixtureDistribution.h"

namespace QGpCoreStat {

  /*!
    \class GaussianMixtureDistribution GaussianMixtureDistribution.h
    \brief A Gaussian mixture conmplex distribution

    GaussianMixtureDistribution is a collection of Gauss distributions associated with a weight.
    It represents a complex statistical distribution with several modes.
  */

  /*!
    Description of constructor still missing
  */
  GaussianMixtureDistribution::GaussianMixtureDistribution()
  {
    TRACE;
    _modeCount=1;
    allocate(1, 1, _weights, _modes);
  }

  /*!
    Description of constructor still missing
  */
  GaussianMixtureDistribution::GaussianMixtureDistribution(int modeCount, int dimensionCount)
  {
    TRACE;
    _modeCount=modeCount;
    allocate(modeCount, dimensionCount, _weights, _modes);
  }

  /*!
    Description of constructor still missing
  */
  GaussianMixtureDistribution::GaussianMixtureDistribution(const GaussianMixtureDistribution& o)
  {
    TRACE;
    _modes=nullptr;
    _weights=nullptr;
    _modeCount=0;
    *this=o;
  }

  /*!
    Description of destructor still missing
  */
  GaussianMixtureDistribution::~GaussianMixtureDistribution()
  {
    TRACE;
    clear();
  }

  /*!
  */
  void GaussianMixtureDistribution::operator=(const GaussianMixtureDistribution& o)
  {
    TRACE;
    if(_modeCount!=o._modeCount) {
      delete [] _modes;
      delete [] _weights;
      _modes=new MultivariateNormalDistribution[o._modeCount];
      _weights=new double[o._modeCount];
      _modeCount=o._modeCount;
    }
    for(int im=0; im<_modeCount; im++) {
      _modes[im]=o._modes[im];
      _weights[im]=o._weights[im];
    }
  }

  void GaussianMixtureDistribution::allocate(int modeCount, int dimensionCount,
                                             double *& weights,
                                             MultivariateNormalDistribution *& modes)
  {
    if(modeCount>=1 && dimensionCount>=1) {
      modes=new MultivariateNormalDistribution[modeCount];
      weights=new double[modeCount];
      for(int im=0; im<modeCount; im++) {
        modes[im].init(dimensionCount);
        weights[im]=1.0;
      }
    } else {
      modes=nullptr;
      weights=nullptr;
    }
  }

  void GaussianMixtureDistribution::clear()
  {
    if(_modes) {
      delete [] _modes;
      _modes=nullptr;
    }
    delete [] _weights;
    _weights=nullptr;
  }

  void GaussianMixtureDistribution::clearVariance()
  {
    for(int i=0; i<_modeCount; i++) {
      _modes[i].clearVariance();
    }
  }

  double GaussianMixtureDistribution::value(int mode, const Vector<double>& x) const
  {
    return _modes[mode].value(x)*_weights[mode];
  }

  double GaussianMixtureDistribution::value(const Vector<double>& x) const
  {
    double sum=0.0;
    for(int i=0; i<_modeCount; i++) {
      sum+=_modes[i].value(x)*_weights[i];
    }
    return sum;
  }

  double GaussianMixtureDistribution::value1D(int index, double x) const
  {
    double sum=0.0;
    for(int i=0; i<_modeCount; i++) {
      sum+=_modes[i].value1D(index, x)*_weights[i];
    }
    return sum;
  }

  double GaussianMixtureDistribution::cumulativeValue1D(int index, double x) const
  {
    double sum=0.0;
    for(int i=0; i<_modeCount; i++) {
      sum+=_modes[i].cumulativeValue1D(index, x)*_weights[i];
    }
    return sum;
  }

  bool GaussianMixtureDistribution::beginFilter(double *& newWeights,
                                                MultivariateNormalDistribution *& newModes,
                                                int newModeCount)
  {
    TRACE;
    allocate(newModeCount, dimensionCount(), newWeights, newModes);
    return newWeights && newModes;
  }

  void GaussianMixtureDistribution::endFilter(double * newWeights,
                                              MultivariateNormalDistribution * newModes,
                                              int newModeCount)
  {
    TRACE;
    clear();
    _modeCount=newModeCount;
    _weights=newWeights;
    _modes=newModes;
    // Adjust weights so that sum==1
    double sum=0.0;
    for(int i=0; i<_modeCount; i++) {
      sum+=_weights[i];
    }
    sum=1.0/sum;
    for(int i=0; i<_modeCount; i++) {
      _weights[i]*=sum;
    }
  }

  VectorList<GaussianMixtureDistribution::SortIndex> GaussianMixtureDistribution::sortHelper()
  {
    TRACE;
    VectorList<SortIndex> s(_modeCount);
    for(int i=0; i<_modeCount; i++) {
      s[i]=SortIndex(this, i);
    }
    return s;
  }

  void GaussianMixtureDistribution::commitSort(VectorList<SortIndex>& s)
  {
    TRACE;
    double * newWeights;
    MultivariateNormalDistribution * newModes;
    if(beginFilter(newWeights, newModes, _modeCount)) {
      for(int i=0; i<_modeCount; i++) {
        int oldI=s[i]._i;
        newWeights[i]=_weights[oldI];
        newModes[i]=_modes[oldI];
      }
      endFilter(newWeights, newModes, _modeCount);
    }
  }

  /*!
    Sort modes by decreasing weight
  */
  void GaussianMixtureDistribution::sortWeight()
  {
    TRACE;
    VectorList<SortIndex> s=sortHelper();
    std::sort(s.begin(), s.end(), lessThanWeight);
    commitSort(s);
  }

  /*!
    Sort modes by decreasing mean
  */
  void GaussianMixtureDistribution::sortMean()
  {
    TRACE;
    VectorList<SortIndex> s=sortHelper();
    std::sort(s.begin(), s.end(), lessThanMean);
    commitSort(s);
  }

  bool GaussianMixtureDistribution::lessThanWeight(const SortIndex& i1, const SortIndex& i2)
  {
    return i1._parent->weight(i1._i)>i2._parent->weight(i2._i);
  }

  bool GaussianMixtureDistribution::lessThanMean(const SortIndex& i1, const SortIndex& i2)
  {
    return i1._parent->mode(i1._i).mean()>i2._parent->mode(i2._i).mean();
  }

  /*!
    Remove modes with a weight lower than \a min
  */
  void GaussianMixtureDistribution::filterWeight(double min)
  {
    TRACE;
    double * newWeights;
    MultivariateNormalDistribution * newModes;
    int newModeCount=0;
    if(beginFilter(newWeights, newModes, _modeCount)) {
      for(int i=0; i<_modeCount; i++) {
        if(_weights[i]>min) {
          newWeights[newModeCount]=_weights[i];
          newModes[newModeCount]=_modes[i];
          newModeCount++;
        }
      }
      endFilter(newWeights, newModes, newModeCount);
    }
  }

  /*!
    Over the interval +/- \a stddevFactor*stddev the values cannot exceed the value
    at the mean of the considered mode.
    The condition must be valid for all dimensions.
  */
  void GaussianMixtureDistribution::filterDominantModes(double stddevFactor)
  {
    TRACE;
    double * newWeights;
    MultivariateNormalDistribution * newModes;
    int newModeCount=0;
    if(beginFilter(newWeights, newModes, _modeCount)) {
      for(int im1=0; im1<_modeCount; im1++) {
        const MultivariateNormalDistribution& d1=_modes[im1];
        double m1=value(d1.mean());
        int im2;
        for(im2=0; im2<_modeCount; im2++) {
          if(im1==im2) {
            continue;
          }
          const MultivariateNormalDistribution& d2=_modes[im2];
          PrivateVector<double> v(d2.mean());
          v-=d1.mean();
          v.abs();
          PrivateVector<double> s(v.count());
          d1.stddev(s);
          s*=stddevFactor;
          if(s.hasElementGreaterThan(v) && m1<value(d2.mean())) {
            break;
          }
        }
        if(im2==_modeCount) {
          newWeights[newModeCount]=_weights[im1];
          newModes[newModeCount]=_modes[im1];
          newModeCount++;
        }
      }
      endFilter(newWeights, newModes, newModeCount);
    }
  }

  /*!
    Keeps modes only with a mean in dimension \a dim between \a min and \a max.
  */
  void GaussianMixtureDistribution::filterRange(int dim, double min, double max)
  {
    TRACE;
    double * newWeights;
    MultivariateNormalDistribution * newModes;
    int newModeCount=0;
    if(beginFilter(newWeights, newModes, _modeCount)) {
      for(int i=0; i<_modeCount; i++) {
        double m=_modes[i].mean(dim);
        if(m>min && m<max) {
          newWeights[newModeCount]=_weights[i];
          newModes[newModeCount]=_modes[i];
          newModeCount++;
        }
      }
      endFilter(newWeights, newModes, newModeCount);
    }
  }

  /*!
    Keeps only the \a maxCount best weights.
    Modes must be sorted by decreasing weight.
  */
  void GaussianMixtureDistribution::bestWeights(int maxCount)
  {
    TRACE;
    if(maxCount>=_modeCount) {
      return;
    }
    double * newWeights;
    MultivariateNormalDistribution * newModes;
    if(beginFilter(newWeights, newModes, maxCount)) {
      for(int i=0; i<maxCount; i++) {
        newWeights[i]=_weights[i];
        newModes[i]=_modes[i];
      }
      endFilter(newWeights, newModes, maxCount);
    }
  }

  /*!
    If identical modes are found, their weights are summed.
    Modes must be sorted by decreasing mean.
  */
  void GaussianMixtureDistribution::unique()
  {
    TRACE;
    double * newWeights;
    MultivariateNormalDistribution * newModes;
    if(beginFilter(newWeights, newModes, _modeCount)) {
      int newModeCount=0;
      newWeights[newModeCount]=_weights[0];
      newModes[newModeCount]=_modes[0];
      newModeCount++;
      for(int i=1; i<_modeCount; i++) {
        if(_modes[i].isSimilar(_modes[i-1], 0.01)) {
          newWeights[newModeCount]+=_weights[i];
        } else {
          newWeights[newModeCount]=_weights[i];
          newModes[newModeCount]=_modes[i];
          newModeCount++;
        }
      }
      endFilter(newWeights, newModes, newModeCount);
    }
  }

  /*!
    Remove modes that overlap
  */
  void GaussianMixtureDistribution::excludeMerge()
  {
    TRACE;
    for(int im=0; im<_modeCount; im++) {
      if(_weights[im]>0.0) {
        for(int jm=0; jm<_modeCount; jm++) {
          if(im!=jm && _weights[jm]>0.0) {
            if(_modes[im].overlap(_modes[jm])) {
              _weights[jm]=0.0;
              _weights[im]=0.0;
            }
          }
        }
      }
    }
    filterWeight(0.0);
  }

  QString GaussianMixtureDistribution::toString(bool varianceMatrix)
  {
    QString t;
    const QString formatMode("%1: %2\n");
    const QString formatDim("  %1 +/- %2\n");
    for(int im=0; im<_modeCount; im++) {
      t+=formatMode.arg(im).arg(_weights[im]);
      const MultivariateNormalDistribution& d=_modes[im];
      for(int id=0; id<d.count(); id++) {
        t+=formatDim.arg(d.mean(id), 0, 'g', 6).arg(d.stddev(id), 0, 'g', 6);
        if(varianceMatrix) {
          t+=d.variance().toUserString();
        }
      }
    }
    return t;
  }

} // namespace QGpCoreStat

