00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026 #ifndef ROOT_TMVA_ModulekNN
00027 #define ROOT_TMVA_ModulekNN
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #include <cassert>
00039 #include <iosfwd>
00040 #include <map>
00041 #include <string>
00042 #include <vector>
00043
00044
00045 #ifndef ROOT_Rtypes
00046 #include "Rtypes.h"
00047 #endif
00048 #ifndef ROOT_TRandom
00049 #include "TRandom3.h"
00050 #endif
00051
00052 #ifndef ROOT_TMVA_NodekNN
00053 #include "TMVA/NodekNN.h"
00054 #endif
00055
00056 namespace TMVA {
00057
00058 class MsgLogger;
00059
00060 namespace kNN {
00061
00062 typedef Float_t VarType;
00063 typedef std::vector<VarType> VarVec;
00064
00065 class Event {
00066 public:
00067
00068 Event();
00069 Event(const VarVec &vec, Double_t weight, Short_t type);
00070 Event(const VarVec &vec, Double_t weight, Short_t type, const VarVec &tvec);
00071 ~Event();
00072
00073 Double_t GetWeight() const;
00074
00075 VarType GetVar(UInt_t i) const;
00076 VarType GetTgt(UInt_t i) const;
00077
00078 UInt_t GetNVar() const;
00079 UInt_t GetNTgt() const;
00080
00081 Short_t GetType() const;
00082
00083
00084 VarType GetDist(VarType var, UInt_t ivar) const;
00085 VarType GetDist(const Event &other) const;
00086
00087 void SetTargets(const VarVec &tvec);
00088 const VarVec& GetTargets() const;
00089 const VarVec& GetVars() const;
00090
00091 void Print() const;
00092 void Print(std::ostream& os) const;
00093
00094 private:
00095
00096 VarVec fVar;
00097 VarVec fTgt;
00098
00099 Double_t fWeight;
00100 Short_t fType;
00101 };
00102
00103 typedef std::vector<TMVA::kNN::Event> EventVec;
00104 typedef std::pair<const Node<Event> *, VarType> Elem;
00105 typedef std::list<Elem> List;
00106
00107 std::ostream& operator<<(std::ostream& os, const Event& event);
00108
00109 class ModulekNN
00110 {
00111 public:
00112
00113 typedef std::map<int, std::vector<Double_t> > VarMap;
00114
00115 public:
00116
00117 ModulekNN();
00118 ~ModulekNN();
00119
00120 void Clear();
00121
00122 void Add(const Event &event);
00123
00124 Bool_t Fill(const UShort_t odepth, UInt_t ifrac, const std::string &option = "");
00125
00126 Bool_t Find(Event event, UInt_t nfind = 100, const std::string &option = "count") const;
00127 Bool_t Find(UInt_t nfind, const std::string &option) const;
00128
00129 const EventVec& GetEventVec() const;
00130
00131 const List& GetkNNList() const;
00132 const Event& GetkNNEvent() const;
00133
00134 const VarMap& GetVarMap() const;
00135
00136 const std::map<Int_t, Double_t>& GetMetric() const;
00137
00138 void Print() const;
00139 void Print(std::ostream &os) const;
00140
00141 private:
00142
00143 Node<Event>* Optimize(UInt_t optimize_depth);
00144
00145 void ComputeMetric(UInt_t ifrac);
00146
00147 const Event Scale(const Event &event) const;
00148
00149 private:
00150
00151 static TRandom3 fgRndm;
00152
00153 UInt_t fDimn;
00154
00155 Node<Event> *fTree;
00156
00157 std::map<Int_t, Double_t> fVarScale;
00158
00159 mutable List fkNNList;
00160 mutable Event fkNNEvent;
00161
00162 std::map<Short_t, UInt_t> fCount;
00163
00164 EventVec fEvent;
00165 VarMap fVar;
00166
00167 mutable MsgLogger* fLogger;
00168 MsgLogger& Log() const { return *fLogger; }
00169 };
00170
00171
00172
00173
00174 inline VarType Event::GetDist(const VarType var1, const UInt_t ivar) const
00175 {
00176 const VarType var2 = GetVar(ivar);
00177 return (var1 - var2) * (var1 - var2);
00178 }
00179 inline Double_t Event::GetWeight() const
00180 {
00181 return fWeight;
00182 }
00183 inline VarType Event::GetVar(const UInt_t i) const
00184 {
00185 return fVar[i];
00186 }
00187 inline VarType Event::GetTgt(const UInt_t i) const
00188 {
00189 return fTgt[i];
00190 }
00191
00192 inline UInt_t Event::GetNVar() const
00193 {
00194 return fVar.size();
00195 }
00196 inline UInt_t Event::GetNTgt() const
00197 {
00198 return fTgt.size();
00199 }
00200 inline Short_t Event::GetType() const
00201 {
00202 return fType;
00203 }
00204
00205
00206
00207
00208 inline const List& ModulekNN::GetkNNList() const
00209 {
00210 return fkNNList;
00211 }
00212 inline const Event& ModulekNN::GetkNNEvent() const
00213 {
00214 return fkNNEvent;
00215 }
00216 inline const EventVec& ModulekNN::GetEventVec() const
00217 {
00218 return fEvent;
00219 }
00220 inline const ModulekNN::VarMap& ModulekNN::GetVarMap() const
00221 {
00222 return fVar;
00223 }
00224 inline const std::map<Int_t, Double_t>& ModulekNN::GetMetric() const
00225 {
00226 return fVarScale;
00227 }
00228
00229 }
00230 }
00231
00232 #endif
00233