2a3a47ede619b6f9f55d9059a00de199a443b443
1 /* This file is part of libDAI - http://www.libdai.org/
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
12 #include <dai/jtree.h>
13 #include <dai/treeep.h>
23 void TreeEP::setProperties( const PropertySet
&opts
) {
24 DAI_ASSERT( opts
.hasKey("tol") );
25 DAI_ASSERT( opts
.hasKey("type") );
27 props
.tol
= opts
.getStringAs
<Real
>("tol");
28 props
.type
= opts
.getStringAs
<Properties::TypeType
>("type");
29 if( opts
.hasKey("maxiter") )
30 props
.maxiter
= opts
.getStringAs
<size_t>("maxiter");
32 props
.maxiter
= 10000;
33 if( opts
.hasKey("maxtime") )
34 props
.maxtime
= opts
.getStringAs
<Real
>("maxtime");
36 props
.maxtime
= INFINITY
;
37 if( opts
.hasKey("verbose") )
38 props
.verbose
= opts
.getStringAs
<size_t>("verbose");
44 PropertySet
TreeEP::getProperties() const {
46 opts
.set( "tol", props
.tol
);
47 opts
.set( "maxiter", props
.maxiter
);
48 opts
.set( "maxtime", props
.maxtime
);
49 opts
.set( "verbose", props
.verbose
);
50 opts
.set( "type", props
.type
);
55 string
TreeEP::printProperties() const {
56 stringstream
s( stringstream::out
);
58 s
<< "tol=" << props
.tol
<< ",";
59 s
<< "maxiter=" << props
.maxiter
<< ",";
60 s
<< "maxtime=" << props
.maxtime
<< ",";
61 s
<< "verbose=" << props
.verbose
<< ",";
62 s
<< "type=" << props
.type
<< "]";
67 TreeEP::TreeEP( const FactorGraph
&fg
, const PropertySet
&opts
) : JTree(fg
, opts("updates",string("HUGIN")), false), _maxdiff(0.0), _iters(0), props(), _Q() {
68 setProperties( opts
);
70 if( opts
.hasKey("tree") ) {
71 construct( fg
, opts
.getAs
<RootedTree
>("tree") );
73 if( props
.type
== Properties::TypeType::ORG
|| props
.type
== Properties::TypeType::ALT
) {
74 // ORG: construct weighted graph with as weights a crude estimate of the
75 // mutual information between the nodes
76 // ALT: construct weighted graph with as weights an upper bound on the
77 // effective interaction strength between pairs of nodes
79 WeightedGraph
<Real
> wg
;
80 // in order to get a connected weighted graph, we start
81 // by connecting every variable to the zero'th variable with weight 0
82 for( size_t i
= 1; i
< fg
.nrVars(); i
++ )
84 for( size_t i
= 0; i
< fg
.nrVars(); i
++ ) {
85 SmallSet
<size_t> delta_i
= fg
.bipGraph().delta1( i
, false );
86 const Var
& v_i
= fg
.var(i
);
87 bforeach( size_t j
, delta_i
)
89 const Var
& v_j
= fg
.var(j
);
90 VarSet
v_ij( v_i
, v_j
);
91 SmallSet
<size_t> nb_ij
= fg
.bipGraph().nb1Set( i
) | fg
.bipGraph().nb1Set( j
);
93 bforeach( size_t I
, nb_ij
) {
94 const VarSet
& Ivars
= fg
.factor(I
).vars();
95 if( props
.type
== Properties::TypeType::ORG
) {
96 if( (Ivars
== v_i
) || (Ivars
== v_j
) )
98 else if( Ivars
>> v_ij
)
99 piet
*= fg
.factor(I
).marginal( v_ij
);
102 piet
*= fg
.factor(I
);
105 if( props
.type
== Properties::TypeType::ORG
) {
106 if( piet
.vars() >> v_ij
) {
107 piet
= piet
.marginal( v_ij
);
108 Factor pietf
= piet
.marginal(v_i
) * piet
.marginal(v_j
);
109 wg
[UEdge(i
,j
)] = dist( piet
, pietf
, DISTKL
);
111 // this should never happen...
112 DAI_ASSERT( 0 == 1 );
116 wg
[UEdge(i
,j
)] = piet
.strength(v_i
, v_j
);
120 // find maximal spanning tree
121 if( props
.verbose
>= 3 )
122 cerr
<< "WeightedGraph: " << wg
<< endl
;
123 RootedTree t
= MaxSpanningTree( wg
, true );
124 if( props
.verbose
>= 3 )
125 cerr
<< "Spanningtree: " << t
<< endl
;
128 DAI_THROW(UNKNOWN_ENUM_VALUE
);
133 void TreeEP::construct( const FactorGraph
& fg
, const RootedTree
& tree
) {
134 // Copy the factor graph
135 FactorGraph::operator=( fg
);
138 for( size_t i
= 0; i
< tree
.size(); i
++ )
139 cl
.push_back( VarSet( var(tree
[i
].first
), var(tree
[i
].second
) ) );
141 // If no outer region can be found subsuming that factor, label the
142 // factor as off-tree.
143 JTree::construct( *this, cl
, false );
145 if( props
.verbose
>= 1 )
146 cerr
<< "TreeEP::construct: The tree has size " << JTree::RTree
.size() << endl
;
147 if( props
.verbose
>= 3 )
148 cerr
<< " it is " << JTree::RTree
<< " with cliques " << cl
<< endl
;
150 // Create factor approximations
152 size_t PreviousRoot
= (size_t)-1;
153 // Second repetition: previous root of first off-tree factor should be the root of the last off-tree factor
154 for( size_t repeats
= 0; repeats
< 2; repeats
++ )
155 for( size_t I
= 0; I
< nrFactors(); I
++ )
157 // find efficient subtree
159 size_t subTreeSize
= findEfficientTree( factor(I
).vars(), subTree
, PreviousRoot
);
160 PreviousRoot
= subTree
[0].first
;
161 subTree
.resize( subTreeSize
);
162 if( props
.verbose
>= 1 )
163 cerr
<< "Subtree " << I
<< " has size " << subTreeSize
<< endl
;
164 if( props
.verbose
>= 3 )
165 cerr
<< " it is " << subTree
<< endl
;
166 _Q
[I
] = TreeEPSubTree( subTree
, RTree
, Qa
, Qb
, &factor(I
) );
171 if( props
.verbose
>= 3 )
172 cerr
<< "Resulting regiongraph: " << *this << endl
;
176 void TreeEP::init() {
179 // Init factor approximations
180 for( size_t I
= 0; I
< nrFactors(); I
++ )
187 if( props
.verbose
>= 1 )
188 cerr
<< "Starting " << identify() << "...";
189 if( props
.verbose
>= 3 )
194 vector
<Factor
> oldBeliefs
= beliefs();
196 // do several passes over the network until maximum number of iterations has
197 // been reached or until the maximum belief difference is smaller than tolerance
198 Real maxDiff
= INFINITY
;
199 for( _iters
= 0; _iters
< props
.maxiter
&& maxDiff
> props
.tol
&& (toc() - tic
) < props
.maxtime
; _iters
++ ) {
200 for( size_t I
= 0; I
< nrFactors(); I
++ )
202 _Q
[I
].InvertAndMultiply( Qa
, Qb
);
203 _Q
[I
].HUGIN_with_I( Qa
, Qb
);
204 _Q
[I
].InvertAndMultiply( Qa
, Qb
);
207 // calculate new beliefs and compare with old ones
208 vector
<Factor
> newBeliefs
= beliefs();
210 for( size_t t
= 0; t
< oldBeliefs
.size(); t
++ )
211 maxDiff
= std::max( maxDiff
, dist( newBeliefs
[t
], oldBeliefs
[t
], DISTLINF
) );
212 swap( newBeliefs
, oldBeliefs
);
214 if( props
.verbose
>= 3 )
215 cerr
<< name() << "::run: maxdiff " << maxDiff
<< " after " << _iters
+1 << " passes" << endl
;
218 if( maxDiff
> _maxdiff
)
221 if( props
.verbose
>= 1 ) {
222 if( maxDiff
> props
.tol
) {
223 if( props
.verbose
== 1 )
225 cerr
<< name() << "::run: WARNING: not converged after " << _iters
<< " passes (" << toc() - tic
<< " seconds)...final maxdiff:" << maxDiff
<< endl
;
227 if( props
.verbose
>= 3 )
228 cerr
<< name() << "::run: ";
229 cerr
<< "converged in " << _iters
<< " passes (" << toc() - tic
<< " seconds)." << endl
;
237 Real
TreeEP::logZ() const {
240 // entropy of the tree
241 for( size_t beta
= 0; beta
< nrIRs(); beta
++ )
242 s
-= Qb
[beta
].entropy();
243 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
244 s
+= Qa
[alpha
].entropy();
246 // energy of the on-tree factors
247 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
248 s
+= (OR(alpha
).log(true) * Qa
[alpha
]).sum();
250 // energy of the off-tree factors
251 for( size_t I
= 0; I
< nrFactors(); I
++ )
253 s
+= (_Q
.find(I
))->second
.logZ( Qa
, Qb
);
259 TreeEP::TreeEPSubTree::TreeEPSubTree( const RootedTree
&subRTree
, const RootedTree
&jt_RTree
, const std::vector
<Factor
> &jt_Qa
, const std::vector
<Factor
> &jt_Qb
, const Factor
*I
) : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(I
), _ns(), _nsrem(), _logZ(0.0) {
262 // Make _Qa, _Qb, _a and _b corresponding to the subtree
263 _b
.reserve( subRTree
.size() );
264 _Qb
.reserve( subRTree
.size() );
265 _RTree
.reserve( subRTree
.size() );
266 for( size_t i
= 0; i
< subRTree
.size(); i
++ ) {
267 size_t alpha1
= subRTree
[i
].first
; // old index 1
268 size_t alpha2
= subRTree
[i
].second
; // old index 2
269 size_t beta
; // old sep index
270 for( beta
= 0; beta
< jt_RTree
.size(); beta
++ )
271 if( UEdge( jt_RTree
[beta
].first
, jt_RTree
[beta
].second
) == UEdge( alpha1
, alpha2
) )
273 DAI_ASSERT( beta
!= jt_RTree
.size() );
275 size_t newalpha1
= find(_a
.begin(), _a
.end(), alpha1
) - _a
.begin();
276 if( newalpha1
== _a
.size() ) {
277 _Qa
.push_back( Factor( jt_Qa
[alpha1
].vars(), 1.0 ) );
278 _a
.push_back( alpha1
); // save old index in index conversion table
281 size_t newalpha2
= find(_a
.begin(), _a
.end(), alpha2
) - _a
.begin();
282 if( newalpha2
== _a
.size() ) {
283 _Qa
.push_back( Factor( jt_Qa
[alpha2
].vars(), 1.0 ) );
284 _a
.push_back( alpha2
); // save old index in index conversion table
287 _RTree
.push_back( DEdge( newalpha1
, newalpha2
) );
288 _Qb
.push_back( Factor( jt_Qb
[beta
].vars(), 1.0 ) );
289 _b
.push_back( beta
);
292 // Find remaining variables (which are not in the new root)
293 _nsrem
= _ns
/ _Qa
[0].vars();
297 void TreeEP::TreeEPSubTree::init() {
298 for( size_t alpha
= 0; alpha
< _Qa
.size(); alpha
++ )
299 _Qa
[alpha
].fill( 1.0 );
300 for( size_t beta
= 0; beta
< _Qb
.size(); beta
++ )
301 _Qb
[beta
].fill( 1.0 );
305 void TreeEP::TreeEPSubTree::InvertAndMultiply( const std::vector
<Factor
> &Qa
, const std::vector
<Factor
> &Qb
) {
306 for( size_t alpha
= 0; alpha
< _Qa
.size(); alpha
++ )
307 _Qa
[alpha
] = Qa
[_a
[alpha
]] / _Qa
[alpha
];
309 for( size_t beta
= 0; beta
< _Qb
.size(); beta
++ )
310 _Qb
[beta
] = Qb
[_b
[beta
]] / _Qb
[beta
];
314 void TreeEP::TreeEPSubTree::HUGIN_with_I( std::vector
<Factor
> &Qa
, std::vector
<Factor
> &Qb
) {
315 // Backup _Qa and _Qb
316 vector
<Factor
> _Qa_old(_Qa
);
317 vector
<Factor
> _Qb_old(_Qb
);
320 for( size_t alpha
= 0; alpha
< _Qa
.size(); alpha
++ )
321 Qa
[_a
[alpha
]].fill( 0.0 );
322 for( size_t beta
= 0; beta
< _Qb
.size(); beta
++ )
323 Qb
[_b
[beta
]].fill( 0.0 );
325 // For all states of _nsrem
326 for( State
s(_nsrem
); s
.valid(); s
++ ) {
327 // Multiply root with slice of I
328 _Qa
[0] *= _I
->slice( _nsrem
, s
);
331 for( size_t i
= _RTree
.size(); (i
--) != 0; ) {
332 // clamp variables in nsrem
333 for( VarSet::const_iterator n
= _nsrem
.begin(); n
!= _nsrem
.end(); n
++ )
334 if( _Qa
[_RTree
[i
].second
].vars() >> *n
)
335 _Qa
[_RTree
[i
].second
] *= createFactorDelta( *n
, s(*n
) );
336 Factor new_Qb
= _Qa
[_RTree
[i
].second
].marginal( _Qb
[i
].vars(), false );
337 _Qa
[_RTree
[i
].first
] *= new_Qb
/ _Qb
[i
];
341 // DistributeEvidence
342 for( size_t i
= 0; i
< _RTree
.size(); i
++ ) {
343 Factor new_Qb
= _Qa
[_RTree
[i
].first
].marginal( _Qb
[i
].vars(), false );
344 _Qa
[_RTree
[i
].second
] *= new_Qb
/ _Qb
[i
];
348 // Store Qa's and Qb's
349 for( size_t alpha
= 0; alpha
< _Qa
.size(); alpha
++ )
350 Qa
[_a
[alpha
]].p() += _Qa
[alpha
].p();
351 for( size_t beta
= 0; beta
< _Qb
.size(); beta
++ )
352 Qb
[_b
[beta
]].p() += _Qb
[beta
].p();
354 // Restore _Qa and _Qb
359 // Normalize Qa and Qb
361 for( size_t alpha
= 0; alpha
< _Qa
.size(); alpha
++ ) {
362 _logZ
+= log(Qa
[_a
[alpha
]].sum());
363 Qa
[_a
[alpha
]].normalize();
365 for( size_t beta
= 0; beta
< _Qb
.size(); beta
++ ) {
366 _logZ
-= log(Qb
[_b
[beta
]].sum());
367 Qb
[_b
[beta
]].normalize();
372 Real
TreeEP::TreeEPSubTree::logZ( const std::vector
<Factor
> &Qa
, const std::vector
<Factor
> &Qb
) const {
374 for( size_t alpha
= 0; alpha
< _Qa
.size(); alpha
++ )
375 s
+= (Qa
[_a
[alpha
]] * _Qa
[alpha
].log(true)).sum();
376 for( size_t beta
= 0; beta
< _Qb
.size(); beta
++ )
377 s
-= (Qb
[_b
[beta
]] * _Qb
[beta
].log(true)).sum();
382 } // end of namespace dai