Merge branch 'master' of git.tuebingen.mpg.de:libdai
[libdai.git] / src / mf.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <sstream>
14 #include <map>
15 #include <set>
16 #include <dai/mf.h>
17 #include <dai/util.h>
18
19
20 namespace dai {
21
22
23 using namespace std;
24
25
26 void MF::setProperties( const PropertySet &opts ) {
27 DAI_ASSERT( opts.hasKey("tol") );
28 DAI_ASSERT( opts.hasKey("maxiter") );
29
30 props.tol = opts.getStringAs<Real>("tol");
31 props.maxiter = opts.getStringAs<size_t>("maxiter");
32 if( opts.hasKey("verbose") )
33 props.verbose = opts.getStringAs<size_t>("verbose");
34 else
35 props.verbose = 0U;
36 if( opts.hasKey("damping") )
37 props.damping = opts.getStringAs<Real>("damping");
38 else
39 props.damping = 0.0;
40 if( opts.hasKey("init") )
41 props.init = opts.getStringAs<Properties::InitType>("init");
42 else
43 props.init = Properties::InitType::UNIFORM;
44 if( opts.hasKey("updates") )
45 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
46 else
47 props.updates = Properties::UpdateType::NAIVE;
48 }
49
50
51 PropertySet MF::getProperties() const {
52 PropertySet opts;
53 opts.set( "tol", props.tol );
54 opts.set( "maxiter", props.maxiter );
55 opts.set( "verbose", props.verbose );
56 opts.set( "damping", props.damping );
57 opts.set( "init", props.init );
58 opts.set( "updates", props.updates );
59 return opts;
60 }
61
62
63 string MF::printProperties() const {
64 stringstream s( stringstream::out );
65 s << "[";
66 s << "tol=" << props.tol << ",";
67 s << "maxiter=" << props.maxiter << ",";
68 s << "verbose=" << props.verbose << ",";
69 s << "init=" << props.init << ",";
70 s << "updates=" << props.updates << ",";
71 s << "damping=" << props.damping << "]";
72 return s.str();
73 }
74
75
76 void MF::construct() {
77 // create beliefs
78 _beliefs.clear();
79 _beliefs.reserve( nrVars() );
80 for( size_t i = 0; i < nrVars(); ++i )
81 _beliefs.push_back( Factor( var(i) ) );
82 }
83
84
85 void MF::init() {
86 if( props.init == Properties::InitType::UNIFORM )
87 for( size_t i = 0; i < nrVars(); i++ )
88 _beliefs[i].fill( 1.0 );
89 else
90 for( size_t i = 0; i < nrVars(); i++ )
91 _beliefs[i].randomize();
92 }
93
94
95 Factor MF::calcNewBelief( size_t i ) {
96 Factor result;
97 foreach( const Neighbor &I, nbV(i) ) {
98 Factor belief_I_minus_i;
99 foreach( const Neighbor &j, nbF(I) ) // for all j in I \ i
100 if( j != i )
101 belief_I_minus_i *= _beliefs[j];
102 Factor f_I = factor(I);
103 if( props.updates == Properties::UpdateType::NAIVE )
104 f_I.takeLog();
105 Factor msg_I_i = (belief_I_minus_i * f_I).marginal( var(i), false );
106 if( props.updates == Properties::UpdateType::NAIVE )
107 result *= msg_I_i.exp();
108 else
109 result *= msg_I_i;
110 }
111 result.normalize();
112 return result;
113 }
114
115
116 Real MF::run() {
117 if( props.verbose >= 1 )
118 cerr << "Starting " << identify() << "...";
119
120 double tic = toc();
121
122 vector<size_t> update_seq;
123 update_seq.reserve( nrVars() );
124 for( size_t i = 0; i < nrVars(); i++ )
125 update_seq.push_back( i );
126
127 // do several passes over the network until maximum number of iterations has
128 // been reached or until the maximum belief difference is smaller than tolerance
129 Real maxDiff = INFINITY;
130 for( _iters = 0; _iters < props.maxiter && maxDiff > props.tol; _iters++ ) {
131 random_shuffle( update_seq.begin(), update_seq.end() );
132
133 maxDiff = -INFINITY;
134 foreach( const size_t &i, update_seq ) {
135 Factor nb = calcNewBelief( i );
136
137 if( nb.hasNaNs() ) {
138 cerr << name() << "::run(): ERROR: new belief of variable " << var(i) << " has NaNs!" << endl;
139 return 1.0;
140 }
141
142 if( props.damping != 0.0 )
143 nb = (nb^(1.0 - props.damping)) * (_beliefs[i]^props.damping);
144
145 maxDiff = std::max( maxDiff, dist( nb, _beliefs[i], DISTLINF ) );
146 _beliefs[i] = nb;
147 }
148
149 if( props.verbose >= 3 )
150 cerr << name() << "::run: maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
151 }
152
153 if( maxDiff > _maxdiff )
154 _maxdiff = maxDiff;
155
156 if( props.verbose >= 1 ) {
157 if( maxDiff > props.tol ) {
158 if( props.verbose == 1 )
159 cerr << endl;
160 cerr << name() << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
161 } else {
162 if( props.verbose >= 3 )
163 cerr << name() << "::run: ";
164 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
165 }
166 }
167
168 return maxDiff;
169 }
170
171
172 Factor MF::beliefV( size_t i ) const {
173 return _beliefs[i].normalized();
174 }
175
176
177 Factor MF::belief (const VarSet &ns) const {
178 if( ns.size() == 0 )
179 return Factor();
180 else if( ns.size() == 1 )
181 return beliefV( findVar( *(ns.begin()) ) );
182 else {
183 DAI_THROW(BELIEF_NOT_AVAILABLE);
184 return Factor();
185 }
186 }
187
188
189 vector<Factor> MF::beliefs() const {
190 vector<Factor> result;
191 for( size_t i = 0; i < nrVars(); i++ )
192 result.push_back( beliefV(i) );
193 return result;
194 }
195
196
197 Real MF::logZ() const {
198 Real s = 0.0;
199
200 for( size_t i = 0; i < nrVars(); i++ )
201 s -= beliefV(i).entropy();
202 for( size_t I = 0; I < nrFactors(); I++ ) {
203 Factor henk;
204 foreach( const Neighbor &j, nbF(I) ) // for all j in I
205 henk *= _beliefs[j];
206 henk.normalize();
207 Factor piet;
208 piet = factor(I).log(true);
209 piet *= henk;
210 s -= piet.sum();
211 }
212
213 return -s;
214 }
215
216
217 void MF::init( const VarSet &ns ) {
218 for( size_t i = 0; i < nrVars(); i++ )
219 if( ns.contains(var(i) ) ) {
220 if( props.init == Properties::InitType::UNIFORM )
221 _beliefs[i].fill( 1.0 );
222 else
223 _beliefs[i].randomize();
224 }
225 }
226
227
228 } // end of namespace dai