[Frederik Eaton] Major cleanup of BBP and CBP code and documentation
[libdai.git] / src / treeep.cpp
index d21ed5f..2c88f98 100644 (file)
@@ -1,6 +1,7 @@
-/*  Copyright (C) 2006-2008  Joris Mooij  [j dot mooij at science dot ru dot nl]
-    Radboud University Nijmegen, The Netherlands
-    
+/*  Copyright (C) 2006-2008  Joris Mooij  [joris dot mooij at tuebingen dot mpg dot de]
+    Radboud University Nijmegen, The Netherlands /
+    Max Planck Institute for Biological Cybernetics, Germany
+
     This file is part of libDAI.
 
     libDAI is free software; you can redistribute it and/or modify
@@ -25,7 +26,6 @@
 #include <dai/jtree.h>
 #include <dai/treeep.h>
 #include <dai/util.h>
-#include <dai/diffs.h>
 
 
 namespace dai {
@@ -119,10 +119,10 @@ void TreeEP::TreeEPSubTree::init() {
 
 void TreeEP::TreeEPSubTree::InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) {
     for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
-        _Qa[alpha] = Qa[_a[alpha]].divided_by( _Qa[alpha] );
+        _Qa[alpha] = Qa[_a[alpha]] / _Qa[alpha];
 
     for( size_t beta = 0; beta < _Qb.size(); beta++ )
-        _Qb[beta] = Qb[_b[beta]].divided_by( _Qb[beta] );
+        _Qb[beta] = Qb[_b[beta]] / _Qb[beta];
 }
 
 
@@ -151,15 +151,15 @@ void TreeEP::TreeEPSubTree::HUGIN_with_I( std::vector<Factor> &Qa, std::vector<F
                     delta[s(*n)] = 1.0;
                     _Qa[_RTree[i].n2] *= delta;
                 }
-            Factor new_Qb = _Qa[_RTree[i].n2].partSum( _Qb[i].vars() );
-            _Qa[_RTree[i].n1] *= new_Qb.divided_by( _Qb[i] )
+            Factor new_Qb = _Qa[_RTree[i].n2].marginal( _Qb[i].vars(), false );
+            _Qa[_RTree[i].n1] *= new_Qb / _Qb[i]
             _Qb[i] = new_Qb;
         }
 
         // DistributeEvidence
         for( size_t i = 0; i < _RTree.size(); i++ ) {
-            Factor new_Qb = _Qa[_RTree[i].n1].partSum( _Qb[i].vars() );
-            _Qa[_RTree[i].n2] *= new_Qb.divided_by( _Qb[i] )
+            Factor new_Qb = _Qa[_RTree[i].n1].marginal( _Qb[i].vars(), false );
+            _Qa[_RTree[i].n2] *= new_Qb / _Qb[i]
             _Qb[i] = new_Qb;
         }
 
@@ -177,23 +177,23 @@ void TreeEP::TreeEPSubTree::HUGIN_with_I( std::vector<Factor> &Qa, std::vector<F
     // Normalize Qa and Qb
     _logZ = 0.0;
     for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) {
-        _logZ += log(Qa[_a[alpha]].totalSum());
+        _logZ += log(Qa[_a[alpha]].sum());
         Qa[_a[alpha]].normalize();
     }
     for( size_t beta = 0; beta < _Qb.size(); beta++ ) {
-        _logZ -= log(Qb[_b[beta]].totalSum());
+        _logZ -= log(Qb[_b[beta]].sum());
         Qb[_b[beta]].normalize();
     }
 }
 
 
 double TreeEP::TreeEPSubTree::logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const {
-    double sum = 0.0;
+    double s = 0.0;
     for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
-        sum += (Qa[_a[alpha]] * _Qa[alpha].log0()).totalSum();
+        s += (Qa[_a[alpha]] * _Qa[alpha].log(true)).sum();
     for( size_t beta = 0; beta < _Qb.size(); beta++ )
-        sum -= (Qb[_b[beta]] * _Qb[beta].log0()).totalSum();
-    return sum + _logZ;
+        s -= (Qb[_b[beta]] * _Qb[beta].log(true)).sum();
+    return s + _logZ;
 }
 
 
@@ -235,7 +235,7 @@ TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opt
                             if( piet.vars() >> ij ) {
                                 piet = piet.marginal( ij );
                                 Factor pietf = piet.marginal(v_i) * piet.marginal(*j);
-                                wg[UEdge(i,findVar(*j))] = KL_dist( piet, pietf );
+                                wg[UEdge(i,findVar(*j))] = dist( piet, pietf, Prob::DISTKL );
                             } else
                                 wg[UEdge(i,findVar(*j))] = 0;
                         } else {
@@ -255,7 +255,7 @@ TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opt
 void TreeEP::ConstructRG( const DEdgeVec &tree ) {
     vector<VarSet> Cliques;
     for( size_t i = 0; i < tree.size(); i++ )
-        Cliques.push_back( var(tree[i].n1) | var(tree[i].n2) );
+        Cliques.push_back( VarSet( var(tree[i].n1), var(tree[i].n2) ) );
     
     // Construct a weighted graph (each edge is weighted with the cardinality 
     // of the intersection of the nodes, where the nodes are the elements of
@@ -336,7 +336,7 @@ void TreeEP::ConstructRG( const DEdgeVec &tree ) {
             /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
             PreviousRoot = subTree[0].n1;
             //subTree.resize( subTreeSize );  // FIXME
-//          cout << "subtree " << I << " has size " << subTreeSize << endl;
+//          cerr << "subtree " << I << " has size " << subTreeSize << endl;
 
             TreeEPSubTree QI( subTree, RTree, Qa, Qb, &factor(I) );
             _Q[I] = QI;
@@ -348,7 +348,7 @@ void TreeEP::ConstructRG( const DEdgeVec &tree ) {
             /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
             PreviousRoot = subTree[0].n1;
             //subTree.resize( subTreeSize ); // FIXME
-//          cout << "subtree " << I << " has size " << subTreeSize << endl;
+//          cerr << "subtree " << I << " has size " << subTreeSize << endl;
 
             TreeEPSubTree QI( subTree, RTree, Qa, Qb, &factor(I) );
             _Q[I] = QI;
@@ -356,7 +356,7 @@ void TreeEP::ConstructRG( const DEdgeVec &tree ) {
         }
 
     if( props.verbose >= 3 ) {
-        cout << "Resulting regiongraph: " << *this << endl;
+        cerr << "Resulting regiongraph: " << *this << endl;
     }
 }
 
@@ -378,9 +378,9 @@ void TreeEP::init() {
 
 double TreeEP::run() {
     if( props.verbose >= 1 )
-        cout << "Starting " << identify() << "...";
+        cerr << "Starting " << identify() << "...";
     if( props.verbose >= 3)
-        cout << endl;
+        cerr << endl;
 
     double tic = toc();
     Diffs diffs(nrVars(), 1.0);
@@ -408,7 +408,7 @@ double TreeEP::run() {
         }
 
         if( props.verbose >= 3 )
-            cout << Name << "::run:  maxdiff " << diffs.maxDiff() << " after " << _iters+1 << " passes" << endl;
+            cerr << Name << "::run:  maxdiff " << diffs.maxDiff() << " after " << _iters+1 << " passes" << endl;
     }
 
     if( diffs.maxDiff() > _maxdiff )
@@ -417,12 +417,12 @@ double TreeEP::run() {
     if( props.verbose >= 1 ) {
         if( diffs.maxDiff() > props.tol ) {
             if( props.verbose == 1 )
-                cout << endl;
-            cout << Name << "::run:  WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << diffs.maxDiff() << endl;
+                cerr << endl;
+            cerr << Name << "::run:  WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << diffs.maxDiff() << endl;
         } else {
             if( props.verbose >= 3 )
-                cout << Name << "::run:  ";
-            cout << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
+                cerr << Name << "::run:  ";
+            cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
         }
     }
 
@@ -431,24 +431,24 @@ double TreeEP::run() {
 
 
 Real TreeEP::logZ() const {
-    double sum = 0.0;
+    double s = 0.0;
 
     // entropy of the tree
     for( size_t beta = 0; beta < nrIRs(); beta++ )
-        sum -= Qb[beta].entropy();
+        s -= Qb[beta].entropy();
     for( size_t alpha = 0; alpha < nrORs(); alpha++ )
-        sum += Qa[alpha].entropy();
+        s += Qa[alpha].entropy();
 
     // energy of the on-tree factors
     for( size_t alpha = 0; alpha < nrORs(); alpha++ )
-        sum += (OR(alpha).log0() * Qa[alpha]).totalSum();
+        s += (OR(alpha).log(true) * Qa[alpha]).sum();
 
     // energy of the off-tree factors
     for( size_t I = 0; I < nrFactors(); I++ )
         if( offtree(I) )
-            sum += (_Q.find(I))->second.logZ( Qa, Qb );
+            s += (_Q.find(I))->second.logZ( Qa, Qb );
     
-    return sum;
+    return s;
 }