00001 // @(#)root/tmva $Id: TNeuron.h 33928 2010-06-15 16:19:31Z stelzer $ 00002 // Author: Matt Jachowski 00003 00004 /********************************************************************************** 00005 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis * 00006 * Package: TMVA * 00007 * Class : TMVA::TNeuron * 00008 * Web : http://tmva.sourceforge.net * 00009 * * 00010 * Description: * 00011 * Neuron class to be used in MethodANNBase and its derivatives. * 00012 * * 00013 * Authors (alphabetical): * 00014 * Matt Jachowski <jachowski@stanford.edu> - Stanford University, USA * 00015 * * 00016 * Copyright (c) 2005: * 00017 * CERN, Switzerland * 00018 * * 00019 * Redistribution and use in source and binary forms, with or without * 00020 * modification, are permitted according to the terms listed in LICENSE * 00021 * (http://tmva.sourceforge.net/LICENSE) * 00022 **********************************************************************************/ 00023 00024 #ifndef ROOT_TMVA_TNeuron 00025 #define ROOT_TMVA_TNeuron 00026 00027 ////////////////////////////////////////////////////////////////////////// 00028 // // 00029 // TNeuron // 00030 // // 00031 // Neuron used by derivatives of MethodANNBase // 00032 // // 00033 ////////////////////////////////////////////////////////////////////////// 00034 00035 #include <iostream> 00036 00037 #ifndef ROOT_TString 00038 #include "TString.h" 00039 #endif 00040 #ifndef ROOT_TObjArray 00041 #include "TObjArray.h" 00042 #endif 00043 #ifndef ROOT_TFormula 00044 #include "TFormula.h" 00045 #endif 00046 00047 #ifndef ROOT_TMVA_TSynapse 00048 #include "TMVA/TSynapse.h" 00049 #endif 00050 #ifndef ROOT_TMVA_TActivation 00051 #include "TMVA/TActivation.h" 00052 #endif 00053 #ifndef ROOT_TMVA_Types 00054 #include "TMVA/Types.h" 00055 #endif 00056 00057 namespace TMVA { 00058 00059 class TNeuronInput; 00060 00061 class TNeuron : public TObject { 00062 00063 public: 00064 00065 TNeuron(); 00066 virtual ~TNeuron(); 00067 00068 // force the input value 00069 void ForceValue(Double_t value); 00070 00071 // calculate the input value 00072 void CalculateValue(); 00073 00074 // calculate the activation value 00075 void CalculateActivationValue(); 00076 00077 // calculate the error field of the neuron 00078 void CalculateDelta(); 00079 00080 // set the activation function 00081 void SetActivationEqn(TActivation* activation); 00082 00083 // set the input calculator 00084 void SetInputCalculator(TNeuronInput* calculator); 00085 00086 // add a synapse as a pre-link 00087 void AddPreLink(TSynapse* pre); 00088 00089 // add a synapse as a post-link 00090 void AddPostLink(TSynapse* post); 00091 00092 // delete all pre-links 00093 void DeletePreLinks(); 00094 00095 // set the error 00096 void SetError(Double_t error); 00097 00098 // update the error fields of all pre-synapses, batch mode 00099 // to actually update the weights, call adjust synapse weights 00100 void UpdateSynapsesBatch(); 00101 00102 // update the error fields and weights of all pre-synapses, sequential mode 00103 void UpdateSynapsesSequential(); 00104 00105 // update the weights of the all pre-synapses, batch mode 00106 //(call UpdateSynapsesBatch first) 00107 void AdjustSynapseWeights(); 00108 00109 // explicitly initialize error fields of pre-synapses, batch mode 00110 void InitSynapseDeltas(); 00111 00112 // print activation equation, for debugging 00113 void PrintActivationEqn(); 00114 00115 // inlined functions 00116 Double_t GetValue() const { return fValue; } 00117 Double_t GetActivationValue() const { return fActivationValue; } 00118 Double_t GetDelta() const { return fDelta; } 00119 Double_t GetDEDw() const { return fDEDw; } 00120 Int_t NumPreLinks() const { return NumLinks(fLinksIn); } 00121 Int_t NumPostLinks() const { return NumLinks(fLinksOut); } 00122 TSynapse* PreLinkAt ( Int_t index ) const { return (TSynapse*)fLinksIn->At(index); } 00123 TSynapse* PostLinkAt( Int_t index ) const { return (TSynapse*)fLinksOut->At(index); } 00124 void SetInputNeuron() { NullifyLinks(fLinksIn); } 00125 void SetOutputNeuron() { NullifyLinks(fLinksOut); } 00126 void SetBiasNeuron() { NullifyLinks(fLinksIn); } 00127 void SetDEDw( Double_t DEDw ) { fDEDw = DEDw; } 00128 Bool_t IsInputNeuron() const { return fLinksIn == NULL; } 00129 Bool_t IsOutputNeuron() const { return fLinksOut == NULL; } 00130 void PrintPreLinks() const { PrintLinks(fLinksIn); return; } 00131 void PrintPostLinks() const { PrintLinks(fLinksOut); return; } 00132 00133 virtual void Print(Option_t* = "") const { 00134 std::cout << fValue << std::endl; 00135 //PrintPreLinks(); PrintPostLinks(); 00136 } 00137 00138 private: 00139 00140 // prviate helper functions 00141 void InitNeuron(); 00142 void DeleteLinksArray( TObjArray*& links ); 00143 void PrintLinks ( TObjArray* links ) const; 00144 void PrintMessage ( EMsgType, TString message ); 00145 00146 // inlined helper functions 00147 Int_t NumLinks(TObjArray* links) const { 00148 if (links == NULL) return 0; return links->GetEntriesFast(); 00149 } 00150 void NullifyLinks(TObjArray*& links) { 00151 if (links != NULL) delete links; links = NULL; 00152 } 00153 00154 // private member variables 00155 TObjArray* fLinksIn; // array of input synapses 00156 TObjArray* fLinksOut; // array of output synapses 00157 Double_t fValue; // input value 00158 Double_t fActivationValue; // activation/output value 00159 Double_t fDelta; // error field of neuron 00160 Double_t fDEDw; // sum of all deltas 00161 Double_t fError; // error, only set for output neurons 00162 Bool_t fForcedValue; // flag for forced input value 00163 TActivation* fActivation; // activation equation 00164 TNeuronInput* fInputCalculator; // input calculator 00165 00166 static MsgLogger* fgLogger; //! message logger, static to save resources 00167 MsgLogger& Log() const { return *fgLogger; } 00168 00169 ClassDef(TNeuron,0) // Neuron class used by MethodANNBase derivative ANNs 00170 }; 00171 00172 } // namespace TMVA 00173 00174 #endif