#include "hmdcsizescells.h"
#include "hkaldafwire.h"
#include "hkalgeomtools.h"
ClassImp (HKalDafWire)
HKalDafWire::HKalDafWire(Int_t n, Int_t measDim, Int_t stateDim,
HMdcTrackGField *fMap, Double_t fpol)
: HKalDafSingleWire(n, measDim, stateDim, fMap, fpol) {
}
Bool_t HKalDafWire::calcProjector(Int_t iSite) const {
HKalTrackSite *site = getSite(iSite);
if(!site) {
return kFALSE;
}
TMatrixD fProj(getMeasDim(), getStateDim(Kalman::kLayCoord));
fProj.Zero();
fProj(0, kIdxY0) = 1.;
site->setStateProjMat(Kalman::kFiltered, fProj, Kalman::kLayCoord);
return kTRUE;
}
Bool_t HKalDafWire::calcVirtPlane(Int_t iSite) const {
HKalTrackSite *site = getSite(iSite);
if(!site) {
return kFALSE;
}
const TVectorD &predState = site->getStateVec(Kalman::kPredicted);
TVector3 posAt;
HKalTrackState::calcPosAtPlane(posAt, site->getHitMeasLayer(), predState);
TVector3 dir;
HKalTrackState::calcDir(dir, predState);
TVector3 pos1 = posAt - dir;
TVector3 pos2 = posAt + dir;
const Int_t nComp = site->getNcompetitors();
const Int_t nWires = nComp/2;
for(Int_t iWire = 0; iWire < nWires; iWire++) {
Int_t iHit = iWire*2;
TVector3 wire1, wire2;
site->getHitWirePts(wire1, wire2, iHit);
Int_t iflag;
Double_t length, mindist;
TVector3 pcaTrack, pcaWire;
HKalGeomTools::Track2ToLine(pcaTrack, pcaWire, iflag, mindist, length,
pos1, pos2, wire1, wire2);
if(iflag == 1) {
if(getPrintWarn()) {
Warning("calcVirtPlane()", "PCA is outside of wire");
}
}
if(iflag == 2) {
if(getPrintErr()) {
Error("calcVirtPlane()", "Track parallel to wire.");
}
return kFALSE;
}
TVector3 u = (wire2 - wire1).Unit();
TVector3 v = (pcaTrack - pcaWire).Unit();
HMdcSizesCells *fSizesCells = HMdcSizesCells::getExObject();
HMdcSizesCellsLayer &fSizesCellsLayer =
(*fSizesCells)[site->getSector()][site->getModule()][site->getLayer()];
const HGeomTransform &transRotLaySysRSec =
fSizesCellsLayer.getRotLaySysRSec(site->getCell());
HGeomVector ygeom(0., 1., 0.);
transRotLaySysRSec.transFrom(ygeom);
TVector3 y(ygeom.getX(), ygeom.getY(), ygeom.getZ());
if(y.Dot(v) < 0.) {
v = -1. * v;
}
site->setHitVirtPlane(pcaWire, u, v, iHit);
site->setHitVirtPlane(pcaWire, u, v, iHit+1);
#if dafDebug > 2
cout<<"Virtual plane for wire hit "<<iWire<<" of site "<<iSite<<endl;
cout<<"Centre: "<<endl;
posAt.Print();
cout<<"Axis U: "<<endl;
u.Print();
cout<<"Axis V: "<<endl;
v.Print();
#endif
}
return kTRUE;
}
Bool_t HKalDafWire::calcMeasVecFromState(TVectorD &projMeasVec,
HKalTrackSite const* const site,
Kalman::kalFilterTypes stateType,
Kalman::coordSys sys) const {
#if kalDebug > 0
Int_t mdim = getMeasDim();
if(projMeasVec.GetNrows() != mdim) {
Warning("calcMeasVecFromState()",
Form("Dimension of measurement vector (%i) does not match that of function parameter (%i).",
mdim, projMeasVec.GetNrows()));
projMeasVec.ResizeTo(mdim);
}
#endif
const TVectorD &sv = site->getStateVec(stateType, sys);
if(sys == Kalman::kSecCoord) {
TVector3 posAt;
if(!HKalTrackState::calcPosAtPlane(posAt, site->getHitMeasLayer(), sv)) {
if(getPrintErr()) {
Error("calcMeasVecFromState()",
"Could not extract position vector from track state.");
}
return kFALSE;
}
TVector3 dir;
HKalTrackState::calcDir(dir, sv);
TVector3 pos1 = posAt - dir;
TVector3 pos2 = posAt + dir;
TVector3 wire1, wire2;
Int_t iHit = getWireNr()*2;
site->getHitWirePts(wire1, wire2, iHit);
TVector3 pcaTrack, pcaWire;
Int_t Iflag = 0;
Double_t dist = 0.;
Double_t length = 0.;
HKalGeomTools::Track2ToLine(pcaTrack, pcaWire, Iflag, dist, length,
pos1, pos2, wire1, wire2);
if(site->getHitVirtPlane(iHit).getAxisV().Dot(pcaTrack - pcaWire) < 0.) {
dist *= -1.;
}
projMeasVec(0) = dist;
} else {
projMeasVec(0) = sv(kIdxY0);
}
return kTRUE;
}
Bool_t HKalDafWire::filter(Int_t iSite) {
HKalTrackSite *site = getSite(iSite);
if(!site) {
return kFALSE;
}
const Int_t nWires = site->getNcompetitors()/2;
vector<TVectorD> states(nWires, TVectorD(getStateDim()));
vector<TMatrixD> invCovs(nWires, TMatrixD(getStateDim(), getStateDim()));
for(Int_t iWire = 0; iWire < nWires; iWire++) {
setWireNr(iWire);
Int_t iHit = iWire*2;
getSite(iSite)->transSecToVirtLay(Kalman::kPredicted, iHit,
(getFilterMethod() == Kalman::kKalUD));
if(!calcEffMeasVec(iSite, iWire)) {
if(getPrintErr()) {
Error("filter()",
Form("Could not calculate effective measurement for site %i.",
iSite));
}
return kFALSE;
}
switch(getFilterMethod()) {
case Kalman::kKalConv:
filterConventional(iSite);
break;
case Kalman::kKalJoseph:
filterJoseph(iSite);
break;
case Kalman::kKalUD:
filterUD(iSite);
break;
case Kalman::kKalSeq:
filterSequential(iSite);
break;
case Kalman::kKalSwer:
filterSwerling(iSite);
break;
default:
filterConventional(iSite);
break;
}
site->transVirtLayToSec(Kalman::kFiltered, iHit,
(getFilterMethod() == Kalman::kKalUD));
states.at(iWire) = TVectorD(site->getStateVec(Kalman::kFiltered,
Kalman::kSecCoord));
invCovs.at(iWire) = TMatrixD(TMatrixD::kInverted,
site->getStateCovMat(Kalman::kFiltered,
Kalman::kSecCoord));
}
TMatrixD cov(getStateDim(), getStateDim());
cov.Zero();
for(UInt_t i = 0; i < invCovs.size(); i++) {
cov += invCovs.at(i);
}
cov.Invert();
site->setStateCovMat(Kalman::kFiltered, cov,
Kalman::kSecCoord);
TVectorD state(getStateDim());
state.Zero();
for(UInt_t i = 0; i < states.size(); i++) {
state += invCovs.at(i)*states.at(i);
}
state = cov*state;
site->setStateVec(Kalman::kFiltered, state,
Kalman::kSecCoord);
return kTRUE;
}
Bool_t HKalDafWire::propagate(Int_t iFromSite, Int_t iToSite) {
if(!propagateTrack(iFromSite, iToSite)) {
return kFALSE;
}
if(!calcVirtPlane(iToSite)) {
return kFALSE;
}
if(getFilterMethod() == Kalman::kKalUD) {
propagateCovUD(iFromSite, iToSite);
} else {
propagateCovConv(iFromSite, iToSite);
}
return kTRUE;
}
void HKalDafWire::updateSites(const TObjArray &hits) {
for(Int_t i = 0; i < getNsites(); i++) {
getSite(i)->clearHits();
getSite(i)->setActive(kTRUE);
}
Int_t iHit = 0;
Int_t iSite = 0;
while(iHit < hits.GetEntries() - 1) {
#if kalDebug > 0
if(!hits.At(iHit)->InheritsFrom("HKalMdcHit")) {
Error("updateSites()",
Form("Object at index %i in hits array is of class %s. Expected class is HKalMdcHit.",
iHit, hits.At(iHit)->ClassName()));
exit(1);
}
#endif
HKalMdcHit *hit = (HKalMdcHit*)hits.At(iHit);
HKalMdcHit *nexthit = (HKalMdcHit*)hits.At(iHit + 1);
getSite(iSite)->addHit(hit);
HKalMdcHit *hit2 = new HKalMdcHit(*hit);
hit2->setDriftTime(hit->getDriftTime() * (-1.), hit->getDriftTimeErr());
TVectorD hitVec2 = hit->getHitVec();
hitVec2 *= -1.;
hit2->setHitAndErr(hitVec2, hit->getErrVec());
getSite(iSite)->addHit(hit2);
Double_t w = hit->getWeight();
hit ->setWeight(w / 2.);
hit2->setWeight(w / 2.);
if(!hit->areCompetitors(*hit, *nexthit)) {
iSite++;
}
iHit++;
}
HKalMdcHit *lastHit = (HKalMdcHit*)hits.At(hits.GetEntries()-1);
getSite(iSite)->addHit(lastHit);
HKalMdcHit *hit2 = new HKalMdcHit(*lastHit);
hit2->setDriftTime(lastHit->getDriftTime() * (-1.),
lastHit->getDriftTimeErr());
Double_t w = lastHit->getWeight();
lastHit->setWeight(w / 2.);
hit2 ->setWeight(w / 2.);
getSite(iSite)->addHit(hit2);
iSite++;
setNSites(iSite);
setNHitsInTrack(hits.GetEntries() * 2);
#if kalDebug > 1
cout<<"New track has "<<getNsites()<<" measurement sites."<<endl;
#endif
}