SVWorkingSet.cxx

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: SVWorkingSet.cxx 31458 2009-11-30 13:58:20Z stelzer $    
00002 // Author: Andrzej Zemla
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : SVWorkingSet                                                          *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      Implementation                                                            *
00012  *                                                                                *
00013  * Authors (alphabetical):                                                        *
00014  *      Marcin Wolter  <Marcin.Wolter@cern.ch> - IFJ PAN, Krakow, Poland          *
00015  *      Andrzej Zemla  <azemla@cern.ch>        - IFJ PAN, Krakow, Poland          *
00016  *      (IFJ PAN: Henryk Niewodniczanski Inst. Nucl. Physics, Krakow, Poland)     *   
00017  *                                                                                *
00018  * Copyright (c) 2005:                                                            *
00019  *      CERN, Switzerland                                                         * 
00020  *      MPI-K Heidelberg, Germany                                                 * 
00021  *      PAN, Krakow, Poland                                                       *
00022  *                                                                                *
00023  * Redistribution and use in source and binary forms, with or without             *
00024  * modification, are permitted according to the terms listed in LICENSE           *
00025  * (http://tmva.sourceforge.net/LICENSE)                                          *
00026  **********************************************************************************/
00027 
00028 #include "TMath.h"
00029 #include "TRandom3.h"
00030 
00031 #ifndef ROOT_TMVA_MsgLogger
00032 #include "TMVA/MsgLogger.h"
00033 #endif
00034 #include "TMVA/SVWorkingSet.h"
00035 #include "TMVA/SVKernelFunction.h"
00036 #include "TMVA/SVEvent.h"
00037 #include "TMVA/SVKernelMatrix.h"
00038 
00039 #include <vector>
00040 #include <iostream>
00041 
00042 //_______________________________________________________________________
00043 TMVA::SVWorkingSet::SVWorkingSet() 
00044    : fdoRegression(kFALSE),
00045      fInputData(0),
00046      fSupVec(0),
00047      fKFunction(0),
00048      fKMatrix(0),
00049      fTEventUp(0),
00050      fTEventLow(0),
00051      fB_low(1.),
00052      fB_up(-1.),
00053      fTolerance(0.01),
00054      fLogger( new MsgLogger( "SVWorkingSet", kINFO ) )
00055 {
00056    // constructor
00057 }  
00058 
00059 //_______________________________________________________________________
00060 TMVA::SVWorkingSet::SVWorkingSet(std::vector<TMVA::SVEvent*>*inputVectors, SVKernelFunction* kernelFunction,Float_t tol, Bool_t doreg)
00061    : fdoRegression(doreg),
00062      fInputData(inputVectors),
00063      fSupVec(0),
00064      fKFunction(kernelFunction),
00065      fTEventUp(0),
00066      fTEventLow(0),
00067      fB_low(1.),
00068      fB_up(-1.),
00069      fTolerance(tol),      
00070      fLogger( new MsgLogger( "SVWorkingSet", kINFO ) )
00071 {
00072    // constructor
00073    fKMatrix = new TMVA::SVKernelMatrix(inputVectors, kernelFunction);
00074    Float_t *pt;
00075    for( UInt_t i = 0; i < fInputData->size(); i++){ 
00076       pt = fKMatrix->GetLine(i);
00077       fInputData->at(i)->SetLine(pt);
00078       fInputData->at(i)->SetNs(i);
00079       if(fdoRegression) fInputData->at(i)->SetErrorCache(fInputData->at(i)->GetTarget());
00080    }
00081    TRandom3 rand;
00082    UInt_t kk = rand.Integer(fInputData->size());
00083    if(fdoRegression) {
00084       fTEventLow = fTEventUp =fInputData->at(0);
00085       fB_low = fTEventUp ->GetTarget() - fTolerance;
00086       fB_up  = fTEventLow->GetTarget() + fTolerance;
00087    }
00088    else{
00089       while(1){
00090          if(fInputData->at(kk)->GetTypeFlag()==-1){ 
00091             fTEventLow = fInputData->at(kk);
00092             break;
00093          }
00094          kk = rand.Integer(fInputData->size());
00095       }
00096    
00097       while (1){
00098          if (fInputData->at(kk)->GetTypeFlag()==1) {
00099             fTEventUp = fInputData->at(kk);
00100             break;
00101          }
00102          kk = rand.Integer(fInputData->size());
00103       }
00104    }
00105    fTEventUp ->SetErrorCache(fTEventUp->GetTarget());
00106    fTEventLow->SetErrorCache(fTEventUp->GetTarget());
00107 }
00108 
00109 //_______________________________________________________________________
00110 TMVA::SVWorkingSet::~SVWorkingSet() 
00111 {
00112    // destructor
00113    if (fKMatrix   != 0) {delete fKMatrix; fKMatrix = 0;}
00114    delete fLogger;
00115 }
00116 
00117 //_______________________________________________________________________
00118 Bool_t TMVA::SVWorkingSet::ExamineExample( TMVA::SVEvent* jevt ) 
00119 {   
00120    SVEvent* ievt=0;
00121    Float_t fErrorC_J = 0.;
00122    if( jevt->GetIdx()==0) fErrorC_J = jevt->GetErrorCache();
00123    else{
00124       Float_t *fKVals = jevt->GetLine();
00125       fErrorC_J = 0.;
00126       std::vector<TMVA::SVEvent*>::iterator fIDIter;
00127       
00128       UInt_t k=0;
00129       for(fIDIter = fInputData->begin(); fIDIter != fInputData->end(); fIDIter++){
00130          if((*fIDIter)->GetAlpha()>0)
00131             fErrorC_J += (*fIDIter)->GetAlpha()*(*fIDIter)->GetTypeFlag()*fKVals[k];
00132          k++;
00133       }
00134       
00135      
00136       fErrorC_J -= jevt->GetTypeFlag();
00137       jevt->SetErrorCache(fErrorC_J);
00138       
00139       if((jevt->GetIdx() == 1) && (fErrorC_J < fB_up )){
00140          fB_up = fErrorC_J;
00141          fTEventUp = jevt;
00142       }
00143       else if ((jevt->GetIdx() == -1)&&(fErrorC_J > fB_low)) {
00144          fB_low = fErrorC_J;
00145          fTEventLow = jevt;
00146       }
00147    }
00148    Bool_t converged = kTRUE;
00149    
00150    if((jevt->GetIdx()>=0) && (fB_low - fErrorC_J > 2*fTolerance)) {
00151       converged = kFALSE;
00152       ievt = fTEventLow; 
00153    }
00154    
00155    if((jevt->GetIdx()<=0) && (fErrorC_J - fB_up > 2*fTolerance)) {
00156       converged = kFALSE;
00157       ievt = fTEventUp; 
00158    }
00159       
00160    if (converged) return kFALSE;
00161       
00162    if(jevt->GetIdx()==0){
00163       if(fB_low - fErrorC_J > fErrorC_J - fB_up) ievt = fTEventLow;
00164       else                                       ievt = fTEventUp;
00165    }
00166       
00167    if (TakeStep(ievt, jevt)) return kTRUE;
00168    else                      return kFALSE;
00169 }
00170 
00171 
00172 //_______________________________________________________________________
00173 Bool_t TMVA::SVWorkingSet::TakeStep(TMVA::SVEvent* ievt,TMVA::SVEvent* jevt ) 
00174 {
00175    if (ievt == jevt) return kFALSE;
00176    std::vector<TMVA::SVEvent*>::iterator fIDIter;
00177    const Float_t epsilon = 1e-8; //make it 1-e6 or 1-e5 to make it faster
00178    
00179    Float_t type_I,  type_J;
00180    Float_t errorC_I,  errorC_J;
00181    Float_t alpha_I, alpha_J;
00182    
00183    Float_t newAlpha_I, newAlpha_J;
00184    Int_t   s;  
00185    
00186    Float_t l, h, lobj = 0, hobj = 0;
00187    Float_t eta;
00188 
00189    type_I   = ievt->GetTypeFlag();
00190    alpha_I  = ievt->GetAlpha();
00191    errorC_I = ievt->GetErrorCache();
00192 
00193    type_J   = jevt->GetTypeFlag();
00194    alpha_J  = jevt->GetAlpha();
00195    errorC_J = jevt->GetErrorCache();
00196     
00197    s = Int_t( type_I * type_J );
00198 
00199    Float_t c_i = ievt->GetCweight();
00200    
00201    Float_t c_j =  jevt->GetCweight(); 
00202    
00203    // compute l, h objective function
00204 
00205    if (type_I == type_J) {
00206       Float_t gamma = alpha_I + alpha_J;
00207       
00208       if ( c_i > c_j ) {
00209          if ( gamma < c_j ) {
00210             l = 0;
00211             h = gamma;
00212          }
00213          else{
00214             h = c_j;
00215             if ( gamma < c_i )
00216                l = 0;
00217             else
00218                l = gamma - c_i;
00219          }
00220       }           
00221       else {
00222          if ( gamma < c_i ){
00223             l = 0;
00224             h = gamma;
00225          }
00226          else {
00227             l = gamma - c_i;
00228             if ( gamma < c_j )
00229                h = gamma;
00230             else
00231                h = c_j;
00232          }
00233       }
00234    }
00235    else {
00236       Float_t gamma = alpha_I - alpha_J;
00237       if (gamma > 0) {
00238          l = 0;
00239          if ( gamma >= (c_i - c_j) ) 
00240             h = c_i - gamma;
00241          else
00242             h = c_j;
00243       }
00244       else {
00245          l = -gamma;
00246          if ( (c_i - c_j) >= gamma)
00247             h = c_j;
00248          else 
00249             h = c_i - gamma;
00250       }
00251    }
00252   
00253    if (l == h)  return kFALSE;
00254    Float_t kernel_II, kernel_IJ, kernel_JJ;
00255 
00256    kernel_II = fKMatrix->GetElement(ievt->GetNs(),ievt->GetNs());
00257    kernel_IJ = fKMatrix->GetElement(ievt->GetNs(), jevt->GetNs());
00258    kernel_JJ = fKMatrix->GetElement(jevt->GetNs(),jevt->GetNs());
00259    
00260    eta = 2*kernel_IJ - kernel_II - kernel_JJ; 
00261    if (eta < 0) {
00262       newAlpha_J = alpha_J + (type_J*( errorC_J - errorC_I ))/eta;
00263       if      (newAlpha_J < l) newAlpha_J = l;
00264       else if (newAlpha_J > h) newAlpha_J = h;
00265       
00266    }
00267 
00268    else {
00269 
00270       Float_t c_I = eta/2;
00271       Float_t c_J = type_J*( errorC_I - errorC_J ) - eta * alpha_J;
00272       lobj = c_I * l * l + c_J * l;
00273       hobj = c_I * h * h + c_J * h;
00274 
00275       if      (lobj > hobj + epsilon)  newAlpha_J = l;
00276       else if (lobj < hobj - epsilon)  newAlpha_J = h; 
00277       else                              newAlpha_J = alpha_J;
00278    }
00279 
00280    if (TMath::Abs( newAlpha_J - alpha_J ) < ( epsilon * ( newAlpha_J + alpha_J+ epsilon ))){
00281       return kFALSE;
00282       //it spends here to much time... it is stupido
00283    }
00284    newAlpha_I = alpha_I - s*( newAlpha_J - alpha_J );
00285 
00286    if (newAlpha_I < 0) {
00287       newAlpha_J += s* newAlpha_I;
00288       newAlpha_I = 0;
00289    }
00290    else if (newAlpha_I > c_i) {
00291       Float_t temp = newAlpha_I - c_i;
00292       newAlpha_J += s * temp;
00293       newAlpha_I = c_i;
00294    }
00295   
00296    Float_t dL_I = type_I * ( newAlpha_I - alpha_I );
00297    Float_t dL_J = type_J * ( newAlpha_J - alpha_J );  
00298 
00299    Int_t k = 0; 
00300    for(fIDIter = fInputData->begin(); fIDIter != fInputData->end(); fIDIter++){
00301       k++;
00302       if((*fIDIter)->GetIdx()==0){
00303          Float_t ii = fKMatrix->GetElement(ievt->GetNs(), (*fIDIter)->GetNs());
00304          Float_t jj = fKMatrix->GetElement(jevt->GetNs(), (*fIDIter)->GetNs());
00305          
00306          (*fIDIter)->UpdateErrorCache(dL_I * ii + dL_J * jj);       
00307       }
00308    }
00309    ievt->SetAlpha(newAlpha_I);
00310    jevt->SetAlpha(newAlpha_J);
00311    // set new indexes
00312    SetIndex(ievt);
00313    SetIndex(jevt);
00314 
00315    // update error cache
00316    ievt->SetErrorCache(errorC_I + dL_I*kernel_II + dL_J*kernel_IJ);
00317    jevt->SetErrorCache(errorC_J + dL_I*kernel_IJ + dL_J*kernel_JJ);
00318 
00319    // compute fI_low, fB_low
00320 
00321    fB_low = -1*1e30;
00322    fB_up = 1e30;
00323    
00324    for(fIDIter = fInputData->begin(); fIDIter != fInputData->end(); fIDIter++){
00325       if((*fIDIter)->GetIdx()==0){
00326          if((*fIDIter)->GetErrorCache()> fB_low){
00327             fB_low = (*fIDIter)->GetErrorCache();
00328             fTEventLow = (*fIDIter);
00329          }
00330          if( (*fIDIter)->GetErrorCache()< fB_up){
00331             fB_up =(*fIDIter)->GetErrorCache();
00332             fTEventUp = (*fIDIter);
00333          }
00334       }
00335    }                             
00336 
00337    // for optimized alfa's
00338    if (fB_low < TMath::Max(ievt->GetErrorCache(), jevt->GetErrorCache())) {
00339       if (ievt->GetErrorCache() > fB_low) {
00340          fB_low = ievt->GetErrorCache();
00341          fTEventLow = ievt;
00342       }
00343       else {
00344          fB_low = jevt->GetErrorCache();
00345          fTEventLow = jevt;
00346       }
00347    }
00348   
00349    if (fB_up > TMath::Max(ievt->GetErrorCache(), jevt->GetErrorCache())) {
00350       if (ievt->GetErrorCache()< fB_low) {
00351          fB_up =ievt->GetErrorCache();
00352          fTEventUp = ievt;
00353       }
00354       else {
00355          fB_up =jevt->GetErrorCache() ;
00356          fTEventUp = jevt;
00357       }
00358    }  
00359    return kTRUE;
00360 }
00361 
00362 //_______________________________________________________________________
00363 Bool_t  TMVA::SVWorkingSet::Terminated() 
00364 {
00365    if((fB_up > fB_low - 2*fTolerance)) return kTRUE;
00366    return kFALSE;
00367 }
00368 
00369 //_______________________________________________________________________
00370 void TMVA::SVWorkingSet::Train(UInt_t nMaxIter) 
00371 {
00372    // train the SVM
00373    
00374    
00375    Int_t numChanged  = 0;
00376    Int_t examineAll  = 1;
00377 
00378    Float_t numChangedOld = 0;
00379    Int_t deltaChanges = 0;
00380    UInt_t numit    = 0;
00381    
00382    std::vector<TMVA::SVEvent*>::iterator fIDIter;
00383 
00384    while ((numChanged > 0) || (examineAll > 0)) {
00385       numChanged = 0;
00386       if (examineAll) {
00387          for (fIDIter = fInputData->begin(); fIDIter!=fInputData->end(); fIDIter++){
00388             if(!fdoRegression) numChanged += (UInt_t)ExamineExample(*fIDIter);
00389             else numChanged += (UInt_t)ExamineExampleReg(*fIDIter);
00390          }    
00391       }
00392       else {
00393          for (fIDIter = fInputData->begin(); fIDIter!=fInputData->end(); fIDIter++) {
00394             if ((*fIDIter)->IsInI0()) {
00395                if(!fdoRegression) numChanged += (UInt_t)ExamineExample(*fIDIter);
00396                else numChanged += (UInt_t)ExamineExampleReg(*fIDIter);
00397                if (Terminated()) {
00398                   numChanged = 0;
00399                   break;
00400                }
00401             }
00402          }
00403       }
00404 
00405       if      (examineAll == 1) examineAll = 0;
00406       else if (numChanged == 0 || numChanged < 10 || deltaChanges > 3 ) examineAll = 1;
00407 
00408       if (numChanged == numChangedOld) deltaChanges++;
00409       else                             deltaChanges = 0;
00410       numChangedOld = numChanged;
00411       ++numit;
00412 
00413       if (numit >= nMaxIter) {
00414          *fLogger << kWARNING 
00415                   << "Max number of iterations exceeded. "
00416                   << "Training may not be completed. Try use less Cost parameter" << Endl;
00417          break;
00418       }
00419    }
00420 }
00421 
00422 //_______________________________________________________________________
00423 void TMVA::SVWorkingSet::SetIndex( TMVA::SVEvent* event ) 
00424 {
00425    if( (0< event->GetAlpha()) && (event->GetAlpha()< event->GetCweight()))
00426       event->SetIdx(0);
00427 
00428    if( event->GetTypeFlag() == 1){
00429       if( event->GetAlpha() == 0)
00430          event->SetIdx(1);
00431       else if( event->GetAlpha() == event->GetCweight() )
00432          event->SetIdx(-1);
00433    }
00434    if( event->GetTypeFlag() == -1){
00435       if( event->GetAlpha() == 0)
00436          event->SetIdx(-1);
00437       else if( event->GetAlpha() == event->GetCweight() )
00438          event->SetIdx(1);
00439    }
00440 }
00441 
00442 //_______________________________________________________________________
00443 void TMVA::SVWorkingSet::PrintStat() 
00444 {
00445    std::vector<TMVA::SVEvent*>::iterator fIDIter;
00446    UInt_t counter = 0;
00447    for( fIDIter = fInputData->begin(); fIDIter != fInputData->end(); fIDIter++)
00448       if((*fIDIter)->GetAlpha() !=0) counter++;
00449 }
00450 
00451 //_______________________________________________________________________
00452 std::vector<TMVA::SVEvent*>* TMVA::SVWorkingSet::GetSupportVectors() 
00453 {
00454    std::vector<TMVA::SVEvent*>::iterator fIDIter;
00455    if( fSupVec != 0) {delete fSupVec; fSupVec = 0; }
00456    fSupVec = new std::vector<TMVA::SVEvent*>(0);
00457    
00458    for( fIDIter = fInputData->begin(); fIDIter != fInputData->end(); fIDIter++){
00459       if((*fIDIter)->GetDeltaAlpha() !=0){
00460          fSupVec->push_back((*fIDIter));
00461       }
00462    }
00463    return fSupVec;
00464 }
00465 
00466 //for regression
00467 
00468 Bool_t TMVA::SVWorkingSet::TakeStepReg(TMVA::SVEvent* ievt,TMVA::SVEvent* jevt )
00469 {
00470    if (ievt == jevt) return kFALSE;
00471    std::vector<TMVA::SVEvent*>::iterator fIDIter;
00472    const Float_t epsilon = 0.001*fTolerance;//TODO
00473 
00474    const Float_t kernel_II = fKMatrix->GetElement(ievt->GetNs(),ievt->GetNs());
00475    const Float_t kernel_IJ = fKMatrix->GetElement(ievt->GetNs(),jevt->GetNs());
00476    const Float_t kernel_JJ = fKMatrix->GetElement(jevt->GetNs(),jevt->GetNs());
00477    
00478    //compute eta & gamma
00479    const Float_t eta = -2*kernel_IJ + kernel_II + kernel_JJ; 
00480    const Float_t gamma = ievt->GetDeltaAlpha() + jevt->GetDeltaAlpha(); 
00481    
00482    //TODO CHECK WHAT IF ETA <0 
00483    //w.r.t Mercer's conditions it should never happen, but what if?
00484    
00485    Bool_t caseA, caseB, caseC, caseD, terminated;
00486    caseA = caseB = caseC = caseD = terminated = kFALSE;
00487    Float_t b_alpha_i, b_alpha_j, b_alpha_i_p, b_alpha_j_p; //temporary lagrange multipliers
00488    const Float_t b_cost_i = ievt->GetCweight();
00489    const Float_t b_cost_j = jevt->GetCweight();
00490 
00491    b_alpha_i   = ievt->GetAlpha();
00492    b_alpha_j   = jevt->GetAlpha();
00493    b_alpha_i_p = ievt->GetAlpha_p();
00494    b_alpha_j_p = jevt->GetAlpha_p();
00495 
00496    //calculate deltafi
00497    Float_t deltafi = ievt->GetErrorCache()-jevt->GetErrorCache();
00498    
00499    // main loop
00500    while(!terminated) {
00501       const Float_t null = 0.; //!!! dummy float null declaration because of problems with TMath::Max/Min(Float_t, Float_t) function
00502       Float_t low, high;
00503       Float_t tmp_alpha_i, tmp_alpha_j;
00504       tmp_alpha_i = tmp_alpha_j = 0.;
00505       
00506       //TODO check this conditions, are they proper
00507       if((caseA == kFALSE) && (b_alpha_i > 0 || (b_alpha_i_p == 0 && deltafi > 0)) && (b_alpha_j > 0 || (b_alpha_j_p == 0 && deltafi < 0)))
00508          {
00509             //compute low, high w.r.t a_i, a_j
00510             low  = TMath::Max( null, gamma - b_cost_j );
00511             high = TMath::Min( b_cost_i , gamma);
00512          
00513             if(low<high){
00514                tmp_alpha_j = b_alpha_j - (deltafi/eta);
00515                tmp_alpha_j = TMath::Min(tmp_alpha_j,high      );
00516                tmp_alpha_j = TMath::Max(low        ,tmp_alpha_j);
00517                tmp_alpha_i = b_alpha_i - (tmp_alpha_j - b_alpha_j);
00518             
00519                //update Li & Lj if change is significant (??)
00520                if( IsDiffSignificant(b_alpha_j,tmp_alpha_j, epsilon) ||  IsDiffSignificant(b_alpha_i,tmp_alpha_i, epsilon)){
00521                   b_alpha_j = tmp_alpha_j;
00522                   b_alpha_i = tmp_alpha_i;
00523                }
00524             
00525             }
00526             else
00527                terminated = kTRUE;
00528          
00529             caseA = kTRUE;
00530          }
00531       else if((caseB==kFALSE) && (b_alpha_i>0 || (b_alpha_i_p==0 && deltafi >2*epsilon )) && (b_alpha_j_p>0 || (b_alpha_j==0 && deltafi>2*epsilon)))
00532          {
00533             //compute LH w.r.t. a_i, a_j*
00534             low  = TMath::Max( null, gamma );  //TODO 
00535             high = TMath::Min( b_cost_i , b_cost_j + gamma);
00536 
00537          
00538             if(low<high){
00539                tmp_alpha_j = b_alpha_j_p - ((deltafi-2*epsilon)/eta);
00540                tmp_alpha_j = TMath::Min(tmp_alpha_j,high);
00541                tmp_alpha_j = TMath::Max(low,tmp_alpha_j);
00542                tmp_alpha_i = b_alpha_i - (tmp_alpha_j - b_alpha_j_p);
00543             
00544                //update alphai alphaj_p
00545                if( IsDiffSignificant(b_alpha_j_p,tmp_alpha_j, epsilon) ||  IsDiffSignificant(b_alpha_i,tmp_alpha_i, epsilon)){
00546                   b_alpha_j_p = tmp_alpha_j;
00547                   b_alpha_i   = tmp_alpha_i;
00548                }
00549             }
00550             else
00551                terminated = kTRUE;
00552          
00553             caseB = kTRUE;
00554          }
00555       else if((caseC==kFALSE) && (b_alpha_i_p>0 || (b_alpha_i==0 && deltafi < -2*epsilon )) && (b_alpha_j>0 || (b_alpha_j_p==0 && deltafi< -2*epsilon)))
00556          {
00557             //compute LH w.r.t. alphai_p alphaj
00558             low  = TMath::Max(null, -gamma  );
00559             high = TMath::Min(b_cost_i, -gamma+b_cost_j);
00560          
00561             if(low<high){
00562                tmp_alpha_j = b_alpha_j - ((deltafi+2*epsilon)/eta);
00563                tmp_alpha_j = TMath::Min(tmp_alpha_j,high      );
00564                tmp_alpha_j = TMath::Max(low        ,tmp_alpha_j);
00565                tmp_alpha_i = b_alpha_i_p - (tmp_alpha_j - b_alpha_j);
00566             
00567                //update alphai_p alphaj
00568                if( IsDiffSignificant(b_alpha_j,tmp_alpha_j, epsilon) ||  IsDiffSignificant(b_alpha_i_p,tmp_alpha_i, epsilon)){
00569                   b_alpha_j     = tmp_alpha_j;
00570                   b_alpha_i_p   = tmp_alpha_i;
00571                } 
00572             }
00573             else
00574                terminated = kTRUE;
00575          
00576             caseC = kTRUE;
00577          }
00578       else if((caseD == kFALSE) && 
00579               (b_alpha_i_p>0 || (b_alpha_i==0 && deltafi <0 )) && 
00580               (b_alpha_j_p>0 || (b_alpha_j==0 && deltafi >0 )))
00581          {
00582             //compute LH w.r.t. alphai_p alphaj_p
00583             low  = TMath::Max(null,-gamma - b_cost_j);
00584             high = TMath::Min(b_cost_i, -gamma);
00585          
00586             if(low<high){
00587                tmp_alpha_j = b_alpha_j_p + (deltafi/eta);
00588                tmp_alpha_j = TMath::Min(tmp_alpha_j,high      );
00589                tmp_alpha_j = TMath::Max(low        ,tmp_alpha_j);
00590                tmp_alpha_i = b_alpha_i_p - (tmp_alpha_j - b_alpha_j_p);
00591             
00592                if( IsDiffSignificant(b_alpha_j_p,tmp_alpha_j, epsilon) ||  IsDiffSignificant(b_alpha_i_p,tmp_alpha_i, epsilon)){
00593                   b_alpha_j_p   = tmp_alpha_j;
00594                   b_alpha_i_p   = tmp_alpha_i;
00595                } 
00596             }
00597             else
00598                terminated = kTRUE;
00599          
00600             caseD = kTRUE;
00601          }
00602       else
00603          terminated = kTRUE;
00604    }
00605    // TODO ad commment how it was calculated
00606    deltafi += ievt->GetDeltaAlpha()*(kernel_II - kernel_IJ) + jevt->GetDeltaAlpha()*(kernel_IJ - kernel_JJ); 
00607 
00608    if( IsDiffSignificant(b_alpha_i, ievt->GetAlpha(), epsilon) ||
00609        IsDiffSignificant(b_alpha_j, jevt->GetAlpha(), epsilon) ||
00610        IsDiffSignificant(b_alpha_i_p, ievt->GetAlpha_p(), epsilon) ||
00611        IsDiffSignificant(b_alpha_j_p, jevt->GetAlpha_p(), epsilon) ){
00612          
00613       //TODO check if these conditions might be easier
00614       //TODO write documentation for this
00615       const Float_t diff_alpha_i = ievt->GetDeltaAlpha()+b_alpha_i_p - ievt->GetAlpha();
00616       const Float_t diff_alpha_j = jevt->GetDeltaAlpha()+b_alpha_j_p - jevt->GetAlpha();
00617 
00618       //update error cache
00619       Int_t k = 0; 
00620       for(fIDIter = fInputData->begin(); fIDIter != fInputData->end(); fIDIter++){
00621          k++;
00622          //there will be some changes in Idx notation
00623          if((*fIDIter)->GetIdx()==0){
00624             Float_t k_ii = fKMatrix->GetElement(ievt->GetNs(), (*fIDIter)->GetNs());
00625             Float_t k_jj = fKMatrix->GetElement(jevt->GetNs(), (*fIDIter)->GetNs());
00626          
00627             (*fIDIter)->UpdateErrorCache(diff_alpha_i * k_ii + diff_alpha_j * k_jj);
00628          }
00629       }
00630          
00631       //store new alphas in SVevents
00632       ievt->SetAlpha(b_alpha_i);
00633       jevt->SetAlpha(b_alpha_j);
00634       ievt->SetAlpha_p(b_alpha_i_p);
00635       jevt->SetAlpha_p(b_alpha_j_p);
00636          
00637       //TODO update Idexes
00638          
00639       // compute fI_low, fB_low
00640 
00641       fB_low = -1*1e30;
00642       fB_up =1e30;
00643    
00644       for(fIDIter = fInputData->begin(); fIDIter != fInputData->end(); fIDIter++){
00645          if((!(*fIDIter)->IsInI3()) && ((*fIDIter)->GetErrorCache()> fB_low)){
00646             fB_low = (*fIDIter)->GetErrorCache();
00647             fTEventLow = (*fIDIter);
00648                   
00649          }
00650          if((!(*fIDIter)->IsInI2()) && ((*fIDIter)->GetErrorCache()< fB_up)){
00651             fB_up =(*fIDIter)->GetErrorCache();
00652             fTEventUp = (*fIDIter);
00653          }
00654       }
00655       return kTRUE;
00656    } else return kFALSE;
00657 }
00658 
00659 
00660 //_______________________________________________________________________
00661 Bool_t TMVA::SVWorkingSet::ExamineExampleReg(TMVA::SVEvent* jevt)
00662 {
00663    Float_t feps = 1e-7;// TODO check which value is the best
00664    SVEvent* ievt=0;
00665    Float_t fErrorC_J = 0.;
00666    if( jevt->IsInI0()) {
00667       fErrorC_J = jevt->GetErrorCache();
00668    }
00669    else{
00670       Float_t *fKVals = jevt->GetLine();
00671       fErrorC_J = 0.;
00672       std::vector<TMVA::SVEvent*>::iterator fIDIter;
00673       
00674       UInt_t k=0;
00675       for(fIDIter = fInputData->begin(); fIDIter != fInputData->end(); fIDIter++){
00676          fErrorC_J -= (*fIDIter)->GetDeltaAlpha()*fKVals[k];
00677          k++;
00678       }
00679       
00680       fErrorC_J += jevt->GetTarget();
00681       jevt->SetErrorCache(fErrorC_J);
00682       
00683       if(jevt->IsInI1()){
00684          if(fErrorC_J + feps < fB_up ){
00685             fB_up = fErrorC_J + feps;
00686             fTEventUp = jevt;
00687          }
00688          else if(fErrorC_J -feps > fB_low) {
00689             fB_low = fErrorC_J - feps;
00690             fTEventLow = jevt;
00691          }
00692       }else if((jevt->IsInI2()) && (fErrorC_J + feps > fB_low)){
00693          fB_low = fErrorC_J + feps;
00694          fTEventLow = jevt;
00695       }else if((jevt->IsInI3()) && (fErrorC_J - feps < fB_up)){
00696          fB_up = fErrorC_J - feps;
00697          fTEventUp = jevt;
00698       }
00699    }
00700    
00701    Bool_t converged = kTRUE;
00702    //case 1
00703    if(jevt->IsInI0a()){
00704       if( fB_low -fErrorC_J + feps > 2*fTolerance){
00705          converged = kFALSE;
00706          ievt = fTEventLow;
00707          if(fErrorC_J-feps-fB_up > fB_low-fErrorC_J+feps){
00708             ievt = fTEventUp;
00709          }
00710       }else if(fErrorC_J -feps - fB_up > 2*fTolerance){
00711          converged = kFALSE;
00712          ievt = fTEventUp;
00713          if(fB_low - fErrorC_J+feps > fErrorC_J-feps -fB_up){
00714             ievt = fTEventLow;
00715          }
00716       }
00717    }
00718    
00719    //case 2
00720    if(jevt->IsInI0b()){
00721       if( fB_low -fErrorC_J - feps > 2*fTolerance){
00722          converged = kFALSE;
00723          ievt = fTEventLow;
00724          if(fErrorC_J+feps-fB_up > fB_low-fErrorC_J-feps){
00725             ievt = fTEventUp;
00726          }
00727       }else if(fErrorC_J + feps - fB_up > 2*fTolerance){
00728          converged = kFALSE;
00729          ievt = fTEventUp;
00730          if(fB_low - fErrorC_J-feps > fErrorC_J+feps -fB_up){
00731             ievt = fTEventLow;
00732          }
00733       }
00734    }
00735    
00736    //case 3
00737    if(jevt->IsInI1()){
00738       if( fB_low -fErrorC_J - feps > 2*fTolerance){
00739          converged = kFALSE;
00740          ievt = fTEventLow;
00741          if(fErrorC_J+feps-fB_up > fB_low-fErrorC_J-feps){
00742             ievt = fTEventUp;
00743          }
00744       }else if(fErrorC_J - feps - fB_up > 2*fTolerance){
00745          converged = kFALSE;
00746          ievt = fTEventUp;
00747          if(fB_low - fErrorC_J+feps > fErrorC_J-feps -fB_up){
00748             ievt = fTEventLow;
00749          }
00750       }
00751    }
00752    
00753    //case 4
00754    if(jevt->IsInI2()){
00755       if( fErrorC_J + feps -fB_up > 2*fTolerance){
00756          converged = kFALSE;
00757          ievt = fTEventUp;
00758       }
00759    }
00760    
00761    //case 5
00762    if(jevt->IsInI3()){
00763       if(fB_low -fErrorC_J +feps > 2*fTolerance){
00764          converged = kFALSE;
00765          ievt = fTEventLow;
00766       }
00767    }
00768    
00769    if(converged) return kFALSE;
00770    if (TakeStepReg(ievt, jevt)) return kTRUE;
00771    else return kFALSE;
00772 }
00773 
00774 Bool_t TMVA::SVWorkingSet::IsDiffSignificant(Float_t a_i, Float_t a_j, Float_t eps)
00775 {   
00776    if( TMath::Abs(a_i - a_j) > eps*(a_i + a_j + eps)) return kTRUE;
00777    else return kFALSE;
00778 }
00779 

Generated on Tue Jul 5 15:25:36 2011 for ROOT_528-00b_version by  doxygen 1.5.1