00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030 #ifndef ROOT_TMVA_RuleFitAPI
00031 #define ROOT_TMVA_RuleFitAPI
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041 #include <fstream>
00042
00043 #include "TMVA/MsgLogger.h"
00044
00045 namespace TMVA {
00046
00047 class MethodRuleFit;
00048
00049 class RuleFitAPI {
00050
00051 public:
00052
00053 RuleFitAPI( const TMVA::MethodRuleFit *rfbase, TMVA::RuleFit *rulefit, EMsgType minType );
00054
00055 virtual ~RuleFitAPI();
00056
00057
00058 void WelcomeMessage();
00059
00060
00061 void HowtoSetupRF();
00062
00063
00064 void SetRFWorkDir(const char * wdir);
00065
00066
00067 void CheckRFWorkDir();
00068
00069
00070 inline void TrainRuleFit();
00071 inline void TestRuleFit();
00072 inline void VarImp();
00073
00074
00075 Bool_t ReadModelSum();
00076
00077
00078 const TString GetRFWorkDir() const { return fRFWorkDir; }
00079
00080 protected:
00081
00082 enum ERFMode { kRfRegress=1, kRfClass=2 };
00083 enum EModel { kRfLinear=0, kRfRules=1, kRfBoth=2 };
00084 enum ERFProgram { kRfTrain=0, kRfPredict, kRfVarimp };
00085
00086
00087 typedef struct {
00088 Int_t mode;
00089 Int_t lmode;
00090 Int_t n;
00091 Int_t p;
00092 Int_t max_rules;
00093 Int_t tree_size;
00094 Int_t path_speed;
00095 Int_t path_xval;
00096 Int_t path_steps;
00097 Int_t path_testfreq;
00098 Int_t tree_store;
00099 Int_t cat_store;
00100 } IntParms;
00101
00102
00103 typedef struct {
00104 Float_t xmiss;
00105 Float_t trim_qntl;
00106 Float_t huber;
00107 Float_t inter_supp;
00108 Float_t memory_par;
00109 Float_t samp_fract;
00110 Float_t path_inc;
00111 Float_t conv_fac;
00112 } RealParms;
00113
00114
00115 void InitRuleFit();
00116 void FillRealParmsDef();
00117 void FillIntParmsDef();
00118 void ImportSetup();
00119 void SetTrainParms();
00120 void SetTestParms();
00121
00122
00123 Int_t RunRuleFit();
00124
00125
00126 void SetRFTrain() { fRFProgram = kRfTrain; }
00127 void SetRFPredict() { fRFProgram = kRfPredict; }
00128 void SetRFVarimp() { fRFProgram = kRfVarimp; }
00129
00130
00131 inline TString GetRFName(TString name);
00132 inline Bool_t OpenRFile(TString name, std::ofstream & f);
00133 inline Bool_t OpenRFile(TString name, std::ifstream & f);
00134
00135
00136 inline Bool_t WriteInt(ofstream & f, const Int_t *v, Int_t n=1);
00137 inline Bool_t WriteFloat(ofstream & f, const Float_t *v, Int_t n=1);
00138 inline Int_t ReadInt(ifstream & f, Int_t *v, Int_t n=1) const;
00139 inline Int_t ReadFloat(ifstream & f, Float_t *v, Int_t n=1) const;
00140
00141
00142 Bool_t WriteAll();
00143 Bool_t WriteIntParms();
00144 Bool_t WriteRealParms();
00145 Bool_t WriteLx();
00146 Bool_t WriteProgram();
00147 Bool_t WriteRealVarImp();
00148 Bool_t WriteRfOut();
00149 Bool_t WriteRfStatus();
00150 Bool_t WriteRuleFitMod();
00151 Bool_t WriteRuleFitSum();
00152 Bool_t WriteTrain();
00153 Bool_t WriteVarNames();
00154 Bool_t WriteVarImp();
00155 Bool_t WriteYhat();
00156 Bool_t WriteTest();
00157
00158
00159 Bool_t ReadYhat();
00160 Bool_t ReadIntParms();
00161 Bool_t ReadRealParms();
00162 Bool_t ReadLx();
00163 Bool_t ReadProgram();
00164 Bool_t ReadRealVarImp();
00165 Bool_t ReadRfOut();
00166 Bool_t ReadRfStatus();
00167 Bool_t ReadRuleFitMod();
00168 Bool_t ReadRuleFitSum();
00169 Bool_t ReadTrainX();
00170 Bool_t ReadTrainY();
00171 Bool_t ReadTrainW();
00172 Bool_t ReadVarNames();
00173 Bool_t ReadVarImp();
00174
00175 private:
00176
00177 RuleFitAPI();
00178 const MethodRuleFit *fMethodRuleFit;
00179 RuleFit *fRuleFit;
00180
00181 std::vector<Float_t> fRFYhat;
00182 std::vector<Float_t> fRFVarImp;
00183 std::vector<Int_t> fRFVarImpInd;
00184 TString fRFWorkDir;
00185 IntParms fRFIntParms;
00186 RealParms fRFRealParms;
00187 std::vector<int> fRFLx;
00188 ERFProgram fRFProgram;
00189 TString fModelType;
00190
00191 mutable MsgLogger fLogger;
00192
00193 ClassDef(RuleFitAPI,0)
00194
00195 };
00196
00197 }
00198
00199
00200 void TMVA::RuleFitAPI::TrainRuleFit()
00201 {
00202
00203 SetTrainParms();
00204 WriteAll();
00205 RunRuleFit();
00206 }
00207
00208
00209 void TMVA::RuleFitAPI::TestRuleFit()
00210 {
00211
00212 SetTestParms();
00213 WriteAll();
00214 RunRuleFit();
00215 ReadYhat();
00216 }
00217
00218
00219 void TMVA::RuleFitAPI::VarImp()
00220 {
00221
00222 SetRFVarimp();
00223 WriteAll();
00224 RunRuleFit();
00225 ReadVarImp();
00226 }
00227
00228
00229 TString TMVA::RuleFitAPI::GetRFName(TString name)
00230 {
00231
00232 return fRFWorkDir+"/"+name;
00233 }
00234
00235
00236 Bool_t TMVA::RuleFitAPI::OpenRFile(TString name, std::ofstream & f)
00237 {
00238
00239 TString fullName = GetRFName(name);
00240 f.open(fullName);
00241 if (!f.is_open()) {
00242 fLogger << kERROR << "Error opening RuleFit file for output: "
00243 << fullName << Endl;
00244 return kFALSE;
00245 }
00246 return kTRUE;
00247 }
00248
00249
00250 Bool_t TMVA::RuleFitAPI::OpenRFile(TString name, std::ifstream & f)
00251 {
00252
00253 TString fullName = GetRFName(name);
00254 f.open(fullName);
00255 if (!f.is_open()) {
00256 fLogger << kERROR << "Error opening RuleFit file for input: "
00257 << fullName << Endl;
00258 return kFALSE;
00259 }
00260 return kTRUE;
00261 }
00262
00263
00264 Bool_t TMVA::RuleFitAPI::WriteInt(ofstream & f, const Int_t *v, Int_t n)
00265 {
00266
00267 if (!f.is_open()) return kFALSE;
00268 return f.write(reinterpret_cast<char const *>(v), n*sizeof(Int_t));
00269 }
00270
00271
00272 Bool_t TMVA::RuleFitAPI::WriteFloat(ofstream & f, const Float_t *v, Int_t n)
00273 {
00274
00275 if (!f.is_open()) return kFALSE;
00276 return f.write(reinterpret_cast<char const *>(v), n*sizeof(Float_t));
00277 }
00278
00279
00280 Int_t TMVA::RuleFitAPI::ReadInt(ifstream & f, Int_t *v, Int_t n) const
00281 {
00282
00283 if (!f.is_open()) return 0;
00284 if (f.read(reinterpret_cast<char *>(v), n*sizeof(Int_t))) return 1;
00285 return 0;
00286 }
00287
00288
00289 Int_t TMVA::RuleFitAPI::ReadFloat(ifstream & f, Float_t *v, Int_t n) const
00290 {
00291
00292 if (!f.is_open()) return 0;
00293 if (f.read(reinterpret_cast<char *>(v), n*sizeof(Float_t))) return 1;
00294 return 0;
00295 }
00296
00297 #endif // RuleFitAPI_H