Fixed tabs and trailing whitespaces
[libdai.git] / src / bp.cpp
index eb851cf..07c54bc 100644 (file)
@@ -49,7 +49,7 @@ void BP::setProperties( const PropertySet &opts ) {
     assert( opts.hasKey("maxiter") );
     assert( opts.hasKey("logdomain") );
     assert( opts.hasKey("updates") );
-    
+
     props.tol = opts.getStringAs<double>("tol");
     props.maxiter = opts.getStringAs<size_t>("maxiter");
     props.logdomain = opts.getStringAs<bool>("logdomain");
@@ -106,7 +106,7 @@ void BP::construct() {
         _edge2lut.reserve( nrVars() );
     for( size_t i = 0; i < nrVars(); ++i ) {
         _edges.push_back( vector<EdgeProp>() );
-        _edges[i].reserve( nbV(i).size() ); 
+        _edges[i].reserve( nbV(i).size() );
         if( props.updates == Properties::UpdateType::SEQMAX ) {
             _edge2lut.push_back( vector<LutType::iterator>() );
             _edge2lut[i].reserve( nbV(i).size() );
@@ -125,7 +125,7 @@ void BP::construct() {
             newEP.residual = 0.0;
             _edges[i].push_back( newEP );
             if( props.updates == Properties::UpdateType::SEQMAX )
-                _edge2lut[i].push_back( _lut.insert( std::make_pair( newEP.residual, std::make_pair( i, _edges[i].size() - 1 ))) );
+                _edge2lut[i].push_back( _lut.insert( make_pair( newEP.residual, make_pair( i, _edges[i].size() - 1 ))) );
         }
     }
 }
@@ -138,25 +138,13 @@ void BP::init() {
             message( i, I.iter ).fill( c );
             newMessage( i, I.iter ).fill( c );
             if( props.updates == Properties::UpdateType::SEQMAX )
-                               updateResidual( i, I.iter, 0.0 );
+                updateResidual( i, I.iter, 0.0 );
         }
     }
 }
 
 
 void BP::findMaxResidual( size_t &i, size_t &_I ) {
-/*
-    i = 0;
-    _I = 0;
-    double maxres = residual( i, _I );
-    for( size_t j = 0; j < nrVars(); ++j )
-        foreach( const Neighbor &I, nbV(j) )
-            if( residual( j, I.iter ) > maxres ) {
-                i = j;
-                _I = I.iter;
-                maxres = residual( i, _I );
-            }
-*/
     assert( !_lut.empty() );
     LutType::const_iterator largestEl = _lut.end();
     --largestEl;
@@ -169,41 +157,36 @@ void BP::calcNewMessage( size_t i, size_t _I ) {
     // calculate updated message I->i
     size_t I = nbV(i,_I);
 
-    if( !DAI_BP_FAST ) {
-        /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
-        Factor prod( factor( I ) );
-        foreach( const Neighbor &j, nbF(I) )
-            if( j != i ) {     // for all j in I \ i
-                foreach( const Neighbor &J, nbV(j) )
-                    if( J != I ) {     // for all J in nb(j) \ I 
-                        prod *= Factor( var(j), message(j, J.iter) );
-                    }
-            }
-        newMessage(i,_I) = prod.marginal( var(i) ).p();
-    } else {
-        /* OPTIMIZED VERSION */
-        Prob prod( factor(I).p() );
-        if( props.logdomain ) 
-            prod.takeLog();
+    Factor Fprod( factor(I) );
+    Prob &prod = Fprod.p();
+    if( props.logdomain )
+        prod.takeLog();
+
+    // Calculate product of incoming messages and factor I
+    foreach( const Neighbor &j, nbF(I) )
+        if( j != i ) { // for all j in I \ i
+            // prod_j will be the product of messages coming into j
+            Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
+            foreach( const Neighbor &J, nbV(j) )
+                if( J != I ) { // for all J in nb(j) \ I
+                    if( props.logdomain )
+                        prod_j += message( j, J.iter );
+                    else
+                        prod_j *= message( j, J.iter );
+                }
 
-        // Calculate product of incoming messages and factor I
-        foreach( const Neighbor &j, nbF(I) ) {
-            if( j != i ) {     // for all j in I \ i
+            // multiply prod with prod_j
+            if( !DAI_BP_FAST ) {
+                /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
+                if( props.logdomain )
+                    Fprod += Factor( var(j), prod_j );
+                else
+                    Fprod *= Factor( var(j), prod_j );
+            } else {
+                /* OPTIMIZED VERSION */
                 size_t _I = j.dual;
                 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
                 const ind_t &ind = index(j, _I);
-
-                // prod_j will be the product of messages coming into j
-                Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 ); 
-                foreach( const Neighbor &J, nbV(j) )
-                    if( J != I ) { // for all J in nb(j) \ I 
-                        if( props.logdomain )
-                            prod_j += message( j, J.iter );
-                        else
-                            prod_j *= message( j, J.iter );
-                    }
-
-                // multiply prod with prod_j
                 for( size_t r = 0; r < prod.size(); ++r )
                     if( props.logdomain )
                         prod[r] += prod_j[ind[r]];
@@ -211,31 +194,41 @@ void BP::calcNewMessage( size_t i, size_t _I ) {
                         prod[r] *= prod_j[ind[r]];
             }
         }
-        if( props.logdomain ) {
-            prod -= prod.max();
-            prod.takeExp();
-        }
 
-        // Marginalize onto i
-        Prob marg( var(i).states(), 0.0 );
+    if( props.logdomain ) {
+        prod -= prod.max();
+        prod.takeExp();
+    }
+
+    // Marginalize onto i
+    Prob marg;
+    if( !DAI_BP_FAST ) {
+        /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
+        if( props.inference == Properties::InfType::SUMPROD )
+            marg = Fprod.marginal( var(i) ).p();
+        else
+            marg = Fprod.maxMarginal( var(i) ).p();
+    } else {
+        /* OPTIMIZED VERSION */
+        marg = Prob( var(i).states(), 0.0 );
         // ind is the precalculated IndexFor(i,I) i.e. to x_I == k corresponds x_i == ind[k]
         const ind_t ind = index(i,_I);
-        if( props.inference == Properties::InfType::SUMPROD ) 
+        if( props.inference == Properties::InfType::SUMPROD )
             for( size_t r = 0; r < prod.size(); ++r )
                 marg[ind[r]] += prod[r];
         else
             for( size_t r = 0; r < prod.size(); ++r )
-                if( prod[r] > marg[ind[r]] ) 
+                if( prod[r] > marg[ind[r]] )
                     marg[ind[r]] = prod[r];
         marg.normalize();
-
-        // Store result
-        if( props.logdomain )
-            newMessage(i,_I) = marg.log();
-        else
-            newMessage(i,_I) = marg;
     }
 
+    // Store result
+    if( props.logdomain )
+        newMessage(i,_I) = marg.log();
+    else
+        newMessage(i,_I) = marg;
+
     // Update the residual if necessary
     if( props.updates == Properties::UpdateType::SEQMAX )
         updateResidual( i, _I , dist( newMessage( i, _I ), message( i, _I ), Prob::DISTLINF ) );
@@ -252,7 +245,7 @@ double BP::run() {
 
     double tic = toc();
     Diffs diffs(nrVars(), 1.0);
-    
+
     vector<Edge> update_seq;
 
     vector<Factor> old_beliefs;
@@ -300,7 +293,7 @@ double BP::run() {
                 }
             }
         } else if( props.updates == Properties::UpdateType::PARALL ) {
-            // Parallel updates 
+            // Parallel updates
             for( size_t i = 0; i < nrVars(); ++i )
                 foreach( const Neighbor &I, nbV(i) )
                     calcNewMessage( i, I.iter );
@@ -312,7 +305,7 @@ double BP::run() {
             // Sequential updates
             if( props.updates == Properties::UpdateType::SEQRND )
                 random_shuffle( update_seq.begin(), update_seq.end() );
-            
+
             foreach( const Edge &e, update_seq ) {
                 calcNewMessage( e.first, e.second );
                 updateMessage( e.first, e.second );
@@ -350,7 +343,7 @@ double BP::run() {
 
 
 void BP::calcBeliefV( size_t i, Prob &p ) const {
-    p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 ); 
+    p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
     foreach( const Neighbor &I, nbV(i) )
         if( props.logdomain )
             p += newMessage( i, I.iter );
@@ -360,34 +353,46 @@ void BP::calcBeliefV( size_t i, Prob &p ) const {
 
 
 void BP::calcBeliefF( size_t I, Prob &p ) const {
-    p = factor(I).p();
+    Factor Fprod( factor( I ) );
+    Prob &prod = Fprod.p();
+
     if( props.logdomain )
-        p.takeLog();
+        prod.takeLog();
 
     foreach( const Neighbor &j, nbF(I) ) {
-        size_t _I = j.dual;
-        // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
-        const ind_t & ind = index(j, _I);
-
         // prod_j will be the product of messages coming into j
-        Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 ); 
-        foreach( const Neighbor &J, nbV(j) ) {
-            if( J != I ) { // for all J in nb(j) \ I 
+        Prob prod_j( var(j).states(), props.logdomain ? 0.0 : 1.0 );
+        foreach( const Neighbor &J, nbV(j) )
+            if( J != I ) { // for all J in nb(j) \ I
                 if( props.logdomain )
                     prod_j += newMessage( j, J.iter );
                 else
                     prod_j *= newMessage( j, J.iter );
             }
-        }
 
-        // multiply p with prod_j
-        for( size_t r = 0; r < p.size(); ++r ) {
+        // multiply prod with prod_j
+        if( !DAI_BP_FAST ) {
+            /* UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
             if( props.logdomain )
-                p[r] += prod_j[ind[r]];
+                Fprod += Factor( var(j), prod_j );
             else
-                p[r] *= prod_j[ind[r]];
+                Fprod *= Factor( var(j), prod_j );
+        } else {
+            /* OPTIMIZED VERSION */
+            size_t _I = j.dual;
+            // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
+            const ind_t & ind = index(j, _I);
+
+            for( size_t r = 0; r < prod.size(); ++r ) {
+                if( props.logdomain )
+                    prod[r] += prod_j[ind[r]];
+                else
+                    prod[r] *= prod_j[ind[r]];
+            }
         }
     }
+
+    p = prod;
 }
 
 
@@ -399,12 +404,26 @@ Factor BP::beliefV( size_t i ) const {
         p -= p.max();
         p.takeExp();
     }
-
     p.normalize();
+
     return( Factor( var(i), p ) );
 }
 
 
+Factor BP::beliefF( size_t I ) const {
+    Prob p;
+    calcBeliefF( I, p );
+
+    if( props.logdomain ) {
+        p -= p.max();
+        p.takeExp();
+    }
+    p.normalize();
+
+    return( Factor( factor(I).vars(), p ) );
+}
+
+
 Factor BP::belief( const Var &n ) const {
     return( beliefV( findVar( n ) ) );
 }
@@ -434,37 +453,6 @@ Factor BP::belief( const VarSet &ns ) const {
 }
 
 
-Factor BP::beliefF( size_t I ) const {
-    if( !DAI_BP_FAST ) {
-        /*  UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION */
-
-        Factor prod( factor(I) );
-        foreach( const Neighbor &j, nbF(I) ) {
-            foreach( const Neighbor &J, nbV(j) ) {
-                if( J != I )  // for all J in nb(j) \ I
-                    prod *= Factor( var(j), newMessage(j, J.iter) );
-            }
-        }
-        return prod.normalized();
-    } else {
-        /* OPTIMIZED VERSION */
-
-        Prob prod;
-        calcBeliefF( I, prod );
-
-        if( props.logdomain ) {
-            prod -= prod.max();
-            prod.takeExp();
-        }
-        prod.normalize();
-
-        Factor result( factor(I).vars(), prod );
-
-        return( result );
-    }
-}
-
-
 Real BP::logZ() const {
     Real sum = 0.0;
     for(size_t i = 0; i < nrVars(); ++i )
@@ -475,7 +463,7 @@ Real BP::logZ() const {
 }
 
 
-string BP::identify() const { 
+string BP::identify() const {
     return string(Name) + printProperties();
 }
 
@@ -510,32 +498,32 @@ void BP::updateMessage( size_t i, size_t _I ) {
 
 
 void BP::updateResidual( size_t i, size_t _I, double r ) {
-       EdgeProp* pEdge = &_edges[i][_I];
-       pEdge->residual = r;
-       
-       // rearrange look-up table (delete and reinsert new key)
-       _lut.erase( _edge2lut[i][_I] );
-       _edge2lut[i][_I] = _lut.insert( std::make_pair( r, std::make_pair(i, _I) ) );
+    EdgeProp* pEdge = &_edges[i][_I];
+    pEdge->residual = r;
+
+    // rearrange look-up table (delete and reinsert new key)
+    _lut.erase( _edge2lut[i][_I] );
+    _edge2lut[i][_I] = _lut.insert( make_pair( r, make_pair(i, _I) ) );
 }
 
 
 std::vector<size_t> BP::findMaximum() const {
-    std::vector<size_t> maximum( nrVars() );
-    std::vector<bool> visitedVars( nrVars(), false );
-    std::vector<bool> visitedFactors( nrFactors(), false );
-    std::stack<size_t> scheduledFactors;
+    vector<size_t> maximum( nrVars() );
+    vector<bool> visitedVars( nrVars(), false );
+    vector<bool> visitedFactors( nrFactors(), false );
+    stack<size_t> scheduledFactors;
     for( size_t i = 0; i < nrVars(); ++i ) {
         if( visitedVars[i] )
             continue;
         visitedVars[i] = true;
-        
+
         // Maximise with respect to variable i
         Prob prod;
         calcBeliefV( i, prod );
-        maximum[i] = std::max_element( prod.begin(), prod.end() ) - prod.begin();
-        
+        maximum[i] = max_element( prod.begin(), prod.end() ) - prod.begin();
+
         foreach( const Neighbor &I, nbV(i) )
-            if( !visitedFactors[I] ) 
+            if( !visitedFactors[I] )
                 scheduledFactors.push(I);
 
         while( !scheduledFactors.empty() ){
@@ -544,24 +532,24 @@ std::vector<size_t> BP::findMaximum() const {
             if( visitedFactors[I] )
                 continue;
             visitedFactors[I] = true;
-            
+
             // Evaluate if some neighboring variables still need to be fixed; if not, we're done
             bool allDetermined = true;
-            foreach( const Neighbor &j, nbF(I) ) 
+            foreach( const Neighbor &j, nbF(I) )
                 if( !visitedVars[j.node] ) {
                     allDetermined = false;
                     break;
                 }
             if( allDetermined )
                 continue;
-            
+
             // Calculate product of incoming messages on factor I
             Prob prod2;
             calcBeliefF( I, prod2 );
 
             // The allowed configuration is restrained according to the variables assigned so far:
             // pick the argmax amongst the allowed states
-            Real maxProb = std::numeric_limits<Real>::min();
+            Real maxProb = numeric_limits<Real>::min();
             State maxState( factor(I).vars() );
             for( State s( factor(I).vars() ); s.valid(); ++s ){
                 // First, calculate whether this state is consistent with variables that
@@ -578,19 +566,19 @@ std::vector<size_t> BP::findMaximum() const {
                     maxProb = prod2[s];
                 }
             }
-            
+
             // Decode the argmax
             foreach( const Neighbor &j, nbF(I) ) {
                 if( visitedVars[j.node] ) {
                     // We have already visited j earlier - hopefully our state is consistent
                     if( maximum[j.node] != maxState(var(j.node)) && props.verbose >= 1 )
-                        std::cerr << "BP::findMaximum - warning: maximum not consistent due to loops." << std::endl;
+                        cerr << "BP::findMaximum - warning: maximum not consistent due to loops." << endl;
                 } else {
                     // We found a consistent state for variable j
                     visitedVars[j.node] = true;
                     maximum[j.node] = maxState( var(j.node) );
                     foreach( const Neighbor &J, nbV(j) )
-                        if( !visitedFactors[J] ) 
+                        if( !visitedFactors[J] )
                             scheduledFactors.push(J);
                 }
             }