fd039a63955210d50dce7790c26222913bb9da11
[libdai.git] / include / dai / treeep.h
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #ifndef __defined_libdai_treeep_h
23 #define __defined_libdai_treeep_h
24
25
26 #include <vector>
27 #include <string>
28 #include <dai/daialg.h>
29 #include <dai/varset.h>
30 #include <dai/regiongraph.h>
31 #include <dai/factorgraph.h>
32 #include <dai/clustergraph.h>
33 #include <dai/weightedgraph.h>
34 #include <dai/jtree.h>
35 #include <dai/properties.h>
36 #include <dai/enum.h>
37
38
39 namespace dai {
40
41
42 class TreeEP : public JTree {
43 protected:
44 /// Maximum difference encountered so far
45 double _maxdiff;
46 /// Number of iterations needed
47 size_t _iters;
48
49 public:
50 struct Properties {
51 size_t verbose;
52 size_t maxiter;
53 double tol;
54 DAI_ENUM(TypeType,ORG,ALT)
55 TypeType type;
56 } props; // FIXME: should be props2 because of conflict with JTree::props?
57 /// Name of this inference method
58 static const char *Name;
59
60 protected:
61 class TreeEPSubTree {
62 protected:
63 std::vector<Factor> _Qa;
64 std::vector<Factor> _Qb;
65 DEdgeVec _RTree;
66 std::vector<size_t> _a; // _Qa[alpha] <-> superTree._Qa[_a[alpha]]
67 std::vector<size_t> _b; // _Qb[beta] <-> superTree._Qb[_b[beta]]
68 // _Qb[beta] <-> _RTree[beta]
69 const Factor * _I;
70 VarSet _ns;
71 VarSet _nsrem;
72 double _logZ;
73
74
75 public:
76 TreeEPSubTree() : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(NULL), _ns(), _nsrem(), _logZ(0.0) {}
77 TreeEPSubTree( const TreeEPSubTree &x) : _Qa(x._Qa), _Qb(x._Qb), _RTree(x._RTree), _a(x._a), _b(x._b), _I(x._I), _ns(x._ns), _nsrem(x._nsrem), _logZ(x._logZ) {}
78 TreeEPSubTree & operator=( const TreeEPSubTree& x ) {
79 if( this != &x ) {
80 _Qa = x._Qa;
81 _Qb = x._Qb;
82 _RTree = x._RTree;
83 _a = x._a;
84 _b = x._b;
85 _I = x._I;
86 _ns = x._ns;
87 _nsrem = x._nsrem;
88 _logZ = x._logZ;
89 }
90 return *this;
91 }
92
93 TreeEPSubTree( const DEdgeVec &subRTree, const DEdgeVec &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I );
94 void init();
95 void InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb );
96 void HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb );
97 double logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const;
98 const Factor *& I() { return _I; }
99 };
100
101 std::map<size_t, TreeEPSubTree> _Q;
102
103 public:
104 /// Default constructor
105 TreeEP() : JTree(), _maxdiff(0.0), _iters(0), props(), _Q() {}
106
107 /// Construct from FactorGraph fg and PropertySet opts
108 TreeEP( const FactorGraph &fg, const PropertySet &opts );
109
110 /// Copy constructor
111 TreeEP( const TreeEP &x ) : JTree(x), _maxdiff(x._maxdiff), _iters(x._iters), props(x.props), _Q(x._Q) {
112 for( size_t I = 0; I < nrFactors(); I++ )
113 if( offtree( I ) )
114 _Q[I].I() = &factor(I);
115 }
116
117 /// Clone *this (virtual copy constructor)
118 virtual TreeEP* clone() const { return new TreeEP(*this); }
119
120 /// Create (virtual default constructor)
121 virtual TreeEP* create() const { return new TreeEP(); }
122
123 /// Assignment operator
124 TreeEP& operator=( const TreeEP &x ) {
125 if( this != &x ) {
126 JTree::operator=( x );
127 _maxdiff = x._maxdiff;
128 _iters = x._iters;
129 props = x.props;
130 _Q = x._Q;
131 for( size_t I = 0; I < nrFactors(); I++ )
132 if( offtree( I ) )
133 _Q[I].I() = &factor(I);
134 }
135 return *this;
136 }
137
138 /// Identifies itself for logging purposes
139 virtual std::string identify() const;
140
141 /// Get log partition sum
142 virtual Real logZ() const;
143
144 /// Clear messages and beliefs
145 virtual void init();
146
147 /// Clear messages and beliefs corresponding to the nodes in ns
148 virtual void init( const VarSet &/*ns*/ ) { init(); }
149
150 /// The actual approximate inference algorithm
151 virtual double run();
152
153 /// Return maximum difference between single node beliefs in the last pass
154 virtual double maxDiff() const { return _maxdiff; }
155
156 /// Return number of passes over the factorgraph
157 virtual size_t Iterations() const { return _iters; }
158
159
160 void ConstructRG( const DEdgeVec &tree );
161 bool offtree( size_t I ) const { return (fac2OR[I] == -1U); }
162
163 void setProperties( const PropertySet &opts );
164 PropertySet getProperties() const;
165 std::string printProperties() const;
166 };
167
168
169 } // end of namespace dai
170
171
172 #endif