Small misc changes
[libdai.git] / src / bp.cpp
index 3b7477b..ce27640 100644 (file)
@@ -58,45 +58,57 @@ bool BP::checkProperties() {
 }
 
 
-void BP::Regenerate() {
-    DAIAlgFG::Regenerate();
-    
-    // clear messages
-    _messages.clear();
-    _messages.reserve(nr_edges());
-
-    // clear indices
-    _indices.clear();
-    _indices.reserve(nr_edges());
-
-    // create messages and indices
-    for( vector<_edge_t>::const_iterator iI=edges().begin(); iI!=edges().end(); ++iI ) {
-        _messages.push_back( Prob( var(iI->first).states() ) );
-
-        vector<size_t> ind( factor(iI->second).stateSpace(), 0 );
-        Index i (var(iI->first), factor(iI->second).vars() );
-        for( size_t j = 0; i >= 0; ++i,++j )
-            ind[j] = i; 
-        _indices.push_back( ind );
+void BP::create() {
+    // create edge properties
+    edges.clear();
+    edges.reserve( nrVars() );
+    for( size_t i = 0; i < nrVars(); ++i ) {
+        edges.push_back( vector<EdgeProp>() );
+        edges[i].reserve( nbV(i).size() ); 
+        foreach( const Neighbor &I, nbV(i) ) {
+            EdgeProp newEP;
+            newEP.message = Prob( var(i).states() );
+            newEP.newMessage = Prob( var(i).states() );
+
+            newEP.index.reserve( factor(I).states() );
+            for( Index k( var(i), factor(I).vars() ); k >= 0; ++k )
+                newEP.index.push_back( k );
+
+            newEP.residual = 0.0;
+            edges[i].push_back( newEP );
+        }
     }
-
-    // create new_messages
-    _newmessages = _messages;
 }
 
 
 void BP::init() {
     assert( checkProperties() );
-    for( vector<Prob>::iterator mij = _messages.begin(); mij != _messages.end(); ++mij )
-        mij->fill(1.0 / mij->size());
-    _newmessages = _messages;
+    for( size_t i = 0; i < nrVars(); ++i ) {
+        foreach( const Neighbor &I, nbV(i) ) {
+            message( i, I.iter ).fill( 1.0 );
+            newMessage( i, I.iter ).fill( 1.0 );
+        }
+    }
+}
+
+
+void BP::findMaxResidual( size_t &i, size_t &_I ) {
+    i = 0;
+    _I = 0;
+    double maxres = residual( i, _I );
+    for( size_t j = 0; j < nrVars(); ++j )
+        foreach( const Neighbor &I, nbV(j) )
+            if( residual( j, I.iter ) > maxres ) {
+                i = j;
+                _I = I.iter;
+                maxres = residual( i, _I );
+            }
 }
 
 
-void BP::calcNewMessage (size_t iI) {
+void BP::calcNewMessage( size_t i, size_t _I ) {
     // calculate updated message I->i
-    size_t i = edge(iI).first;
-    size_t I = edge(iI).second;
+    size_t I = nbV(i,_I);
 
 /*  UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
 
@@ -112,32 +124,35 @@ void BP::calcNewMessage (size_t iI) {
     Prob prod( factor(I).p() );
 
     // Calculate product of incoming messages and factor I
-    for( _nb_cit j = nb2(I).begin(); j != nb2(I).end(); ++j )
-        if( *j != i ) {     // for all j in I \ i
+    foreach( const Neighbor &j, nbF(I) ) {
+        if( j != i ) {     // for all j in I \ i
+            size_t _I = j.dual;
             // ind is the precalculated Index(j,I) i.e. to x_I == k corresponds x_j == ind[k]
-            _ind_t* ind = &(index(*j,I));
+            const ind_t & ind = index(j, _I);
 
             // prod_j will be the product of messages coming into j
-            Prob prod_j( var(*j).states() ); 
-            for( _nb_cit J = nb1(*j).begin(); J != nb1(*j).end(); ++J )
-                if( *J != I )   // for all J in nb(j) \ I 
-                    prod_j *= message(*j,*J);
+            Prob prod_j( var(j).states() ); 
+            foreach( const Neighbor &J, nbV(j) ) {
+                if( J != I )   // for all J in nb(j) \ I 
+                    prod_j *= message( j, J.iter );
+            }
 
             // multiply prod with prod_j
             for( size_t r = 0; r < prod.size(); ++r )
-                prod[r] *= prod_j[(*ind)[r]];
+                prod[r] *= prod_j[ind[r]];
         }
+    }
 
     // Marginalize onto i
     Prob marg( var(i).states(), 0.0 );
     // ind is the precalculated Index(i,I) i.e. to x_I == k corresponds x_i == ind[k]
-    _ind_t* ind = &(index(i,I));
+    const ind_t ind = index(i,_I);
     for( size_t r = 0; r < prod.size(); ++r )
-        marg[(*ind)[r]] += prod[r];
+        marg[ind[r]] += prod[r];
     marg.normalize( _normtype );
     
     // Store result
-    _newmessages[iI] = marg;
+    newMessage(i,_I) = marg;
 }
 
 
@@ -152,29 +167,30 @@ double BP::run() {
     clock_t tic = toc();
     Diffs diffs(nrVars(), 1.0);
     
-    vector<size_t> edge_seq;
-    vector<double> residuals;
+    typedef pair<size_t,size_t> Edge;
+    vector<Edge> update_seq;
 
     vector<Factor> old_beliefs;
     old_beliefs.reserve( nrVars() );
     for( size_t i = 0; i < nrVars(); ++i )
-        old_beliefs.push_back(belief1(i));
+        old_beliefs.push_back( beliefV(i) );
 
     size_t iter = 0;
+    size_t nredges = nrEdges();
 
     if( Updates() == UpdateType::SEQMAX ) {
         // do the first pass
-        for(size_t iI = 0; iI < nr_edges(); ++iI ) 
-            calcNewMessage(iI);
-
-        // calculate initial residuals
-        residuals.reserve(nr_edges());
-        for( size_t iI = 0; iI < nr_edges(); ++iI )
-            residuals.push_back( dist( _newmessages[iI], _messages[iI], Prob::DISTLINF ) );
+        for( size_t i = 0; i < nrVars(); ++i )
+            foreach( const Neighbor &I, nbV(i) ) {
+                calcNewMessage( i, I.iter );
+                // calculate initial residuals
+                residual( i, I.iter ) = dist( newMessage( i, I.iter ), message( i, I.iter ), Prob::DISTLINF );
+            }
     } else {
-        edge_seq.reserve( nr_edges() );
-        for( size_t i = 0; i < nr_edges(); ++i )
-            edge_seq.push_back( i );
+        update_seq.reserve( nredges );
+        for( size_t i = 0; i < nrVars(); ++i )
+            foreach( const Neighbor &I, nbV(i) )
+                update_seq.push_back( Edge( i, I.iter ) );
     }
 
     // do several passes over the network until maximum number of iterations has
@@ -182,47 +198,51 @@ double BP::run() {
     for( iter=0; iter < MaxIter() && diffs.max() > Tol(); ++iter ) {
         if( Updates() == UpdateType::SEQMAX ) {
             // Residuals-BP by Koller et al.
-            for( size_t t = 0; t < nr_edges(); ++t ) {
+            for( size_t t = 0; t < nredges; ++t ) {
                 // update the message with the largest residual
-                size_t iI = max_element(residuals.begin(), residuals.end()) - residuals.begin();
-                _messages[iI] = _newmessages[iI];
-                residuals[iI] = 0;
+
+                size_t i, _I;
+                findMaxResidual( i, _I );
+                message( i, _I ) = newMessage( i, _I );
+                residual( i, _I ) = 0.0;
 
                 // I->i has been updated, which means that residuals for all
                 // J->j with J in nb[i]\I and j in nb[J]\i have to be updated
-                size_t i = edge(iI).first;
-                size_t I = edge(iI).second;
-                for( _nb_cit J = nb1(i).begin(); J != nb1(i).end(); ++J ) 
-                    if( *J != I )
-                        for( _nb_cit j = nb2(*J).begin(); j != nb2(*J).end(); ++j )
-                            if( *j != i ) {
-                                size_t jJ = VV2E(*j,*J);
-                                calcNewMessage(jJ);
-                                residuals[jJ] = dist( _newmessages[jJ], _messages[jJ], Prob::DISTLINF );
+                foreach( const Neighbor &J, nbV(i) ) {
+                    if( J.iter != _I ) {
+                        foreach( const Neighbor &j, nbF(J) ) {
+                            size_t _J = j.dual;
+                            if( j != i ) {
+                                calcNewMessage( j, _J );
+                                residual( j, _J ) = dist( newMessage( j, _J ), message( j, _J ), Prob::DISTLINF );
                             }
+                        }
+                    }
+                }
             }
         } else if( Updates() == UpdateType::PARALL ) {
             // Parallel updates 
-            for( size_t t = 0; t < nr_edges(); ++t )
-                calcNewMessage(t);
+            for( size_t i = 0; i < nrVars(); ++i )
+                foreach( const Neighbor &I, nbV(i) )
+                    calcNewMessage( i, I.iter );
 
-            for( size_t t = 0; t < nr_edges(); ++t )
-                _messages[t] = _newmessages[t];
+            for( size_t i = 0; i < nrVars(); ++i )
+                foreach( const Neighbor &I, nbV(i) )
+                    message( i, I.iter ) = newMessage( i, I.iter );
         } else {
             // Sequential updates
             if( Updates() == UpdateType::SEQRND )
-                random_shuffle( edge_seq.begin(), edge_seq.end() );
+                random_shuffle( update_seq.begin(), update_seq.end() );
             
-            for( size_t t = 0; t < nr_edges(); ++t ) {
-                size_t k = edge_seq[t];
-                calcNewMessage(k);
-                _messages[k] = _newmessages[k];
+            foreach( const Edge &e, update_seq ) {
+                calcNewMessage( e.first, e.second );
+                message( e.first, e.second ) = newMessage( e.first, e.second );
             }
         }
 
         // calculate new beliefs and compare with old ones
         for( size_t i = 0; i < nrVars(); ++i ) {
-            Factor nb( belief1(i) );
+            Factor nb( beliefV(i) );
             diffs.push( dist( nb, old_beliefs[i], Prob::DISTLINF ) );
             old_beliefs[i] = nb;
         }
@@ -249,10 +269,10 @@ double BP::run() {
 }
 
 
-Factor BP::belief1( size_t i ) const {
+Factor BP::beliefV( size_t i ) const {
     Prob prod( var(i).states() ); 
-    for( _nb_cit I = nb1(i).begin(); I != nb1(i).end(); ++I ) 
-        prod *= newMessage(i,*I);
+    foreach( const Neighbor &I, nbV(i) )
+        prod *= newMessage( i, I.iter );
 
     prod.normalize( Prob::NORMPROB );
     return( Factor( var(i), prod ) );
@@ -260,16 +280,16 @@ Factor BP::belief1( size_t i ) const {
 
 
 Factor BP::belief (const Var &n) const {
-    return( belief1( findVar( n ) ) );
+    return( beliefV( findVar( n ) ) );
 }
 
 
 vector<Factor> BP::beliefs() const {
     vector<Factor> result;
     for( size_t i = 0; i < nrVars(); ++i )
-        result.push_back( belief1(i) );
+        result.push_back( beliefV(i) );
     for( size_t I = 0; I < nrFactors(); ++I )
-        result.push_back( belief2(I) );
+        result.push_back( beliefF(I) );
     return result;
 }
 
@@ -283,27 +303,29 @@ Factor BP::belief( const VarSet &ns ) const {
             if( factor(I).vars() >> ns )
                 break;
         assert( I != nrFactors() );
-        return belief2(I).marginal(ns);
+        return beliefF(I).marginal(ns);
     }
 }
 
 
-Factor BP::belief2 (size_t I) const {
+Factor BP::beliefF (size_t I) const {
     Prob prod( factor(I).p() );
 
-    for( _nb_cit j = nb2(I).begin(); j != nb2(I).end(); ++j ) {
+    foreach( const Neighbor &j, nbF(I) ) {
+        size_t _I = j.dual;
         // ind is the precalculated Index(j,I) i.e. to x_I == k corresponds x_j == ind[k]
-        const _ind_t *ind = &(index(*j, I));
+        const ind_t & ind = index(j, _I);
 
         // prod_j will be the product of messages coming into j
-        Prob prod_j( var(*j).states() ); 
-        for( _nb_cit J = nb1(*j).begin(); J != nb1(*j).end(); ++J ) 
-            if( *J != I )   // for all J in nb(j) \ I 
-                prod_j *= newMessage(*j,*J);
+        Prob prod_j( var(j).states() ); 
+        foreach( const Neighbor &J, nbV(j) ) {
+            if( J != I )   // for all J in nb(j) \ I 
+                prod_j *= newMessage( j, J.iter );
+        }
 
         // multiply prod with prod_j
         for( size_t r = 0; r < prod.size(); ++r )
-            prod[r] *= prod_j[(*ind)[r]];
+            prod[r] *= prod_j[ind[r]];
     }
 
     Factor result( factor(I).vars(), prod );
@@ -326,9 +348,9 @@ Factor BP::belief2 (size_t I) const {
 Complex BP::logZ() const {
     Complex sum = 0.0;
     for(size_t i = 0; i < nrVars(); ++i )
-        sum += Complex(1.0 - nb1(i).size()) * belief1(i).entropy();
+        sum += Complex(1.0 - nbV(i).size()) * beliefV(i).entropy();
     for( size_t I = 0; I < nrFactors(); ++I )
-        sum -= KL_dist( belief2(I), factor(I) );
+        sum -= KL_dist( beliefF(I), factor(I) );
     return sum;
 }
 
@@ -343,8 +365,8 @@ string BP::identify() const {
 void BP::init( const VarSet &ns ) {
     for( VarSet::const_iterator n = ns.begin(); n != ns.end(); ++n ) {
         size_t ni = findVar( *n );
-        for( _nb_cit I = nb1(ni).begin(); I != nb1(ni).end(); ++I )
-            message(ni,*I).fill( 1.0 );
+        foreach( const Neighbor &I, nbV( ni ) )
+            message( ni, I.iter ).fill( 1.0 );
     }
 }