1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
4 This file is part of libDAI.
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
25 #include <dai/jtree.h>
26 #include <dai/treeep.h>
28 #include <dai/diffs.h>
37 const char *TreeEP::Name
= "TREEEP";
40 bool TreeEP::checkProperties() {
41 if( !HasProperty("type") )
43 if( !HasProperty("tol") )
45 if (!HasProperty("maxiter") )
47 if (!HasProperty("verbose") )
50 ConvertPropertyTo
<TypeType
>("type");
51 ConvertPropertyTo
<double>("tol");
52 ConvertPropertyTo
<size_t>("maxiter");
53 ConvertPropertyTo
<size_t>("verbose");
59 TreeEPSubTree::TreeEPSubTree( const DEdgeVec
&subRTree
, const DEdgeVec
&jt_RTree
, const vector
<Factor
> &jt_Qa
, const vector
<Factor
> &jt_Qb
, const Factor
*I
) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I
), _ns(), _nsrem(), _logZ(0.0) {
62 // Make _Qa, _Qb, _a and _b corresponding to the subtree
63 _b
.reserve( subRTree
.size() );
64 _Qb
.reserve( subRTree
.size() );
65 _RTree
.reserve( subRTree
.size() );
66 for( size_t i
= 0; i
< subRTree
.size(); i
++ ) {
67 size_t alpha1
= subRTree
[i
].n1
; // old index 1
68 size_t alpha2
= subRTree
[i
].n2
; // old index 2
69 size_t beta
; // old sep index
70 for( beta
= 0; beta
< jt_RTree
.size(); beta
++ )
71 if( UEdge( jt_RTree
[beta
].n1
, jt_RTree
[beta
].n2
) == UEdge( alpha1
, alpha2
) )
73 assert( beta
!= jt_RTree
.size() );
75 size_t newalpha1
= find(_a
.begin(), _a
.end(), alpha1
) - _a
.begin();
76 if( newalpha1
== _a
.size() ) {
77 _Qa
.push_back( Factor( jt_Qa
[alpha1
].vars(), 1.0 ) );
78 _a
.push_back( alpha1
); // save old index in index conversion table
81 size_t newalpha2
= find(_a
.begin(), _a
.end(), alpha2
) - _a
.begin();
82 if( newalpha2
== _a
.size() ) {
83 _Qa
.push_back( Factor( jt_Qa
[alpha2
].vars(), 1.0 ) );
84 _a
.push_back( alpha2
); // save old index in index conversion table
87 _RTree
.push_back( DEdge( newalpha1
, newalpha2
) );
88 _Qb
.push_back( Factor( jt_Qb
[beta
].vars(), 1.0 ) );
92 // Find remaining variables (which are not in the new root)
93 _nsrem
= _ns
/ _Qa
[0].vars();
97 void TreeEPSubTree::init() {
98 for( size_t alpha
= 0; alpha
< _Qa
.size(); alpha
++ )
99 _Qa
[alpha
].fill( 1.0 );
100 for( size_t beta
= 0; beta
< _Qb
.size(); beta
++ )
101 _Qb
[beta
].fill( 1.0 );
105 void TreeEPSubTree::InvertAndMultiply( const vector
<Factor
> &Qa
, const vector
<Factor
> &Qb
) {
106 for( size_t alpha
= 0; alpha
< _Qa
.size(); alpha
++ )
107 _Qa
[alpha
] = Qa
[_a
[alpha
]].divided_by( _Qa
[alpha
] );
109 for( size_t beta
= 0; beta
< _Qb
.size(); beta
++ )
110 _Qb
[beta
] = Qb
[_b
[beta
]].divided_by( _Qb
[beta
] );
114 void TreeEPSubTree::HUGIN_with_I( vector
<Factor
> &Qa
, vector
<Factor
> &Qb
) {
115 multind
mi( _nsrem
);
117 // Backup _Qa and _Qb
118 vector
<Factor
> _Qa_old(_Qa
);
119 vector
<Factor
> _Qb_old(_Qb
);
122 for( size_t alpha
= 0; alpha
< _Qa
.size(); alpha
++ )
123 Qa
[_a
[alpha
]].fill( 0.0 );
124 for( size_t beta
= 0; beta
< _Qb
.size(); beta
++ )
125 Qb
[_b
[beta
]].fill( 0.0 );
127 // For all states of _nsrem
128 for( size_t j
= 0; j
< mi
.max(); j
++ ) {
129 vector
<size_t> vi
= mi
.vi( j
);
131 // Multiply root with slice of I
132 _Qa
[0] *= _I
->slice( _nsrem
, j
);
135 for( size_t i
= _RTree
.size(); (i
--) != 0; ) {
136 // clamp variables in nsrem
138 for( VarSet::const_iterator n
= _nsrem
.begin(); n
!= _nsrem
.end(); n
++, k
++ )
139 if( _Qa
[_RTree
[i
].n2
].vars() >> *n
) {
140 Factor
delta( *n
, 0.0 );
142 _Qa
[_RTree
[i
].n2
] *= delta
;
144 Factor new_Qb
= _Qa
[_RTree
[i
].n2
].part_sum( _Qb
[i
].vars() );
145 _Qa
[_RTree
[i
].n1
] *= new_Qb
.divided_by( _Qb
[i
] );
149 // DistributeEvidence
150 for( size_t i
= 0; i
< _RTree
.size(); i
++ ) {
151 Factor new_Qb
= _Qa
[_RTree
[i
].n1
].part_sum( _Qb
[i
].vars() );
152 _Qa
[_RTree
[i
].n2
] *= new_Qb
.divided_by( _Qb
[i
] );
156 // Store Qa's and Qb's
157 for( size_t alpha
= 0; alpha
< _Qa
.size(); alpha
++ )
158 Qa
[_a
[alpha
]].p() += _Qa
[alpha
].p();
159 for( size_t beta
= 0; beta
< _Qb
.size(); beta
++ )
160 Qb
[_b
[beta
]].p() += _Qb
[beta
].p();
162 // Restore _Qa and _Qb
167 // Normalize Qa and Qb
169 for( size_t alpha
= 0; alpha
< _Qa
.size(); alpha
++ ) {
170 _logZ
+= log(Qa
[_a
[alpha
]].totalSum());
171 Qa
[_a
[alpha
]].normalize( Prob::NORMPROB
);
173 for( size_t beta
= 0; beta
< _Qb
.size(); beta
++ ) {
174 _logZ
-= log(Qb
[_b
[beta
]].totalSum());
175 Qb
[_b
[beta
]].normalize( Prob::NORMPROB
);
180 double TreeEPSubTree::logZ( const vector
<Factor
> &Qa
, const vector
<Factor
> &Qb
) const {
182 for( size_t alpha
= 0; alpha
< _Qa
.size(); alpha
++ )
183 sum
+= (Qa
[_a
[alpha
]] * _Qa
[alpha
].log0()).totalSum();
184 for( size_t beta
= 0; beta
< _Qb
.size(); beta
++ )
185 sum
-= (Qb
[_b
[beta
]] * _Qb
[beta
].log0()).totalSum();
190 TreeEP::TreeEP( const FactorGraph
&fg
, const Properties
&opts
) : JTree(fg
, opts("updates",string("HUGIN")), false) {
191 assert( checkProperties() );
193 assert( fg
.G
.isConnected() );
195 if( opts
.hasKey("tree") ) {
196 ConstructRG( opts
.GetAs
<DEdgeVec
>("tree") );
198 if( Type() == TypeType::ORG
) {
199 // construct weighted graph with as weights a crude estimate of the
200 // mutual information between the nodes
201 WeightedGraph
<double> wg
;
202 for( size_t i
= 0; i
< nrVars(); ++i
) {
204 VarSet di
= delta(i
);
205 for( VarSet::const_iterator j
= di
.begin(); j
!= di
.end(); j
++ )
208 for( size_t I
= 0; I
< nrFactors(); I
++ ) {
209 VarSet Ivars
= factor(I
).vars();
210 if( (Ivars
== v_i
) || (Ivars
== *j
) )
212 else if( Ivars
>> (v_i
| *j
) )
213 piet
*= factor(I
).marginal( v_i
| *j
);
215 if( piet
.vars() >> (v_i
| *j
) ) {
216 piet
= piet
.marginal( v_i
| *j
);
217 Factor pietf
= piet
.marginal(v_i
) * piet
.marginal(*j
);
218 wg
[UEdge(i
,findVar(*j
))] = real( KL_dist( piet
, pietf
) );
220 wg
[UEdge(i
,findVar(*j
))] = 0;
224 // find maximal spanning tree
225 ConstructRG( MaxSpanningTreePrim( wg
) );
227 // cout << "Constructing maximum spanning tree..." << endl;
228 // DEdgeVec MST = MaxSpanningTreePrim( wg );
229 // cout << "Maximum spanning tree:" << endl;
230 // for( DEdgeVec::const_iterator e = MST.begin(); e != MST.end(); e++ )
231 // cout << *e << endl;
232 // ConstructRG( MST );
233 } else if( Type() == TypeType::ALT
) {
234 // construct weighted graph with as weights an upper bound on the
235 // effective interaction strength between pairs of nodes
236 WeightedGraph
<double> wg
;
237 for( size_t i
= 0; i
< nrVars(); ++i
) {
239 VarSet di
= delta(i
);
240 for( VarSet::const_iterator j
= di
.begin(); j
!= di
.end(); j
++ )
243 for( size_t I
= 0; I
< nrFactors(); I
++ ) {
244 VarSet Ivars
= factor(I
).vars();
245 if( Ivars
>> (v_i
| *j
) )
248 wg
[UEdge(i
,findVar(*j
))] = piet
.strength(v_i
, *j
);
252 // find maximal spanning tree
253 ConstructRG( MaxSpanningTreePrim( wg
) );
261 void TreeEP::ConstructRG( const DEdgeVec
&tree
) {
262 vector
<VarSet
> Cliques
;
263 for( size_t i
= 0; i
< tree
.size(); i
++ )
264 Cliques
.push_back( var(tree
[i
].n1
) | var(tree
[i
].n2
) );
266 // Construct a weighted graph (each edge is weighted with the cardinality
267 // of the intersection of the nodes, where the nodes are the elements of
269 WeightedGraph
<int> JuncGraph
;
270 for( size_t i
= 0; i
< Cliques
.size(); i
++ )
271 for( size_t j
= i
+1; j
< Cliques
.size(); j
++ ) {
272 size_t w
= (Cliques
[i
] & Cliques
[j
]).size();
273 JuncGraph
[UEdge(i
,j
)] = w
;
276 // Construct maximal spanning tree using Prim's algorithm
277 _RTree
= MaxSpanningTreePrim( JuncGraph
);
279 // Construct corresponding region graph
281 // Create outer regions
282 ORs
.reserve( Cliques
.size() );
283 for( size_t i
= 0; i
< Cliques
.size(); i
++ )
284 ORs
.push_back( FRegion( Factor(Cliques
[i
], 1.0), 1.0 ) );
286 // For each factor, find an outer region that subsumes that factor.
287 // Then, multiply the outer region with that factor.
288 // If no outer region can be found subsuming that factor, label the
289 // factor as off-tree.
291 fac2OR
.resize( nrFactors(), -1U );
292 for( size_t I
= 0; I
< nrFactors(); I
++ ) {
294 for( alpha
= 0; alpha
< nrORs(); alpha
++ )
295 if( OR(alpha
).vars() >> factor(I
).vars() ) {
299 // DIFF WITH JTree::GenerateJT: assert
303 // Create inner regions and edges
304 IRs
.reserve( _RTree
.size() );
305 typedef pair
<size_t,size_t> Edge
;
307 edges
.reserve( 2 * _RTree
.size() );
308 for( size_t i
= 0; i
< _RTree
.size(); i
++ ) {
309 edges
.push_back( Edge( _RTree
[i
].n1
, IRs
.size() ) );
310 edges
.push_back( Edge( _RTree
[i
].n2
, IRs
.size() ) );
311 // inner clusters have counting number -1
312 IRs
.push_back( Region( Cliques
[_RTree
[i
].n1
] & Cliques
[_RTree
[i
].n2
], -1.0 ) );
315 // create bipartite graph
316 G
.create( nrORs(), nrIRs(), edges
.begin(), edges
.end() );
318 // Check counting numbers
319 Check_Counting_Numbers();
321 // Create messages and beliefs
323 _Qa
.reserve( nrORs() );
324 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
325 _Qa
.push_back( OR(alpha
) );
328 _Qb
.reserve( nrIRs() );
329 for( size_t beta
= 0; beta
< nrIRs(); beta
++ )
330 _Qb
.push_back( Factor( IR(beta
), 1.0 ) );
332 // DIFF with JTree::GenerateJT: no messages
334 // DIFF with JTree::GenerateJT:
335 // Create factor approximations
337 size_t PreviousRoot
= (size_t)-1;
338 for( size_t I
= 0; I
< nrFactors(); I
++ )
340 // find efficient subtree
342 /*size_t subTreeSize =*/ findEfficientTree( factor(I
).vars(), subTree
, PreviousRoot
);
343 PreviousRoot
= subTree
[0].n1
;
344 //subTree.resize( subTreeSize ); // FIXME
345 // cout << "subtree " << I << " has size " << subTreeSize << endl;
349 sprintf( fn, "/tmp/subtree_%d.dot", I );
350 std::ofstream dots(fn);
351 dots << "graph G {" << endl;
352 dots << "graph[size=\"9,9\"];" << endl;
353 dots << "node[shape=circle,width=0.4,fixedsize=true];" << endl;
354 for( size_t i = 0; i < nrVars(); i++ )
355 dots << "\tx" << var(i).label() << ((factor(I).vars() >> var(i)) ? "[color=blue];" : ";") << endl;
356 dots << "node[shape=box,style=filled,color=lightgrey,width=0.3,height=0.3,fixedsize=true];" << endl;
357 for( size_t J = 0; J < nrFactors(); J++ )
358 dots << "\tp" << J << ";" << endl;
359 for( size_t iI = 0; iI < FactorGraph::nr_edges(); iI++ )
360 dots << "\tx" << var(FactorGraph::edge(iI).first).label() << " -- p" << FactorGraph::edge(iI).second << ";" << endl;
361 for( size_t a = 0; a < tree.size(); a++ )
362 dots << "\tx" << var(tree[a].n1).label() << " -- x" << var(tree[a].n2).label() << " [color=red];" << endl;
367 TreeEPSubTree
QI( subTree
, _RTree
, _Qa
, _Qb
, &factor(I
) );
370 // Previous root of first off-tree factor should be the root of the last off-tree factor
371 for( size_t I
= 0; I
< nrFactors(); I
++ )
374 /*size_t subTreeSize =*/ findEfficientTree( factor(I
).vars(), subTree
, PreviousRoot
);
375 PreviousRoot
= subTree
[0].n1
;
376 //subTree.resize( subTreeSize ); // FIXME
377 // cout << "subtree " << I << " has size " << subTreeSize << endl;
379 TreeEPSubTree
QI( subTree
, _RTree
, _Qa
, _Qb
, &factor(I
) );
384 if( Verbose() >= 3 ) {
385 cout
<< "Resulting regiongraph: " << *this << endl
;
390 string
TreeEP::identify() const {
391 stringstream
result (stringstream::out
);
392 result
<< Name
<< GetProperties();
397 void TreeEP::init() {
398 assert( checkProperties() );
402 // Init factor approximations
403 for( size_t I
= 0; I
< nrFactors(); I
++ )
409 double TreeEP::run() {
411 cout
<< "Starting " << identify() << "...";
416 Diffs
diffs(nrVars(), 1.0);
418 vector
<Factor
> old_beliefs
;
419 old_beliefs
.reserve( nrVars() );
420 for( size_t i
= 0; i
< nrVars(); i
++ )
421 old_beliefs
.push_back(belief(var(i
)));
425 // do several passes over the network until maximum number of iterations has
426 // been reached or until the maximum belief difference is smaller than tolerance
427 for( iter
=0; iter
< MaxIter() && diffs
.max() > Tol(); iter
++ ) {
428 for( size_t I
= 0; I
< nrFactors(); I
++ )
430 _Q
[I
].InvertAndMultiply( _Qa
, _Qb
);
431 _Q
[I
].HUGIN_with_I( _Qa
, _Qb
);
432 _Q
[I
].InvertAndMultiply( _Qa
, _Qb
);
435 // calculate new beliefs and compare with old ones
436 for( size_t i
= 0; i
< nrVars(); i
++ ) {
437 Factor
nb( belief(var(i
)) );
438 diffs
.push( dist( nb
, old_beliefs
[i
], Prob::DISTLINF
) );
443 cout
<< "TreeEP::run: maxdiff " << diffs
.max() << " after " << iter
+1 << " passes" << endl
;
446 updateMaxDiff( diffs
.max() );
448 if( Verbose() >= 1 ) {
449 if( diffs
.max() > Tol() ) {
452 cout
<< "TreeEP::run: WARNING: not converged within " << MaxIter() << " passes (" << toc() - tic
<< " clocks)...final maxdiff:" << diffs
.max() << endl
;
455 cout
<< "TreeEP::run: ";
456 cout
<< "converged in " << iter
<< " passes (" << toc() - tic
<< " clocks)." << endl
;
464 Complex
TreeEP::logZ() const {
467 // entropy of the tree
468 for( size_t beta
= 0; beta
< nrIRs(); beta
++ )
469 sum
-= real(_Qb
[beta
].entropy());
470 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
471 sum
+= real(_Qa
[alpha
].entropy());
473 // energy of the on-tree factors
474 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
475 sum
+= (OR(alpha
).log0() * _Qa
[alpha
]).totalSum();
477 // energy of the off-tree factors
478 for( size_t I
= 0; I
< nrFactors(); I
++ )
480 sum
+= (_Q
.find(I
))->second
.logZ( _Qa
, _Qb
);
486 } // end of namespace dai