00001 // @(#)root/tmva $Id: TSynapse.h 33928 2010-06-15 16:19:31Z stelzer $ 00002 // Author: Matt Jachowski 00003 00004 /********************************************************************************** 00005 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis * 00006 * Package: TMVA * 00007 * Class : TMVA::TSynapse * 00008 * Web : http://tmva.sourceforge.net * 00009 * * 00010 * Description: * 00011 * Synapse class for use in derivatives of MethodANNBase * 00012 * * 00013 * Authors (alphabetical): * 00014 * Matt Jachowski <jachowski@stanford.edu> - Stanford University, USA * 00015 * * 00016 * Copyright (c) 2005: * 00017 * CERN, Switzerland * 00018 * * 00019 * Redistribution and use in source and binary forms, with or without * 00020 * modification, are permitted according to the terms listed in LICENSE * 00021 * (http://tmva.sourceforge.net/LICENSE) * 00022 **********************************************************************************/ 00023 00024 #ifndef ROOT_TMVA_TSynapse 00025 #define ROOT_TMVA_TSynapse 00026 00027 ////////////////////////////////////////////////////////////////////////// 00028 // // 00029 // TSynapse // 00030 // // 00031 // Synapse used by derivatives of MethodANNBase // 00032 // // 00033 ////////////////////////////////////////////////////////////////////////// 00034 00035 #ifndef ROOT_TString 00036 #include "TString.h" 00037 #endif 00038 #ifndef ROOT_TFormula 00039 #include "TFormula.h" 00040 #endif 00041 00042 00043 namespace TMVA { 00044 00045 class TNeuron; 00046 class MsgLogger; 00047 00048 class TSynapse : public TObject { 00049 00050 public: 00051 00052 TSynapse(); 00053 virtual ~TSynapse(); 00054 00055 // set the weight of the synapse 00056 void SetWeight(Double_t weight); 00057 00058 // get the weight of the synapse 00059 Double_t GetWeight() { return fWeight; } 00060 00061 // set the learning rate 00062 void SetLearningRate(Double_t rate) { fLearnRate = rate; } 00063 00064 // get the learning rate 00065 Double_t GetLearningRate() { return fLearnRate; } 00066 00067 // decay the learning rate 00068 void DecayLearningRate(Double_t rate){ fLearnRate *= (1-rate); } 00069 00070 // set the pre-neuron 00071 void SetPreNeuron(TNeuron* pre) { fPreNeuron = pre; } 00072 00073 // set hte post-neuron 00074 void SetPostNeuron(TNeuron* post) { fPostNeuron = post; } 00075 00076 // get the weighted output of the pre-neuron 00077 Double_t GetWeightedValue(); 00078 00079 // get the weighted error field of the post-neuron 00080 Double_t GetWeightedDelta(); 00081 00082 // force the synapse to adjust its weight according to its error field 00083 void AdjustWeight(); 00084 00085 // calulcate the error field of the synapse 00086 void CalculateDelta(); 00087 00088 // initialize the error field of the synpase to 0 00089 void InitDelta() { fDelta = 0.0; fCount = 0; } 00090 00091 void SetDEDw(Double_t DEDw) { fDEDw = DEDw; } 00092 Double_t GetDEDw() { return fDEDw; } 00093 Double_t GetDelta() { return fDelta; } 00094 00095 private: 00096 00097 Double_t fWeight; // weight of the synapse 00098 Double_t fLearnRate; // learning rate parameter 00099 Double_t fDelta; // local error field 00100 Double_t fDEDw; // sum of deltas 00101 Int_t fCount; // number of updates contributing to error field 00102 TNeuron* fPreNeuron; // pointer to pre-neuron 00103 TNeuron* fPostNeuron; // pointer to post-neuron 00104 00105 static MsgLogger* fgLogger; //! message logger, static to save resources 00106 MsgLogger& Log() const { return *fgLogger; } 00107 00108 ClassDef(TSynapse,0) // Synapse class used by MethodANNBase and derivatives 00109 }; 00110 00111 } // namespace TMVA 00112 00113 #endif