1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
4 This file is part of libDAI.
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
23 #include <dai/daialg.h>
32 /// Calculate the marginal of obj on ns by clamping
33 /// all variables in ns and calculating logZ for each joined state
34 Factor
calcMarginal( const InfAlg
& obj
, const VarSet
& ns
, bool reInit
) {
39 InfAlg
*clamped
= obj
.clone();
44 for( size_t j
= 0; j
< mi
.max(); j
++ ) {
45 // save unclamped factors connected to ns
46 clamped
->saveProbs( ns
);
48 // set clamping Factors to delta functions
49 vector
<size_t> vi
= mi
.vi( j
);
51 for( VarSet::const_iterator n
= ns
.begin(); n
!= ns
.end(); n
++, k
++ )
52 clamped
->clamp( *n
, vi
[k
] );
54 // run DAIAlg, calc logZ, store in Pns
55 if( clamped
->Verbose() >= 2 )
63 logZ0
= clamped
->logZ();
66 // subtract logZ0 to avoid very large numbers
67 Z
= exp(clamped
->logZ() - logZ0
);
68 if( fabs(imag(Z
)) > 1e-5 )
69 cout
<< "Marginal:: WARNING: complex Z (" << Z
<< ")" << endl
;
74 // restore clamped factors
75 clamped
->undoProbs( ns
);
80 return( Pns
.normalized(Prob::NORMPROB
) );
84 vector
<Factor
> calcPairBeliefs( const InfAlg
& obj
, const VarSet
& ns
, bool reInit
) {
85 // convert ns to vector<VarSet>
89 for( VarSet::const_iterator n
= ns
.begin(); n
!= ns
.end(); n
++ )
92 vector
<Factor
> pairbeliefs
;
93 pairbeliefs
.reserve( N
* N
);
94 for( size_t j
= 0; j
< N
; j
++ )
95 for( size_t k
= 0; k
< N
; k
++ )
97 pairbeliefs
.push_back(Factor());
99 pairbeliefs
.push_back(Factor(vns
[j
] | vns
[k
]));
101 InfAlg
*clamped
= obj
.clone();
106 for( size_t j
= 0; j
< N
; j
++ ) {
107 // clamp Var j to its possible values
108 for( size_t j_val
= 0; j_val
< vns
[j
].states(); j_val
++ ) {
109 if( obj
.Verbose() >= 2 )
110 cout
<< j
<< "/" << N
-1 << " (" << j_val
<< "/" << vns
[j
].states() << "): ";
112 // save unclamped factors connected to ns
113 clamped
->saveProbs( ns
);
115 clamped
->clamp( vns
[j
], j_val
);
121 // logZ0 = obj.logZ();
123 if( j
== 0 && j_val
== 0 ) {
124 logZ0
= clamped
->logZ();
126 // subtract logZ0 to avoid very large numbers
127 Complex Z
= exp(clamped
->logZ() - logZ0
);
128 if( fabs(imag(Z
)) > 1e-5 )
129 cout
<< "calcPairBelief:: Warning: complex Z: " << Z
<< endl
;
133 for( size_t k
= 0; k
< N
; k
++ )
135 Factor b_k
= clamped
->belief(vns
[k
]);
136 for( size_t k_val
= 0; k_val
< vns
[k
].states(); k_val
++ )
137 if( vns
[j
].label() < vns
[k
].label() )
138 pairbeliefs
[j
* N
+ k
][j_val
+ (k_val
* vns
[j
].states())] = Z_xj
* b_k
[k_val
];
140 pairbeliefs
[j
* N
+ k
][k_val
+ (j_val
* vns
[k
].states())] = Z_xj
* b_k
[k_val
];
143 // restore clamped factors
144 clamped
->undoProbs( ns
);
150 // Calculate result by taking the geometric average
151 vector
<Factor
> result
;
152 result
.reserve( N
* (N
- 1) / 2 );
153 for( size_t j
= 0; j
< N
; j
++ )
154 for( size_t k
= j
+1; k
< N
; k
++ )
155 result
.push_back( (pairbeliefs
[j
* N
+ k
] * pairbeliefs
[k
* N
+ j
]) ^ 0.5 );
161 Factor
calcMarginal2ndO( const InfAlg
& obj
, const VarSet
& ns
, bool reInit
) {
162 // returns a a probability distribution whose 1st order interactions
163 // are unspecified, whose 2nd order interactions approximate those of
164 // the marginal on ns, and whose higher order interactions are absent.
166 vector
<Factor
> pairbeliefs
= calcPairBeliefs( obj
, ns
, reInit
);
169 for( size_t ij
= 0; ij
< pairbeliefs
.size(); ij
++ )
170 Pns
*= pairbeliefs
[ij
];
172 return( Pns
.normalized(Prob::NORMPROB
) );
176 vector
<Factor
> calcPairBeliefsNew( const InfAlg
& obj
, const VarSet
& ns
, bool reInit
) {
177 vector
<Factor
> result
;
178 result
.reserve( ns
.size() * (ns
.size() - 1) / 2 );
180 InfAlg
*clamped
= obj
.clone();
185 VarSet::const_iterator nj
= ns
.begin();
186 for( long j
= 0; j
< (long)ns
.size() - 1; j
++, nj
++ ) {
188 for( VarSet::const_iterator nk
= nj
; (++nk
) != ns
.end(); k
++ ) {
189 Factor
pairbelief( *nj
| *nk
);
191 // clamp Vars j and k to their possible values
192 for( size_t j_val
= 0; j_val
< nj
->states(); j_val
++ )
193 for( size_t k_val
= 0; k_val
< nk
->states(); k_val
++ ) {
194 // save unclamped factors connected to ns
195 clamped
->saveProbs( ns
);
197 clamped
->clamp( *nj
, j_val
);
198 clamped
->clamp( *nk
, k_val
);
204 if( j_val
== 0 && k_val
== 0 ) {
205 logZ0
= clamped
->logZ();
207 // subtract logZ0 to avoid very large numbers
208 Complex Z
= exp(clamped
->logZ() - logZ0
);
209 if( fabs(imag(Z
)) > 1e-5 )
210 cout
<< "calcPairBelief:: Warning: complex Z: " << Z
<< endl
;
214 // we assume that j.label() < k.label()
215 // i.e. we make an assumption here about the indexing
216 pairbelief
[j_val
+ (k_val
* nj
->states())] = Z_xj
;
218 // restore clamped factors
219 clamped
->undoProbs( ns
);
222 result
.push_back( pairbelief
);
228 assert( result
.size() == (ns
.size() * (ns
.size() - 1) / 2) );
234 } // end of namespace dai