00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033 #ifndef ROOT_TMVA_MethodANNBase
00034 #define ROOT_TMVA_MethodANNBase
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044 #ifndef ROOT_TString
00045 #include "TString.h"
00046 #endif
00047 #include <vector>
00048 #ifndef ROOT_TTree
00049 #include "TTree.h"
00050 #endif
00051 #ifndef ROOT_TObjArray
00052 #include "TObjArray.h"
00053 #endif
00054 #ifndef ROOT_TRandom3
00055 #include "TRandom3.h"
00056 #endif
00057
00058 #ifndef ROOT_TMVA_MethodBase
00059 #include "TMVA/MethodBase.h"
00060 #endif
00061 #ifndef ROOT_TMVA_TActivation
00062 #include "TMVA/TActivation.h"
00063 #endif
00064 #ifndef ROOT_TMVA_TNeuron
00065 #include "TMVA/TNeuron.h"
00066 #endif
00067 #ifndef ROOT_TMVA_TNeuronInput
00068 #include "TMVA/TNeuronInput.h"
00069 #endif
00070
00071 class TH1;
00072 class TH1F;
00073
00074 namespace TMVA {
00075
00076 class MethodANNBase : public MethodBase {
00077
00078 public:
00079
00080
00081 MethodANNBase( const TString& jobName,
00082 Types::EMVA methodType,
00083 const TString& methodTitle,
00084 DataSetInfo& theData,
00085 const TString& theOption,
00086 TDirectory* theTargetDir );
00087
00088 MethodANNBase( Types::EMVA methodType,
00089 DataSetInfo& theData,
00090 const TString& theWeightFile,
00091 TDirectory* theTargetDir );
00092
00093 virtual ~MethodANNBase();
00094
00095
00096 void InitANNBase();
00097
00098
00099 void SetActivation(TActivation* activation) {
00100 if (fActivation != NULL) delete fActivation; fActivation = activation;
00101 }
00102 void SetNeuronInputCalculator(TNeuronInput* inputCalculator) {
00103 if (fInputCalculator != NULL) delete fInputCalculator;
00104 fInputCalculator = inputCalculator;
00105 }
00106
00107
00108 virtual void Train() = 0;
00109
00110
00111 virtual void PrintNetwork() const;
00112
00113 using MethodBase::ReadWeightsFromStream;
00114
00115
00116 void AddWeightsXMLTo( void* parent ) const;
00117 void ReadWeightsFromXML( void* wghtnode );
00118
00119
00120 virtual void ReadWeightsFromStream( istream& istr );
00121
00122
00123 virtual Double_t GetMvaValue( Double_t* err = 0, Double_t* errUpper = 0 );
00124
00125 virtual const std::vector<Float_t> &GetRegressionValues();
00126
00127 virtual const std::vector<Float_t> &GetMulticlassValues();
00128
00129
00130 virtual void WriteMonitoringHistosToFile() const;
00131
00132
00133 const Ranking* CreateRanking();
00134
00135
00136 virtual void DeclareOptions();
00137 virtual void ProcessOptions();
00138
00139 Bool_t Debug() const;
00140
00141 enum EEstimator { kMSE=0,kCE};
00142
00143 protected:
00144
00145 virtual void MakeClassSpecific( std::ostream&, const TString& ) const;
00146
00147 std::vector<Int_t>* ParseLayoutString( TString layerSpec );
00148 virtual void BuildNetwork( std::vector<Int_t>* layout, std::vector<Double_t>* weights=NULL,
00149 Bool_t fromFile = kFALSE );
00150 void ForceNetworkInputs( const Event* ev, Int_t ignoreIndex = -1 );
00151 Double_t GetNetworkOutput() { return GetOutputNeuron()->GetActivationValue(); }
00152
00153
00154 void PrintMessage( TString message, Bool_t force = kFALSE ) const;
00155 void ForceNetworkCalculations();
00156 void WaitForKeyboard();
00157
00158
00159 Int_t NumCycles() { return fNcycles; }
00160 TNeuron* GetInputNeuron(Int_t index) { return (TNeuron*)fInputLayer->At(index); }
00161 TNeuron* GetOutputNeuron( Int_t index = 0) { return fOutputNeurons.at(index); }
00162
00163
00164 TObjArray* fNetwork;
00165 TObjArray* fSynapses;
00166 TActivation* fActivation;
00167 TActivation* fOutput;
00168 TActivation* fIdentity;
00169 TRandom3* frgen;
00170 TNeuronInput* fInputCalculator;
00171
00172 std::vector<Int_t> fRegulatorIdx;
00173 std::vector<Double_t> fRegulators;
00174 EEstimator fEstimator;
00175 TString fEstimatorS;
00176
00177
00178 TH1F* fEstimatorHistTrain;
00179 TH1F* fEstimatorHistTest;
00180
00181
00182 void CreateWeightMonitoringHists( const TString& bulkname, std::vector<TH1*>* hv = 0 ) const;
00183 std::vector<TH1*> fEpochMonHistS;
00184 std::vector<TH1*> fEpochMonHistB;
00185 std::vector<TH1*> fEpochMonHistW;
00186
00187
00188
00189 TMatrixD fInvHessian;
00190 bool fUseRegulator;
00191
00192 protected:
00193 Int_t fRandomSeed;
00194
00195 private:
00196
00197
00198 void BuildLayers(std::vector<Int_t>* layout, Bool_t from_file = false);
00199 void BuildLayer(Int_t numNeurons, TObjArray* curLayer, TObjArray* prevLayer,
00200 Int_t layerIndex, Int_t numLayers, Bool_t from_file = false);
00201 void AddPreLinks(TNeuron* neuron, TObjArray* prevLayer);
00202
00203
00204 void InitWeights();
00205 void ForceWeights(std::vector<Double_t>* weights);
00206
00207
00208 void DeleteNetwork();
00209 void DeleteNetworkLayer(TObjArray*& layer);
00210
00211
00212 void PrintLayer(TObjArray* layer) const;
00213 void PrintNeuron(TNeuron* neuron) const;
00214
00215
00216 Int_t fNcycles;
00217 TString fNeuronType;
00218 TString fNeuronInputType;
00219 TObjArray* fInputLayer;
00220 std::vector<TNeuron*> fOutputNeurons;
00221 TString fLayerSpec;
00222
00223
00224 static const Bool_t fgDEBUG = kTRUE;
00225
00226 ClassDef(MethodANNBase,0)
00227 };
00228
00229 }
00230
00231 #endif