Updated copyrights
[libdai.git] / include / dai / treeep.h
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 #ifndef __defined_libdai_treeep_h
24 #define __defined_libdai_treeep_h
25
26
27 #include <vector>
28 #include <string>
29 #include <dai/daialg.h>
30 #include <dai/varset.h>
31 #include <dai/regiongraph.h>
32 #include <dai/factorgraph.h>
33 #include <dai/clustergraph.h>
34 #include <dai/weightedgraph.h>
35 #include <dai/jtree.h>
36 #include <dai/properties.h>
37 #include <dai/enum.h>
38
39
40 namespace dai {
41
42
43 class TreeEP : public JTree {
44 private:
45 /// Maximum difference encountered so far
46 double _maxdiff;
47 /// Number of iterations needed
48 size_t _iters;
49
50 public:
51 struct Properties {
52 size_t verbose;
53 size_t maxiter;
54 double tol;
55 DAI_ENUM(TypeType,ORG,ALT)
56 TypeType type;
57 } props; // FIXME: should be props2 because of conflict with JTree::props?
58 /// Name of this inference method
59 static const char *Name;
60
61 private:
62 class TreeEPSubTree {
63 private:
64 std::vector<Factor> _Qa;
65 std::vector<Factor> _Qb;
66 DEdgeVec _RTree;
67 std::vector<size_t> _a; // _Qa[alpha] <-> superTree.Qa[_a[alpha]]
68 std::vector<size_t> _b; // _Qb[beta] <-> superTree.Qb[_b[beta]]
69 // _Qb[beta] <-> _RTree[beta]
70 const Factor * _I;
71 VarSet _ns;
72 VarSet _nsrem;
73 double _logZ;
74
75
76 public:
77 TreeEPSubTree() : _Qa(), _Qb(), _RTree(), _a(), _b(), _I(NULL), _ns(), _nsrem(), _logZ(0.0) {}
78 TreeEPSubTree( const TreeEPSubTree &x) : _Qa(x._Qa), _Qb(x._Qb), _RTree(x._RTree), _a(x._a), _b(x._b), _I(x._I), _ns(x._ns), _nsrem(x._nsrem), _logZ(x._logZ) {}
79 TreeEPSubTree & operator=( const TreeEPSubTree& x ) {
80 if( this != &x ) {
81 _Qa = x._Qa;
82 _Qb = x._Qb;
83 _RTree = x._RTree;
84 _a = x._a;
85 _b = x._b;
86 _I = x._I;
87 _ns = x._ns;
88 _nsrem = x._nsrem;
89 _logZ = x._logZ;
90 }
91 return *this;
92 }
93
94 TreeEPSubTree( const DEdgeVec &subRTree, const DEdgeVec &jt_RTree, const std::vector<Factor> &jt_Qa, const std::vector<Factor> &jt_Qb, const Factor *I );
95 void init();
96 void InvertAndMultiply( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb );
97 void HUGIN_with_I( std::vector<Factor> &Qa, std::vector<Factor> &Qb );
98 double logZ( const std::vector<Factor> &Qa, const std::vector<Factor> &Qb ) const;
99 const Factor *& I() { return _I; }
100 };
101
102 std::map<size_t, TreeEPSubTree> _Q;
103
104 public:
105 /// Default constructor
106 TreeEP() : JTree(), _maxdiff(0.0), _iters(0), props(), _Q() {}
107
108 /// Construct from FactorGraph fg and PropertySet opts
109 TreeEP( const FactorGraph &fg, const PropertySet &opts );
110
111 /// Copy constructor
112 TreeEP( const TreeEP &x ) : JTree(x), _maxdiff(x._maxdiff), _iters(x._iters), props(x.props), _Q(x._Q) {
113 for( size_t I = 0; I < nrFactors(); I++ )
114 if( offtree( I ) )
115 _Q[I].I() = &factor(I);
116 }
117
118 /// Clone *this (virtual copy constructor)
119 virtual TreeEP* clone() const { return new TreeEP(*this); }
120
121 /// Create (virtual default constructor)
122 virtual TreeEP* create() const { return new TreeEP(); }
123
124 /// Assignment operator
125 TreeEP& operator=( const TreeEP &x ) {
126 if( this != &x ) {
127 JTree::operator=( x );
128 _maxdiff = x._maxdiff;
129 _iters = x._iters;
130 props = x.props;
131 _Q = x._Q;
132 for( size_t I = 0; I < nrFactors(); I++ )
133 if( offtree( I ) )
134 _Q[I].I() = &factor(I);
135 }
136 return *this;
137 }
138
139 /// Identifies itself for logging purposes
140 virtual std::string identify() const;
141
142 /// Get log partition sum
143 virtual Real logZ() const;
144
145 /// Clear messages and beliefs
146 virtual void init();
147
148 /// Clear messages and beliefs corresponding to the nodes in ns
149 virtual void init( const VarSet &/*ns*/ ) { init(); }
150
151 /// The actual approximate inference algorithm
152 virtual double run();
153
154 /// Return maximum difference between single node beliefs in the last pass
155 virtual double maxDiff() const { return _maxdiff; }
156
157 /// Return number of passes over the factorgraph
158 virtual size_t Iterations() const { return _iters; }
159
160
161 void ConstructRG( const DEdgeVec &tree );
162 bool offtree( size_t I ) const { return (fac2OR[I] == -1U); }
163
164 void setProperties( const PropertySet &opts );
165 PropertySet getProperties() const;
166 std::string printProperties() const;
167 };
168
169
170 } // end of namespace dai
171
172
173 #endif