Replaced ENUM2,ENUM3,ENUM4,ENUM5,ENUM6 by single DAI_ENUM macro.
[libdai.git] / src / treeep.cpp
index f0b0ba7..6f5b797 100644 (file)
@@ -37,26 +37,30 @@ using namespace std;
 const char *TreeEP::Name = "TREEEP";
 
 
-bool TreeEP::checkProperties() {
-    if( !HasProperty("type") )
-        return false;
-    if( !HasProperty("tol") )
-        return false;
-    if (!HasProperty("maxiter") )
-        return false;
-    if (!HasProperty("verbose") )
-        return false;
+void TreeEP::setProperties( const PropertySet &opts ) {
+    assert( opts.hasKey("tol") );
+    assert( opts.hasKey("maxiter") );
+    assert( opts.hasKey("verbose") );
+    assert( opts.hasKey("type") );
     
-    ConvertPropertyTo<TypeType>("type");
-    ConvertPropertyTo<double>("tol");
-    ConvertPropertyTo<size_t>("maxiter");
-    ConvertPropertyTo<size_t>("verbose");
+    props.tol = opts.getStringAs<double>("tol");
+    props.maxiter = opts.getStringAs<size_t>("maxiter");
+    props.verbose = opts.getStringAs<size_t>("verbose");
+    props.type = opts.getStringAs<Properties::TypeType>("type");
+}
+
 
-    return true;
+PropertySet TreeEP::getProperties() const {
+    PropertySet opts;
+    opts.Set( "tol", props.tol );
+    opts.Set( "maxiter", props.maxiter );
+    opts.Set( "verbose", props.verbose );
+    opts.Set( "type", props.type );
+    return opts;
 }
 
 
-TreeEPSubTree::TreeEPSubTree( const DEdgeVec &subRTree, const DEdgeVec &jt_RTree, const vector<Factor> &jt_Qa, const vector<Factor> &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) {
+TreeEPSubTree::TreeEPSubTree( const DEdgeVec &subRTree, const DEdgeVec &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) {
     _ns = _I->vars();
 
     // Make _Qa, _Qb, _a and _b corresponding to the subtree
@@ -102,7 +106,7 @@ void TreeEPSubTree::init() {
 }
 
 
-void TreeEPSubTree::InvertAndMultiply( const vector<Factor> &Qa, const vector<Factor> &Qb ) {
+void TreeEPSubTree::InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) {
     for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
         _Qa[alpha] = Qa[_a[alpha]].divided_by( _Qa[alpha] );
 
@@ -111,9 +115,7 @@ void TreeEPSubTree::InvertAndMultiply( const vector<Factor> &Qa, const vector<Fa
 }
 
 
-void TreeEPSubTree::HUGIN_with_I( vector<Factor> &Qa, vector<Factor> &Qb ) {
-    multind mi( _nsrem );
-
+void TreeEPSubTree::HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb ) {
     // Backup _Qa and _Qb
     vector<Factor> _Qa_old(_Qa);
     vector<Factor> _Qb_old(_Qb);
@@ -125,20 +127,17 @@ void TreeEPSubTree::HUGIN_with_I( vector<Factor> &Qa, vector<Factor> &Qb ) {
         Qb[_b[beta]].fill( 0.0 );
     
     // For all states of _nsrem
-    for( size_t j = 0; j < mi.max(); j++ ) {
-        vector<size_t> vi = mi.vi( j );
-        
+    for( State s(_nsrem); s.valid(); s++ ) {
         // Multiply root with slice of I
-        _Qa[0] *= _I->slice( _nsrem, j );
+        _Qa[0] *= _I->slice( _nsrem, s );
 
         // CollectEvidence
         for( size_t i = _RTree.size(); (i--) != 0; ) {
             // clamp variables in nsrem
-            size_t k = 0;
-            for( VarSet::const_iterator n = _nsrem.begin(); n != _nsrem.end(); n++, k++ )
+            for( VarSet::const_iterator n = _nsrem.begin(); n != _nsrem.end(); n++ )
                 if( _Qa[_RTree[i].n2].vars() >> *n ) {
                     Factor delta( *n, 0.0 );
-                    delta[vi[k]] = 1.0;
+                    delta[s(*n)] = 1.0;
                     _Qa[_RTree[i].n2] *= delta;
                 }
             Factor new_Qb = _Qa[_RTree[i].n2].part_sum( _Qb[i].vars() );
@@ -177,7 +176,7 @@ void TreeEPSubTree::HUGIN_with_I( vector<Factor> &Qa, vector<Factor> &Qb ) {
 }
 
 
-double TreeEPSubTree::logZ( const vector<Factor> &Qa, const vector<Factor> &Qb ) const {
+double TreeEPSubTree::logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const {
     double sum = 0.0;
     for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
         sum += (Qa[_a[alpha]] * _Qa[alpha].log0()).totalSum();
@@ -187,70 +186,72 @@ double TreeEPSubTree::logZ( const vector<Factor> &Qa, const vector<Factor> &Qb )
 }
 
 
-TreeEP::TreeEP( const FactorGraph &fg, const Properties &opts ) : JTree(fg, opts("updates",string("HUGIN")), false) {
-    assert( checkProperties() );
+TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opts("updates",string("HUGIN")), false), props(), maxdiff(0.0) {
+    setProperties( opts );
 
-    assert( fg.isConnected() );
+    assert( fg.G.isConnected() );
 
     if( opts.hasKey("tree") ) {
         ConstructRG( opts.GetAs<DEdgeVec>("tree") );
     } else {
-        if( Type() == TypeType::ORG ) {
+        if( props.type == Properties::TypeType::ORG ) {
             // construct weighted graph with as weights a crude estimate of the
             // mutual information between the nodes
             WeightedGraph<double> wg;
-            for( vector<Var>::const_iterator i = vars().begin(); i != vars().end(); i++ ) {
-                VarSet di = delta(*i);
+            for( size_t i = 0; i < nrVars(); ++i ) {
+                Var v_i = var(i);
+                VarSet di = delta(i);
                 for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
-                    if( *i < *j ) {
+                    if( v_i < *j ) {
                         Factor piet;
                         for( size_t I = 0; I < nrFactors(); I++ ) {
                             VarSet Ivars = factor(I).vars();
-                            if( (Ivars == *i) || (Ivars == *j) )
+                            if( (Ivars == v_i) || (Ivars == *j) )
                                 piet *= factor(I);
-                            else if( Ivars >> (*i | *j) )
-                                piet *= factor(I).marginal( *i | *j );
+                            else if( Ivars >> (v_i | *j) )
+                                piet *= factor(I).marginal( v_i | *j );
                         }
-                        if( piet.vars() >> (*i | *j) ) {
-                            piet = piet.marginal( *i | *j );
-                            Factor pietf = piet.marginal(*i) * piet.marginal(*j);
-                            wg[UEdge(findVar(*i),findVar(*j))] = real( KL_dist( piet, pietf ) );
+                        if( piet.vars() >> (v_i | *j) ) {
+                            piet = piet.marginal( v_i | *j );
+                            Factor pietf = piet.marginal(v_i) * piet.marginal(*j);
+                            wg[UEdge(i,findVar(*j))] = KL_dist( piet, pietf );
                         } else
-                            wg[UEdge(findVar(*i),findVar(*j))] = 0;
+                            wg[UEdge(i,findVar(*j))] = 0;
                     }
             }
 
             // find maximal spanning tree
-            ConstructRG( MaxSpanningTreePrim( wg ) );
+            ConstructRG( MaxSpanningTreePrims( wg ) );
 
 //            cout << "Constructing maximum spanning tree..." << endl;
-//            DEdgeVec MST = MaxSpanningTreePrim( wg );
+//            DEdgeVec MST = MaxSpanningTreePrims( wg );
 //            cout << "Maximum spanning tree:" << endl;
 //            for( DEdgeVec::const_iterator e = MST.begin(); e != MST.end(); e++ )
 //                cout << *e << endl; 
 //            ConstructRG( MST );
-        } else if( Type() == TypeType::ALT ) {
+        } else if( props.type == Properties::TypeType::ALT ) {
             // construct weighted graph with as weights an upper bound on the
             // effective interaction strength between pairs of nodes
             WeightedGraph<double> wg;
-            for( vector<Var>::const_iterator i = vars().begin(); i != vars().end(); i++ ) {
-                VarSet di = delta(*i);
+            for( size_t i = 0; i < nrVars(); ++i ) {
+                Var v_i = var(i);
+                VarSet di = delta(i);
                 for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
-                    if( *i < *j ) {
+                    if( v_i < *j ) {
                         Factor piet;
                         for( size_t I = 0; I < nrFactors(); I++ ) {
                             VarSet Ivars = factor(I).vars();
-                            if( Ivars >> (*i | *j) )
+                            if( Ivars >> (v_i | *j) )
                                 piet *= factor(I);
                         }
-                        wg[UEdge(findVar(*i),findVar(*j))] = piet.strength(*i, *j);
+                        wg[UEdge(i,findVar(*j))] = piet.strength(v_i, *j);
                     }
             }
 
             // find maximal spanning tree
-            ConstructRG( MaxSpanningTreePrim( wg ) );
+            ConstructRG( MaxSpanningTreePrims( wg ) );
         } else {
-            assert( 0 == 1 );
+            DAI_THROW(INTERNAL_ERROR);
         }
     }
 }
@@ -272,24 +273,26 @@ void TreeEP::ConstructRG( const DEdgeVec &tree ) {
         }
     
     // Construct maximal spanning tree using Prim's algorithm
-    _RTree = MaxSpanningTreePrim( JuncGraph );
+    _RTree = MaxSpanningTreePrims( JuncGraph );
 
     // Construct corresponding region graph
 
     // Create outer regions
-    ORs().reserve( Cliques.size() );
+    ORs.reserve( Cliques.size() );
     for( size_t i = 0; i < Cliques.size(); i++ )
-        ORs().push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) );
+        ORs.push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) );
 
     // For each factor, find an outer region that subsumes that factor.
     // Then, multiply the outer region with that factor.
     // If no outer region can be found subsuming that factor, label the
     // factor as off-tree.
+    fac2OR.clear();
+    fac2OR.resize( nrFactors(), -1U );
     for( size_t I = 0; I < nrFactors(); I++ ) {
         size_t alpha;
-        for( alpha = 0; alpha < nr_ORs(); alpha++ )
+        for( alpha = 0; alpha < nrORs(); alpha++ )
             if( OR(alpha).vars() >> factor(I).vars() ) {
-                _fac2OR[I] = alpha;
+                fac2OR[I] = alpha;
                 break;
             }
     // DIFF WITH JTree::GenerateJT:      assert
@@ -297,30 +300,31 @@ void TreeEP::ConstructRG( const DEdgeVec &tree ) {
     RecomputeORs();
 
     // Create inner regions and edges
-    IRs().reserve( _RTree.size() );
-    Redges().reserve( 2 * _RTree.size() );
+    IRs.reserve( _RTree.size() );
+    vector<Edge> edges;
+    edges.reserve( 2 * _RTree.size() );
     for( size_t i = 0; i < _RTree.size(); i++ ) {
-        Redges().push_back( R_edge_t( _RTree[i].n1, IRs().size() ) );
-        Redges().push_back( R_edge_t( _RTree[i].n2, IRs().size() ) );
+        edges.push_back( Edge( _RTree[i].n1, IRs.size() ) );
+        edges.push_back( Edge( _RTree[i].n2, IRs.size() ) );
         // inner clusters have counting number -1
-        IRs().push_back( Region( Cliques[_RTree[i].n1] & Cliques[_RTree[i].n2], -1.0 ) );
+        IRs.push_back( Region( Cliques[_RTree[i].n1] & Cliques[_RTree[i].n2], -1.0 ) );
     }
 
-    // Regenerate BipartiteGraph internals
-    Regenerate();
+    // create bipartite graph
+    G.create( nrORs(), nrIRs(), edges.begin(), edges.end() );
 
     // Check counting numbers
     Check_Counting_Numbers();
     
     // Create messages and beliefs
     _Qa.clear();
-    _Qa.reserve( nr_ORs() );
-    for( size_t alpha = 0; alpha < nr_ORs(); alpha++ )
+    _Qa.reserve( nrORs() );
+    for( size_t alpha = 0; alpha < nrORs(); alpha++ )
         _Qa.push_back( OR(alpha) );
 
     _Qb.clear();
-    _Qb.reserve( nr_IRs() );
-    for( size_t beta = 0; beta < nr_IRs(); beta++ ) 
+    _Qb.reserve( nrIRs() );
+    for( size_t beta = 0; beta < nrIRs(); beta++ ) 
         _Qb.push_back( Factor( IR(beta), 1.0 ) );
 
     // DIFF with JTree::GenerateJT:  no messages
@@ -375,7 +379,7 @@ void TreeEP::ConstructRG( const DEdgeVec &tree ) {
             break;
         }
 
-    if( Verbose() >= 3 ) {
+    if( props.verbose >= 3 ) {
         cout << "Resulting regiongraph: " << *this << endl;
     }
 }
@@ -383,14 +387,12 @@ void TreeEP::ConstructRG( const DEdgeVec &tree ) {
 
 string TreeEP::identify() const { 
     stringstream result (stringstream::out);
-    result << Name << GetProperties();
+    result << Name << getProperties();
     return result.str();
 }
 
 
 void TreeEP::init() {
-    assert( checkProperties() );
-
     runHUGIN();
 
     // Init factor approximations
@@ -401,12 +403,12 @@ void TreeEP::init() {
 
 
 double TreeEP::run() {
-    if( Verbose() >= 1 )
+    if( props.verbose >= 1 )
         cout << "Starting " << identify() << "...";
-    if( Verbose() >= 3)
+    if( props.verbose >= 3)
         cout << endl;
 
-    clock_t tic = toc();
+    double tic = toc();
     Diffs diffs(nrVars(), 1.0);
 
     vector<Factor> old_beliefs;
@@ -418,7 +420,7 @@ double TreeEP::run() {
     
     // do several passes over the network until maximum number of iterations has
     // been reached or until the maximum belief difference is smaller than tolerance
-    for( iter=0; iter < MaxIter() && diffs.max() > Tol(); iter++ ) {
+    for( iter=0; iter < props.maxiter && diffs.maxDiff() > props.tol; iter++ ) {
         for( size_t I = 0; I < nrFactors(); I++ )
             if( offtree(I) ) {  
                 _Q[I].InvertAndMultiply( _Qa, _Qb );
@@ -433,39 +435,40 @@ double TreeEP::run() {
             old_beliefs[i] = nb;
         }
 
-        if( Verbose() >= 3 )
-            cout << "TreeEP::run:  maxdiff " << diffs.max() << " after " << iter+1 << " passes" << endl;
+        if( props.verbose >= 3 )
+            cout << "TreeEP::run:  maxdiff " << diffs.maxDiff() << " after " << iter+1 << " passes" << endl;
     }
 
-    updateMaxDiff( diffs.max() );
+    if( diffs.maxDiff() > maxdiff )
+        maxdiff = diffs.maxDiff();
 
-    if( Verbose() >= 1 ) {
-        if( diffs.max() > Tol() ) {
-            if( Verbose() == 1 )
+    if( props.verbose >= 1 ) {
+        if( diffs.maxDiff() > props.tol ) {
+            if( props.verbose == 1 )
                 cout << endl;
-            cout << "TreeEP::run:  WARNING: not converged within " << MaxIter() << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.max() << endl;
+            cout << "TreeEP::run:  WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl;
         } else {
-            if( Verbose() >= 3 )
+            if( props.verbose >= 3 )
                 cout << "TreeEP::run:  ";
             cout << "converged in " << iter << " passes (" << toc() - tic << " clocks)." << endl;
         }
     }
 
-    return diffs.max();
+    return diffs.maxDiff();
 }
 
 
-Complex TreeEP::logZ() const {
+Real TreeEP::logZ() const {
     double sum = 0.0;
 
     // entropy of the tree
-    for( size_t beta = 0; beta < nr_IRs(); beta++ )
-        sum -= real(_Qb[beta].entropy());
-    for( size_t alpha = 0; alpha < nr_ORs(); alpha++ )
-        sum += real(_Qa[alpha].entropy());
+    for( size_t beta = 0; beta < nrIRs(); beta++ )
+        sum -= _Qb[beta].entropy();
+    for( size_t alpha = 0; alpha < nrORs(); alpha++ )
+        sum += _Qa[alpha].entropy();
 
     // energy of the on-tree factors
-    for( size_t alpha = 0; alpha < nr_ORs(); alpha++ )
+    for( size_t alpha = 0; alpha < nrORs(); alpha++ )
         sum += (OR(alpha).log0() * _Qa[alpha]).totalSum();
 
     // energy of the off-tree factors