00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035 #include <cmath>
00036 #include <string>
00037 #include <cstdlib>
00038
00039
00040 #include "TFile.h"
00041 #include "TMath.h"
00042 #include "TTree.h"
00043
00044
00045 #include "TMVA/ClassifierFactory.h"
00046 #include "TMVA/MethodKNN.h"
00047 #include "TMVA/Ranking.h"
00048 #include "TMVA/Tools.h"
00049
00050 REGISTER_METHOD(KNN)
00051
00052 ClassImp(TMVA::MethodKNN)
00053
00054
00055 TMVA::MethodKNN::MethodKNN( const TString& jobName,
00056 const TString& methodTitle,
00057 DataSetInfo& theData,
00058 const TString& theOption,
00059 TDirectory* theTargetDir )
00060 : TMVA::MethodBase(jobName, Types::kKNN, methodTitle, theData, theOption, theTargetDir)
00061 , fSumOfWeightsS(0)
00062 , fSumOfWeightsB(0)
00063 , fModule(0)
00064 , fnkNN(0)
00065 , fBalanceDepth(0)
00066 , fScaleFrac(0)
00067 , fSigmaFact(0)
00068 , fTrim(kFALSE)
00069 , fUseKernel(kFALSE)
00070 , fUseWeight(kFALSE)
00071 , fUseLDA(kFALSE)
00072 , fTreeOptDepth(0)
00073 {
00074
00075 }
00076
00077
00078 TMVA::MethodKNN::MethodKNN( DataSetInfo& theData,
00079 const TString& theWeightFile,
00080 TDirectory* theTargetDir )
00081 : TMVA::MethodBase( Types::kKNN, theData, theWeightFile, theTargetDir)
00082 , fSumOfWeightsS(0)
00083 , fSumOfWeightsB(0)
00084 , fModule(0)
00085 , fnkNN(0)
00086 , fBalanceDepth(0)
00087 , fScaleFrac(0)
00088 , fSigmaFact(0)
00089 , fTrim(kFALSE)
00090 , fUseKernel(kFALSE)
00091 , fUseWeight(kFALSE)
00092 , fUseLDA(kFALSE)
00093 , fTreeOptDepth(0)
00094 {
00095
00096 }
00097
00098
00099 TMVA::MethodKNN::~MethodKNN()
00100 {
00101
00102 if (fModule) delete fModule;
00103 }
00104
00105
00106 void TMVA::MethodKNN::DeclareOptions()
00107 {
00108
00109
00110
00111
00112
00113
00114
00115
00116
00117
00118
00119
00120 DeclareOptionRef(fnkNN = 20, "nkNN", "Number of k-nearest neighbors");
00121 DeclareOptionRef(fBalanceDepth = 6, "BalanceDepth", "Binary tree balance depth");
00122 DeclareOptionRef(fScaleFrac = 0.80, "ScaleFrac", "Fraction of events used to compute variable width");
00123 DeclareOptionRef(fSigmaFact = 1.0, "SigmaFact", "Scale factor for sigma in Gaussian kernel");
00124 DeclareOptionRef(fKernel = "Gaus", "Kernel", "Use polynomial (=Poln) or Gaussian (=Gaus) kernel");
00125 DeclareOptionRef(fTrim = kFALSE, "Trim", "Use equal number of signal and background events");
00126 DeclareOptionRef(fUseKernel = kFALSE, "UseKernel", "Use polynomial kernel weight");
00127 DeclareOptionRef(fUseWeight = kTRUE, "UseWeight", "Use weight to count kNN events");
00128 DeclareOptionRef(fUseLDA = kFALSE, "UseLDA", "Use local linear discriminant - experimental feature");
00129 }
00130
00131
00132 void TMVA::MethodKNN::DeclareCompatibilityOptions() {
00133 MethodBase::DeclareCompatibilityOptions();
00134 DeclareOptionRef(fTreeOptDepth = 6, "TreeOptDepth", "Binary tree optimisation depth");
00135 }
00136
00137
00138 void TMVA::MethodKNN::ProcessOptions()
00139 {
00140
00141 if (!(fnkNN > 0)) {
00142 fnkNN = 10;
00143 Log() << kWARNING << "kNN must be a positive integer: set kNN = " << fnkNN << Endl;
00144 }
00145 if (fScaleFrac < 0.0) {
00146 fScaleFrac = 0.0;
00147 Log() << kWARNING << "ScaleFrac can not be negative: set ScaleFrac = " << fScaleFrac << Endl;
00148 }
00149 if (fScaleFrac > 1.0) {
00150 fScaleFrac = 1.0;
00151 }
00152 if (!(fBalanceDepth > 0)) {
00153 fBalanceDepth = 6;
00154 Log() << kWARNING << "Optimize must be a positive integer: set Optimize = " << fBalanceDepth << Endl;
00155 }
00156
00157 Log() << kVERBOSE
00158 << "kNN options: \n"
00159 << " kNN = \n" << fnkNN
00160 << " UseKernel = \n" << fUseKernel
00161 << " SigmaFact = \n" << fSigmaFact
00162 << " ScaleFrac = \n" << fScaleFrac
00163 << " Kernel = \n" << fKernel
00164 << " Trim = \n" << fTrim
00165 << " Optimize = " << fBalanceDepth << Endl;
00166 }
00167
00168
00169 Bool_t TMVA::MethodKNN::HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t )
00170 {
00171
00172 if (type == Types::kClassification && numberClasses == 2) return kTRUE;
00173 if (type == Types::kRegression) return kTRUE;
00174 return kFALSE;
00175 }
00176
00177
00178 void TMVA::MethodKNN::Init()
00179 {
00180
00181
00182
00183
00184
00185 fModule = new kNN::ModulekNN();
00186 fSumOfWeightsS = 0;
00187 fSumOfWeightsB = 0;
00188 }
00189
00190
00191 void TMVA::MethodKNN::MakeKNN()
00192 {
00193
00194 if (!fModule) {
00195 Log() << kFATAL << "ModulekNN is not created" << Endl;
00196 }
00197
00198 fModule->Clear();
00199
00200 std::string option;
00201 if (fScaleFrac > 0.0) {
00202 option += "metric";
00203 }
00204 if (fTrim) {
00205 option += "trim";
00206 }
00207
00208 Log() << kINFO << "Creating kd-tree with " << fEvent.size() << " events" << Endl;
00209
00210 for (kNN::EventVec::const_iterator event = fEvent.begin(); event != fEvent.end(); ++event) {
00211 fModule->Add(*event);
00212 }
00213
00214
00215 fModule->Fill(static_cast<UInt_t>(fBalanceDepth),
00216 static_cast<UInt_t>(100.0*fScaleFrac),
00217 option);
00218 }
00219
00220
00221 void TMVA::MethodKNN::Train()
00222 {
00223
00224 Log() << kINFO << "<Train> start..." << Endl;
00225
00226 if (IsNormalised()) {
00227 Log() << kINFO << "Input events are normalized - setting ScaleFrac to 0" << Endl;
00228 fScaleFrac = 0.0;
00229 }
00230
00231 if (!fEvent.empty()) {
00232 Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
00233 fEvent.clear();
00234 }
00235 if (GetNVariables() < 1)
00236 Log() << kFATAL << "MethodKNN::Train() - mismatched or wrong number of event variables" << Endl;
00237
00238
00239 Log() << kINFO << "Reading " << GetNEvents() << " events" << Endl;
00240
00241 for (UInt_t ievt = 0; ievt < GetNEvents(); ++ievt) {
00242
00243 const Event* evt_ = GetEvent(ievt);
00244 Double_t weight = evt_->GetWeight();
00245
00246
00247 if (IgnoreEventsWithNegWeightsInTraining() && weight <= 0) continue;
00248
00249 kNN::VarVec vvec(GetNVariables(), 0.0);
00250 for (UInt_t ivar = 0; ivar < evt_ -> GetNVariables(); ++ivar) vvec[ivar] = evt_->GetValue(ivar);
00251
00252 Short_t event_type = 0;
00253
00254 if (DataInfo().IsSignal(evt_)) {
00255 fSumOfWeightsS += weight;
00256 event_type = 1;
00257 }
00258 else {
00259 fSumOfWeightsB += weight;
00260 event_type = 2;
00261 }
00262
00263
00264
00265
00266 kNN::Event event_knn(vvec, weight, event_type);
00267 event_knn.SetTargets(evt_->GetTargets());
00268 fEvent.push_back(event_knn);
00269
00270 }
00271 Log() << kINFO
00272 << "Number of signal events " << fSumOfWeightsS << Endl
00273 << "Number of background events " << fSumOfWeightsB << Endl;
00274
00275
00276 MakeKNN();
00277 }
00278
00279
00280 Double_t TMVA::MethodKNN::GetMvaValue( Double_t* err, Double_t* errUpper )
00281 {
00282
00283
00284
00285 NoErrorCalc(err, errUpper);
00286
00287
00288
00289
00290 const Event *ev = GetEvent();
00291 const Int_t nvar = GetNVariables();
00292 const Double_t weight = ev->GetWeight();
00293 const UInt_t knn = static_cast<UInt_t>(fnkNN);
00294
00295 kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
00296
00297 for (Int_t ivar = 0; ivar < nvar; ++ivar) {
00298 vvec[ivar] = ev->GetValue(ivar);
00299 }
00300
00301
00302
00303
00304 const kNN::Event event_knn(vvec, weight, 3);
00305 fModule->Find(event_knn, knn + 2);
00306
00307 const kNN::List &rlist = fModule->GetkNNList();
00308 if (rlist.size() != knn + 2) {
00309 Log() << kFATAL << "kNN result list is empty" << Endl;
00310 return -100.0;
00311 }
00312
00313 if (fUseLDA) return MethodKNN::getLDAValue(rlist, event_knn);
00314
00315
00316
00317
00318 Bool_t use_gaus = false, use_poln = false;
00319 if (fUseKernel) {
00320 if (fKernel == "Gaus") use_gaus = true;
00321 else if (fKernel == "Poln") use_poln = true;
00322 }
00323
00324
00325
00326
00327 Double_t kradius = -1.0;
00328 if (use_poln) {
00329 kradius = MethodKNN::getKernelRadius(rlist);
00330
00331 if (!(kradius > 0.0)) {
00332 Log() << kFATAL << "kNN radius is not positive" << Endl;
00333 return -100.0;
00334 }
00335
00336 kradius = 1.0/TMath::Sqrt(kradius);
00337 }
00338
00339
00340
00341
00342 std::vector<Double_t> rms_vec;
00343 if (use_gaus) {
00344 rms_vec = TMVA::MethodKNN::getRMS(rlist, event_knn);
00345
00346 if (rms_vec.empty() || rms_vec.size() != event_knn.GetNVar()) {
00347 Log() << kFATAL << "Failed to compute RMS vector" << Endl;
00348 return -100.0;
00349 }
00350 }
00351
00352 UInt_t count_all = 0;
00353 Double_t weight_all = 0, weight_sig = 0, weight_bac = 0;
00354
00355 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
00356
00357
00358 const kNN::Node<kNN::Event> &node = *(lit->first);
00359
00360
00361
00362 if (lit->second < 0.0) {
00363 Log() << kFATAL << "A neighbor has negative distance to query event" << Endl;
00364 }
00365 else if (!(lit->second > 0.0)) {
00366 Log() << kVERBOSE << "A neighbor has zero distance to query event" << Endl;
00367 }
00368
00369
00370 Double_t evweight = node.GetWeight();
00371 if (use_gaus) evweight *= MethodKNN::GausKernel(event_knn, node.GetEvent(), rms_vec);
00372 else if (use_poln) evweight *= MethodKNN::PolnKernel(TMath::Sqrt(lit->second)*kradius);
00373
00374 if (fUseWeight) weight_all += evweight;
00375 else ++weight_all;
00376
00377 if (node.GetEvent().GetType() == 1) {
00378 if (fUseWeight) weight_sig += evweight;
00379 else ++weight_sig;
00380 }
00381 else if (node.GetEvent().GetType() == 2) {
00382 if (fUseWeight) weight_bac += evweight;
00383 else ++weight_bac;
00384 }
00385 else {
00386 Log() << kFATAL << "Unknown type for training event" << Endl;
00387 }
00388
00389
00390 ++count_all;
00391
00392 if (count_all >= knn) {
00393 break;
00394 }
00395 }
00396
00397
00398 if (!(count_all > 0)) {
00399 Log() << kFATAL << "Size kNN result list is not positive" << Endl;
00400 return -100.0;
00401 }
00402
00403
00404 if (count_all < knn) {
00405 Log() << kDEBUG << "count_all and kNN have different size: " << count_all << " < " << knn << Endl;
00406 }
00407
00408
00409 if (!(weight_all > 0.0)) {
00410 Log() << kFATAL << "kNN result total weight is not positive" << Endl;
00411 return -100.0;
00412 }
00413
00414 return weight_sig/weight_all;
00415 }
00416
00417
00418 const std::vector< Float_t >& TMVA::MethodKNN::GetRegressionValues()
00419 {
00420
00421
00422
00423
00424 if( fRegressionReturnVal == 0 )
00425 fRegressionReturnVal = new std::vector<Float_t>;
00426 else
00427 fRegressionReturnVal->clear();
00428
00429
00430
00431
00432 const Event *evt = GetEvent();
00433 const Int_t nvar = GetNVariables();
00434 const UInt_t knn = static_cast<UInt_t>(fnkNN);
00435 std::vector<float> reg_vec;
00436
00437 kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
00438
00439 for (Int_t ivar = 0; ivar < nvar; ++ivar) {
00440 vvec[ivar] = evt->GetValue(ivar);
00441 }
00442
00443
00444
00445
00446 const kNN::Event event_knn(vvec, evt->GetWeight(), 3);
00447 fModule->Find(event_knn, knn + 2);
00448
00449 const kNN::List &rlist = fModule->GetkNNList();
00450 if (rlist.size() != knn + 2) {
00451 Log() << kFATAL << "kNN result list is empty" << Endl;
00452 return *fRegressionReturnVal;
00453 }
00454
00455
00456 Double_t weight_all = 0;
00457 UInt_t count_all = 0;
00458
00459 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
00460
00461
00462 const kNN::Node<kNN::Event> &node = *(lit->first);
00463 const kNN::VarVec &tvec = node.GetEvent().GetTargets();
00464 const Double_t weight = node.GetEvent().GetWeight();
00465
00466 if (reg_vec.empty()) {
00467 reg_vec= kNN::VarVec(tvec.size(), 0.0);
00468 }
00469
00470 for(UInt_t ivar = 0; ivar < tvec.size(); ++ivar) {
00471 if (fUseWeight) reg_vec[ivar] += tvec[ivar]*weight;
00472 else reg_vec[ivar] += tvec[ivar];
00473 }
00474
00475 if (fUseWeight) weight_all += weight;
00476 else ++weight_all;
00477
00478
00479 ++count_all;
00480
00481 if (count_all == knn) {
00482 break;
00483 }
00484 }
00485
00486
00487 if (!(weight_all > 0.0)) {
00488 Log() << kFATAL << "Total weight sum is not positive: " << weight_all << Endl;
00489 return *fRegressionReturnVal;
00490 }
00491
00492 for (UInt_t ivar = 0; ivar < reg_vec.size(); ++ivar) {
00493 reg_vec[ivar] /= weight_all;
00494 }
00495
00496
00497 fRegressionReturnVal->insert(fRegressionReturnVal->begin(), reg_vec.begin(), reg_vec.end());
00498
00499 return *fRegressionReturnVal;
00500 }
00501
00502
00503 const TMVA::Ranking* TMVA::MethodKNN::CreateRanking()
00504 {
00505
00506 return 0;
00507 }
00508
00509
00510 void TMVA::MethodKNN::AddWeightsXMLTo( void* parent ) const {
00511
00512
00513 void* wght = gTools().AddChild(parent, "Weights");
00514 gTools().AddAttr(wght,"NEvents",fEvent.size());
00515 if (fEvent.size()>0) gTools().AddAttr(wght,"NVar",fEvent.begin()->GetNVar());
00516 if (fEvent.size()>0) gTools().AddAttr(wght,"NTgt",fEvent.begin()->GetNTgt());
00517
00518 for (kNN::EventVec::const_iterator event = fEvent.begin(); event != fEvent.end(); ++event) {
00519
00520 std::stringstream s("");
00521 s.precision( 16 );
00522 for (UInt_t ivar = 0; ivar < event->GetNVar(); ++ivar) {
00523 if (ivar>0) s << " ";
00524 s << std::scientific << event->GetVar(ivar);
00525 }
00526
00527 for (UInt_t itgt = 0; itgt < event->GetNTgt(); ++itgt) {
00528 s << " " << std::scientific << event->GetTgt(itgt);
00529 }
00530
00531 void* evt = gTools().AddChild(wght, "Event", s.str().c_str());
00532 gTools().AddAttr(evt,"Type", event->GetType());
00533 gTools().AddAttr(evt,"Weight", event->GetWeight());
00534 }
00535 }
00536
00537
00538 void TMVA::MethodKNN::ReadWeightsFromXML( void* wghtnode ) {
00539
00540 void* ch = gTools().GetChild(wghtnode);
00541 UInt_t nvar = 0, ntgt = 0;
00542 gTools().ReadAttr( wghtnode, "NVar", nvar );
00543 gTools().ReadAttr( wghtnode, "NTgt", ntgt );
00544
00545
00546 Short_t evtType(0);
00547 Double_t evtWeight(0);
00548
00549 while (ch) {
00550
00551 kNN::VarVec vvec(nvar, 0);
00552 kNN::VarVec tvec(ntgt, 0);
00553
00554 gTools().ReadAttr( ch, "Type", evtType );
00555 gTools().ReadAttr( ch, "Weight", evtWeight );
00556 std::stringstream s( gTools().GetContent(ch) );
00557
00558 for(UInt_t ivar=0; ivar<nvar; ivar++)
00559 s >> vvec[ivar];
00560
00561 for(UInt_t itgt=0; itgt<ntgt; itgt++)
00562 s >> tvec[itgt];
00563
00564 ch = gTools().GetNextChild(ch);
00565
00566 kNN::Event event_knn(vvec, evtWeight, evtType, tvec);
00567 fEvent.push_back(event_knn);
00568 }
00569
00570
00571 MakeKNN();
00572 }
00573
00574
00575 void TMVA::MethodKNN::ReadWeightsFromStream(istream& is)
00576 {
00577
00578 Log() << kINFO << "Starting ReadWeightsFromStream(istream& is) function..." << Endl;
00579
00580 if (!fEvent.empty()) {
00581 Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
00582 fEvent.clear();
00583 }
00584
00585 UInt_t nvar = 0;
00586
00587 while (!is.eof()) {
00588 std::string line;
00589 std::getline(is, line);
00590
00591 if (line.empty() || line.find("#") != std::string::npos) {
00592 continue;
00593 }
00594
00595 UInt_t count = 0;
00596 std::string::size_type pos=0;
00597 while( (pos=line.find(',',pos)) != std::string::npos ) { count++; pos++; }
00598
00599 if (nvar == 0) {
00600 nvar = count - 2;
00601 }
00602 if (count < 3 || nvar != count - 2) {
00603 Log() << kFATAL << "Missing comma delimeter(s)" << Endl;
00604 }
00605
00606 Int_t ievent = -1, type = -1;
00607 Double_t weight = -1.0;
00608
00609 kNN::VarVec vvec(nvar, 0.0);
00610
00611 UInt_t vcount = 0;
00612 std::string::size_type prev = 0;
00613
00614 for (std::string::size_type ipos = 0; ipos < line.size(); ++ipos) {
00615 if (line[ipos] != ',' && ipos + 1 != line.size()) {
00616 continue;
00617 }
00618
00619 if (!(ipos > prev)) {
00620 Log() << kFATAL << "Wrong substring limits" << Endl;
00621 }
00622
00623 std::string vstring = line.substr(prev, ipos - prev);
00624 if (ipos + 1 == line.size()) {
00625 vstring = line.substr(prev, ipos - prev + 1);
00626 }
00627
00628 if (vstring.empty()) {
00629 Log() << kFATAL << "Failed to parse string" << Endl;
00630 }
00631
00632 if (vcount == 0) {
00633 ievent = std::atoi(vstring.c_str());
00634 }
00635 else if (vcount == 1) {
00636 type = std::atoi(vstring.c_str());
00637 }
00638 else if (vcount == 2) {
00639 weight = std::atof(vstring.c_str());
00640 }
00641 else if (vcount - 3 < vvec.size()) {
00642 vvec[vcount - 3] = std::atof(vstring.c_str());
00643 }
00644 else {
00645 Log() << kFATAL << "Wrong variable count" << Endl;
00646 }
00647
00648 prev = ipos + 1;
00649 ++vcount;
00650 }
00651
00652 fEvent.push_back(kNN::Event(vvec, weight, type));
00653 }
00654
00655 Log() << kINFO << "Read " << fEvent.size() << " events from text file" << Endl;
00656
00657
00658 MakeKNN();
00659 }
00660
00661
00662 void TMVA::MethodKNN::WriteWeightsToStream(TFile &rf) const
00663 {
00664
00665 Log() << kINFO << "Starting WriteWeightsToStream(TFile &rf) function..." << Endl;
00666
00667 if (fEvent.empty()) {
00668 Log() << kWARNING << "MethodKNN contains no events " << Endl;
00669 return;
00670 }
00671
00672 kNN::Event *event = new kNN::Event();
00673 TTree *tree = new TTree("knn", "event tree");
00674 tree->SetDirectory(0);
00675 tree->Branch("event", "TMVA::kNN::Event", &event);
00676
00677 Double_t size = 0.0;
00678 for (kNN::EventVec::const_iterator it = fEvent.begin(); it != fEvent.end(); ++it) {
00679 (*event) = (*it);
00680 size += tree->Fill();
00681 }
00682
00683
00684 rf.WriteTObject(tree, "knn", "Overwrite");
00685
00686
00687 size /= 1048576.0;
00688
00689 Log() << kINFO << "Wrote " << size << "MB and " << fEvent.size()
00690 << " events to ROOT file" << Endl;
00691
00692 delete tree;
00693 delete event;
00694 }
00695
00696
00697 void TMVA::MethodKNN::ReadWeightsFromStream(TFile &rf)
00698 {
00699
00700 Log() << kINFO << "Starting ReadWeightsFromStream(TFile &rf) function..." << Endl;
00701
00702 if (!fEvent.empty()) {
00703 Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
00704 fEvent.clear();
00705 }
00706
00707
00708 TTree *tree = dynamic_cast<TTree *>(rf.Get("knn"));
00709 if (!tree) {
00710 Log() << kFATAL << "Failed to find knn tree" << Endl;
00711 return;
00712 }
00713
00714 kNN::Event *event = new kNN::Event();
00715 tree->SetBranchAddress("event", &event);
00716
00717 const Int_t nevent = tree->GetEntries();
00718
00719 Double_t size = 0.0;
00720 for (Int_t i = 0; i < nevent; ++i) {
00721 size += tree->GetEntry(i);
00722 fEvent.push_back(*event);
00723 }
00724
00725
00726 size /= 1048576.0;
00727
00728 Log() << kINFO << "Read " << size << "MB and " << fEvent.size()
00729 << " events from ROOT file" << Endl;
00730
00731 delete event;
00732
00733
00734 MakeKNN();
00735 }
00736
00737
00738 void TMVA::MethodKNN::MakeClassSpecific( std::ostream& fout, const TString& className ) const
00739 {
00740
00741 fout << " // not implemented for class: \"" << className << "\"" << std::endl;
00742 fout << "};" << std::endl;
00743 }
00744
00745
00746 void TMVA::MethodKNN::GetHelpMessage() const
00747 {
00748
00749
00750
00751
00752 Log() << Endl;
00753 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
00754 Log() << Endl;
00755 Log() << "The k-nearest neighbor (k-NN) algorithm is a multi-dimensional classification" << Endl
00756 << "and regression algorithm. Similarly to other TMVA algorithms, k-NN uses a set of" << Endl
00757 << "training events for which a classification category/regression target is known. " << Endl
00758 << "The k-NN method compares a test event to all training events using a distance " << Endl
00759 << "function, which is an Euclidean distance in a space defined by the input variables. "<< Endl
00760 << "The k-NN method, as implemented in TMVA, uses a kd-tree algorithm to perform a" << Endl
00761 << "quick search for the k events with shortest distance to the test event. The method" << Endl
00762 << "returns a fraction of signal events among the k neighbors. It is recommended" << Endl
00763 << "that a histogram which stores the k-NN decision variable is binned with k+1 bins" << Endl
00764 << "between 0 and 1." << Endl;
00765
00766 Log() << Endl;
00767 Log() << gTools().Color("bold") << "--- Performance tuning via configuration options: "
00768 << gTools().Color("reset") << Endl;
00769 Log() << Endl;
00770 Log() << "The k-NN method estimates a density of signal and background events in a "<< Endl
00771 << "neighborhood around the test event. The method assumes that the density of the " << Endl
00772 << "signal and background events is uniform and constant within the neighborhood. " << Endl
00773 << "k is an adjustable parameter and it determines an average size of the " << Endl
00774 << "neighborhood. Small k values (less than 10) are sensitive to statistical " << Endl
00775 << "fluctuations and large (greater than 100) values might not sufficiently capture " << Endl
00776 << "local differences between events in the training set. The speed of the k-NN" << Endl
00777 << "method also increases with larger values of k. " << Endl;
00778 Log() << Endl;
00779 Log() << "The k-NN method assigns equal weight to all input variables. Different scales " << Endl
00780 << "among the input variables is compensated using ScaleFrac parameter: the input " << Endl
00781 << "variables are scaled so that the widths for central ScaleFrac*100% events are " << Endl
00782 << "equal among all the input variables." << Endl;
00783
00784 Log() << Endl;
00785 Log() << gTools().Color("bold") << "--- Additional configuration options: "
00786 << gTools().Color("reset") << Endl;
00787 Log() << Endl;
00788 Log() << "The method inclues an option to use a Gaussian kernel to smooth out the k-NN" << Endl
00789 << "response. The kernel re-weights events using a distance to the test event." << Endl;
00790 }
00791
00792
00793 Double_t TMVA::MethodKNN::PolnKernel(const Double_t value) const
00794 {
00795
00796 const Double_t avalue = TMath::Abs(value);
00797
00798 if (!(avalue < 1.0)) {
00799 return 0.0;
00800 }
00801
00802 const Double_t prod = 1.0 - avalue * avalue * avalue;
00803
00804 return (prod * prod * prod);
00805 }
00806
00807
00808 Double_t TMVA::MethodKNN::GausKernel(const kNN::Event &event_knn,
00809 const kNN::Event &event, const std::vector<Double_t> &svec) const
00810 {
00811
00812
00813 if (event_knn.GetNVar() != event.GetNVar() || event_knn.GetNVar() != svec.size()) {
00814 Log() << kFATAL << "Mismatched vectors in Gaussian kernel function" << Endl;
00815 return 0.0;
00816 }
00817
00818
00819
00820
00821 double sum_exp = 0.0;
00822
00823 for(unsigned int ivar = 0; ivar < event_knn.GetNVar(); ++ivar) {
00824
00825 const Double_t diff_ = event.GetVar(ivar) - event_knn.GetVar(ivar);
00826 const Double_t sigm_ = svec[ivar];
00827 if (!(sigm_ > 0.0)) {
00828 Log() << kFATAL << "Bad sigma value = " << sigm_ << Endl;
00829 return 0.0;
00830 }
00831
00832 sum_exp += diff_*diff_/(2.0*sigm_*sigm_);
00833 }
00834
00835
00836
00837
00838
00839
00840 return std::exp(-sum_exp);
00841 }
00842
00843
00844 Double_t TMVA::MethodKNN::getKernelRadius(const kNN::List &rlist) const
00845 {
00846
00847
00848
00849 Double_t kradius = -1.0;
00850 UInt_t kcount = 0;
00851 const UInt_t knn = static_cast<UInt_t>(fnkNN);
00852
00853 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
00854 {
00855 if (!(lit->second > 0.0)) continue;
00856
00857 if (kradius < lit->second || kradius < 0.0) kradius = lit->second;
00858
00859 ++kcount;
00860 if (kcount >= knn) break;
00861 }
00862
00863 return kradius;
00864 }
00865
00866
00867 const std::vector<Double_t> TMVA::MethodKNN::getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const
00868 {
00869
00870
00871
00872 std::vector<Double_t> rvec;
00873 UInt_t kcount = 0;
00874 const UInt_t knn = static_cast<UInt_t>(fnkNN);
00875
00876 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
00877 {
00878 if (!(lit->second > 0.0)) continue;
00879
00880 const kNN::Node<kNN::Event> *node_ = lit -> first;
00881 const kNN::Event &event_ = node_-> GetEvent();
00882
00883 if (rvec.empty()) {
00884 rvec.insert(rvec.end(), event_.GetNVar(), 0.0);
00885 }
00886 else if (rvec.size() != event_.GetNVar()) {
00887 Log() << kFATAL << "Wrong number of variables, should never happen!" << Endl;
00888 rvec.clear();
00889 return rvec;
00890 }
00891
00892 for(unsigned int ivar = 0; ivar < event_.GetNVar(); ++ivar) {
00893 const Double_t diff_ = event_.GetVar(ivar) - event_knn.GetVar(ivar);
00894 rvec[ivar] += diff_*diff_;
00895 }
00896
00897 ++kcount;
00898 if (kcount >= knn) break;
00899 }
00900
00901 if (kcount < 1) {
00902 Log() << kFATAL << "Bad event kcount = " << kcount << Endl;
00903 rvec.clear();
00904 return rvec;
00905 }
00906
00907 for(unsigned int ivar = 0; ivar < rvec.size(); ++ivar) {
00908 if (!(rvec[ivar] > 0.0)) {
00909 Log() << kFATAL << "Bad RMS value = " << rvec[ivar] << Endl;
00910 rvec.clear();
00911 return rvec;
00912 }
00913
00914 rvec[ivar] = std::abs(fSigmaFact)*std::sqrt(rvec[ivar]/kcount);
00915 }
00916
00917 return rvec;
00918 }
00919
00920
00921 Double_t TMVA::MethodKNN::getLDAValue(const kNN::List &rlist, const kNN::Event &event_knn)
00922 {
00923 LDAEvents sig_vec, bac_vec;
00924
00925 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
00926
00927
00928 const kNN::Node<kNN::Event> &node = *(lit->first);
00929 const kNN::VarVec &tvec = node.GetEvent().GetVars();
00930
00931 if (node.GetEvent().GetType() == 1) {
00932 sig_vec.push_back(tvec);
00933 }
00934 else if (node.GetEvent().GetType() == 2) {
00935 bac_vec.push_back(tvec);
00936 }
00937 else {
00938 Log() << kFATAL << "Unknown type for training event" << Endl;
00939 }
00940 }
00941
00942 fLDA.Initialize(sig_vec, bac_vec);
00943
00944 return fLDA.GetProb(event_knn.GetVars(), 1);
00945 }