Extended SWIG python interface (inspired by Kyle Ellrott): inference is possible...
[libdai.git] / src / matlab / dai_jtree.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <iostream>
10 #include <dai/matlab/matlab.h>
11 #include "mex.h"
12 #include <dai/jtree.h>
13
14
15 using namespace std;
16 using namespace dai;
17
18
19 /* Convert cell vector of Matlab sets to vector<VarSet> */
20 vector<VarSet> mx2VarSets(const mxArray *vs, const FactorGraph &fg, long verbose, vector<Permute> &perms) {
21 vector<VarSet> varsets;
22
23 int n1 = mxGetM(vs);
24 int n2 = mxGetN(vs);
25 if( n2 != 1 && n1 != 1 )
26 mexErrMsgTxt("varsets should be a Nx1 or 1xN cell matrix.");
27 size_t nr_vs = n1;
28 if( n1 == 1 )
29 nr_vs = n2;
30
31 // interpret vs, linear cell array of varsets
32 varsets.reserve( nr_vs );
33 perms.clear();
34 perms.reserve( nr_vs );
35 for( size_t cellind = 0; cellind < nr_vs; cellind++ ) {
36 if( verbose >= 3 )
37 cerr << "reading varset " << cellind << ": " << endl;
38 mxArray *cell = mxGetCell(vs, cellind);
39 if( verbose >= 3 )
40 cerr << " got cell " << endl;
41 size_t nr_mem = mxGetN(cell);
42 if( verbose >= 3 )
43 cerr << " number members: " << nr_mem << endl;
44 double *members = mxGetPr(cell);
45 if( verbose >= 3 )
46 cerr << " got them! " << endl;
47
48 // add variables
49 VarSet vsvars;
50 if( verbose >= 3 )
51 cerr << " vars: ";
52 vector<long> labels(nr_mem,0);
53 vector<size_t> dims(nr_mem,0);
54 for( size_t mi = 0; mi < nr_mem; mi++ ) {
55 labels[mi] = (long)members[mi];
56 dims[mi] = fg.var(labels[mi]).states();
57 if( verbose >= 3 )
58 cerr << labels[mi] << " ";
59 vsvars.insert( fg.var(labels[mi]) );
60 }
61 if( verbose >= 3 )
62 cerr << endl;
63 DAI_ASSERT( nr_mem == vsvars.size() );
64 varsets.push_back(vsvars);
65
66 // calculate permutation matrix
67 vector<size_t> perm(nr_mem,0);
68 VarSet::iterator j = vsvars.begin();
69 for( size_t mi = 0; mi < nr_mem; mi++,j++ ) {
70 long gezocht = j->label();
71 vector<long>::iterator piet = find(labels.begin(),labels.end(),gezocht);
72 perm[mi] = piet - labels.begin();
73 }
74 if( verbose >= 3 ) {
75 cerr << endl << " perm: ";
76 for( vector<size_t>::iterator r=perm.begin(); r!=perm.end(); r++ )
77 cerr << *r << " ";
78 cerr << endl;
79 }
80 // create Permute object
81 vector<size_t> di(nr_mem,0);
82 size_t prod = 1;
83 for( size_t k = 0; k < nr_mem; k++ ) {
84 di[k] = dims[k];
85 prod *= dims[k];
86 }
87 Permute permindex( di, perm );
88 perms.push_back( permindex );
89 }
90
91 if( verbose >= 3 ) {
92 for(vector<VarSet>::const_iterator I=varsets.begin(); I!=varsets.end(); I++ )
93 cerr << *I << endl;
94 }
95
96 return( varsets );
97 }
98
99
100 /* Input Arguments */
101
102 #define PSI_IN prhs[0]
103 #define VARSETS_IN prhs[1]
104 #define OPTS_IN prhs[2]
105 #define NR_IN 3
106 #define NR_IN_OPT 0
107
108
109 /* Output Arguments */
110
111 #define LOGZ_OUT plhs[0]
112 #define Q_OUT plhs[1]
113 #define QV_OUT plhs[2]
114 #define QF_OUT plhs[3]
115 #define QMAP_OUT plhs[4]
116 #define MARGS_OUT plhs[5]
117 #define NR_OUT 3
118 #define NR_OUT_OPT 3
119
120
121 void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] ) {
122 // Check for proper number of arguments
123 if( ((nrhs < NR_IN) || (nrhs > NR_IN + NR_IN_OPT)) || ((nlhs < NR_OUT) || (nlhs > NR_OUT + NR_OUT_OPT)) ) {
124 mexErrMsgTxt("Usage: [logZ,q,qv,qf,qmap,margs] = dai_jtree(psi,varsets,opts)\n\n"
125 "\n"
126 "INPUT: psi = linear cell array containing the factors\n"
127 " (psi{i} should be a structure with a Member field\n"
128 " and a P field).\n"
129 " varsets = linear cell array containing varsets for which marginals\n"
130 " are requested.\n"
131 " opts = string of options.\n"
132 "\n"
133 "OUTPUT: logZ = logarithm of the partition sum.\n"
134 " q = linear cell array containing all calculated marginals.\n"
135 " qv = linear cell array containing all variable marginals.\n"
136 " qf = linear cell array containing all factor marginals.\n"
137 " qmap = linear array containing the MAP state.\n"
138 " margs = linear cell array containing all requested marginals.\n");
139 }
140
141 // Get psi and construct factorgraph
142 vector<Factor> factors = mx2Factors(PSI_IN, 0);
143 FactorGraph fg(factors);
144
145 // Get varsets
146 vector<Permute> perms;
147 vector<VarSet> varsets = mx2VarSets(VARSETS_IN,fg,0,perms);
148
149 // Get options string
150 char *opts;
151 size_t buflen = mxGetN( OPTS_IN ) + 1;
152 opts = (char *)mxCalloc( buflen, sizeof(char) );
153 mxGetString( OPTS_IN, opts, buflen );
154 // Convert to options object props
155 stringstream ss;
156 ss << opts;
157 PropertySet props;
158 ss >> props;
159
160 // Construct InfAlg object, init and run
161 JTree jt = JTree( fg, props );
162 jt.init();
163 jt.run();
164
165 // Save logZ
166 double logZ = NAN;
167 logZ = jt.logZ();
168
169 // Hand over results to MATLAB
170 LOGZ_OUT = mxCreateDoubleMatrix(1,1,mxREAL);
171 *(mxGetPr(LOGZ_OUT)) = logZ;
172
173 Q_OUT = Factors2mx(jt.beliefs());
174
175 if( nlhs >= 3 ) {
176 vector<Factor> qv;
177 qv.reserve( fg.nrVars() );
178 for( size_t i = 0; i < fg.nrVars(); i++ )
179 qv.push_back( jt.belief( fg.var(i) ) );
180 QV_OUT = Factors2mx( qv );
181 }
182
183 if( nlhs >= 4 ) {
184 vector<Factor> qf;
185 qf.reserve( fg.nrFactors() );
186 for( size_t I = 0; I < fg.nrFactors(); I++ )
187 qf.push_back( jt.belief( fg.factor(I).vars() ) );
188 QF_OUT = Factors2mx( qf );
189 }
190
191 if( nlhs >= 5 ) {
192 std::vector<size_t> map_state;
193 bool supported = true;
194 try {
195 map_state = jt.findMaximum();
196 } catch( Exception &e ) {
197 if( e.getCode() == Exception::NOT_IMPLEMENTED )
198 supported = false;
199 else
200 throw;
201 }
202 if( supported ) {
203 QMAP_OUT = mxCreateNumericMatrix(map_state.size(), 1, mxUINT32_CLASS, mxREAL);
204 uint32_T* qmap_p = reinterpret_cast<uint32_T *>(mxGetPr(QMAP_OUT));
205 for (size_t n = 0; n < map_state.size(); ++n)
206 qmap_p[n] = map_state[n];
207 } else {
208 mexErrMsgTxt("Calculating a MAP state is not supported by this inference algorithm.");
209 }
210 }
211
212 if( nlhs >= 6 ) {
213 vector<Factor> margs;
214 margs.reserve( varsets.size() );
215 for( size_t s = 0; s < varsets.size(); s++ ) {
216 Factor marg;
217 jt.init();
218 jt.run();
219 marg = jt.calcMarginal( varsets[s] );
220
221 // permute entries of marg
222 Factor margperm = marg;
223 for( size_t li = 0; li < marg.nrStates(); li++ )
224 margperm.set( li, marg[perms[s].convertLinearIndex(li)] );
225 margs.push_back( margperm );
226 }
227 MARGS_OUT = Factors2mx( margs );
228 }
229
230 return;
231 }