Merge branch 'master' of git@git.tuebingen.mpg.de:libdai
[libdai.git] / src / trwbp.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2010 Joris Mooij [joris dot mooij at libdai dot org]
8 */
9
10
11 #include <dai/trwbp.h>
12
13
14 #define DAI_TRWBP_FAST 1
15
16
17 namespace dai {
18
19
20 using namespace std;
21
22
23 const char *TRWBP::Name = "TRWBP";
24
25
26 void TRWBP::setProperties( const PropertySet &opts ) {
27 BP::setProperties( opts );
28
29 if( opts.hasKey("nrtrees") )
30 nrtrees = opts.getStringAs<size_t>("nrtrees");
31 else
32 nrtrees = 0;
33 }
34
35
36 PropertySet TRWBP::getProperties() const {
37 PropertySet opts = BP::getProperties();
38 opts.set( "nrtrees", nrtrees );
39 return opts;
40 }
41
42
43 string TRWBP::printProperties() const {
44 stringstream s( stringstream::out );
45 string sbp = BP::printProperties();
46 s << sbp.substr( 0, sbp.size() - 1 );
47 s << ",";
48 s << "nrtrees=" << nrtrees << "]";
49 return s.str();
50 }
51
52
53 string TRWBP::identify() const {
54 return string(Name) + printProperties();
55 }
56
57
58 // This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour
59 Real TRWBP::logZ() const {
60 Real sum = 0.0;
61 for( size_t I = 0; I < nrFactors(); I++ ) {
62 sum += (beliefF(I) * factor(I).log(true)).sum(); // TRWBP/FBP
63 sum += Weight(I) * beliefF(I).entropy(); // TRWBP/FBP
64 }
65 for( size_t i = 0; i < nrVars(); ++i ) {
66 Real c_i = 0.0;
67 foreach( const Neighbor &I, nbV(i) )
68 c_i += Weight(I);
69 sum += (1.0 - c_i) * beliefV(i).entropy(); // TRWBP/FBP
70 }
71 return sum;
72 }
73
74
75 // This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour
76 Prob TRWBP::calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) const {
77 Real c_I = Weight(I); // TRWBP: c_I
78
79 Factor Fprod( factor(I) );
80 Prob &prod = Fprod.p();
81 if( props.logdomain ) {
82 prod.takeLog();
83 prod /= c_I; // TRWBP
84 } else
85 prod ^= (1.0 / c_I); // TRWBP
86
87 // Calculate product of incoming messages and factor I
88 foreach( const Neighbor &j, nbF(I) )
89 if( !(without_i && (j == i)) ) {
90 const Var &v_j = var(j);
91 // prod_j will be the product of messages coming into j
92 // TRWBP: corresponds to messages n_jI
93 Prob prod_j( v_j.states(), props.logdomain ? 0.0 : 1.0 );
94 foreach( const Neighbor &J, nbV(j) ) {
95 Real c_J = Weight(J); // TRWBP
96 if( J != I ) { // for all J in nb(j) \ I
97 if( props.logdomain )
98 prod_j += message( j, J.iter ) * c_J;
99 else
100 prod_j *= message( j, J.iter ) ^ c_J;
101 } else { // TRWBP: multiply by m_Ij^(c_I-1)
102 if( props.logdomain )
103 prod_j += message( j, J.iter ) * (c_J - 1.0);
104 else
105 prod_j *= message( j, J.iter ) ^ (c_J - 1.0);
106 }
107 }
108
109 // multiply prod with prod_j
110 if( !DAI_TRWBP_FAST ) {
111 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
112 if( props.logdomain )
113 Fprod += Factor( v_j, prod_j );
114 else
115 Fprod *= Factor( v_j, prod_j );
116 } else {
117 // OPTIMIZED VERSION
118 size_t _I = j.dual;
119 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
120 const ind_t &ind = index(j, _I);
121
122 for( size_t r = 0; r < prod.size(); ++r )
123 if( props.logdomain )
124 prod.set( r, prod[r] + prod_j[ind[r]] );
125 else
126 prod.set( r, prod[r] * prod_j[ind[r]] );
127 }
128 }
129
130 return prod;
131 }
132
133
134 // This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour
135 void TRWBP::calcBeliefV( size_t i, Prob &p ) const {
136 p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
137 foreach( const Neighbor &I, nbV(i) ) {
138 Real c_I = Weight(I);
139 if( props.logdomain )
140 p += newMessage( i, I.iter ) * c_I;
141 else
142 p *= newMessage( i, I.iter ) ^ c_I;
143 }
144 }
145
146
147 void TRWBP::construct() {
148 BP::construct();
149 _weight.resize( nrFactors(), 1.0 );
150 sampleWeights( nrtrees );
151 if( props.verbose >= 2 )
152 cerr << "Weights: " << _weight << endl;
153 }
154
155
156 void TRWBP::addTreeToWeights( const RootedTree &tree ) {
157 for( RootedTree::const_iterator e = tree.begin(); e != tree.end(); e++ ) {
158 VarSet ij( var(e->first), var(e->second) );
159 size_t I = findFactor( ij );
160 _weight[I] += 1.0;
161 }
162 }
163
164
165 void TRWBP::sampleWeights( size_t nrTrees ) {
166 if( !nrTrees )
167 return;
168
169 // initialize weights to zero
170 fill( _weight.begin(), _weight.end(), 0.0 );
171
172 // construct Markov adjacency graph, with edges weighted with
173 // random weights drawn from the uniform distribution on the interval [0,1]
174 WeightedGraph<Real> wg;
175 for( size_t i = 0; i < nrVars(); ++i ) {
176 const Var &v_i = var(i);
177 VarSet di = delta(i);
178 for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
179 if( v_i < *j )
180 wg[UEdge(i,findVar(*j))] = rnd_uniform();
181 }
182
183 // now repeatedly change the random weights, find the minimal spanning tree, and add it to the weights
184 for( size_t nr = 0; nr < nrTrees; nr++ ) {
185 // find minimal spanning tree
186 RootedTree randTree = MinSpanningTree( wg, true );
187 // add it to the weights
188 addTreeToWeights( randTree );
189 // resample weights of the graph
190 for( WeightedGraph<Real>::iterator e = wg.begin(); e != wg.end(); e++ )
191 e->second = rnd_uniform();
192 }
193
194 // normalize the weights and set the single-variable weights to 1.0
195 for( size_t I = 0; I < nrFactors(); I++ ) {
196 size_t sizeI = factor(I).vars().size();
197 if( sizeI == 1 )
198 _weight[I] = 1.0;
199 else if( sizeI == 2 )
200 _weight[I] /= nrTrees;
201 else
202 DAI_THROW(NOT_IMPLEMENTED);
203 }
204 }
205
206
207 } // end of namespace dai