ac6efd0a973ee61006b27fb5f2be10a7a5da4cac
[libdai.git] / examples / example.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <map>
14 #include <dai/alldai.h> // Include main libDAI header file
15 #include <dai/jtree.h>
16 #include <dai/bp.h>
17 #include <dai/decmap.h>
18
19
20 using namespace dai;
21 using namespace std;
22
23
24 int main( int argc, char *argv[] ) {
25 if ( argc != 2 ) {
26 cout << "Usage: " << argv[0] << " <filename.fg>" << endl << endl;
27 cout << "Reads factor graph <filename.fg> and runs" << endl;
28 cout << "Belief Propagation and JunctionTree on it." << endl << endl;
29 return 1;
30 } else {
31 // Read FactorGraph from the file specified by the first command line argument
32 FactorGraph fg;
33 fg.ReadFromFile(argv[1]);
34
35 // Set some constants
36 size_t maxiter = 10000;
37 Real tol = 1e-9;
38 size_t verb = 1;
39
40 // Store the constants in a PropertySet object
41 PropertySet opts;
42 opts.set("maxiter",maxiter); // Maximum number of iterations
43 opts.set("tol",tol); // Tolerance for convergence
44 opts.set("verbose",verb); // Verbosity (amount of output generated)
45
46 // Construct a JTree (junction tree) object from the FactorGraph fg
47 // using the parameters specified by opts and an additional property
48 // that specifies the type of updates the JTree algorithm should perform
49 JTree jt( fg, opts("updates",string("HUGIN")) );
50 // Initialize junction tree algorithm
51 jt.init();
52 // Run junction tree algorithm
53 jt.run();
54
55 // Construct another JTree (junction tree) object that is used to calculate
56 // the joint configuration of variables that has maximum probability (MAP state)
57 JTree jtmap( fg, opts("updates",string("HUGIN"))("inference",string("MAXPROD")) );
58 // Initialize junction tree algorithm
59 jtmap.init();
60 // Run junction tree algorithm
61 jtmap.run();
62 // Calculate joint state of all variables that has maximum probability
63 vector<size_t> jtmapstate = jtmap.findMaximum();
64
65 // Construct a BP (belief propagation) object from the FactorGraph fg
66 // using the parameters specified by opts and two additional properties,
67 // specifying the type of updates the BP algorithm should perform and
68 // whether they should be done in the real or in the logdomain
69 BP bp(fg, opts("updates",string("SEQRND"))("logdomain",false));
70 // Initialize belief propagation algorithm
71 bp.init();
72 // Run belief propagation algorithm
73 bp.run();
74
75 // Construct a BP (belief propagation) object from the FactorGraph fg
76 // using the parameters specified by opts and two additional properties,
77 // specifying the type of updates the BP algorithm should perform and
78 // whether they should be done in the real or in the logdomain
79 //
80 // Note that inference is set to MAXPROD, which means that the object
81 // will perform the max-product algorithm instead of the sum-product algorithm
82 BP mp(fg, opts("updates",string("SEQRND"))("logdomain",false)("inference",string("MAXPROD"))("damping",string("0.1")));
83 // Initialize max-product algorithm
84 mp.init();
85 // Run max-product algorithm
86 mp.run();
87 // Calculate joint state of all variables that has maximum probability
88 // based on the max-product result
89 vector<size_t> mpstate = mp.findMaximum();
90
91 // Construct a decimation algorithm object from the FactorGraph fg
92 // using the parameters specified by opts and three additional properties,
93 // specifying that the decimation algorithm should use the max-product
94 // algorithm and should completely reinitalize its state at every step
95 DecMAP decmap(fg, opts("reinit",true)("ianame",string("BP"))("iaopts",string("[damping=0.1,inference=MAXPROD,logdomain=0,maxiter=1000,tol=1e-9,updates=SEQRND,verbose=1]")) );
96 decmap.init();
97 decmap.run();
98 vector<size_t> decmapstate = decmap.findMaximum();
99
100 // Report variable marginals for fg, calculated by the junction tree algorithm
101 cout << "Exact variable marginals:" << endl;
102 for( size_t i = 0; i < fg.nrVars(); i++ ) // iterate over all variables in fg
103 cout << jt.belief(fg.var(i)) << endl; // display the "belief" of jt for that variable
104
105 // Report variable marginals for fg, calculated by the belief propagation algorithm
106 cout << "Approximate (loopy belief propagation) variable marginals:" << endl;
107 for( size_t i = 0; i < fg.nrVars(); i++ ) // iterate over all variables in fg
108 cout << bp.belief(fg.var(i)) << endl; // display the belief of bp for that variable
109
110 // Report factor marginals for fg, calculated by the junction tree algorithm
111 cout << "Exact factor marginals:" << endl;
112 for( size_t I = 0; I < fg.nrFactors(); I++ ) // iterate over all factors in fg
113 cout << jt.belief(fg.factor(I).vars()) << endl; // display the "belief" of jt for the variables in that factor
114
115 // Report factor marginals for fg, calculated by the belief propagation algorithm
116 cout << "Approximate (loopy belief propagation) factor marginals:" << endl;
117 for( size_t I = 0; I < fg.nrFactors(); I++ ) // iterate over all factors in fg
118 cout << bp.belief(fg.factor(I).vars()) << endl; // display the belief of bp for the variables in that factor
119
120 // Report log partition sum (normalizing constant) of fg, calculated by the junction tree algorithm
121 cout << "Exact log partition sum: " << jt.logZ() << endl;
122
123 // Report log partition sum of fg, approximated by the belief propagation algorithm
124 cout << "Approximate (loopy belief propagation) log partition sum: " << bp.logZ() << endl;
125
126 // Report exact MAP variable marginals
127 cout << "Exact MAP variable marginals:" << endl;
128 for( size_t i = 0; i < fg.nrVars(); i++ )
129 cout << jtmap.belief(fg.var(i)) << endl;
130
131 // Report max-product variable marginals
132 cout << "Approximate (max-product) MAP variable marginals:" << endl;
133 for( size_t i = 0; i < fg.nrVars(); i++ )
134 cout << mp.belief(fg.var(i)) << endl;
135
136 // Report exact MAP factor marginals
137 cout << "Exact MAP factor marginals:" << endl;
138 for( size_t I = 0; I < fg.nrFactors(); I++ )
139 cout << jtmap.belief(fg.factor(I).vars()) << " == " << jtmap.beliefF(I) << endl;
140
141 // Report max-product factor marginals
142 cout << "Approximate (max-product) MAP factor marginals:" << endl;
143 for( size_t I = 0; I < fg.nrFactors(); I++ )
144 cout << mp.belief(fg.factor(I).vars()) << " == " << mp.beliefF(I) << endl;
145
146 // Report exact MAP joint state
147 cout << "Exact MAP state (log score = " << fg.logScore( jtmapstate ) << "):" << endl;
148 for( size_t i = 0; i < jtmapstate.size(); i++ )
149 cout << fg.var(i) << ": " << jtmapstate[i] << endl;
150
151 // Report max-product MAP joint state
152 cout << "Approximate (max-product) MAP state (log score = " << fg.logScore( mpstate ) << "):" << endl;
153 for( size_t i = 0; i < mpstate.size(); i++ )
154 cout << fg.var(i) << ": " << mpstate[i] << endl;
155
156 // Report DecMAP joint state
157 cout << "Approximate DecMAP state (log score = " << fg.logScore( decmapstate ) << "):" << endl;
158 for( size_t i = 0; i < decmapstate.size(); i++ )
159 cout << fg.var(i) << ": " << decmapstate[i] << endl;
160 }
161
162 return 0;
163 }