Replaced sub_nb class in mr.h by boost::dynamic_bitset
[libdai.git] / src / jtree.cpp
index b846d35..0f28d5b 100644 (file)
@@ -32,40 +32,56 @@ using namespace std;
 const char *JTree::Name = "JTREE";
 
 
-bool JTree::checkProperties() {
-    if (!HasProperty("verbose") )
-        return false;
-    if( !HasProperty("updates") )
-        return false;
+void JTree::setProperties( const PropertySet &opts ) {
+    assert( opts.hasKey("verbose") );
+    assert( opts.hasKey("updates") );
     
-    ConvertPropertyTo<size_t>("verbose");
-    ConvertPropertyTo<UpdateType>("updates");
+    props.verbose = opts.getStringAs<size_t>("verbose");
+    props.updates = opts.getStringAs<Properties::UpdateType>("updates");
+}
+
+
+PropertySet JTree::getProperties() const {
+    PropertySet opts;
+    opts.Set( "verbose", props.verbose );
+    opts.Set( "updates", props.updates );
+    return opts;
+}
+
 
-    return true;
+string JTree::printProperties() const {
+    stringstream s( stringstream::out );
+    s << "[";
+    s << "verbose=" << props.verbose << ",";
+    s << "updates=" << props.updates << "]";
+    return s.str();
 }
 
 
-JTree::JTree( const FactorGraph &fg, const Properties &opts, bool automatic ) : DAIAlgRG(fg, opts), _RTree(), _Qa(), _Qb(), _mes(), _logZ() {
-    assert( checkProperties() );
+JTree::JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic ) : DAIAlgRG(fg), _RTree(), _Qa(), _Qb(), _mes(), _logZ(), props() {
+    setProperties( opts );
+
+    if( !isConnected() ) 
+       DAI_THROW(FACTORGRAPH_NOT_CONNECTED); 
 
     if( automatic ) {
-        // Copy VarSets of factors
+        // Create ClusterGraph which contains factors as clusters
         vector<VarSet> cl;
         cl.reserve( fg.nrFactors() );
         for( size_t I = 0; I < nrFactors(); I++ )
             cl.push_back( factor(I).vars() );
         ClusterGraph _cg( cl );
 
-        if( Verbose() >= 3 )
+        if( props.verbose >= 3 )
             cout << "Initial clusters: " << _cg << endl;
 
         // Retain only maximal clusters
         _cg.eraseNonMaximal();
-        if( Verbose() >= 3 )
+        if( props.verbose >= 3 )
             cout << "Maximal clusters: " << _cg << endl;
 
         vector<VarSet> ElimVec = _cg.VarElim_MinFill().eraseNonMaximal().toVector();
-        if( Verbose() >= 3 )
+        if( props.verbose >= 3 )
             cout << "VarElim_MinFill result: " << ElimVec << endl;
 
         GenerateJT( ElimVec );
@@ -81,7 +97,8 @@ void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
     for( size_t i = 0; i < Cliques.size(); i++ )
         for( size_t j = i+1; j < Cliques.size(); j++ ) {
             size_t w = (Cliques[i] & Cliques[j]).size();
-            JuncGraph[UEdge(i,j)] = w;
+            if( w ) 
+                JuncGraph[UEdge(i,j)] = w;
         }
     
     // Construct maximal spanning tree using Prim's algorithm
@@ -100,7 +117,6 @@ void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
         size_t alpha;
         for( alpha = 0; alpha < nrORs(); alpha++ )
             if( OR(alpha).vars() >> factor(I).vars() ) {
-//              OR(alpha) *= factor(I);
                 fac2OR.push_back( alpha );
                 break;
             }
@@ -120,7 +136,7 @@ void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
     }
 
     // create bipartite graph
-    G.create( nrORs(), nrIRs(), edges.begin(), edges.end() );
+    G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() );
 
     // Create messages and beliefs
     _Qa.clear();
@@ -145,16 +161,14 @@ void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
     // Check counting numbers
     Check_Counting_Numbers();
 
-    if( Verbose() >= 3 ) {
+    if( props.verbose >= 3 ) {
         cout << "Resulting regiongraph: " << *this << endl;
     }
 }
 
 
 string JTree::identify() const {
-    stringstream result (stringstream::out);
-    result << Name << GetProperties();
-    return result.str();
+    return string(Name) + printProperties();
 }
 
 
@@ -204,15 +218,15 @@ void JTree::runHUGIN() {
     for( size_t i = _RTree.size(); (i--) != 0; ) {
 //      Make outer region _RTree[i].n1 consistent with outer region _RTree[i].n2
 //      IR(i) = seperator OR(_RTree[i].n1) && OR(_RTree[i].n2)
-        Factor new_Qb = _Qa[_RTree[i].n2].part_sum( IR( i ) );
-        _logZ += log(new_Qb.normalize( Prob::NORMPROB ));
+        Factor new_Qb = _Qa[_RTree[i].n2].partSum( IR( i ) );
+        _logZ += log(new_Qb.normalize());
         _Qa[_RTree[i].n1] *= new_Qb.divided_by( _Qb[i] ); 
         _Qb[i] = new_Qb;
     }
     if( _RTree.empty() )
-        _logZ += log(_Qa[0].normalize( Prob::NORMPROB ) );
+        _logZ += log(_Qa[0].normalize() );
     else
-        _logZ += log(_Qa[_RTree[0].n1].normalize( Prob::NORMPROB ));
+        _logZ += log(_Qa[_RTree[0].n1].normalize());
 
     // DistributeEvidence
     for( size_t i = 0; i < _RTree.size(); i++ ) {
@@ -225,7 +239,7 @@ void JTree::runHUGIN() {
 
     // Normalize
     for( size_t alpha = 0; alpha < nrORs(); alpha++ )
-        _Qa[alpha].normalize( Prob::NORMPROB );
+        _Qa[alpha].normalize();
 }
 
 
@@ -245,8 +259,8 @@ void JTree::runShaferShenoy() {
         foreach( const Neighbor &k, nbOR(i) )
             if( k != e ) 
                 piet *= message( i, k.iter );
-        message( j, _e ) = piet.part_sum( IR(e) );
-        _logZ += log( message(j,_e).normalize( Prob::NORMPROB ) );
+        message( j, _e ) = piet.partSum( IR(e) );
+        _logZ += log( message(j,_e).normalize() );
     }
 
     // Second pass
@@ -257,7 +271,7 @@ void JTree::runShaferShenoy() {
         
         Factor piet = OR(i);
         foreach( const Neighbor &k, nbOR(i) )
-            if(  k != e )
+            if( k != e )
                 piet *= message( i, k.iter );
         message( j, _e ) = piet.marginal( IR(e) );
     }
@@ -268,13 +282,13 @@ void JTree::runShaferShenoy() {
         foreach( const Neighbor &k, nbOR(alpha) )
             piet *= message( alpha, k.iter );
         if( nrIRs() == 0 ) {
-            _logZ += log( piet.normalize( Prob::NORMPROB ) );
+            _logZ += log( piet.normalize() );
             _Qa[alpha] = piet;
         } else if( alpha == nbIR(0)[0].node /*_RTree[0].n1*/ ) {
-            _logZ += log( piet.normalize( Prob::NORMPROB ) );
+            _logZ += log( piet.normalize() );
             _Qa[alpha] = piet;
         } else
-            _Qa[alpha] = piet.normalized( Prob::NORMPROB );
+            _Qa[alpha] = piet.normalized();
     }
 
     // Only for logZ (and for belief)...
@@ -284,20 +298,20 @@ void JTree::runShaferShenoy() {
 
 
 double JTree::run() {
-    if( Updates() == UpdateType::HUGIN )
+    if( props.updates == Properties::UpdateType::HUGIN )
         runHUGIN();
-    else if( Updates() == UpdateType::SHSH )
+    else if( props.updates == Properties::UpdateType::SHSH )
         runShaferShenoy();
     return 0.0;
 }
 
 
-Complex JTree::logZ() const {
-    Complex sum = 0.0;
+Real JTree::logZ() const {
+    Real sum = 0.0;
     for( size_t beta = 0; beta < nrIRs(); beta++ )
-        sum += Complex(IR(beta).c()) * _Qb[beta].entropy();
+        sum += IR(beta).c() * _Qb[beta].entropy();
     for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
-        sum += Complex(OR(alpha).c()) * _Qa[alpha].entropy();
+        sum += OR(alpha).c() * _Qa[alpha].entropy();
         sum += (OR(alpha).log0() * _Qa[alpha]).totalSum();
     }
     return sum;
@@ -316,29 +330,21 @@ size_t JTree::findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t Previo
         }
     }
 
-//  for( size_t e = 0; e < _RTree.size(); e++ )
-//      cout << OR(_RTree[e].n1).vars() << "->" << OR(_RTree[e].n2).vars() << ",  ";
-//  cout << endl;
     // grow new tree
     Graph oldTree;
     for( DEdgeVec::const_iterator e = _RTree.begin(); e != _RTree.end(); e++ )
         oldTree.insert( UEdge(e->n1, e->n2) );
     DEdgeVec newTree = GrowRootedTree( oldTree, maxalpha );
-//  cout << ns << ": ";
-//  for( size_t e = 0; e < newTree.size(); e++ )
-//      cout << OR(newTree[e].n1).vars() << "->" << OR(newTree[e].n2).vars() << ",  ";
-//  cout << endl;
     
     // identify subtree that contains variables of ns which are not in the new root
     VarSet nsrem = ns / OR(maxalpha).vars();
-//  cout << "nsrem:" << nsrem << endl;
     set<DEdge> subTree;
     // for each variable in ns that is not in the root clique
     for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ ) {
         // find first occurence of *n in the tree, which is closest to the root
         size_t e = 0;
         for( ; e != newTree.size(); e++ ) {
-            if( OR(newTree[e].n2).vars() && *n )
+            if( OR(newTree[e].n2).vars().contains( *n ) )
                 break;
         }
         assert( e != newTree.size() );
@@ -370,10 +376,6 @@ size_t JTree::findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t Previo
                 pos = newTree[e-1].n1;
             }
     }
-//  cout << "subTree: " << endl;
-//  for( set<DEdge>::const_iterator sTi = subTree.begin(); sTi != subTree.end(); sTi++ )
-//      cout << OR(sTi->n1).vars() << "->" << OR(sTi->n2).vars() << ",  ";
-//  cout << endl;
 
     // Resulting Tree is a reordered copy of newTree
     // First add edges in subTree to Tree
@@ -381,9 +383,7 @@ size_t JTree::findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t Previo
     for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
         if( subTree.count( *e ) ) {
             Tree.push_back( *e );
-//          cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ",  ";
         }
-//  cout << endl;
     // Then add edges pointing away from nsrem
     // FIXME
 /*  for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
@@ -392,7 +392,6 @@ size_t JTree::findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t Previo
                 if( e->n1 == sTi->n1 || e->n1 == sTi->n2 ||
                     e->n2 == sTi->n1 || e->n2 == sTi->n2 ) {
                     Tree.push_back( *e );
-//                  cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ",  ";
                 }
             }*/
     // FIXME
@@ -406,10 +405,8 @@ size_t JTree::findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t Previo
                 }
             if( found ) {
                 Tree.push_back( *e );
-                cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ",  ";
             }
