Changed FactorGraph::clamp and DAIAlg::clamp interfaces
authorJoris Mooij <joris.mooij@tuebingen.mpg.de>
Wed, 2 Sep 2009 14:55:59 +0000 (16:55 +0200)
committerJoris Mooij <joris.mooij@tuebingen.mpg.de>
Wed, 2 Sep 2009 14:55:59 +0000 (16:55 +0200)
The variable to be clamped is now indicated by its index, not as a Var.
The old interface is marked obsolete

ChangeLog
include/dai/bipgraph.h
include/dai/daialg.h
include/dai/factorgraph.h
include/dai/lc.h
src/cbp.cpp
src/daialg.cpp
src/evidence.cpp
src/factorgraph.cpp

index 6432dcf..dbe384d 100644 (file)
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,3 +1,7 @@
+* Changed FactorGraph::clamp and DAIAlg::clamp interfaces (the variable to be
+  clamped is now indicated by its index, not as a Var) and marked the old
+  interface as obsolete
+* [Patrick Pletscher] Fixed performance issue in FactorGraph::clamp
 * [Sebastian Nowozin] MEX file dai now also optionally returns the MAP state
   (only for BP)
 * [Sebastian Nowozin] Fixed memory leak in MEX file dai
index 15845de..3d5479b 100644 (file)
@@ -146,6 +146,7 @@ class BipartiteGraph {
             std::vector<size_t> ind2;       // indices of nodes of type 2
         };
 
+        // OBSOLETE
         /// @name Backwards compatibility layer (to be removed soon)
         //@{
         /// Enable backwards compatibility layer?
