Merged jtree.h and jtree.cpp from SVN head
authorJoris Mooij <jorism@marvin.jorismooij.nl>
Sat, 27 Sep 2008 15:57:28 +0000 (17:57 +0200)
committerJoris Mooij <jorism@marvin.jorismooij.nl>
Sat, 27 Sep 2008 15:57:28 +0000 (17:57 +0200)
STATUS
include/dai/jtree.h
src/hak.cpp
src/jtree.cpp
src/lc.cpp
src/mr.cpp
src/treeep.cpp

diff --git a/STATUS b/STATUS
index ffad5f0..66d4ced 100644 (file)
--- a/STATUS
+++ b/STATUS
@@ -99,6 +99,8 @@ mf.h
 mf.cpp
 bp.h
 bp.cpp
+jtree.h
+jtree.cpp
 
 FILES IN SVN HEAD THAT ARE STILL RELEVANT:
 ChangeLog
@@ -107,8 +109,6 @@ TODO
 
 hak.h
 hak.cpp
-jtree.h
-jtree.cpp
 lc.h
 lc.cpp
 mr.h
index 4cfb7e0..ccb733c 100644 (file)
@@ -52,16 +52,23 @@ class JTree : public DAIAlgRG {
             DAI_ENUM(UpdateType,HUGIN,SHSH)
             UpdateType updates;
         } props;
+        /// Name of this inference method
+        static const char *Name;
 
     public:
+        /// Default constructor
         JTree() : DAIAlgRG(), _RTree(), _Qa(), _Qb(), _mes(), _logZ(), props() {}
-        JTree( const JTree& x ) : DAIAlgRG(x), _RTree(x._RTree), _Qa(x._Qa), _Qb(x._Qb), _mes(x._mes), _logZ(x._logZ), props(x.props) {}
-        JTree* clone() const { return new JTree(*this); }
-        /// Create (virtual constructor)
-        virtual JTree* create() const { return new JTree(); }
-        JTree & operator=( const JTree& x ) {
+
+        /// Construct JTree object using the specified properties
+        JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic=true );
+
+        /// Copy constructor
+        JTree( const JTree &x ) : DAIAlgRG(x), _RTree(x._RTree), _Qa(x._Qa), _Qb(x._Qb), _mes(x._mes), _logZ(x._logZ), props(x.props) {}
+
+        /// Assignment operator
+        JTree & operator=( const JTree &x ) {
             if( this != &x ) {
-                DAIAlgRG::operator=(x);
+                DAIAlgRG::operator=( x );
                 _RTree  = x._RTree;
                 _Qa     = x._Qa;
                 _Qb     = x._Qb;
@@ -71,33 +78,57 @@ class JTree : public DAIAlgRG {
             }
             return *this;
         }
-        JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic=true );
-        void GenerateJT( const std::vector<VarSet> &Cliques );
 
-        Factor & message( size_t alpha, size_t _beta ) { return _mes[alpha][_beta]; }   
-        const Factor & message( size_t alpha, size_t _beta ) const { return _mes[alpha][_beta]; }   
+        /// Clone (virtual copy constructor)
+        virtual JTree* clone() const { return new JTree(*this); }
 
-        static const char *Name;
+        /// Create (virtual constructor)
+        virtual JTree* create() const { return new JTree(); }
+
+        /// Return number of passes over the factorgraph
+        virtual size_t Iterations() const { return 1UL; }
+
+        /// Return maximum difference between single node beliefs for two consecutive iterations
+        virtual double maxDiff() const { return 0.0; }
+
+        /// Identifies itself for logging purposes
         std::string identify() const;
-        void init() {}
-        /// Clear messages and beliefs corresponding to the nodes in ns
-        virtual void init( const VarSet &/*ns*/ ) {}
-        void runHUGIN();
-        void runShaferShenoy();
-        double run();
+
+        /// Get single node belief
         Factor belief( const Var &n ) const;
+
+        /// Get general belief
         Factor belief( const VarSet &ns ) const;
+
+        /// Get all beliefs
         std::vector<Factor> beliefs() const;
+
+        /// Get log partition sum
         Real logZ() const;
 
-        void restoreFactors( const VarSet &ns ) { RegionGraph::restoreFactors( ns ); init( ns ); }
+        /// Clear messages and beliefs
+        void init() {}
 
+        /// Clear messages and beliefs corresponding to the nodes in ns
+        virtual void init( const VarSet &/*ns*/ ) {}
+
+        /// The actual approximate inference algorithm
+        double run();
+
+
+        void GenerateJT( const std::vector<VarSet> &Cliques );
+
+        Factor & message( size_t alpha, size_t _beta ) { return _mes[alpha][_beta]; }   
+        const Factor & message( size_t alpha, size_t _beta ) const { return _mes[alpha][_beta]; }   
+
+        void runHUGIN();
+        void runShaferShenoy();
         size_t findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t PreviousRoot=(size_t)-1 ) const;
         Factor calcMarginal( const VarSet& ns );
