1 /* This file is part of libDAI - http://www.libdai.org/
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
15 #include <dai/jtree.h>
16 #include <dai/treeep.h>
26 const char *TreeEP::Name
= "TREEEP";
29 void TreeEP::setProperties( const PropertySet
&opts
) {
30 DAI_ASSERT( opts
.hasKey("tol") );
31 DAI_ASSERT( opts
.hasKey("maxiter") );
32 DAI_ASSERT( opts
.hasKey("verbose") );
33 DAI_ASSERT( opts
.hasKey("type") );
35 props
.tol
= opts
.getStringAs
<Real
>("tol");
36 props
.maxiter
= opts
.getStringAs
<size_t>("maxiter");
37 props
.verbose
= opts
.getStringAs
<size_t>("verbose");
38 props
.type
= opts
.getStringAs
<Properties::TypeType
>("type");
42 PropertySet
TreeEP::getProperties() const {
44 opts
.Set( "tol", props
.tol
);
45 opts
.Set( "maxiter", props
.maxiter
);
46 opts
.Set( "verbose", props
.verbose
);
47 opts
.Set( "type", props
.type
);
52 string
TreeEP::printProperties() const {
53 stringstream
s( stringstream::out
);
55 s
<< "tol=" << props
.tol
<< ",";
56 s
<< "maxiter=" << props
.maxiter
<< ",";
57 s
<< "verbose=" << props
.verbose
<< ",";
58 s
<< "type=" << props
.type
<< "]";
63 TreeEP::TreeEPSubTree::TreeEPSubTree( const DEdgeVec
&subRTree
, const DEdgeVec
&jt_RTree
, const std::vector
<Factor
> &jt_Qa
, const std::vector
<Factor
> &jt_Qb
, const Factor
*I
) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I
), _ns(), _nsrem(), _logZ(0.0) {
66 // Make _Qa, _Qb, _a and _b corresponding to the subtree
67 _b
.reserve( subRTree
.size() );
68 _Qb
.reserve( subRTree
.size() );
69 _RTree
.reserve( subRTree
.size() );
70 for( size_t i
= 0; i
< subRTree
.size(); i
++ ) {
71 size_t alpha1
= subRTree
[i
].n1
; // old index 1
72 size_t alpha2
= subRTree
[i
].n2
; // old index 2
73 size_t beta
; // old sep index
74 for( beta
= 0; beta
< jt_RTree
.size(); beta
++ )
75 if( UEdge( jt_RTree
[beta
].n1
, jt_RTree
[beta
].n2
) == UEdge( alpha1
, alpha2
) )
77 DAI_ASSERT( beta
!= jt_RTree
.size() );
79 size_t newalpha1
= find(_a
.begin(), _a
.end(), alpha1
) - _a
.begin();
80 if( newalpha1
== _a
.size() ) {
81 _Qa
.push_back( Factor( jt_Qa
[alpha1
].vars(), 1.0 ) );
82 _a
.push_back( alpha1
); // save old index in index conversion table
85 size_t newalpha2
= find(_a
.begin(), _a
.end(), alpha2
) - _a
.begin();
86 if( newalpha2
== _a
.size() ) {
87 _Qa
.push_back( Factor( jt_Qa
[alpha2
].vars(), 1.0 ) );
88 _a
.push_back( alpha2
); // save old index in index conversion table
91 _RTree
.push_back( DEdge( newalpha1
, newalpha2
) );
92 _Qb
.push_back( Factor( jt_Qb
[beta
].vars(), 1.0 ) );
96 // Find remaining variables (which are not in the new root)
97 _nsrem
= _ns
/ _Qa
[0].vars();
101 void TreeEP::TreeEPSubTree::init() {
102 for( size_t alpha
= 0; alpha
< _Qa
.size(); alpha
++ )
103 _Qa
[alpha
].fill( 1.0 );
104 for( size_t beta
= 0; beta
< _Qb
.size(); beta
++ )
105 _Qb
[beta
].fill( 1.0 );
109 void TreeEP::TreeEPSubTree::InvertAndMultiply( const std::vector
<Factor
> &Qa
, const std::vector
<Factor
> &Qb
) {
110 for( size_t alpha
= 0; alpha
< _Qa
.size(); alpha
++ )
111 _Qa
[alpha
] = Qa
[_a
[alpha
]] / _Qa
[alpha
];
113 for( size_t beta
= 0; beta
< _Qb
.size(); beta
++ )
114 _Qb
[beta
] = Qb
[_b
[beta
]] / _Qb
[beta
];
118 void TreeEP::TreeEPSubTree::HUGIN_with_I( std::vector
<Factor
> &Qa
, std::vector
<Factor
> &Qb
) {
119 // Backup _Qa and _Qb
120 vector
<Factor
> _Qa_old(_Qa
);
121 vector
<Factor
> _Qb_old(_Qb
);
124 for( size_t alpha
= 0; alpha
< _Qa
.size(); alpha
++ )
125 Qa
[_a
[alpha
]].fill( 0.0 );
126 for( size_t beta
= 0; beta
< _Qb
.size(); beta
++ )
127 Qb
[_b
[beta
]].fill( 0.0 );
129 // For all states of _nsrem
130 for( State
s(_nsrem
); s
.valid(); s
++ ) {
131 // Multiply root with slice of I
132 _Qa
[0] *= _I
->slice( _nsrem
, s
);
135 for( size_t i
= _RTree
.size(); (i
--) != 0; ) {
136 // clamp variables in nsrem
137 for( VarSet::const_iterator n
= _nsrem
.begin(); n
!= _nsrem
.end(); n
++ )
138 if( _Qa
[_RTree
[i
].n2
].vars() >> *n
) {
139 Factor
delta( *n
, 0.0 );
141 _Qa
[_RTree
[i
].n2
] *= delta
;
143 Factor new_Qb
= _Qa
[_RTree
[i
].n2
].marginal( _Qb
[i
].vars(), false );
144 _Qa
[_RTree
[i
].n1
] *= new_Qb
/ _Qb
[i
];
148 // DistributeEvidence
149 for( size_t i
= 0; i
< _RTree
.size(); i
++ ) {
150 Factor new_Qb
= _Qa
[_RTree
[i
].n1
].marginal( _Qb
[i
].vars(), false );
151 _Qa
[_RTree
[i
].n2
] *= new_Qb
/ _Qb
[i
];
155 // Store Qa's and Qb's
156 for( size_t alpha
= 0; alpha
< _Qa
.size(); alpha
++ )
157 Qa
[_a
[alpha
]].p() += _Qa
[alpha
].p();
158 for( size_t beta
= 0; beta
< _Qb
.size(); beta
++ )
159 Qb
[_b
[beta
]].p() += _Qb
[beta
].p();
161 // Restore _Qa and _Qb
166 // Normalize Qa and Qb
168 for( size_t alpha
= 0; alpha
< _Qa
.size(); alpha
++ ) {
169 _logZ
+= log(Qa
[_a
[alpha
]].sum());
170 Qa
[_a
[alpha
]].normalize();
172 for( size_t beta
= 0; beta
< _Qb
.size(); beta
++ ) {
173 _logZ
-= log(Qb
[_b
[beta
]].sum());
174 Qb
[_b
[beta
]].normalize();
179 Real
TreeEP::TreeEPSubTree::logZ( const std::vector
<Factor
> &Qa
, const std::vector
<Factor
> &Qb
) const {
181 for( size_t alpha
= 0; alpha
< _Qa
.size(); alpha
++ )
182 s
+= (Qa
[_a
[alpha
]] * _Qa
[alpha
].log(true)).sum();
183 for( size_t beta
= 0; beta
< _Qb
.size(); beta
++ )
184 s
-= (Qb
[_b
[beta
]] * _Qb
[beta
].log(true)).sum();
189 TreeEP::TreeEP( const FactorGraph
&fg
, const PropertySet
&opts
) : JTree(fg
, opts("updates",string("HUGIN")), false), _maxdiff(0.0), _iters(0), props(), _Q() {
190 setProperties( opts
);
192 DAI_ASSERT( fg
.isConnected() );
194 if( opts
.hasKey("tree") ) {
195 ConstructRG( opts
.GetAs
<DEdgeVec
>("tree") );
197 if( props
.type
== Properties::TypeType::ORG
|| props
.type
== Properties::TypeType::ALT
) {
198 // ORG: construct weighted graph with as weights a crude estimate of the
199 // mutual information between the nodes
200 // ALT: construct weighted graph with as weights an upper bound on the
201 // effective interaction strength between pairs of nodes
203 WeightedGraph
<Real
> wg
;
204 for( size_t i
= 0; i
< nrVars(); ++i
) {
206 VarSet di
= delta(i
);
207 for( VarSet::const_iterator j
= di
.begin(); j
!= di
.end(); j
++ )
211 for( size_t I
= 0; I
< nrFactors(); I
++ ) {
212 VarSet Ivars
= factor(I
).vars();
213 if( props
.type
== Properties::TypeType::ORG
) {
214 if( (Ivars
== v_i
) || (Ivars
== *j
) )
216 else if( Ivars
>> ij
)
217 piet
*= factor(I
).marginal( ij
);
223 if( props
.type
== Properties::TypeType::ORG
) {
224 if( piet
.vars() >> ij
) {
225 piet
= piet
.marginal( ij
);
226 Factor pietf
= piet
.marginal(v_i
) * piet
.marginal(*j
);
227 wg
[UEdge(i
,findVar(*j
))] = dist( piet
, pietf
, Prob::DISTKL
);
229 wg
[UEdge(i
,findVar(*j
))] = 0;
231 wg
[UEdge(i
,findVar(*j
))] = piet
.strength(v_i
, *j
);
236 // find maximal spanning tree
237 ConstructRG( MaxSpanningTreePrims( wg
) );
239 DAI_THROW(UNKNOWN_ENUM_VALUE
);
244 void TreeEP::ConstructRG( const DEdgeVec
&tree
) {
245 vector
<VarSet
> Cliques
;
246 for( size_t i
= 0; i
< tree
.size(); i
++ )
247 Cliques
.push_back( VarSet( var(tree
[i
].n1
), var(tree
[i
].n2
) ) );
249 // Construct a weighted graph (each edge is weighted with the cardinality
250 // of the intersection of the nodes, where the nodes are the elements of
252 WeightedGraph
<int> JuncGraph
;
253 for( size_t i
= 0; i
< Cliques
.size(); i
++ )
254 for( size_t j
= i
+1; j
< Cliques
.size(); j
++ ) {
255 size_t w
= (Cliques
[i
] & Cliques
[j
]).size();
257 JuncGraph
[UEdge(i
,j
)] = w
;
260 // Construct maximal spanning tree using Prim's algorithm
261 RTree
= MaxSpanningTreePrims( JuncGraph
);
263 // Construct corresponding region graph
265 // Create outer regions
266 ORs
.reserve( Cliques
.size() );
267 for( size_t i
= 0; i
< Cliques
.size(); i
++ )
268 ORs
.push_back( FRegion( Factor(Cliques
[i
], 1.0), 1.0 ) );
270 // For each factor, find an outer region that subsumes that factor.
271 // Then, multiply the outer region with that factor.
272 // If no outer region can be found subsuming that factor, label the
273 // factor as off-tree.
275 fac2OR
.resize( nrFactors(), -1U );
276 for( size_t I
= 0; I
< nrFactors(); I
++ ) {
278 for( alpha
= 0; alpha
< nrORs(); alpha
++ )
279 if( OR(alpha
).vars() >> factor(I
).vars() ) {
283 // DIFF WITH JTree::GenerateJT: assert
287 // Create inner regions and edges
288 IRs
.reserve( RTree
.size() );
290 edges
.reserve( 2 * RTree
.size() );
291 for( size_t i
= 0; i
< RTree
.size(); i
++ ) {
292 edges
.push_back( Edge( RTree
[i
].n1
, IRs
.size() ) );
293 edges
.push_back( Edge( RTree
[i
].n2
, IRs
.size() ) );
294 // inner clusters have counting number -1
295 IRs
.push_back( Region( Cliques
[RTree
[i
].n1
] & Cliques
[RTree
[i
].n2
], -1.0 ) );
298 // create bipartite graph
299 G
.construct( nrORs(), nrIRs(), edges
.begin(), edges
.end() );
301 // Check counting numbers
302 checkCountingNumbers();
304 // Create messages and beliefs
306 Qa
.reserve( nrORs() );
307 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
308 Qa
.push_back( OR(alpha
) );
311 Qb
.reserve( nrIRs() );
312 for( size_t beta
= 0; beta
< nrIRs(); beta
++ )
313 Qb
.push_back( Factor( IR(beta
), 1.0 ) );
315 // DIFF with JTree::GenerateJT: no messages
317 // DIFF with JTree::GenerateJT:
318 // Create factor approximations
320 size_t PreviousRoot
= (size_t)-1;
321 for( size_t I
= 0; I
< nrFactors(); I
++ )
323 // find efficient subtree
325 /*size_t subTreeSize =*/ findEfficientTree( factor(I
).vars(), subTree
, PreviousRoot
);
326 PreviousRoot
= subTree
[0].n1
;
327 //subTree.resize( subTreeSize ); // FIXME
328 // cerr << "subtree " << I << " has size " << subTreeSize << endl;
330 TreeEPSubTree
QI( subTree
, RTree
, Qa
, Qb
, &factor(I
) );
333 // Previous root of first off-tree factor should be the root of the last off-tree factor
334 for( size_t I
= 0; I
< nrFactors(); I
++ )
337 /*size_t subTreeSize =*/ findEfficientTree( factor(I
).vars(), subTree
, PreviousRoot
);
338 PreviousRoot
= subTree
[0].n1
;
339 //subTree.resize( subTreeSize ); // FIXME
340 // cerr << "subtree " << I << " has size " << subTreeSize << endl;
342 TreeEPSubTree
QI( subTree
, RTree
, Qa
, Qb
, &factor(I
) );
347 if( props
.verbose
>= 3 ) {
348 cerr
<< "Resulting regiongraph: " << *this << endl
;
353 string
TreeEP::identify() const {
354 return string(Name
) + printProperties();
358 void TreeEP::init() {
361 // Init factor approximations
362 for( size_t I
= 0; I
< nrFactors(); I
++ )
369 if( props
.verbose
>= 1 )
370 cerr
<< "Starting " << identify() << "...";
371 if( props
.verbose
>= 3)
375 Diffs
diffs(nrVars(), 1.0);
377 vector
<Factor
> old_beliefs
;
378 old_beliefs
.reserve( nrVars() );
379 for( size_t i
= 0; i
< nrVars(); i
++ )
380 old_beliefs
.push_back(belief(var(i
)));
382 // do several passes over the network until maximum number of iterations has
383 // been reached or until the maximum belief difference is smaller than tolerance
384 for( _iters
=0; _iters
< props
.maxiter
&& diffs
.maxDiff() > props
.tol
; _iters
++ ) {
385 for( size_t I
= 0; I
< nrFactors(); I
++ )
387 _Q
[I
].InvertAndMultiply( Qa
, Qb
);
388 _Q
[I
].HUGIN_with_I( Qa
, Qb
);
389 _Q
[I
].InvertAndMultiply( Qa
, Qb
);
392 // calculate new beliefs and compare with old ones
393 for( size_t i
= 0; i
< nrVars(); i
++ ) {
394 Factor
nb( belief(var(i
)) );
395 diffs
.push( dist( nb
, old_beliefs
[i
], Prob::DISTLINF
) );
399 if( props
.verbose
>= 3 )
400 cerr
<< Name
<< "::run: maxdiff " << diffs
.maxDiff() << " after " << _iters
+1 << " passes" << endl
;
403 if( diffs
.maxDiff() > _maxdiff
)
404 _maxdiff
= diffs
.maxDiff();
406 if( props
.verbose
>= 1 ) {
407 if( diffs
.maxDiff() > props
.tol
) {
408 if( props
.verbose
== 1 )
410 cerr
<< Name
<< "::run: WARNING: not converged within " << props
.maxiter
<< " passes (" << toc() - tic
<< " seconds)...final maxdiff:" << diffs
.maxDiff() << endl
;
412 if( props
.verbose
>= 3 )
413 cerr
<< Name
<< "::run: ";
414 cerr
<< "converged in " << _iters
<< " passes (" << toc() - tic
<< " seconds)." << endl
;
418 return diffs
.maxDiff();
422 Real
TreeEP::logZ() const {
425 // entropy of the tree
426 for( size_t beta
= 0; beta
< nrIRs(); beta
++ )
427 s
-= Qb
[beta
].entropy();
428 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
429 s
+= Qa
[alpha
].entropy();
431 // energy of the on-tree factors
432 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
433 s
+= (OR(alpha
).log(true) * Qa
[alpha
]).sum();
435 // energy of the off-tree factors
436 for( size_t I
= 0; I
< nrFactors(); I
++ )
438 s
+= (_Q
.find(I
))->second
.logZ( Qa
, Qb
);
444 } // end of namespace dai