Multiple changes: changes in build system, one workaround and one bug fix
[libdai.git] / examples / example.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <iostream>
10 #include <map>
11 #include <dai/alldai.h> // Include main libDAI header file
12 #include <dai/jtree.h>
13 #include <dai/bp.h>
14 #include <dai/decmap.h>
15
16
17 using namespace dai;
18 using namespace std;
19
20
21 int main( int argc, char *argv[] ) {
22 #if defined(DAI_WITH_BP) && defined(DAI_WITH_JTREE)
23 if ( argc != 2 && argc != 3 ) {
24 cout << "Usage: " << argv[0] << " <filename.fg> [maxstates]" << endl << endl;
25 cout << "Reads factor graph <filename.fg> and runs" << endl;
26 cout << "Belief Propagation, Max-Product and JunctionTree on it." << endl;
27 cout << "JunctionTree is only run if a junction tree is found with" << endl;
28 cout << "total number of states less than <maxstates> (where 0 means unlimited)." << endl << endl;
29 return 1;
30 } else {
31 // Report inference algorithms built into libDAI
32 cout << "Builtin inference algorithms: " << builtinInfAlgNames() << endl << endl;
33
34 // Read FactorGraph from the file specified by the first command line argument
35 FactorGraph fg;
36 fg.ReadFromFile(argv[1]);
37 size_t maxstates = 1000000;
38 if( argc == 3 )
39 maxstates = fromString<size_t>( argv[2] );
40
41 // Set some constants
42 size_t maxiter = 10000;
43 Real tol = 1e-9;
44 size_t verb = 1;
45
46 // Store the constants in a PropertySet object
47 PropertySet opts;
48 opts.set("maxiter",maxiter); // Maximum number of iterations
49 opts.set("tol",tol); // Tolerance for convergence
50 opts.set("verbose",verb); // Verbosity (amount of output generated)
51
52 // Bound treewidth for junctiontree
53 bool do_jt = true;
54 try {
55 boundTreewidth(fg, &eliminationCost_MinFill, maxstates );
56 } catch( Exception &e ) {
57 if( e.getCode() == Exception::OUT_OF_MEMORY ) {
58 do_jt = false;
59 cout << "Skipping junction tree (need more than " << maxstates << " states)." << endl;
60 }
61 else
62 throw;
63 }
64
65 JTree jt, jtmap;
66 vector<size_t> jtmapstate;
67 if( do_jt ) {
68 // Construct a JTree (junction tree) object from the FactorGraph fg
69 // using the parameters specified by opts and an additional property
70 // that specifies the type of updates the JTree algorithm should perform
71 jt = JTree( fg, opts("updates",string("HUGIN")) );
72 // Initialize junction tree algorithm
73 jt.init();
74 // Run junction tree algorithm
75 jt.run();
76
77 // Construct another JTree (junction tree) object that is used to calculate
78 // the joint configuration of variables that has maximum probability (MAP state)
79 jtmap = JTree( fg, opts("updates",string("HUGIN"))("inference",string("MAXPROD")) );
80 // Initialize junction tree algorithm
81 jtmap.init();
82 // Run junction tree algorithm
83 jtmap.run();
84 // Calculate joint state of all variables that has maximum probability
85 jtmapstate = jtmap.findMaximum();
86 }
87
88 // Construct a BP (belief propagation) object from the FactorGraph fg
89 // using the parameters specified by opts and two additional properties,
90 // specifying the type of updates the BP algorithm should perform and
91 // whether they should be done in the real or in the logdomain
92 BP bp(fg, opts("updates",string("SEQRND"))("logdomain",false));
93 // Initialize belief propagation algorithm
94 bp.init();
95 // Run belief propagation algorithm
96 bp.run();
97
98 // Construct a BP (belief propagation) object from the FactorGraph fg
99 // using the parameters specified by opts and two additional properties,
100 // specifying the type of updates the BP algorithm should perform and
101 // whether they should be done in the real or in the logdomain
102 //
103 // Note that inference is set to MAXPROD, which means that the object
104 // will perform the max-product algorithm instead of the sum-product algorithm
105 BP mp(fg, opts("updates",string("SEQRND"))("logdomain",false)("inference",string("MAXPROD"))("damping",string("0.1")));
106 // Initialize max-product algorithm
107 mp.init();
108 // Run max-product algorithm
109 mp.run();
110 // Calculate joint state of all variables that has maximum probability
111 // based on the max-product result
112 vector<size_t> mpstate = mp.findMaximum();
113
114 // Construct a decimation algorithm object from the FactorGraph fg
115 // using the parameters specified by opts and three additional properties,
116 // specifying that the decimation algorithm should use the max-product
117 // algorithm and should completely reinitalize its state at every step
118 DecMAP decmap(fg, opts("reinit",true)("ianame",string("BP"))("iaopts",string("[damping=0.1,inference=MAXPROD,logdomain=0,maxiter=1000,tol=1e-9,updates=SEQRND,verbose=1]")) );
119 decmap.init();
120 decmap.run();
121 vector<size_t> decmapstate = decmap.findMaximum();
122
123 if( do_jt ) {
124 // Report variable marginals for fg, calculated by the junction tree algorithm
125 cout << "Exact variable marginals:" << endl;
126 for( size_t i = 0; i < fg.nrVars(); i++ ) // iterate over all variables in fg
127 cout << jt.belief(fg.var(i)) << endl; // display the "belief" of jt for that variable
128 }
129
130 // Report variable marginals for fg, calculated by the belief propagation algorithm
131 cout << "Approximate (loopy belief propagation) variable marginals:" << endl;
132 for( size_t i = 0; i < fg.nrVars(); i++ ) // iterate over all variables in fg
133 cout << bp.belief(fg.var(i)) << endl; // display the belief of bp for that variable
134
135 if( do_jt ) {
136 // Report factor marginals for fg, calculated by the junction tree algorithm
137 cout << "Exact factor marginals:" << endl;
138 for( size_t I = 0; I < fg.nrFactors(); I++ ) // iterate over all factors in fg
139 cout << jt.belief(fg.factor(I).vars()) << endl; // display the "belief" of jt for the variables in that factor
140 }
141
142 // Report factor marginals for fg, calculated by the belief propagation algorithm
143 cout << "Approximate (loopy belief propagation) factor marginals:" << endl;
144 for( size_t I = 0; I < fg.nrFactors(); I++ ) // iterate over all factors in fg
145 cout << bp.belief(fg.factor(I).vars()) << endl; // display the belief of bp for the variables in that factor
146
147 if( do_jt ) {
148 // Report log partition sum (normalizing constant) of fg, calculated by the junction tree algorithm
149 cout << "Exact log partition sum: " << jt.logZ() << endl;
150 }
151
152 // Report log partition sum of fg, approximated by the belief propagation algorithm
153 cout << "Approximate (loopy belief propagation) log partition sum: " << bp.logZ() << endl;
154
155 if( do_jt ) {
156 // Report exact MAP variable marginals
157 cout << "Exact MAP variable marginals:" << endl;
158 for( size_t i = 0; i < fg.nrVars(); i++ )
159 cout << jtmap.belief(fg.var(i)) << endl;
160 }
161
162 // Report max-product variable marginals
163 cout << "Approximate (max-product) MAP variable marginals:" << endl;
164 for( size_t i = 0; i < fg.nrVars(); i++ )
165 cout << mp.belief(fg.var(i)) << endl;
166
167 if( do_jt ) {
168 // Report exact MAP factor marginals
169 cout << "Exact MAP factor marginals:" << endl;
170 for( size_t I = 0; I < fg.nrFactors(); I++ )
171 cout << jtmap.belief(fg.factor(I).vars()) << " == " << jtmap.beliefF(I) << endl;
172 }
173
174 // Report max-product factor marginals
175 cout << "Approximate (max-product) MAP factor marginals:" << endl;
176 for( size_t I = 0; I < fg.nrFactors(); I++ )
177 cout << mp.belief(fg.factor(I).vars()) << " == " << mp.beliefF(I) << endl;
178
179 if( do_jt ) {
180 // Report exact MAP joint state
181 cout << "Exact MAP state (log score = " << fg.logScore( jtmapstate ) << "):" << endl;
182 for( size_t i = 0; i < jtmapstate.size(); i++ )
183 cout << fg.var(i) << ": " << jtmapstate[i] << endl;
184 }
185
186 // Report max-product MAP joint state
187 cout << "Approximate (max-product) MAP state (log score = " << fg.logScore( mpstate ) << "):" << endl;
188 for( size_t i = 0; i < mpstate.size(); i++ )
189 cout << fg.var(i) << ": " << mpstate[i] << endl;
190
191 // Report DecMAP joint state
192 cout << "Approximate DecMAP state (log score = " << fg.logScore( decmapstate ) << "):" << endl;
193 for( size_t i = 0; i < decmapstate.size(); i++ )
194 cout << fg.var(i) << ": " << decmapstate[i] << endl;
195 }
196
197 return 0;
198 #else
199 cout << "libDAI was configured without BP or JunctionTree (this can be changed in include/dai/dai_config.h)." << endl;
200 #endif
201 }