New git HEAD version
[libdai.git] / src / clustergraph.cpp
index 96d424c..f899c4d 100644 (file)
@@ -1,11 +1,8 @@
 /*  This file is part of libDAI - http://www.libdai.org/
  *
- *  libDAI is licensed under the terms of the GNU General Public License version
- *  2, or (at your option) any later version. libDAI is distributed without any
- *  warranty. See the file COPYING for more details.
+ *  Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
  *
- *  Copyright (C) 2006-2010  Joris Mooij  [joris dot mooij at libdai dot org]
- *  Copyright (C) 2006-2007  Radboud University Nijmegen, The Netherlands
+ *  Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
  */
 
 
@@ -22,41 +19,51 @@ namespace dai {
 using namespace std;
 
 
-ClusterGraph::ClusterGraph( const std::vector<VarSet> & cls ) : G(), vars(), clusters() {
+ClusterGraph::ClusterGraph( const std::vector<VarSet> & cls ) : _G(), _vars(), _clusters() {
     // construct vars, clusters and edge list
     vector<Edge> edges;
-    foreach( const VarSet &cl, cls ) {
-        if( find( clusters.begin(), clusters.end(), cl ) == clusters.end() ) {
+    bforeach( const VarSet &cl, cls ) {
+        if( find( clusters().begin(), clusters().end(), cl ) == clusters().end() ) {
             // add cluster
-            size_t n2 = clusters.size();
-            clusters.push_back( cl );
+            size_t n2 = nrClusters();
+            _clusters.push_back( cl );
             for( VarSet::const_iterator n = cl.begin(); n != cl.end(); n++ ) {
-                size_t n1 = find( vars.begin(), vars.end(), *n ) - vars.begin();
-                if( n1 == vars.size() )
+                size_t n1 = find( vars().begin(), vars().end(), *n ) - vars().begin();
+                if( n1 == nrVars() )
                     // add variable
-                    vars.push_back( *n );
+                    _vars.push_back( *n );
                 edges.push_back( Edge( n1, n2 ) );
             }
         } // disregard duplicate clusters
     }
 
     // Create bipartite graph
-    G.construct( vars.size(), clusters.size(), edges.begin(), edges.end() );
+    _G.construct( nrVars(), nrClusters(), edges.begin(), edges.end() );
 }
 
 
-size_t ClusterGraph::eliminationCost( size_t i ) const {
-    return eliminationCost_MinFill( *this, i );
-}
-
-
-ClusterGraph ClusterGraph::VarElim( const std::vector<Var> &ElimSeq ) const {
-    return VarElim( sequentialVariableElimination( ElimSeq ) );
-}
-
+ClusterGraph::ClusterGraph( const FactorGraph& fg, bool onlyMaximal ) : _G( fg.nrVars(), 0 ), _vars(), _clusters() {
+    // copy variables
+    _vars.reserve( fg.nrVars() );
+    for( size_t i = 0; i < fg.nrVars(); i++ )
+        _vars.push_back( fg.var(i) );
 
-ClusterGraph ClusterGraph::VarElim_MinFill() const {
-    return VarElim( greedyVariableElimination( &eliminationCost_MinFill ) );
+    if( onlyMaximal ) {
+        for( size_t I = 0; I < fg.nrFactors(); I++ )
+            if( fg.isMaximal( I ) ) {
+                _clusters.push_back( fg.factor(I).vars() );
+                size_t clind = _G.addNode2();
+                bforeach( const Neighbor &i, fg.nbF(I) )
+                    _G.addEdge( i, clind, true );
+            }
+    } else {
+        // copy clusters
+        _clusters.reserve( fg.nrFactors() );
+        for( size_t I = 0; I < fg.nrFactors(); I++ )
+            _clusters.push_back( fg.factor(I).vars() );
+        // copy bipartite graph
+        _G = fg.bipGraph();
+    }
 }
 
 
@@ -80,49 +87,50 @@ size_t greedyVariableElimination::operator()( const ClusterGraph &cl, const std:
 
 
 size_t eliminationCost_MinNeighbors( const ClusterGraph &cl, size_t i ) {
-    std::vector<size_t> id_n = cl.G.delta1( i );
-    return id_n.size();
+    return cl.bipGraph().delta1( i ).size();
 }
 
 
 size_t eliminationCost_MinWeight( const ClusterGraph &cl, size_t i ) {
-    std::vector<size_t> id_n = cl.G.delta1( i );
+    SmallSet<size_t> id_n = cl.bipGraph().delta1( i );
     
     size_t cost = 1;
-    for( size_t _i = 0; _i < id_n.size(); _i++ )
-        cost *= cl.vars[id_n[_i]].states();
+    for( SmallSet<size_t>::const_iterator it = id_n.begin(); it != id_n.end(); it++ )
+        cost *= cl.vars()[*it].states();
 
     return cost;
 }
 
 
 size_t eliminationCost_MinFill( const ClusterGraph &cl, size_t i ) {
-    std::vector<size_t> id_n = cl.G.delta1( i );
+    SmallSet<size_t> id_n = cl.bipGraph().delta1( i );
 
     size_t cost = 0;
     // for each unordered pair {i1,i2} adjacent to n
-    for( size_t _i1 = 0; _i1 < id_n.size(); _i1++ )
-        for( size_t _i2 = _i1 + 1; _i2 < id_n.size(); _i2++ ) {
-            // if i1 and i2 are not adjacent, eliminating n would make them adjacent
-            if( !cl.adj(id_n[_i1], id_n[_i2]) )
-                cost++;
-        }
+    for( SmallSet<size_t>::const_iterator it1 = id_n.begin(); it1 != id_n.end(); it1++ )
+        for( SmallSet<size_t>::const_iterator it2 = it1; it2 != id_n.end(); it2++ )
+            if( it1 != it2 ) {
+                // if i1 and i2 are not adjacent, eliminating n would make them adjacent
+                if( !cl.adj(*it1, *it2) )
+                    cost++;
+            }
 
     return cost;
 }
 
 
 size_t eliminationCost_WeightedMinFill( const ClusterGraph &cl, size_t i ) {
-    std::vector<size_t> id_n = cl.G.delta1( i );
+    SmallSet<size_t> id_n = cl.bipGraph().delta1( i );
 
     size_t cost = 0;
     // for each unordered pair {i1,i2} adjacent to n
-    for( size_t _i1 = 0; _i1 < id_n.size(); _i1++ )
-        for( size_t _i2 = _i1 + 1; _i2 < id_n.size(); _i2++ ) {
-            // if i1 and i2 are not adjacent, eliminating n would make them adjacent
-            if( !cl.adj(id_n[_i1], id_n[_i2]) )
-                cost += cl.vars[id_n[_i1]].states() * cl.vars[id_n[_i2]].states();
-        }
+    for( SmallSet<size_t>::const_iterator it1 = id_n.begin(); it1 != id_n.end(); it1++ )
+        for( SmallSet<size_t>::const_iterator it2 = it1; it2 != id_n.end(); it2++ )
+            if( it1 != it2 ) {
+                // if i1 and i2 are not adjacent, eliminating n would make them adjacent
+                if( !cl.adj(*it1, *it2) )
+                    cost += cl.vars()[*it1].states() * cl.vars()[*it2].states();
+            }
 
     return cost;
 }