1 /* This file is part of libDAI - http://www.libdai.org/
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
11 #include <dai/alldai.h> // Include main libDAI header file
12 #include <dai/jtree.h>
14 #include <dai/decmap.h>
21 int main( int argc
, char *argv
[] ) {
22 #if defined(DAI_WITH_BP) && defined(DAI_WITH_JTREE)
23 if ( argc
!= 2 && argc
!= 3 ) {
24 cout
<< "Usage: " << argv
[0] << " <filename.fg> [maxstates]" << endl
<< endl
;
25 cout
<< "Reads factor graph <filename.fg> and runs" << endl
;
26 cout
<< "Belief Propagation, Max-Product and JunctionTree on it." << endl
;
27 cout
<< "JunctionTree is only run if a junction tree is found with" << endl
;
28 cout
<< "total number of states less than <maxstates> (where 0 means unlimited)." << endl
<< endl
;
31 // Report inference algorithms built into libDAI
32 cout
<< "Builtin inference algorithms: " << builtinInfAlgNames() << endl
<< endl
;
34 // Read FactorGraph from the file specified by the first command line argument
36 fg
.ReadFromFile(argv
[1]);
37 size_t maxstates
= 1000000;
39 maxstates
= fromString
<size_t>( argv
[2] );
42 size_t maxiter
= 10000;
46 // Store the constants in a PropertySet object
48 opts
.set("maxiter",maxiter
); // Maximum number of iterations
49 opts
.set("tol",tol
); // Tolerance for convergence
50 opts
.set("verbose",verb
); // Verbosity (amount of output generated)
52 // Bound treewidth for junctiontree
55 boundTreewidth(fg
, &eliminationCost_MinFill
, maxstates
);
56 } catch( Exception
&e
) {
57 if( e
.getCode() == Exception::OUT_OF_MEMORY
) {
59 cout
<< "Skipping junction tree (need more than " << maxstates
<< " states)." << endl
;
66 vector
<size_t> jtmapstate
;
68 // Construct a JTree (junction tree) object from the FactorGraph fg
69 // using the parameters specified by opts and an additional property
70 // that specifies the type of updates the JTree algorithm should perform
71 jt
= JTree( fg
, opts("updates",string("HUGIN")) );
72 // Initialize junction tree algorithm
74 // Run junction tree algorithm
77 // Construct another JTree (junction tree) object that is used to calculate
78 // the joint configuration of variables that has maximum probability (MAP state)
79 jtmap
= JTree( fg
, opts("updates",string("HUGIN"))("inference",string("MAXPROD")) );
80 // Initialize junction tree algorithm
82 // Run junction tree algorithm
84 // Calculate joint state of all variables that has maximum probability
85 jtmapstate
= jtmap
.findMaximum();
88 // Construct a BP (belief propagation) object from the FactorGraph fg
89 // using the parameters specified by opts and two additional properties,
90 // specifying the type of updates the BP algorithm should perform and
91 // whether they should be done in the real or in the logdomain
92 BP
bp(fg
, opts("updates",string("SEQRND"))("logdomain",false));
93 // Initialize belief propagation algorithm
95 // Run belief propagation algorithm
98 // Construct a BP (belief propagation) object from the FactorGraph fg
99 // using the parameters specified by opts and two additional properties,
100 // specifying the type of updates the BP algorithm should perform and
101 // whether they should be done in the real or in the logdomain
103 // Note that inference is set to MAXPROD, which means that the object
104 // will perform the max-product algorithm instead of the sum-product algorithm
105 BP
mp(fg
, opts("updates",string("SEQRND"))("logdomain",false)("inference",string("MAXPROD"))("damping",string("0.1")));
106 // Initialize max-product algorithm
108 // Run max-product algorithm
110 // Calculate joint state of all variables that has maximum probability
111 // based on the max-product result
112 vector
<size_t> mpstate
= mp
.findMaximum();
114 // Construct a decimation algorithm object from the FactorGraph fg
115 // using the parameters specified by opts and three additional properties,
116 // specifying that the decimation algorithm should use the max-product
117 // algorithm and should completely reinitalize its state at every step
118 DecMAP
decmap(fg
, opts("reinit",true)("ianame",string("BP"))("iaopts",string("[damping=0.1,inference=MAXPROD,logdomain=0,maxiter=1000,tol=1e-9,updates=SEQRND,verbose=1]")) );
121 vector
<size_t> decmapstate
= decmap
.findMaximum();
124 // Report variable marginals for fg, calculated by the junction tree algorithm
125 cout
<< "Exact variable marginals:" << endl
;
126 for( size_t i
= 0; i
< fg
.nrVars(); i
++ ) // iterate over all variables in fg
127 cout
<< jt
.belief(fg
.var(i
)) << endl
; // display the "belief" of jt for that variable
130 // Report variable marginals for fg, calculated by the belief propagation algorithm
131 cout
<< "Approximate (loopy belief propagation) variable marginals:" << endl
;
132 for( size_t i
= 0; i
< fg
.nrVars(); i
++ ) // iterate over all variables in fg
133 cout
<< bp
.belief(fg
.var(i
)) << endl
; // display the belief of bp for that variable
136 // Report factor marginals for fg, calculated by the junction tree algorithm
137 cout
<< "Exact factor marginals:" << endl
;
138 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ ) // iterate over all factors in fg
139 cout
<< jt
.belief(fg
.factor(I
).vars()) << endl
; // display the "belief" of jt for the variables in that factor
142 // Report factor marginals for fg, calculated by the belief propagation algorithm
143 cout
<< "Approximate (loopy belief propagation) factor marginals:" << endl
;
144 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ ) // iterate over all factors in fg
145 cout
<< bp
.belief(fg
.factor(I
).vars()) << endl
; // display the belief of bp for the variables in that factor
148 // Report log partition sum (normalizing constant) of fg, calculated by the junction tree algorithm
149 cout
<< "Exact log partition sum: " << jt
.logZ() << endl
;
152 // Report log partition sum of fg, approximated by the belief propagation algorithm
153 cout
<< "Approximate (loopy belief propagation) log partition sum: " << bp
.logZ() << endl
;
156 // Report exact MAP variable marginals
157 cout
<< "Exact MAP variable marginals:" << endl
;
158 for( size_t i
= 0; i
< fg
.nrVars(); i
++ )
159 cout
<< jtmap
.belief(fg
.var(i
)) << endl
;
162 // Report max-product variable marginals
163 cout
<< "Approximate (max-product) MAP variable marginals:" << endl
;
164 for( size_t i
= 0; i
< fg
.nrVars(); i
++ )
165 cout
<< mp
.belief(fg
.var(i
)) << endl
;
168 // Report exact MAP factor marginals
169 cout
<< "Exact MAP factor marginals:" << endl
;
170 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ )
171 cout
<< jtmap
.belief(fg
.factor(I
).vars()) << " == " << jtmap
.beliefF(I
) << endl
;
174 // Report max-product factor marginals
175 cout
<< "Approximate (max-product) MAP factor marginals:" << endl
;
176 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ )
177 cout
<< mp
.belief(fg
.factor(I
).vars()) << " == " << mp
.beliefF(I
) << endl
;
180 // Report exact MAP joint state
181 cout
<< "Exact MAP state (log score = " << fg
.logScore( jtmapstate
) << "):" << endl
;
182 for( size_t i
= 0; i
< jtmapstate
.size(); i
++ )
183 cout
<< fg
.var(i
) << ": " << jtmapstate
[i
] << endl
;
186 // Report max-product MAP joint state
187 cout
<< "Approximate (max-product) MAP state (log score = " << fg
.logScore( mpstate
) << "):" << endl
;
188 for( size_t i
= 0; i
< mpstate
.size(); i
++ )
189 cout
<< fg
.var(i
) << ": " << mpstate
[i
] << endl
;
191 // Report DecMAP joint state
192 cout
<< "Approximate DecMAP state (log score = " << fg
.logScore( decmapstate
) << "):" << endl
;
193 for( size_t i
= 0; i
< decmapstate
.size(); i
++ )
194 cout
<< fg
.var(i
) << ": " << decmapstate
[i
] << endl
;
199 cout
<< "libDAI was configured without BP or JunctionTree (this can be changed in include/dai/dai_config.h)." << endl
;