/***************************************************************************
**
**  This file is part of QGpCoreMath.
**
**  QGpCoreMath is free software: you can redistribute it and/or modify
**  it under the terms of the GNU General Public License as published by
**  the Free Software Foundation, either version 3 of the License, or
**  (at your option) any later version.
**
**  QGpCoreMath is distributed in the hope that it will be useful,
**  but WITHOUT ANY WARRANTY; without even the implied warranty of
**  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
**  GNU General Public License for more details.
**
**  You should have received a copy of the GNU General Public License
**  along with Foobar.  If not, see <http://www.gnu.org/licenses/>
**
**  See http://www.geopsy.org for more information.
**
**  Created: 2022-12-05
**  Copyright: 2022
**    Marc Wathelet (ISTerre, Grenoble, France)
**
***************************************************************************/


#include "MultivariateNormalDistribution.h"
#include "MultivariateStatistics.h"
#include "NormalDistribution.h"
#include "StatisticalValue.h"

namespace QGpCoreMath {

  /*!
    \class MultivariateNormalDistribution MultivariateNormalDistribution.h
    \brief Multivariate Normal distribution for diagonal covariance matrix

    Full description of class still missing
  */

  MultivariateNormalDistribution::MultivariateNormalDistribution(int dimensionCount)
    : _mean(dimensionCount, 0.0), _invVariance(dimensionCount, dimensionCount)
  {
    _invVariance.zero();
    _factor=1.0;
  }

  /*!
    Description of constructor still missing
  */
  MultivariateNormalDistribution::MultivariateNormalDistribution(const Vector<double>& mean,
                                                                 const DoubleMatrix& covariance)
    : _mean(mean), _invVariance(covariance)
  {
    ASSERT(mean.count()==_invVariance.rowCount());
    ASSERT(_invVariance.rowCount()==_invVariance.columnCount());
    _factor=1.0/sqrt(2.0*M_PI*_invVariance.determinant());
    _invVariance.invert();
  }

  MultivariateNormalDistribution:: MultivariateNormalDistribution(const MultivariateStatistics& stat)
    : _mean(stat.dimensionCount())
  {
    stat.mean(_mean);
    _invVariance=stat.variance();
    setFactor();
  }

  void MultivariateNormalDistribution::operator=(const MultivariateNormalDistribution& o)
  {
    _mean=o._mean;
    _invVariance=o._invVariance;
    _factor=o._factor;
  }

  bool MultivariateNormalDistribution::isSimilar(const MultivariateNormalDistribution& o, double precision)
  {
    int n=_mean.count();
    for(int i=0; i<n; i++) {
      double s=stddev(i);
      double sp=s*precision;
      if(fabs(s-o.stddev(i))>sp ||
         fabs(_mean[i]-o._mean[i])>sp) {
        return false;
      }
    }
    return true;
  }

  void MultivariateNormalDistribution::init(int dimension)
  {
     _mean.resize(dimension);
     _invVariance.resize(dimension);
  }

  void MultivariateNormalDistribution::setFactor()
  {
    _invVariance.invert();
    _factor=sqrt(_invVariance.determinant()/(2.0*M_PI));
  }

  void MultivariateNormalDistribution::clearVariance()
  {
    _invVariance.zero();
  }

  void MultivariateNormalDistribution::setStddev(int index, double s)
  {
    _invVariance.at(index, index)=s*s;
  }

  void MultivariateNormalDistribution::setVariance(const DoubleMatrix& v)
  {
    _invVariance=v;
    setFactor();
  }

  double MultivariateNormalDistribution::value(const Vector<double>& x) const
  {
    PrivateVector<double> x0(x), xs(_mean.count());
    x0-=_mean;
    xs.multiply(_invVariance, x0);
    return _factor*exp(-0.5*xs.scalarProductConjugate(x0));
  }

  double MultivariateNormalDistribution::value1D(int index, double x) const
  {
    NormalDistribution d(_mean.at(index), 1.0/sqrt(_invVariance.at(index, index)));
    return d.value(x);
  }

  double MultivariateNormalDistribution::cumulativeValue1D(int index, double x) const
  {
    NormalDistribution d(_mean.at(index), 1.0/sqrt(_invVariance.at(index, index)));
    return d.cumulativeValue(x);
  }

  DoubleMatrix MultivariateNormalDistribution::variance() const
  {
    DoubleMatrix var(_invVariance);
    var.invert();
    return var;
  }

  /*!
    Returns true if there is at least one dimension where the difference between
    the means is larger than the sum of standard deviations.
  */
  bool MultivariateNormalDistribution::overlap(const MultivariateNormalDistribution& o) const
  {
    int n=_mean.count();
    for(int i=0; i<n; i++) {
      if(fabs(_mean[i]-o._mean[i])>stddev(i)+o.stddev(i)) {
        return false;
      }
    }
    return true;
  }

  void MultivariateNormalDistribution::stddev(Vector<double>& s) const
  {
    int n=_mean.count();
    for(int i=0; i<n; i++) {
      s[i]=stddev(i);
    }
  }

} // namespace QGpCoreMath

