[Patrick Pletscher] Fixed performance issue in FactorGraph::clamp and FactorGraph...
[libdai.git] / src / bp_dual.cpp
1 /* Copyright (C) 2009 Frederik Eaton [frederik at ofb dot net]
2
3 This file is part of libDAI.
4
5 libDAI is free software; you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation; either version 2 of the License, or
8 (at your option) any later version.
9
10 libDAI is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
14
15 You should have received a copy of the GNU General Public License
16 along with libDAI; if not, write to the Free Software
17 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
18 */
19
20
21 #include <iostream>
22 #include <sstream>
23 #include <algorithm>
24
25 #include <dai/bp_dual.h>
26 #include <dai/util.h>
27 #include <dai/bipgraph.h>
28
29
30 namespace dai {
31
32
33 using namespace std;
34
35
36 typedef BipartiteGraph::Neighbor Neighbor;
37
38
39 void BP_dual::init() {
40 regenerateMessages();
41 regenerateBeliefs();
42 calcMessages();
43 calcBeliefs();
44 }
45
46
47 void BP_dual::regenerateMessages() {
48 size_t nv = fg().nrVars();
49 _msgs.Zn.resize(nv);
50 _msgs.Zm.resize(nv);
51 _msgs.m.resize(nv);
52 _msgs.n.resize(nv);
53 for( size_t i = 0; i < nv; i++ ) {
54 size_t nvf = fg().nbV(i).size();
55 _msgs.Zn[i].resize(nvf, 1.0);
56 _msgs.Zm[i].resize(nvf, 1.0);
57 size_t states = fg().var(i).states();
58 _msgs.n[i].resize(nvf, Prob(states));
59 _msgs.m[i].resize(nvf, Prob(states));
60 }
61 }
62
63
64 void BP_dual::regenerateBeliefs() {
65 _beliefs.b1.clear();
66 _beliefs.b1.reserve(fg().nrVars());
67 _beliefs.Zb1.resize(fg().nrVars(), 1.0);
68 _beliefs.b2.clear();
69 _beliefs.b2.reserve(fg().nrFactors());
70 _beliefs.Zb2.resize(fg().nrFactors(), 1.0);
71
72 for( size_t i = 0; i < fg().nrVars(); i++ )
73 _beliefs.b1.push_back( Prob( fg().var(i).states() ) );
74 for( size_t I = 0; I < fg().nrFactors(); I++ )
75 _beliefs.b2.push_back( Prob( fg().factor(I).states() ) );
76 }
77
78
79 void BP_dual::calcMessages() {
80 // calculate 'n' messages from "factor marginal / factor"
81 for( size_t I = 0; I < fg().nrFactors(); I++ ) {
82 Factor f = _ia->beliefF(I) / fg().factor(I);
83 foreach( const Neighbor &i, fg().nbF(I) )
84 msgN(i, i.dual) = f.marginal( fg().var(i) ).p();
85 }
86 // calculate 'm' messages and normalizers from 'n' messages
87 for( size_t i = 0; i < fg().nrVars(); i++ )
88 foreach( const Neighbor &I, fg().nbV(i) )
89 calcNewM( i, I.iter );
90 // recalculate 'n' messages and normalizers from 'm' messages
91 for( size_t i = 0; i < fg().nrVars(); i++ )
92 foreach( const Neighbor &I, fg().nbV(i) )
93 calcNewN(i, I.iter);
94 }
95
96
97 void BP_dual::calcNewM( size_t i, size_t _I ) {
98 // calculate updated message I->i
99 const Neighbor &I = fg().nbV(i)[_I];
100 Prob prod( fg().factor(I).p() );
101 foreach( const Neighbor &j, fg().nbF(I) )
102 if( j != i ) { // for all j in I \ i
103 Prob &n = msgN(j,j.dual);
104 IndexFor ind( fg().var(j), fg().factor(I).vars() );
105 for( size_t x = 0; ind >= 0; x++, ++ind )
106 prod[x] *= n[ind];
107 }
108 // Marginalize onto i
109 Prob marg( fg().var(i).states(), 0.0 );
110 // ind is the precalculated Index(i,I) i.e. to x_I == k corresponds x_i == ind[k]
111 IndexFor ind( fg().var(i), fg().factor(I).vars() );
112 for( size_t x = 0; ind >= 0; x++, ++ind )
113 marg[ind] += prod[x];
114
115 _msgs.Zm[i][_I] = marg.normalize();
116 _msgs.m[i][_I] = marg;
117 }
118
119
120 void BP_dual::calcNewN( size_t i, size_t _I ) {
121 // calculate updated message i->I
122 const Neighbor &I = fg().nbV(i)[_I];
123 Prob prod( fg().var(i).states(), 1.0 );
124 foreach( const Neighbor &J, fg().nbV(i) )
125 if( J.node != I.node ) // for all J in i \ I
126 prod *= msgM(i,J.iter);
127 _msgs.Zn[i][_I] = prod.normalize();
128 _msgs.n[i][_I] = prod;
129 }
130
131
132 void BP_dual::calcBeliefs() {
133 for( size_t i = 0; i < fg().nrVars(); i++ )
134 calcBeliefV(i); // calculate b_i
135 for( size_t I = 0; I < fg().nrFactors(); I++ )
136 calcBeliefF(I); // calculate b_I
137 }
138
139
140 void BP_dual::calcBeliefV( size_t i ) {
141 Prob prod( fg().var(i).states(), 1.0 );
142 foreach( const Neighbor &I, fg().nbV(i) )
143 prod *= msgM(i,I.iter);
144 _beliefs.Zb1[i] = prod.normalize();
145 _beliefs.b1[i] = prod;
146 }
147
148
149 void BP_dual::calcBeliefF( size_t I ) {
150 Prob prod( fg().factor(I).p() );
151 foreach( const Neighbor &j, fg().nbF(I) ) {
152 IndexFor ind( fg().var(j), fg().factor(I).vars() );
153 Prob n( msgN(j,j.dual) );
154 for( size_t x = 0; ind >= 0; x++, ++ind )
155 prod[x] *= n[ind];
156 }
157 _beliefs.Zb2[I] = prod.normalize();
158 _beliefs.b2[I] = prod;
159 }
160
161
162 } // end of namespace dai