Merged regiongraph.* and daialg.* from SVN head,
[libdai.git] / src / daialg.cpp
index a688f36..0fabbfb 100644 (file)
@@ -29,8 +29,6 @@ namespace dai {
 using namespace std;
 
 
-/// Calculate the marginal of obj on ns by clamping 
-/// all variables in ns and calculating logZ for each joined state
 Factor calcMarginal( const InfAlg & obj, const VarSet & ns, bool reInit ) {
     Factor Pns (ns);
     
@@ -38,42 +36,40 @@ Factor calcMarginal( const InfAlg & obj, const VarSet & ns, bool reInit ) {
     if( !reInit )
         clamped->init();
 
-    Complex logZ0;
+    Real logZ0 = 0.0;
     for( State s(ns); s.valid(); s++ ) {
         // save unclamped factors connected to ns
-        clamped->saveProbs( ns );
+        clamped->backupFactors( ns );
 
         // set clamping Factors to delta functions
         for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
             clamped->clamp( *n, s(*n) );
         
         // run DAIAlg, calc logZ, store in Pns
-        if( clamped->Verbose() >= 2 )
-            cout << s << ": ";
         if( reInit )
             clamped->init();
+        else
+            clamped->init(ns);
         clamped->run();
 
-        Complex Z;
+        Real Z;
         if( s == 0 ) {
             logZ0 = clamped->logZ();
             Z = 1.0;
         } else {
             // subtract logZ0 to avoid very large numbers
             Z = exp(clamped->logZ() - logZ0);
-            if( fabs(imag(Z)) > 1e-5 )
-                cout << "Marginal:: WARNING: complex Z (" << Z << ")" << endl;
         }
 
-        Pns[s] = real(Z);
+        Pns[s] = Z;
         
         // restore clamped factors
-        clamped->undoProbs( ns );
+        clamped->restoreFactors( ns );
     }
 
     delete clamped;
 
-    return( Pns.normalized(Prob::NORMPROB) );
+    return( Pns.normalized() );
 }
 
 
@@ -90,27 +86,23 @@ vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& ns, bool reIni
     for( size_t j = 0; j < N; j++ )
         for( size_t k = 0; k < N; k++ )
             if( j == k )
-                pairbeliefs.push_back(Factor());
+                pairbeliefs.push_back( Factor() );
             else
-                pairbeliefs.push_back(Factor(vns[j] | vns[k]));
+                pairbeliefs.push_back( Factor( vns[j] | vns[k] ) );
 
     InfAlg *clamped = obj.clone();
     if( !reInit )
         clamped->init();
 
-    Complex logZ0;
+    Real logZ0 = 0.0;
     for( size_t j = 0; j < N; j++ ) {
         // clamp Var j to its possible values
         for( size_t j_val = 0; j_val < vns[j].states(); j_val++ ) {
-            if( obj.Verbose() >= 2 )
-                cout << j << "/" << N-1 << " (" << j_val << "/" << vns[j].states() << "): ";
-
-            // save unclamped factors connected to ns
-            clamped->saveProbs( ns );
-            
-            clamped->clamp( vns[j], j_val );
+            clamped->clamp( vns[j], j_val, true );
             if( reInit )
                 clamped->init();
+            else
+                clamped->init(ns);
             clamped->run();
 
             //if( j == 0 )
@@ -120,10 +112,7 @@ vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& ns, bool reIni
                 logZ0 = clamped->logZ();
             } else {
                 // subtract logZ0 to avoid very large numbers
-                Complex Z = exp(clamped->logZ() - logZ0);
-                if( fabs(imag(Z)) > 1e-5 )
-                    cout << "calcPairBelief::  Warning: complex Z: " << Z << endl;
-                Z_xj = real(Z);
+                Z_xj = exp(clamped->logZ() - logZ0);
             }
 
             for( size_t k = 0; k < N; k++ ) 
@@ -137,7 +126,7 @@ vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& ns, bool reIni
                 }
 
             // restore clamped factors
-            clamped->undoProbs( ns );
+            clamped->restoreFactors( ns );
         }
     }
     
@@ -165,7 +154,7 @@ Factor calcMarginal2ndO( const InfAlg & obj, const VarSet& ns, bool reInit ) {
     for( size_t ij = 0; ij < pairbeliefs.size(); ij++ )
         Pns *= pairbeliefs[ij];
     
-    return( Pns.normalized(Prob::NORMPROB) );
+    return( Pns.normalized() );
 }
 
 
@@ -177,7 +166,7 @@ vector<Factor> calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool re
     if( !reInit )
         clamped->init();
 
-    Complex logZ0;
+    Real logZ0 = 0.0;
     VarSet::const_iterator nj = ns.begin();
     for( long j = 0; j < (long)ns.size() - 1; j++, nj++ ) {
         size_t k = 0;
@@ -188,12 +177,14 @@ vector<Factor> calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool re
             for( size_t j_val = 0; j_val < nj->states(); j_val++ ) 
                 for( size_t k_val = 0; k_val < nk->states(); k_val++ ) {
                     // save unclamped factors connected to ns
-                    clamped->saveProbs( ns );
+                    clamped->backupFactors( ns );
             
                     clamped->clamp( *nj, j_val );
                     clamped->clamp( *nk, k_val );
                     if( reInit )
                         clamped->init();
+                    else
+                        clamped->init(ns);
                     clamped->run();
 
                     double Z_xj = 1.0;
@@ -201,10 +192,7 @@ vector<Factor> calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool re
                         logZ0 = clamped->logZ();
                     } else {
                         // subtract logZ0 to avoid very large numbers
-                        Complex Z = exp(clamped->logZ() - logZ0);
-                        if( fabs(imag(Z)) > 1e-5 )
-                            cout << "calcPairBelief::  Warning: complex Z: " << Z << endl;
-                        Z_xj = real(Z);
+                        Z_xj = exp(clamped->logZ() - logZ0);
                     }
 
                     // we assume that j.label() < k.label()
@@ -212,7 +200,7 @@ vector<Factor> calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool re
                     pairbelief[j_val + (k_val * nj->states())] = Z_xj;
 
                     // restore clamped factors
-                    clamped->undoProbs( ns );
+                    clamped->restoreFactors( ns );
                 }
         
             result.push_back( pairbelief );