MethodKNN.h

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: MethodKNN.h 36966 2010-11-26 09:50:13Z evt $
00002 // Author: Rustem Ospanov
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : MethodKNN                                                             *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      Analysis of k-nearest neighbor                                            *
00012  *                                                                                *
00013  * Author:                                                                        *
00014  *      Rustem Ospanov <rustem@fnal.gov> - U. of Texas at Austin, USA             *
00015  *                                                                                *
00016  * Copyright (c) 2007:                                                            *
00017  *      CERN, Switzerland                                                         * 
00018  *      MPI-K Heidelberg, Germany                                                 * 
00019  *      U. of Texas at Austin, USA                                                *
00020  *                                                                                *
00021  * Redistribution and use in source and binary forms, with or without             *
00022  * modification, are permitted according to the terms listed in LICENSE           *
00023  * (http://tmva.sourceforge.net/LICENSE)                                          *
00024  **********************************************************************************/
00025 
00026 #ifndef ROOT_TMVA_MethodKNN
00027 #define ROOT_TMVA_MethodKNN
00028 
00029 //////////////////////////////////////////////////////////////////////////
00030 //                                                                      //
00031 // MethodKNN                                                            //
00032 //                                                                      //
00033 // Analysis of k-nearest neighbor                                       //
00034 //                                                                      //
00035 //////////////////////////////////////////////////////////////////////////
00036 
00037 #include <vector>
00038 #include <map>
00039 
00040 // Local
00041 #ifndef ROOT_TMVA_MethodBase
00042 #include "TMVA/MethodBase.h"
00043 #endif
00044 #ifndef ROOT_TMVA_ModulekNN
00045 #include "TMVA/ModulekNN.h"
00046 #endif
00047 
00048 // SVD and linear discriminat code
00049 #ifndef ROOT_TMVA_LDA
00050 #include "TMVA/LDA.h"
00051 #endif
00052 
00053 namespace TMVA
00054 {   
00055    namespace kNN
00056    {
00057       class ModulekNN;
00058    }
00059 
00060    class MethodKNN : public MethodBase
00061    {
00062    public:
00063 
00064       MethodKNN(const TString& jobName, 
00065                 const TString& methodTitle, 
00066                 DataSetInfo& theData,
00067                 const TString& theOption = "KNN",
00068                 TDirectory* theTargetDir = NULL);
00069 
00070       MethodKNN(DataSetInfo& theData, 
00071                 const TString& theWeightFile,  
00072                 TDirectory* theTargetDir = NULL);
00073       
00074       virtual ~MethodKNN( void );
00075     
00076       virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets );
00077 
00078       void Train( void );
00079 
00080       Double_t GetMvaValue( Double_t* err = 0, Double_t* errUpper = 0 );
00081       const std::vector<Float_t>& GetRegressionValues();
00082 
00083       using MethodBase::ReadWeightsFromStream;
00084 
00085       void WriteWeightsToStream(TFile& rf) const;
00086       void AddWeightsXMLTo( void* parent ) const;
00087       void ReadWeightsFromXML( void* wghtnode );
00088 
00089       void ReadWeightsFromStream(std::istream& istr);
00090       void ReadWeightsFromStream(TFile &rf);
00091 
00092       const Ranking* CreateRanking();
00093 
00094    protected:
00095 
00096       // make ROOT-independent C++ class for classifier response (classifier-specific implementation)
00097       void MakeClassSpecific( std::ostream&, const TString& ) const;
00098 
00099       // get help message text
00100       void GetHelpMessage() const;
00101 
00102    private:
00103 
00104       // the option handling methods
00105       void DeclareOptions();
00106       void ProcessOptions();
00107       void DeclareCompatibilityOptions();
00108 
00109       // default initialisation called by all constructors
00110       void Init( void );
00111 
00112       // create kd-tree (binary tree) structure
00113       void MakeKNN( void );
00114 
00115       // polynomial and Gaussian kernel weight function
00116       Double_t PolnKernel(Double_t value) const;
00117       Double_t GausKernel(const kNN::Event &event_knn, const kNN::Event &event, const std::vector<Double_t> &svec) const;
00118 
00119       Double_t getKernelRadius(const kNN::List &rlist) const;
00120       const std::vector<Double_t> getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const;
00121       
00122       double getLDAValue(const kNN::List &rlist, const kNN::Event &event_knn);
00123 
00124    private:
00125 
00126       // number of events (sumOfWeights)
00127       Double_t fSumOfWeightsS;        // sum-of-weights for signal training events
00128       Double_t fSumOfWeightsB;        // sum-of-weights for background training events      
00129 
00130       kNN::ModulekNN *fModule;        //! module where all work is done
00131 
00132       Int_t fnkNN;            // number of k-nearest neighbors 
00133       Int_t fBalanceDepth;    // number of binary tree levels used for balancing tree
00134 
00135       Float_t fScaleFrac;     // fraction of events used to compute variable width
00136       Float_t fSigmaFact;     // scale factor for Gaussian sigma in Gaus. kernel
00137 
00138       TString fKernel;        // ="Gaus","Poln" - kernel type for smoothing
00139 
00140       Bool_t fTrim;           // set equal number of signal and background events
00141       Bool_t fUseKernel;      // use polynomial kernel weight function
00142       Bool_t fUseWeight;      // use weights to count kNN
00143       Bool_t fUseLDA;         // use local linear discriminat analysis to compute MVA
00144 
00145       kNN::EventVec fEvent;   //! (untouched) events used for learning
00146 
00147       LDA fLDA;               //! Experimental feature for local knn analysis
00148 
00149       // for backward compatibility
00150       Int_t fTreeOptDepth;    // number of binary tree levels used for optimization
00151 
00152       ClassDef(MethodKNN,0) // k Nearest Neighbour classifier
00153    };
00154 
00155 } // namespace TMVA
00156 
00157 #endif // MethodKNN

Generated on Tue Jul 5 14:27:31 2011 for ROOT_528-00b_version by  doxygen 1.5.1