7270279a5e6753ffbe18e518ec7b6518b537c5c7
1 /* This file is part of libDAI - http://www.libdai.org/
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
15 #include <dai/jtree.h>
26 void MR::setProperties( const PropertySet
&opts
) {
27 DAI_ASSERT( opts
.hasKey("tol") );
28 DAI_ASSERT( opts
.hasKey("updates") );
29 DAI_ASSERT( opts
.hasKey("inits") );
31 props
.tol
= opts
.getStringAs
<Real
>("tol");
32 props
.updates
= opts
.getStringAs
<Properties::UpdateType
>("updates");
33 props
.inits
= opts
.getStringAs
<Properties::InitType
>("inits");
34 if( opts
.hasKey("verbose") )
35 props
.verbose
= opts
.getStringAs
<size_t>("verbose");
41 PropertySet
MR::getProperties() const {
43 opts
.set( "tol", props
.tol
);
44 opts
.set( "verbose", props
.verbose
);
45 opts
.set( "updates", props
.updates
);
46 opts
.set( "inits", props
.inits
);
51 string
MR::printProperties() const {
52 stringstream
s( stringstream::out
);
54 s
<< "tol=" << props
.tol
<< ",";
55 s
<< "verbose=" << props
.verbose
<< ",";
56 s
<< "updates=" << props
.updates
<< ",";
57 s
<< "inits=" << props
.inits
<< "]";
62 Real
MR::T(size_t i
, sub_nb A
) {
63 sub_nb
_nbi_min_A(G
.nb(i
).size());
68 for( size_t _j
= 0; _j
< _nbi_min_A
.size(); _j
++ )
69 if( _nbi_min_A
.test(_j
) )
70 res
+= atanh(tJ
[i
][_j
] * M
[i
][_j
]);
75 Real
MR::T(size_t i
, size_t _j
) {
76 sub_nb
j(G
.nb(i
).size());
82 Real
MR::Omega(size_t i
, size_t _j
, size_t _l
) {
83 sub_nb
jl(G
.nb(i
).size());
87 return Tijl
/ (1.0 + tJ
[i
][_l
] * M
[i
][_l
] * Tijl
);
91 Real
MR::Gamma(size_t i
, size_t _j
, size_t _l1
, size_t _l2
) {
92 sub_nb
jll(G
.nb(i
).size());
97 Real Tijll
= T(i
,jll
);
99 return (Tijll
- Tij
) / (1.0 + tJ
[i
][_l1
] * tJ
[i
][_l2
] * M
[i
][_l1
] * M
[i
][_l2
] + tJ
[i
][_l1
] * M
[i
][_l1
] * Tijll
+ tJ
[i
][_l2
] * M
[i
][_l2
] * Tijll
);
103 Real
MR::Gamma(size_t i
, size_t _l1
, size_t _l2
) {
104 sub_nb
ll(G
.nb(i
).size());
110 return (Till
- Ti
) / (1.0 + tJ
[i
][_l1
] * tJ
[i
][_l2
] * M
[i
][_l1
] * M
[i
][_l2
] + tJ
[i
][_l1
] * M
[i
][_l1
] * Till
+ tJ
[i
][_l2
] * M
[i
][_l2
] * Till
);
114 Real
MR::_tJ(size_t i
, sub_nb A
) {
115 sub_nb::size_type _j
= A
.find_first();
116 if( _j
== sub_nb::npos
)
119 return tJ
[i
][_j
] * _tJ(i
, A
.reset(_j
));
123 Real
MR::appM(size_t i
, sub_nb A
) {
124 sub_nb::size_type _j
= A
.find_first();
125 if( _j
== sub_nb::npos
)
128 sub_nb
A_j(A
); A_j
.reset(_j
);
130 Real result
= M
[i
][_j
] * appM(i
, A_j
);
131 for( size_t _k
= 0; _k
< A_j
.size(); _k
++ )
133 sub_nb
A_jk(A_j
); A_jk
.reset(_k
);
134 result
+= cors
[i
][_j
][_k
] * appM(i
,A_jk
);
142 void MR::sum_subs(size_t j
, sub_nb A
, Real
*sum_even
, Real
*sum_odd
) {
149 *sum_odd
+= _tJ(j
,B
) * appM(j
,B
);
151 *sum_even
+= _tJ(j
,B
) * appM(j
,B
);
153 // calc next subset B
155 for( ; bit
< A
.size(); bit
++ )
168 void MR::propagateCavityFields() {
169 Real sum_even
, sum_odd
;
171 size_t maxruns
= 1000;
173 for( size_t i
= 0; i
< G
.nrNodes(); i
++ )
174 bforeach( const Neighbor
&j
, G
.nb(i
) )
181 for( size_t i
= 0; i
< G
.nrNodes(); i
++ ) {
182 bforeach( const Neighbor
&j
, G
.nb(i
) ) {
184 size_t _i
= G
.findNb(j
,i
);
185 DAI_ASSERT( G
.nb(j
,_i
) == i
);
188 if( props
.updates
== Properties::UpdateType::FULL
) {
189 // find indices in nb(j) that do not correspond with i
190 sub_nb
_nbj_min_i(G
.nb(j
).size());
192 _nbj_min_i
.reset(_i
);
194 // find indices in nb(i) that do not correspond with j
195 sub_nb
_nbi_min_j(G
.nb(i
).size());
197 _nbi_min_j
.reset(_j
);
199 sum_subs(j
, _nbj_min_i
, &sum_even
, &sum_odd
);
200 newM
= (tanh(theta
[j
]) * sum_even
+ sum_odd
) / (sum_even
+ tanh(theta
[j
]) * sum_odd
);
202 sum_subs(i
, _nbi_min_j
, &sum_even
, &sum_odd
);
203 Real denom
= sum_even
+ tanh(theta
[i
]) * sum_odd
;
205 for(size_t _k
=0; _k
< G
.nb(i
).size(); _k
++) if(_k
!= _j
) {
206 sub_nb
_nbi_min_jk(_nbi_min_j
);
207 _nbi_min_jk
.reset(_k
);
208 sum_subs(i
, _nbi_min_jk
, &sum_even
, &sum_odd
);
209 numer
+= tJ
[i
][_k
] * cors
[i
][_j
][_k
] * (tanh(theta
[i
]) * sum_even
+ sum_odd
);
211 newM
-= numer
/ denom
;
212 } else if( props
.updates
== Properties::UpdateType::LINEAR
) {
214 for(size_t _l
=0; _l
<G
.nb(i
).size(); _l
++) if( _l
!= _j
)
215 newM
-= Omega(i
,_j
,_l
) * tJ
[i
][_l
] * cors
[i
][_j
][_l
];
216 for(size_t _l1
=0; _l1
<G
.nb(j
).size(); _l1
++) if( _l1
!= _i
)
217 for( size_t _l2
=_l1
+1; _l2
<G
.nb(j
).size(); _l2
++) if( _l2
!= _i
)
218 newM
+= Gamma(j
,_i
,_l1
,_l2
) * tJ
[j
][_l1
] * tJ
[j
][_l2
] * cors
[j
][_l1
][_l2
];
221 Real dev
= newM
- M
[i
][_j
];
223 if( abs(dev
) >= maxdev
)
226 newM
= M
[i
][_j
] + dev
;
227 if( abs(newM
) > 1.0 )
228 newM
= (newM
> 0.0) ? 1.0 : -1.0;
232 } while((maxdev
>props
.tol
)&&(run
<maxruns
));
235 if( maxdev
> _maxdiff
)
239 if( props
.verbose
>= 1 )
240 cerr
<< "MR::propagateCavityFields: Convergence not reached (maxdev=" << maxdev
<< ")..." << endl
;
245 void MR::calcMagnetizations() {
246 for( size_t i
= 0; i
< G
.nrNodes(); i
++ ) {
247 if( props
.updates
== Properties::UpdateType::FULL
) {
248 // find indices in nb(i)
249 sub_nb
_nbi( G
.nb(i
).size() );
252 // calc numerator1 and denominator1
253 Real sum_even
, sum_odd
;
254 sum_subs(i
, _nbi
, &sum_even
, &sum_odd
);
256 Mag
[i
] = (tanh(theta
[i
]) * sum_even
+ sum_odd
) / (sum_even
+ tanh(theta
[i
]) * sum_odd
);
258 } else if( props
.updates
== Properties::UpdateType::LINEAR
) {
259 sub_nb
empty( G
.nb(i
).size() );
262 for( size_t _l1
= 0; _l1
< G
.nb(i
).size(); _l1
++ )
263 for( size_t _l2
= _l1
+ 1; _l2
< G
.nb(i
).size(); _l2
++ )
264 Mag
[i
] += Gamma(i
,_l1
,_l2
) * tJ
[i
][_l1
] * tJ
[i
][_l2
] * cors
[i
][_l1
][_l2
];
266 if( abs( Mag
[i
] ) > 1.0 )
267 Mag
[i
] = (Mag
[i
] > 0.0) ? 1.0 : -1.0;
272 Real
MR::calcCavityCorrelations() {
274 for( size_t i
= 0; i
< nrVars(); i
++ ) {
275 vector
<Factor
> pairq
;
276 if( props
.inits
== Properties::InitType::EXACT
) {
277 JTree
jtcav(*this, PropertySet()("updates", string("HUGIN"))("verbose", (size_t)0) );
278 jtcav
.makeCavity( i
);
279 pairq
= calcPairBeliefs( jtcav
, delta(i
), false, true );
280 } else if( props
.inits
== Properties::InitType::CLAMPING
) {
281 BP
bpcav(*this, PropertySet()("updates", string("SEQMAX"))("tol", (Real
)1.0e-9)("maxiter", (size_t)10000)("verbose", (size_t)0)("logdomain", false));
282 bpcav
.makeCavity( i
);
284 pairq
= calcPairBeliefs( bpcav
, delta(i
), false, true );
285 md
= std::max( md
, bpcav
.maxDiff() );
286 } else if( props
.inits
== Properties::InitType::RESPPROP
) {
287 BP
bpcav(*this, PropertySet()("updates", string("SEQMAX"))("tol", (Real
)1.0e-9)("maxiter", (size_t)10000)("verbose", (size_t)0)("logdomain", false));
288 bpcav
.makeCavity( i
);
289 bpcav
.makeCavity( i
);
293 BBP
bbp( &bpcav
, PropertySet()("verbose",(size_t)0)("tol",(Real
)1.0e-9)("maxiter",(size_t)10000)("damping",(Real
)0.0)("updates",string("SEQ_MAX")) );
294 bforeach( const Neighbor
&j
, G
.nb(i
) ) {
295 // Create weights for magnetization of some spin
300 // BBP cost function would be the magnetization of spin j
302 b1_adj
.reserve( nrVars() );
303 for( size_t l
= 0; l
< nrVars(); l
++ )
305 b1_adj
.push_back( p
);
307 b1_adj
.push_back( Prob( 2, 0.0 ) );
308 bbp
.init_V( b1_adj
);
310 // run BBP to estimate adjoints
313 bforeach( const Neighbor
&k
, G
.nb(i
) ) {
315 cors
[i
][j
.iter
][k
.iter
] = (bbp
.adj_psi_V(k
)[1] - bbp
.adj_psi_V(k
)[0]);
317 cors
[i
][j
.iter
][k
.iter
] = 0.0;
322 if( props
.inits
!= Properties::InitType::RESPPROP
) {
323 for( size_t jk
= 0; jk
< pairq
.size(); jk
++ ) {
324 VarSet::const_iterator kit
= pairq
[jk
].vars().begin();
325 size_t j
= findVar( *(kit
) );
326 size_t k
= findVar( *(++kit
) );
327 pairq
[jk
].normalize();
328 Real cor
= (pairq
[jk
][3] - pairq
[jk
][2] - pairq
[jk
][1] + pairq
[jk
][0]) - (pairq
[jk
][3] + pairq
[jk
][2] - pairq
[jk
][1] - pairq
[jk
][0]) * (pairq
[jk
][3] - pairq
[jk
][2] + pairq
[jk
][1] - pairq
[jk
][0]);
330 size_t _j
= G
.findNb(i
,j
);
331 size_t _k
= G
.findNb(i
,k
);
332 cors
[i
][_j
][_k
] = cor
;
333 cors
[i
][_k
][_j
] = cor
;
344 if( props
.verbose
>= 1 )
345 cerr
<< "Starting " << identify() << "...";
349 // approximate correlations of cavity spins
350 Real md
= calcCavityCorrelations();
355 propagateCavityFields();
357 // calculate magnetizations
358 calcMagnetizations();
360 if( props
.verbose
>= 1 )
361 cerr
<< name() << " needed " << toc() - tic
<< " seconds." << endl
;
369 Factor
MR::beliefV( size_t i
) const {
372 x
[0] = 0.5 - Mag
[i
] / 2.0;
373 x
[1] = 0.5 + Mag
[i
] / 2.0;
375 return Factor( var(i
), x
);
381 Factor
MR::belief (const VarSet
&ns
) const {
384 else if( ns
.size() == 1 )
385 return beliefV( findVar( *(ns
.begin()) ) );
387 DAI_THROW(BELIEF_NOT_AVAILABLE
);
393 vector
<Factor
> MR::beliefs() const {
394 vector
<Factor
> result
;
395 for( size_t i
= 0; i
< nrVars(); i
++ )
396 result
.push_back( beliefV( i
) );
401 MR::MR( const FactorGraph
&fg
, const PropertySet
&opts
) : DAIAlgFG(fg
), supported(true), _maxdiff(0.0), _iters(0) {
402 setProperties( opts
);
404 size_t N
= fg
.nrVars();
406 // check whether all vars in fg are binary
407 for( size_t i
= 0; i
< N
; i
++ )
408 if( (fg
.var(i
).states() > 2) ) {
413 DAI_THROWE(NOT_IMPLEMENTED
,"MR only supports binary variables");
415 // check whether all interactions are pairwise or single
416 // and construct Markov graph
418 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ ) {
419 const Factor
&psi
= fg
.factor(I
);
420 if( psi
.vars().size() > 2 ) {
423 } else if( psi
.vars().size() == 2 ) {
424 VarSet::const_iterator jit
= psi
.vars().begin();
425 size_t i
= fg
.findVar( *(jit
) );
426 size_t j
= fg
.findVar( *(++jit
) );
427 G
.addEdge( i
, j
, false );
431 DAI_THROWE(NOT_IMPLEMENTED
,"MR does not support higher order interactions (only single and pairwise are supported)");
435 theta
.resize( N
, 0.0 );
439 for( size_t i
= 0; i
< N
; i
++ )
440 tJ
[i
].resize( G
.nb(i
).size(), 0.0 );
442 // initialize theta and tJ
443 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ ) {
444 const Factor
&psi
= fg
.factor(I
);
445 if( psi
.vars().size() == 1 ) {
446 size_t i
= fg
.findVar( *(psi
.vars().begin()) );
447 theta
[i
] += 0.5 * log(psi
[1] / psi
[0]);
448 } else if( psi
.vars().size() == 2 ) {
449 VarSet::const_iterator jit
= psi
.vars().begin();
450 size_t i
= fg
.findVar( *(jit
) );
451 size_t j
= fg
.findVar( *(++jit
) );
453 Real w_ij
= 0.25 * log(psi
[3] * psi
[0] / (psi
[2] * psi
[1]));
454 tJ
[i
][G
.findNb(i
,j
)] += w_ij
;
455 tJ
[j
][G
.findNb(j
,i
)] += w_ij
;
457 theta
[i
] += 0.25 * log(psi
[3] / psi
[2] * psi
[1] / psi
[0]);
458 theta
[j
] += 0.25 * log(psi
[3] / psi
[1] * psi
[2] / psi
[0]);
461 for( size_t i
= 0; i
< N
; i
++ )
462 bforeach( const Neighbor
&j
, G
.nb(i
) )
463 tJ
[i
][j
.iter
] = tanh( tJ
[i
][j
.iter
] );
467 for( size_t i
= 0; i
< N
; i
++ )
468 M
[i
].resize( G
.nb(i
).size() );
472 for( size_t i
= 0; i
< N
; i
++ )
473 cors
[i
].resize( G
.nb(i
).size() );
474 for( size_t i
= 0; i
< N
; i
++ )
475 for( size_t _j
= 0; _j
< cors
[i
].size(); _j
++ )
476 cors
[i
][_j
].resize( G
.nb(i
).size() );
483 } // end of namespace dai