Improved properties.h/cpp and added unit tests
[libdai.git] / src / mf.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <sstream>
14 #include <map>
15 #include <set>
16 #include <dai/mf.h>
17 #include <dai/util.h>
18
19
20 namespace dai {
21
22
23 using namespace std;
24
25
26 const char *MF::Name = "MF";
27
28
29 void MF::setProperties( const PropertySet &opts ) {
30 DAI_ASSERT( opts.hasKey("tol") );
31 DAI_ASSERT( opts.hasKey("maxiter") );
32
33 props.tol = opts.getStringAs<Real>("tol");
34 props.maxiter = opts.getStringAs<size_t>("maxiter");
35 if( opts.hasKey("verbose") )
36 props.verbose = opts.getStringAs<size_t>("verbose");
37 else
38 props.verbose = 0U;
39 if( opts.hasKey("damping") )
40 props.damping = opts.getStringAs<Real>("damping");
41 else
42 props.damping = 0.0;
43 }
44
45
46 PropertySet MF::getProperties() const {
47 PropertySet opts;
48 opts.set( "tol", props.tol );
49 opts.set( "maxiter", props.maxiter );
50 opts.set( "verbose", props.verbose );
51 opts.set( "damping", props.damping );
52 return opts;
53 }
54
55
56 string MF::printProperties() const {
57 stringstream s( stringstream::out );
58 s << "[";
59 s << "tol=" << props.tol << ",";
60 s << "maxiter=" << props.maxiter << ",";
61 s << "verbose=" << props.verbose << ",";
62 s << "damping=" << props.damping << "]";
63 return s.str();
64 }
65
66
67 void MF::construct() {
68 // create beliefs
69 _beliefs.clear();
70 _beliefs.reserve( nrVars() );
71 for( size_t i = 0; i < nrVars(); ++i )
72 _beliefs.push_back( Factor( var(i) ) );
73 }
74
75
76 string MF::identify() const {
77 return string(Name) + printProperties();
78 }
79
80
81 void MF::init() {
82 for( vector<Factor>::iterator qi = _beliefs.begin(); qi != _beliefs.end(); qi++ )
83 qi->fill(1.0);
84 }
85
86
87 Factor MF::calcNewBelief( size_t i ) {
88 Factor result;
89 foreach( const Neighbor &I, nbV(i) ) {
90 Factor henk;
91 foreach( const Neighbor &j, nbF(I) ) // for all j in I \ i
92 if( j != i )
93 henk *= _beliefs[j];
94 Factor piet = factor(I).log(true);
95 piet *= henk;
96 piet = piet.marginal(var(i), false);
97 piet = piet.exp();
98 result *= piet;
99 }
100 result.normalize();
101 return result;
102 }
103
104
105 Real MF::run() {
106 if( props.verbose >= 1 )
107 cerr << "Starting " << identify() << "...";
108
109 double tic = toc();
110
111 vector<size_t> update_seq;
112 update_seq.reserve( nrVars() );
113 for( size_t i = 0; i < nrVars(); i++ )
114 update_seq.push_back( i );
115
116 // do several passes over the network until maximum number of iterations has
117 // been reached or until the maximum belief difference is smaller than tolerance
118 Real maxDiff = INFINITY;
119 for( _iters = 0; _iters < props.maxiter && maxDiff > props.tol; _iters++ ) {
120 random_shuffle( update_seq.begin(), update_seq.end() );
121
122 maxDiff = -INFINITY;
123 foreach( const size_t &i, update_seq ) {
124 Factor nb = calcNewBelief( i );
125
126 if( nb.hasNaNs() ) {
127 cerr << Name << "::run(): ERROR: new belief of variable " << var(i) << " has NaNs!" << endl;
128 return 1.0;
129 }
130
131 if( props.damping != 0.0 )
132 nb = (nb^(1.0 - props.damping)) * (_beliefs[i]^props.damping);
133
134 maxDiff = std::max( maxDiff, dist( nb, _beliefs[i], Prob::DISTLINF ) );
135 _beliefs[i] = nb;
136 }
137
138 if( props.verbose >= 3 )
139 cerr << Name << "::run: maxdiff " << maxDiff << " after " << _iters+1 << " passes" << endl;
140 }
141
142 if( maxDiff > _maxdiff )
143 _maxdiff = maxDiff;
144
145 if( props.verbose >= 1 ) {
146 if( maxDiff > props.tol ) {
147 if( props.verbose == 1 )
148 cerr << endl;
149 cerr << Name << "::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " seconds)...final maxdiff:" << maxDiff << endl;
150 } else {
151 if( props.verbose >= 3 )
152 cerr << Name << "::run: ";
153 cerr << "converged in " << _iters << " passes (" << toc() - tic << " seconds)." << endl;
154 }
155 }
156
157 return maxDiff;
158 }
159
160
161 Factor MF::beliefV( size_t i ) const {
162 return _beliefs[i].normalized();
163 }
164
165
166 Factor MF::belief (const VarSet &ns) const {
167 if( ns.size() == 0 )
168 return Factor();
169 else if( ns.size() == 1 )
170 return beliefV( findVar( *(ns.begin()) ) );
171 else {
172 DAI_THROW(BELIEF_NOT_AVAILABLE);
173 return Factor();
174 }
175 }
176
177
178 vector<Factor> MF::beliefs() const {
179 vector<Factor> result;
180 for( size_t i = 0; i < nrVars(); i++ )
181 result.push_back( beliefV(i) );
182 return result;
183 }
184
185
186 Real MF::logZ() const {
187 Real s = 0.0;
188
189 for( size_t i = 0; i < nrVars(); i++ )
190 s -= beliefV(i).entropy();
191 for( size_t I = 0; I < nrFactors(); I++ ) {
192 Factor henk;
193 foreach( const Neighbor &j, nbF(I) ) // for all j in I
194 henk *= _beliefs[j];
195 henk.normalize();
196 Factor piet;
197 piet = factor(I).log(true);
198 piet *= henk;
199 s -= piet.sum();
200 }
201
202 return -s;
203 }
204
205
206 void MF::init( const VarSet &ns ) {
207 for( size_t i = 0; i < nrVars(); i++ ) {
208 if( ns.contains(var(i) ) )
209 _beliefs[i].fill( 1.0 );
210 }
211 }
212
213
214 } // end of namespace dai