#include "TMatrixD.h"
#include "TMath.h"
#include "TVector3.h"
#include "hkalplane.h"
#include "hkaltrackstate.h"
#include <iostream>
using namespace std;
ClassImp(HKalTrackState)
HKalTrackState::HKalTrackState(Kalman::kalFilterTypes stateType, Int_t measDim, Int_t stateDim)
: TObject(),
  fPropagator(stateDim, stateDim),
  fProjector(measDim, stateDim),
  fCovariance(stateDim, stateDim),
  fProcessNoise(stateDim, stateDim),
  stateVec(stateDim) {
    
    
    
    
    type = stateType;
    fPropagator.UnitMatrix();
    fProjector.UnitMatrix();
    fCovariance.UnitMatrix();
}
HKalTrackState::HKalTrackState(Kalman::kalFilterTypes stateType, const TVectorD &sv, Int_t measDim)
: TObject(),
  fPropagator(sv.GetNrows(), sv.GetNrows()),
  fProjector(measDim, sv.GetNrows()),
  fCovariance(sv.GetNrows(), sv.GetNrows()),
  fProcessNoise(sv.GetNrows(), sv.GetNrows()),
  stateVec(sv) {
    
    
    
    
    
    type = stateType;
    fPropagator.UnitMatrix();
    fProjector.UnitMatrix();
    fCovariance.UnitMatrix();
}
HKalTrackState::~HKalTrackState() {
}
void HKalTrackState::calcDir(TVector3 &dir) const {
    
    
    calcDir(dir, stateVec);
}
void HKalTrackState::calcDir(TVector3 &dir, const TVectorD &sv) {
    
    
    
    
    Double_t tanx = sv(kIdxTanPhi);
    Double_t tany = sv(kIdxTanTheta);
    Double_t qp   = sv(kIdxQP);
    dir.SetZ( 1./(TMath::Abs(qp) * TMath::Sqrt(tanx*tanx + tany*tany + 1)) );
    dir.SetX(tanx * dir.Z());
    dir.SetY(tany * dir.Z());
    dir = dir.Unit();
}
void HKalTrackState::calcMomVec(TVector3 &dir) const {
    
    
    calcMomVec(dir, stateVec);
}
void HKalTrackState::calcMomVec(TVector3 &dir, const TVectorD &sv) {
    
    
    
    
    Double_t tanx = sv(kIdxTanPhi);
    Double_t tany = sv(kIdxTanTheta);
    Double_t qp   = sv(kIdxQP);
    dir.SetZ( 1./(TMath::Abs(qp) * TMath::Sqrt(tanx*tanx + tany*tany + 1)) );
    dir.SetX(tanx * dir.Z());
    dir.SetY(tany * dir.Z());
}
Bool_t HKalTrackState::calcPosAtPlane(TVector3 &pos, const HKalPlane &plane) const {
    
    
    
    
    
    return calcPosAtPlane(pos, plane, stateVec);
}
Bool_t HKalTrackState::calcPosAtPlane(TVector3 &pos, const HKalPlane &plane, const TVectorD &sv) {
    
    
    
    
    
    
    TVector3 posFrom;
    posFrom.SetX(sv(kIdxX0));
    posFrom.SetY(sv(kIdxY0));
    posFrom.SetZ(0.);
    TVector3 dir(0., 0., 1.);
    return plane.findIntersection(pos, posFrom, dir);
}
Bool_t HKalTrackState::calcPosAndDirAtPlane(TVector3 &pos, TVector3 &dir, const HKalPlane &plane) const {
    
    
    
    
    
    calcDir(dir);
    return calcPosAtPlane(pos, plane);
}
void HKalTrackState::calcStateVec(TVectorD &sv, Double_t qp, const TVector3 &pos, const TVector3 &dir) {
    
    
    
    
    
#if kalDebug > 0
    Int_t dim = sv.GetNrows();
    Int_t minDim = 5;
    if(dim < minDim) {
        sv.ResizeTo(minDim);
        ::Warning("calcStateVec", Form("Dimension of output function parameter is %i, but must be at least %i.", dim, minDim));
    }
#endif
    sv(kIdxX0)       = pos.x();
    sv(kIdxY0)       = pos.y();
    sv(kIdxTanPhi)   = dir.x() / dir.z();
    sv(kIdxTanTheta) = dir.y() / dir.z();
    sv(kIdxQP)       = qp;
    if(sv.GetNrows() > 5) {
        sv(kIdxZ0) = pos.z();
    }
}
void HKalTrackState::clear() {
    
    stateVec.Zero();
    fPropagator.UnitMatrix();
    fProjector.UnitMatrix();
    fCovariance.UnitMatrix();
    fProcessNoise.Zero();
}
void HKalTrackState::print(const Option_t *opt) const {
    
    
    
    
    
    
    
    
    
    
    TString stropt(opt);
    switch (type) {
    case kPredicted:
        cout<<"**** Predicted state: ****"<<endl;
        break;
    case kFiltered:
        cout<<"**** Filtered state: ****"<<endl;
        break;
    case kSmoothed:
        cout<<"**** Smoothed state: ****"<<endl;
        break;
    case kInvFiltered:
        cout<<"**** Inverse filtered vector: ****"<<endl;
        break;
    }
    if(stropt.Contains("S", TString::kIgnoreCase) || stropt.IsNull()) {
        cout<<"State vector:"<<endl;
        stateVec.Print();
    }
    if(stropt.Contains("F", TString::kIgnoreCase) || stropt.IsNull()) {
        cout<<"Propagator matrix:"<<endl;
        fPropagator.Print();
    }
    if(stropt.Contains("C", TString::kIgnoreCase) || stropt.IsNull()) {
        cout<<"Covariance matrix:"<<endl;
        fCovariance.Print();
    }
    if(stropt.Contains("H", TString::kIgnoreCase) || stropt.IsNull()) {
        cout<<"Projector matrix:"<<endl;
        fProjector.Print();
    }
    if(stropt.Contains("Q", TString::kIgnoreCase) || stropt.IsNull()) {
        cout<<"Process noise matrix:"<<endl;
        fProcessNoise.Print();
    }
    cout<<"**** End of track state print. ****"<<endl;
}
void HKalTrackState::transform(const TRotation &transMat, const HKalPlane &plane) {
    
    
    
    
    TVector3 pos; 
    TVector3 dir; 
    calcPosAndDirAtPlane(pos, dir, plane);
    pos.Transform(transMat);
    dir.Transform(transMat);
    Double_t qp = getStateParam(kIdxQP);
    setStateVec(qp, pos, dir);
}
void HKalTrackState::transformFromLayerToSector(TVectorD &svSec, const TVectorD &svLay, const HKalPlane &plane) {
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
#if kalDebug > 0
    if(svLay.GetNrows() != 5) {
        ::Error("transformFromSectorToLayer()", "Wrong dimension in input state vector svLay. No coordinate transformation done.");
        return;
    }
    if(svSec.GetNrows() != 5 && svSec.GetNrows() != 6) {
        ::Error("transformFromSectorToLayer()", Form("Wrong dimension in output state vector svSec."));
        return;
    }
#endif
    const TVector3 &origin = plane.getCenter();
    const TVector3 &u      = plane.getAxisU();
    const TVector3 &v      = plane.getAxisV();
    const TVector3 &w      = plane.getNormal();
    TVector3 pos = origin + svLay(kIdxX0) * u + svLay(kIdxY0) * v;
    svSec(kIdxX0) = pos.X();
    svSec(kIdxY0) = pos.Y();
    if(svSec.GetNrows() > 5) {
        svSec(kIdxZ0) = pos.Z();
    }
    Double_t tu  = svLay(kIdxTanPhi);
    Double_t tv  = svLay(kIdxTanTheta);
    TVector3 dir = w + tu*u + tv*v;
    svSec(kIdxTanPhi)   = dir.X() / dir.Z();
    svSec(kIdxTanTheta) = dir.Y() / dir.Z();
    svSec(kIdxQP)       = svLay(kIdxQP);
}
void HKalTrackState::transformFromSectorToLayer(TVectorD &svLay, const TVectorD &svSec, const HKalPlane &plane) {
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
#if kalDebug > 0
    if(svSec.GetNrows() != 5 && svSec.GetNrows() != 6) {
        ::Error("transformFromSectorToLayer()", "Wrong dimension in input state vector svSec. No coordinate transformation done.");
        return;
    }
#endif
    Int_t sdim = 5;
    if(svLay.GetNrows() != sdim) {
        svLay.ResizeTo(sdim);
        ::Warning("transformFromSectorToLayer()", Form("Wrong dimension in input state vector svLay. Resized to %ix1", sdim));
    }
    const TVector3 &origin = plane.getCenter();
    const TVector3 &u      = plane.getAxisU();
    const TVector3 &v      = plane.getAxisV();
    const TVector3 &w      = plane.getNormal();
    TVector3 pos;
    TVector3 dir;
    calcPosAtPlane(pos, plane, svSec);
    calcDir(dir, svSec);
    TVector3 diffSecOrg = pos - origin;
    Double_t aw         = dir * w;
    svLay(kIdxX0)       = diffSecOrg * u;
    svLay(kIdxY0)       = diffSecOrg * v;
    svLay(kIdxTanPhi)   = dir * u / aw;
    svLay(kIdxTanTheta) = dir * v / aw;
    svLay(kIdxQP)       = svSec(kIdxQP);
}
void HKalTrackState::setCovMat(const TMatrixD &fCov) {
    
#if kalDebug > 0
    Int_t nRowsOld = getCovMat().GetNrows();
    Int_t nColsOld = getCovMat().GetNcols();
    Int_t nRowsNew = fCov.GetNrows();
    Int_t nColsNew = fCov.GetNcols();
    if(nRowsOld != nRowsNew || nColsOld != nColsNew) {
        Error("setCovMat()", Form("Matrices not compatible. Cannot replace %ix%i matrix with %ix%i matrix.", nRowsOld, nColsOld, nRowsNew, nColsNew));
        exit(1);
    }
#endif
    fCovariance = fCov;
}
void HKalTrackState::setPropMat(const TMatrixD &fProp) {
    
#if kalDebug > 0
    Int_t nRowsOld = getPropMat().GetNrows();
    Int_t nColsOld = getPropMat().GetNcols();
    Int_t nRowsNew = fProp.GetNrows();
    Int_t nColsNew = fProp.GetNcols();
    if(nRowsOld != nRowsNew || nColsOld != nColsNew) {
        Error("setPropMat()", Form("Matrices not compatible. Cannot replace %ix%i matrix with %ix%i matrix.", nRowsOld, nColsOld, nRowsNew, nColsNew));
        exit(1);
    }
#endif
    fPropagator = fProp;
}
void HKalTrackState::setProjMat(const TMatrixD &fProj) {
    
#if kalDebug > 0
    Int_t nRowsOld = getProjMat().GetNrows();
    Int_t nColsOld = getProjMat().GetNcols();
    Int_t nRowsNew = fProj.GetNrows();
    Int_t nColsNew = fProj.GetNcols();
    if(nRowsOld != nRowsNew || nColsOld != nColsNew) {
        Error("setProjMat()", Form("Matrices not compatible. Cannot replace %ix%i matrix with %ix%i matrix.", nRowsOld, nColsOld, nRowsNew, nColsNew));
        exit(1);
    }
#endif
    fProjector = fProj;
}
void HKalTrackState::setProcNoiseMat(const TMatrixD &fProc) {
    
    fProcessNoise.ResizeTo(fProc.GetNrows(), fProc.GetNcols());
    fProcessNoise = fProc;
}
void HKalTrackState::setStateVec(const TVectorD &sv) {
    
#if kalDebug > 0
    Int_t oldStateDim = getStateVec().GetNrows();
    Int_t newStateDim = sv.GetNrows();
    if(oldStateDim != newStateDim) {
        Error("setStateVec()", Form("Dimension of new state vector (%i) not the same as dimension of current state vector (%i).", newStateDim, oldStateDim));
        exit(1);
    }
#endif
    stateVec = sv;
}