/***************************************************************************
**
**  This file is part of ArrayCore.
**
**  ArrayCore is free software: you can redistribute it and/or modify
**  it under the terms of the GNU General Public License as published by
**  the Free Software Foundation, either version 3 of the License, or
**  (at your option) any later version.
**
**  ArrayCore is distributed in the hope that it will be useful,
**  but WITHOUT ANY WARRANTY; without even the implied warranty of
**  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
**  GNU General Public License for more details.
**
**  You should have received a copy of the GNU General Public License
**  along with Foobar.  If not, see <http://www.gnu.org/licenses/>
**
**  See http://www.geopsy.org for more information.
**
**  Created: 2022-08-29
**  Copyright: 2022
**    Marc Wathelet (ISTerre, Grenoble, France)
**
***************************************************************************/

#include "Wavefield.h"

namespace ArrayCore {

  /*!
    \class Wavefield Wavefield.h
    \brief Brief description of class still missing

    Full description of class still missing
  */

  GlobalRandom * Wavefield::_phaseGenerator=nullptr;

  bool Wavefield::Wave::operator==(const Wave& o) const
  {
    return a==o.a && k==o.k && theta==o.theta && xi==o.xi &&
        phi==o.phi && kvec==o.kvec && ctheta==o.ctheta && stheta==o.stheta &&
        cxi==o.cxi && sxi==o.sxi;
  }

  /*!

  */
  Wavefield::Wavefield()
  {
    _etaH=0.0;
    _etaZ=0.0;
    _blockCount=1;
    if(!_phaseGenerator) {
      _phaseGenerator=new GlobalRandom;
      CoreApplication::instance()->addGlobalObject(_phaseGenerator);
    }
    _fields=nullptr;
  }

  /*!
    Description of destructor still missing
  */
  Wavefield::~Wavefield()
  {
    delete [] _fields;
  }

  bool Wavefield::operator==(const Wavefield& o) const
  {
    if(_blockCount!=o._blockCount) {
      return false;
    }
    for(int ib=0; ib<_blockCount; ib++) {
      if(_fields[ib]!=o._fields[ib]) {
        return false;
      }
    }
    return _etaH==o._etaH &&
        _etaZ==o._etaZ &&
        _waves==o._waves &&
        _sensors==o._sensors &&
        _expCache==o._expCache;
  }

  Point2D Wavefield::sensor(int index) const
  {
    const Complex& c=_sensors.at(index);
    return Point2D(c.re(), c.im());
  }

  void Wavefield::setBlockCount(int blockCount)
  {
    _blockCount=blockCount;
    allocateFieldVectors();
  }

  void Wavefield::setSensors(const VectorList<Point2D>& pos)
  {
    int n=pos.count();
    _sensors.resize(n);
    for(int i=0; i<n; i++) {
      const Point2D& p=pos.at(i);
      Complex& r=_sensors[i];
      r.setRe(p.x());
      r.setIm(p.y());
    }
    allocateFieldVectors();
  }

  void Wavefield::allocateFieldVectors()
  {
    int ns=_sensors.count();
    if(_blockCount>0 && ns>0) {
      delete [] _fields;
      _fields=new ComplexMatrix[_blockCount];
      for(int i=_blockCount-1; i>=0; i--) {
        _fields[i].resize(3*ns, 1);
      }
    }
  }

  /*!
  */
  void Wavefield::addWave(double amplitude, double wavenumber,
                          double theta, double xi)
  {
    _waves.append(Wave());
    Wave& w=_waves.last();
    w.a=amplitude;
    w.k=wavenumber;
    w.theta=theta;
    w.xi=xi;
    //Random pg(qRound(Angle::radiansToDegrees(Angle::canonicalRadians(w.theta+1))));
    w.phi.resize(_blockCount);
    for(int i=1; i<_blockCount; i++) {
      w.phi[i]=_phaseGenerator->uniform(0.0, 2.0*M_PI);
      //w.phi[i]=pg.uniform(0.0, 2.0*M_PI);
      //printf("Wavefield random phase block[%i]=%lf\n", i, w.phi[i]);
    }
    // precomputed values, minus in front of theta to transform a product
    // of complex numbers into a scalar product (the real part)
    // (a+jb)*(c-jd)=(ac+bd)+j(bc-ad)
    w.kvec.setExp(w.k, -w.theta);
    w.ctheta=cos(w.theta);
    w.stheta=sin(w.theta);
    w.cxi=cos(w.xi);
    w.sxi=sin(w.xi);
  }

