Made members Neighbor, Neighbors and Edge of Graph, BipartiteGraph and DAG global
[libdai.git] / src / bp_dual.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2009 Frederik Eaton [frederik at ofb dot net]
8 */
9
10
11 #include <iostream>
12 #include <sstream>
13 #include <algorithm>
14
15 #include <dai/bp_dual.h>
16 #include <dai/util.h>
17 #include <dai/bipgraph.h>
18
19
20 namespace dai {
21
22
23 using namespace std;
24
25
26 void BP_dual::init() {
27 regenerateMessages();
28 regenerateBeliefs();
29 calcMessages();
30 calcBeliefs();
31 }
32
33
34 void BP_dual::regenerateMessages() {
35 size_t nv = fg().nrVars();
36 _msgs.Zn.resize(nv);
37 _msgs.Zm.resize(nv);
38 _msgs.m.resize(nv);
39 _msgs.n.resize(nv);
40 for( size_t i = 0; i < nv; i++ ) {
41 size_t nvf = fg().nbV(i).size();
42 _msgs.Zn[i].resize(nvf, 1.0);
43 _msgs.Zm[i].resize(nvf, 1.0);
44 size_t states = fg().var(i).states();
45 _msgs.n[i].resize(nvf, Prob(states));
46 _msgs.m[i].resize(nvf, Prob(states));
47 }
48 }
49
50
51 void BP_dual::regenerateBeliefs() {
52 _beliefs.b1.clear();
53 _beliefs.b1.reserve(fg().nrVars());
54 _beliefs.Zb1.resize(fg().nrVars(), 1.0);
55 _beliefs.b2.clear();
56 _beliefs.b2.reserve(fg().nrFactors());
57 _beliefs.Zb2.resize(fg().nrFactors(), 1.0);
58
59 for( size_t i = 0; i < fg().nrVars(); i++ )
60 _beliefs.b1.push_back( Prob( fg().var(i).states() ) );
61 for( size_t I = 0; I < fg().nrFactors(); I++ )
62 _beliefs.b2.push_back( Prob( fg().factor(I).nrStates() ) );
63 }
64
65
66 void BP_dual::calcMessages() {
67 // calculate 'n' messages from "factor marginal / factor"
68 for( size_t I = 0; I < fg().nrFactors(); I++ ) {
69 Factor f = _ia->beliefF(I) / fg().factor(I);
70 foreach( const Neighbor &i, fg().nbF(I) )
71 msgN(i, i.dual) = f.marginal( fg().var(i) ).p();
72 }
73 // calculate 'm' messages and normalizers from 'n' messages
74 for( size_t i = 0; i < fg().nrVars(); i++ )
75 foreach( const Neighbor &I, fg().nbV(i) )
76 calcNewM( i, I.iter );
77 // recalculate 'n' messages and normalizers from 'm' messages
78 for( size_t i = 0; i < fg().nrVars(); i++ )
79 foreach( const Neighbor &I, fg().nbV(i) )
80 calcNewN(i, I.iter);
81 }
82
83
84 void BP_dual::calcNewM( size_t i, size_t _I ) {
85 // calculate updated message I->i
86 const Neighbor &I = fg().nbV(i)[_I];
87 Prob prod( fg().factor(I).p() );
88 foreach( const Neighbor &j, fg().nbF(I) )
89 if( j != i ) { // for all j in I \ i
90 Prob &n = msgN(j,j.dual);
91 IndexFor ind( fg().var(j), fg().factor(I).vars() );
92 for( size_t x = 0; ind.valid(); x++, ++ind )
93 prod.set( x, prod[x] * n[ind] );
94 }
95 // Marginalize onto i
96 Prob marg( fg().var(i).states(), 0.0 );
97 // ind is the precalculated Index(i,I) i.e. to x_I == k corresponds x_i == ind[k]
98 IndexFor ind( fg().var(i), fg().factor(I).vars() );
99 for( size_t x = 0; ind.valid(); x++, ++ind )
100 marg.set( ind, marg[ind] + prod[x] );
101
102 _msgs.Zm[i][_I] = marg.normalize();
103 _msgs.m[i][_I] = marg;
104 }
105
106
107 void BP_dual::calcNewN( size_t i, size_t _I ) {
108 // calculate updated message i->I
109 const Neighbor &I = fg().nbV(i)[_I];
110 Prob prod( fg().var(i).states(), 1.0 );
111 foreach( const Neighbor &J, fg().nbV(i) )
112 if( J.node != I.node ) // for all J in i \ I
113 prod *= msgM(i,J.iter);
114 _msgs.Zn[i][_I] = prod.normalize();
115 _msgs.n[i][_I] = prod;
116 }
117
118
119 void BP_dual::calcBeliefs() {
120 for( size_t i = 0; i < fg().nrVars(); i++ )
121 calcBeliefV(i); // calculate b_i
122 for( size_t I = 0; I < fg().nrFactors(); I++ )
123 calcBeliefF(I); // calculate b_I
124 }
125
126
127 void BP_dual::calcBeliefV( size_t i ) {
128 Prob prod( fg().var(i).states(), 1.0 );
129 foreach( const Neighbor &I, fg().nbV(i) )
130 prod *= msgM(i,I.iter);
131 _beliefs.Zb1[i] = prod.normalize();
132 _beliefs.b1[i] = prod;
133 }
134
135
136 void BP_dual::calcBeliefF( size_t I ) {
137 Prob prod( fg().factor(I).p() );
138 foreach( const Neighbor &j, fg().nbF(I) ) {
139 IndexFor ind( fg().var(j), fg().factor(I).vars() );
140 Prob n( msgN(j,j.dual) );
141 for( size_t x = 0; ind.valid(); x++, ++ind )
142 prod.set( x, prod[x] * n[ind] );
143 }
144 _beliefs.Zb2[I] = prod.normalize();
145 _beliefs.b2[I] = prod;
146 }
147
148
149 } // end of namespace dai