ModulekNN.h

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: ModulekNN.h 35738 2010-09-26 09:17:57Z stelzer $
00002 // Author: Rustem Ospanov
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : ModulekNN                                                             *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      Module for k-nearest neighbor algorithm                                   *
00012  *                                                                                *
00013  * Author:                                                                        *
00014  *      Rustem Ospanov <rustem@fnal.gov> - U. of Texas at Austin, USA             *
00015  *                                                                                *
00016  * Copyright (c) 2007:                                                            *
00017  *      CERN, Switzerland                                                         *
00018  *      MPI-K Heidelberg, Germany                                                 *
00019  *      U. of Texas at Austin, USA                                                *
00020  *                                                                                *
00021  * Redistribution and use in source and binary forms, with or without             *
00022  * modification, are permitted according to the terms listed in LICENSE           *
00023  * (http://tmva.sourceforge.net/LICENSE)                                          *
00024  **********************************************************************************/
00025 
00026 #ifndef ROOT_TMVA_ModulekNN
00027 #define ROOT_TMVA_ModulekNN
00028 
00029 //______________________________________________________________________
00030 /*
00031   kNN::Event describes point in input variable vector-space, with
00032   additional functionality like distance between points
00033 */
00034 //______________________________________________________________________
00035 
00036 
00037 // C++
00038 #include <cassert>
00039 #include <iosfwd>
00040 #include <map>
00041 #include <string>
00042 #include <vector>
00043 
00044 // ROOT
00045 #ifndef ROOT_Rtypes
00046 #include "Rtypes.h"
00047 #endif
00048 #ifndef ROOT_TRandom
00049 #include "TRandom3.h"
00050 #endif
00051 
00052 #ifndef ROOT_TMVA_NodekNN
00053 #include "TMVA/NodekNN.h"
00054 #endif
00055 
00056 namespace TMVA {
00057 
00058    class MsgLogger;
00059 
00060    namespace kNN {
00061       
00062       typedef Float_t VarType;
00063       typedef std::vector<VarType> VarVec;
00064       
00065       class Event {
00066       public:
00067 
00068          Event();
00069          Event(const VarVec &vec, Double_t weight, Short_t type);
00070          Event(const VarVec &vec, Double_t weight, Short_t type, const VarVec &tvec);
00071          ~Event();
00072 
00073          Double_t GetWeight() const;
00074 
00075          VarType GetVar(UInt_t i) const;
00076          VarType GetTgt(UInt_t i) const;
00077 
00078          UInt_t GetNVar() const;
00079          UInt_t GetNTgt() const;
00080 
00081          Short_t GetType() const;
00082 
00083          // keep these two function separate
00084          VarType GetDist(VarType var, UInt_t ivar) const;
00085          VarType GetDist(const Event &other) const;
00086 
00087          void SetTargets(const VarVec &tvec);
00088          const VarVec& GetTargets() const;
00089          const VarVec& GetVars() const;
00090 
00091          void Print() const;
00092          void Print(std::ostream& os) const;
00093 
00094       private:
00095 
00096          VarVec fVar; // coordinates (variables) for knn search
00097          VarVec fTgt; // targets for regression analysis
00098 
00099          Double_t fWeight; // event weight
00100          Short_t fType; // event type ==0 or == 1, expand it to arbitrary class types? 
00101       };
00102 
00103       typedef std::vector<TMVA::kNN::Event> EventVec;
00104       typedef std::pair<const Node<Event> *, VarType> Elem;
00105       typedef std::list<Elem> List;
00106 
00107       std::ostream& operator<<(std::ostream& os, const Event& event);
00108 
00109       class ModulekNN
00110       {
00111       public:
00112 
00113          typedef std::map<int, std::vector<Double_t> > VarMap;
00114 
00115       public:
00116 
00117          ModulekNN();
00118          ~ModulekNN();
00119 
00120          void Clear();
00121 
00122          void Add(const Event &event);
00123 
00124          Bool_t Fill(const UShort_t odepth, UInt_t ifrac, const std::string &option = "");
00125 
00126          Bool_t Find(Event event, UInt_t nfind = 100, const std::string &option = "count") const;
00127          Bool_t Find(UInt_t nfind, const std::string &option) const;
00128       
00129          const EventVec& GetEventVec() const;
00130 
00131          const List& GetkNNList() const;
00132          const Event& GetkNNEvent() const;
00133 
00134          const VarMap& GetVarMap() const;
00135 
00136          const std::map<Int_t, Double_t>& GetMetric() const;
00137       
00138          void Print() const;
00139          void Print(std::ostream &os) const;
00140 
00141       private:
00142 
00143          Node<Event>* Optimize(UInt_t optimize_depth);
00144 
00145          void ComputeMetric(UInt_t ifrac);
00146 
00147          const Event Scale(const Event &event) const;
00148 
00149       private:
00150 
00151          static TRandom3 fgRndm;
00152 
00153          UInt_t fDimn;
00154 
00155          Node<Event> *fTree;
00156 
00157          std::map<Int_t, Double_t> fVarScale;
00158 
00159          mutable List  fkNNList;     // latest result from kNN search
00160          mutable Event fkNNEvent;    // latest event used for kNN search
00161          
00162          std::map<Short_t, UInt_t> fCount; // count number of events of each type
00163 
00164          EventVec fEvent; // vector of all events used to build tree and analysis
00165          VarMap   fVar;   // sorted map of variables in each dimension for all event types
00166 
00167          mutable MsgLogger* fLogger;   // message logger
00168          MsgLogger& Log() const { return *fLogger; }
00169       };
00170 
00171       //
00172       // inlined functions for Event class
00173       //
00174       inline VarType Event::GetDist(const VarType var1, const UInt_t ivar) const
00175       {
00176          const VarType var2 = GetVar(ivar);
00177          return (var1 - var2) * (var1 - var2);
00178       }
00179       inline Double_t Event::GetWeight() const
00180       {
00181          return fWeight;
00182       }
00183       inline VarType Event::GetVar(const UInt_t i) const
00184       {
00185          return fVar[i];
00186       }
00187       inline VarType Event::GetTgt(const UInt_t i) const
00188       {
00189          return fTgt[i];
00190       }
00191 
00192       inline UInt_t Event::GetNVar() const
00193       {
00194          return fVar.size();
00195       }
00196       inline UInt_t Event::GetNTgt() const
00197       {
00198          return fTgt.size();
00199       }
00200       inline Short_t Event::GetType() const
00201       {
00202          return fType;
00203       }
00204 
00205       //
00206       // inline functions for ModulekNN class
00207       //
00208       inline const List& ModulekNN::GetkNNList() const
00209       {
00210          return fkNNList;
00211       }
00212       inline const Event& ModulekNN::GetkNNEvent() const
00213       {
00214          return fkNNEvent;
00215       }
00216       inline const EventVec& ModulekNN::GetEventVec() const
00217       {
00218          return fEvent;
00219       }
00220       inline const ModulekNN::VarMap& ModulekNN::GetVarMap() const
00221       {
00222          return fVar;
00223       }
00224       inline const std::map<Int_t, Double_t>& ModulekNN::GetMetric() const
00225       {
00226          return fVarScale;
00227       }
00228 
00229    } // end of kNN namespace
00230 } // end of TMVA namespace
00231 
00232 #endif
00233 

Generated on Tue Jul 5 14:27:32 2011 for ROOT_528-00b_version by  doxygen 1.5.1