  void Wavefield::setAmplitude(int i, double amplitude)
  {
    if(i>=_waves.count()) {
      _waves.resize(i+1);
    }
    Wave& w=_waves[i];
    w.a=amplitude;
  }

  void Wavefield::setEllipticity(int i, double xi)
  {
    if(i>=_waves.count()) {
      _waves.resize(i+1);
    }
    Wave& w=_waves[i];
    w.xi=xi;
    // precomputed values
    w.cxi=cos(w.xi);
    w.sxi=sin(w.xi);
  }

  void Wavefield::setIncoherentNoise(double R, double sigma)
  {
    // Noise parameters are R (total noise ratio) and sigma the ratio etaH/etaZ
    double eta=R*_signalPower;
    double f=1.0/(2.0*sigma+1.0);
    _etaH=eta*sigma*f;
    _etaZ=eta*f;
  }

  void Wavefield::setSignalPower()
  {
    int nw=_waves.count();
    _signalPower=0.0;
    for(int iw=0; iw<nw; iw++) {
      Wave& w=_waves[iw];
      _signalPower+=w.a*w.a;
    }
    // R=1, sigma=1
    _etaH=_signalPower/3.0;
    _etaZ=_etaH;
  }

  void Wavefield::setParameters(const Vector<double>& p)
  {
    int nw=_waves.count();
    int nw2=2*nw;
    ASSERT(p.count()==nw2+2);
    _signalPower=0.0;
    for(int iw=0; iw<nw; iw++) {
      Wave& w=_waves[iw];
      int ip=2*iw;
      w.a=p[ip];
      _signalPower+=w.a*w.a;
      w.xi=p[ip+1];
      // precomputed values
      w.cxi=cos(w.xi);
      w.sxi=sin(w.xi);
    }
    setIncoherentNoise(p[nw2], p[nw2+1]);
  }

  void Wavefield::getParameters(Vector<double>& p) const
  {
    int nw=_waves.count();
    int nw2=2*nw;
    ASSERT(p.count()==nw2+2);
    for(int iw=0; iw<nw; iw++) {
      const Wave& w=_waves[iw];
      int ip=2*iw;
      p[ip]=w.a;
      p[ip+1]=w.xi;
    }
    // Noise parameters are R (total noise ratio) and sigma the ratio etaH/etaZ
    p[nw2]=(2.0*_etaH+_etaZ)/_signalPower;
    p[nw2+1]=_etaH/_etaZ;
  }

  double Wavefield::incoherentNoiseRation() const
  {
    return (2.0*_etaH+_etaZ)/_signalPower;
  }

  void Wavefield::setMaximumShift(Vector<double>& maxShift) const
  {
    int nw=_waves.count();
    int nw2=2*nw;
    ASSERT(maxShift.count()==nw2+2);
    for(int i=nw2+1; i>=0; i--) {
      maxShift[i]=std::numeric_limits<double>::infinity();
    }
    return;
    for(int iw=0; iw<nw; iw++) {
      const Wave& w=_waves[iw];
      int ip=2*iw;
      maxShift[ip]=w.a*0.5;
      maxShift[ip+1]=M_PI/18.0;
    }
    maxShift[nw2]=5.0;
    maxShift[nw2+1]=5.0;
  }

  void Wavefield::setPrecision(Vector<double>& prec) const
  {
    int nw=_waves.count();
    int nw2=2*nw;
    ASSERT(prec.count()==nw2+2);
    for(int iw=0; iw<nw; iw++) {
      int ip=2*iw;
      prec[ip]=0.001;
      prec[ip+1]=M_PI/18000.0;
    }
    prec[nw2]=0.001;
    prec[nw2+1]=0.001;
  }

  Complex Wavefield::east(int iw, int ib) const
  {
    Complex c;
    const Wave& w=_waves.at(iw);
    const Complex& r=_sensors.at(iw);
    c.setExp(w.a*w.sxi*w.ctheta, w.phi[ib]-(w.kvec*r).re());
    c.imaginaryMultiply(-1.0);
    return c;
  }

