153df2747184055d5145e3c1e9f9215db1e04d2a
[libdai.git] / src / lc.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <iostream>
10 #include <algorithm>
11 #include <map>
12 #include <set>
13 #include <dai/lc.h>
14 #include <dai/util.h>
15 #include <dai/alldai.h>
16
17
18 namespace dai {
19
20
21 using namespace std;
22
23
24 void LC::setProperties( const PropertySet &opts ) {
25 DAI_ASSERT( opts.hasKey("tol") );
26 DAI_ASSERT( opts.hasKey("maxiter") );
27 DAI_ASSERT( opts.hasKey("cavity") );
28 DAI_ASSERT( opts.hasKey("updates") );
29
30 props.tol = opts.getStringAs<Real>("tol");
31 props.maxiter = opts.getStringAs<size_t>("maxiter");
32 props.cavity = opts.getStringAs<Properties::CavityType>("cavity");
33 props.updates = opts.getStringAs<Properties::UpdateType>("updates");
34 if( opts.hasKey("verbose") )
35 props.verbose = opts.getStringAs<size_t>("verbose");
36 else
37 props.verbose = 0;
38 if( opts.hasKey("cavainame") )
39 props.cavainame = opts.getStringAs<string>("cavainame");
40 if( opts.hasKey("cavaiopts") )
41 props.cavaiopts = opts.getStringAs<PropertySet>("cavaiopts");
42 if( opts.hasKey("reinit") )
43 props.reinit = opts.getStringAs<bool>("reinit");
44 if( opts.hasKey("damping") )
45 props.damping = opts.getStringAs<Real>("damping");
46 else
47 props.damping = 0.0;
48 }
49
50
51 PropertySet LC::getProperties() const {
52 PropertySet opts;
53 opts.set( "tol", props.tol );
54 opts.set( "maxiter", props.maxiter );
55 opts.set( "verbose", props.verbose );
56 opts.set( "cavity", props.cavity );
57 opts.set( "updates", props.updates );
58 opts.set( "cavainame", props.cavainame );
59 opts.set( "cavaiopts", props.cavaiopts );
60 opts.set( "reinit", props.reinit );
61 opts.set( "damping", props.damping );
62 return opts;
63 }
64
65
66 string LC::printProperties() const {
67 stringstream s( stringstream::out );
68 s << "[";
69 s << "tol=" << props.tol << ",";
70 s << "maxiter=" << props.maxiter << ",";
71 s << "verbose=" << props.verbose << ",";
72 s << "cavity=" << props.cavity << ",";
73 s << "updates=" << props.updates << ",";
74 s << "cavainame=" << props.cavainame << ",";
75 s << "cavaiopts=" << props.cavaiopts << ",";
76 s << "reinit=" << props.reinit << ",";
77 s << "damping=" << props.damping << "]";
78 return s.str();
79 }
80
81
82 LC::LC( const FactorGraph & fg, const PropertySet &opts ) : DAIAlgFG(fg), _pancakes(), _cavitydists(), _phis(), _beliefs(), _maxdiff(0.0), _iters(0), props() {
83 setProperties( opts );
84
85 // create pancakes
86 _pancakes.resize( nrVars() );
87
88 // create cavitydists
89 for( size_t i=0; i < nrVars(); i++ )
90 _cavitydists.push_back(Factor( delta(i) ));
91
92 // create phis
93 _phis.reserve( nrVars() );
94 for( size_t i = 0; i < nrVars(); i++ ) {
95 _phis.push_back( vector<Factor>() );
96 _phis[i].reserve( nbV(i).size() );
97 bforeach( const Neighbor &I, nbV(i) )
98 _phis[i].push_back( Factor( factor(I).vars() / var(i) ) );
99 }
100
101 // create beliefs
102 _beliefs.reserve( nrVars() );
103 for( size_t i=0; i < nrVars(); i++ )
104 _beliefs.push_back(Factor(var(i)));
105 }
106
107
108 void LC::CalcBelief (size_t i) {
109 _beliefs[i] = _pancakes[i].marginal(var(i));
110 }
111
112
113 Factor LC::belief (const VarSet &ns) const {
114 if( ns.size() == 0 )
115 return Factor();
116 else if( ns.size() == 1 )
117 return beliefV( findVar( *(ns.begin()) ) );
118 else {
119 DAI_THROW(BELIEF_NOT_AVAILABLE);
120 return Factor();
121 }
122 }
123
124
125 Real LC::CalcCavityDist (size_t i, const std::string &name, const PropertySet &opts) {
126 Factor Bi;
127 Real maxdiff = 0;
128
129 if( props.verbose >= 2 )
130 cerr << "Initing cavity " << var(i) << "(" << delta(i).size() << " vars, " << delta(i).nrStates() << " states)" << endl;
131
132 if( props.cavity == Properties::CavityType::UNIFORM )
133 Bi = Factor(delta(i));
134 else {
135 InfAlg *cav = newInfAlg( name, *this, opts );
136 cav->makeCavity( i );
137
138 if( props.cavity == Properties::CavityType::FULL )
139 Bi = calcMarginal( *cav, cav->fg().delta(i), props.reinit );
140 else if( props.cavity == Properties::CavityType::PAIR ) {
141 vector<Factor> pairbeliefs = calcPairBeliefs( *cav, cav->fg().delta(i), props.reinit, false );
142 for( size_t ij = 0; ij < pairbeliefs.size(); ij++ )
143 Bi *= pairbeliefs[ij];
144 } else if( props.cavity == Properties::CavityType::PAIR2 ) {
145 vector<Factor> pairbeliefs = calcPairBeliefs( *cav, cav->fg().delta(i), props.reinit, true );
146 for( size_t ij = 0; ij < pairbeliefs.size(); ij++ )
147 Bi *= pairbeliefs[ij];
148 }
149 maxdiff = cav->maxDiff();
150 delete cav;
151 }
152 Bi.normalize();
153 _cavitydists[i] = Bi;
154
155 return maxdiff;
156 }
157
158
159 Real LC::InitCavityDists( const std::string &name, const PropertySet &opts ) {
160 double tic = toc();
161
162 if( props.verbose >= 1 ) {
163 cerr << this->name() << "::InitCavityDists: ";
164 if( props.cavity == Properties::CavityType::UNIFORM )
165 cerr << "Using uniform initial cavity distributions" << endl;
166 else if( props.cavity == Properties::CavityType::FULL )
167 cerr << "Using full " << name << opts << "...";
168 else if( props.cavity == Properties::CavityType::PAIR )
169 cerr << "Using pairwise " << name << opts << "...";
170 else if( props.cavity == Properties::CavityType::PAIR2 )
171 cerr << "Using pairwise(new) " << name << opts << "...";
172 }
173
174 Real maxdiff = 0.0;
175 for( size_t i = 0; i < nrVars(); i++ ) {
176 Real md = CalcCavityDist(i, name, opts);
177 if( md > maxdiff )
178 maxdiff = md;
179 }
180
181 if( props.verbose >= 1 ) {
182 cerr << this->name() << "::InitCavityDists used " << toc() - tic << " seconds." << endl;
183 }
184
185 return maxdiff;
186 }
187
188
189 long LC::SetCavityDists( std::vector<Factor> &Q ) {
190 if( props.verbose >= 1 )
191 cerr << name() << "::SetCavityDists: Setting initial cavity distributions" << endl;
192 if( Q.size() != nrVars() )
193 return -1;
194 for( size_t i = 0; i < nrVars(); i++ ) {
195 if( _cavitydists[i].vars() != Q[i].vars() ) {
196 return i+1;
197 } else
198 _cavitydists[i] = Q[i];
199 }
200 return 0;
201 }
202
203
204 void LC::init() {
205 for( size_t i = 0; i < nrVars(); ++i )
206 bforeach( const Neighbor &I, nbV(i) )
207 if( props.updates == Properties::UpdateType::SEQRND )
208 _phis[i][I.iter].randomize();
209 else
210 _phis[i][I.iter].fill(1.0);
211 }
212
213
214 Factor LC::NewPancake (size_t i, size_t _I, bool & hasNaNs) {
215 size_t I = nbV(i)[_I];
216 Factor piet = _pancakes[i];
217
218 // recalculate _pancake[i]
219 VarSet Ivars = factor(I).vars();
220 Factor A_I;
221 for( VarSet::const_iterator k = Ivars.begin(); k != Ivars.end(); k++ )
222 if( var(i) != *k )
223 A_I *= (_pancakes[findVar(*k)] * factor(I).inverse()).marginal( Ivars / var(i), false );
224 if( Ivars.size() > 1 )
225 A_I ^= (1.0 / (Ivars.size() - 1));
226 Factor A_Ii = (_pancakes[i] * factor(I).inverse() * _phis[i][_I].inverse()).marginal( Ivars / var(i), false );
227 Factor quot = A_I / A_Ii;
228 if( props.damping != 0.0 )
229 quot = (quot^(1.0 - props.damping)) * (_phis[i][_I]^props.damping);
230
231 piet *= quot / _phis[i][_I].normalized();
232 _phis[i][_I] = quot.normalized();
233
234 piet.normalize();
235
236 if( piet.hasNaNs() ) {
237 cerr << name() << "::NewPancake(" << i << ", " << _I << "): has NaNs!" << endl;
238 hasNaNs = true;
239 }
240
241 return piet;
242 }
243
244
245 Real LC::run() {
246 if( props.verbose >= 1 )
247 cerr << "Starting " << identify() << "...";
248 if( props.verbose >= 2 )
249 cerr << endl;
250
251 double tic = toc();
252
253 Real md = InitCavityDists( props.cavainame, props.cavaiopts );
254 if( md > _maxdiff )
255 _maxdiff = md;
256
257 for( size_t i = 0; i < nrVars(); i++ ) {
258 _pancakes[i] = _cavitydists[i];
259
260 bforeach( const Neighbor &I, nbV(i) ) {
261 _pancakes[i] *= factor(I);
262 if( props.updates == Properties::UpdateType::SEQRND )
263 _pancakes[i] *= _phis[i][I.iter];
264 }
265
266 _pancakes[i].normalize();
267
268 CalcBelief(i);
269 }
270
271 vector<Factor> oldBeliefsV;
272 for( size_t i = 0; i < nrVars(); i++ )
273 oldBeliefsV.push_back( beliefV(i) );
274
275 bool hasNaNs = false;
276 for( size_t i=0; i < nrVars(); i++ )
277 if( _pancakes[i].hasNaNs() ) {
278 hasNaNs = true;
279 break;
280 }
281 if( hasNaNs ) {
282 cerr << name() << "::run: initial _pancakes has NaNs!" << endl;
283 return 1.0;
284 }
285
286 size_t nredges = nrEdges();
287 vector<Edge> update_seq;
288 update_seq.reserve( nredges );
289 for( size_t i = 0; i < nrVars(); ++i )
290 bforeach( const Neighbor &I, nbV(i) )
291 update_seq.push_back( Edge( i, I.iter ) );
292
293 // do several passes over the network until maximum number of iterations has
294 // been reached or until the maximum belief difference is smaller than tolerance
295 Real maxDiff = INFINITY;
296 for( _iters = 0; _iters < props.maxiter && maxDiff > props.tol; _iters++ ) {
297 // Sequential updates
298 if( props.updates == Properties::UpdateType::SEQRND )
299 random_shuffle( update_seq.begin(), update_seq.end(), rnd );
300
301 for( size_t t=0; t < nredges; t++ ) {
302 size_t i = update_seq[t].first;
303 size_t _I = update_seq[t].second;
304 _pancakes[i] = NewPancake( i, _I, hasNaNs);
305 if( hasNaNs )
306 return 1.0;
307 CalcBelief( i );
308 }
309
310 // compare new beliefs with old ones
311 maxDiff = -INFINITY;
312 for( size_t i = 0; i < nrVars(); i++ ) {
313 maxDiff = std::max( maxDiff, dist( beliefV(i), oldBeliefsV[i], DISTLINF ) );
314 oldBeliefsV[i] = beliefV(i);
315 }
316
317 if( props.verbose >= 3 )
318 cerr << name() << "::run: maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
319 }
320
321 if( maxDiff > _maxdiff )
322 _maxdiff = maxDiff;
323
324 if( props.verbose >= 1 ) {
325 if( maxDiff > props.tol ) {
326 if( props.verbose == 1 )
327 cerr << endl;
328 cerr << name() << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
329 } else {
330 if( props.verbose >= 2 )
331 cerr << name() << "::run: ";
332 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
333 }
334 }
335
336 return maxDiff;
337 }
338
339
340 } // end of namespace dai