/***************************************************************************
**
**  This file is part of QGpCoreStat.
**
**  QGpCoreStat is free software: you can redistribute it and/or modify
**  it under the terms of the GNU General Public License as published by
**  the Free Software Foundation, either version 3 of the License, or
**  (at your option) any later version.
**
**  QGpCoreStat is distributed in the hope that it will be useful,
**  but WITHOUT ANY WARRANTY; without even the implied warranty of
**  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
**  GNU General Public License for more details.
**
**  You should have received a copy of the GNU General Public License
**  along with Foobar.  If not, see <http://www.gnu.org/licenses/>
**
**  See http://www.geopsy.org for more information.
**
**  Created: 2023-06-16
**  Copyright: 2023
**    Marc Wathelet (ISTerre, Grenoble, France)
**
***************************************************************************/

#include "GaussianMixtureEM.h"

namespace QGpCoreStat {

  /*!
    \class GaussianMixtureEM GaussianMixtureEM.h
    \brief Brief description of class still missing

    Full description of class still missing
  */

  /*!
    Description of constructor still missing
  */
  GaussianMixtureEM::GaussianMixtureEM()
  {
    _histogram=nullptr;
    _distribution=nullptr;
  }

  /*!
    Description of destructor still missing
  */
  GaussianMixtureEM::~GaussianMixtureEM()
  {
    delete _distribution;
  }

  void GaussianMixtureEM::init(const MultivariateHistogram * h, int modeCount)
  {
    _histogram=h;
    delete _distribution;
    _distribution=new GaussianMixtureDistribution(modeCount, _histogram->dimensionCount());
  }

  void GaussianMixtureEM::uniform(const Vector<double>& min,
                                  const Vector<double>& max,
                                  const Vector<double>& stddev)
  {
    _distribution->clearVariance();
    int nModes=_distribution->modeCount();
    double weight=1.0/nModes;
    for(int iDim=_histogram->dimensionCount()-1; iDim>=0; iDim--) {
      double step=(max[iDim]-min[iDim])/(nModes+1.0);
      for(int iMode=0; iMode<nModes; iMode++) {
        _distribution->setWeight(iMode, weight);
        _distribution->setMode(iMode, iDim, min[iDim]+(iMode+1)*step, stddev[iDim]);
      }
    }
    for(int iMode=_distribution->modeCount()-1; iMode>=0; iMode--) {
      _distribution->setFactor(iMode);
    }
  }

  void GaussianMixtureEM::iterate()
  {
    int nDims=_histogram->dimensionCount();
    int nBuckets=_histogram->bucketCount();
    PrivateVector<double> pos(nDims);
    double * h=new double[nBuckets];
    double * hn=new double[nBuckets];
    double sampleNormFactor=0;
    for(int iBucket=0; iBucket<nBuckets; iBucket++) {
      _histogram->bucketPosition(iBucket, pos);
      hn[iBucket]=1.0/_distribution->value(pos);
      sampleNormFactor+=_histogram->bucketValueAt(iBucket);
    }
    sampleNormFactor=1.0/sampleNormFactor;
    GaussianMixtureDistribution * newDistribution;
    newDistribution=new GaussianMixtureDistribution(_distribution->modeCount(),
                                                    _distribution->dimensionCount());
    for(int iMode=_distribution->modeCount()-1; iMode>=0; iMode--) {
      double wSum=0.0;
      PrivateVector<double> muSum(nDims, 0.0);
      for(int iBucket=0; iBucket<nBuckets; iBucket++) {
        _histogram->bucketPosition(iBucket, pos);
        double sampleCount=_histogram->bucketValueAt(iBucket);
        h[iBucket]=_distribution->value(iMode, pos)*hn[iBucket];
        wSum+=h[iBucket]*sampleCount;
        for(int iDim=0; iDim<nDims; iDim++) {
          muSum[iDim]+=h[iBucket]*pos[iDim]*sampleCount;
        }
      }
      double muNormFactor=1.0/wSum;
      for(int iDim=0; iDim<nDims; iDim++) {
        muSum[iDim]*=muNormFactor;
      }
      wSum*=sampleNormFactor;
      DoubleMatrix sigmaSum(nDims);
      sigmaSum.zero();
      for(int iBucket=0; iBucket<nBuckets; iBucket++) {
        _histogram->bucketPosition(iBucket, pos);
        double sampleCount=_histogram->bucketValueAt(iBucket);
        DoubleMatrix x0(nDims, 1);
        for(int iDim=0; iDim<nDims; iDim++) {
          x0.at(iDim, 0)=pos[iDim]-muSum[iDim];
        }
        x0*=x0.transposed();
        x0*=h[iBucket]*sampleCount;
        sigmaSum+=x0;
      }
      sigmaSum*=muNormFactor;
      MultivariateNormalDistribution& d=newDistribution->mode(iMode);
      newDistribution->setWeight(iMode, wSum);
      d.setMean(muSum);
      d.setVariance(_distribution->mode(iMode).variance());
      //d.setVariance(sigmaSum);
    }
    delete [] h;
    delete [] hn;
    delete _distribution;
    _distribution=newDistribution;
  }


} // namespace QGpCoreStat