  Complex Wavefield::north(int iw, int ib) const
  {
    Complex c;
    const Wave& w=_waves.at(iw);
    const Complex& r=_sensors.at(iw);
    c.setExp(w.a*w.sxi*w.stheta, w.phi[ib]-(w.kvec*r).re());
    c.imaginaryMultiply(-1.0);
    return c;
  }

  Complex Wavefield::vertical(int iw, int ib) const
  {
    Complex c;
    const Wave& w=_waves.at(iw);
    const Complex& r=_sensors.at(iw);
    c.setExp(w.a*w.cxi, w.phi[ib]-(w.kvec*r).re());
    return c;
  }

  void Wavefield::setFieldVectors()
  {
    int nw=_waves.count();
    int ns=_sensors.count();
    int ns2=2*ns;
    // Set exponental cache
    _expCache.resize(_blockCount*ns*nw);
    int cacheIndex=0;
    for(int ib=0; ib<_blockCount; ib++) {
      for(int is=0; is<ns; is++) {
        const Complex& r=_sensors.at(is);
        for(int iw=0; iw<nw; iw++) {
          const Wave& w=_waves.at(iw);
          _expCache[cacheIndex++].setUnitExp(w.phi[ib]-(w.kvec*r).re());
        }
      }
    }
    // Field vectors
    Complex c1, c;
    cacheIndex=0;
    for(int ib=0; ib<_blockCount; ib++) {
      ComplexMatrix& field=_fields[ib];
      field.zero();
      for(int is=0; is<ns; is++) {
        for(int iw=0; iw<nw; iw++) {
          const Wave& w=_waves.at(iw);
          Complex exp=_expCache[cacheIndex++];
          exp*=w.a;
          // East
          c=exp;
          c.imaginaryMultiply(-w.sxi*w.ctheta);
          field.at(is, 0)+=c;
          // North
          c=exp;
          c.imaginaryMultiply(-w.sxi*w.stheta);
          field.at(is+ns, 0)+=c;
          // Vertical
          c=exp;
          c*=w.cxi;
          field.at(is+ns2, 0)+=c;
        }
      }
    }
  }

  /*!
    You must call setFieldVectors() before this function.
  */
  ComplexMatrix Wavefield::crossSpectrum() const
  {
    int ns=_sensors.count();
    int ns2=2*ns;
    ComplexMatrix covmat(3*ns);
    covmat.zero();
    Complex c1, c;
    for(int ib=0; ib<_blockCount; ib++) {
      ComplexMatrix& field=_fields[ib];
      covmat+=field*field.conjugateTransposedVector();
    }
    covmat*=1.0/static_cast<double>(_blockCount);
    // Add noise
    int i1, i2;
    for(int i=0; i<ns; i++) {
      i1=i+ns;
      i2=i+ns2;
      covmat.at(i, i)+=_etaH;
      covmat.at(i1, i1)+=_etaH;
      covmat.at(i2, i2)+=_etaZ;
    }
    return covmat;
  }

