00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028 #include <iostream>
00029 #include <iomanip>
00030
00031 #include "TVectorF.h"
00032 #include "TVectorD.h"
00033 #include "TMatrixD.h"
00034 #include "TMatrixDBase.h"
00035
00036 #ifndef ROOT_TMVA_MsgLogger
00037 #include "TMVA/MsgLogger.h"
00038 #endif
00039 #ifndef ROOT_TMVA_VariableNormalizeTransform
00040 #include "TMVA/VariableNormalizeTransform.h"
00041 #endif
00042 #ifndef ROOT_TMVA_Tools
00043 #include "TMVA/Tools.h"
00044 #endif
00045 #ifndef ROOT_TMVA_DataSet
00046 #include "TMVA/DataSet.h"
00047 #endif
00048
00049 ClassImp(TMVA::VariableNormalizeTransform)
00050
00051
00052 TMVA::VariableNormalizeTransform::VariableNormalizeTransform( DataSetInfo& dsi )
00053 : VariableTransformBase( dsi, Types::kNormalized, "Norm" )
00054 {
00055
00056 }
00057
00058
00059 TMVA::VariableNormalizeTransform::~VariableNormalizeTransform() {
00060 }
00061
00062
00063 void TMVA::VariableNormalizeTransform::Initialize()
00064 {
00065
00066
00067 UInt_t nvar = Variables().size();
00068 UInt_t ntgts = Targets().size();
00069 Int_t numC = GetNClasses()+1;
00070 if (GetNClasses() <= 1 ) numC = 1;
00071
00072 fMin.resize( numC );
00073 fMax.resize( numC );
00074 for (Int_t i=0; i<numC; i++) {
00075 fMin.at(i).resize(nvar+ntgts);
00076 fMax.at(i).resize(nvar+ntgts);
00077 fMin.at(i).assign(nvar+ntgts, 0);
00078 fMax.at(i).assign(nvar+ntgts, 0);
00079 }
00080 }
00081
00082
00083 Bool_t TMVA::VariableNormalizeTransform::PrepareTransformation( const std::vector<Event*>& events )
00084 {
00085
00086 if (!IsEnabled() || IsCreated()) return kTRUE;
00087
00088 Log() << kINFO << "Preparing the transformation." << Endl;
00089
00090 Initialize();
00091
00092 CalcNormalizationParams( events );
00093
00094 SetCreated( kTRUE );
00095
00096 return kTRUE;
00097 }
00098
00099
00100 const TMVA::Event* TMVA::VariableNormalizeTransform::Transform( const TMVA::Event* const ev, Int_t cls ) const
00101 {
00102
00103
00104 if (!IsCreated()) Log() << kFATAL << "Transformation not yet created" << Endl;
00105
00106
00107
00108
00109
00110
00111
00112
00113 if (cls < 0 || cls >= (int) fMin.size()) cls = fMin.size()-1;
00114
00115
00116 const UInt_t nvars = GetNVariables();
00117 const UInt_t ntgts = ev->GetNTargets();
00118 if (nvars != ev->GetNVariables()) {
00119 Log() << kFATAL << "Transformation defined for a different number of variables (defined for: " << GetNVariables()
00120 << ", event contains: " << ev->GetNVariables() << ")" << Endl;
00121 }
00122 if (ntgts != ev->GetNTargets()) {
00123 Log() << kFATAL << "Transformation defined for a different number of targets (defined for: " << GetNTargets()
00124 << ", event contains: " << ev->GetNTargets() << ")" << Endl;
00125 }
00126
00127 if (fTransformedEvent==0) fTransformedEvent = new Event();
00128
00129 Float_t min,max;
00130 for (Int_t ivar=nvars-1; ivar>=0; ivar--) {
00131 min = fMin.at(cls).at(ivar);
00132 max = fMax.at(cls).at(ivar);
00133 Float_t offset = min;
00134 Float_t scale = 1.0/(max-min);
00135
00136 Float_t valnorm = (ev->GetValue(ivar)-offset)*scale * 2 - 1;
00137 fTransformedEvent->SetVal(ivar,valnorm);
00138 }
00139 for (Int_t itgt=ntgts-1; itgt>=0; itgt--) {
00140 min = fMin.at(cls).at(nvars+itgt);
00141 max = fMax.at(cls).at(nvars+itgt);
00142 Float_t offset = min;
00143 Float_t scale = 1.0/(max-min);
00144
00145 Float_t original = ev->GetTarget(itgt);
00146 Float_t valnorm = (original-offset)*scale * 2 - 1;
00147 fTransformedEvent->SetTarget(itgt,valnorm);
00148 }
00149
00150 fTransformedEvent->SetWeight ( ev->GetWeight() );
00151 fTransformedEvent->SetBoostWeight( ev->GetBoostWeight() );
00152 fTransformedEvent->SetClass ( ev->GetClass() );
00153 return fTransformedEvent;
00154 }
00155
00156
00157 const TMVA::Event* TMVA::VariableNormalizeTransform::InverseTransform( const TMVA::Event* const ev, Int_t cls ) const
00158 {
00159
00160 if (!IsCreated()) Log() << kFATAL << "Transformation not yet created" << Endl;
00161
00162
00163
00164 if (cls < 0 || cls > GetNClasses()) {
00165 if (GetNClasses() > 1 ) cls = GetNClasses();
00166 else cls = 0;
00167 }
00168
00169 const UInt_t nvars = GetNVariables();
00170 const UInt_t ntgts = GetNTargets();
00171 if (nvars != ev->GetNVariables()) {
00172 Log() << kFATAL << "Transformation defined for a different number of variables " << GetNVariables() << " " << ev->GetNVariables()
00173 << Endl;
00174 }
00175
00176 if (fBackTransformedEvent==0) fBackTransformedEvent = new Event( *ev );
00177
00178 Float_t min,max;
00179 for (Int_t ivar=nvars-1; ivar>=0; ivar--) {
00180 min = fMin.at(cls).at(ivar);
00181 max = fMax.at(cls).at(ivar);
00182 Float_t offset = min;
00183 Float_t scale = 1.0/(max-min);
00184
00185 Float_t valnorm = offset+((ev->GetValue(ivar)+1)/(scale * 2));
00186 fBackTransformedEvent->SetVal(ivar,valnorm);
00187 }
00188
00189 for (Int_t itgt=ntgts-1; itgt>=0; itgt--) {
00190 min = fMin.at(cls).at(nvars+itgt);
00191 max = fMax.at(cls).at(nvars+itgt);
00192 Float_t offset = min;
00193 Float_t scale = 1.0/(max-min);
00194
00195 Float_t original = ev->GetTarget(itgt);
00196 Float_t valnorm = offset+((original+1.0)/(scale * 2));
00197 fBackTransformedEvent->SetTarget(itgt,valnorm);
00198 }
00199
00200 return fBackTransformedEvent;
00201 }
00202
00203
00204 void TMVA::VariableNormalizeTransform::CalcNormalizationParams( const std::vector<Event*>& events )
00205 {
00206
00207 if (events.size() <= 1)
00208 Log() << kFATAL << "Not enough events (found " << events.size() << ") to calculate the normalization" << Endl;
00209
00210 UInt_t nvars = GetNVariables();
00211 UInt_t ntgts = GetNTargets();
00212
00213 Int_t numC = GetNClasses()+1;
00214 if (GetNClasses() <= 1 ) numC = 1;
00215
00216 for (UInt_t ivar=0; ivar<nvars+ntgts; ivar++) {
00217 for (Int_t ic = 0; ic < numC; ic++) {
00218 fMin.at(ic).at(ivar) = FLT_MAX;
00219 fMax.at(ic).at(ivar) = -FLT_MAX;
00220 }
00221 }
00222
00223 const Int_t all = GetNClasses();
00224 std::vector<Event*>::const_iterator evIt = events.begin();
00225 for (;evIt!=events.end();evIt++) {
00226 for (UInt_t ivar=0; ivar<nvars; ivar++) {
00227 Float_t val = (*evIt)->GetValue(ivar);
00228 UInt_t cls = (*evIt)->GetClass();
00229
00230 if (fMin.at(cls).at(ivar) > val) fMin.at(cls).at(ivar) = val;
00231 if (fMax.at(cls).at(ivar) < val) fMax.at(cls).at(ivar) = val;
00232
00233 if (GetNClasses() != 1) {
00234 if (fMin.at(all).at(ivar) > val) fMin.at(all).at(ivar) = val;
00235 if (fMax.at(all).at(ivar) < val) fMax.at(all).at(ivar) = val;
00236 }
00237 }
00238 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
00239 Float_t val = (*evIt)->GetTarget(itgt);
00240 UInt_t cls = (*evIt)->GetClass();
00241
00242 if (fMin.at(cls).at(nvars+itgt) > val) fMin.at(cls).at(nvars+itgt) = val;
00243 if (fMax.at(cls).at(nvars+itgt) < val) fMax.at(cls).at(nvars+itgt) = val;
00244
00245 if (GetNClasses() != 1) {
00246 if (fMin.at(all).at(nvars+itgt) > val) fMin.at(all).at(nvars+itgt) = val;
00247 if (fMax.at(all).at(nvars+itgt) < val) fMax.at(all).at(nvars+itgt) = val;
00248 }
00249 }
00250 }
00251
00252 return;
00253 }
00254
00255
00256 std::vector<TString>* TMVA::VariableNormalizeTransform::GetTransformationStrings( Int_t cls ) const
00257 {
00258
00259
00260
00261
00262 if (cls < 0 || cls > GetNClasses()) cls = GetNClasses();
00263
00264 const UInt_t nvar = GetNVariables();
00265 std::vector<TString>* strVec = new std::vector<TString>(nvar);
00266
00267 Float_t min, max;
00268 for (Int_t ivar=nvar-1; ivar>=0; ivar--) {
00269 min = fMin.at(cls).at(ivar);
00270 max = fMax.at(cls).at(ivar);
00271 Float_t offset = min;
00272 Float_t scale = 1.0/(max-min);
00273 TString str("");
00274 if (offset < 0) str = Form( "2*%g*([%s] + %g) - 1", scale, Variables()[ivar].GetLabel().Data(), -offset );
00275 else str = Form( "2*%g*([%s] - %g) - 1", scale, Variables()[ivar].GetLabel().Data(), offset );
00276 (*strVec)[ivar] = str;
00277 }
00278
00279 return strVec;
00280 }
00281
00282
00283 void TMVA::VariableNormalizeTransform::WriteTransformationToStream( std::ostream& o ) const
00284 {
00285
00286 o << "# min max for all variables for all classes one after the other and as a last entry for all classes together" << std::endl;
00287
00288 Int_t numC = GetNClasses()+1;
00289 if (GetNClasses() <= 1 ) numC = 1;
00290
00291 UInt_t nvars = GetNVariables();
00292 UInt_t ntgts = GetNTargets();
00293
00294 for (Int_t icls = 0; icls < numC; icls++ ) {
00295 o << icls << std::endl;
00296 for (UInt_t ivar=0; ivar<nvars; ivar++)
00297 o << std::setprecision(12) << std::setw(20) << fMin.at(icls).at(ivar) << " "
00298 << std::setprecision(12) << std::setw(20) << fMax.at(icls).at(ivar) << std::endl;
00299 for (UInt_t itgt=0; itgt<ntgts; itgt++)
00300 o << std::setprecision(12) << std::setw(20) << fMin.at(icls).at(nvars+itgt) << " "
00301 << std::setprecision(12) << std::setw(20) << fMax.at(icls).at(nvars+itgt) << std::endl;
00302 }
00303 o << "##" << std::endl;
00304 }
00305
00306
00307 void TMVA::VariableNormalizeTransform::AttachXMLTo(void* parent)
00308 {
00309
00310 Int_t numC = (GetNClasses()<= 1)?1:GetNClasses()+1;
00311 UInt_t nvars = GetNVariables();
00312 UInt_t ntgts = GetNTargets();
00313
00314 void* trfxml = gTools().AddChild(parent, "Transform");
00315 gTools().AddAttr(trfxml, "Name", "Normalize");
00316 gTools().AddAttr(trfxml, "NVariables", nvars);
00317 gTools().AddAttr(trfxml, "NTargets", ntgts);
00318
00319 for( Int_t icls=0; icls<numC; icls++ ) {
00320 void* clsxml = gTools().AddChild(trfxml, "Class");
00321 gTools().AddAttr(clsxml, "ClassIndex", icls);
00322 void* varsxml = gTools().AddChild(clsxml, "Variables");
00323 for (UInt_t ivar=0; ivar<nvars; ivar++) {
00324 void* varxml = gTools().AddChild(varsxml, "Variable");
00325 gTools().AddAttr(varxml, "VarIndex", ivar);
00326 gTools().AddAttr(varxml, "Min", fMin.at(icls).at(ivar) );
00327 gTools().AddAttr(varxml, "Max", fMax.at(icls).at(ivar) );
00328 }
00329 void* tgtsxml = gTools().AddChild(clsxml, "Targets");
00330 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
00331 void* tgtxml = gTools().AddChild(tgtsxml, "Target");
00332 gTools().AddAttr(tgtxml, "TargetIndex", itgt);
00333 gTools().AddAttr(tgtxml, "Min", fMin.at(icls).at(nvars+itgt) );
00334 gTools().AddAttr(tgtxml, "Max", fMax.at(icls).at(nvars+itgt) );
00335 }
00336 }
00337 }
00338
00339
00340 void TMVA::VariableNormalizeTransform::ReadFromXML( void* trfnode )
00341 {
00342
00343 UInt_t classindex, varindex, tgtindex, nvars, ntgts;
00344
00345 gTools().ReadAttr(trfnode, "NVariables", nvars);
00346 gTools().ReadAttr(trfnode, "NTargets", ntgts);
00347
00348 void* ch = gTools().GetChild( trfnode );
00349 while(ch) {
00350 gTools().ReadAttr(ch, "ClassIndex", classindex);
00351
00352 fMin.resize(classindex+1);
00353 fMax.resize(classindex+1);
00354 fMin[classindex].resize(nvars+ntgts,Float_t(0));
00355 fMax[classindex].resize(nvars+ntgts,Float_t(0));
00356
00357 void* clch = gTools().GetChild( ch );
00358 while(clch) {
00359 TString nodeName(gTools().GetName(clch));
00360 if(nodeName=="Variables") {
00361 void* varch = gTools().GetChild( clch );
00362 while(varch) {
00363 gTools().ReadAttr(varch, "VarIndex", varindex);
00364 gTools().ReadAttr(varch, "Min", fMin[classindex][varindex]);
00365 gTools().ReadAttr(varch, "Max", fMax[classindex][varindex]);
00366 varch = gTools().GetNextChild( varch );
00367 }
00368 } else if (nodeName=="Targets") {
00369 void* tgtch = gTools().GetChild( clch );
00370 while(tgtch) {
00371 gTools().ReadAttr(tgtch, "TargetIndex", tgtindex);
00372 gTools().ReadAttr(tgtch, "Min", fMin[classindex][nvars+tgtindex]);
00373 gTools().ReadAttr(tgtch, "Max", fMax[classindex][nvars+tgtindex]);
00374 tgtch = gTools().GetNextChild( tgtch );
00375 }
00376 }
00377 clch = gTools().GetNextChild( clch );
00378 }
00379 ch = gTools().GetNextChild( ch );
00380 }
00381 SetCreated();
00382 }
00383
00384
00385 void
00386 TMVA::VariableNormalizeTransform::BuildTransformationFromVarInfo( const std::vector<TMVA::VariableInfo>& var ) {
00387
00388
00389
00390
00391 UInt_t nvars = GetNVariables();
00392
00393 if(var.size() != nvars)
00394 Log() << kFATAL << "<BuildTransformationFromVarInfo> can't build transformation,"
00395 << " since the number of variables disagree" << Endl;
00396
00397 UInt_t numC = (GetNClasses()<=1)?1:GetNClasses()+1;
00398 fMin.clear();fMin.resize( numC );
00399 fMax.clear();fMax.resize( numC );
00400
00401
00402 for(UInt_t cls=0; cls<numC; ++cls) {
00403 fMin[cls].resize(nvars+GetNTargets(),0);
00404 fMax[cls].resize(nvars+GetNTargets(),0);
00405 UInt_t vidx(0);
00406 for(std::vector<TMVA::VariableInfo>::const_iterator v = var.begin(); v!=var.end(); ++v, ++vidx) {
00407 fMin[cls][vidx] = v->GetMin();
00408 fMax[cls][vidx] = v->GetMax();
00409 }
00410 }
00411 SetCreated();
00412 }
00413
00414
00415 void TMVA::VariableNormalizeTransform::ReadTransformationFromStream( std::istream& istr, const TString& )
00416 {
00417
00418
00419 UInt_t nvars = GetNVariables();
00420 UInt_t ntgts = GetNTargets();
00421 char buf[512];
00422 char buf2[512];
00423 istr.getline(buf,512);
00424 TString strvar, dummy;
00425 Int_t icls;
00426 TString test;
00427 while (!(buf[0]=='#'&& buf[1]=='#')) {
00428 char* p = buf;
00429 while (*p==' ' || *p=='\t') p++;
00430 if (*p=='#' || *p=='\0') {
00431 istr.getline(buf,512);
00432 continue;
00433 }
00434 std::stringstream sstr(buf);
00435 sstr >> icls;
00436 for (UInt_t ivar=0;ivar<nvars;ivar++) {
00437 istr.getline(buf2,512);
00438 std::stringstream sstr2(buf2);
00439 sstr2 >> fMin[icls][ivar] >> fMax[icls][ivar];
00440 }
00441 for (UInt_t itgt=0;itgt<ntgts;itgt++) {
00442 istr.getline(buf2,512);
00443 std::stringstream sstr2(buf2);
00444 sstr2 >> fMin[icls][nvars+itgt] >> fMax[icls][nvars+itgt];
00445 }
00446 istr.getline(buf,512);
00447 }
00448 SetCreated();
00449 }
00450
00451
00452 void TMVA::VariableNormalizeTransform::PrintTransformation( ostream& o )
00453 {
00454
00455
00456 Int_t numC = GetNClasses()+1;
00457 if (GetNClasses() <= 1 ) numC = 1;
00458
00459 UInt_t nvars = GetNVariables();
00460 UInt_t ntgts = GetNTargets();
00461 for (Int_t icls = 0; icls < numC; icls++ ) {
00462 Log() << kINFO << "Transformation for class " << icls << " based on these ranges:" << Endl;
00463 Log() << kINFO << "Variables:" << Endl;
00464 for (UInt_t ivar=0; ivar<nvars; ivar++)
00465 o << std::setw(20) << fMin[icls][ivar] << std::setw(20) << fMax[icls][ivar] << std::endl;
00466 Log() << kINFO << "Targets:" << Endl;
00467 for (UInt_t itgt=0; itgt<ntgts; itgt++)
00468 o << std::setw(20) << fMin[icls][nvars+itgt] << std::setw(20) << fMax[icls][nvars+itgt] << std::endl;
00469 }
00470 }
00471
00472
00473 void TMVA::VariableNormalizeTransform::MakeFunction( std::ostream& fout, const TString& fcncName,
00474 Int_t part, UInt_t trCounter, Int_t )
00475 {
00476
00477
00478 UInt_t numC = fMin.size();
00479 if (part==1) {
00480 fout << std::endl;
00481 fout << " double fMin_"<<trCounter<<"["<<numC<<"]["<<GetNVariables()<<"];" << std::endl;
00482 fout << " double fMax_"<<trCounter<<"["<<numC<<"]["<<GetNVariables()<<"];" << std::endl;
00483 }
00484
00485 if (part==2) {
00486 fout << std::endl;
00487 fout << "//_______________________________________________________________________" << std::endl;
00488 fout << "inline void " << fcncName << "::InitTransform_"<<trCounter<<"()" << std::endl;
00489 fout << "{" << std::endl;
00490 for (UInt_t ivar=0; ivar<GetNVariables(); ivar++) {
00491 Float_t min = FLT_MAX;
00492 Float_t max = -FLT_MAX;
00493 for (UInt_t icls = 0; icls < numC; icls++) {
00494 min = TMath::Min(min, fMin.at(icls).at(ivar) );
00495 max = TMath::Max(max, fMax.at(icls).at(ivar) );
00496 fout << " fMin_"<<trCounter<<"["<<icls<<"]["<<ivar<<"] = " << std::setprecision(12)
00497 << min << ";" << std::endl;
00498 fout << " fMax_"<<trCounter<<"["<<icls<<"]["<<ivar<<"] = " << std::setprecision(12)
00499 << max << ";" << std::endl;
00500 }
00501 }
00502 fout << "}" << std::endl;
00503 fout << std::endl;
00504 fout << "//_______________________________________________________________________" << std::endl;
00505 fout << "inline void " << fcncName << "::Transform_"<<trCounter<<"( std::vector<double>& iv, int cls) const" << std::endl;
00506 fout << "{" << std::endl;
00507 fout << "if (cls < 0 || cls > "<<GetNClasses()<<") {"<< std::endl;
00508 fout << " if ("<<GetNClasses()<<" > 1 ) cls = "<<GetNClasses()<<";"<< std::endl;
00509 fout << " else cls = "<<(fMin.size()==1?0:2)<<";"<< std::endl;
00510 fout << "}"<< std::endl;
00511 fout << " for (int ivar=0;ivar<"<<GetNVariables()<<";ivar++) {" << std::endl;
00512 fout << " double offset = fMin_"<<trCounter<<"[cls][ivar];" << std::endl;
00513 fout << " double scale = 1.0/(fMax_"<<trCounter<<"[cls][ivar]-fMin_"<<trCounter<<"[cls][ivar]);" << std::endl;
00514 fout << " iv[ivar] = (iv[ivar]-offset)*scale * 2 - 1;" << std::endl;
00515 fout << " }" << std::endl;
00516 fout << "}" << std::endl;
00517 }
00518 }