Multiple changes: changes in build system, one workaround and one bug fix
[libdai.git] / src / trwbp.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <dai/dai_config.h>
10 #ifdef DAI_WITH_TRWBP
11
12
13 #include <dai/trwbp.h>
14
15
16 #define DAI_TRWBP_FAST 1
17
18
19 namespace dai {
20
21
22 using namespace std;
23
24
25 void TRWBP::setProperties( const PropertySet &opts ) {
26 BP::setProperties( opts );
27
28 if( opts.hasKey("nrtrees") )
29 nrtrees = opts.getStringAs<size_t>("nrtrees");
30 else
31 nrtrees = 0;
32 }
33
34
35 PropertySet TRWBP::getProperties() const {
36 PropertySet opts = BP::getProperties();
37 opts.set( "nrtrees", nrtrees );
38 return opts;
39 }
40
41
42 string TRWBP::printProperties() const {
43 stringstream s( stringstream::out );
44 string sbp = BP::printProperties();
45 s << sbp.substr( 0, sbp.size() - 1 );
46 s << ",";
47 s << "nrtrees=" << nrtrees << "]";
48 return s.str();
49 }
50
51
52 // This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour
53 Real TRWBP::logZ() const {
54 Real sum = 0.0;
55 for( size_t I = 0; I < nrFactors(); I++ ) {
56 sum += (beliefF(I) * factor(I).log(true)).sum(); // TRWBP/FBP
57 sum += Weight(I) * beliefF(I).entropy(); // TRWBP/FBP
58 }
59 for( size_t i = 0; i < nrVars(); ++i ) {
60 Real c_i = 0.0;
61 bforeach( const Neighbor &I, nbV(i) )
62 c_i += Weight(I);
63 if( c_i != 1.0 )
64 sum += (1.0 - c_i) * beliefV(i).entropy(); // TRWBP/FBP
65 }
66 return sum;
67 }
68
69
70 // This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour
71 Prob TRWBP::calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) const {
72 Real c_I = Weight(I); // TRWBP: c_I
73
74 Factor Fprod( factor(I) );
75 Prob &prod = Fprod.p();
76 if( props.logdomain ) {
77 prod.takeLog();
78 prod /= c_I; // TRWBP
79 } else
80 prod ^= (1.0 / c_I); // TRWBP
81
82 // Calculate product of incoming messages and factor I
83 bforeach( const Neighbor &j, nbF(I) )
84 if( !(without_i && (j == i)) ) {
85 const Var &v_j = var(j);
86 // prod_j will be the product of messages coming into j
87 // TRWBP: corresponds to messages n_jI
88 Prob prod_j( v_j.states(), props.logdomain ? 0.0 : 1.0 );
89 bforeach( const Neighbor &J, nbV(j) ) {
90 Real c_J = Weight(J); // TRWBP
91 if( J != I ) { // for all J in nb(j) \ I
92 if( props.logdomain )
93 prod_j += message( j, J.iter ) * c_J;
94 else
95 prod_j *= message( j, J.iter ) ^ c_J;
96 } else if( c_J != 1.0 ) { // TRWBP: multiply by m_Ij^(c_I-1)
97 if( props.logdomain )
98 prod_j += message( j, J.iter ) * (c_J - 1.0);
99 else
100 prod_j *= message( j, J.iter ) ^ (c_J - 1.0);
101 }
102 }
103
104 // multiply prod with prod_j
105 if( !DAI_TRWBP_FAST ) {
106 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
107 if( props.logdomain )
108 Fprod += Factor( v_j, prod_j );
109 else
110 Fprod *= Factor( v_j, prod_j );
111 } else {
112 // OPTIMIZED VERSION
113 size_t _I = j.dual;
114 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
115 const ind_t &ind = index(j, _I);
116
117 for( size_t r = 0; r < prod.size(); ++r )
118 if( props.logdomain )
119 prod.set( r, prod[r] + prod_j[ind[r]] );
120 else
121 prod.set( r, prod[r] * prod_j[ind[r]] );
122 }
123 }
124
125 return prod;
126 }
127
128
129 // This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour
130 void TRWBP::calcBeliefV( size_t i, Prob &p ) const {
131 p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
132 bforeach( const Neighbor &I, nbV(i) ) {
133 Real c_I = Weight(I);
134 if( props.logdomain )
135 p += newMessage( i, I.iter ) * c_I;
136 else
137 p *= newMessage( i, I.iter ) ^ c_I;
138 }
139 }
140
141
142 void TRWBP::construct() {
143 BP::construct();
144 _weight.resize( nrFactors(), 1.0 );
145 sampleWeights( nrtrees );
146 if( props.verbose >= 2 )
147 cerr << "Weights: " << _weight << endl;
148 }
149
150
151 void TRWBP::addTreeToWeights( const RootedTree &tree ) {
152 for( RootedTree::const_iterator e = tree.begin(); e != tree.end(); e++ ) {
153 VarSet ij( var(e->first), var(e->second) );
154 size_t I = findFactor( ij );
155 _weight[I] += 1.0;
156 }
157 }
158
159
160 void TRWBP::sampleWeights( size_t nrTrees ) {
161 if( !nrTrees )
162 return;
163
164 // initialize weights to zero
165 fill( _weight.begin(), _weight.end(), 0.0 );
166
167 // construct Markov adjacency graph, with edges weighted with
168 // random weights drawn from the uniform distribution on the interval [0,1]
169 WeightedGraph<Real> wg;
170 for( size_t i = 0; i < nrVars(); ++i ) {
171 const Var &v_i = var(i);
172 VarSet di = delta(i);
173 for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
174 if( v_i < *j )
175 wg[UEdge(i,findVar(*j))] = rnd_uniform();
176 }
177
178 // now repeatedly change the random weights, find the minimal spanning tree, and add it to the weights
179 for( size_t nr = 0; nr < nrTrees; nr++ ) {
180 // find minimal spanning tree
181 RootedTree randTree = MinSpanningTree( wg, /*true*/false ); // WORKAROUND FOR BUG IN BOOST GRAPH LIBRARY VERSION 1.54
182 // add it to the weights
183 addTreeToWeights( randTree );
184 // resample weights of the graph
185 for( WeightedGraph<Real>::iterator e = wg.begin(); e != wg.end(); e++ )
186 e->second = rnd_uniform();
187 }
188
189 // normalize the weights and set the single-variable weights to 1.0
190 for( size_t I = 0; I < nrFactors(); I++ ) {
191 size_t sizeI = factor(I).vars().size();
192 if( sizeI == 1 )
193 _weight[I] = 1.0;
194 else if( sizeI == 2 )
195 _weight[I] /= nrTrees;
196 else
197 DAI_THROW(NOT_IMPLEMENTED);
198 }
199 }
200
201
202 } // end of namespace dai
203
204
205 #endif