New git HEAD version
[libdai.git] / src / mf.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <dai/dai_config.h>
10 #ifdef DAI_WITH_MF
11
12
13 #include <iostream>
14 #include <sstream>
15 #include <map>
16 #include <set>
17 #include <dai/mf.h>
18 #include <dai/util.h>
19
20
21 namespace dai {
22
23
24 using namespace std;
25
26
27 void MF::setProperties( const PropertySet &opts ) {
28 DAI_ASSERT( opts.hasKey("tol") );
29 DAI_ASSERT( opts.hasKey("maxiter") );
30
31 props.tol = opts.getStringAs<Real>("tol");
32 props.maxiter = opts.getStringAs<size_t>("maxiter");
33 if( opts.hasKey("verbose") )
34 props.verbose = opts.getStringAs<size_t>("verbose");
35 else
36 props.verbose = 0U;
37 if( opts.hasKey("damping") )
38 props.damping = opts.getStringAs<Real>("damping");
39 else
40 props.damping = 0.0;
41 if( opts.hasKey("init") )
42 props.init = opts.getStringAs<Properties::InitType>("init");
43 else
44 props.init = Properties::InitType::UNIFORM;
45 if( opts.hasKey("updates") )
46 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
47 else
48 props.updates = Properties::UpdateType::NAIVE;
49 }
50
51
52 PropertySet MF::getProperties() const {
53 PropertySet opts;
54 opts.set( "tol", props.tol );
55 opts.set( "maxiter", props.maxiter );
56 opts.set( "verbose", props.verbose );
57 opts.set( "damping", props.damping );
58 opts.set( "init", props.init );
59 opts.set( "updates", props.updates );
60 return opts;
61 }
62
63
64 string MF::printProperties() const {
65 stringstream s( stringstream::out );
66 s << "[";
67 s << "tol=" << props.tol << ",";
68 s << "maxiter=" << props.maxiter << ",";
69 s << "verbose=" << props.verbose << ",";
70 s << "init=" << props.init << ",";
71 s << "updates=" << props.updates << ",";
72 s << "damping=" << props.damping << "]";
73 return s.str();
74 }
75
76
77 void MF::construct() {
78 // create beliefs
79 _beliefs.clear();
80 _beliefs.reserve( nrVars() );
81 for( size_t i = 0; i < nrVars(); ++i )
82 _beliefs.push_back( Factor( var(i) ) );
83 }
84
85
86 void MF::init() {
87 if( props.init == Properties::InitType::UNIFORM )
88 for( size_t i = 0; i < nrVars(); i++ )
89 _beliefs[i].fill( 1.0 );
90 else
91 for( size_t i = 0; i < nrVars(); i++ )
92 _beliefs[i].randomize();
93 }
94
95
96 Factor MF::calcNewBelief( size_t i ) {
97 Factor result;
98 bforeach( const Neighbor &I, nbV(i) ) {
99 Factor belief_I_minus_i;
100 bforeach( const Neighbor &j, nbF(I) ) // for all j in I \ i
101 if( j != i )
102 belief_I_minus_i *= _beliefs[j];
103 Factor f_I = factor(I);
104 if( props.updates == Properties::UpdateType::NAIVE )
105 f_I.takeLog(true);
106 Factor msg_I_i = (belief_I_minus_i * f_I).marginal( var(i), false );
107 if( props.updates == Properties::UpdateType::NAIVE )
108 result *= msg_I_i.exp();
109 else
110 result *= msg_I_i;
111 }
112 result.normalize();
113 return result;
114 }
115
116
117 Real MF::run() {
118 if( props.verbose >= 1 )
119 cerr << "Starting " << identify() << "...";
120
121 double tic = toc();
122
123 vector<size_t> update_seq;
124 update_seq.reserve( nrVars() );
125 for( size_t i = 0; i < nrVars(); i++ )
126 update_seq.push_back( i );
127
128 // do several passes over the network until maximum number of iterations has
129 // been reached or until the maximum belief difference is smaller than tolerance
130 Real maxDiff = INFINITY;
131 for( _iters = 0; _iters < props.maxiter && maxDiff > props.tol; _iters++ ) {
132 random_shuffle( update_seq.begin(), update_seq.end(), rnd );
133
134 maxDiff = -INFINITY;
135 bforeach( const size_t &i, update_seq ) {
136 Factor nb = calcNewBelief( i );
137
138 if( nb.hasNaNs() ) {
139 cerr << name() << "::run(): ERROR: new belief of variable " << var(i) << " has NaNs!" << endl;
140 return 1.0;
141 }
142
143 if( props.damping != 0.0 )
144 nb = (nb^(1.0 - props.damping)) * (_beliefs[i]^props.damping);
145
146 maxDiff = std::max( maxDiff, dist( nb, _beliefs[i], DISTLINF ) );
147 _beliefs[i] = nb;
148 }
149
150 if( props.verbose >= 3 )
151 cerr << name() << "::run: maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
152 }
153
154 if( maxDiff > _maxdiff )
155 _maxdiff = maxDiff;
156
157 if( props.verbose >= 1 ) {
158 if( maxDiff > props.tol ) {
159 if( props.verbose == 1 )
160 cerr << endl;
161 cerr << name() << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
162 } else {
163 if( props.verbose >= 3 )
164 cerr << name() << "::run: ";
165 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
166 }
167 }
168
169 return maxDiff;
170 }
171
172
173 Factor MF::beliefV( size_t i ) const {
174 return _beliefs[i].normalized();
175 }
176
177
178 Factor MF::belief (const VarSet &ns) const {
179 if( ns.size() == 0 )
180 return Factor();
181 else if( ns.size() == 1 )
182 return beliefV( findVar( *(ns.begin()) ) );
183 else {
184 DAI_THROW(BELIEF_NOT_AVAILABLE);
185 return Factor();
186 }
187 }
188
189
190 vector<Factor> MF::beliefs() const {
191 vector<Factor> result;
192 for( size_t i = 0; i < nrVars(); i++ )
193 result.push_back( beliefV(i) );
194 return result;
195 }
196
197
198 Real MF::logZ() const {
199 Real s = 0.0;
200
201 for( size_t i = 0; i < nrVars(); i++ )
202 s -= beliefV(i).entropy();
203 for( size_t I = 0; I < nrFactors(); I++ ) {
204 Factor henk;
205 bforeach( const Neighbor &j, nbF(I) ) // for all j in I
206 henk *= _beliefs[j];
207 henk.normalize();
208 Factor piet;
209 piet = factor(I).log(true);
210 piet *= henk;
211 s -= piet.sum();
212 }
213
214 return -s;
215 }
216
217
218 void MF::init( const VarSet &ns ) {
219 for( size_t i = 0; i < nrVars(); i++ )
220 if( ns.contains(var(i) ) ) {
221 if( props.init == Properties::InitType::UNIFORM )
222 _beliefs[i].fill( 1.0 );
223 else
224 _beliefs[i].randomize();
225 }
226 }
227
228
229 } // end of namespace dai
230
231
232 #endif