Removed stuff from InfAlg, moved it to individual inference algorithms
[libdai.git] / src / treeep.cpp
index f076aa8..9ce7354 100644 (file)
@@ -37,22 +37,26 @@ using namespace std;
 const char *TreeEP::Name = "TREEEP";
 
 
-bool TreeEP::checkProperties() {
-    if( !HasProperty("type") )
-        return false;
-    if( !HasProperty("tol") )
-        return false;
-    if (!HasProperty("maxiter") )
-        return false;
-    if (!HasProperty("verbose") )
-        return false;
+void TreeEP::setProperties( const PropertySet &opts ) {
+    assert( opts.hasKey("tol") );
+    assert( opts.hasKey("maxiter") );
+    assert( opts.hasKey("verbose") );
+    assert( opts.hasKey("type") );
     
-    ConvertPropertyTo<TypeType>("type");
-    ConvertPropertyTo<double>("tol");
-    ConvertPropertyTo<size_t>("maxiter");
-    ConvertPropertyTo<size_t>("verbose");
+    props.tol = opts.getStringAs<double>("tol");
+    props.maxiter = opts.getStringAs<size_t>("maxiter");
+    props.verbose = opts.getStringAs<size_t>("verbose");
+    props.type = opts.getStringAs<Properties::TypeType>("type");
+}
+
 
-    return true;
+PropertySet TreeEP::getProperties() const {
+    PropertySet opts;
+    opts.Set( "tol", props.tol );
+    opts.Set( "maxiter", props.maxiter );
+    opts.Set( "verbose", props.verbose );
+    opts.Set( "type", props.type );
+    return opts;
 }
 
 
@@ -182,15 +186,15 @@ double TreeEPSubTree::logZ( const std::vector<Factor> &Qa, const std::vector<Fac
 }
 
 
-TreeEP::TreeEP( const FactorGraph &fg, const Properties &opts ) : JTree(fg, opts("updates",string("HUGIN")), false) {
-    assert( checkProperties() );
+TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opts("updates",string("HUGIN")), false), props(), maxdiff(0.0) {
+    setProperties( opts );
 
     assert( fg.G.isConnected() );
 
     if( opts.hasKey("tree") ) {
         ConstructRG( opts.GetAs<DEdgeVec>("tree") );
     } else {
-        if( Type() == TypeType::ORG ) {
+        if( props.type == Properties::TypeType::ORG ) {
             // construct weighted graph with as weights a crude estimate of the
             // mutual information between the nodes
             WeightedGraph<double> wg;
@@ -217,15 +221,15 @@ TreeEP::TreeEP( const FactorGraph &fg, const Properties &opts ) : JTree(fg, opts
             }
 
             // find maximal spanning tree
-            ConstructRG( MaxSpanningTreePrim( wg ) );
+            ConstructRG( MaxSpanningTreePrims( wg ) );
 
 //            cout << "Constructing maximum spanning tree..." << endl;
-//            DEdgeVec MST = MaxSpanningTreePrim( wg );
+//            DEdgeVec MST = MaxSpanningTreePrims( wg );
 //            cout << "Maximum spanning tree:" << endl;
 //            for( DEdgeVec::const_iterator e = MST.begin(); e != MST.end(); e++ )
 //                cout << *e << endl; 
 //            ConstructRG( MST );
-        } else if( Type() == TypeType::ALT ) {
+        } else if( props.type == Properties::TypeType::ALT ) {
             // construct weighted graph with as weights an upper bound on the
             // effective interaction strength between pairs of nodes
             WeightedGraph<double> wg;
@@ -245,7 +249,7 @@ TreeEP::TreeEP( const FactorGraph &fg, const Properties &opts ) : JTree(fg, opts
             }
 
             // find maximal spanning tree
-            ConstructRG( MaxSpanningTreePrim( wg ) );
+            ConstructRG( MaxSpanningTreePrims( wg ) );
         } else {
             assert( 0 == 1 );
         }
@@ -269,7 +273,7 @@ void TreeEP::ConstructRG( const DEdgeVec &tree ) {
         }
     
     // Construct maximal spanning tree using Prim's algorithm
-    _RTree = MaxSpanningTreePrim( JuncGraph );
+    _RTree = MaxSpanningTreePrims( JuncGraph );
 
     // Construct corresponding region graph
 
@@ -297,7 +301,6 @@ void TreeEP::ConstructRG( const DEdgeVec &tree ) {
 
     // Create inner regions and edges
     IRs.reserve( _RTree.size() );
-    typedef pair<size_t,size_t> Edge;
     vector<Edge> edges;
     edges.reserve( 2 * _RTree.size() );
     for( size_t i = 0; i < _RTree.size(); i++ ) {
@@ -376,7 +379,7 @@ void TreeEP::ConstructRG( const DEdgeVec &tree ) {
             break;
         }
 
-    if( Verbose() >= 3 ) {
+    if( props.verbose >= 3 ) {
         cout << "Resulting regiongraph: " << *this << endl;
     }
 }
@@ -384,14 +387,12 @@ void TreeEP::ConstructRG( const DEdgeVec &tree ) {
 
 string TreeEP::identify() const { 
     stringstream result (stringstream::out);
-    result << Name << GetProperties();
+    result << Name << getProperties();
     return result.str();
 }
 
 
 void TreeEP::init() {
-    assert( checkProperties() );
-
     runHUGIN();
 
     // Init factor approximations
@@ -402,9 +403,9 @@ void TreeEP::init() {
 
 
 double TreeEP::run() {
-    if( Verbose() >= 1 )
+    if( props.verbose >= 1 )
         cout << "Starting " << identify() << "...";
-    if( Verbose() >= 3)
+    if( props.verbose >= 3)
         cout << endl;
 
     double tic = toc();
@@ -419,7 +420,7 @@ double TreeEP::run() {
     
     // do several passes over the network until maximum number of iterations has
     // been reached or until the maximum belief difference is smaller than tolerance
-    for( iter=0; iter < MaxIter() && diffs.maxDiff() > Tol(); iter++ ) {
+    for( iter=0; iter < props.maxiter && diffs.maxDiff() > props.tol; iter++ ) {
         for( size_t I = 0; I < nrFactors(); I++ )
             if( offtree(I) ) {  
                 _Q[I].InvertAndMultiply( _Qa, _Qb );
@@ -434,19 +435,20 @@ double TreeEP::run() {
             old_beliefs[i] = nb;
         }
 
-        if( Verbose() >= 3 )
+        if( props.verbose >= 3 )
             cout << "TreeEP::run:  maxdiff " << diffs.maxDiff() << " after " << iter+1 << " passes" << endl;
     }
 
-    updateMaxDiff( diffs.maxDiff() );
+    if( diffs.maxDiff() > maxdiff )
+        maxdiff = diffs.maxDiff();
 
-    if( Verbose() >= 1 ) {
-        if( diffs.maxDiff() > Tol() ) {
-            if( Verbose() == 1 )
+    if( props.verbose >= 1 ) {
+        if( diffs.maxDiff() > props.tol ) {
+            if( props.verbose == 1 )
                 cout << endl;
-            cout << "TreeEP::run:  WARNING: not converged within " << MaxIter() << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl;
+            cout << "TreeEP::run:  WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl;
         } else {
-            if( Verbose() >= 3 )
+            if( props.verbose >= 3 )
                 cout << "TreeEP::run:  ";
             cout << "converged in " << iter << " passes (" << toc() - tic << " clocks)." << endl;
         }