  /*!
    You must call setFieldVectors() before this function.
  */
  ComplexMatrix Wavefield::crossSpectrumDerivative(int parameterIndex) const
  {
    int ns=_sensors.count();
    int ns2=2*ns;
    int ns3=3*ns;
    ComplexMatrix covmatder(ns3);
    covmatder.zero();
    int nw=_waves.count();
    // Parameter selection
    int noiseParam=parameterIndex-2*nw;
    if(noiseParam>=0) {
      /*if(noiseParam==0) {  // eta_H
        int i1;
        for(int i=0; i<ns; i++) {
          i1=i+ns;
          covmatder.at(i, i)=1.0;
          covmatder.at(i1, i1)=1.0;
        }
      } else {             // eta_Z
        for(int i=ns2; i<ns3; i++) {
          covmatder.at(i, i)=1.0;
        }
      }*/
      // Noise parameters are eta (total noise) and sigma the ratio etaH/etaZ
      /*double eta=2.0*_etaH+_etaZ;
      double f=_etaZ/eta;
      int i1, i2;
      if(noiseParam==0) {  // eta
        double sigma=_etaH/_etaZ;
        for(int i=0; i<ns; i++) {
          i1=i+ns;
          i2=i+ns2;
          Complex& c=covmatder.at(i, i);
          c=sigma*f;
          covmatder.at(i1, i1)=c;
          covmatder.at(i2, i2)=f;
        }
      } else {             // sigma
        double f2=f*f;
        for(int i=0; i<ns; i++) {
          i1=i+ns;
          i2=i+ns2;
          Complex& c=covmatder.at(i, i);
          c=eta*f2;
          covmatder.at(i1, i1)=c;
          covmatder.at(i2, i2)=-2.0*eta*f2;
        }
      }*/
      // Noise parameters are R (total noise ratio) and sigma the ratio etaH/etaZ
      double eta=2.0*_etaH+_etaZ;
      double f=_etaZ/eta;
      int i1, i2;
      if(noiseParam==0) {  // R
        double sigma=_etaH/_etaZ;
        f*=_signalPower;
        for(int i=0; i<ns; i++) {
          i1=i+ns;
          i2=i+ns2;
          Complex& c=covmatder.at(i, i);
          c=sigma*f;
          covmatder.at(i1, i1)=c;
          covmatder.at(i2, i2)=f;
        }
      } else {             // sigma
        double f2=f*f;
        for(int i=0; i<ns; i++) {
          i1=i+ns;
          i2=i+ns2;
          Complex& c=covmatder.at(i, i);
          c=eta*f2;
          covmatder.at(i1, i1)=c;
          covmatder.at(i2, i2)=-2.0*eta*f2;
        }
      }
    } else {
      int iw=parameterIndex/2;
      const Wave& w=_waves.at(iw);
      // Derivative of conjugate complex is computed
      ComplexMatrix derField(1, ns3);
      ComplexMatrix halfder(ns3);
      int cacheIndex=iw;
      if(parameterIndex-2*iw==0) { // A_w
        for(int ib=0; ib<_blockCount; ib++) {
          ComplexMatrix& field=_fields[ib];
          for(int is=0; is<ns; is++) {
            Complex exp=_expCache[cacheIndex];
            cacheIndex+=nw;
            exp.conjugate();
            // East
            Complex& e=derField.at(0, is);
            e=exp;
            e.imaginaryMultiply(w.sxi*w.ctheta);
            // North
            Complex& n=derField.at(0, is+ns);
            n=exp;
            n.imaginaryMultiply(w.sxi*w.stheta);
            // Vertical
            Complex& z=derField.at(0, is+ns2);
            z=exp;
            z*=w.cxi;
          }
          halfder=field*derField;
          covmatder+=halfder+halfder.conjugate().transposed();
        }
        covmatder*=1.0/static_cast<double>(_blockCount);
        // Dependency of eta_H and eta_Z to A_w
        double eta=2.0*_etaH+_etaZ;
        double R=eta/_signalPower;
        double f=2.0*_etaZ/eta*w.a*R;
        double sigma=_etaH/_etaZ;
        double sigmaf=sigma*f;
        int i1, i2;
        for(int i=0; i<ns; i++) {
          i1=i+ns;
          i2=i+ns2;
          covmatder.at(i, i)+=sigmaf;
          covmatder.at(i1, i1)+=sigmaf;
          covmatder.at(i2, i2)+=f;
        }
      } else {                     // xi_w
        for(int ib=0; ib<_blockCount; ib++) {
          ComplexMatrix& field=_fields[ib];
          for(int is=0; is<ns; is++) {
            Complex exp=_expCache[cacheIndex];
            cacheIndex+=nw;
            exp.conjugate();
            exp*=w.a;
            // East
            Complex& e=derField.at(0, is);
            e=exp;
            e.imaginaryMultiply(w.cxi*w.ctheta);
            // North
            Complex& n=derField.at(0, is+ns);
            n=exp;
            n.imaginaryMultiply(w.cxi*w.stheta);
            // Vertical
            Complex& z=derField.at(0, is+ns2);
            z=exp;
            z*=-w.sxi;
          }
          halfder=field*derField;
          covmatder+=halfder+halfder.conjugate().transposed();
        }
        covmatder*=1.0/static_cast<double>(_blockCount);
      }
    }
    return covmatder;
  }

