[Benjamin Piwowarski] Renamed "foreach" macro into "bforeach" to avoid conflicts...
[libdai.git] / examples / example.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <iostream>
10 #include <map>
11 #include <dai/alldai.h> // Include main libDAI header file
12 #include <dai/jtree.h>
13 #include <dai/bp.h>
14 #include <dai/decmap.h>
15
16
17 using namespace dai;
18 using namespace std;
19
20
21 int main( int argc, char *argv[] ) {
22 if ( argc != 2 && argc != 3 ) {
23 cout << "Usage: " << argv[0] << " <filename.fg> [maxstates]" << endl << endl;
24 cout << "Reads factor graph <filename.fg> and runs" << endl;
25 cout << "Belief Propagation, Max-Product and JunctionTree on it." << endl;
26 cout << "JunctionTree is only run if a junction tree is found with" << endl;
27 cout << "total number of states less than <maxstates> (where 0 means unlimited)." << endl << endl;
28 return 1;
29 } else {
30 // Report inference algorithms built into libDAI
31 cout << "Builtin inference algorithms: " << builtinInfAlgNames() << endl << endl;
32
33 // Read FactorGraph from the file specified by the first command line argument
34 FactorGraph fg;
35 fg.ReadFromFile(argv[1]);
36 size_t maxstates = 1000000;
37 if( argc == 3 )
38 maxstates = fromString<size_t>( argv[2] );
39
40 // Set some constants
41 size_t maxiter = 10000;
42 Real tol = 1e-9;
43 size_t verb = 1;
44
45 // Store the constants in a PropertySet object
46 PropertySet opts;
47 opts.set("maxiter",maxiter); // Maximum number of iterations
48 opts.set("tol",tol); // Tolerance for convergence
49 opts.set("verbose",verb); // Verbosity (amount of output generated)
50
51 // Bound treewidth for junctiontree
52 bool do_jt = true;
53 try {
54 boundTreewidth(fg, &eliminationCost_MinFill, maxstates );
55 } catch( Exception &e ) {
56 if( e.getCode() == Exception::OUT_OF_MEMORY ) {
57 do_jt = false;
58 cout << "Skipping junction tree (need more than " << maxstates << " states)." << endl;
59 }
60 else
61 throw;
62 }
63
64 JTree jt, jtmap;
65 vector<size_t> jtmapstate;
66 if( do_jt ) {
67 // Construct a JTree (junction tree) object from the FactorGraph fg
68 // using the parameters specified by opts and an additional property
69 // that specifies the type of updates the JTree algorithm should perform
70 jt = JTree( fg, opts("updates",string("HUGIN")) );
71 // Initialize junction tree algorithm
72 jt.init();
73 // Run junction tree algorithm
74 jt.run();
75
76 // Construct another JTree (junction tree) object that is used to calculate
77 // the joint configuration of variables that has maximum probability (MAP state)
78 jtmap = JTree( fg, opts("updates",string("HUGIN"))("inference",string("MAXPROD")) );
79 // Initialize junction tree algorithm
80 jtmap.init();
81 // Run junction tree algorithm
82 jtmap.run();
83 // Calculate joint state of all variables that has maximum probability
84 jtmapstate = jtmap.findMaximum();
85 }
86
87 // Construct a BP (belief propagation) object from the FactorGraph fg
88 // using the parameters specified by opts and two additional properties,
89 // specifying the type of updates the BP algorithm should perform and
90 // whether they should be done in the real or in the logdomain
91 BP bp(fg, opts("updates",string("SEQRND"))("logdomain",false));
92 // Initialize belief propagation algorithm
93 bp.init();
94 // Run belief propagation algorithm
95 bp.run();
96
97 // Construct a BP (belief propagation) object from the FactorGraph fg
98 // using the parameters specified by opts and two additional properties,
99 // specifying the type of updates the BP algorithm should perform and
100 // whether they should be done in the real or in the logdomain
101 //
102 // Note that inference is set to MAXPROD, which means that the object
103 // will perform the max-product algorithm instead of the sum-product algorithm
104 BP mp(fg, opts("updates",string("SEQRND"))("logdomain",false)("inference",string("MAXPROD"))("damping",string("0.1")));
105 // Initialize max-product algorithm
106 mp.init();
107 // Run max-product algorithm
108 mp.run();
109 // Calculate joint state of all variables that has maximum probability
110 // based on the max-product result
111 vector<size_t> mpstate = mp.findMaximum();
112
113 // Construct a decimation algorithm object from the FactorGraph fg
114 // using the parameters specified by opts and three additional properties,
115 // specifying that the decimation algorithm should use the max-product
116 // algorithm and should completely reinitalize its state at every step
117 DecMAP decmap(fg, opts("reinit",true)("ianame",string("BP"))("iaopts",string("[damping=0.1,inference=MAXPROD,logdomain=0,maxiter=1000,tol=1e-9,updates=SEQRND,verbose=1]")) );
118 decmap.init();
119 decmap.run();
120 vector<size_t> decmapstate = decmap.findMaximum();
121
122 if( do_jt ) {
123 // Report variable marginals for fg, calculated by the junction tree algorithm
124 cout << "Exact variable marginals:" << endl;
125 for( size_t i = 0; i < fg.nrVars(); i++ ) // iterate over all variables in fg
126 cout << jt.belief(fg.var(i)) << endl; // display the "belief" of jt for that variable
127 }
128
129 // Report variable marginals for fg, calculated by the belief propagation algorithm
130 cout << "Approximate (loopy belief propagation) variable marginals:" << endl;
131 for( size_t i = 0; i < fg.nrVars(); i++ ) // iterate over all variables in fg
132 cout << bp.belief(fg.var(i)) << endl; // display the belief of bp for that variable
133
134 if( do_jt ) {
135 // Report factor marginals for fg, calculated by the junction tree algorithm
136 cout << "Exact factor marginals:" << endl;
137 for( size_t I = 0; I < fg.nrFactors(); I++ ) // iterate over all factors in fg
138 cout << jt.belief(fg.factor(I).vars()) << endl; // display the "belief" of jt for the variables in that factor
139 }
140
141 // Report factor marginals for fg, calculated by the belief propagation algorithm
142 cout << "Approximate (loopy belief propagation) factor marginals:" << endl;
143 for( size_t I = 0; I < fg.nrFactors(); I++ ) // iterate over all factors in fg
144 cout << bp.belief(fg.factor(I).vars()) << endl; // display the belief of bp for the variables in that factor
145
146 if( do_jt ) {
147 // Report log partition sum (normalizing constant) of fg, calculated by the junction tree algorithm
148 cout << "Exact log partition sum: " << jt.logZ() << endl;
149 }
150
151 // Report log partition sum of fg, approximated by the belief propagation algorithm
152 cout << "Approximate (loopy belief propagation) log partition sum: " << bp.logZ() << endl;
153
154 if( do_jt ) {
155 // Report exact MAP variable marginals
156 cout << "Exact MAP variable marginals:" << endl;
157 for( size_t i = 0; i < fg.nrVars(); i++ )
158 cout << jtmap.belief(fg.var(i)) << endl;
159 }
160
161 // Report max-product variable marginals
162 cout << "Approximate (max-product) MAP variable marginals:" << endl;
163 for( size_t i = 0; i < fg.nrVars(); i++ )
164 cout << mp.belief(fg.var(i)) << endl;
165
166 if( do_jt ) {
167 // Report exact MAP factor marginals
168 cout << "Exact MAP factor marginals:" << endl;
169 for( size_t I = 0; I < fg.nrFactors(); I++ )
170 cout << jtmap.belief(fg.factor(I).vars()) << " == " << jtmap.beliefF(I) << endl;
171 }
172
173 // Report max-product factor marginals
174 cout << "Approximate (max-product) MAP factor marginals:" << endl;
175 for( size_t I = 0; I < fg.nrFactors(); I++ )
176 cout << mp.belief(fg.factor(I).vars()) << " == " << mp.beliefF(I) << endl;
177
178 if( do_jt ) {
179 // Report exact MAP joint state
180 cout << "Exact MAP state (log score = " << fg.logScore( jtmapstate ) << "):" << endl;
181 for( size_t i = 0; i < jtmapstate.size(); i++ )
182 cout << fg.var(i) << ": " << jtmapstate[i] << endl;
183 }
184
185 // Report max-product MAP joint state
186 cout << "Approximate (max-product) MAP state (log score = " << fg.logScore( mpstate ) << "):" << endl;
187 for( size_t i = 0; i < mpstate.size(); i++ )
188 cout << fg.var(i) << ": " << mpstate[i] << endl;
189
190 // Report DecMAP joint state
191 cout << "Approximate DecMAP state (log score = " << fg.logScore( decmapstate ) << "):" << endl;
192 for( size_t i = 0; i < decmapstate.size(); i++ )
193 cout << fg.var(i) << ": " << decmapstate[i] << endl;
194 }
195
196 return 0;
197 }