Merged regiongraph.* and daialg.* from SVN head,
[libdai.git] / src / daialg.cpp
index b8adfb9..0fabbfb 100644 (file)
@@ -29,8 +29,6 @@ namespace dai {
 using namespace std;
 
 
-/// Calculate the marginal of obj on ns by clamping 
-/// all variables in ns and calculating logZ for each joined state
 Factor calcMarginal( const InfAlg & obj, const VarSet & ns, bool reInit ) {
     Factor Pns (ns);
     
@@ -50,6 +48,8 @@ Factor calcMarginal( const InfAlg & obj, const VarSet & ns, bool reInit ) {
         // run DAIAlg, calc logZ, store in Pns
         if( reInit )
             clamped->init();
+        else
+            clamped->init(ns);
         clamped->run();
 
         Real Z;
@@ -69,7 +69,7 @@ Factor calcMarginal( const InfAlg & obj, const VarSet & ns, bool reInit ) {
 
     delete clamped;
 
-    return( Pns.normalized(Prob::NORMPROB) );
+    return( Pns.normalized() );
 }
 
 
@@ -86,9 +86,9 @@ vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& ns, bool reIni
     for( size_t j = 0; j < N; j++ )
         for( size_t k = 0; k < N; k++ )
             if( j == k )
-                pairbeliefs.push_back(Factor());
+                pairbeliefs.push_back( Factor() );
             else
-                pairbeliefs.push_back(Factor(vns[j] | vns[k]));
+                pairbeliefs.push_back( Factor( vns[j] | vns[k] ) );
 
     InfAlg *clamped = obj.clone();
     if( !reInit )
@@ -98,12 +98,11 @@ vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& ns, bool reIni
     for( size_t j = 0; j < N; j++ ) {
         // clamp Var j to its possible values
         for( size_t j_val = 0; j_val < vns[j].states(); j_val++ ) {
-            // save unclamped factors connected to ns
-            clamped->backupFactors( ns );
-            
-            clamped->clamp( vns[j], j_val );
+            clamped->clamp( vns[j], j_val, true );
             if( reInit )
                 clamped->init();
+            else
+                clamped->init(ns);
             clamped->run();
 
             //if( j == 0 )
@@ -155,7 +154,7 @@ Factor calcMarginal2ndO( const InfAlg & obj, const VarSet& ns, bool reInit ) {
     for( size_t ij = 0; ij < pairbeliefs.size(); ij++ )
         Pns *= pairbeliefs[ij];
     
-    return( Pns.normalized(Prob::NORMPROB) );
+    return( Pns.normalized() );
 }
 
 
@@ -184,6 +183,8 @@ vector<Factor> calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool re
                     clamped->clamp( *nk, k_val );
                     if( reInit )
                         clamped->init();
+                    else
+                        clamped->init(ns);
                     clamped->run();
 
                     double Z_xj = 1.0;