00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029 #ifndef ROOT_TMVA_DataSet
00030 #define ROOT_TMVA_DataSet
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040 #include <vector>
00041 #include <map>
00042 #include <string>
00043
00044 #ifndef ROOT_TObject
00045 #include "TObject.h"
00046 #endif
00047 #ifndef ROOT_TString
00048 #include "TString.h"
00049 #endif
00050 #ifndef ROOT_TTree
00051 #include "TTree.h"
00052 #endif
00053 #ifndef ROOT_TCut
00054 #include "TCut.h"
00055 #endif
00056 #ifndef ROOT_TMatrixDfwd
00057 #include "TMatrixDfwd.h"
00058 #endif
00059 #ifndef ROOT_TPrincipal
00060 #include "TPrincipal.h"
00061 #endif
00062 #ifndef ROOT_TRandom3
00063 #include "TRandom3.h"
00064 #endif
00065
00066 #ifndef ROOT_TMVA_Types
00067 #include "TMVA/Types.h"
00068 #endif
00069 #ifndef ROOT_TMVA_VariableInfo
00070 #include "TMVA/VariableInfo.h"
00071 #endif
00072
00073 namespace TMVA {
00074
00075 class Event;
00076 class DataSetInfo;
00077 class MsgLogger;
00078 class Results;
00079
00080 class DataSet {
00081
00082 public:
00083
00084 DataSet(const DataSetInfo&);
00085 virtual ~DataSet();
00086
00087 void AddEvent( Event *, Types::ETreeType );
00088
00089 Long64_t GetNEvents( Types::ETreeType type = Types::kMaxTreeType ) const;
00090 Long64_t GetNTrainingEvents() const { return GetNEvents(Types::kTraining); }
00091 Long64_t GetNTestEvents() const { return GetNEvents(Types::kTesting); }
00092 Event* GetEvent() const;
00093 Event* GetEvent ( Long64_t ievt ) const { fCurrentEventIdx = ievt; return GetEvent(); }
00094 Event* GetTrainingEvent( Long64_t ievt ) const { return GetEvent(ievt, Types::kTraining); }
00095 Event* GetTestEvent ( Long64_t ievt ) const { return GetEvent(ievt, Types::kTesting); }
00096 Event* GetEvent ( Long64_t ievt, Types::ETreeType type ) const {
00097 fCurrentTreeIdx = TreeIndex(type); fCurrentEventIdx = ievt; return GetEvent();
00098 }
00099
00100 UInt_t GetNVariables() const;
00101 UInt_t GetNTargets() const;
00102 UInt_t GetNSpectators() const;
00103
00104 void SetCurrentEvent( Long64_t ievt ) const { fCurrentEventIdx = ievt; }
00105 void SetCurrentType ( Types::ETreeType type ) const { fCurrentTreeIdx = TreeIndex(type); }
00106 Types::ETreeType GetCurrentType() const;
00107
00108 void SetEventCollection( std::vector<Event*>*, Types::ETreeType );
00109 const std::vector<Event*>& GetEventCollection( Types::ETreeType type = Types::kMaxTreeType ) const;
00110 const TTree* GetEventCollectionAsTree();
00111
00112 Long64_t GetNEvtSigTest();
00113 Long64_t GetNEvtBkgdTest();
00114 Long64_t GetNEvtSigTrain();
00115 Long64_t GetNEvtBkgdTrain();
00116
00117 Bool_t HasNegativeEventWeights() const { return fHasNegativeEventWeights; }
00118
00119 Results* GetResults ( const TString &,
00120 Types::ETreeType type,
00121 Types::EAnalysisType analysistype );
00122 void DeleteResults ( const TString &,
00123 Types::ETreeType type,
00124 Types::EAnalysisType analysistype );
00125
00126 void SetVerbose( Bool_t ) {}
00127
00128
00129
00130 void DivideTrainingSet( UInt_t blockNum );
00131
00132
00133 void MoveTrainingBlock( Int_t blockInd,Types::ETreeType dest, Bool_t applyChanges = kTRUE );
00134
00135 void IncrementNClassEvents( Int_t type, UInt_t classNumber );
00136 Long64_t GetNClassEvents ( Int_t type, UInt_t classNumber );
00137 void ClearNClassEvents ( Int_t type );
00138
00139 TTree* GetTree( Types::ETreeType type );
00140
00141
00142 void InitSampling( Float_t fraction, Float_t weight, UInt_t seed = 0 );
00143 void EventResult( Bool_t successful, Long64_t evtNumber = -1 );
00144 void CreateSampling() const;
00145
00146 UInt_t TreeIndex(Types::ETreeType type) const;
00147
00148 private:
00149
00150
00151 DataSet();
00152 void DestroyCollection( Types::ETreeType type, Bool_t deleteEvents );
00153
00154 const DataSetInfo& fdsi;
00155
00156 std::vector<Event*>::iterator fEvtCollIt;
00157 std::vector< std::vector<Event*>* > fEventCollection;
00158
00159 std::vector< std::map< TString, Results* > > fResults;
00160
00161 mutable UInt_t fCurrentTreeIdx;
00162 mutable Long64_t fCurrentEventIdx;
00163
00164
00165 std::vector<Char_t> fSampling;
00166 std::vector<Int_t> fSamplingNEvents;
00167 std::vector<Float_t> fSamplingWeight;
00168 mutable std::vector< std::vector< std::pair< Float_t, Long64_t >* > > fSamplingEventList;
00169 mutable std::vector< std::vector< std::pair< Float_t, Long64_t >* > > fSamplingSelected;
00170 TRandom3 *fSamplingRandom;
00171
00172
00173
00174 std::vector< std::vector<Long64_t> > fClassEvents;
00175
00176
00177 Bool_t fHasNegativeEventWeights;
00178
00179 mutable MsgLogger* fLogger;
00180 MsgLogger& Log() const { return *fLogger; }
00181 std::vector<Char_t> fBlockBelongToTraining;
00182
00183
00184
00185 Long64_t fTrainingBlockSize;
00186
00187 void ApplyTrainingBlockDivision();
00188 void ApplyTrainingSetDivision();
00189 };
00190 }
00191
00192
00193
00194 inline UInt_t TMVA::DataSet::TreeIndex(Types::ETreeType type) const
00195 {
00196 switch (type) {
00197 case Types::kMaxTreeType : return fCurrentTreeIdx;
00198 case Types::kTraining : return 0;
00199 case Types::kTesting : return 1;
00200 case Types::kValidation : return 2;
00201 case Types::kTrainingOriginal : return 3;
00202 default : return fCurrentTreeIdx;
00203 }
00204 }
00205
00206
00207 inline TMVA::Types::ETreeType TMVA::DataSet::GetCurrentType() const
00208 {
00209 switch (fCurrentTreeIdx) {
00210 case 0: return Types::kTraining;
00211 case 1: return Types::kTesting;
00212 case 2: return Types::kValidation;
00213 case 3: return Types::kTrainingOriginal;
00214 }
00215 return Types::kMaxTreeType;
00216 }
00217
00218
00219 inline Long64_t TMVA::DataSet::GetNEvents(Types::ETreeType type) const
00220 {
00221 Int_t treeIdx = TreeIndex(type);
00222 if (fSampling.size() > UInt_t(treeIdx) && fSampling.at(treeIdx)) {
00223 return fSamplingSelected.at(treeIdx).size();
00224 }
00225 return GetEventCollection(type).size();
00226 }
00227
00228
00229 inline const std::vector<TMVA::Event*>& TMVA::DataSet::GetEventCollection( TMVA::Types::ETreeType type ) const
00230 {
00231 return *(fEventCollection.at(TreeIndex(type)));
00232 }
00233
00234
00235 #endif