Improved mf.h/cpp by adding "init" and "updates" options
[libdai.git] / src / mf.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <sstream>
14 #include <map>
15 #include <set>
16 #include <dai/mf.h>
17 #include <dai/util.h>
18
19
20 namespace dai {
21
22
23 using namespace std;
24
25
26 const char *MF::Name = "MF";
27
28
29 void MF::setProperties( const PropertySet &opts ) {
30 DAI_ASSERT( opts.hasKey("tol") );
31 DAI_ASSERT( opts.hasKey("maxiter") );
32
33 props.tol = opts.getStringAs<Real>("tol");
34 props.maxiter = opts.getStringAs<size_t>("maxiter");
35 if( opts.hasKey("verbose") )
36 props.verbose = opts.getStringAs<size_t>("verbose");
37 else
38 props.verbose = 0U;
39 if( opts.hasKey("damping") )
40 props.damping = opts.getStringAs<Real>("damping");
41 else
42 props.damping = 0.0;
43 if( opts.hasKey("init") )
44 props.init = opts.getStringAs<Properties::InitType>("init");
45 else
46 props.init = Properties::InitType::UNIFORM;
47 if( opts.hasKey("updates") )
48 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
49 else
50 props.updates = Properties::UpdateType::NAIVE;
51 }
52
53
54 PropertySet MF::getProperties() const {
55 PropertySet opts;
56 opts.set( "tol", props.tol );
57 opts.set( "maxiter", props.maxiter );
58 opts.set( "verbose", props.verbose );
59 opts.set( "damping", props.damping );
60 opts.set( "init", props.init );
61 opts.set( "updates", props.updates );
62 return opts;
63 }
64
65
66 string MF::printProperties() const {
67 stringstream s( stringstream::out );
68 s << "[";
69 s << "tol=" << props.tol << ",";
70 s << "maxiter=" << props.maxiter << ",";
71 s << "verbose=" << props.verbose << ",";
72 s << "init=" << props.init << ",";
73 s << "updates=" << props.updates << ",";
74 s << "damping=" << props.damping << "]";
75 return s.str();
76 }
77
78
79 void MF::construct() {
80 // create beliefs
81 _beliefs.clear();
82 _beliefs.reserve( nrVars() );
83 for( size_t i = 0; i < nrVars(); ++i )
84 _beliefs.push_back( Factor( var(i) ) );
85 }
86
87
88 string MF::identify() const {
89 return string(Name) + printProperties();
90 }
91
92
93 void MF::init() {
94 if( props.init == Properties::InitType::UNIFORM )
95 for( size_t i = 0; i < nrVars(); i++ )
96 _beliefs[i].fill( 1.0 );
97 else
98 for( size_t i = 0; i < nrVars(); i++ )
99 _beliefs[i].randomize();
100 }
101
102
103 Factor MF::calcNewBelief( size_t i ) {
104 Factor result;
105 foreach( const Neighbor &I, nbV(i) ) {
106 Factor belief_I_minus_i;
107 foreach( const Neighbor &j, nbF(I) ) // for all j in I \ i
108 if( j != i )
109 belief_I_minus_i *= _beliefs[j];
110 Factor f_I = factor(I);
111 if( props.updates == Properties::UpdateType::NAIVE )
112 f_I.takeLog();
113 Factor msg_I_i = (belief_I_minus_i * f_I).marginal( var(i), false );
114 if( props.updates == Properties::UpdateType::NAIVE )
115 result *= msg_I_i.exp();
116 else
117 result *= msg_I_i;
118 }
119 result.normalize();
120 return result;
121 }
122
123
124 Real MF::run() {
125 if( props.verbose >= 1 )
126 cerr << "Starting " << identify() << "...";
127
128 double tic = toc();
129
130 vector<size_t> update_seq;
131 update_seq.reserve( nrVars() );
132 for( size_t i = 0; i < nrVars(); i++ )
133 update_seq.push_back( i );
134
135 // do several passes over the network until maximum number of iterations has
136 // been reached or until the maximum belief difference is smaller than tolerance
137 Real maxDiff = INFINITY;
138 for( _iters = 0; _iters < props.maxiter && maxDiff > props.tol; _iters++ ) {
139 random_shuffle( update_seq.begin(), update_seq.end() );
140
141 maxDiff = -INFINITY;
142 foreach( const size_t &i, update_seq ) {
143 Factor nb = calcNewBelief( i );
144
145 if( nb.hasNaNs() ) {
146 cerr << Name << "::run(): ERROR: new belief of variable " << var(i) << " has NaNs!" << endl;
147 return 1.0;
148 }
149
150 if( props.damping != 0.0 )
151 nb = (nb^(1.0 - props.damping)) * (_beliefs[i]^props.damping);
152
153 maxDiff = std::max( maxDiff, dist( nb, _beliefs[i], DISTLINF ) );
154 _beliefs[i] = nb;
155 }
156
157 if( props.verbose >= 3 )
158 cerr << Name << "::run: maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
159 }
160
161 if( maxDiff > _maxdiff )
162 _maxdiff = maxDiff;
163
164 if( props.verbose >= 1 ) {
165 if( maxDiff > props.tol ) {
166 if( props.verbose == 1 )
167 cerr << endl;
168 cerr << Name << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
169 } else {
170 if( props.verbose >= 3 )
171 cerr << Name << "::run: ";
172 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
173 }
174 }
175
176 return maxDiff;
177 }
178
179
180 Factor MF::beliefV( size_t i ) const {
181 return _beliefs[i].normalized();
182 }
183
184
185 Factor MF::belief (const VarSet &ns) const {
186 if( ns.size() == 0 )
187 return Factor();
188 else if( ns.size() == 1 )
189 return beliefV( findVar( *(ns.begin()) ) );
190 else {
191 DAI_THROW(BELIEF_NOT_AVAILABLE);
192 return Factor();
193 }
194 }
195
196
197 vector<Factor> MF::beliefs() const {
198 vector<Factor> result;
199 for( size_t i = 0; i < nrVars(); i++ )
200 result.push_back( beliefV(i) );
201 return result;
202 }
203
204
205 Real MF::logZ() const {
206 Real s = 0.0;
207
208 for( size_t i = 0; i < nrVars(); i++ )
209 s -= beliefV(i).entropy();
210 for( size_t I = 0; I < nrFactors(); I++ ) {
211 Factor henk;
212 foreach( const Neighbor &j, nbF(I) ) // for all j in I
213 henk *= _beliefs[j];
214 henk.normalize();
215 Factor piet;
216 piet = factor(I).log(true);
217 piet *= henk;
218 s -= piet.sum();
219 }
220
221 return -s;
222 }
223
224
225 void MF::init( const VarSet &ns ) {
226 for( size_t i = 0; i < nrVars(); i++ )
227 if( ns.contains(var(i) ) ) {
228 if( props.init == Properties::InitType::UNIFORM )
229 _beliefs[i].fill( 1.0 );
230 else
231 _beliefs[i].randomize();
232 }
233 }
234
235
236 } // end of namespace dai