-/* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
- Radboud University Nijmegen, The Netherlands
-
+/* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
+ Radboud University Nijmegen, The Netherlands /
+ Max Planck Institute for Biological Cybernetics, Germany
+
This file is part of libDAI.
libDAI is free software; you can redistribute it and/or modify
#include <dai/jtree.h>
#include <dai/treeep.h>
#include <dai/util.h>
-#include <dai/diffs.h>
namespace dai {
}
-TreeEPSubTree::TreeEPSubTree( const DEdgeVec &subRTree, const DEdgeVec &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) {
+string TreeEP::printProperties() const {
+ stringstream s( stringstream::out );
+ s << "[";
+ s << "tol=" << props.tol << ",";
+ s << "maxiter=" << props.maxiter << ",";
+ s << "verbose=" << props.verbose << ",";
+ s << "type=" << props.type << "]";
+ return s.str();
+}
+
+
+TreeEP::TreeEPSubTree::TreeEPSubTree( const DEdgeVec &subRTree, const DEdgeVec &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I ) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I), _ns(), _nsrem(), _logZ(0.0) {
_ns = _I->vars();
// Make _Qa, _Qb, _a and _b corresponding to the subtree
// Find remaining variables (which are not in the new root)
_nsrem = _ns / _Qa[0].vars();
-};
+}
-void TreeEPSubTree::init() {
+void TreeEP::TreeEPSubTree::init() {
for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
_Qa[alpha].fill( 1.0 );
for( size_t beta = 0; beta < _Qb.size(); beta++ )
}
-void TreeEPSubTree::InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) {
+void TreeEP::TreeEPSubTree::InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) {
for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
_Qa[alpha] = Qa[_a[alpha]].divided_by( _Qa[alpha] );
}
-void TreeEPSubTree::HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb ) {
+void TreeEP::TreeEPSubTree::HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb ) {
// Backup _Qa and _Qb
vector<Factor> _Qa_old(_Qa);
vector<Factor> _Qb_old(_Qb);
delta[s(*n)] = 1.0;
_Qa[_RTree[i].n2] *= delta;
}
- Factor new_Qb = _Qa[_RTree[i].n2].part_sum( _Qb[i].vars() );
+ Factor new_Qb = _Qa[_RTree[i].n2].partSum( _Qb[i].vars() );
_Qa[_RTree[i].n1] *= new_Qb.divided_by( _Qb[i] );
_Qb[i] = new_Qb;
}
// DistributeEvidence
for( size_t i = 0; i < _RTree.size(); i++ ) {
- Factor new_Qb = _Qa[_RTree[i].n1].part_sum( _Qb[i].vars() );
+ Factor new_Qb = _Qa[_RTree[i].n1].partSum( _Qb[i].vars() );
_Qa[_RTree[i].n2] *= new_Qb.divided_by( _Qb[i] );
_Qb[i] = new_Qb;
}
_logZ = 0.0;
for( size_t alpha = 0; alpha < _Qa.size(); alpha++ ) {
_logZ += log(Qa[_a[alpha]].totalSum());
- Qa[_a[alpha]].normalize( Prob::NORMPROB );
+ Qa[_a[alpha]].normalize();
}
for( size_t beta = 0; beta < _Qb.size(); beta++ ) {
_logZ -= log(Qb[_b[beta]].totalSum());
- Qb[_b[beta]].normalize( Prob::NORMPROB );
+ Qb[_b[beta]].normalize();
}
}
-double TreeEPSubTree::logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const {
+double TreeEP::TreeEPSubTree::logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const {
double sum = 0.0;
for( size_t alpha = 0; alpha < _Qa.size(); alpha++ )
sum += (Qa[_a[alpha]] * _Qa[alpha].log0()).totalSum();
}
-TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opts("updates",string("HUGIN")), false), props(), maxdiff(0.0) {
+TreeEP::TreeEP( const FactorGraph &fg, const PropertySet &opts ) : JTree(fg, opts("updates",string("HUGIN")), false), _maxdiff(0.0), _iters(0), props(), _Q() {
setProperties( opts );
- assert( fg.G.isConnected() );
+ assert( fg.isConnected() );
if( opts.hasKey("tree") ) {
ConstructRG( opts.GetAs<DEdgeVec>("tree") );
} else {
- if( props.type == Properties::TypeType::ORG ) {
- // construct weighted graph with as weights a crude estimate of the
+ if( props.type == Properties::TypeType::ORG || props.type == Properties::TypeType::ALT ) {
+ // ORG: construct weighted graph with as weights a crude estimate of the
// mutual information between the nodes
- WeightedGraph<double> wg;
- for( size_t i = 0; i < nrVars(); ++i ) {
- Var v_i = var(i);
- VarSet di = delta(i);
- for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
- if( v_i < *j ) {
- Factor piet;
- for( size_t I = 0; I < nrFactors(); I++ ) {
- VarSet Ivars = factor(I).vars();
- if( (Ivars == v_i) || (Ivars == *j) )
- piet *= factor(I);
- else if( Ivars >> (v_i | *j) )
- piet *= factor(I).marginal( v_i | *j );
- }
- if( piet.vars() >> (v_i | *j) ) {
- piet = piet.marginal( v_i | *j );
- Factor pietf = piet.marginal(v_i) * piet.marginal(*j);
- wg[UEdge(i,findVar(*j))] = real( KL_dist( piet, pietf ) );
- } else
- wg[UEdge(i,findVar(*j))] = 0;
- }
- }
-
- // find maximal spanning tree
- ConstructRG( MaxSpanningTreePrims( wg ) );
-
-// cout << "Constructing maximum spanning tree..." << endl;
-// DEdgeVec MST = MaxSpanningTreePrims( wg );
-// cout << "Maximum spanning tree:" << endl;
-// for( DEdgeVec::const_iterator e = MST.begin(); e != MST.end(); e++ )
-// cout << *e << endl;
-// ConstructRG( MST );
- } else if( props.type == Properties::TypeType::ALT ) {
- // construct weighted graph with as weights an upper bound on the
+ // ALT: construct weighted graph with as weights an upper bound on the
// effective interaction strength between pairs of nodes
+
WeightedGraph<double> wg;
for( size_t i = 0; i < nrVars(); ++i ) {
Var v_i = var(i);
VarSet di = delta(i);
for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
if( v_i < *j ) {
+ VarSet ij(v_i,*j);
Factor piet;
for( size_t I = 0; I < nrFactors(); I++ ) {
VarSet Ivars = factor(I).vars();
- if( Ivars >> (v_i | *j) )
- piet *= factor(I);
+ if( props.type == Properties::TypeType::ORG ) {
+ if( (Ivars == v_i) || (Ivars == *j) )
+ piet *= factor(I);
+ else if( Ivars >> ij )
+ piet *= factor(I).marginal( ij );
+ } else {
+ if( Ivars >> ij )
+ piet *= factor(I);
+ }
+ }
+ if( props.type == Properties::TypeType::ORG ) {
+ if( piet.vars() >> ij ) {
+ piet = piet.marginal( ij );
+ Factor pietf = piet.marginal(v_i) * piet.marginal(*j);
+ wg[UEdge(i,findVar(*j))] = dist( piet, pietf, Prob::DISTKL );
+ } else
+ wg[UEdge(i,findVar(*j))] = 0;
+ } else {
+ wg[UEdge(i,findVar(*j))] = piet.strength(v_i, *j);
}
- wg[UEdge(i,findVar(*j))] = piet.strength(v_i, *j);
}
}
// find maximal spanning tree
ConstructRG( MaxSpanningTreePrims( wg ) );
- } else {
- assert( 0 == 1 );
- }
+ } else
+ DAI_THROW(INTERNAL_ERROR);
}
}
void TreeEP::ConstructRG( const DEdgeVec &tree ) {
vector<VarSet> Cliques;
for( size_t i = 0; i < tree.size(); i++ )
- Cliques.push_back( var(tree[i].n1) | var(tree[i].n2) );
+ Cliques.push_back( VarSet( var(tree[i].n1), var(tree[i].n2) ) );
// Construct a weighted graph (each edge is weighted with the cardinality
// of the intersection of the nodes, where the nodes are the elements of
for( size_t i = 0; i < Cliques.size(); i++ )
for( size_t j = i+1; j < Cliques.size(); j++ ) {
size_t w = (Cliques[i] & Cliques[j]).size();
- JuncGraph[UEdge(i,j)] = w;
+ if( w )
+ JuncGraph[UEdge(i,j)] = w;
}
// Construct maximal spanning tree using Prim's algorithm
- _RTree = MaxSpanningTreePrims( JuncGraph );
+ RTree = MaxSpanningTreePrims( JuncGraph );
// Construct corresponding region graph
RecomputeORs();
// Create inner regions and edges
- IRs.reserve( _RTree.size() );
+ IRs.reserve( RTree.size() );
vector<Edge> edges;
- edges.reserve( 2 * _RTree.size() );
- for( size_t i = 0; i < _RTree.size(); i++ ) {
- edges.push_back( Edge( _RTree[i].n1, IRs.size() ) );
- edges.push_back( Edge( _RTree[i].n2, IRs.size() ) );
+ edges.reserve( 2 * RTree.size() );
+ for( size_t i = 0; i < RTree.size(); i++ ) {
+ edges.push_back( Edge( RTree[i].n1, IRs.size() ) );
+ edges.push_back( Edge( RTree[i].n2, IRs.size() ) );
// inner clusters have counting number -1
- IRs.push_back( Region( Cliques[_RTree[i].n1] & Cliques[_RTree[i].n2], -1.0 ) );
+ IRs.push_back( Region( Cliques[RTree[i].n1] & Cliques[RTree[i].n2], -1.0 ) );
}
// create bipartite graph
- G.create( nrORs(), nrIRs(), edges.begin(), edges.end() );
+ G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() );
// Check counting numbers
Check_Counting_Numbers();
// Create messages and beliefs
- _Qa.clear();
- _Qa.reserve( nrORs() );
+ Qa.clear();
+ Qa.reserve( nrORs() );
for( size_t alpha = 0; alpha < nrORs(); alpha++ )
- _Qa.push_back( OR(alpha) );
+ Qa.push_back( OR(alpha) );
- _Qb.clear();
- _Qb.reserve( nrIRs() );
+ Qb.clear();
+ Qb.reserve( nrIRs() );
for( size_t beta = 0; beta < nrIRs(); beta++ )
- _Qb.push_back( Factor( IR(beta), 1.0 ) );
+ Qb.push_back( Factor( IR(beta), 1.0 ) );
// DIFF with JTree::GenerateJT: no messages
//subTree.resize( subTreeSize ); // FIXME
// cout << "subtree " << I << " has size " << subTreeSize << endl;
-/*
- char fn[30];
- sprintf( fn, "/tmp/subtree_%d.dot", I );
- std::ofstream dots(fn);
- dots << "graph G {" << endl;
- dots << "graph[size=\"9,9\"];" << endl;
- dots << "node[shape=circle,width=0.4,fixedsize=true];" << endl;
- for( size_t i = 0; i < nrVars(); i++ )
- dots << "\tx" << var(i).label() << ((factor(I).vars() >> var(i)) ? "[color=blue];" : ";") << endl;
- dots << "node[shape=box,style=filled,color=lightgrey,width=0.3,height=0.3,fixedsize=true];" << endl;
- for( size_t J = 0; J < nrFactors(); J++ )
- dots << "\tp" << J << ";" << endl;
- for( size_t iI = 0; iI < FactorGraph::nr_edges(); iI++ )
- dots << "\tx" << var(FactorGraph::edge(iI).first).label() << " -- p" << FactorGraph::edge(iI).second << ";" << endl;
- for( size_t a = 0; a < tree.size(); a++ )
- dots << "\tx" << var(tree[a].n1).label() << " -- x" << var(tree[a].n2).label() << " [color=red];" << endl;
- dots << "}" << endl;
- dots.close();
-*/
-
- TreeEPSubTree QI( subTree, _RTree, _Qa, _Qb, &factor(I) );
+ TreeEPSubTree QI( subTree, RTree, Qa, Qb, &factor(I) );
_Q[I] = QI;
}
// Previous root of first off-tree factor should be the root of the last off-tree factor
//subTree.resize( subTreeSize ); // FIXME
// cout << "subtree " << I << " has size " << subTreeSize << endl;
- TreeEPSubTree QI( subTree, _RTree, _Qa, _Qb, &factor(I) );
+ TreeEPSubTree QI( subTree, RTree, Qa, Qb, &factor(I) );
_Q[I] = QI;
break;
}
string TreeEP::identify() const {
- stringstream result (stringstream::out);
- result << Name << getProperties();
- return result.str();
+ return string(Name) + printProperties();
}
for( size_t i = 0; i < nrVars(); i++ )
old_beliefs.push_back(belief(var(i)));
- size_t iter = 0;
-
// do several passes over the network until maximum number of iterations has
// been reached or until the maximum belief difference is smaller than tolerance
- for( iter=0; iter < props.maxiter && diffs.maxDiff() > props.tol; iter++ ) {
+ for( _iters=0; _iters < props.maxiter && diffs.maxDiff() > props.tol; _iters++ ) {
for( size_t I = 0; I < nrFactors(); I++ )
if( offtree(I) ) {
- _Q[I].InvertAndMultiply( _Qa, _Qb );
- _Q[I].HUGIN_with_I( _Qa, _Qb );
- _Q[I].InvertAndMultiply( _Qa, _Qb );
+ _Q[I].InvertAndMultiply( Qa, Qb );
+ _Q[I].HUGIN_with_I( Qa, Qb );
+ _Q[I].InvertAndMultiply( Qa, Qb );
}
// calculate new beliefs and compare with old ones
}
if( props.verbose >= 3 )
- cout << "TreeEP::run: maxdiff " << diffs.maxDiff() << " after " << iter+1 << " passes" << endl;
+ cout << Name << "::run: maxdiff " << diffs.maxDiff() << " after " << _iters+1 << " passes" << endl;
}
- if( diffs.maxDiff() > maxdiff )
- maxdiff = diffs.maxDiff();
+ if( diffs.maxDiff() > _maxdiff )
+ _maxdiff = diffs.maxDiff();
if( props.verbose >= 1 ) {
if( diffs.maxDiff() > props.tol ) {
if( props.verbose == 1 )
cout << endl;
- cout << "TreeEP::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl;
+ cout << Name << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << diffs.maxDiff() << endl;
} else {
if( props.verbose >= 3 )
- cout << "TreeEP::run: ";
- cout << "converged in " << iter << " passes (" << toc() - tic << " clocks)." << endl;
+ cout << Name << "::run: ";
+ cout << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
}
}
}
-Complex TreeEP::logZ() const {
+Real TreeEP::logZ() const {
double sum = 0.0;
// entropy of the tree
for( size_t beta = 0; beta < nrIRs(); beta++ )
- sum -= real(_Qb[beta].entropy());
+ sum -= Qb[beta].entropy();
for( size_t alpha = 0; alpha < nrORs(); alpha++ )
- sum += real(_Qa[alpha].entropy());
+ sum += Qa[alpha].entropy();
// energy of the on-tree factors
for( size_t alpha = 0; alpha < nrORs(); alpha++ )
- sum += (OR(alpha).log0() * _Qa[alpha]).totalSum();
+ sum += (OR(alpha).log0() * Qa[alpha]).totalSum();
// energy of the off-tree factors
for( size_t I = 0; I < nrFactors(); I++ )
if( offtree(I) )
- sum += (_Q.find(I))->second.logZ( _Qa, _Qb );
+ sum += (_Q.find(I))->second.logZ( Qa, Qb );
return sum;
}