+
         void setProperties( const PropertySet &opts );
         PropertySet getProperties() const;
         std::string printProperties() const;
-        double maxDiff() const { return 0.0; }
 };
 
 
index dd57706..3d2f887 100644 (file)
@@ -241,7 +241,7 @@ double HAK::doGBP() {
                 Qb_new *= muab(alpha,_beta) ^ (1 / (nbIR(beta).size() + IR(beta).c()));
             }
 
-            Qb_new.normalize( Prob::NORMPROB );
+            Qb_new.normalize();
             if( Qb_new.hasNaNs() ) {
                 cout << "HAK::doGBP:  Qb_new has NaNs!" << endl;
                 return 1.0;
@@ -258,7 +258,7 @@ double HAK::doGBP() {
                 foreach( const Neighbor &gamma, nbOR(alpha) )
                     Qa_new *= muba(alpha,gamma.iter);
                 Qa_new ^= (1.0 / OR(alpha).c());
-                Qa_new.normalize( Prob::NORMPROB );
+                Qa_new.normalize();
                 if( Qa_new.hasNaNs() ) {
                     cout << "HAK::doGBP:  Qa_new has NaNs!" << endl;
                     return 1.0;
index 810dfa5..24a5eed 100644 (file)
@@ -61,8 +61,11 @@ string JTree::printProperties() const {
 JTree::JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic ) : DAIAlgRG(fg), _RTree(), _Qa(), _Qb(), _mes(), _logZ(), props() {
     setProperties( opts );
 
+    if( !isConnected() ) 
+       DAI_THROW(FACTORGRAPH_NOT_CONNECTED); 
+
     if( automatic ) {
-        // Copy VarSets of factors
+        // Create ClusterGraph which contains factors as clusters
         vector<VarSet> cl;
         cl.reserve( fg.nrFactors() );
         for( size_t I = 0; I < nrFactors(); I++ )
@@ -94,7 +97,8 @@ void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
     for( size_t i = 0; i < Cliques.size(); i++ )
         for( size_t j = i+1; j < Cliques.size(); j++ ) {
             size_t w = (Cliques[i] & Cliques[j]).size();
-            JuncGraph[UEdge(i,j)] = w;
+            if( w ) 
+                JuncGraph[UEdge(i,j)] = w;
         }
     
     // Construct maximal spanning tree using Prim's algorithm
@@ -113,7 +117,6 @@ void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
         size_t alpha;
         for( alpha = 0; alpha < nrORs(); alpha++ )
             if( OR(alpha).vars() >> factor(I).vars() ) {
-//              OR(alpha) *= factor(I);
                 fac2OR.push_back( alpha );
                 break;
             }
@@ -216,14 +219,14 @@ void JTree::runHUGIN() {
 //      Make outer region _RTree[i].n1 consistent with outer region _RTree[i].n2
 //      IR(i) = seperator OR(_RTree[i].n1) && OR(_RTree[i].n2)
         Factor new_Qb = _Qa[_RTree[i].n2].partSum( IR( i ) );
-        _logZ += log(new_Qb.normalize( Prob::NORMPROB ));
+        _logZ += log(new_Qb.normalize());
         _Qa[_RTree[i].n1] *= new_Qb.divided_by( _Qb[i] ); 
         _Qb[i] = new_Qb;
     }
     if( _RTree.empty() )
-        _logZ += log(_Qa[0].normalize( Prob::NORMPROB ) );
+        _logZ += log(_Qa[0].normalize() );
     else
-        _logZ += log(_Qa[_RTree[0].n1].normalize( Prob::NORMPROB ));
+        _logZ += log(_Qa[_RTree[0].n1].normalize());
 
     // DistributeEvidence
     for( size_t i = 0; i < _RTree.size(); i++ ) {
@@ -236,7 +239,7 @@ void JTree::runHUGIN() {
 
     // Normalize
     for( size_t alpha = 0; alpha < nrORs(); alpha++ )
-        _Qa[alpha].normalize( Prob::NORMPROB );
+        _Qa[alpha].normalize();
 }
 
 
@@ -257,7 +260,7 @@ void JTree::runShaferShenoy() {
             if( k != e ) 
                 piet *= message( i, k.iter );
         message( j, _e ) = piet.partSum( IR(e) );
-        _logZ += log( message(j,_e).normalize( Prob::NORMPROB ) );
+        _logZ += log( message(j,_e).normalize() );
     }
 
     // Second pass
@@ -268,7 +271,7 @@ void JTree::runShaferShenoy() {
         
         Factor piet = OR(i);
         foreach( const Neighbor &k, nbOR(i) )
-            if(  k != e )
+            if( k != e )
                 piet *= message( i, k.iter );
         message( j, _e ) = piet.marginal( IR(e) );
     }
@@ -279,13 +282,13 @@ void JTree::runShaferShenoy() {
         foreach( const Neighbor &k, nbOR(alpha) )
             piet *= message( alpha, k.iter );
         if( nrIRs() == 0 ) {
-            _logZ += log( piet.normalize( Prob::NORMPROB ) );
+            _logZ += log( piet.normalize() );
             _Qa[alpha] = piet;
         } else if( alpha == nbIR(0)[0].node /*_RTree[0].n1*/ ) {
-            _logZ += log( piet.normalize( Prob::NORMPROB ) );
+            _logZ += log( piet.normalize() );
             _Qa[alpha] = piet;
         } else
-            _Qa[alpha] = piet.normalized( Prob::NORMPROB );
+            _Qa[alpha] = piet.normalized();
     }
 
     // Only for logZ (and for belief)...
@@ -480,7 +483,6 @@ Factor JTree::calcMarginal( const VarSet& ns ) {
                 
             // For all states of nsrem
             for( State s(nsrem); s.valid(); s++ ) {
-                
                 // CollectEvidence
                 double logZ = 0.0;
                 for( size_t i = Tsize; (i--) != 0; ) {
@@ -495,11 +497,11 @@ Factor JTree::calcMarginal( const VarSet& ns ) {
                         }
 
                     Factor new_Qb = _Qa[T[i].n2].partSum( IR( b[i] ) );
-                    logZ += log(new_Qb.normalize( Prob::NORMPROB ));
+                    logZ += log(new_Qb.normalize());
                     _Qa[T[i].n1] *= new_Qb.divided_by( _Qb[b[i]] ); 
                     _Qb[b[i]] = new_Qb;
                 }
-                logZ += log(_Qa[T[0].n1].normalize( Prob::NORMPROB ));
+                logZ += log(_Qa[T[0].n1].normalize());
 
                 Factor piet( nsrem, 0.0 );
                 piet[s] = exp(logZ);
@@ -512,7 +514,7 @@ Factor JTree::calcMarginal( const VarSet& ns ) {
                     _Qb[beta->first] = beta->second;
             }
 
-            return( Pns.normalized(Prob::NORMPROB) );
+            return( Pns.normalized() );
         }
     }
 }
index fc63e6b..0aec5d9 100644 (file)
@@ -149,7 +149,7 @@ double LC::CalcCavityDist (size_t i, const std::string &name, const PropertySet
         maxdiff = cav->maxDiff();
         delete cav;
     }
-    Bi.normalize( Prob::NORMPROB );
+    Bi.normalize();
     _cavitydists[i] = Bi;
 
     return maxdiff;
@@ -219,7 +219,7 @@ void LC::init() {
               _pancakes[i] *= _phis[i][I.iter];
         }
         
-        _pancakes[i].normalize( Prob::NORMPROB );
+        _pancakes[i].normalize();
 
         CalcBelief(i);
     }
@@ -241,10 +241,10 @@ Factor LC::NewPancake (size_t i, size_t _I, bool & hasNaNs) {
     Factor A_Ii = (_pancakes[i] * factor(I).inverse() * _phis[i][_I].inverse()).partSum( Ivars / var(i) );
     Factor quot = A_I.divided_by(A_Ii);
 
-    piet *= quot.divided_by( _phis[i][_I] ).normalized( Prob::NORMPROB );
-    _phis[i][_I] = quot.normalized( Prob::NORMPROB );
+    piet *= quot.divided_by( _phis[i][_I] ).normalized();
+    _phis[i][_I] = quot.normalized();
 
-    piet.normalize( Prob::NORMPROB );
+    piet.normalize();
 
     if( piet.hasNaNs() ) {
         cout << "LC::NewPancake(" << i << ", " << _I << "):  has NaNs!" << endl;
index 0daab37..13f3734 100644 (file)
@@ -519,7 +519,7 @@ void MR::init_cor() {
             VarSet::const_iterator kit = pairq[jk].vars().begin();
             size_t j = findVar( *(kit) );
             size_t k = findVar( *(++kit) );
-            pairq[jk].normalize(Prob::NORMPROB);
+            pairq[jk].normalize();
             double cor = (pairq[jk][3] - pairq[jk][2] - pairq[jk][1] + pairq[jk][0]) - (pairq[jk][3] + pairq[jk][2] - pairq[jk][1] - pairq[jk][0]) * (pairq[jk][3] - pairq[jk][2] + pairq[jk][1] - pairq[jk][0]);
             for( size_t _j = 0; _j < con[i]; _j++ ) if( nb[i][_j] == j )
                 for( size_t _k = 0; _k < con[i]; _k++ ) if( nb[i][_k] == k ) {
index 56924ae..b780d26 100644 (file)
@@ -178,11 +178,11 @@ void TreeEPSubTree::HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &
     _logZ = 0.0;
     for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) {
         _logZ += log(Qa[_a[alpha]].totalSum());
-        Qa[_a[alpha]].normalize( Prob::NORMPROB );
+        Qa[_a[alpha]].normalize();
     }
     for( size_t beta = 0; beta < _Qb.size(); beta++ ) {
         _logZ -= log(Qb[_b[beta]].totalSum());
-        Qb[_b[beta]].normalize( Prob::NORMPROB );
+        Qb[_b[beta]].normalize();
     }
 }