Removed stuff from InfAlg, moved it to individual inference algorithms
[libdai.git] / src / jtree.cpp
index 7fb6d4b..39c7cfb 100644 (file)
@@ -32,53 +32,52 @@ using namespace std;
 const char *JTree::Name = "JTREE";
 
 
-bool JTree::checkProperties() {
-    if (!HasProperty("verbose") )
-        return false;
-    if( !HasProperty("updates") )
-        return false;
+void JTree::setProperties( const PropertySet &opts ) {
+    assert( opts.hasKey("verbose") );
+    assert( opts.hasKey("updates") );
     
-    ConvertPropertyTo<size_t>("verbose");
-    ConvertPropertyTo<UpdateType>("updates");
+    props.verbose = opts.getStringAs<size_t>("verbose");
+    props.updates = opts.getStringAs<Properties::UpdateType>("updates");
+}
+
 
-    return true;
+PropertySet JTree::getProperties() const {
+    PropertySet opts;
+    opts.Set( "verbose", props.verbose );
+    opts.Set( "updates", props.updates );
+    return opts;
 }
 
 
-JTree::JTree( const FactorGraph &fg, const Properties &opts, bool automatic ) : DAIAlgRG(fg, opts), _RTree(), _Qa(), _Qb(), _mes(), _logZ() {
-    assert( checkProperties() );
+JTree::JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic ) : DAIAlgRG(fg), _RTree(), _Qa(), _Qb(), _mes(), _logZ(), props() {
+    setProperties( opts );
 
     if( automatic ) {
-        ClusterGraph _cg;
-
-        // Copy factors
+        // Copy VarSets of factors
+        vector<VarSet> cl;
+        cl.reserve( fg.nrFactors() );
         for( size_t I = 0; I < nrFactors(); I++ )
-            _cg.insert( factor(I).vars() );
-        if( Verbose() >= 3 )
+            cl.push_back( factor(I).vars() );
+        ClusterGraph _cg( cl );
+
+        if( props.verbose >= 3 )
             cout << "Initial clusters: " << _cg << endl;
 
         // Retain only maximal clusters
         _cg.eraseNonMaximal();
-        if( Verbose() >= 3 )
+        if( props.verbose >= 3 )
             cout << "Maximal clusters: " << _cg << endl;
 
         vector<VarSet> ElimVec = _cg.VarElim_MinFill().eraseNonMaximal().toVector();
-        if( Verbose() >= 3 ) {
-            cout << "VarElim_MinFill result: {" << endl;
-            for( size_t i = 0; i < ElimVec.size(); i++ ) {
-                if( i != 0 )
-                    cout << ", ";
-                cout << ElimVec[i];
-            }
-            cout << "}" << endl;
-        }
+        if( props.verbose >= 3 )
+            cout << "VarElim_MinFill result: " << ElimVec << endl;
 
         GenerateJT( ElimVec );
     }
 }
 
 
-void JTree::GenerateJT( const vector<VarSet> &Cliques ) {
+void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
     // Construct a weighted graph (each edge is weighted with the cardinality 
     // of the intersection of the nodes, where the nodes are the elements of
     // Cliques).
@@ -90,7 +89,7 @@ void JTree::GenerateJT( const vector<VarSet> &Cliques ) {
         }
     
     // Construct maximal spanning tree using Prim's algorithm
-    _RTree = MaxSpanningTreePrim( JuncGraph );
+    _RTree = MaxSpanningTreePrims( JuncGraph );
 
     // Construct corresponding region graph
 
@@ -115,7 +114,6 @@ void JTree::GenerateJT( const vector<VarSet> &Cliques ) {
 
     // Create inner regions and edges
     IRs.reserve( _RTree.size() );
-    typedef pair<size_t,size_t> Edge;
     vector<Edge> edges;
     edges.reserve( 2 * _RTree.size() );
     for( size_t i = 0; i < _RTree.size(); i++ ) {
@@ -151,7 +149,7 @@ void JTree::GenerateJT( const vector<VarSet> &Cliques ) {
     // Check counting numbers
     Check_Counting_Numbers();
 
-    if( Verbose() >= 3 ) {
+    if( props.verbose >= 3 ) {
         cout << "Resulting regiongraph: " << *this << endl;
     }
 }
@@ -159,7 +157,7 @@ void JTree::GenerateJT( const vector<VarSet> &Cliques ) {
 
 string JTree::identify() const {
     stringstream result (stringstream::out);
-    result << Name << GetProperties();
+    result << Name << getProperties();
     return result.str();
 }
 
@@ -290,9 +288,9 @@ void JTree::runShaferShenoy() {
 
 
 double JTree::run() {
-    if( Updates() == UpdateType::HUGIN )
+    if( props.updates == Properties::UpdateType::HUGIN )
         runHUGIN();
-    else if( Updates() == UpdateType::SHSH )
+    else if( props.updates == Properties::UpdateType::SHSH )
         runShaferShenoy();
     return 0.0;
 }
@@ -315,7 +313,7 @@ size_t JTree::findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t Previo
     // find new root clique (the one with maximal statespace overlap with ns)
     size_t maxval = 0, maxalpha = 0;
     for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
-        size_t val = (ns & OR(alpha).vars()).stateSpace();
+        size_t val = (ns & OR(alpha).vars()).states();
         if( val > maxval ) {
             maxval = val;
             maxalpha = alpha;
@@ -344,7 +342,7 @@ size_t JTree::findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t Previo
         // find first occurence of *n in the tree, which is closest to the root
         size_t e = 0;
         for( ; e != newTree.size(); e++ ) {
-            if( OR(newTree[e].n2).vars() && *n )
+            if( OR(newTree[e].n2).vars().contains( *n ) )
                 break;
         }
         assert( e != newTree.size() );
@@ -451,8 +449,6 @@ Factor JTree::calcMarginal( const VarSet& ns ) {
             VarSet nsrem = ns / OR(T.front().n1).vars();
             Factor Pns (ns, 0.0);
             
-            multind mi( nsrem );
-
             // Save _Qa and _Qb on the subtree
             map<size_t,Factor> _Qa_old;
             map<size_t,Factor> _Qb_old;
@@ -476,20 +472,18 @@ Factor JTree::calcMarginal( const VarSet& ns ) {
             }
                 
             // For all states of nsrem
-            for( size_t j = 0; j < mi.max(); j++ ) {
-                vector<size_t> vi = mi.vi( j );
+            for( State s(nsrem); s.valid(); s++ ) {
                 
                 // CollectEvidence
                 double logZ = 0.0;
                 for( size_t i = Tsize; (i--) != 0; ) {
-            //      Make outer region T[i].n1 consistent with outer region T[i].n2
-            //      IR(i) = seperator OR(T[i].n1) && OR(T[i].n2)
+                // Make outer region T[i].n1 consistent with outer region T[i].n2
+                // IR(i) = seperator OR(T[i].n1) && OR(T[i].n2)
 
-                    size_t k = 0;
-                    for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++, k++ )
+                    for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ )
                         if( _Qa[T[i].n2].vars() >> *n ) {
                             Factor piet( *n, 0.0 );
-                            piet[vi[k]] = 1.0;
+                            piet[s(*n)] = 1.0;
                             _Qa[T[i].n2] *= piet; 
                         }
 
@@ -501,7 +495,7 @@ Factor JTree::calcMarginal( const VarSet& ns ) {
                 logZ += log(_Qa[T[0].n1].normalize( Prob::NORMPROB ));
 
                 Factor piet( nsrem, 0.0 );
-                piet[j] = exp(logZ);
+                piet[s] = exp(logZ);
                 Pns += piet * _Qa[T[0].n1].part_sum( ns / nsrem );      // OPTIMIZE ME
 
                 // Restore clamped beliefs