582bd5770eafee4180f7bf324750c41e39b6c4cd
[libdai.git] / src / mf.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <sstream>
24 #include <map>
25 #include <set>
26 #include <dai/mf.h>
27 #include <dai/diffs.h>
28 #include <dai/util.h>
29
30
31 namespace dai {
32
33
34 using namespace std;
35
36
37 const char *MF::Name = "MF";
38
39
40 void MF::setProperties( const PropertySet &opts ) {
41 assert( opts.hasKey("tol") );
42 assert( opts.hasKey("maxiter") );
43 assert( opts.hasKey("verbose") );
44
45 props.tol = opts.getStringAs<double>("tol");
46 props.maxiter = opts.getStringAs<size_t>("maxiter");
47 props.verbose = opts.getStringAs<size_t>("verbose");
48 }
49
50
51 PropertySet MF::getProperties() const {
52 PropertySet opts;
53 opts.Set( "tol", props.tol );
54 opts.Set( "maxiter", props.maxiter );
55 opts.Set( "verbose", props.verbose );
56 return opts;
57 }
58
59
60 string MF::printProperties() const {
61 stringstream s( stringstream::out );
62 s << "[";
63 s << "tol=" << props.tol << ",";
64 s << "maxiter=" << props.maxiter << ",";
65 s << "verbose=" << props.verbose << "]";
66 return s.str();
67 }
68
69
70 void MF::create() {
71 // clear beliefs
72 _beliefs.clear();
73 _beliefs.reserve( nrVars() );
74
75 // create beliefs
76 for( size_t i = 0; i < nrVars(); ++i )
77 _beliefs.push_back(Factor(var(i)));
78 }
79
80
81 string MF::identify() const {
82 return string(Name) + printProperties();
83 }
84
85
86 void MF::init() {
87 for( vector<Factor>::iterator qi = _beliefs.begin(); qi != _beliefs.end(); qi++ )
88 qi->fill(1.0);
89 }
90
91
92 double MF::run() {
93 double tic = toc();
94
95 if( props.verbose >= 1 )
96 cout << "Starting " << identify() << "...";
97
98 size_t pass_size = _beliefs.size();
99 Diffs diffs(pass_size * 3, 1.0);
100
101 size_t t=0;
102 for( t=0; t < (props.maxiter*pass_size) && diffs.maxDiff() > props.tol; t++ ) {
103 // choose random Var i
104 size_t i = (size_t) (nrVars() * rnd_uniform());
105
106 Factor jan;
107 Factor piet;
108 foreach( const Neighbor &I, nbV(i) ) {
109 Factor henk;
110 foreach( const Neighbor &j, nbF(I) ) // for all j in I \ i
111 if( j != i )
112 henk *= _beliefs[j];
113 piet = factor(I).log0();
114 piet *= henk;
115 piet = piet.part_sum(var(i));
116 piet = piet.exp();
117 jan *= piet;
118 }
119
120 jan.normalize( Prob::NORMPROB );
121
122 if( jan.hasNaNs() ) {
123 cout << "MF::run(): ERROR: jan has NaNs!" << endl;
124 return 1.0;
125 }
126
127 diffs.push( dist( jan, _beliefs[i], Prob::DISTLINF ) );
128
129 _beliefs[i] = jan;
130 }
131
132 if( diffs.maxDiff() > maxdiff )
133 maxdiff = diffs.maxDiff();
134
135 if( props.verbose >= 1 ) {
136 if( diffs.maxDiff() > props.tol ) {
137 if( props.verbose == 1 )
138 cout << endl;
139 cout << "MF::run: WARNING: not converged within " << props.maxiter << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl;
140 } else {
141 if( props.verbose >= 2 )
142 cout << "MF::run: ";
143 cout << "converged in " << t / pass_size << " passes (" << toc() - tic << " clocks)." << endl;
144 }
145 }
146
147 return diffs.maxDiff();
148 }
149
150
151 Factor MF::beliefV (size_t i) const {
152 Factor piet;
153 piet = _beliefs[i];
154 piet.normalize( Prob::NORMPROB );
155 return(piet);
156 }
157
158
159 Factor MF::belief (const VarSet &ns) const {
160 if( ns.size() == 1 )
161 return belief( *(ns.begin()) );
162 else {
163 assert( ns.size() == 1 );
164 return Factor();
165 }
166 }
167
168
169 Factor MF::belief (const Var &n) const {
170 return( beliefV( findVar( n ) ) );
171 }
172
173
174 vector<Factor> MF::beliefs() const {
175 vector<Factor> result;
176 for( size_t i = 0; i < nrVars(); i++ )
177 result.push_back( beliefV(i) );
178 return result;
179 }
180
181
182 Real MF::logZ() const {
183 Real sum = 0.0;
184
185 for(size_t i=0; i < nrVars(); i++ )
186 sum -= beliefV(i).entropy();
187 for(size_t I=0; I < nrFactors(); I++ ) {
188 Factor henk;
189 foreach( const Neighbor &j, nbF(I) ) // for all j in I
190 henk *= _beliefs[j];
191 henk.normalize( Prob::NORMPROB );
192 Factor piet;
193 piet = factor(I).log0();
194 piet *= henk;
195 sum -= piet.totalSum();
196 }
197
198 return -sum;
199 }
200
201
202 void MF::init( const VarSet &ns ) {
203 for( size_t i = 0; i < nrVars(); i++ ) {
204 if( ns.contains(var(i) ) )
205 _beliefs[i].fill( 1.0 );
206 }
207 }
208
209
210 } // end of namespace dai