1 /* This file is part of libDAI - http://www.libdai.org/
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
9 #include <dai/dai_config.h>
15 #include <dai/jtree.h>
24 void JTree::setProperties( const PropertySet
&opts
) {
25 DAI_ASSERT( opts
.hasKey("updates") );
27 props
.updates
= opts
.getStringAs
<Properties::UpdateType
>("updates");
28 if( opts
.hasKey("verbose") )
29 props
.verbose
= opts
.getStringAs
<size_t>("verbose");
32 if( opts
.hasKey("inference") )
33 props
.inference
= opts
.getStringAs
<Properties::InfType
>("inference");
35 props
.inference
= Properties::InfType::SUMPROD
;
36 if( opts
.hasKey("heuristic") )
37 props
.heuristic
= opts
.getStringAs
<Properties::HeuristicType
>("heuristic");
39 props
.heuristic
= Properties::HeuristicType::MINFILL
;
40 if( opts
.hasKey("maxmem") )
41 props
.maxmem
= opts
.getStringAs
<size_t>("maxmem");
47 PropertySet
JTree::getProperties() const {
49 opts
.set( "verbose", props
.verbose
);
50 opts
.set( "updates", props
.updates
);
51 opts
.set( "inference", props
.inference
);
52 opts
.set( "heuristic", props
.heuristic
);
53 opts
.set( "maxmem", props
.maxmem
);
58 string
JTree::printProperties() const {
59 stringstream
s( stringstream::out
);
61 s
<< "verbose=" << props
.verbose
<< ",";
62 s
<< "updates=" << props
.updates
<< ",";
63 s
<< "heuristic=" << props
.heuristic
<< ",";
64 s
<< "inference=" << props
.inference
<< ",";
65 s
<< "maxmem=" << props
.maxmem
<< "]";
70 JTree::JTree( const FactorGraph
&fg
, const PropertySet
&opts
, bool automatic
) : DAIAlgRG(), _mes(), _logZ(), RTree(), Qa(), Qb(), props() {
71 setProperties( opts
);
74 // Create ClusterGraph which contains maximal factors as clusters
75 ClusterGraph
_cg( fg
, true );
76 if( props
.verbose
>= 3 )
77 cerr
<< "Initial clusters: " << _cg
<< endl
;
79 // Use heuristic to guess optimal elimination sequence
80 greedyVariableElimination::eliminationCostFunction
ec(NULL
);
81 switch( (size_t)props
.heuristic
) {
82 case Properties::HeuristicType::MINNEIGHBORS
:
83 ec
= eliminationCost_MinNeighbors
;
85 case Properties::HeuristicType::MINWEIGHT
:
86 ec
= eliminationCost_MinWeight
;
88 case Properties::HeuristicType::MINFILL
:
89 ec
= eliminationCost_MinFill
;
91 case Properties::HeuristicType::WEIGHTEDMINFILL
:
92 ec
= eliminationCost_WeightedMinFill
;
95 DAI_THROW(UNKNOWN_ENUM_VALUE
);
97 size_t fudge
= 6; // this yields a rough estimate of the memory needed (for some reason not yet clearly understood)
98 vector
<VarSet
> ElimVec
= _cg
.VarElim( greedyVariableElimination( ec
), props
.maxmem
/ (sizeof(Real
) * fudge
) ).eraseNonMaximal().clusters();
99 if( props
.verbose
>= 3 )
100 cerr
<< "VarElim result: " << ElimVec
<< endl
;
102 // Estimate memory needed (rough upper bound)
103 BigInt memneeded
= 0;
104 bforeach( const VarSet
& cl
, ElimVec
)
105 memneeded
+= cl
.nrStates();
106 memneeded
*= (BigInt
)sizeof(Real
) * (BigInt
)fudge
;
107 if( props
.verbose
>= 1 ) {
108 cerr
<< "Estimate of needed memory: " << memneeded
/ 1024 << "kB" << endl
;
109 cerr
<< "Maximum memory: ";
111 cerr
<< props
.maxmem
/ 1024 << "kB" << endl
;
113 cerr
<< "unlimited" << endl
;
115 if( props
.maxmem
&& memneeded
> (BigInt
)props
.maxmem
)
116 DAI_THROW(OUT_OF_MEMORY
);
118 // Generate the junction tree corresponding to the elimination sequence
119 GenerateJT( fg
, ElimVec
);
124 void JTree::construct( const FactorGraph
&fg
, const std::vector
<VarSet
> &cl
, bool verify
) {
125 // Copy the factor graph
126 FactorGraph::operator=( fg
);
128 // Construct a weighted graph (each edge is weighted with the cardinality
129 // of the intersection of the nodes, where the nodes are the elements of cl).
130 WeightedGraph
<int> JuncGraph
;
131 // Start by connecting all clusters with cluster zero, and weight zero,
132 // in order to get a connected weighted graph
133 for( size_t i
= 1; i
< cl
.size(); i
++ )
134 JuncGraph
[UEdge(i
,0)] = 0;
135 for( size_t i
= 0; i
< cl
.size(); i
++ ) {
136 for( size_t j
= i
+ 1; j
< cl
.size(); j
++ ) {
137 size_t w
= (cl
[i
] & cl
[j
]).size();
139 JuncGraph
[UEdge(i
,j
)] = w
;
142 if( props
.verbose
>= 3 )
143 cerr
<< "Weightedgraph: " << JuncGraph
<< endl
;
145 // Construct maximal spanning tree
146 RTree
= MaxSpanningTree( JuncGraph
, /*true*/false ); // WORKAROUND FOR BUG IN BOOST GRAPH LIBRARY VERSION 1.54
147 if( props
.verbose
>= 3 )
148 cerr
<< "Spanning tree: " << RTree
<< endl
;
149 DAI_DEBASSERT( RTree
.size() == cl
.size() - 1 );
151 // Construct corresponding region graph
153 // Create outer regions
155 _ORs
.reserve( cl
.size() );
156 for( size_t i
= 0; i
< cl
.size(); i
++ )
157 _ORs
.push_back( FRegion( Factor(cl
[i
], 1.0), 1.0 ) );
159 // For each factor, find an outer region that subsumes that factor.
160 // Then, multiply the outer region with that factor.
162 _fac2OR
.resize( nrFactors(), -1U );
163 for( size_t I
= 0; I
< nrFactors(); I
++ ) {
165 for( alpha
= 0; alpha
< nrORs(); alpha
++ )
166 if( OR(alpha
).vars() >> factor(I
).vars() ) {
171 DAI_ASSERT( alpha
!= nrORs() );
175 // Create inner regions and edges
177 _IRs
.reserve( RTree
.size() );
179 edges
.reserve( 2 * RTree
.size() );
180 for( size_t i
= 0; i
< RTree
.size(); i
++ ) {
181 edges
.push_back( Edge( RTree
[i
].first
, nrIRs() ) );
182 edges
.push_back( Edge( RTree
[i
].second
, nrIRs() ) );
183 // inner clusters have counting number -1, except if they are empty
184 VarSet intersection
= cl
[RTree
[i
].first
] & cl
[RTree
[i
].second
];
185 _IRs
.push_back( Region( intersection
, intersection
.size() ? -1.0 : 0.0 ) );
188 // create bipartite graph
189 _G
.construct( nrORs(), nrIRs(), edges
.begin(), edges
.end() );
191 // Check counting numbers
193 checkCountingNumbers();
198 Qa
.reserve( nrORs() );
199 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
200 Qa
.push_back( OR(alpha
) );
203 Qb
.reserve( nrIRs() );
204 for( size_t beta
= 0; beta
< nrIRs(); beta
++ )
205 Qb
.push_back( Factor( IR(beta
), 1.0 ) );
209 void JTree::GenerateJT( const FactorGraph
&fg
, const std::vector
<VarSet
> &cl
) {
210 construct( fg
, cl
, true );
214 _mes
.reserve( nrORs() );
215 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ ) {
216 _mes
.push_back( vector
<Factor
>() );
217 _mes
[alpha
].reserve( nbOR(alpha
).size() );
218 bforeach( const Neighbor
&beta
, nbOR(alpha
) )
219 _mes
[alpha
].push_back( Factor( IR(beta
), 1.0 ) );
222 if( props
.verbose
>= 3 )
223 cerr
<< "Regiongraph generated by JTree::GenerateJT: " << *this << endl
;
227 Factor
JTree::belief( const VarSet
&vs
) const {
228 vector
<Factor
>::const_iterator beta
;
229 for( beta
= Qb
.begin(); beta
!= Qb
.end(); beta
++ )
230 if( beta
->vars() >> vs
)
232 if( beta
!= Qb
.end() ) {
233 if( props
.inference
== Properties::InfType::SUMPROD
)
234 return( beta
->marginal(vs
) );
236 return( beta
->maxMarginal(vs
) );
238 vector
<Factor
>::const_iterator alpha
;
239 for( alpha
= Qa
.begin(); alpha
!= Qa
.end(); alpha
++ )
240 if( alpha
->vars() >> vs
)
242 if( alpha
== Qa
.end() ) {
243 DAI_THROW(BELIEF_NOT_AVAILABLE
);
246 if( props
.inference
== Properties::InfType::SUMPROD
)
247 return( alpha
->marginal(vs
) );
249 return( alpha
->maxMarginal(vs
) );
255 vector
<Factor
> JTree::beliefs() const {
256 vector
<Factor
> result
;
257 for( size_t beta
= 0; beta
< nrIRs(); beta
++ )
258 result
.push_back( Qb
[beta
] );
259 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
260 result
.push_back( Qa
[alpha
] );
265 void JTree::runHUGIN() {
266 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
267 Qa
[alpha
] = OR(alpha
);
269 for( size_t beta
= 0; beta
< nrIRs(); beta
++ )
270 Qb
[beta
].fill( 1.0 );
274 for( size_t i
= RTree
.size(); (i
--) != 0; ) {
275 // Make outer region RTree[i].first consistent with outer region RTree[i].second
276 // IR(i) = seperator OR(RTree[i].first) && OR(RTree[i].second)
278 if( props
.inference
== Properties::InfType::SUMPROD
)
279 new_Qb
= Qa
[RTree
[i
].second
].marginal( IR( i
), false );
281 new_Qb
= Qa
[RTree
[i
].second
].maxMarginal( IR( i
), false );
283 _logZ
+= log(new_Qb
.normalize());
284 Qa
[RTree
[i
].first
] *= new_Qb
/ Qb
[i
];
288 _logZ
+= log(Qa
[0].normalize() );
290 _logZ
+= log(Qa
[RTree
[0].first
].normalize());
292 // DistributeEvidence
293 for( size_t i
= 0; i
< RTree
.size(); i
++ ) {
294 // Make outer region RTree[i].second consistent with outer region RTree[i].first
295 // IR(i) = seperator OR(RTree[i].first) && OR(RTree[i].second)
297 if( props
.inference
== Properties::InfType::SUMPROD
)
298 new_Qb
= Qa
[RTree
[i
].first
].marginal( IR( i
) );
300 new_Qb
= Qa
[RTree
[i
].first
].maxMarginal( IR( i
) );
302 Qa
[RTree
[i
].second
] *= new_Qb
/ Qb
[i
];
307 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ )
308 Qa
[alpha
].normalize();
312 void JTree::runShaferShenoy() {
315 for( size_t e
= nrIRs(); (e
--) != 0; ) {
316 // send a message from RTree[e].second to RTree[e].first
317 // or, actually, from the seperator IR(e) to RTree[e].first
319 size_t i
= nbIR(e
)[1].node
; // = RTree[e].second
320 size_t j
= nbIR(e
)[0].node
; // = RTree[e].first
321 size_t _e
= nbIR(e
)[0].dual
;
324 bforeach( const Neighbor
&k
, nbOR(i
) )
326 msg
*= message( i
, k
.iter
);
327 if( props
.inference
== Properties::InfType::SUMPROD
)
328 message( j
, _e
) = msg
.marginal( IR(e
), false );
330 message( j
, _e
) = msg
.maxMarginal( IR(e
), false );
331 _logZ
+= log( message(j
,_e
).normalize() );
335 for( size_t e
= 0; e
< nrIRs(); e
++ ) {
336 size_t i
= nbIR(e
)[0].node
; // = RTree[e].first
337 size_t j
= nbIR(e
)[1].node
; // = RTree[e].second
338 size_t _e
= nbIR(e
)[1].dual
;
341 bforeach( const Neighbor
&k
, nbOR(i
) )
343 msg
*= message( i
, k
.iter
);
344 if( props
.inference
== Properties::InfType::SUMPROD
)
345 message( j
, _e
) = msg
.marginal( IR(e
) );
347 message( j
, _e
) = msg
.maxMarginal( IR(e
) );
351 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ ) {
352 Factor piet
= OR(alpha
);
353 bforeach( const Neighbor
&k
, nbOR(alpha
) )
354 piet
*= message( alpha
, k
.iter
);
356 _logZ
+= log( piet
.normalize() );
358 } else if( alpha
== nbIR(0)[0].node
/*RTree[0].first*/ ) {
359 _logZ
+= log( piet
.normalize() );
362 Qa
[alpha
] = piet
.normalized();
365 // Only for logZ (and for belief)...
366 for( size_t beta
= 0; beta
< nrIRs(); beta
++ ) {
367 if( props
.inference
== Properties::InfType::SUMPROD
)
368 Qb
[beta
] = Qa
[nbIR(beta
)[0].node
].marginal( IR(beta
) );
370 Qb
[beta
] = Qa
[nbIR(beta
)[0].node
].maxMarginal( IR(beta
) );
376 if( props
.updates
== Properties::UpdateType::HUGIN
)
378 else if( props
.updates
== Properties::UpdateType::SHSH
)
384 Real
JTree::logZ() const {
386 for( size_t beta = 0; beta < nrIRs(); beta++ )
387 s += IR(beta).c() * Qb[beta].entropy();
388 for( size_t alpha = 0; alpha < nrORs(); alpha++ ) {
389 s += OR(alpha).c() * Qa[alpha].entropy();
390 s += (OR(alpha).log(true) * Qa[alpha]).sum();
392 DAI_ASSERT( abs( _logZ - s ) < 1e-8 );
398 size_t JTree::findEfficientTree( const VarSet
& vs
, RootedTree
&Tree
, size_t PreviousRoot
) const {
399 // find new root clique (the one with maximal statespace overlap with vs)
402 for( size_t alpha
= 0; alpha
< nrORs(); alpha
++ ) {
403 BigInt val
= VarSet(vs
& OR(alpha
).vars()).nrStates();
410 // reorder the tree edges such that maxalpha becomes the new root
411 RootedTree
newTree( GraphEL( RTree
.begin(), RTree
.end() ), maxalpha
);
413 // identify subtree that contains all variables of vs which are not in the new root
415 // for each variable in vs
416 for( VarSet::const_iterator n
= vs
.begin(); n
!= vs
.end(); n
++ ) {
417 for( size_t e
= 0; e
< newTree
.size(); e
++ ) {
418 if( OR(newTree
[e
].second
).vars().contains( *n
) ) {
420 subTree
.insert( newTree
[f
] );
421 size_t pos
= newTree
[f
].first
;
423 if( newTree
[f
-1].second
== pos
) {
424 subTree
.insert( newTree
[f
-1] );
425 pos
= newTree
[f
-1].first
;
430 if( PreviousRoot
!= (size_t)-1 && PreviousRoot
!= maxalpha
) {
431 // find first occurence of PreviousRoot in the tree, which is closest to the new root
433 for( ; e
!= newTree
.size(); e
++ ) {
434 if( newTree
[e
].second
== PreviousRoot
)
437 DAI_ASSERT( e
!= newTree
.size() );
439 // track-back path to root and add edges to subTree
440 subTree
.insert( newTree
[e
] );
441 size_t pos
= newTree
[e
].first
;
443 if( newTree
[e
-1].second
== pos
) {
444 subTree
.insert( newTree
[e
-1] );
445 pos
= newTree
[e
-1].first
;
449 // Resulting Tree is a reordered copy of newTree
450 // First add edges in subTree to Tree
452 vector
<DEdge
> remTree
;
453 for( RootedTree::const_iterator e
= newTree
.begin(); e
!= newTree
.end(); e
++ )
454 if( subTree
.count( *e
) )
455 Tree
.push_back( *e
);
457 remTree
.push_back( *e
);
458 size_t subTreeSize
= Tree
.size();
459 // Then add remaining edges
460 copy( remTree
.begin(), remTree
.end(), back_inserter( Tree
) );
466 Factor
JTree::calcMarginal( const VarSet
& vs
) {
467 vector
<Factor
>::const_iterator beta
;
468 for( beta
= Qb
.begin(); beta
!= Qb
.end(); beta
++ )
469 if( beta
->vars() >> vs
)
471 if( beta
!= Qb
.end() ) {
472 if( props
.inference
== Properties::InfType::SUMPROD
)
473 return( beta
->marginal(vs
) );
475 return( beta
->maxMarginal(vs
) );
477 vector
<Factor
>::const_iterator alpha
;
478 for( alpha
= Qa
.begin(); alpha
!= Qa
.end(); alpha
++ )
479 if( alpha
->vars() >> vs
)
481 if( alpha
!= Qa
.end() ) {
482 if( props
.inference
== Properties::InfType::SUMPROD
)
483 return( alpha
->marginal(vs
) );
485 return( alpha
->maxMarginal(vs
) );
487 // Find subtree to do efficient inference
489 size_t Tsize
= findEfficientTree( vs
, T
);
491 // Find remaining variables (which are not in the new root)
492 VarSet vsrem
= vs
/ OR(T
.front().first
).vars();
493 Factor
Pvs (vs
, 0.0);
495 // Save Qa and Qb on the subtree
496 map
<size_t,Factor
> Qa_old
;
497 map
<size_t,Factor
> Qb_old
;
498 vector
<size_t> b(Tsize
, 0);
499 for( size_t i
= Tsize
; (i
--) != 0; ) {
500 size_t alpha1
= T
[i
].first
;
501 size_t alpha2
= T
[i
].second
;
503 for( beta
= 0; beta
< nrIRs(); beta
++ )
504 if( UEdge( RTree
[beta
].first
, RTree
[beta
].second
) == UEdge( alpha1
, alpha2
) )
506 DAI_ASSERT( beta
!= nrIRs() );
509 if( !Qa_old
.count( alpha1
) )
510 Qa_old
[alpha1
] = Qa
[alpha1
];
511 if( !Qa_old
.count( alpha2
) )
512 Qa_old
[alpha2
] = Qa
[alpha2
];
513 if( !Qb_old
.count( beta
) )
514 Qb_old
[beta
] = Qb
[beta
];
517 // For all states of vsrem
518 for( State
s(vsrem
); s
.valid(); s
++ ) {
521 for( size_t i
= Tsize
; (i
--) != 0; ) {
522 // Make outer region T[i].first consistent with outer region T[i].second
523 // IR(i) = seperator OR(T[i].first) && OR(T[i].second)
525 for( VarSet::const_iterator n
= vsrem
.begin(); n
!= vsrem
.end(); n
++ )
526 if( Qa
[T
[i
].second
].vars() >> *n
) {
527 Factor
piet( *n
, 0.0 );
528 piet
.set( s(*n
), 1.0 );
529 Qa
[T
[i
].second
] *= piet
;
533 if( props
.inference
== Properties::InfType::SUMPROD
)
534 new_Qb
= Qa
[T
[i
].second
].marginal( IR( b
[i
] ), false );
536 new_Qb
= Qa
[T
[i
].second
].maxMarginal( IR( b
[i
] ), false );
537 logZ
+= log(new_Qb
.normalize());
538 Qa
[T
[i
].first
] *= new_Qb
/ Qb
[b
[i
]];
541 logZ
+= log(Qa
[T
[0].first
].normalize());
543 Factor
piet( vsrem
, 0.0 );
544 piet
.set( s
, exp(logZ
) );
545 if( props
.inference
== Properties::InfType::SUMPROD
)
546 Pvs
+= piet
* Qa
[T
[0].first
].marginal( vs
/ vsrem
, false ); // OPTIMIZE ME
548 Pvs
+= piet
* Qa
[T
[0].first
].maxMarginal( vs
/ vsrem
, false ); // OPTIMIZE ME
550 // Restore clamped beliefs
551 for( map
<size_t,Factor
>::const_iterator alpha
= Qa_old
.begin(); alpha
!= Qa_old
.end(); alpha
++ )
552 Qa
[alpha
->first
] = alpha
->second
;
553 for( map
<size_t,Factor
>::const_iterator beta
= Qb_old
.begin(); beta
!= Qb_old
.end(); beta
++ )
554 Qb
[beta
->first
] = beta
->second
;
557 return( Pvs
.normalized() );
563 std::pair
<size_t,BigInt
> boundTreewidth( const FactorGraph
&fg
, greedyVariableElimination::eliminationCostFunction fn
, size_t maxStates
) {
564 // Create cluster graph from factor graph
565 ClusterGraph
_cg( fg
, true );
567 // Obtain elimination sequence
568 vector
<VarSet
> ElimVec
= _cg
.VarElim( greedyVariableElimination( fn
), maxStates
).eraseNonMaximal().clusters();
570 // Calculate treewidth
571 size_t treewidth
= 0;
572 BigInt nrstates
= 0.0;
573 for( size_t i
= 0; i
< ElimVec
.size(); i
++ ) {
574 if( ElimVec
[i
].size() > treewidth
)
575 treewidth
= ElimVec
[i
].size();
576 BigInt s
= ElimVec
[i
].nrStates();
581 return make_pair(treewidth
, nrstates
);
585 std::vector
<size_t> JTree::findMaximum() const {
586 vector
<size_t> maximum( nrVars() );
587 vector
<bool> visitedVars( nrVars(), false );
588 vector
<bool> visitedORs( nrORs(), false );
589 stack
<size_t> scheduledORs
;
590 scheduledORs
.push( 0 );
591 while( !scheduledORs
.empty() ) {
592 size_t alpha
= scheduledORs
.top();
594 if( visitedORs
[alpha
] )
596 visitedORs
[alpha
] = true;
598 // Get marginal of outer region alpha
599 Prob probF
= Qa
[alpha
].p();
601 // The allowed configuration is restrained according to the variables assigned so far:
602 // pick the argmax amongst the allowed states
603 Real maxProb
= -numeric_limits
<Real
>::max();
604 State
maxState( OR(alpha
).vars() );
606 for( State
s( OR(alpha
).vars() ); s
.valid(); ++s
) {
607 // First, calculate whether this state is consistent with variables that
608 // have been assigned already
609 bool allowedState
= true;
610 bforeach( const Var
& j
, OR(alpha
).vars() ) {
611 size_t j_index
= findVar(j
);
612 if( visitedVars
[j_index
] && maximum
[j_index
] != s(j
) ) {
613 allowedState
= false;
617 // If it is consistent, check if its probability is larger than what we have seen so far
619 if( probF
[s
] > maxProb
) {
627 DAI_ASSERT( maxProb
!= 0.0 );
628 DAI_ASSERT( Qa
[alpha
][maxState
] != 0.0 );
631 bforeach( const Var
& j
, OR(alpha
).vars() ) {
632 size_t j_index
= findVar(j
);
633 if( visitedVars
[j_index
] ) {
634 // We have already visited j earlier - hopefully our state is consistent
635 if( maximum
[j_index
] != maxState( j
) )
636 DAI_THROWE(RUNTIME_ERROR
,"MAP state inconsistent due to loops");
638 // We found a consistent state for variable j
639 visitedVars
[j_index
] = true;
640 maximum
[j_index
] = maxState( j
);
641 bforeach( const Neighbor
&beta
, nbOR(alpha
) )
642 bforeach( const Neighbor
&alpha2
, nbIR(beta
) )
643 if( !visitedORs
[alpha2
] )
644 scheduledORs
.push(alpha2
);
652 } // end of namespace dai