Added example_imagesegmentation, BipartiteGraph::nb1Set() and nb2Set()
[libdai.git] / src / bp_dual.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2009 Frederik Eaton [frederik at ofb dot net]
8 */
9
10
11 #include <iostream>
12 #include <sstream>
13 #include <algorithm>
14
15 #include <dai/bp_dual.h>
16 #include <dai/util.h>
17 #include <dai/bipgraph.h>
18
19
20 namespace dai {
21
22
23 using namespace std;
24
25
26 typedef BipartiteGraph::Neighbor Neighbor;
27
28
29 void BP_dual::init() {
30 regenerateMessages();
31 regenerateBeliefs();
32 calcMessages();
33 calcBeliefs();
34 }
35
36
37 void BP_dual::regenerateMessages() {
38 size_t nv = fg().nrVars();
39 _msgs.Zn.resize(nv);
40 _msgs.Zm.resize(nv);
41 _msgs.m.resize(nv);
42 _msgs.n.resize(nv);
43 for( size_t i = 0; i < nv; i++ ) {
44 size_t nvf = fg().nbV(i).size();
45 _msgs.Zn[i].resize(nvf, 1.0);
46 _msgs.Zm[i].resize(nvf, 1.0);
47 size_t states = fg().var(i).states();
48 _msgs.n[i].resize(nvf, Prob(states));
49 _msgs.m[i].resize(nvf, Prob(states));
50 }
51 }
52
53
54 void BP_dual::regenerateBeliefs() {
55 _beliefs.b1.clear();
56 _beliefs.b1.reserve(fg().nrVars());
57 _beliefs.Zb1.resize(fg().nrVars(), 1.0);
58 _beliefs.b2.clear();
59 _beliefs.b2.reserve(fg().nrFactors());
60 _beliefs.Zb2.resize(fg().nrFactors(), 1.0);
61
62 for( size_t i = 0; i < fg().nrVars(); i++ )
63 _beliefs.b1.push_back( Prob( fg().var(i).states() ) );
64 for( size_t I = 0; I < fg().nrFactors(); I++ )
65 _beliefs.b2.push_back( Prob( fg().factor(I).nrStates() ) );
66 }
67
68
69 void BP_dual::calcMessages() {
70 // calculate 'n' messages from "factor marginal / factor"
71 for( size_t I = 0; I < fg().nrFactors(); I++ ) {
72 Factor f = _ia->beliefF(I) / fg().factor(I);
73 foreach( const Neighbor &i, fg().nbF(I) )
74 msgN(i, i.dual) = f.marginal( fg().var(i) ).p();
75 }
76 // calculate 'm' messages and normalizers from 'n' messages
77 for( size_t i = 0; i < fg().nrVars(); i++ )
78 foreach( const Neighbor &I, fg().nbV(i) )
79 calcNewM( i, I.iter );
80 // recalculate 'n' messages and normalizers from 'm' messages
81 for( size_t i = 0; i < fg().nrVars(); i++ )
82 foreach( const Neighbor &I, fg().nbV(i) )
83 calcNewN(i, I.iter);
84 }
85
86
87 void BP_dual::calcNewM( size_t i, size_t _I ) {
88 // calculate updated message I->i
89 const Neighbor &I = fg().nbV(i)[_I];
90 Prob prod( fg().factor(I).p() );
91 foreach( const Neighbor &j, fg().nbF(I) )
92 if( j != i ) { // for all j in I \ i
93 Prob &n = msgN(j,j.dual);
94 IndexFor ind( fg().var(j), fg().factor(I).vars() );
95 for( size_t x = 0; ind.valid(); x++, ++ind )
96 prod.set( x, prod[x] * n[ind] );
97 }
98 // Marginalize onto i
99 Prob marg( fg().var(i).states(), 0.0 );
100 // ind is the precalculated Index(i,I) i.e. to x_I == k corresponds x_i == ind[k]
101 IndexFor ind( fg().var(i), fg().factor(I).vars() );
102 for( size_t x = 0; ind.valid(); x++, ++ind )
103 marg.set( ind, marg[ind] + prod[x] );
104
105 _msgs.Zm[i][_I] = marg.normalize();
106 _msgs.m[i][_I] = marg;
107 }
108
109
110 void BP_dual::calcNewN( size_t i, size_t _I ) {
111 // calculate updated message i->I
112 const Neighbor &I = fg().nbV(i)[_I];
113 Prob prod( fg().var(i).states(), 1.0 );
114 foreach( const Neighbor &J, fg().nbV(i) )
115 if( J.node != I.node ) // for all J in i \ I
116 prod *= msgM(i,J.iter);
117 _msgs.Zn[i][_I] = prod.normalize();
118 _msgs.n[i][_I] = prod;
119 }
120
121
122 void BP_dual::calcBeliefs() {
123 for( size_t i = 0; i < fg().nrVars(); i++ )
124 calcBeliefV(i); // calculate b_i
125 for( size_t I = 0; I < fg().nrFactors(); I++ )
126 calcBeliefF(I); // calculate b_I
127 }
128
129
130 void BP_dual::calcBeliefV( size_t i ) {
131 Prob prod( fg().var(i).states(), 1.0 );
132 foreach( const Neighbor &I, fg().nbV(i) )
133 prod *= msgM(i,I.iter);
134 _beliefs.Zb1[i] = prod.normalize();
135 _beliefs.b1[i] = prod;
136 }
137
138
139 void BP_dual::calcBeliefF( size_t I ) {
140 Prob prod( fg().factor(I).p() );
141 foreach( const Neighbor &j, fg().nbF(I) ) {
142 IndexFor ind( fg().var(j), fg().factor(I).vars() );
143 Prob n( msgN(j,j.dual) );
144 for( size_t x = 0; ind.valid(); x++, ++ind )
145 prod.set( x, prod[x] * n[ind] );
146 }
147 _beliefs.Zb2[I] = prod.normalize();
148 _beliefs.b2[I] = prod;
149 }
150
151
152 } // end of namespace dai