Multiple changes: changes in build system, one workaround and one bug fix
[libdai.git] / src / bp_dual.cpp
index 8e67da9..4cf7bf3 100644 (file)
@@ -1,21 +1,13 @@
-/*  Copyright (C) 2009  Frederik Eaton [frederik at ofb dot net]
+/*  This file is part of libDAI - http://www.libdai.org/
+ *
+ *  Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
+ */
 
-    This file is part of libDAI.
 
-    libDAI is free software; you can redistribute it and/or modify
-    it under the terms of the GNU General Public License as published by
-    the Free Software Foundation; either version 2 of the License, or
-    (at your option) any later version.
-
-    libDAI is distributed in the hope that it will be useful,
-    but WITHOUT ANY WARRANTY; without even the implied warranty of
-    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-    GNU General Public License for more details.
-
-    You should have received a copy of the GNU General Public License
-    along with libDAI; if not, write to the Free Software
-    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
-*/
+#include <dai/dai_config.h>
+#ifdef DAI_WITH_CBP
 
 
 #include <iostream>
@@ -33,9 +25,6 @@ namespace dai {
 using namespace std;
 
 
-typedef BipartiteGraph::Neighbor Neighbor;
-
-
 void BP_dual::init() {
     regenerateMessages();
     regenerateBeliefs();
@@ -72,7 +61,7 @@ void BP_dual::regenerateBeliefs() {
     for( size_t i = 0; i < fg().nrVars(); i++ )
         _beliefs.b1.push_back( Prob( fg().var(i).states() ) );
     for( size_t I = 0; I < fg().nrFactors(); I++ )
-        _beliefs.b2.push_back( Prob( fg().factor(I).states() ) );
+        _beliefs.b2.push_back( Prob( fg().factor(I).nrStates() ) );
 }
 
 
@@ -80,16 +69,16 @@ void BP_dual::calcMessages() {
     // calculate 'n' messages from "factor marginal / factor"
     for( size_t I = 0; I < fg().nrFactors(); I++ ) {
         Factor f = _ia->beliefF(I) / fg().factor(I);
-        foreach( const Neighbor &i, fg().nbF(I) )
+        bforeach( const Neighbor &i, fg().nbF(I) )
             msgN(i, i.dual) = f.marginal( fg().var(i) ).p();
     }
     // calculate 'm' messages and normalizers from 'n' messages
     for( size_t i = 0; i < fg().nrVars(); i++ )
-        foreach( const Neighbor &I, fg().nbV(i) )
+        bforeach( const Neighbor &I, fg().nbV(i) )
             calcNewM( i, I.iter );
     // recalculate 'n' messages and normalizers from 'm' messages
     for( size_t i = 0; i < fg().nrVars(); i++ )
-        foreach( const Neighbor &I, fg().nbV(i) )
+        bforeach( const Neighbor &I, fg().nbV(i) )
             calcNewN(i, I.iter);
 }
 
@@ -98,19 +87,19 @@ void BP_dual::calcNewM( size_t i, size_t _I ) {
     // calculate updated message I->i
     const Neighbor &I = fg().nbV(i)[_I];
     Prob prod( fg().factor(I).p() );
-    foreach( const Neighbor &j, fg().nbF(I) )
+    bforeach( const Neighbor &j, fg().nbF(I) )
         if( j != i ) { // for all j in I \ i
             Prob &n = msgN(j,j.dual);
             IndexFor ind( fg().var(j), fg().factor(I).vars() );
-            for( size_t x = 0; ind >= 0; x++, ++ind )
-                prod[x] *= n[ind];
+            for( size_t x = 0; ind.valid(); x++, ++ind )
+                prod.set( x, prod[x] * n[ind] );
         }
     // Marginalize onto i
     Prob marg( fg().var(i).states(), 0.0 );
     // ind is the precalculated Index(i,I) i.e. to x_I == k corresponds x_i == ind[k]
     IndexFor ind( fg().var(i), fg().factor(I).vars() );
-    for( size_t x = 0; ind >= 0; x++, ++ind )
-        marg[ind] += prod[x];
+    for( size_t x = 0; ind.valid(); x++, ++ind )
+        marg.set( ind, marg[ind] + prod[x] );
 
     _msgs.Zm[i][_I] = marg.normalize();
     _msgs.m[i][_I] = marg;
@@ -121,7 +110,7 @@ void BP_dual::calcNewN( size_t i, size_t _I ) {
     // calculate updated message i->I
     const Neighbor &I = fg().nbV(i)[_I];
     Prob prod( fg().var(i).states(), 1.0 );
-    foreach( const Neighbor &J, fg().nbV(i) )
+    bforeach( const Neighbor &J, fg().nbV(i) )
         if( J.node != I.node ) // for all J in i \ I
             prod *= msgM(i,J.iter);
     _msgs.Zn[i][_I] = prod.normalize();
@@ -139,7 +128,7 @@ void BP_dual::calcBeliefs() {
 
 void BP_dual::calcBeliefV( size_t i ) {
     Prob prod( fg().var(i).states(), 1.0 );
-    foreach( const Neighbor &I, fg().nbV(i) )
+    bforeach( const Neighbor &I, fg().nbV(i) )
         prod *= msgM(i,I.iter);
     _beliefs.Zb1[i] = prod.normalize();
     _beliefs.b1[i] = prod;
@@ -148,11 +137,11 @@ void BP_dual::calcBeliefV( size_t i ) {
 
 void BP_dual::calcBeliefF( size_t I ) {
     Prob prod( fg().factor(I).p() );
-    foreach( const Neighbor &j, fg().nbF(I) ) {
+    bforeach( const Neighbor &j, fg().nbF(I) ) {
         IndexFor ind( fg().var(j), fg().factor(I).vars() );
         Prob n( msgN(j,j.dual) );
-        for( size_t x = 0; ind >= 0; x++, ++ind )
-            prod[x] *= n[ind];
+        for( size_t x = 0; ind.valid(); x++, ++ind )
+            prod.set( x, prod[x] * n[ind] );
     }
     _beliefs.Zb2[I] = prod.normalize();
     _beliefs.b2[I] = prod;
@@ -160,3 +149,6 @@ void BP_dual::calcBeliefF( size_t I ) {
 
 
 } // end of namespace dai
+
+
+#endif