00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034 #ifndef ROOT_TMVA_MethodFisher
00035 #define ROOT_TMVA_MethodFisher
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045 #include <vector>
00046
00047 #ifndef ROOT_TMVA_MethodBase
00048 #include "TMVA/MethodBase.h"
00049 #endif
00050 #ifndef ROOT_TMatrixDfwd
00051 #include "TMatrixDfwd.h"
00052 #endif
00053
00054 class TH1D;
00055
00056 namespace TMVA {
00057
00058 class MethodFisher : public MethodBase {
00059
00060 public:
00061
00062 MethodFisher( const TString& jobName,
00063 const TString& methodTitle,
00064 DataSetInfo& dsi,
00065 const TString& theOption = "Fisher",
00066 TDirectory* theTargetDir = 0 );
00067
00068 MethodFisher( DataSetInfo& dsi,
00069 const TString& theWeightFile,
00070 TDirectory* theTargetDir = NULL );
00071
00072 virtual ~MethodFisher( void );
00073
00074 virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets );
00075
00076
00077
00078 void Train( void );
00079
00080 using MethodBase::ReadWeightsFromStream;
00081
00082
00083 void AddWeightsXMLTo ( void* parent ) const;
00084
00085
00086 void ReadWeightsFromStream( std::istream & i );
00087 void ReadWeightsFromXML ( void* wghtnode );
00088
00089
00090 Double_t GetMvaValue( Double_t* err = 0, Double_t* errUpper = 0 );
00091
00092 enum EFisherMethod { kFisher, kMahalanobis };
00093 EFisherMethod GetFisherMethod( void ) { return fFisherMethod; }
00094
00095
00096 const Ranking* CreateRanking();
00097
00098 protected:
00099
00100
00101 void MakeClassSpecific( std::ostream&, const TString& ) const;
00102
00103
00104 void GetHelpMessage() const;
00105
00106 private:
00107
00108
00109 void DeclareOptions();
00110 void ProcessOptions();
00111
00112
00113 void InitMatrices( void );
00114
00115
00116 void GetMean( void );
00117
00118
00119 void GetCov_WithinClass( void );
00120
00121
00122 void GetCov_BetweenClass( void );
00123
00124
00125 void GetCov_Full( void );
00126
00127
00128 void GetDiscrimPower( void );
00129
00130
00131 void PrintCoefficients( void );
00132
00133
00134 void GetFisherCoeff( void );
00135
00136
00137 TMatrixD *fMeanMatx;
00138
00139
00140 TString fTheMethod;
00141 EFisherMethod fFisherMethod;
00142
00143
00144 TMatrixD *fBetw;
00145 TMatrixD *fWith;
00146 TMatrixD *fCov;
00147
00148
00149 Double_t fSumOfWeightsS;
00150 Double_t fSumOfWeightsB;
00151
00152 std::vector<Double_t>* fDiscrimPow;
00153 std::vector<Double_t>* fFisherCoeff;
00154 Double_t fF0;
00155
00156
00157 void Init( void );
00158
00159 ClassDef(MethodFisher,0)
00160 };
00161
00162 }
00163
00164 #endif // MethodFisher_H