Adopted contributions by Christian.
[libdai.git] / x2x.cpp
1 /*
2 Copyright (C) 2005 Martijn Leisink m.leisink at science.ru.nl
3
4 This program is free software; you can redistribute it and/or
5 modify it under the terms of the GNU General Public License
6 as published by the Free Software Foundation; either version 2
7 of the License, or (at your option) any later version.
8
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 GNU General Public License for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with this program; if not, write to the Free Software
16 Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
17
18 CHANGES:
19 2006-11-20 Joris Mooij
20 * removed MATLAB interface
21 * put into namespace "x2x"
22 */
23
24
25 #include <cmath>
26 #include <cstring>
27
28
29
30 namespace x2x {
31
32 // helper functions to compute the sum over all partitions
33 double psum (double *x, long s, int n=0);
34 double psumx (double *x, long a, long s, int n) {
35 // recursively process all one bits remaining after psum
36 // put some of them into the same subset and call psum again for the rest
37 if((s>>n)>1) {
38 while(!((s>>++n)&1));
39 return(psumx(x,a,s,n)+psumx(x,a^(1l<<n),s,n));
40 } else {
41 return(x[a]*psum(x,s^a,0));
42 };
43 }
44 double psum (double *x, long s, int n) {
45 // take the first one bit and put it in the first subset, then call psumx
46 if(s>>n) {
47 while(!((s>>n)&1)) ++n;
48 return(psumx(x,1l<<n,s,n));
49 } else return(1);
50 }
51
52 // convert moments to cumulants upto order k
53 void m2c (int N, double *x, int k) {
54 int *c=new int[k+1];
55 long s;
56 int z;
57 x[0]=log(x[0]); // to get the correct answer if not normalized
58 for(int b=1;b<=k;++b) {
59 // start with marginals, then two-point correlations and so on
60 c[z=b]=-1;
61 s=0; // index into x
62 do {
63 while(--z>=0) s^=(1l<<(c[z]=c[z+1]+1));
64 // c_ijk = m_ijk - c_i*c_jk - c_j*c_ik - c_k*c_ij - c_i*c_j*c_k
65 x[s]=2*x[s]-psum(x,s);
66 // increment b indices
67 for(z=0;z<b&&(s^=3l<<c[z],++c[z]==N-z);++z) s^=1l<<c[z];
68 } while(z<b);
69 };
70 delete[] c;
71 }
72
73 // convert cumulants to moments upto order k
74 void c2m (int N, double *x, int k) {
75 int *c=new int[k+1];
76 long s;
77 int z;
78 for(int b=k;b>=1;--b) {
79 // start with k-point cumulants, then k-1-point cumulants and so on
80 c[z=b]=-1;
81 s=0; // index into x
82 do {
83 while(--z>=0) s^=(1l<<(c[z]=c[z+1]+1));
84 // m_ijk = c_ijk + c_i*c_jk + c_j*c_ik + c_k*c_ij + c_i*c_j*c_k
85 x[s]=psum(x,s);
86 // increment b indices
87 for(z=0;z<b&&(s^=3l<<c[z],++c[z]==N-z);++z) s^=1l<<c[z];
88 } while(z<b);
89 };
90 x[0]=exp(x[0]); // to get the correct answer if not normalized
91 delete[] c;
92 }
93
94 // convert (generalized) weights to log probability or energy
95 void w2logp (int N, double *x) {
96 for(long s=1l<<N;s>>=1;) for(long j=1l<<N;(j-=(s<<1))>=0;)
97 for(long k=j+s;--k>=j;) { x[k+s]+=x[k]; x[k]=2*x[k]-x[k+s]; };
98 }
99
100 // convert log probability or energy to (generalized) weights
101 void logp2w (int N, double *x) {
102 for(long s=1l<<N;s>>=1;) for(long j=1l<<N;(j-=(s<<1))>=0;)
103 for(long k=j+s;--k>=j;) { x[k]=(x[k]+x[k+s])/2; x[k+s]-=x[k]; };
104 }
105
106 // convert probability to moments
107 void p2m (int N, double *x) {
108 for(long s=1l<<N;s>>=1;) for(long j=1l<<N;(j-=(s<<1))>=0;)
109 for(long k=j+s;--k>=j;) { x[k+s]-=x[k]; x[k]=2*x[k]+x[k+s]; };
110 }
111
112 // convert moments to probability
113 void m2p (int N, double *x) {
114 for(long s=1l<<N;s>>=1;) for(long j=1l<<N;(j-=(s<<1))>=0;)
115 for(long k=j+s;--k>=j;) { x[k]=(x[k]-x[k+s])/2; x[k+s]+=x[k]; };
116 }
117
118 // convert log probability to probability
119 void logp2p (int N, double *x) {
120 for(long s=1l<<N;s--;) x[s]=exp(x[s]);
121 }
122
123 // convert probability to log probability
124 void p2logp (int N, double *x) {
125 for(long s=1l<<N;s--;) x[s]=log(x[s]);
126 }
127
128 // normalize a log probability table
129 void logpnorm (int N, double *x) {
130 double f=x[0];
131 for(long s=1l<<N;--s;)
132 if(f>x[s]) f+=log1p(exp(x[s]-f));
133 else f=x[s]+log1p(exp(f-x[s]));
134 for(long s=1l<<N;s--;) x[s]-=f;
135 }
136
137 // normalize a probability table, use logpnorm whenever possible
138 void pnorm (int N, double *x) {
139 double z=0;
140 for(long s=1l<<N;s--;) z+=x[s];
141 for(long s=1l<<N;s--;) x[s]/=z;
142 }
143
144 // fills table with v for all entries with more than k indices
145 // used for example when cumulants or moments are converted upto some order
146 void fill (int N, double *x, int k, double v) {
147 if(k>N) return;
148 long ss=0,s;
149 int n=0;
150 for(long i=1l<<N;i--;) { // make use of gray code to count number of bits
151 if(n>k) x[ss]=v;
152 s=i^(i>>1);
153 if(s>ss) ++n; else --n;
154 ss=s;
155 };
156 }
157
158 }