Several small changes
[libdai.git] / src / mf.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <sstream>
24 #include <map>
25 #include <set>
26 #include <dai/mf.h>
27 #include <dai/diffs.h>
28 #include <dai/util.h>
29
30
31 namespace dai {
32
33
34 using namespace std;
35
36
37 const char *MF::Name = "MF";
38
39
40 bool MF::checkProperties() {
41 if( !HasProperty("tol") )
42 return false;
43 if (!HasProperty("maxiter") )
44 return false;
45 if (!HasProperty("verbose") )
46 return false;
47
48 ConvertPropertyTo<double>("tol");
49 ConvertPropertyTo<size_t>("maxiter");
50 ConvertPropertyTo<size_t>("verbose");
51
52 return true;
53 }
54
55
56 void MF::create() {
57 // clear beliefs
58 _beliefs.clear();
59 _beliefs.reserve( nrVars() );
60
61 // create beliefs
62 for( size_t i = 0; i < nrVars(); ++i )
63 _beliefs.push_back(Factor(var(i)));
64 }
65
66
67 string MF::identify() const {
68 stringstream result (stringstream::out);
69 result << Name << GetProperties();
70 return result.str();
71 }
72
73
74 void MF::init() {
75 assert( checkProperties() );
76
77 for( vector<Factor>::iterator qi = _beliefs.begin(); qi != _beliefs.end(); qi++ )
78 qi->fill(1.0);
79 }
80
81
82 double MF::run() {
83 double tic = toc();
84
85 if( Verbose() >= 1 )
86 cout << "Starting " << identify() << "...";
87
88 size_t pass_size = _beliefs.size();
89 Diffs diffs(pass_size * 3, 1.0);
90
91 size_t t=0;
92 for( t=0; t < (MaxIter()*pass_size) && diffs.maxDiff() > Tol(); t++ ) {
93 // choose random Var i
94 size_t i = (size_t) (nrVars() * rnd_uniform());
95
96 Factor jan;
97 Factor piet;
98 foreach( const Neighbor &I, nbV(i) ) {
99 Factor henk;
100 foreach( const Neighbor &j, nbF(I) ) // for all j in I \ i
101 if( j != i )
102 henk *= _beliefs[j];
103 piet = factor(I).log0();
104 piet *= henk;
105 piet = piet.part_sum(var(i));
106 piet = piet.exp();
107 jan *= piet;
108 }
109
110 jan.normalize( Prob::NORMPROB );
111
112 if( jan.hasNaNs() ) {
113 cout << "MF::run(): ERROR: jan has NaNs!" << endl;
114 return 1.0;
115 }
116
117 diffs.push( dist( jan, _beliefs[i], Prob::DISTLINF ) );
118
119 _beliefs[i] = jan;
120 }
121
122 updateMaxDiff( diffs.maxDiff() );
123
124 if( Verbose() >= 1 ) {
125 if( diffs.maxDiff() > Tol() ) {
126 if( Verbose() == 1 )
127 cout << endl;
128 cout << "MF::run: WARNING: not converged within " << MaxIter() << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.maxDiff() << endl;
129 } else {
130 if( Verbose() >= 2 )
131 cout << "MF::run: ";
132 cout << "converged in " << t / pass_size << " passes (" << toc() - tic << " clocks)." << endl;
133 }
134 }
135
136 return diffs.maxDiff();
137 }
138
139
140 Factor MF::beliefV (size_t i) const {
141 Factor piet;
142 piet = _beliefs[i];
143 piet.normalize( Prob::NORMPROB );
144 return(piet);
145 }
146
147
148 Factor MF::belief (const VarSet &ns) const {
149 if( ns.size() == 1 )
150 return belief( *(ns.begin()) );
151 else {
152 assert( ns.size() == 1 );
153 return Factor();
154 }
155 }
156
157
158 Factor MF::belief (const Var &n) const {
159 return( beliefV( findVar( n ) ) );
160 }
161
162
163 vector<Factor> MF::beliefs() const {
164 vector<Factor> result;
165 for( size_t i = 0; i < nrVars(); i++ )
166 result.push_back( beliefV(i) );
167 return result;
168 }
169
170
171 Real MF::logZ() const {
172 Real sum = 0.0;
173
174 for(size_t i=0; i < nrVars(); i++ )
175 sum -= beliefV(i).entropy();
176 for(size_t I=0; I < nrFactors(); I++ ) {
177 Factor henk;
178 foreach( const Neighbor &j, nbF(I) ) // for all j in I
179 henk *= _beliefs[j];
180 henk.normalize( Prob::NORMPROB );
181 Factor piet;
182 piet = factor(I).log0();
183 piet *= henk;
184 sum -= piet.totalSum();
185 }
186
187 return -sum;
188 }
189
190
191 void MF::init( const VarSet &ns ) {
192 for( size_t i = 0; i < nrVars(); i++ ) {
193 if( ns.contains(var(i) ) )
194 _beliefs[i].fill( 1.0 );
195 }
196 }
197
198
199 } // end of namespace dai