Partial adoption of contributions by Giuseppe:
[libdai.git] / src / mf.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <sstream>
24 #include <map>
25 #include <set>
26 #include <dai/mf.h>
27 #include <dai/diffs.h>
28 #include <dai/util.h>
29
30
31 namespace dai {
32
33
34 using namespace std;
35
36
37 const char *MF::Name = "MF";
38
39
40 bool MF::checkProperties() {
41 if( !HasProperty("tol") )
42 return false;
43 if (!HasProperty("maxiter") )
44 return false;
45 if (!HasProperty("verbose") )
46 return false;
47
48 ConvertPropertyTo<double>("tol");
49 ConvertPropertyTo<size_t>("maxiter");
50 ConvertPropertyTo<size_t>("verbose");
51
52 return true;
53 }
54
55
56 void MF::Regenerate() {
57 DAIAlgFG::Regenerate();
58
59 // clear beliefs
60 _beliefs.clear();
61 _beliefs.reserve( nrVars() );
62
63 // create beliefs
64 for( vector<Var>::const_iterator i = vars().begin(); i != vars().end(); i++ )
65 _beliefs.push_back(Factor(*i));
66 }
67
68
69 string MF::identify() const {
70 stringstream result (stringstream::out);
71 result << Name << GetProperties();
72 return result.str();
73 }
74
75
76 void MF::init() {
77 assert( checkProperties() );
78
79 for( vector<Factor>::iterator qi = _beliefs.begin(); qi != _beliefs.end(); qi++ )
80 qi->fill(1.0);
81 }
82
83
84 double MF::run() {
85 clock_t tic = toc();
86
87 if( Verbose() >= 1 )
88 cout << "Starting " << identify() << "...";
89
90 size_t pass_size = _beliefs.size();
91 Diffs diffs(pass_size * 3, 1.0);
92
93 size_t t=0;
94 for( t=0; t < (MaxIter()*pass_size) && diffs.max() > Tol(); t++ ) {
95 // choose random Var i
96 size_t i = (size_t) (nrVars() * rnd_uniform());
97
98 Factor jan;
99 Factor piet;
100 for( _nb_cit I = nb1(i).begin(); I != nb1(i).end(); I++ ) {
101
102 Factor henk;
103 for( _nb_cit j = nb2(*I).begin(); j != nb2(*I).end(); j++ ) // for all j in I \ i
104 if( *j != i )
105 henk *= _beliefs[*j];
106 piet = factor(*I).log0();
107 piet *= henk;
108 piet = piet.part_sum(var(i));
109 piet = piet.exp();
110 jan *= piet;
111 }
112
113 jan.normalize( _normtype );
114
115 if( jan.hasNaNs() ) {
116 cout << "MF::run(): ERROR: jan has NaNs!" << endl;
117 return NAN;
118 }
119
120 diffs.push( dist( jan, _beliefs[i], Prob::DISTLINF ) );
121
122 _beliefs[i] = jan;
123 }
124
125 updateMaxDiff( diffs.max() );
126
127 if( Verbose() >= 1 ) {
128 if( diffs.max() > Tol() ) {
129 if( Verbose() == 1 )
130 cout << endl;
131 cout << "MF::run: WARNING: not converged within " << MaxIter() << " passes (" << toc() - tic << " clocks)...final maxdiff:" << diffs.max() << endl;
132 } else {
133 if( Verbose() >= 2 )
134 cout << "MF::run: ";
135 cout << "converged in " << t / pass_size << " passes (" << toc() - tic << " clocks)." << endl;
136 }
137 }
138
139 return diffs.max();
140 }
141
142
143 Factor MF::belief1 (size_t i) const {
144 Factor piet;
145 piet = _beliefs[i];
146 piet.normalize( Prob::NORMPROB );
147 return(piet);
148 }
149
150
151 Factor MF::belief (const VarSet &ns) const {
152 if( ns.size() == 1 )
153 return belief( *(ns.begin()) );
154 else {
155 assert( ns.size() == 1 );
156 return Factor();
157 }
158 }
159
160
161 Factor MF::belief (const Var &n) const {
162 return( belief1( findVar( n) ) );
163 }
164
165
166 vector<Factor> MF::beliefs() const {
167 vector<Factor> result;
168 for( size_t i = 0; i < nrVars(); i++ )
169 result.push_back( belief1(i) );
170 return result;
171 }
172
173
174 Complex MF::logZ() const {
175 Complex sum = 0.0;
176
177 for(size_t i=0; i < nrVars(); i++ )
178 sum -= belief1(i).entropy();
179 for(size_t I=0; I < nrFactors(); I++ ) {
180 Factor henk;
181 for( _nb_cit j = nb2(I).begin(); j != nb2(I).end(); j++ ) // for all j in I
182 henk *= _beliefs[*j];
183 henk.normalize( Prob::NORMPROB );
184 Factor piet;
185 piet = factor(I).log0();
186 piet *= henk;
187 sum -= Complex( piet.totalSum() );
188 }
189
190 return -sum;
191 }
192
193
194 void MF::init( const VarSet &ns ) {
195 for( size_t i = 0; i < nrVars(); i++ ) {
196 if( ns && var(i) )
197 _beliefs[i].fill( 1.0 );
198 }
199 }
200
201
202 } // end of namespace dai