-        }
-    cout << endl;*/
+        }*/
     size_t subTreeSize = Tree.size();
     // Then add remaining edges
     for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
@@ -469,7 +466,6 @@ Factor JTree::calcMarginal( const VarSet& ns ) {
                 
             // For all states of nsrem
             for( State s(nsrem); s.valid(); s++ ) {
-                
                 // CollectEvidence
                 double logZ = 0.0;
                 for( size_t i = Tsize; (i--) != 0; ) {
@@ -483,16 +479,16 @@ Factor JTree::calcMarginal( const VarSet& ns ) {
                             _Qa[T[i].n2] *= piet; 
                         }
 
-                    Factor new_Qb = _Qa[T[i].n2].part_sum( IR( b[i] ) );
-                    logZ += log(new_Qb.normalize( Prob::NORMPROB ));
+                    Factor new_Qb = _Qa[T[i].n2].partSum( IR( b[i] ) );
+                    logZ += log(new_Qb.normalize());
                     _Qa[T[i].n1] *= new_Qb.divided_by( _Qb[b[i]] ); 
                     _Qb[b[i]] = new_Qb;
                 }
-                logZ += log(_Qa[T[0].n1].normalize( Prob::NORMPROB ));
+                logZ += log(_Qa[T[0].n1].normalize());
 
                 Factor piet( nsrem, 0.0 );
                 piet[s] = exp(logZ);
-                Pns += piet * _Qa[T[0].n1].part_sum( ns / nsrem );      // OPTIMIZE ME
+                Pns += piet * _Qa[T[0].n1].partSum( ns / nsrem );      // OPTIMIZE ME
 
                 // Restore clamped beliefs
                 for( map<size_t,Factor>::const_iterator alpha = _Qa_old.begin(); alpha != _Qa_old.end(); alpha++ )
@@ -501,10 +497,39 @@ Factor JTree::calcMarginal( const VarSet& ns ) {
                     _Qb[beta->first] = beta->second;
             }
 
-            return( Pns.normalized(Prob::NORMPROB) );
+            return( Pns.normalized() );
         }
     }
 }
 
 
+// first return value is treewidth
+// second return value is number of states in largest clique
+pair<size_t,size_t> treewidth( const FactorGraph & fg ) {
+    ClusterGraph _cg;
+
+    // Copy factors
+    for( size_t I = 0; I < fg.nrFactors(); I++ )
+        _cg.insert( fg.factor(I).vars() );
+
+    // Retain only maximal clusters
+    _cg.eraseNonMaximal();
+
+    // Obtain elimination sequence
+    vector<VarSet> ElimVec = _cg.VarElim_MinFill().eraseNonMaximal().toVector();
+
+    // Calculate treewidth
+    size_t treewidth = 0;
+    size_t nrstates = 0;
+    for( size_t i = 0; i < ElimVec.size(); i++ ) {
+        if( ElimVec[i].size() > treewidth )
+            treewidth = ElimVec[i].size();
+        if( ElimVec[i].states() > nrstates )
+            nrstates = ElimVec[i].states();
+    }
+
+    return pair<size_t,size_t>(treewidth, nrstates);
+}
+
+
 } // end of namespace dai