X-Git-Url: http://git.tuebingen.mpg.de/?p=libdai.git;a=blobdiff_plain;f=src%2Fdaialg.cpp;h=0fabbfb4986551012e46daf44867227f3a61d7ab;hp=a688f36aba54fcb36d092d1c8b52a84ee83cb87c;hb=3e17c64f34ffad3ff2305fb6a05859033b1009bb;hpb=a82c5d4dc9da2f0ebbd421d67cf3c846e47c0443 diff --git a/src/daialg.cpp b/src/daialg.cpp index a688f36..0fabbfb 100644 --- a/src/daialg.cpp +++ b/src/daialg.cpp @@ -29,8 +29,6 @@ namespace dai { using namespace std; -/// Calculate the marginal of obj on ns by clamping -/// all variables in ns and calculating logZ for each joined state Factor calcMarginal( const InfAlg & obj, const VarSet & ns, bool reInit ) { Factor Pns (ns); @@ -38,42 +36,40 @@ Factor calcMarginal( const InfAlg & obj, const VarSet & ns, bool reInit ) { if( !reInit ) clamped->init(); - Complex logZ0; + Real logZ0 = 0.0; for( State s(ns); s.valid(); s++ ) { // save unclamped factors connected to ns - clamped->saveProbs( ns ); + clamped->backupFactors( ns ); // set clamping Factors to delta functions for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ ) clamped->clamp( *n, s(*n) ); // run DAIAlg, calc logZ, store in Pns - if( clamped->Verbose() >= 2 ) - cout << s << ": "; if( reInit ) clamped->init(); + else + clamped->init(ns); clamped->run(); - Complex Z; + Real Z; if( s == 0 ) { logZ0 = clamped->logZ(); Z = 1.0; } else { // subtract logZ0 to avoid very large numbers Z = exp(clamped->logZ() - logZ0); - if( fabs(imag(Z)) > 1e-5 ) - cout << "Marginal:: WARNING: complex Z (" << Z << ")" << endl; } - Pns[s] = real(Z); + Pns[s] = Z; // restore clamped factors - clamped->undoProbs( ns ); + clamped->restoreFactors( ns ); } delete clamped; - return( Pns.normalized(Prob::NORMPROB) ); + return( Pns.normalized() ); } @@ -90,27 +86,23 @@ vector calcPairBeliefs( const InfAlg & obj, const VarSet& ns, bool reIni for( size_t j = 0; j < N; j++ ) for( size_t k = 0; k < N; k++ ) if( j == k ) - pairbeliefs.push_back(Factor()); + pairbeliefs.push_back( Factor() ); else - pairbeliefs.push_back(Factor(vns[j] | vns[k])); + pairbeliefs.push_back( Factor( vns[j] | vns[k] ) ); InfAlg *clamped = obj.clone(); if( !reInit ) clamped->init(); - Complex logZ0; + Real logZ0 = 0.0; for( size_t j = 0; j < N; j++ ) { // clamp Var j to its possible values for( size_t j_val = 0; j_val < vns[j].states(); j_val++ ) { - if( obj.Verbose() >= 2 ) - cout << j << "/" << N-1 << " (" << j_val << "/" << vns[j].states() << "): "; - - // save unclamped factors connected to ns - clamped->saveProbs( ns ); - - clamped->clamp( vns[j], j_val ); + clamped->clamp( vns[j], j_val, true ); if( reInit ) clamped->init(); + else + clamped->init(ns); clamped->run(); //if( j == 0 ) @@ -120,10 +112,7 @@ vector calcPairBeliefs( const InfAlg & obj, const VarSet& ns, bool reIni logZ0 = clamped->logZ(); } else { // subtract logZ0 to avoid very large numbers - Complex Z = exp(clamped->logZ() - logZ0); - if( fabs(imag(Z)) > 1e-5 ) - cout << "calcPairBelief:: Warning: complex Z: " << Z << endl; - Z_xj = real(Z); + Z_xj = exp(clamped->logZ() - logZ0); } for( size_t k = 0; k < N; k++ ) @@ -137,7 +126,7 @@ vector calcPairBeliefs( const InfAlg & obj, const VarSet& ns, bool reIni } // restore clamped factors - clamped->undoProbs( ns ); + clamped->restoreFactors( ns ); } } @@ -165,7 +154,7 @@ Factor calcMarginal2ndO( const InfAlg & obj, const VarSet& ns, bool reInit ) { for( size_t ij = 0; ij < pairbeliefs.size(); ij++ ) Pns *= pairbeliefs[ij]; - return( Pns.normalized(Prob::NORMPROB) ); + return( Pns.normalized() ); } @@ -177,7 +166,7 @@ vector calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool re if( !reInit ) clamped->init(); - Complex logZ0; + Real logZ0 = 0.0; VarSet::const_iterator nj = ns.begin(); for( long j = 0; j < (long)ns.size() - 1; j++, nj++ ) { size_t k = 0; @@ -188,12 +177,14 @@ vector calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool re for( size_t j_val = 0; j_val < nj->states(); j_val++ ) for( size_t k_val = 0; k_val < nk->states(); k_val++ ) { // save unclamped factors connected to ns - clamped->saveProbs( ns ); + clamped->backupFactors( ns ); clamped->clamp( *nj, j_val ); clamped->clamp( *nk, k_val ); if( reInit ) clamped->init(); + else + clamped->init(ns); clamped->run(); double Z_xj = 1.0; @@ -201,10 +192,7 @@ vector calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool re logZ0 = clamped->logZ(); } else { // subtract logZ0 to avoid very large numbers - Complex Z = exp(clamped->logZ() - logZ0); - if( fabs(imag(Z)) > 1e-5 ) - cout << "calcPairBelief:: Warning: complex Z: " << Z << endl; - Z_xj = real(Z); + Z_xj = exp(clamped->logZ() - logZ0); } // we assume that j.label() < k.label() @@ -212,7 +200,7 @@ vector calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool re pairbelief[j_val + (k_val * nj->states())] = Z_xj; // restore clamped factors - clamped->undoProbs( ns ); + clamped->restoreFactors( ns ); } result.push_back( pairbelief );