1f1faf05f2f051403b1e69b052f3bbaf9377fbff
[libdai.git] / src / bp_dual.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <iostream>
10 #include <sstream>
11 #include <algorithm>
12
13 #include <dai/bp_dual.h>
14 #include <dai/util.h>
15 #include <dai/bipgraph.h>
16
17
18 namespace dai {
19
20
21 using namespace std;
22
23
24 void BP_dual::init() {
25 regenerateMessages();
26 regenerateBeliefs();
27 calcMessages();
28 calcBeliefs();
29 }
30
31
32 void BP_dual::regenerateMessages() {
33 size_t nv = fg().nrVars();
34 _msgs.Zn.resize(nv);
35 _msgs.Zm.resize(nv);
36 _msgs.m.resize(nv);
37 _msgs.n.resize(nv);
38 for( size_t i = 0; i < nv; i++ ) {
39 size_t nvf = fg().nbV(i).size();
40 _msgs.Zn[i].resize(nvf, 1.0);
41 _msgs.Zm[i].resize(nvf, 1.0);
42 size_t states = fg().var(i).states();
43 _msgs.n[i].resize(nvf, Prob(states));
44 _msgs.m[i].resize(nvf, Prob(states));
45 }
46 }
47
48
49 void BP_dual::regenerateBeliefs() {
50 _beliefs.b1.clear();
51 _beliefs.b1.reserve(fg().nrVars());
52 _beliefs.Zb1.resize(fg().nrVars(), 1.0);
53 _beliefs.b2.clear();
54 _beliefs.b2.reserve(fg().nrFactors());
55 _beliefs.Zb2.resize(fg().nrFactors(), 1.0);
56
57 for( size_t i = 0; i < fg().nrVars(); i++ )
58 _beliefs.b1.push_back( Prob( fg().var(i).states() ) );
59 for( size_t I = 0; I < fg().nrFactors(); I++ )
60 _beliefs.b2.push_back( Prob( fg().factor(I).nrStates() ) );
61 }
62
63
64 void BP_dual::calcMessages() {
65 // calculate 'n' messages from "factor marginal / factor"
66 for( size_t I = 0; I < fg().nrFactors(); I++ ) {
67 Factor f = _ia->beliefF(I) / fg().factor(I);
68 bforeach( const Neighbor &i, fg().nbF(I) )
69 msgN(i, i.dual) = f.marginal( fg().var(i) ).p();
70 }
71 // calculate 'm' messages and normalizers from 'n' messages
72 for( size_t i = 0; i < fg().nrVars(); i++ )
73 bforeach( const Neighbor &I, fg().nbV(i) )
74 calcNewM( i, I.iter );
75 // recalculate 'n' messages and normalizers from 'm' messages
76 for( size_t i = 0; i < fg().nrVars(); i++ )
77 bforeach( const Neighbor &I, fg().nbV(i) )
78 calcNewN(i, I.iter);
79 }
80
81
82 void BP_dual::calcNewM( size_t i, size_t _I ) {
83 // calculate updated message I->i
84 const Neighbor &I = fg().nbV(i)[_I];
85 Prob prod( fg().factor(I).p() );
86 bforeach( const Neighbor &j, fg().nbF(I) )
87 if( j != i ) { // for all j in I \ i
88 Prob &n = msgN(j,j.dual);
89 IndexFor ind( fg().var(j), fg().factor(I).vars() );
90 for( size_t x = 0; ind.valid(); x++, ++ind )
91 prod.set( x, prod[x] * n[ind] );
92 }
93 // Marginalize onto i
94 Prob marg( fg().var(i).states(), 0.0 );
95 // ind is the precalculated Index(i,I) i.e. to x_I == k corresponds x_i == ind[k]
96 IndexFor ind( fg().var(i), fg().factor(I).vars() );
97 for( size_t x = 0; ind.valid(); x++, ++ind )
98 marg.set( ind, marg[ind] + prod[x] );
99
100 _msgs.Zm[i][_I] = marg.normalize();
101 _msgs.m[i][_I] = marg;
102 }
103
104
105 void BP_dual::calcNewN( size_t i, size_t _I ) {
106 // calculate updated message i->I
107 const Neighbor &I = fg().nbV(i)[_I];
108 Prob prod( fg().var(i).states(), 1.0 );
109 bforeach( const Neighbor &J, fg().nbV(i) )
110 if( J.node != I.node ) // for all J in i \ I
111 prod *= msgM(i,J.iter);
112 _msgs.Zn[i][_I] = prod.normalize();
113 _msgs.n[i][_I] = prod;
114 }
115
116
117 void BP_dual::calcBeliefs() {
118 for( size_t i = 0; i < fg().nrVars(); i++ )
119 calcBeliefV(i); // calculate b_i
120 for( size_t I = 0; I < fg().nrFactors(); I++ )
121 calcBeliefF(I); // calculate b_I
122 }
123
124
125 void BP_dual::calcBeliefV( size_t i ) {
126 Prob prod( fg().var(i).states(), 1.0 );
127 bforeach( const Neighbor &I, fg().nbV(i) )
128 prod *= msgM(i,I.iter);
129 _beliefs.Zb1[i] = prod.normalize();
130 _beliefs.b1[i] = prod;
131 }
132
133
134 void BP_dual::calcBeliefF( size_t I ) {
135 Prob prod( fg().factor(I).p() );
136 bforeach( const Neighbor &j, fg().nbF(I) ) {
137 IndexFor ind( fg().var(j), fg().factor(I).vars() );
138 Prob n( msgN(j,j.dual) );
139 for( size_t x = 0; ind.valid(); x++, ++ind )
140 prod.set( x, prod[x] * n[ind] );
141 }
142 _beliefs.Zb2[I] = prod.normalize();
143 _beliefs.b2[I] = prod;
144 }
145
146
147 } // end of namespace dai