@@ -373,6 +374,7 @@ class BipartiteGraph {
         /// Writes this BipartiteGraph to an output stream in GraphViz .dot syntax
         void printDot( std::ostream& os ) const;
 
+        // OBSOLETE
         /// @name Backwards compatibility layer (to be removed soon)
         //@{
         void indexEdges() {
index 4dbae3a..2d5274f 100644 (file)
@@ -98,8 +98,14 @@ class InfAlg {
          */
         virtual double run() = 0;
 
-        /// Clamp variable n to value i (i.e. multiply with a Kronecker delta \f$\delta_{x_n, i}\f$)
-        virtual void clamp( const Var & n, size_t i, bool backup = false ) = 0;
+        /// Clamp variable with index i to value x (i.e. multiply with a Kronecker delta \f$\delta_{x_i, x}\f$)
+        /** If backup == true, make a backup of all factors that are changed
+         */
+        virtual void clamp( size_t i, size_t x, bool backup = false ) = 0;
+
+        // OBSOLETE
+        /// Only for backwards compatibility (to be removed soon)
+        virtual void clamp( const Var &v, size_t x, bool backup = false ) = 0;
 
         /// Set all factors interacting with var(i) to 1
         virtual void makeCavity( size_t i, bool backup = false ) = 0;
@@ -158,8 +164,15 @@ class DAIAlg : public InfAlg, public GRM {
         /// Restore Factors involving ns
         void restoreFactors( const VarSet &ns ) { GRM::restoreFactors( ns ); }
 
-        /// Clamp variable n to value i (i.e. multiply with a Kronecker delta \f$\delta_{x_n, i}\f$)
-        void clamp( const Var & n, size_t i, bool backup = false ) { GRM::clamp( n, i, backup ); }
+        /// Clamp variable with index i to value x (i.e. multiply with a Kronecker delta \f$\delta_{x_i, x}\f$)
+        void clamp( size_t i, size_t x, bool backup = false ) { GRM::clamp( i, x, backup ); }
+
+        // OBSOLETE
+        /// Only for backwards compatibility (to be removed soon)
+        void clamp( const Var &v, size_t x, bool backup = false ) { 
+            GRM::clamp( v, x, backup );
+            std::cerr << "Warning: this DAIAlg<...>::clamp(const Var&,...) interface is obsolete!" << std::endl;
+        }
 
         /// Set all factors interacting with var(i) to 1
         void makeCavity( size_t i, bool backup = false ) { GRM::makeCavity( i, backup ); }
index 5beae95..8ca8ef6 100644 (file)
@@ -157,7 +157,7 @@ class FactorGraph {
         Neighbor & nbF( size_t I, size_t _i ) { return G.nb2(I)[_i]; }
 
         /// Returns the index of a particular variable
-        size_t findVar( const Var & n ) const {
+        size_t findVar( const Var &n ) const {
             size_t i = find( vars().begin(), vars().end(), n ) - vars().begin();
             assert( i != nrVars() );
             return i;
@@ -212,10 +212,17 @@ class FactorGraph {
             }
         }
 
-        /// Clamp variable n to value i (i.e. multiply with a Kronecker delta \f$\delta_{x_n, i}\f$)
+        /// Clamp variable with index i to value x (i.e. multiply with a Kronecker delta \f$\delta_{x_i, x}\f$)
         /** If backup == true, make a backup of all factors that are changed
          */
-        virtual void clamp( const Var & n, size_t i, bool backup = false );
+        virtual void clamp( size_t i, size_t x, bool backup = false );
+
+        // OBSOLETE
+        /// Only for backwards compatibility (to be removed soon)
+        virtual void clamp( const Var &v, size_t x, bool backup = false ) { 
+            std::cerr << "Warning: this FactorGraph::clamp(const Var&,...) interface is obsolete!" << std::endl;
+            clamp( findVar(v), x, backup );
+        }
 
         /// Clamp a variable in a factor graph to have one out of a list of values
         /** If backup == true, make a backup of all factors that are changed
@@ -249,10 +256,10 @@ class FactorGraph {
         bool isBinary() const;
 
         /// Reads a FactorGraph from a file
-        void ReadFromFile(const char *filename);
+        void ReadFromFile( const char *filename );
 
         /// Writes a FactorGraph to a file
-        void WriteToFile(const char *filename, size_t precision=15) const;
+        void WriteToFile( const char *filename, size_t precision=15 ) const;
 
         /// Writes a FactorGraph to a GraphViz .dot file
         void printDot( std::ostream& os ) const;
@@ -260,11 +267,11 @@ class FactorGraph {
         /// Returns the cliques in this FactorGraph
         std::vector<VarSet> Cliques() const;
 
-        /// Clamp variable v_i to value state (i.e. multiply with a Kronecker delta \f$\delta_{x_{v_i},x}\f$);
+        /// Clamp variable v to value x (i.e. multiply with a Kronecker delta \f$\delta_{x_v,x}\f$);
         /** This version changes the factor graph structure and thus returns a newly constructed FactorGraph
          *  and keeps the current one constant, contrary to clamp()
          */
-        FactorGraph clamped( const Var & v_i, size_t x ) const;
+        FactorGraph clamped( const Var &v, size_t x ) const;
 
         /// Returns a copy of *this, where all factors that are subsumed by some larger factor are merged with the larger factors.
         FactorGraph maximalFactors() const;
@@ -284,6 +291,7 @@ class FactorGraph {
         friend std::ostream& operator << (std::ostream& os, const FactorGraph& fg);
         friend std::istream& operator >> (std::istream& is, FactorGraph& fg);
 
+        // OBSOLETE
         /// @name Backwards compatibility layer (to be removed soon)
         //@{
         size_t VV2E(size_t n1, size_t n2) const { return G.VV2E(n1,n2); }
index 536dc8d..1acdab6 100644 (file)
@@ -119,6 +119,7 @@ class LC : public DAIAlgFG {
         virtual size_t Iterations() const { return _iters; }
         //@}
 
+        Factor beliefV( size_t i ) const { return _beliefs[i]; }
 
         /// @name Additional interface specific for LC
         //@{ 
index a5b82f0..a641746 100644 (file)
@@ -349,7 +349,7 @@ bool CBP::chooseNextClampVar( InfAlg *bp, vector<size_t> &clamped_vars_list, siz
                 if( bp->beliefV(k)[xk] < tiny ) 
                     continue;
                 InfAlg *bp1 = bp->clone();
-                bp1->clamp( var(k), xk );
+                bp1->clamp( k, xk );
                 bp1->init( var(k) );
                 bp1->run();
                 Real cost = 0;
index ce18883..47f61eb 100644 (file)
@@ -33,13 +33,17 @@ using namespace std;
 /// Calculates the marginal of obj on ns by clamping all variables in ns and calculating logZ for each joined state.
 /*  reInit should be set to true if at least one of the possible clamped states would be invalid (leading to a factor graph with zero partition sum).
  */
-Factor calcMarginal( const InfAlg & obj, const VarSet & ns, bool reInit ) {
+Factor calcMarginal( const InfAlg &obj, const VarSet &ns, bool reInit ) {
     Factor Pns (ns);
     
     InfAlg *clamped = obj.clone();
     if( !reInit )
         clamped->init();
 
+    map<Var,size_t> varindices;
+    for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
+        varindices[*n] = obj.fg().findVar( *n );
+
     Real logZ0 = -INFINITY;
     for( State s(ns); s.valid(); s++ ) {
         // save unclamped factors connected to ns
@@ -47,7 +51,7 @@ Factor calcMarginal( const InfAlg & obj, const VarSet & ns, bool reInit ) {
 
         // set clamping Factors to delta functions
         for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
-            clamped->clamp( *n, s(*n) );
+            clamped->clamp( varindices[*n], s(*n) );
         
         // run DAIAlg, calc logZ, store in Pns
         if( reInit )
@@ -93,8 +97,11 @@ vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& ns, bool reIni
     size_t N = ns.size();
     vector<Var> vns;
     vns.reserve( N );
-    for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
+    map<Var,size_t> varindices;
+    for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ ) {
         vns.push_back( *n );
+        varindices[*n] = obj.fg().findVar( *n );
+    }
 
     vector<Factor> pairbeliefs;
     pairbeliefs.reserve( N * N );
@@ -113,7 +120,7 @@ vector<Factor> calcPairBeliefs( const InfAlg & obj, const VarSet& ns, bool reIni
     for( size_t j = 0; j < N; j++ ) {
         // clamp Var j to its possible values
         for( size_t j_val = 0; j_val < vns[j].states(); j_val++ ) {
-            clamped->clamp( vns[j], j_val, true );
+            clamped->clamp( varindices[vns[j]], j_val, true );
             if( reInit )
                 clamped->init();
             else
@@ -195,6 +202,10 @@ vector<Factor> calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool re
     if( !reInit )
         clamped->init();
 
+    map<Var,size_t> varindices;
+    for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
+        varindices[*n] = obj.fg().findVar( *n );
+
     Real logZ0 = 0.0;
     VarSet::const_iterator nj = ns.begin();
     for( long j = 0; j < (long)ns.size() - 1; j++, nj++ ) {
@@ -207,9 +218,9 @@ vector<Factor> calcPairBeliefsNew( const InfAlg & obj, const VarSet& ns, bool re
                 for( size_t k_val = 0; k_val < nk->states(); k_val++ ) {
                     // save unclamped factors connected to ns
                     clamped->backupFactors( ns );
-            
-                    clamped->clamp( *nj, j_val );
-                    clamped->clamp( *nk, k_val );
+
+                    clamped->clamp( varindices[*nj], j_val );
+                    clamped->clamp( varindices[*nk], k_val );
                     if( reInit )
                         clamped->init();
                     else
index e003807..10185ec 100644 (file)
@@ -37,7 +37,7 @@ void Observation::addObservation( Var node, size_t setting ) {
 
 void Observation::applyEvidence( InfAlg &alg ) const {
     for( std::map<Var, size_t>::const_iterator i = _obs.begin(); i != _obs.end(); ++i )
-        alg.clamp( i->first, i->second );
+        alg.clamp( alg.fg().findVar(i->first), i->second );
 }
   
 
index e49ee5e..11ea17f 100644 (file)
@@ -310,18 +310,14 @@ vector<VarSet> FactorGraph::Cliques() const {
 }
 
 
-void FactorGraph::clamp( const Var & n, size_t i, bool backup ) {
-    assert( i <= n.states() );
-
-    // Multiply each factor that contains the variable with a delta function
-
-    Factor delta_n_i(n,0.0);
-    delta_n_i[i] = 1.0;
+void FactorGraph::clamp( size_t i, size_t x, bool backup ) {
+    assert( x <= var(i).states() );
+    Factor mask( var(i), 0.0 );
+    mask[x] = 1.0;
 
     map<size_t, Factor> newFacs;
-    size_t n_index = findVar(n);
-       foreach( const BipartiteGraph::Neighbor &I, nbV(n_index) )
-        newFacs[I] = factor(I) * delta_n_i;
+       foreach( const BipartiteGraph::Neighbor &I, nbV(i) )
+        newFacs[I] = factor(I) * mask;
     setFactors( newFacs, backup );
 
     return;
@@ -338,8 +334,7 @@ void FactorGraph::clampVar( size_t i, const vector<size_t> &is, bool backup ) {
     }
 
     map<size_t, Factor> newFacs;
-    size_t n_index = findVar(n);
-       foreach( const BipartiteGraph::Neighbor &I, nbV(n_index) )
+       foreach( const BipartiteGraph::Neighbor &I, nbV(i) )
         newFacs[I] = factor(I) * mask_n;
     setFactors( newFacs, backup );
 }
@@ -425,14 +420,14 @@ bool FactorGraph::isBinary() const {
 }
 
 
-FactorGraph FactorGraph::clamped( const Var & v_i, size_t state ) const {
+FactorGraph FactorGraph::clamped( const Var &v, size_t state ) const {
     Real zeroth_order = 1.0;
     vector<Factor> clamped_facs;
     for( size_t I = 0; I < nrFactors(); I++ ) {
         VarSet v_I = factor(I).vars();
         Factor new_factor;
-        if( v_I.intersects( v_i ) )
-            new_factor = factor(I).slice( v_i, state );
+        if( v_I.intersects( v ) )
+            new_factor = factor(I).slice( v, state );
         else
             new_factor = factor(I);