Improved documentation of include/dai/jtree.h and did some cleanups
authorJoris Mooij <joris.mooij@tuebingen.mpg.de>
Thu, 5 Nov 2009 16:20:44 +0000 (17:20 +0100)
committerJoris Mooij <joris.mooij@tuebingen.mpg.de>
Thu, 5 Nov 2009 16:20:44 +0000 (17:20 +0100)
12 files changed:
ChangeLog
OBSOLETE
TODO
include/dai/exactinf.h
include/dai/jtree.h
include/dai/treeep.h
include/dai/weightedgraph.h
src/exactinf.cpp
src/jtree.cpp
src/treeep.cpp
src/weightedgraph.cpp
utils/createfg.cpp

index 19ec573..7281092 100644 (file)
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,3 +1,8 @@
+* Renamed Treewidth(const FactorGraph &) into boundTreewidth(const FactorGraph &)
+* Added ExactInf::calcMarginal(const VarSet &)
+* Replaced UEdgeVec by Graph
+* Replaced DEdgeVec by new RootedTree class
+* Moved functionality of GrowRootedTree() into constructor of RootedTree
 * Extended InfAlg interface with setProperties(), getProperties() and printProperties()
 * Declared class Diffs as obsolete
 * Declared calcPairBeliefsNew() as obsolete (its functionality is now part
index ce6c62e..010da58 100644 (file)
--- a/OBSOLETE
+++ b/OBSOLETE
@@ -28,3 +28,7 @@ void RegionGraph::Calc_Counting_Numbers();
 std::vector<Factor> calcPairBeliefsNew( const InfAlg& obj, const VarSet& vs, bool reInit );
 Factor calcMarginal2ndO( const InfAlg& obj, const VarSet& vs, bool reInit );
 class Diffs;
+std::pair<size_t,size_t> Treewidth( const FactorGraph & fg );
+typedef UEdgeVec;
+typedef DEdgeVec;
+DEdgeVec GrowRootedTree( const Graph &T, size_t Root );
diff --git a/TODO b/TODO
index 86e6181..055f118 100644 (file)
--- a/TODO
+++ b/TODO
@@ -2,7 +2,6 @@ To do for the next release (0.2.3):
 
 Improve documentation:
 
-       jtree.h
        hak.h
        treeep.h
        lc.h
@@ -15,6 +14,7 @@ Improve documentation:
        emalg.h
        doc.h
                - add http to reference about maximum-residual BP
+               - merge COPYING
 
 Write a concept/notations page for the documentation,
 explaining the concepts of "state" (index into a 
index 48f15e4..ed1512f 100644 (file)
@@ -84,6 +84,15 @@ class ExactInf : public DAIAlgFG {
         virtual std::string printProperties() const;
     //@}
 
+    /// \name Additional interface specific for JTree
+    //@{
+        /// Calculates marginal probability distribution for variables \a vs
+        /** \note The complexity of this calculation is exponential in the number of variables.
+         */
+        Factor calcMarginal( const VarSet &vs ) const;
+    //@}
+
+
     private:
         /// Helper function for constructors
         void construct();
index 799cfe4..a67988a 100644 (file)
@@ -10,8 +10,7 @@
 
 
 /// \file
-/// \brief Defines class JTree
-/// \todo Improve documentation
+/// \brief Defines class JTree, which implements the junction tree algorithm
 
 
 #ifndef __defined_libdai_jtree_h
@@ -34,20 +33,31 @@ namespace dai {
 
 
 /// Exact inference algorithm using junction tree
+/** The junction tree algorithm uses message passing on a junction tree to calculate
+ *  exact marginal probability distributions ("beliefs") for specified cliques
+ *  (outer regions) and separators (intersections of pairs of cliques).
+ *
+ *  There are two variants, the sum-product algorithm (corresponding to 
+ *  finite temperature) and the max-product algorithm (corresponding to 
+ *  zero temperature).
+ */
 class JTree : public DAIAlgRG {
     private:
+        /// Stores the messages
         std::vector<std::vector<Factor> >  _mes;
+
+        /// Stores the logarithm of the partition sum
         Real _logZ;
 
     public:
-        /// Rooted tree
-        DEdgeVec             RTree;
+        /// The junction tree (stored as a rooted tree)
+        RootedTree RTree;
 
         /// Outer region beliefs
-        std::vector<Factor>  Qa;
+        std::vector<Factor> Qa;
 
         /// Inner region beliefs
-        std::vector<Factor>  Qb;
+        std::vector<Factor> Qb;
 
         /// Parameters of this inference algorithm
         struct Properties {
@@ -60,10 +70,10 @@ class JTree : public DAIAlgRG {
             /// Verbosity
             size_t verbose;
 
-            /// Type of updates
+            /// Type of updates: HUGIN or Shafer-Shenoy
             UpdateType updates;
 
-            /// Type of inference: sum-product or max-product?
+            /// Type of inference: sum-product or max-product
             InfType inference;
         } props;
 
@@ -71,11 +81,18 @@ class JTree : public DAIAlgRG {
         static const char *Name;
 
     public:
+    /// \name Constructors/destructors
+    //@{
         /// Default constructor
         JTree() : DAIAlgRG(), _mes(), _logZ(), RTree(), Qa(), Qb(), props() {}
 
-        /// Construct from FactorGraph fg and PropertySet opts
+        /// Construct from FactorGraph \a fg and PropertySet \a opts
+        /** \param fg factor graph (which has to be connected);
+         *  \param opts parameters;
+         *  \param automatic if \c true, construct the junction tree automatically, using the MinFill heuristic.
+         */
         JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic=true );
+    //@}
 
 
     /// \name General InfAlg interface
@@ -99,37 +116,67 @@ class JTree : public DAIAlgRG {
 
     /// \name Additional interface specific for JTree
     //@{
-        void GenerateJT( const std::vector<VarSet> &Cliques );
+        /// Constructs a junction tree based on the cliques \a cl (corresponding to some elimination sequence).
+        /** First, constructs a weighted graph, where the nodes are the elements of \a cl, and 
+         *  each edge is weighted with the cardinality of the intersection of the state spaces of the nodes. 
+         *  Then, a maximal spanning tree for this weighted graph is calculated.
+         *  Finally, a corresponding region graph is built:
+         *    - the outer regions correspond with the cliques and have counting number 1;
+         *    - the inner regions correspond with the seperators, i.e., the intersections of two 
+         *      cliques that are neighbors in the spanning tree, and have counting number -1;
+         *    - inner and outer regions are connected by an edge if the inner region is a
+         *      seperator for the outer region.
+         */
+        void GenerateJT( const std::vector<VarSet> &cl );
 
-        /// Returns reference the message from outer region alpha to its _beta'th neighboring inner region
-        Factor & message( size_t alpha, size_t _beta ) { return _mes[alpha][_beta]; }
-        /// Returns const reference to the message from outer region alpha to its _beta'th neighboring inner region
+        /// Returns constant reference to the message from outer region \a alpha to its \a _beta 'th neighboring inner region
         const Factor & message( size_t alpha, size_t _beta ) const { return _mes[alpha][_beta]; }
+        /// Returns reference to the message from outer region \a alpha to its \a _beta 'th neighboring inner region
+        Factor & message( size_t alpha, size_t _beta ) { return _mes[alpha][_beta]; }
 
-        /// Runs junction-tree with HUGIN updates
+        /// Runs junction tree algorithm using HUGIN updates
+        /** \note The initial messages may be arbitrary.
+         */
         void runHUGIN();
 
-        /// Runs junction-tree with Shafer-Shenoy updates
+        /// Runs junction tree algorithm using Shafer-Shenoy updates
+        /** \note The initial messages may be arbitrary.
+         */
         void runShaferShenoy();
 
-        /// Finds an efficient tree for calculating the marginal of some variables
-        size_t findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t PreviousRoot=(size_t)-1 ) const;
+        /// Finds an efficient subtree for calculating the marginal of the variables in \a vs
+        /** First, the current junction tree is reordered such that it gets as root the clique 
+         *  that has maximal state space overlap with \a vs. Then, the minimal subtree
+         *  (starting from the root) is identified that contains all the variables in \a vs
+         *  and also the outer region with index \a PreviousRoot (if specified). Finally,
+         *  the current junction tree is reordered such that this minimal subtree comes
+         *  before the other edges, and the size of the minimal subtree is returned.
+         */
+        size_t findEfficientTree( const VarSet& vs, RootedTree &Tree, size_t PreviousRoot=(size_t)-1 ) const;
 
-        /// Calculates the marginal of a set of variables
-        Factor calcMarginal( const VarSet& ns );
+        /// Calculates the marginal of a set of variables (using cutset conditioning, if necessary)
+        /** \pre assumes that run() has been called already
+         */
+        Factor calcMarginal( const VarSet& vs );
 
         /// Calculates the joint state of all variables that has maximum probability
-        /** Assumes that run() has been called and that props.inference == MAXPROD
+        /** \pre Assumes that run() has been called and that \a props.inference == \c MAXPROD
          */
         std::vector<std::size_t> findMaximum() const;
     //@}
 };
 
 
-/// Calculates upper bound to the treewidth of a FactorGraph
+/// Calculates upper bound to the treewidth of a FactorGraph, using the MinFill heuristic
 /** \relates JTree
  *  \return a pair (number of variables in largest clique, number of states in largest clique)
  */
+std::pair<size_t,size_t> boundTreewidth( const FactorGraph & fg );
+
+
+/// Calculates upper bound to the treewidth of a FactorGraph, using the MinFill heuristic
+/** \deprecated Renamed into boundTreewidth()
+ */
 std::pair<size_t,size_t> treewidth( const FactorGraph & fg );
 
 
index 4f56094..40050a0 100644 (file)
@@ -69,7 +69,7 @@ class TreeEP : public JTree {
             private:
                 std::vector<Factor>  _Qa;
                 std::vector<Factor>  _Qb;
-                DEdgeVec             _RTree;
+                RootedTree           _RTree;
                 std::vector<size_t>  _a;        // _Qa[alpha]  <->  superTree.Qa[_a[alpha]]
                 std::vector<size_t>  _b;        // _Qb[beta]   <->  superTree.Qb[_b[beta]]
                                                 // _Qb[beta]   <->  _RTree[beta]
@@ -97,7 +97,7 @@ class TreeEP : public JTree {
                     return *this;
                 }
 
-                TreeEPSubTree( const DEdgeVec &subRTree, const DEdgeVec &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I );
+                TreeEPSubTree( const RootedTree &subRTree, const RootedTree &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I );
                 void init();
                 void InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb );
                 void HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb );
@@ -154,7 +154,7 @@ class TreeEP : public JTree {
 
 
     private:
-        void ConstructRG( const DEdgeVec &tree );
+        void ConstructRG( const RootedTree &tree );
         bool offtree( size_t I ) const { return (fac2OR[I] == -1U); }
 };
 
index 96a3830..36c98fe 100644 (file)
@@ -10,7 +10,7 @@
 
 
 /** \file
- *  \brief Defines some utility functions for (weighted) undirected graphs
+ *  \brief Defines some utility functions for (weighted) undirected graphs, trees and rooted trees.
  *  \todo Improve general support for graphs and trees.
  */
 
@@ -36,9 +36,10 @@ namespace dai {
 /// Represents a directed edge
 class DEdge {
     public:
-        /// First node index
+        /// First node index (source of edge)
         size_t n1;
-        /// Second node index
+
+        /// Second node index (sink of edge)
         size_t n2;
 
         /// Default constructor
@@ -110,30 +111,64 @@ class UEdge {
 };
 
 
-/// Vector of UEdge
-typedef std::vector<UEdge>  UEdgeVec;
+/// Represents an undirected graph, implemented as a std::set of undirected edges
+class Graph : public std::set<UEdge> {
+    public:
+        /// Default constructor
+        Graph() {}
 
-/// Vector of DEdge
-typedef std::vector<DEdge>  DEdgeVec;
+        /// Construct from range of objects that can be cast to DEdge
+        template <class InputIterator>
+        Graph( InputIterator begin, InputIterator end ) {
+            insert( begin, end );
+        }
+};
 
-/// Represents an undirected weighted graph, with weights of type \a T
+
+/// Represents an undirected weighted graph, with weights of type \a T, implemented as a std::map mapping undirected edges to weights
 template<class T> class WeightedGraph : public std::map<UEdge, T> {};
 
-/// Represents an undirected graph
-typedef std::set<UEdge>     Graph;
+
+/// Represents a rooted tree, implemented as a vector of directed edges
+/** By convention, the edges are stored such that they point away from 
+ *  the root and such that edges nearer to the root come before edges
+ *  farther away from the root.
+ */
+class RootedTree : public std::vector<DEdge> {
+    public:
+        /// Default constructor
+        RootedTree() {}
+
+        /// Constructs a rooted tree from a tree and a root
+        /** \pre T has no cycles and contains node \a Root
+         */
+        RootedTree( const Graph &T, size_t Root );
+};
 
 
+// OBSOLETE
+/// Vector of UEdge
+/** \deprecated Please use Graph instead
+ */
+typedef std::vector<UEdge>  UEdgeVec;
+// OBSOLETE
+/// Vector of DEdge
+/** \deprecated Please use RootedTree instead
+ */
+typedef std::vector<DEdge>  DEdgeVec;
+
+// OBSOLETE
 /// Constructs a rooted tree from a tree and a root
 /** \pre T has no cycles and contains node \a Root
+ *  \deprecated Please use RootedTree::RootedTree(const Graph &, size_t) instead
  */
 DEdgeVec GrowRootedTree( const Graph &T, size_t Root );
 
-
 /// Uses Prim's algorithm to construct a minimal spanning tree from the (positively) weighted graph \a G.
 /** Uses implementation in Boost Graph Library.
  */
-template<typename T> DEdgeVec MinSpanningTreePrims( const WeightedGraph<T> &G ) {
-    DEdgeVec result;
+template<typename T> RootedTree MinSpanningTreePrims( const WeightedGraph<T> &G ) {
+    RootedTree result;
     if( G.size() > 0 ) {
         using namespace boost;
         using namespace std;
@@ -164,20 +199,19 @@ template<typename T> DEdgeVec MinSpanningTreePrims( const WeightedGraph<T> &G )
                 tree.insert( UEdge( p[i], i ) );
             else
                 root = i;
-        // Order them
-        result = GrowRootedTree( tree, root );
+        // Order them to obtain a rooted tree
+        result = RootedTree( tree, root );
     }
-
     return result;
 }
 
 
-/// Use Prim's algorithm to construct a minimal spanning tree from the (positively) weighted graph \a G.
+/// Use Prim's algorithm to construct a maximal spanning tree from the (positively) weighted graph \a G.
 /** Uses implementation in Boost Graph Library.
  */
-template<typename T> DEdgeVec MaxSpanningTreePrims( const WeightedGraph<T> &G ) {
+template<typename T> RootedTree MaxSpanningTreePrims( const WeightedGraph<T> &G ) {
     if( G.size() == 0 )
-        return DEdgeVec();
+        return RootedTree();
     else {
         T maxweight = G.begin()->second;
         for( typename WeightedGraph<T>::const_iterator it = G.begin(); it != G.end(); it++ )
@@ -201,7 +235,7 @@ template<typename T> DEdgeVec MaxSpanningTreePrims( const WeightedGraph<T> &G )
  *  (which becomes uniform in the limit that \a d is small and \a N goes
  *  to infinity).
  */
-UEdgeVec RandomDRegularGraph( size_t N, size_t d );
+Graph RandomDRegularGraph( size_t N, size_t d );
 
 
 } // end of namespace dai
index 312b797..f4a60d3 100644 (file)
@@ -89,6 +89,14 @@ Real ExactInf::run() {
 }
 
 
+Factor ExactInf::calcMarginal( const VarSet &vs ) const {
+    Factor P;
+    for( size_t I = 0; I < nrFactors(); I++ )
+        P *= factor(I);
+    return P.marginal( vs, true );
+}
+
+
 vector<Factor> ExactInf::beliefs() const {
     vector<Factor> result = _beliefsV;
     result.insert( result.end(), _beliefsF.begin(), _beliefsF.end() );
index a938ee7..8d96691 100644 (file)
@@ -77,10 +77,12 @@ JTree::JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic ) :
         if( props.verbose >= 3 )
             cerr << "Maximal clusters: " << _cg << endl;
 
+        // Use MinFill heuristic to guess optimal elimination sequence
         vector<VarSet> ElimVec = _cg.VarElim_MinFill().eraseNonMaximal().toVector();
         if( props.verbose >= 3 )
             cerr << "VarElim_MinFill result: " << ElimVec << endl;
 
+        // Generate the junction tree corresponding to the elimination sequence
         GenerateJT( ElimVec );
     }
 }
@@ -156,11 +158,11 @@ void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
     }
 
     // Check counting numbers
-    checkCountingNumbers();
+    if( DAI_DEBUG )
+        checkCountingNumbers();
 
-    if( props.verbose >= 3 ) {
-        cerr << "Resulting regiongraph: " << *this << endl;
-    }
+    if( props.verbose >= 3 )
+        cerr << "Regiongraph generated by JTree::GenerateJT: " << *this << endl;
 }
 
 
@@ -169,20 +171,20 @@ string JTree::identify() const {
 }
 
 
-Factor JTree::belief( const VarSet &ns ) const {
+Factor JTree::belief( const VarSet &vs ) const {
     vector<Factor>::const_iterator beta;
     for( beta = Qb.begin(); beta != Qb.end(); beta++ )
-        if( beta->vars() >> ns )
+        if( beta->vars() >> vs )
             break;
     if( beta != Qb.end() )
-        return( beta->marginal(ns) );
+        return( beta->marginal(vs) );
     else {
         vector<Factor>::const_iterator alpha;
         for( alpha = Qa.begin(); alpha != Qa.end(); alpha++ )
-            if( alpha->vars() >> ns )
+            if( alpha->vars() >> vs )
                 break;
         DAI_ASSERT( alpha != Qa.end() );
-        return( alpha->marginal(ns) );
+        return( alpha->marginal(vs) );
     }
 }
 
@@ -197,12 +199,11 @@ vector<Factor> JTree::beliefs() const {
 }
 
 
-Factor JTree::belief( const Var &n ) const {
-    return belief( (VarSet)n );
+Factor JTree::belief( const Var &v ) const {
+    return belief( (VarSet)v );
 }
 
 
-// Needs no init
 void JTree::runHUGIN() {
     for( size_t alpha = 0; alpha < nrORs(); alpha++ )
         Qa[alpha] = OR(alpha);
@@ -250,7 +251,6 @@ void JTree::runHUGIN() {
 }
 
 
-// Really needs no init! Initial messages can be anything.
 void JTree::runShaferShenoy() {
     // First pass
     _logZ = 0.0;
@@ -335,29 +335,25 @@ Real JTree::logZ() const {
 }
 
 
-
-size_t JTree::findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t PreviousRoot ) const {
-    // find new root clique (the one with maximal statespace overlap with ns)
+size_t JTree::findEfficientTree( const VarSet& vs, RootedTree &Tree, size_t PreviousRoot ) const {
+    // find new root clique (the one with maximal statespace overlap with vs)
     size_t maxval = 0, maxalpha = 0;
     for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
-        size_t val = VarSet(ns & OR(alpha).vars()).nrStates();
+        size_t val = VarSet(vs & OR(alpha).vars()).nrStates();
         if( val > maxval ) {
             maxval = val;
             maxalpha = alpha;
         }
     }
 
-    // grow new tree
-    Graph oldTree;
-    for( DEdgeVec::const_iterator e = RTree.begin(); e != RTree.end(); e++ )
-        oldTree.insert( UEdge(e->n1, e->n2) );
-    DEdgeVec newTree = GrowRootedTree( oldTree, maxalpha );
+    // reorder the tree edges such that maxalpha becomes the new root
+    RootedTree newTree( Graph( RTree.begin(), RTree.end() ), maxalpha );
 
-    // identify subtree that contains variables of ns which are not in the new root
-    VarSet nsrem = ns / OR(maxalpha).vars();
+    // identify subtree that contains all variables of vs which are not in the new root
+    VarSet vsrem = vs / OR(maxalpha).vars();
     set<DEdge> subTree;
-    // for each variable in ns that is not in the root clique
-    for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ ) {
+    // for each variable in vs that is not in the root clique
+    for( VarSet::const_iterator n = vsrem.begin(); n != vsrem.end(); n++ ) {
         // find first occurence of *n in the tree, which is closest to the root
         size_t e = 0;
         for( ; e != newTree.size(); e++ ) {
@@ -397,67 +393,42 @@ size_t JTree::findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t Previo
     // Resulting Tree is a reordered copy of newTree
     // First add edges in subTree to Tree
     Tree.clear();
-    for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
-        if( subTree.count( *e ) ) {
+    vector<DEdge> remTree;
+    for( RootedTree::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
+        if( subTree.count( *e ) )
             Tree.push_back( *e );
-        }
-    // Then add edges pointing away from nsrem
-    // FIXME
-/*  for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
-        for( set<DEdge>::const_iterator sTi = subTree.begin(); sTi != subTree.end(); sTi++ )
-            if( *e != *sTi ) {
-                if( e->n1 == sTi->n1 || e->n1 == sTi->n2 ||
-                    e->n2 == sTi->n1 || e->n2 == sTi->n2 ) {
-                    Tree.push_back( *e );
-                }
-            }*/
-    // FIXME
-/*  for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
-        if( find( Tree.begin(), Tree.end(), *e) == Tree.end() ) {
-            bool found = false;
-            for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ )
-                if( (OR(e->n1).vars() && *n) ) {
-                    found = true;
-                    break;
-                }
-            if( found ) {
-                Tree.push_back( *e );
-            }
-        }*/
+        else
+            remTree.push_back( *e );
     size_t subTreeSize = Tree.size();
     // Then add remaining edges
-    for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
-        if( find( Tree.begin(), Tree.end(), *e ) == Tree.end() )
-            Tree.push_back( *e );
+    copy( remTree.begin(), remTree.end(), back_inserter( Tree ) );
 
     return subTreeSize;
 }
 
 
-// Cutset conditioning
-// assumes that run() has been called already
-Factor JTree::calcMarginal( const VarSet& ns ) {
+Factor JTree::calcMarginal( const VarSet& vs ) {
     vector<Factor>::const_iterator beta;
     for( beta = Qb.begin(); beta != Qb.end(); beta++ )
-        if( beta->vars() >> ns )
+        if( beta->vars() >> vs )
             break;
     if( beta != Qb.end() )
-        return( beta->marginal(ns) );
+        return( beta->marginal(vs) );
     else {
         vector<Factor>::const_iterator alpha;
         for( alpha = Qa.begin(); alpha != Qa.end(); alpha++ )
-            if( alpha->vars() >> ns )
+            if( alpha->vars() >> vs )
                 break;
         if( alpha != Qa.end() )
-            return( alpha->marginal(ns) );
+            return( alpha->marginal(vs) );
         else {
             // Find subtree to do efficient inference
-            DEdgeVec T;
-            size_t Tsize = findEfficientTree( ns, T );
+            RootedTree T;
+            size_t Tsize = findEfficientTree( vs, T );
 
             // Find remaining variables (which are not in the new root)
-            VarSet nsrem = ns / OR(T.front().n1).vars();
-            Factor Pns (ns, 0.0);
+            VarSet vsrem = vs / OR(T.front().n1).vars();
+            Factor Pvs (vs, 0.0);
 
             // Save Qa and Qb on the subtree
             map<size_t,Factor> Qa_old;
@@ -481,15 +452,15 @@ Factor JTree::calcMarginal( const VarSet& ns ) {
                     Qb_old[beta] = Qb[beta];
             }
 
-            // For all states of nsrem
-            for( State s(nsrem); s.valid(); s++ ) {
+            // For all states of vsrem
+            for( State s(vsrem); s.valid(); s++ ) {
                 // CollectEvidence
                 Real logZ = 0.0;
                 for( size_t i = Tsize; (i--) != 0; ) {
                 // Make outer region T[i].n1 consistent with outer region T[i].n2
                 // IR(i) = seperator OR(T[i].n1) && OR(T[i].n2)
 
-                    for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ )
+                    for( VarSet::const_iterator n = vsrem.begin(); n != vsrem.end(); n++ )
                         if( Qa[T[i].n2].vars() >> *n ) {
                             Factor piet( *n, 0.0 );
                             piet[s(*n)] = 1.0;
@@ -503,9 +474,9 @@ Factor JTree::calcMarginal( const VarSet& ns ) {
                 }
                 logZ += log(Qa[T[0].n1].normalize());
 
-                Factor piet( nsrem, 0.0 );
+                Factor piet( vsrem, 0.0 );
                 piet[s] = exp(logZ);
-                Pns += piet * Qa[T[0].n1].marginal( ns / nsrem, false );      // OPTIMIZE ME
+                Pvs += piet * Qa[T[0].n1].marginal( vs / vsrem, false );      // OPTIMIZE ME
 
                 // Restore clamped beliefs
                 for( map<size_t,Factor>::const_iterator alpha = Qa_old.begin(); alpha != Qa_old.end(); alpha++ )
@@ -514,17 +485,13 @@ Factor JTree::calcMarginal( const VarSet& ns ) {
                     Qb[beta->first] = beta->second;
             }
 
-            return( Pns.normalized() );
+            return( Pvs.normalized() );
         }
     }
 }
 
 
-/// Calculates upper bound to the treewidth of a FactorGraph
-/** \relates JTree
- *  \return a pair (number of variables in largest clique, number of states in largest clique)
- */
-std::pair<size_t,size_t> treewidth( const FactorGraph & fg ) {
+std::pair<size_t,size_t> boundTreewidth( const FactorGraph & fg ) {
     ClusterGraph _cg;
 
     // Copy factors
@@ -552,6 +519,12 @@ std::pair<size_t,size_t> treewidth( const FactorGraph & fg ) {
 }
 
 
+std::pair<size_t,size_t> treewidth( const FactorGraph & fg )
+{
+    return boundTreewidth( fg );
+}
+
+
 std::vector<size_t> JTree::findMaximum() const {
     vector<size_t> maximum( nrVars() );
     vector<bool> visitedVars( nrVars(), false );
index 989d37e..ed4e87c 100644 (file)
@@ -60,7 +60,7 @@ string TreeEP::printProperties() const {
 }
 
 
-TreeEP::TreeEPSubTree::TreeEPSubTree( const DEdgeVec &subRTree, const DEdgeVec &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) {
+TreeEP::TreeEPSubTree::TreeEPSubTree( const RootedTree &subRTree, const RootedTree &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) {
     _ns = _I->vars();
 
     // Make _Qa, _Qb, _a and _b corresponding to the subtree
@@ -192,7 +192,7 @@ TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opt
     DAI_ASSERT( fg.isConnected() );
 
     if( opts.hasKey("tree") ) {
-        ConstructRG( opts.GetAs<DEdgeVec>("tree") );
+        ConstructRG( opts.GetAs<RootedTree>("tree") );
     } else {
         if( props.type == Properties::TypeType::ORG || props.type == Properties::TypeType::ALT ) {
             // ORG: construct weighted graph with as weights a crude estimate of the
@@ -241,7 +241,7 @@ TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opt
 }
 
 
-void TreeEP::ConstructRG( const DEdgeVec &tree ) {
+void TreeEP::ConstructRG( const RootedTree &tree ) {
     vector<VarSet> Cliques;
     for( size_t i = 0; i < tree.size(); i++ )
         Cliques.push_back( VarSet( var(tree[i].n1), var(tree[i].n2) ) );
@@ -321,7 +321,7 @@ void TreeEP::ConstructRG( const DEdgeVec &tree ) {
     for( size_t I = 0; I < nrFactors(); I++ )
         if( offtree(I) ) {
             // find efficient subtree
-            DEdgeVec subTree;
+            RootedTree subTree;
             /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
             PreviousRoot = subTree[0].n1;
             //subTree.resize( subTreeSize );  // FIXME
@@ -333,7 +333,7 @@ void TreeEP::ConstructRG( const DEdgeVec &tree ) {
     // Previous root of first off-tree factor should be the root of the last off-tree factor
     for( size_t I = 0; I < nrFactors(); I++ )
         if( offtree(I) ) {
-            DEdgeVec subTree;
+            RootedTree subTree;
             /*size_t subTreeSize =*/ findEfficientTree( factor(I).vars(), subTree, PreviousRoot );
             PreviousRoot = subTree[0].n1;
             //subTree.resize( subTreeSize ); // FIXME
index 94bb8c1..1a0ccfb 100644 (file)
@@ -21,51 +21,48 @@ namespace dai {
 using namespace std;
 
 
-DEdgeVec GrowRootedTree( const Graph &T, size_t Root ) {
-    DEdgeVec result;
-    if( T.size() == 0 )
-        return result;
-    else {
+RootedTree::RootedTree( const Graph &T, size_t Root ) {
+    if( T.size() != 0 ) {
         // Make a copy
         Graph Gr = T;
 
         // Nodes in the tree
-        set<size_t> treeV;
+        set<size_t> nodes;
 
         // Start with the root
-        treeV.insert( Root );
+        nodes.insert( Root );
 
         // Keep adding edges until done
         while( !(Gr.empty()) )
             for( Graph::iterator e = Gr.begin(); e != Gr.end(); ) {
-                bool e1_in_treeV = treeV.count( e->n1 );
-                bool e2_in_treeV = treeV.count( e->n2 );
-                DAI_ASSERT( !(e1_in_treeV && e2_in_treeV) );
-                if( e1_in_treeV ) {
+                bool e1_in_nodes = nodes.count( e->n1 );
+                bool e2_in_nodes = nodes.count( e->n2 );
+                DAI_ASSERT( !(e1_in_nodes && e2_in_nodes) );
+                if( e1_in_nodes ) {
                     // Add directed edge, pointing away from the root
-                    result.push_back( DEdge( e->n1, e->n2 ) );
-                    treeV.insert( e->n2 );
+                    push_back( DEdge( e->n1, e->n2 ) );
+                    nodes.insert( e->n2 );
                     // Erase the edge
                     Gr.erase( e++ );
-                } else if( e2_in_treeV ) {
-                    result.push_back( DEdge( e->n2, e->n1 ) );
-                    treeV.insert( e->n1 );
+                } else if( e2_in_nodes ) {
+                    // Add directed edge, pointing away from the root
+                    push_back( DEdge( e->n2, e->n1 ) );
+                    nodes.insert( e->n1 );
                     // Erase the edge
                     Gr.erase( e++ );
                 } else
                     e++;
             }
 
-        return result;
     }
 }
 
 
-UEdgeVec RandomDRegularGraph( size_t N, size_t d ) {
+Graph RandomDRegularGraph( size_t N, size_t d ) {
     DAI_ASSERT( (N * d) % 2 == 0 );
 
     bool ready = false;
-    UEdgeVec G;
+    std::vector<UEdge> G;
 
     size_t tries = 0;
     while( !ready ) {
@@ -107,9 +104,9 @@ UEdgeVec RandomDRegularGraph( size_t N, size_t d ) {
 
             vector<size_t> degrees;
             degrees.resize( N, 0 );
-            for( UEdgeVec::const_iterator e = G.begin(); e != G.end(); e++ ) {
-                degrees[e->n1]++;
-                degrees[e->n2]++;
+            foreach( const UEdge &e, G ) {
+                degrees[e.n1]++;
+                degrees[e.n2]++;
             }
             ready = true;
             for( size_t n = 0; n < N; n++ )
@@ -121,7 +118,7 @@ UEdgeVec RandomDRegularGraph( size_t N, size_t d ) {
             ready = false;
     }
 
-    return G;
+    return Graph( G.begin(), G.end() );
 }
 
 
index 5edf476..6ffbf42 100644 (file)
@@ -264,9 +264,9 @@ void MakeDRegFG( size_t N, size_t d, Real mean_w, Real mean_th, Real sigma_w, Re
     matrix w(N,N,(d*N)/2);
     vector<Real> th(N,0.0);
 
-    UEdgeVec g = RandomDRegularGraph( N, d );
-    for( size_t i = 0; i < g.size(); i++ )
-        w(g[i].n1, g[i].n2) = rnd_stdnormal() * sigma_w + mean_w;
+    Graph g = RandomDRegularGraph( N, d );
+    foreach( const UEdge &e, g )
+        w(e.n1, e.n2) = rnd_stdnormal() * sigma_w + mean_w;
 
     for( size_t i = 0; i < N; i++ )
         th[i] = rnd_stdnormal() * sigma_th + mean_th;