Added max-product version of BP
[libdai.git] / include / dai / bp.h
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #ifndef __defined_libdai_bp_h
23 #define __defined_libdai_bp_h
24
25
26 #include <string>
27 #include <dai/daialg.h>
28 #include <dai/factorgraph.h>
29 #include <dai/properties.h>
30 #include <dai/enum.h>
31
32
33 namespace dai {
34
35
36 class BP : public DAIAlgFG {
37 private:
38 typedef std::vector<size_t> ind_t;
39 struct EdgeProp {
40 ind_t index;
41 Prob message;
42 Prob newMessage;
43 double residual;
44 };
45 std::vector<std::vector<EdgeProp> > _edges;
46 /// Maximum difference encountered so far
47 double _maxdiff;
48 /// Number of iterations needed
49 size_t _iters;
50
51 public:
52 struct Properties {
53 size_t verbose;
54 size_t maxiter;
55 double tol;
56 bool logdomain;
57 double damping;
58 DAI_ENUM(UpdateType,SEQFIX,SEQRND,SEQMAX,PARALL)
59 UpdateType updates;
60 DAI_ENUM(InfType,SUMPROD,MAXPROD)
61 InfType inference;
62 } props;
63 static const char *Name;
64
65 public:
66 /// Default constructor
67 BP() : DAIAlgFG(), _edges(), _maxdiff(0.0), _iters(0U), props() {}
68
69 /// Construct from FactorGraph fg and PropertySet opts
70 BP( const FactorGraph & fg, const PropertySet &opts ) : DAIAlgFG(fg), _edges(), _maxdiff(0.0), _iters(0U), props() {
71 setProperties( opts );
72 construct();
73 }
74
75 /// Copy constructor
76 BP( const BP &x ) : DAIAlgFG(x), _edges(x._edges), _maxdiff(x._maxdiff), _iters(x._iters), props(x.props) {}
77
78 /// Clone *this (virtual copy constructor)
79 virtual BP* clone() const { return new BP(*this); }
80
81 /// Create (virtual default constructor)
82 virtual BP* create() const { return new BP(); }
83
84 /// Assignment operator
85 BP& operator=( const BP &x ) {
86 if( this != &x ) {
87 DAIAlgFG::operator=( x );
88 _edges = x._edges;
89 _maxdiff = x._maxdiff;
90 _iters = x._iters;
91 props = x.props;
92 }
93 return *this;
94 }
95
96 /// Identifies itself for logging purposes
97 virtual std::string identify() const;
98
99 /// Get single node belief
100 virtual Factor belief( const Var &n ) const;
101
102 /// Get general belief
103 virtual Factor belief( const VarSet &ns ) const;
104
105 /// Get all beliefs
106 virtual std::vector<Factor> beliefs() const;
107
108 /// Get log partition sum
109 virtual Real logZ() const;
110
111 /// Clear messages and beliefs
112 virtual void init();
113
114 /// Clear messages and beliefs corresponding to the nodes in ns
115 virtual void init( const VarSet &ns );
116
117 /// The actual approximate inference algorithm
118 virtual double run();
119
120 /// Return maximum difference between single node beliefs in the last pass
121 virtual double maxDiff() const { return _maxdiff; }
122
123 /// Return number of passes over the factorgraph
124 virtual size_t Iterations() const { return _iters; }
125
126
127 Factor beliefV( size_t i ) const;
128 Factor beliefF( size_t I ) const;
129
130 private:
131 const Prob & message(size_t i, size_t _I) const { return _edges[i][_I].message; }
132 Prob & message(size_t i, size_t _I) { return _edges[i][_I].message; }
133 Prob & newMessage(size_t i, size_t _I) { return _edges[i][_I].newMessage; }
134 const Prob & newMessage(size_t i, size_t _I) const { return _edges[i][_I].newMessage; }
135 ind_t & index(size_t i, size_t _I) { return _edges[i][_I].index; }
136 const ind_t & index(size_t i, size_t _I) const { return _edges[i][_I].index; }
137 double & residual(size_t i, size_t _I) { return _edges[i][_I].residual; }
138 const double & residual(size_t i, size_t _I) const { return _edges[i][_I].residual; }
139
140 void calcNewMessage( size_t i, size_t _I );
141 void updateMessage( size_t i, size_t _I ) {
142 if( props.damping == 0.0 ) {
143 message(i,_I) = newMessage(i,_I);
144 residual(i,_I) = 0.0;
145 } else {
146 message(i,_I) = (message(i,_I) ^ props.damping) * (newMessage(i,_I) ^ (1.0 - props.damping));
147 residual(i,_I) = dist( newMessage(i,_I), message(i,_I), Prob::DISTLINF );
148 }
149 }
150 void findMaxResidual( size_t &i, size_t &_I );
151
152 void construct();
153 /// Set Props according to the PropertySet opts, where the values can be stored as std::strings or as the type of the corresponding Props member
154 void setProperties( const PropertySet &opts );
155 PropertySet getProperties() const;
156 std::string printProperties() const;
157 };
158
159
160 } // end of namespace dai
161
162
163 #endif