Added decimation algorithm DECMAP
[libdai.git] / examples / example.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <map>
14 #include <dai/alldai.h> // Include main libDAI header file
15
16
17 using namespace dai;
18 using namespace std;
19
20
21 // Evaluates the log probability of the joint configuration vstate of the variables in factor graph fg
22 Real evalState( const FactorGraph& fg, const std::vector<size_t> vstate ) {
23 // First, construct a map from Var objects to their states
24 // This decouples the representation of the joint state in vstate
25 // from the factor graph
26 map<Var, size_t> state;
27 for( size_t i = 0; i < vstate.size(); i++ )
28 state[fg.var(i)] = vstate[i];
29
30 // Evaluate the log probability of the joint configuration in state
31 // by summing the log factor entries of the factors in factor graph fg
32 // that correspond to the joint configuration
33 Real logZ = 0.0;
34 for( size_t I = 0; I < fg.nrFactors(); I++ )
35 logZ += dai::log( fg.factor(I)[calcLinearState( fg.factor(I).vars(), state )] );
36 return logZ;
37 }
38
39
40 int main( int argc, char *argv[] ) {
41 if ( argc != 2 ) {
42 cout << "Usage: " << argv[0] << " <filename.fg>" << endl << endl;
43 cout << "Reads factor graph <filename.fg> and runs" << endl;
44 cout << "Belief Propagation and JunctionTree on it." << endl << endl;
45 return 1;
46 } else {
47 // Read FactorGraph from the file specified by the first command line argument
48 FactorGraph fg;
49 fg.ReadFromFile(argv[1]);
50
51 // Set some constants
52 size_t maxiter = 10000;
53 Real tol = 1e-9;
54 size_t verb = 1;
55
56 // Store the constants in a PropertySet object
57 PropertySet opts;
58 opts.set("maxiter",maxiter); // Maximum number of iterations
59 opts.set("tol",tol); // Tolerance for convergence
60 opts.set("verbose",verb); // Verbosity (amount of output generated)
61
62 // Construct a JTree (junction tree) object from the FactorGraph fg
63 // using the parameters specified by opts and an additional property
64 // that specifies the type of updates the JTree algorithm should perform
65 JTree jt( fg, opts("updates",string("HUGIN")) );
66 // Initialize junction tree algorithm
67 jt.init();
68 // Run junction tree algorithm
69 jt.run();
70
71 // Construct another JTree (junction tree) object that is used to calculate
72 // the joint configuration of variables that has maximum probability (MAP state)
73 JTree jtmap( fg, opts("updates",string("HUGIN"))("inference",string("MAXPROD")) );
74 // Initialize junction tree algorithm
75 jtmap.init();
76 // Run junction tree algorithm
77 jtmap.run();
78 // Calculate joint state of all variables that has maximum probability
79 vector<size_t> jtmapstate = jtmap.findMaximum();
80
81 // Construct a BP (belief propagation) object from the FactorGraph fg
82 // using the parameters specified by opts and two additional properties,
83 // specifying the type of updates the BP algorithm should perform and
84 // whether they should be done in the real or in the logdomain
85 BP bp(fg, opts("updates",string("SEQRND"))("logdomain",false));
86 // Initialize belief propagation algorithm
87 bp.init();
88 // Run belief propagation algorithm
89 bp.run();
90
91 // Construct a BP (belief propagation) object from the FactorGraph fg
92 // using the parameters specified by opts and two additional properties,
93 // specifying the type of updates the BP algorithm should perform and
94 // whether they should be done in the real or in the logdomain
95 //
96 // Note that inference is set to MAXPROD, which means that the object
97 // will perform the max-product algorithm instead of the sum-product algorithm
98 BP mp(fg, opts("updates",string("SEQRND"))("logdomain",false)("inference",string("MAXPROD"))("damping",string("0.1")));
99 // Initialize max-product algorithm
100 mp.init();
101 // Run max-product algorithm
102 mp.run();
103 // Calculate joint state of all variables that has maximum probability
104 // based on the max-product result
105 vector<size_t> mpstate = mp.findMaximum();
106
107 // Construct a decimation algorithm object from the FactorGraph fg
108 // using the parameters specified by opts and three additional properties,
109 // specifying that the decimation algorithm should use the max-product
110 // algorithm and should completely reinitalize its state at every step
111 DecMAP decmap(fg, opts("reinit",true)("ianame",string("BP"))("iaopts",string("[damping=0.1,inference=MAXPROD,logdomain=0,maxiter=1000,tol=1e-9,updates=SEQRND,verbose=1]")) );
112 decmap.init();
113 decmap.run();
114 vector<size_t> decmapstate = decmap.findMaximum();
115
116 // Report variable marginals for fg, calculated by the junction tree algorithm
117 cout << "Exact variable marginals:" << endl;
118 for( size_t i = 0; i < fg.nrVars(); i++ ) // iterate over all variables in fg
119 cout << jt.belief(fg.var(i)) << endl; // display the "belief" of jt for that variable
120
121 // Report variable marginals for fg, calculated by the belief propagation algorithm
122 cout << "Approximate (loopy belief propagation) variable marginals:" << endl;
123 for( size_t i = 0; i < fg.nrVars(); i++ ) // iterate over all variables in fg
124 cout << bp.belief(fg.var(i)) << endl; // display the belief of bp for that variable
125
126 // Report factor marginals for fg, calculated by the junction tree algorithm
127 cout << "Exact factor marginals:" << endl;
128 for( size_t I = 0; I < fg.nrFactors(); I++ ) // iterate over all factors in fg
129 cout << jt.belief(fg.factor(I).vars()) << endl; // display the "belief" of jt for the variables in that factor
130
131 // Report factor marginals for fg, calculated by the belief propagation algorithm
132 cout << "Approximate (loopy belief propagation) factor marginals:" << endl;
133 for( size_t I = 0; I < fg.nrFactors(); I++ ) // iterate over all factors in fg
134 cout << bp.belief(fg.factor(I).vars()) << endl; // display the belief of bp for the variables in that factor
135
136 // Report log partition sum (normalizing constant) of fg, calculated by the junction tree algorithm
137 cout << "Exact log partition sum: " << jt.logZ() << endl;
138
139 // Report log partition sum of fg, approximated by the belief propagation algorithm
140 cout << "Approximate (loopy belief propagation) log partition sum: " << bp.logZ() << endl;
141
142 // Report exact MAP variable marginals
143 cout << "Exact MAP variable marginals:" << endl;
144 for( size_t i = 0; i < fg.nrVars(); i++ )
145 cout << jtmap.belief(fg.var(i)) << endl;
146
147 // Report max-product variable marginals
148 cout << "Approximate (max-product) MAP variable marginals:" << endl;
149 for( size_t i = 0; i < fg.nrVars(); i++ )
150 cout << mp.belief(fg.var(i)) << endl;
151
152 // Report exact MAP factor marginals
153 cout << "Exact MAP factor marginals:" << endl;
154 for( size_t I = 0; I < fg.nrFactors(); I++ )
155 cout << jtmap.belief(fg.factor(I).vars()) << "=" << jtmap.beliefF(I) << endl;
156
157 // Report max-product factor marginals
158 cout << "Approximate (max-product) MAP factor marginals:" << endl;
159 for( size_t I = 0; I < fg.nrFactors(); I++ )
160 cout << mp.belief(fg.factor(I).vars()) << "=" << mp.beliefF(I) << endl;
161
162 // Report exact MAP joint state
163 cout << "Exact MAP state (log probability = " << evalState( fg, jtmapstate ) << "):" << endl;
164 for( size_t i = 0; i < jtmapstate.size(); i++ )
165 cout << fg.var(i) << ": " << jtmapstate[i] << endl;
166
167 // Report max-product MAP joint state
168 cout << "Approximate (max-product) MAP state (log probability = " << evalState( fg, mpstate ) << "):" << endl;
169 for( size_t i = 0; i < mpstate.size(); i++ )
170 cout << fg.var(i) << ": " << mpstate[i] << endl;
171
172 // Report DecMAP joint state
173 cout << "Approximate DecMAP state (log probability = " << evalState( fg, decmapstate ) << "):" << endl;
174 for( size_t i = 0; i < decmapstate.size(); i++ )
175 cout << fg.var(i) << ": " << decmapstate[i] << endl;
176 }
177
178 return 0;
179 }