Improvements to BP_dual code
[libdai.git] / include / dai / bp_dual.h
1 /* Copyright (C) 2009 Frederik Eaton [frederik at ofb dot net]
2
3 This file is part of libDAI.
4
5 libDAI is free software; you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation; either version 2 of the License, or
8 (at your option) any later version.
9
10 libDAI is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 GNU General Public License for more details.
14
15 You should have received a copy of the GNU General Public License
16 along with libDAI; if not, write to the Free Software
17 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
18 */
19
20
21 /// \file
22 /// \brief Defines class BP_dual
23 /// \todo Improve documentation
24 /// \todo Clean up
25
26
27 #ifndef ____defined_libdai_bp_dual_h__
28 #define ____defined_libdai_bp_dual_h__
29
30
31 #include <dai/daialg.h>
32 #include <dai/factorgraph.h>
33 #include <dai/enum.h>
34 #include <dai/bp.h>
35
36
37 namespace dai {
38
39
40 struct BP_dual_messages {
41 // messages:
42 // indexed by edge index (using VV2E)
43 std::vector<Prob> n;
44 std::vector<Real> Zn;
45 std::vector<Prob> m;
46 std::vector<Real> Zm;
47 };
48
49
50 struct BP_dual_beliefs {
51 // beliefs:
52 // indexed by node
53 std::vector<Prob> b1;
54 std::vector<Real> Zb1;
55 // indexed by factor
56 std::vector<Prob> b2;
57 std::vector<Real> Zb2;
58 };
59
60
61 void _clamp( FactorGraph &g, const Var &n, const std::vector<size_t> &is );
62
63
64 /// Clamp a factor to have one of a set of values
65 void _clampFactor( FactorGraph &g, size_t I, const std::vector<size_t> &is );
66
67
68 class BP_dual : public DAIAlgFG {
69 public:
70 typedef std::vector<size_t> _ind_t;
71
72 protected:
73 // indexed by edge index. for each edge i->I, contains a
74 // vector whose entries correspond to those of I, and the
75 // value of each entry is the corresponding entry of i
76 std::vector<_ind_t> _indices;
77
78 BP_dual_messages _msgs;
79 BP_dual_messages _new_msgs;
80 public:
81 BP_dual_beliefs _beliefs;
82
83 size_t _iters;
84 double _maxdiff;
85
86 struct Properties {
87 typedef BP::Properties::UpdateType UpdateType;
88 UpdateType updates;
89 double tol;
90 size_t maxiter;
91 size_t verbose;
92 } props;
93
94 /// List of property names
95 static const char *PropertyList[];
96 /// Name of this inference algorithm
97 static const char *Name;
98
99 public:
100 void Regenerate(); // used by constructor
101 void RegenerateIndices();
102 void RegenerateMessages();
103 void RegenerateBeliefs();
104
105 void CalcBelief1(size_t i);
106 void CalcBelief2(size_t I);
107 void CalcBeliefs(); // called after run()
108
109 void calcNewM(size_t iI);
110 void calcNewN(size_t iI);
111 void upMsgM(size_t iI);
112 void upMsgN(size_t iI);
113
114 /* DAI_ENUM(UpdateType,SEQFIX,SEQRND,SEQMAX,PARALL) */
115 typedef BP::Properties::UpdateType UpdateType;
116 UpdateType Updates() const { return props.updates; }
117 size_t Verbose() const { return props.verbose; }
118
119 /// Default constructor
120 BP_dual() {}
121
122 /// construct BP_dual object from FactorGraph
123 BP_dual(const FactorGraph & fg, const PropertySet &opts) : DAIAlgFG(fg) {
124 setProperties(opts);
125 Regenerate();
126 }
127
128 DAI_ACCMUT(Prob & msgM(size_t I, size_t i), { return _msgs.m[VV2E(i,I)]; });
129 DAI_ACCMUT(Prob & msgN(size_t i, size_t I), { return _msgs.n[VV2E(i,I)]; });
130 DAI_ACCMUT(Prob & msgM(size_t iI), { return _msgs.m[iI]; });
131 DAI_ACCMUT(Prob & msgN(size_t iI), { return _msgs.n[iI]; });
132 DAI_ACCMUT(Real & zM(size_t I, size_t i), { return _msgs.Zm[VV2E(i,I)]; });
133 DAI_ACCMUT(Real & zN(size_t i, size_t I), { return _msgs.Zn[VV2E(i,I)]; });
134 DAI_ACCMUT(Real & zM(size_t iI), { return _msgs.Zm[iI]; });
135 DAI_ACCMUT(Real & zN(size_t iI), { return _msgs.Zn[iI]; });
136 DAI_ACCMUT(Prob & newMsgM(size_t I, size_t i), { return _new_msgs.m[VV2E(i,I)]; });
137 DAI_ACCMUT(Prob & newMsgN(size_t i, size_t I), { return _new_msgs.n[VV2E(i,I)]; });
138 DAI_ACCMUT(Real & newZM(size_t I, size_t i), { return _new_msgs.Zm[VV2E(i,I)]; });
139 DAI_ACCMUT(Real & newZN(size_t i, size_t I), { return _new_msgs.Zn[VV2E(i,I)]; });
140
141 DAI_ACCMUT(_ind_t & index(size_t i, size_t I), { return( _indices[VV2E(i,I)] ); });
142
143 Real belief1Z(size_t i) const { return _beliefs.Zb1[i]; }
144 Real belief2Z(size_t I) const { return _beliefs.Zb2[I]; }
145
146 size_t doneIters() const { return _iters; }
147
148
149 /// @name General InfAlg interface
150 //@{
151 virtual BP_dual* clone() const { return new BP_dual(*this); }
152 virtual BP_dual* create() const { return new BP_dual(); }
153 // virtual BP_dual* create() const { return NULL; }
154 virtual std::string identify() const;
155 virtual Factor belief (const Var &n) const { return( belief1( findVar( n ) ) ); }
156 virtual Factor belief (const VarSet &n) const;
157 virtual std::vector<Factor> beliefs() const;
158 virtual Real logZ() const;
159 virtual void init();
160 virtual void init( const VarSet &ns );
161 virtual double run();
162 virtual double maxDiff() const { return _maxdiff; }
163 virtual size_t Iterations() const { return _iters; }
164 //@}
165
166 void init(const std::vector<size_t>& state);
167 Factor belief1 (size_t i) const { return Factor(var(i), _beliefs.b1[i]); }
168 Factor belief2 (size_t I) const { return Factor(factor(I).vars(), _beliefs.b2[I]); }
169
170 void updateMaxDiff( double maxdiff ) { if( maxdiff > _maxdiff ) _maxdiff = maxdiff; }
171
172 /// Set Props according to the PropertySet opts, where the values can be stored as std::strings or as the type of the corresponding Props member
173 void setProperties( const PropertySet &opts );
174 PropertySet getProperties() const;
175 std::string printProperties() const;
176 };
177
178
179 } // end of namespace dai
180
181
182 #endif