e2775271b51be439e29fd551b8d2e7d50f61b7eb
[libdai.git] / src / trwbp.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <dai/trwbp.h>
10
11
12 #define DAI_TRWBP_FAST 1
13
14
15 namespace dai {
16
17
18 using namespace std;
19
20
21 void TRWBP::setProperties( const PropertySet &opts ) {
22 BP::setProperties( opts );
23
24 if( opts.hasKey("nrtrees") )
25 nrtrees = opts.getStringAs<size_t>("nrtrees");
26 else
27 nrtrees = 0;
28 }
29
30
31 PropertySet TRWBP::getProperties() const {
32 PropertySet opts = BP::getProperties();
33 opts.set( "nrtrees", nrtrees );
34 return opts;
35 }
36
37
38 string TRWBP::printProperties() const {
39 stringstream s( stringstream::out );
40 string sbp = BP::printProperties();
41 s << sbp.substr( 0, sbp.size() - 1 );
42 s << ",";
43 s << "nrtrees=" << nrtrees << "]";
44 return s.str();
45 }
46
47
48 // This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour
49 Real TRWBP::logZ() const {
50 Real sum = 0.0;
51 for( size_t I = 0; I < nrFactors(); I++ ) {
52 sum += (beliefF(I) * factor(I).log(true)).sum(); // TRWBP/FBP
53 sum += Weight(I) * beliefF(I).entropy(); // TRWBP/FBP
54 }
55 for( size_t i = 0; i < nrVars(); ++i ) {
56 Real c_i = 0.0;
57 bforeach( const Neighbor &I, nbV(i) )
58 c_i += Weight(I);
59 if( c_i != 1.0 )
60 sum += (1.0 - c_i) * beliefV(i).entropy(); // TRWBP/FBP
61 }
62 return sum;
63 }
64
65
66 // This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour
67 Prob TRWBP::calcIncomingMessageProduct( size_t I, bool without_i, size_t i ) const {
68 Real c_I = Weight(I); // TRWBP: c_I
69
70 Factor Fprod( factor(I) );
71 Prob &prod = Fprod.p();
72 if( props.logdomain ) {
73 prod.takeLog();
74 prod /= c_I; // TRWBP
75 } else
76 prod ^= (1.0 / c_I); // TRWBP
77
78 // Calculate product of incoming messages and factor I
79 bforeach( const Neighbor &j, nbF(I) )
80 if( !(without_i && (j == i)) ) {
81 const Var &v_j = var(j);
82 // prod_j will be the product of messages coming into j
83 // TRWBP: corresponds to messages n_jI
84 Prob prod_j( v_j.states(), props.logdomain ? 0.0 : 1.0 );
85 bforeach( const Neighbor &J, nbV(j) ) {
86 Real c_J = Weight(J); // TRWBP
87 if( J != I ) { // for all J in nb(j) \ I
88 if( props.logdomain )
89 prod_j += message( j, J.iter ) * c_J;
90 else
91 prod_j *= message( j, J.iter ) ^ c_J;
92 } else if( c_J != 1.0 ) { // TRWBP: multiply by m_Ij^(c_I-1)
93 if( props.logdomain )
94 prod_j += message( j, J.iter ) * (c_J - 1.0);
95 else
96 prod_j *= message( j, J.iter ) ^ (c_J - 1.0);
97 }
98 }
99
100 // multiply prod with prod_j
101 if( !DAI_TRWBP_FAST ) {
102 // UNOPTIMIZED (SIMPLE TO READ, BUT SLOW) VERSION
103 if( props.logdomain )
104 Fprod += Factor( v_j, prod_j );
105 else
106 Fprod *= Factor( v_j, prod_j );
107 } else {
108 // OPTIMIZED VERSION
109 size_t _I = j.dual;
110 // ind is the precalculated IndexFor(j,I) i.e. to x_I == k corresponds x_j == ind[k]
111 const ind_t &ind = index(j, _I);
112
113 for( size_t r = 0; r < prod.size(); ++r )
114 if( props.logdomain )
115 prod.set( r, prod[r] + prod_j[ind[r]] );
116 else
117 prod.set( r, prod[r] * prod_j[ind[r]] );
118 }
119 }
120
121 return prod;
122 }
123
124
125 // This code has been copied from bp.cpp, except where comments indicate TRWBP-specific behaviour
126 void TRWBP::calcBeliefV( size_t i, Prob &p ) const {
127 p = Prob( var(i).states(), props.logdomain ? 0.0 : 1.0 );
128 bforeach( const Neighbor &I, nbV(i) ) {
129 Real c_I = Weight(I);
130 if( props.logdomain )
131 p += newMessage( i, I.iter ) * c_I;
132 else
133 p *= newMessage( i, I.iter ) ^ c_I;
134 }
135 }
136
137
138 void TRWBP::construct() {
139 BP::construct();
140 _weight.resize( nrFactors(), 1.0 );
141 sampleWeights( nrtrees );
142 if( props.verbose >= 2 )
143 cerr << "Weights: " << _weight << endl;
144 }
145
146
147 void TRWBP::addTreeToWeights( const RootedTree &tree ) {
148 for( RootedTree::const_iterator e = tree.begin(); e != tree.end(); e++ ) {
149 VarSet ij( var(e->first), var(e->second) );
150 size_t I = findFactor( ij );
151 _weight[I] += 1.0;
152 }
153 }
154
155
156 void TRWBP::sampleWeights( size_t nrTrees ) {
157 if( !nrTrees )
158 return;
159
160 // initialize weights to zero
161 fill( _weight.begin(), _weight.end(), 0.0 );
162
163 // construct Markov adjacency graph, with edges weighted with
164 // random weights drawn from the uniform distribution on the interval [0,1]
165 WeightedGraph<Real> wg;
166 for( size_t i = 0; i < nrVars(); ++i ) {
167 const Var &v_i = var(i);
168 VarSet di = delta(i);
169 for( VarSet::const_iterator j = di.begin(); j != di.end(); j++ )
170 if( v_i < *j )
171 wg[UEdge(i,findVar(*j))] = rnd_uniform();
172 }
173
174 // now repeatedly change the random weights, find the minimal spanning tree, and add it to the weights
175 for( size_t nr = 0; nr < nrTrees; nr++ ) {
176 // find minimal spanning tree
177 RootedTree randTree = MinSpanningTree( wg, true );
178 // add it to the weights
179 addTreeToWeights( randTree );
180 // resample weights of the graph
181 for( WeightedGraph<Real>::iterator e = wg.begin(); e != wg.end(); e++ )
182 e->second = rnd_uniform();
183 }
184
185 // normalize the weights and set the single-variable weights to 1.0
186 for( size_t I = 0; I < nrFactors(); I++ ) {
187 size_t sizeI = factor(I).vars().size();
188 if( sizeI == 1 )
189 _weight[I] = 1.0;
190 else if( sizeI == 2 )
191 _weight[I] /= nrTrees;
192 else
193 DAI_THROW(NOT_IMPLEMENTED);
194 }
195 }
196
197
198 } // end of namespace dai