Merge branch 'pletscher'
[libdai.git] / src / lc.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <algorithm>
14 #include <map>
15 #include <set>
16 #include <dai/lc.h>
17 #include <dai/util.h>
18 #include <dai/alldai.h>
19
20
21 namespace dai {
22
23
24 using namespace std;
25
26
27 const char *LC::Name = "LC";
28
29
30 void LC::setProperties( const PropertySet &opts ) {
31 DAI_ASSERT( opts.hasKey("tol") );
32 DAI_ASSERT( opts.hasKey("maxiter") );
33 DAI_ASSERT( opts.hasKey("verbose") );
34 DAI_ASSERT( opts.hasKey("cavity") );
35 DAI_ASSERT( opts.hasKey("updates") );
36
37 props.tol = opts.getStringAs<double>("tol");
38 props.maxiter = opts.getStringAs<size_t>("maxiter");
39 props.verbose = opts.getStringAs<size_t>("verbose");
40 props.cavity = opts.getStringAs<Properties::CavityType>("cavity");
41 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
42 if( opts.hasKey("cavainame") )
43 props.cavainame = opts.getStringAs<string>("cavainame");
44 if( opts.hasKey("cavaiopts") )
45 props.cavaiopts = opts.getStringAs<PropertySet>("cavaiopts");
46 if( opts.hasKey("reinit") )
47 props.reinit = opts.getStringAs<bool>("reinit");
48 if( opts.hasKey("damping") )
49 props.damping = opts.getStringAs<double>("damping");
50 else
51 props.damping = 0.0;
52 }
53
54
55 PropertySet LC::getProperties() const {
56 PropertySet opts;
57 opts.Set( "tol", props.tol );
58 opts.Set( "maxiter", props.maxiter );
59 opts.Set( "verbose", props.verbose );
60 opts.Set( "cavity", props.cavity );
61 opts.Set( "updates", props.updates );
62 opts.Set( "cavainame", props.cavainame );
63 opts.Set( "cavaiopts", props.cavaiopts );
64 opts.Set( "reinit", props.reinit );
65 opts.Set( "damping", props.damping );
66 return opts;
67 }
68
69
70 string LC::printProperties() const {
71 stringstream s( stringstream::out );
72 s << "[";
73 s << "tol=" << props.tol << ",";
74 s << "maxiter=" << props.maxiter << ",";
75 s << "verbose=" << props.verbose << ",";
76 s << "cavity=" << props.cavity << ",";
77 s << "updates=" << props.updates << ",";
78 s << "cavainame=" << props.cavainame << ",";
79 s << "cavaiopts=" << props.cavaiopts << ",";
80 s << "reinit=" << props.reinit << ",";
81 s << "damping=" << props.damping << "]";
82 return s.str();
83 }
84
85
86 LC::LC( const FactorGraph & fg, const PropertySet &opts ) : DAIAlgFG(fg), _pancakes(), _cavitydists(), _phis(), _beliefs(), _maxdiff(0.0), _iters(0), props() {
87 setProperties( opts );
88
89 // create pancakes
90 _pancakes.resize( nrVars() );
91
92 // create cavitydists
93 for( size_t i=0; i < nrVars(); i++ )
94 _cavitydists.push_back(Factor( delta(i) ));
95
96 // create phis
97 _phis.reserve( nrVars() );
98 for( size_t i = 0; i < nrVars(); i++ ) {
99 _phis.push_back( vector<Factor>() );
100 _phis[i].reserve( nbV(i).size() );
101 foreach( const Neighbor &I, nbV(i) )
102 _phis[i].push_back( Factor( factor(I).vars() / var(i) ) );
103 }
104
105 // create beliefs
106 _beliefs.reserve( nrVars() );
107 for( size_t i=0; i < nrVars(); i++ )
108 _beliefs.push_back(Factor(var(i)));
109 }
110
111
112 string LC::identify() const {
113 return string(Name) + printProperties();
114 }
115
116
117 void LC::CalcBelief (size_t i) {
118 _beliefs[i] = _pancakes[i].marginal(var(i));
119 }
120
121
122 double LC::CalcCavityDist (size_t i, const std::string &name, const PropertySet &opts) {
123 Factor Bi;
124 double maxdiff = 0;
125
126 if( props.verbose >= 2 )
127 cerr << "Initing cavity " << var(i) << "(" << delta(i).size() << " vars, " << delta(i).nrStates() << " states)" << endl;
128
129 if( props.cavity == Properties::CavityType::UNIFORM )
130 Bi = Factor(delta(i));
131 else {
132 InfAlg *cav = newInfAlg( name, *this, opts );
133 cav->makeCavity( i );
134
135 if( props.cavity == Properties::CavityType::FULL )
136 Bi = calcMarginal( *cav, cav->fg().delta(i), props.reinit );
137 else if( props.cavity == Properties::CavityType::PAIR )
138 Bi = calcMarginal2ndO( *cav, cav->fg().delta(i), props.reinit );
139 else if( props.cavity == Properties::CavityType::PAIR2 ) {
140 vector<Factor> pairbeliefs = calcPairBeliefsNew( *cav, cav->fg().delta(i), props.reinit );
141 for( size_t ij = 0; ij < pairbeliefs.size(); ij++ )
142 Bi *= pairbeliefs[ij];
143 }
144 maxdiff = cav->maxDiff();
145 delete cav;
146 }
147 Bi.normalize();
148 _cavitydists[i] = Bi;
149
150 return maxdiff;
151 }
152
153
154 double LC::InitCavityDists( const std::string &name, const PropertySet &opts ) {
155 double tic = toc();
156
157 if( props.verbose >= 1 ) {
158 cerr << Name << "::InitCavityDists: ";
159 if( props.cavity == Properties::CavityType::UNIFORM )
160 cerr << "Using uniform initial cavity distributions" << endl;
161 else if( props.cavity == Properties::CavityType::FULL )
162 cerr << "Using full " << name << opts << "...";
163 else if( props.cavity == Properties::CavityType::PAIR )
164 cerr << "Using pairwise " << name << opts << "...";
165 else if( props.cavity == Properties::CavityType::PAIR2 )
166 cerr << "Using pairwise(new) " << name << opts << "...";
167 }
168
169 double maxdiff = 0.0;
170 for( size_t i = 0; i < nrVars(); i++ ) {
171 double md = CalcCavityDist(i, name, opts);
172 if( md > maxdiff )
173 maxdiff = md;
174 }
175
176 if( props.verbose >= 1 ) {
177 cerr << Name << "::InitCavityDists used " << toc() - tic << " seconds." << endl;
178 }
179
180 return maxdiff;
181 }
182
183
184 long LC::SetCavityDists( std::vector<Factor> &Q ) {
185 if( props.verbose >= 1 )
186 cerr << Name << "::SetCavityDists: Setting initial cavity distributions" << endl;
187 if( Q.size() != nrVars() )
188 return -1;
189 for( size_t i = 0; i < nrVars(); i++ ) {
190 if( _cavitydists[i].vars() != Q[i].vars() ) {
191 return i+1;
192 } else
193 _cavitydists[i] = Q[i];
194 }
195 return 0;
196 }
197
198
199 void LC::init() {
200 for( size_t i = 0; i < nrVars(); ++i )
201 foreach( const Neighbor &I, nbV(i) )
202 if( props.updates == Properties::UpdateType::SEQRND )
203 _phis[i][I.iter].randomize();
204 else
205 _phis[i][I.iter].fill(1.0);
206 }
207
208
209 Factor LC::NewPancake (size_t i, size_t _I, bool & hasNaNs) {
210 size_t I = nbV(i)[_I];
211 Factor piet = _pancakes[i];
212
213 // recalculate _pancake[i]
214 VarSet Ivars = factor(I).vars();
215 Factor A_I;
216 for( VarSet::const_iterator k = Ivars.begin(); k != Ivars.end(); k++ )
217 if( var(i) != *k )
218 A_I *= (_pancakes[findVar(*k)] * factor(I).inverse()).marginal( Ivars / var(i), false );
219 if( Ivars.size() > 1 )
220 A_I ^= (1.0 / (Ivars.size() - 1));
221 Factor A_Ii = (_pancakes[i] * factor(I).inverse() * _phis[i][_I].inverse()).marginal( Ivars / var(i), false );
222 Factor quot = A_I / A_Ii;
223 if( props.damping != 0.0 )
224 quot = (quot^(1.0 - props.damping)) * (_phis[i][_I]^props.damping);
225
226 piet *= quot / _phis[i][_I].normalized();
227 _phis[i][_I] = quot.normalized();
228
229 piet.normalize();
230
231 if( piet.hasNaNs() ) {
232 cerr << Name << "::NewPancake(" << i << ", " << _I << "): has NaNs!" << endl;
233 hasNaNs = true;
234 }
235
236 return piet;
237 }
238
239
240 double LC::run() {
241 if( props.verbose >= 1 )
242 cerr << "Starting " << identify() << "...";
243 if( props.verbose >= 2 )
244 cerr << endl;
245
246 double tic = toc();
247 Diffs diffs(nrVars(), 1.0);
248
249 double md = InitCavityDists( props.cavainame, props.cavaiopts );
250 if( md > _maxdiff )
251 _maxdiff = md;
252
253 for( size_t i = 0; i < nrVars(); i++ ) {
254 _pancakes[i] = _cavitydists[i];
255
256 foreach( const Neighbor &I, nbV(i) ) {
257 _pancakes[i] *= factor(I);
258 if( props.updates == Properties::UpdateType::SEQRND )
259 _pancakes[i] *= _phis[i][I.iter];
260 }
261
262 _pancakes[i].normalize();
263
264 CalcBelief(i);
265 }
266
267 vector<Factor> old_beliefs;
268 for(size_t i=0; i < nrVars(); i++ )
269 old_beliefs.push_back(belief(i));
270
271 bool hasNaNs = false;
272 for( size_t i=0; i < nrVars(); i++ )
273 if( _pancakes[i].hasNaNs() ) {
274 hasNaNs = true;
275 break;
276 }
277 if( hasNaNs ) {
278 cerr << Name << "::run: initial _pancakes has NaNs!" << endl;
279 return 1.0;
280 }
281
282 size_t nredges = nrEdges();
283 vector<Edge> update_seq;
284 update_seq.reserve( nredges );
285 for( size_t i = 0; i < nrVars(); ++i )
286 foreach( const Neighbor &I, nbV(i) )
287 update_seq.push_back( Edge( i, I.iter ) );
288
289 // do several passes over the network until maximum number of iterations has
290 // been reached or until the maximum belief difference is smaller than tolerance
291 for( _iters=0; _iters < props.maxiter && diffs.maxDiff() > props.tol; _iters++ ) {
292 // Sequential updates
293 if( props.updates == Properties::UpdateType::SEQRND )
294 random_shuffle( update_seq.begin(), update_seq.end() );
295
296 for( size_t t=0; t < nredges; t++ ) {
297 size_t i = update_seq[t].first;
298 size_t _I = update_seq[t].second;
299 _pancakes[i] = NewPancake( i, _I, hasNaNs);
300 if( hasNaNs )
301 return 1.0;
302 CalcBelief( i );
303 }
304
305 // compare new beliefs with old ones
306 for(size_t i=0; i < nrVars(); i++ ) {
307 diffs.push( dist( belief(i), old_beliefs[i], Prob::DISTLINF ) );
308 old_beliefs[i] = belief(i);
309 }
310
311 if( props.verbose >= 3 )
312 cerr << Name << "::run: maxdiff " << diffs.maxDiff() << " after " << _iters+1 << " passes" << endl;
313 }
314
315 if( diffs.maxDiff() > _maxdiff )
316 _maxdiff = diffs.maxDiff();
317
318 if( props.verbose >= 1 ) {
319 if( diffs.maxDiff() > props.tol ) {
320 if( props.verbose == 1 )
321 cerr << endl;
322 cerr << Name << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << diffs.maxDiff() << endl;
323 } else {
324 if( props.verbose >= 2 )
325 cerr << Name << "::run: ";
326 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
327 }
328 }
329
330 return diffs.maxDiff();
331 }
332
333
334 } // end of namespace dai