-/* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
- Radboud University Nijmegen, The Netherlands
-
+/* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
+ Radboud University Nijmegen, The Netherlands /
+ Max Planck Institute for Biological Cybernetics, Germany
+
This file is part of libDAI.
libDAI is free software; you can redistribute it and/or modify
const char *JTree::Name = "JTREE";
-bool JTree::checkProperties() {
- if (!HasProperty("verbose") )
- return false;
- if( !HasProperty("updates") )
- return false;
-
- ConvertPropertyTo<size_t>("verbose");
- ConvertPropertyTo<UpdateType>("updates");
+void JTree::setProperties( const PropertySet &opts ) {
+ assert( opts.hasKey("verbose") );
+ assert( opts.hasKey("updates") );
- return true;
+ props.verbose = opts.getStringAs<size_t>("verbose");
+ props.updates = opts.getStringAs<Properties::UpdateType>("updates");
}
-JTree::JTree( const FactorGraph &fg, const Properties &opts, bool automatic) : DAIAlgRG(fg, opts), _RTree(), _Qa(), _Qb(), _mes(), _logZ() {
- assert( checkProperties() );
+PropertySet JTree::getProperties() const {
+ PropertySet opts;
+ opts.Set( "verbose", props.verbose );
+ opts.Set( "updates", props.updates );
+ return opts;
+}
+
+
+string JTree::printProperties() const {
+ stringstream s( stringstream::out );
+ s << "[";
+ s << "verbose=" << props.verbose << ",";
+ s << "updates=" << props.updates << "]";
+ return s.str();
+}
- if( automatic ) {
- ClusterGraph _cg;
- // Copy factors
+JTree::JTree( const FactorGraph &fg, const PropertySet &opts, bool automatic ) : DAIAlgRG(fg), _mes(), _logZ(), RTree(), Qa(), Qb(), props() {
+ setProperties( opts );
+
+ if( !isConnected() )
+ DAI_THROW(FACTORGRAPH_NOT_CONNECTED);
+
+ if( automatic ) {
+ // Create ClusterGraph which contains factors as clusters
+ vector<VarSet> cl;
+ cl.reserve( fg.nrFactors() );
for( size_t I = 0; I < nrFactors(); I++ )
- _cg.insert( factor(I).vars() );
- if( Verbose() >= 3 )
- cout << "Initial clusters: " << _cg << endl;
+ cl.push_back( factor(I).vars() );
+ ClusterGraph _cg( cl );
+
+ if( props.verbose >= 3 )
+ cerr << "Initial clusters: " << _cg << endl;
// Retain only maximal clusters
_cg.eraseNonMaximal();
- if( Verbose() >= 3 )
- cout << "Maximal clusters: " << _cg << endl;
+ if( props.verbose >= 3 )
+ cerr << "Maximal clusters: " << _cg << endl;
vector<VarSet> ElimVec = _cg.VarElim_MinFill().eraseNonMaximal().toVector();
- if( Verbose() >= 3 ) {
- cout << "VarElim_MinFill result: {" << endl;
- for( size_t i = 0; i < ElimVec.size(); i++ ) {
- if( i != 0 )
- cout << ", ";
- cout << ElimVec[i];
- }
- cout << "}" << endl;
- }
+ if( props.verbose >= 3 )
+ cerr << "VarElim_MinFill result: " << ElimVec << endl;
GenerateJT( ElimVec );
}
}
-void JTree::GenerateJT( const vector<VarSet> &Cliques ) {
- // Construct a weighted graph (each edge is weighted with the cardinality
+void JTree::GenerateJT( const std::vector<VarSet> &Cliques ) {
+ // Construct a weighted graph (each edge is weighted with the cardinality
// of the intersection of the nodes, where the nodes are the elements of
// Cliques).
WeightedGraph<int> JuncGraph;
for( size_t i = 0; i < Cliques.size(); i++ )
for( size_t j = i+1; j < Cliques.size(); j++ ) {
size_t w = (Cliques[i] & Cliques[j]).size();
- JuncGraph[UEdge(i,j)] = w;
+ if( w )
+ JuncGraph[UEdge(i,j)] = w;
}
-
+
// Construct maximal spanning tree using Prim's algorithm
- _RTree = MaxSpanningTreePrim( JuncGraph );
+ RTree = MaxSpanningTreePrims( JuncGraph );
// Construct corresponding region graph
// Create outer regions
- ORs().reserve( Cliques.size() );
+ ORs.reserve( Cliques.size() );
for( size_t i = 0; i < Cliques.size(); i++ )
- ORs().push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) );
+ ORs.push_back( FRegion( Factor(Cliques[i], 1.0), 1.0 ) );
// For each factor, find an outer region that subsumes that factor.
// Then, multiply the outer region with that factor.
for( size_t I = 0; I < nrFactors(); I++ ) {
size_t alpha;
- for( alpha = 0; alpha < nr_ORs(); alpha++ )
+ for( alpha = 0; alpha < nrORs(); alpha++ )
if( OR(alpha).vars() >> factor(I).vars() ) {
-// OR(alpha) *= factor(I);
- _fac2OR[I] = alpha;
+ fac2OR.push_back( alpha );
break;
}
- assert( alpha != nr_ORs() );
+ assert( alpha != nrORs() );
}
RecomputeORs();
// Create inner regions and edges
- IRs().reserve( _RTree.size() );
- Redges().reserve( 2 * _RTree.size() );
- for( size_t i = 0; i < _RTree.size(); i++ ) {
- Redges().push_back( R_edge_t( _RTree[i].n1, IRs().size() ) );
- Redges().push_back( R_edge_t( _RTree[i].n2, IRs().size() ) );
+ IRs.reserve( RTree.size() );
+ vector<Edge> edges;
+ edges.reserve( 2 * RTree.size() );
+ for( size_t i = 0; i < RTree.size(); i++ ) {
+ edges.push_back( Edge( RTree[i].n1, nrIRs() ) );
+ edges.push_back( Edge( RTree[i].n2, nrIRs() ) );
// inner clusters have counting number -1
- IRs().push_back( Region( Cliques[_RTree[i].n1] & Cliques[_RTree[i].n2], -1.0 ) );
+ IRs.push_back( Region( Cliques[RTree[i].n1] & Cliques[RTree[i].n2], -1.0 ) );
}
- // Regenerate BipartiteGraph internals
- Regenerate();
+ // create bipartite graph
+ G.construct( nrORs(), nrIRs(), edges.begin(), edges.end() );
// Create messages and beliefs
- _Qa.clear();
- _Qa.reserve( nr_ORs() );
- for( size_t alpha = 0; alpha < nr_ORs(); alpha++ )
- _Qa.push_back( OR(alpha) );
+ Qa.clear();
+ Qa.reserve( nrORs() );
+ for( size_t alpha = 0; alpha < nrORs(); alpha++ )
+ Qa.push_back( OR(alpha) );
- _Qb.clear();
- _Qb.reserve( nr_IRs() );
- for( size_t beta = 0; beta < nr_IRs(); beta++ )
- _Qb.push_back( Factor( IR(beta), 1.0 ) );
+ Qb.clear();
+ Qb.reserve( nrIRs() );
+ for( size_t beta = 0; beta < nrIRs(); beta++ )
+ Qb.push_back( Factor( IR(beta), 1.0 ) );
_mes.clear();
- _mes.reserve( nr_Redges() );
- for( size_t e = 0; e < nr_Redges(); e++ )
- _mes.push_back( Factor( IR(Redge(e).second), 1.0 ) );
+ _mes.reserve( nrORs() );
+ for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
+ _mes.push_back( vector<Factor>() );
+ _mes[alpha].reserve( nbOR(alpha).size() );
+ foreach( const Neighbor &beta, nbOR(alpha) )
+ _mes[alpha].push_back( Factor( IR(beta), 1.0 ) );
+ }
// Check counting numbers
Check_Counting_Numbers();
- if( Verbose() >= 3 ) {
- cout << "Resulting regiongraph: " << *this << endl;
+ if( props.verbose >= 3 ) {
+ cerr << "Resulting regiongraph: " << *this << endl;
}
}
string JTree::identify() const {
- stringstream result (stringstream::out);
- result << Name << GetProperties();
- return result.str();
+ return string(Name) + printProperties();
}
Factor JTree::belief( const VarSet &ns ) const {
vector<Factor>::const_iterator beta;
- for( beta = _Qb.begin(); beta != _Qb.end(); beta++ )
+ for( beta = Qb.begin(); beta != Qb.end(); beta++ )
if( beta->vars() >> ns )
break;
- if( beta != _Qb.end() )
+ if( beta != Qb.end() )
return( beta->marginal(ns) );
else {
vector<Factor>::const_iterator alpha;
- for( alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
+ for( alpha = Qa.begin(); alpha != Qa.end(); alpha++ )
if( alpha->vars() >> ns )
break;
- assert( alpha != _Qa.end() );
+ assert( alpha != Qa.end() );
return( alpha->marginal(ns) );
}
}
vector<Factor> JTree::beliefs() const {
vector<Factor> result;
- for( size_t beta = 0; beta < nr_IRs(); beta++ )
- result.push_back( _Qb[beta] );
- for( size_t alpha = 0; alpha < nr_ORs(); alpha++ )
- result.push_back( _Qa[alpha] );
+ for( size_t beta = 0; beta < nrIRs(); beta++ )
+ result.push_back( Qb[beta] );
+ for( size_t alpha = 0; alpha < nrORs(); alpha++ )
+ result.push_back( Qa[alpha] );
return result;
}
// Needs no init
void JTree::runHUGIN() {
- for( size_t alpha = 0; alpha < nr_ORs(); alpha++ )
- _Qa[alpha] = OR(alpha);
+ for( size_t alpha = 0; alpha < nrORs(); alpha++ )
+ Qa[alpha] = OR(alpha);
- for( size_t beta = 0; beta < nr_IRs(); beta++ )
- _Qb[beta].fill( 1.0 );
+ for( size_t beta = 0; beta < nrIRs(); beta++ )
+ Qb[beta].fill( 1.0 );
// CollectEvidence
_logZ = 0.0;
- for( size_t i = _RTree.size(); (i--) != 0; ) {
-// Make outer region _RTree[i].n1 consistent with outer region _RTree[i].n2
-// IR(i) = seperator OR(_RTree[i].n1) && OR(_RTree[i].n2)
- Factor new_Qb = _Qa[_RTree[i].n2].part_sum( IR( i ) );
- _logZ += log(new_Qb.normalize( Prob::NORMPROB ));
- _Qa[_RTree[i].n1] *= new_Qb.divided_by( _Qb[i] );
- _Qb[i] = new_Qb;
+ for( size_t i = RTree.size(); (i--) != 0; ) {
+// Make outer region RTree[i].n1 consistent with outer region RTree[i].n2
+// IR(i) = seperator OR(RTree[i].n1) && OR(RTree[i].n2)
+ Factor new_Qb = Qa[RTree[i].n2].marginal( IR( i ), false );
+ _logZ += log(new_Qb.normalize());
+ Qa[RTree[i].n1] *= new_Qb / Qb[i];
+ Qb[i] = new_Qb;
}
- if( _RTree.empty() )
- _logZ += log(_Qa[0].normalize( Prob::NORMPROB ) );
+ if( RTree.empty() )
+ _logZ += log(Qa[0].normalize() );
else
- _logZ += log(_Qa[_RTree[0].n1].normalize( Prob::NORMPROB ));
+ _logZ += log(Qa[RTree[0].n1].normalize());
// DistributeEvidence
- for( size_t i = 0; i < _RTree.size(); i++ ) {
-// Make outer region _RTree[i].n2 consistent with outer region _RTree[i].n1
-// IR(i) = seperator OR(_RTree[i].n1) && OR(_RTree[i].n2)
- Factor new_Qb = _Qa[_RTree[i].n1].marginal( IR( i ) );
- _Qa[_RTree[i].n2] *= new_Qb.divided_by( _Qb[i] );
- _Qb[i] = new_Qb;
+ for( size_t i = 0; i < RTree.size(); i++ ) {
+// Make outer region RTree[i].n2 consistent with outer region RTree[i].n1
+// IR(i) = seperator OR(RTree[i].n1) && OR(RTree[i].n2)
+ Factor new_Qb = Qa[RTree[i].n1].marginal( IR( i ) );
+ Qa[RTree[i].n2] *= new_Qb / Qb[i];
+ Qb[i] = new_Qb;
}
// Normalize
- for( size_t alpha = 0; alpha < nr_ORs(); alpha++ )
- _Qa[alpha].normalize( Prob::NORMPROB );
+ for( size_t alpha = 0; alpha < nrORs(); alpha++ )
+ Qa[alpha].normalize();
}
void JTree::runShaferShenoy() {
// First pass
_logZ = 0.0;
- for( size_t e = _RTree.size(); (e--) != 0; ) {
- // send a message from _RTree[e].n2 to _RTree[e].n1
- // or, actually, from the seperator IR(e) to _RTree[e].n1
+ for( size_t e = nrIRs(); (e--) != 0; ) {
+ // send a message from RTree[e].n2 to RTree[e].n1
+ // or, actually, from the seperator IR(e) to RTree[e].n1
+
+ size_t i = nbIR(e)[1].node; // = RTree[e].n2
+ size_t j = nbIR(e)[0].node; // = RTree[e].n1
+ size_t _e = nbIR(e)[0].dual;
- size_t i = _RTree[e].n2;
- size_t j = _RTree[e].n1;
-
Factor piet = OR(i);
- for( R_nb_cit k = nbOR(i).begin(); k != nbOR(i).end(); k++ )
- if( *k != e )
- piet *= message( i, *k );
- message( j, e ) = piet.part_sum( IR(e) );
- _logZ += log( message(j,e).normalize( Prob::NORMPROB ) );
+ foreach( const Neighbor &k, nbOR(i) )
+ if( k != e )
+ piet *= message( i, k.iter );
+ message( j, _e ) = piet.marginal( IR(e), false );
+ _logZ += log( message(j,_e).normalize() );
}
// Second pass
- for( size_t e = 0; e < _RTree.size(); e++ ) {
- size_t i = _RTree[e].n1;
- size_t j = _RTree[e].n2;
-
+ for( size_t e = 0; e < nrIRs(); e++ ) {
+ size_t i = nbIR(e)[0].node; // = RTree[e].n1
+ size_t j = nbIR(e)[1].node; // = RTree[e].n2
+ size_t _e = nbIR(e)[1].dual;
+
Factor piet = OR(i);
- for( R_nb_cit k = nbOR(i).begin(); k != nbOR(i).end(); k++ )
- if( *k != e )
- piet *= message( i, *k );
- message( j, e ) = piet.marginal( IR(e) );
+ foreach( const Neighbor &k, nbOR(i) )
+ if( k != e )
+ piet *= message( i, k.iter );
+ message( j, _e ) = piet.marginal( IR(e) );
}
// Calculate beliefs
- for( size_t alpha = 0; alpha < nr_ORs(); alpha++ ) {
+ for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
Factor piet = OR(alpha);
- for( R_nb_cit k = nbOR(alpha).begin(); k != nbOR(alpha).end(); k++ )
- piet *= message( alpha, *k );
- if( _RTree.empty() ) {
- _logZ += log( piet.normalize( Prob::NORMPROB ) );
- _Qa[alpha] = piet;
- } else if( alpha == _RTree[0].n1 ) {
- _logZ += log( piet.normalize( Prob::NORMPROB ) );
- _Qa[alpha] = piet;
+ foreach( const Neighbor &k, nbOR(alpha) )
+ piet *= message( alpha, k.iter );
+ if( nrIRs() == 0 ) {
+ _logZ += log( piet.normalize() );
+ Qa[alpha] = piet;
+ } else if( alpha == nbIR(0)[0].node /*RTree[0].n1*/ ) {
+ _logZ += log( piet.normalize() );
+ Qa[alpha] = piet;
} else
- _Qa[alpha] = piet.normalized( Prob::NORMPROB );
+ Qa[alpha] = piet.normalized();
}
// Only for logZ (and for belief)...
- for( size_t beta = 0; beta < nr_IRs(); beta++ )
- _Qb[beta] = _Qa[nbIR(beta)[0]].marginal( IR(beta) );
+ for( size_t beta = 0; beta < nrIRs(); beta++ )
+ Qb[beta] = Qa[nbIR(beta)[0].node].marginal( IR(beta) );
}
double JTree::run() {
- if( Updates() == UpdateType::HUGIN )
+ if( props.updates == Properties::UpdateType::HUGIN )
runHUGIN();
- else if( Updates() == UpdateType::SHSH )
+ else if( props.updates == Properties::UpdateType::SHSH )
runShaferShenoy();
return 0.0;
}
-Complex JTree::logZ() const {
- Complex sum = 0.0;
- for( size_t beta = 0; beta < nr_IRs(); beta++ )
- sum += Complex(IR(beta).c()) * _Qb[beta].entropy();
- for( size_t alpha = 0; alpha < nr_ORs(); alpha++ ) {
- sum += Complex(OR(alpha).c()) * _Qa[alpha].entropy();
- sum += (OR(alpha).log0() * _Qa[alpha]).totalSum();
+Real JTree::logZ() const {
+ Real s = 0.0;
+ for( size_t beta = 0; beta < nrIRs(); beta++ )
+ s += IR(beta).c() * Qb[beta].entropy();
+ for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
+ s += OR(alpha).c() * Qa[alpha].entropy();
+ s += (OR(alpha).log(true) * Qa[alpha]).sum();
}
- return sum;
+ return s;
}
size_t JTree::findEfficientTree( const VarSet& ns, DEdgeVec &Tree, size_t PreviousRoot ) const {
// find new root clique (the one with maximal statespace overlap with ns)
size_t maxval = 0, maxalpha = 0;
- for( size_t alpha = 0; alpha < nr_ORs(); alpha++ ) {
- size_t val = (ns & OR(alpha).vars()).stateSpace();
+ for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
+ size_t val = VarSet(ns & OR(alpha).vars()).nrStates();
if( val > maxval ) {
maxval = val;
maxalpha = alpha;
}
}
-// for( size_t e = 0; e < _RTree.size(); e++ )
-// cout << OR(_RTree[e].n1).vars() << "->" << OR(_RTree[e].n2).vars() << ", ";
-// cout << endl;
// grow new tree
Graph oldTree;
- for( DEdgeVec::const_iterator e = _RTree.begin(); e != _RTree.end(); e++ )
+ for( DEdgeVec::const_iterator e = RTree.begin(); e != RTree.end(); e++ )
oldTree.insert( UEdge(e->n1, e->n2) );
DEdgeVec newTree = GrowRootedTree( oldTree, maxalpha );
-// cout << ns << ": ";
-// for( size_t e = 0; e < newTree.size(); e++ )
-// cout << OR(newTree[e].n1).vars() << "->" << OR(newTree[e].n2).vars() << ", ";
-// cout << endl;
-
+
// identify subtree that contains variables of ns which are not in the new root
VarSet nsrem = ns / OR(maxalpha).vars();
-// cout << "nsrem:" << nsrem << endl;
set<DEdge> subTree;
// for each variable in ns that is not in the root clique
for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ ) {
// find first occurence of *n in the tree, which is closest to the root
size_t e = 0;
for( ; e != newTree.size(); e++ ) {
- if( OR(newTree[e].n2).vars() && *n )
+ if( OR(newTree[e].n2).vars().contains( *n ) )
break;
}
assert( e != newTree.size() );
pos = newTree[e-1].n1;
}
}
-// cout << "subTree: " << endl;
-// for( set<DEdge>::const_iterator sTi = subTree.begin(); sTi != subTree.end(); sTi++ )
-// cout << OR(sTi->n1).vars() << "->" << OR(sTi->n2).vars() << ", ";
-// cout << endl;
// Resulting Tree is a reordered copy of newTree
// First add edges in subTree to Tree
for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
if( subTree.count( *e ) ) {
Tree.push_back( *e );
-// cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
}
-// cout << endl;
// Then add edges pointing away from nsrem
// FIXME
/* for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
if( e->n1 == sTi->n1 || e->n1 == sTi->n2 ||
e->n2 == sTi->n1 || e->n2 == sTi->n2 ) {
Tree.push_back( *e );
-// cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
}
}*/
// FIXME
}
if( found ) {
Tree.push_back( *e );
- cout << OR(e->n1).vars() << "->" << OR(e->n2).vars() << ", ";
}
- }
- cout << endl;*/
+ }*/
size_t subTreeSize = Tree.size();
// Then add remaining edges
for( DEdgeVec::const_iterator e = newTree.begin(); e != newTree.end(); e++ )
// assumes that run() has been called already
Factor JTree::calcMarginal( const VarSet& ns ) {
vector<Factor>::const_iterator beta;
- for( beta = _Qb.begin(); beta != _Qb.end(); beta++ )
+ for( beta = Qb.begin(); beta != Qb.end(); beta++ )
if( beta->vars() >> ns )
break;
- if( beta != _Qb.end() )
+ if( beta != Qb.end() )
return( beta->marginal(ns) );
else {
vector<Factor>::const_iterator alpha;
- for( alpha = _Qa.begin(); alpha != _Qa.end(); alpha++ )
+ for( alpha = Qa.begin(); alpha != Qa.end(); alpha++ )
if( alpha->vars() >> ns )
break;
- if( alpha != _Qa.end() )
+ if( alpha != Qa.end() )
return( alpha->marginal(ns) );
else {
// Find subtree to do efficient inference
// Find remaining variables (which are not in the new root)
VarSet nsrem = ns / OR(T.front().n1).vars();
Factor Pns (ns, 0.0);
-
- multind mi( nsrem );
- // Save _Qa and _Qb on the subtree
- map<size_t,Factor> _Qa_old;
- map<size_t,Factor> _Qb_old;
+ // Save Qa and Qb on the subtree
+ map<size_t,Factor> Qa_old;
+ map<size_t,Factor> Qb_old;
vector<size_t> b(Tsize, 0);
for( size_t i = Tsize; (i--) != 0; ) {
size_t alpha1 = T[i].n1;
size_t alpha2 = T[i].n2;
size_t beta;
- for( beta = 0; beta < nr_IRs(); beta++ )
- if( UEdge( _RTree[beta].n1, _RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) )
+ for( beta = 0; beta < nrIRs(); beta++ )
+ if( UEdge( RTree[beta].n1, RTree[beta].n2 ) == UEdge( alpha1, alpha2 ) )
break;
- assert( beta != nr_IRs() );
+ assert( beta != nrIRs() );
b[i] = beta;
- if( !_Qa_old.count( alpha1 ) )
- _Qa_old[alpha1] = _Qa[alpha1];
- if( !_Qa_old.count( alpha2 ) )
- _Qa_old[alpha2] = _Qa[alpha2];
- if( !_Qb_old.count( beta ) )
- _Qb_old[beta] = _Qb[beta];
+ if( !Qa_old.count( alpha1 ) )
+ Qa_old[alpha1] = Qa[alpha1];
+ if( !Qa_old.count( alpha2 ) )
+ Qa_old[alpha2] = Qa[alpha2];
+ if( !Qb_old.count( beta ) )
+ Qb_old[beta] = Qb[beta];
}
-
+
// For all states of nsrem
- for( size_t j = 0; j < mi.max(); j++ ) {
- vector<size_t> vi = mi.vi( j );
-
+ for( State s(nsrem); s.valid(); s++ ) {
// CollectEvidence
double logZ = 0.0;
for( size_t i = Tsize; (i--) != 0; ) {
- // Make outer region T[i].n1 consistent with outer region T[i].n2
- // IR(i) = seperator OR(T[i].n1) && OR(T[i].n2)
+ // Make outer region T[i].n1 consistent with outer region T[i].n2
+ // IR(i) = seperator OR(T[i].n1) && OR(T[i].n2)
- size_t k = 0;
- for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++, k++ )
- if( _Qa[T[i].n2].vars() >> *n ) {
+ for( VarSet::const_iterator n = nsrem.begin(); n != nsrem.end(); n++ )
+ if( Qa[T[i].n2].vars() >> *n ) {
Factor piet( *n, 0.0 );
- piet[vi[k]] = 1.0;
- _Qa[T[i].n2] *= piet;
+ piet[s(*n)] = 1.0;
+ Qa[T[i].n2] *= piet;
}
- Factor new_Qb = _Qa[T[i].n2].part_sum( IR( b[i] ) );
- logZ += log(new_Qb.normalize( Prob::NORMPROB ));
- _Qa[T[i].n1] *= new_Qb.divided_by( _Qb[b[i]] );
- _Qb[b[i]] = new_Qb;
+ Factor new_Qb = Qa[T[i].n2].marginal( IR( b[i] ), false );
+ logZ += log(new_Qb.normalize());
+ Qa[T[i].n1] *= new_Qb / Qb[b[i]];
+ Qb[b[i]] = new_Qb;
}
- logZ += log(_Qa[T[0].n1].normalize( Prob::NORMPROB ));
+ logZ += log(Qa[T[0].n1].normalize());
Factor piet( nsrem, 0.0 );
- piet[j] = exp(logZ);
- Pns += piet * _Qa[T[0].n1].part_sum( ns / nsrem ); // OPTIMIZE ME
+ piet[s] = exp(logZ);
+ Pns += piet * Qa[T[0].n1].marginal( ns / nsrem, false ); // OPTIMIZE ME
// Restore clamped beliefs
- for( map<size_t,Factor>::const_iterator alpha = _Qa_old.begin(); alpha != _Qa_old.end(); alpha++ )
- _Qa[alpha->first] = alpha->second;
- for( map<size_t,Factor>::const_iterator beta = _Qb_old.begin(); beta != _Qb_old.end(); beta++ )
- _Qb[beta->first] = beta->second;
+ for( map<size_t,Factor>::const_iterator alpha = Qa_old.begin(); alpha != Qa_old.end(); alpha++ )
+ Qa[alpha->first] = alpha->second;
+ for( map<size_t,Factor>::const_iterator beta = Qb_old.begin(); beta != Qb_old.end(); beta++ )
+ Qb[beta->first] = beta->second;
}
- return( Pns.normalized(Prob::NORMPROB) );
+ return( Pns.normalized() );
}
}
}
+/// Calculates upper bound to the treewidth of a FactorGraph
+/** \relates JTree
+ * \return a pair (number of variables in largest clique, number of states in largest clique)
+ */
+std::pair<size_t,size_t> treewidth( const FactorGraph & fg ) {
+ ClusterGraph _cg;
+
+ // Copy factors
+ for( size_t I = 0; I < fg.nrFactors(); I++ )
+ _cg.insert( fg.factor(I).vars() );
+
+ // Retain only maximal clusters
+ _cg.eraseNonMaximal();
+
+ // Obtain elimination sequence
+ vector<VarSet> ElimVec = _cg.VarElim_MinFill().eraseNonMaximal().toVector();
+
+ // Calculate treewidth
+ size_t treewidth = 0;
+ size_t nrstates = 0;
+ for( size_t i = 0; i < ElimVec.size(); i++ ) {
+ if( ElimVec[i].size() > treewidth )
+ treewidth = ElimVec[i].size();
+ size_t s = ElimVec[i].nrStates();
+ if( s > nrstates )
+ nrstates = s;
+ }
+
+ return pair<size_t,size_t>(treewidth, nrstates);
+}
+
+
} // end of namespace dai