[Benjamin Piwowarski] Renamed "foreach" macro into "bforeach" to avoid conflicts...
[libdai.git] / src / mf.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <iostream>
10 #include <sstream>
11 #include <map>
12 #include <set>
13 #include <dai/mf.h>
14 #include <dai/util.h>
15
16
17 namespace dai {
18
19
20 using namespace std;
21
22
23 void MF::setProperties( const PropertySet &opts ) {
24 DAI_ASSERT( opts.hasKey("tol") );
25 DAI_ASSERT( opts.hasKey("maxiter") );
26
27 props.tol = opts.getStringAs<Real>("tol");
28 props.maxiter = opts.getStringAs<size_t>("maxiter");
29 if( opts.hasKey("verbose") )
30 props.verbose = opts.getStringAs<size_t>("verbose");
31 else
32 props.verbose = 0U;
33 if( opts.hasKey("damping") )
34 props.damping = opts.getStringAs<Real>("damping");
35 else
36 props.damping = 0.0;
37 if( opts.hasKey("init") )
38 props.init = opts.getStringAs<Properties::InitType>("init");
39 else
40 props.init = Properties::InitType::UNIFORM;
41 if( opts.hasKey("updates") )
42 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
43 else
44 props.updates = Properties::UpdateType::NAIVE;
45 }
46
47
48 PropertySet MF::getProperties() const {
49 PropertySet opts;
50 opts.set( "tol", props.tol );
51 opts.set( "maxiter", props.maxiter );
52 opts.set( "verbose", props.verbose );
53 opts.set( "damping", props.damping );
54 opts.set( "init", props.init );
55 opts.set( "updates", props.updates );
56 return opts;
57 }
58
59
60 string MF::printProperties() const {
61 stringstream s( stringstream::out );
62 s << "[";
63 s << "tol=" << props.tol << ",";
64 s << "maxiter=" << props.maxiter << ",";
65 s << "verbose=" << props.verbose << ",";
66 s << "init=" << props.init << ",";
67 s << "updates=" << props.updates << ",";
68 s << "damping=" << props.damping << "]";
69 return s.str();
70 }
71
72
73 void MF::construct() {
74 // create beliefs
75 _beliefs.clear();
76 _beliefs.reserve( nrVars() );
77 for( size_t i = 0; i < nrVars(); ++i )
78 _beliefs.push_back( Factor( var(i) ) );
79 }
80
81
82 void MF::init() {
83 if( props.init == Properties::InitType::UNIFORM )
84 for( size_t i = 0; i < nrVars(); i++ )
85 _beliefs[i].fill( 1.0 );
86 else
87 for( size_t i = 0; i < nrVars(); i++ )
88 _beliefs[i].randomize();
89 }
90
91
92 Factor MF::calcNewBelief( size_t i ) {
93 Factor result;
94 bforeach( const Neighbor &I, nbV(i) ) {
95 Factor belief_I_minus_i;
96 bforeach( const Neighbor &j, nbF(I) ) // for all j in I \ i
97 if( j != i )
98 belief_I_minus_i *= _beliefs[j];
99 Factor f_I = factor(I);
100 if( props.updates == Properties::UpdateType::NAIVE )
101 f_I.takeLog(true);
102 Factor msg_I_i = (belief_I_minus_i * f_I).marginal( var(i), false );
103 if( props.updates == Properties::UpdateType::NAIVE )
104 result *= msg_I_i.exp();
105 else
106 result *= msg_I_i;
107 }
108 result.normalize();
109 return result;
110 }
111
112
113 Real MF::run() {
114 if( props.verbose >= 1 )
115 cerr << "Starting " << identify() << "...";
116
117 double tic = toc();
118
119 vector<size_t> update_seq;
120 update_seq.reserve( nrVars() );
121 for( size_t i = 0; i < nrVars(); i++ )
122 update_seq.push_back( i );
123
124 // do several passes over the network until maximum number of iterations has
125 // been reached or until the maximum belief difference is smaller than tolerance
126 Real maxDiff = INFINITY;
127 for( _iters = 0; _iters < props.maxiter && maxDiff > props.tol; _iters++ ) {
128 random_shuffle( update_seq.begin(), update_seq.end(), rnd );
129
130 maxDiff = -INFINITY;
131 bforeach( const size_t &i, update_seq ) {
132 Factor nb = calcNewBelief( i );
133
134 if( nb.hasNaNs() ) {
135 cerr << name() << "::run(): ERROR: new belief of variable " << var(i) << " has NaNs!" << endl;
136 return 1.0;
137 }
138
139 if( props.damping != 0.0 )
140 nb = (nb^(1.0 - props.damping)) * (_beliefs[i]^props.damping);
141
142 maxDiff = std::max( maxDiff, dist( nb, _beliefs[i], DISTLINF ) );
143 _beliefs[i] = nb;
144 }
145
146 if( props.verbose >= 3 )
147 cerr << name() << "::run: maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
148 }
149
150 if( maxDiff > _maxdiff )
151 _maxdiff = maxDiff;
152
153 if( props.verbose >= 1 ) {
154 if( maxDiff > props.tol ) {
155 if( props.verbose == 1 )
156 cerr << endl;
157 cerr << name() << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
158 } else {
159 if( props.verbose >= 3 )
160 cerr << name() << "::run: ";
161 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
162 }
163 }
164
165 return maxDiff;
166 }
167
168
169 Factor MF::beliefV( size_t i ) const {
170 return _beliefs[i].normalized();
171 }
172
173
174 Factor MF::belief (const VarSet &ns) const {
175 if( ns.size() == 0 )
176 return Factor();
177 else if( ns.size() == 1 )
178 return beliefV( findVar( *(ns.begin()) ) );
179 else {
180 DAI_THROW(BELIEF_NOT_AVAILABLE);
181 return Factor();
182 }
183 }
184
185
186 vector<Factor> MF::beliefs() const {
187 vector<Factor> result;
188 for( size_t i = 0; i < nrVars(); i++ )
189 result.push_back( beliefV(i) );
190 return result;
191 }
192
193
194 Real MF::logZ() const {
195 Real s = 0.0;
196
197 for( size_t i = 0; i < nrVars(); i++ )
198 s -= beliefV(i).entropy();
199 for( size_t I = 0; I < nrFactors(); I++ ) {
200 Factor henk;
201 bforeach( const Neighbor &j, nbF(I) ) // for all j in I
202 henk *= _beliefs[j];
203 henk.normalize();
204 Factor piet;
205 piet = factor(I).log(true);
206 piet *= henk;
207 s -= piet.sum();
208 }
209
210 return -s;
211 }
212
213
214 void MF::init( const VarSet &ns ) {
215 for( size_t i = 0; i < nrVars(); i++ )
216 if( ns.contains(var(i) ) ) {
217 if( props.init == Properties::InitType::UNIFORM )
218 _beliefs[i].fill( 1.0 );
219 else
220 _beliefs[i].randomize();
221 }
222 }
223
224
225 } // end of namespace dai