Merged SVN head ...
[libdai.git] / include / dai / factorgraph.h
index adb6722..7e7f4d9 100644 (file)
@@ -25,7 +25,6 @@
 
 #include <iostream>
 #include <map>
-#include <tr1/unordered_map>
 #include <dai/bipgraph.h>
 #include <dai/factor.h>
 
 namespace dai {
 
 
-bool hasShortLoops( const std::vector<Factor> &P );
-void RemoveShortLoops( std::vector<Factor> &P );
-
-
 class FactorGraph {
     public:
         BipartiteGraph         G;
         std::vector<Var>       vars;
-        std::vector<Factor>    factors;
         typedef BipartiteGraph::Neighbor  Neighbor;
         typedef BipartiteGraph::Neighbors Neighbors;
+        typedef BipartiteGraph::Edge      Edge;
 
-    protected:
-        std::map<size_t,Prob>  _undoProbs;
-        Prob::NormType         _normtype;
+    private:
+        std::vector<Factor>      _factors;
+        std::map<size_t,Factor>  _backup;
 
     public:
         /// Default constructor
-        FactorGraph() : G(), vars(), factors(), _undoProbs(), _normtype(Prob::NORMPROB) {};
+        FactorGraph() : G(), vars(), _factors(), _backup() {}
         /// Copy constructor
-        FactorGraph(const FactorGraph & x) : G(x.G), vars(x.vars), factors(x.factors), _undoProbs(x._undoProbs), _normtype(x._normtype) {};
+        FactorGraph(const FactorGraph & x) : G(x.G), vars(x.vars), _factors(x._factors), _backup(x._backup) {}
         /// Construct FactorGraph from vector of Factors
         FactorGraph(const std::vector<Factor> &P);
         // Construct a FactorGraph from given factor and variable iterators
@@ -65,22 +60,34 @@ class FactorGraph {
             if( this != &x ) {
                 G          = x.G;
                 vars       = x.vars;
-                factors    = x.factors;
-                _undoProbs = x._undoProbs;
-                _normtype  = x._normtype;
+                _factors   = x._factors;
+                _backup    = x._backup;
             }
             return *this;
         }
         virtual ~FactorGraph() {}
 
+        /// Clone *this (virtual copy constructor)
+        virtual FactorGraph* clone() const { return new FactorGraph(); }
+
+        /// Create (virtual default constructor)
+        virtual FactorGraph* create() const { return new FactorGraph(*this); }
+
         // aliases
         Var & var(size_t i) { return vars[i]; }
+        /// Get const reference to i'th variable
         const Var & var(size_t i) const { return vars[i]; }
-        Factor & factor(size_t I) { return factors[I]; }
-        const Factor & factor(size_t I) const { return factors[I]; }
-
+        /// Get const reference to I'th factor
+        Factor & factor(size_t I) { return _factors[I]; }
+        /// Get const reference to I'th factor
+        const Factor & factor(size_t I) const { return _factors[I]; }
+        /// Get const reference to all factors
+        const std::vector<Factor> & factors() const { return _factors; }
+
+        /// Get number of variables
         size_t nrVars() const { return vars.size(); }
-        size_t nrFactors() const { return factors.size(); }
+        /// Get number of factors
+        size_t nrFactors() const { return _factors.size(); }
         size_t nrEdges() const { return G.nrEdges(); }
 
         /// Provides read access to neighbors of variable
@@ -100,11 +107,22 @@ class FactorGraph {
         /// Provides full access to neighbor of factor
         Neighbor & nbF( size_t I, size_t _i ) { return G.nb2(I)[_i]; }
 
-        size_t findVar(const Var & n) const {
+        /// Get index of variable n
+        size_t findVar( const Var & n ) const {
             size_t i = find( vars.begin(), vars.end(), n ) - vars.begin();
             assert( i != nrVars() );
             return i;
         }
+
+        /// Get set of indexes for set of variables
+        std::set<size_t> findVars( VarSet &ns ) const {
+            std::set<size_t> indexes;
+            for( VarSet::const_iterator n = ns.begin(); n != ns.end(); n++ )
+                indexes.insert( findVar( *n ) );
+            return indexes;
+        }
+
+        /// Get index of first factor involving ns
         size_t findFactor(const VarSet &ns) const {
             size_t I;
             for( I = 0; I < nrFactors(); I++ )
@@ -114,55 +132,99 @@ class FactorGraph {
             return I;
         }
 
-        friend std::ostream& operator << (std::ostream& os, const FactorGraph& fg);
-        friend std::istream& operator >> (std::istream& is, FactorGraph& fg);
+        /// Return all variables that occur in a factor involving the i'th variable, itself included
+        VarSet Delta( unsigned i ) const;
 
+        /// Return all variables that occur in a factor involving some variable in ns, ns itself included
+        VarSet Delta( const VarSet &ns ) const;
+
+        /// Return all variables that occur in a factor involving the i'th variable, n itself excluded
         VarSet delta( unsigned i ) const;
-        VarSet Delta( unsigned i ) const;
-        virtual void makeCavity( unsigned i );
 
-        long ReadFromFile(const char *filename);
-        long WriteToFile(const char *filename) const;
-        long WriteToDotFile(const char *filename) const;
+        /// Return all variables that occur in a factor involving some variable in ns, ns itself excluded
+        VarSet delta( const VarSet & ns ) const {
+            return Delta( ns ) / ns;
+        }
 
-        virtual void clamp( const Var & n, size_t i );
-        
-        bool hasNegatives() const;
-        Prob::NormType NormType() const { return _normtype; }
+        /// Set the content of the I'th factor and make a backup of its old content if backup == true
+        virtual void setFactor( size_t I, const Factor &newFactor, bool backup = false ) {
+            assert( newFactor.vars() == factor(I).vars() ); 
+            if( backup )
+                backupFactor( I );
+            _factors[I] = newFactor; 
+        }
+
+        /// Set the contents of all factors as specified by facs and make a backup of the old contents if backup == true
+        virtual void setFactors( const std::map<size_t, Factor> & facs, bool backup = false ) {
+            for( std::map<size_t, Factor>::const_iterator fac = facs.begin(); fac != facs.end(); fac++ ) {
+                if( backup )
+                    backupFactor( fac->first );
+                setFactor( fac->first, fac->second );
+            }
+        }
+
+        /// Clamp variable n to value i (i.e. multiply with a Kronecker delta \f$\delta_{x_n, i}\f$);
+        /// If backup == true, make a backup of all factors that are changed
+        virtual void clamp( const Var & n, size_t i, bool backup = false );
+
+        /// Set all factors interacting with the i'th variable 1
+        virtual void makeCavity( unsigned i, bool backup = false );
+
+        /// Backup the factors specified by indices in facs
+        virtual void backupFactors( const std::set<size_t> & facs );
+
+        /// Restore all factors to the backup copies
+        virtual void restoreFactors();
+
+        bool isConnected() const { return G.isConnected(); }
+        bool isTree() const { return G.isTree(); }
+
+        friend std::ostream& operator << (std::ostream& os, const FactorGraph& fg);
+        friend std::istream& operator >> (std::istream& is, FactorGraph& fg);
+
+        void ReadFromFile(const char *filename);
+        void WriteToFile(const char *filename) const;
+        void printDot( std::ostream& os ) const;
         
         std::vector<VarSet> Cliques() const;
 
-        virtual void undoProbs( const VarSet &ns );
-        void saveProbs( const VarSet &ns );
-        virtual void undoProb( size_t I );
-        void saveProb( size_t I );
+        // Clamp variable v_i to value state (i.e. multiply with a Kronecker delta \f$\delta_{x_{v_i},x}\f$);
+        // This version changes the factor graph structure and thus returns a newly constructed FactorGraph
+        // and keeps the current one constant, contrary to clamp()
+        FactorGraph clamped( const Var & v_i, size_t x ) const;
 
-        virtual void updatedFactor( size_t /*I*/ ) {};
+        FactorGraph maximalFactors() const;
 
-    private:
+        bool isPairwise() const;
+        bool isBinary() const;
+
+        void restoreFactor( size_t I );
+        void backupFactor( size_t I );
+        void restoreFactors( const VarSet &ns );
+        void backupFactors( const VarSet &ns );
         /// Part of constructors (creates edges, neighbors and adjacency matrix)
-        void createGraph( size_t nrEdges );
+        void constructGraph( size_t nrEdges );
 };
 
 
 // assumes that the set of variables in [var_begin,var_end) is the union of the variables in the factors in [fact_begin, fact_end)
 template<typename FactorInputIterator, typename VarInputIterator>
-FactorGraph::FactorGraph(FactorInputIterator fact_begin, FactorInputIterator fact_end, VarInputIterator var_begin, VarInputIterator var_end, size_t nr_fact_hint, size_t nr_var_hint ) : G(), _undoProbs(), _normtype(Prob::NORMPROB) {
+FactorGraph::FactorGraph(FactorInputIterator fact_begin, FactorInputIterator fact_end, VarInputIterator var_begin, VarInputIterator var_end, size_t nr_fact_hint, size_t nr_var_hint ) : G(), _backup() {
     // add factors
     size_t nrEdges = 0;
-    factors.reserve( nr_fact_hint );
+    _factors.reserve( nr_fact_hint );
     for( FactorInputIterator p2 = fact_begin; p2 != fact_end; ++p2 ) {
-        factors.push_back( *p2 );
+        _factors.push_back( *p2 );
         nrEdges += p2->vars().size();
     }
+
     // add variables
     vars.reserve( nr_var_hint );
     for( VarInputIterator p1 = var_begin; p1 != var_end; ++p1 )
         vars.push_back( *p1 );
 
     // create graph structure
-    createGraph( nrEdges );
+    constructGraph( nrEdges );
 }