  void Wavefield::initMisfit(const VectorList<WavefieldValues::Observations>& obs, const FKSteering * steering)
  {
    int nw=_waves.count();
    ASSERT(obs.count()==nw);
    Q_UNUSED(obs);
    _simValues.resize(nw);
    for(int iw=0; iw<nw; iw++) {
      _simValues[iw].setSteering(steering);
    }
  }

  const FKSteering * Wavefield::steering() const
  {
    return _simValues[0].steering();
  }

  VectorList<Point2D> Wavefield::sensors() const
  {
    VectorList<Point2D> points;
    int n=_sensors.count();
    points.resize(n);
    for(int i=0; i<n; i++) {
      const Complex& r=_sensors[i];
      points[i]=Point2D(r.re(), r.im());
    }
    return points;
  }

  double Wavefield::positivePenalty(double param) const
  {
    if(param<0.0) {
      double p1=param /*-1.0*/;
      return p1*p1;
    } else {
      return 0.0;
    }
  }

  double Wavefield::positivePenaltyDerivative(double param) const
  {
    if(param<0.0) {
      return param;
    } else {
      return 0.0;
    }
  }

  double Wavefield::ellipticityPenalty(double xi, double xih, double xiz) const
  {
    double d;
    if(xih>0.0) {
      if(xi<xih) {
        d=xi-xih;
      } else if(xi>xiz) {
        d=xi-xiz;
      } else {
        return 0.0;
      }
    } else {
      if(xi<xiz) {
        d=xi-xiz;
      } else if(xi>xih) {
        d=xi-xih;
      } else {
        return 0.0;
      }
    }
    return d*d;
  }

  double Wavefield::ellipticityPenaltyDerivative(double xi, double xih, double xiz) const
  {
    double d;
    if(xih>0.0) {
      if(xi<xih) {
        d=xi-xih;
      } else if(xi>xiz) {
        d=xi-xiz;
      } else {
        return 0.0;
      }
    } else {
      if(xi<xiz) {
        d=xi-xiz;
      } else if(xi>xih) {
        d=xi-xih;
      } else {
        return 0.0;
      }
    }
    return d;
  }

  double Wavefield::misfit(const VectorList<WavefieldValues::Observations>& obs)
  {
    double m=0.0;
    // Permanent member pointed by cross-spectra in WavefieldValues
    _invFsim=crossSpectrum();
    _invFsim.invert();
    int nw=_waves.count();
    for(int iw=0; iw<nw; iw++) {
      const Wave& w=_waves.at(iw);
      WavefieldValues& val=_simValues[iw];
      const WavefieldValues::Observations& obsW=obs.at(iw);
      val.setCrossSpectrum(&_invFsim);
      val.setWaveNumber(w.kvec);
      val.setSteering();
      val.setValues(obsW);
      m+=val.misfit(obsW);
    }
    return m;
  }

  double Wavefield::misfitDerivative(int parameterIndex, const VectorList<WavefieldValues::Observations>& obs)
  {
    double m=0.0;
    int nw=_waves.count();
    //int nw2=2*nw;
    ComplexMatrix csDer=crossSpectrumDerivative(parameterIndex);
    for(int iw=0; iw<nw; iw++) {
      WavefieldValues& val=_simValues[iw];
      val.setCrossSpectrumDerivative(csDer);
      val.setDerivativeValues();
      m+=val.misfitDerivative(obs.at(iw));
    }
    return 2.0*m;
  }

  QString Wavefield::detailedMisfit(const VectorList<WavefieldValues::Observations>& obs)
  {
    QString s;
    _invFsim=crossSpectrum();
    _invFsim.invert();
    int nw=_waves.count();
    for(int iw=0; iw<nw; iw++) {
      const Wave& w=_waves.at(iw);
      s+=tr("wave %1: ").arg(iw);
      WavefieldValues& val=_simValues[iw];
      const WavefieldValues::Observations& obsW=obs.at(iw);
      val.setCrossSpectrum(&_invFsim);
      val.setWaveNumber(w.kvec);
      val.setSteering();
      val.setValues(obsW);
      s+=val.detailedMisfit(obsW);
    }
    return s;
  }

} // namespace ArrayCore

