[Sebastian Nowozin] MEX file dai now also optionally returns the MAP state (only...
[libdai.git] / src / matlab / dai.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 /*=================================================================*
24 * *
25 * This is a MEX-file for MATLAB. *
26 * *
27 * [logZ,q,md,qv,qf] = dai(psi,method,opts); *
28 * or *
29 * [logZ,q,md,qv,qf,qmap] = dai(psi,method,opts); *
30 * *
31 *=================================================================*/
32
33
34 #include <iostream>
35 #include <dai/matlab/matlab.h>
36 #include "mex.h"
37 #include <dai/alldai.h>
38 #include <dai/bp.h>
39
40
41 using namespace std;
42 using namespace dai;
43
44
45 /* Input Arguments */
46
47 #define PSI_IN prhs[0]
48 #define METHOD_IN prhs[1]
49 #define OPTS_IN prhs[2]
50 #define NR_IN 3
51 #define NR_IN_OPT 0
52
53
54 /* Output Arguments */
55
56 #define LOGZ_OUT plhs[0]
57 #define Q_OUT plhs[1]
58 #define MD_OUT plhs[2]
59 #define QV_OUT plhs[3]
60 #define QF_OUT plhs[4]
61 #define QMAP_OUT plhs[5]
62 #define NR_OUT 3
63 #define NR_OUT_OPT 3
64
65
66 void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] )
67 {
68 size_t buflen;
69
70 // Check for proper number of arguments
71 if( ((nrhs < NR_IN) || (nrhs > NR_IN + NR_IN_OPT)) || ((nlhs < NR_OUT) || (nlhs > NR_OUT + NR_OUT_OPT)) ) {
72 mexErrMsgTxt("Usage: [logZ,q,md,qv,qf,qmap] = dai(psi,method,opts)\n\n"
73 "\n"
74 "INPUT: psi = linear cell array containing the factors \n"
75 " psi{i} should be a structure with a Member field\n"
76 " and a P field, like a CPTAB).\n"
77 " method = name of the method (see README)\n"
78 " opts = string of options (see README)\n"
79 "\n"
80 "OUTPUT: logZ = approximation of the logarithm of the partition sum.\n"
81 " q = linear cell array containing all final beliefs.\n"
82 " md = maxdiff (final linf-dist between new and old single node beliefs).\n"
83 " qv = linear cell array containing all variable beliefs.\n"
84 " qf = linear cell array containing all factor beliefs.\n"
85 " qmap = (V,1) array containing the MAP labeling (only for BP).\n");
86 }
87
88 char *method;
89 char *opts;
90
91
92 // Get psi and construct factorgraph
93 vector<Factor> factors = mx2Factors(PSI_IN, 0);
94 FactorGraph fg(factors);
95
96 // Get method
97 buflen = mxGetN( METHOD_IN ) + 1;
98 method = (char *)mxCalloc( buflen, sizeof(char) );
99 mxGetString( METHOD_IN, method, buflen );
100
101 // Get options string
102 buflen = mxGetN( OPTS_IN ) + 1;
103 opts = (char *)mxCalloc( buflen, sizeof(char) );
104 mxGetString( OPTS_IN, opts, buflen );
105 // Convert to options object props
106 stringstream ss;
107 ss << opts;
108 PropertySet props;
109 ss >> props;
110
111 // Construct InfAlg object, init and run
112 InfAlg *obj = newInfAlg( method, fg, props );
113 obj->init();
114 obj->run();
115
116
117 // Save logZ
118 double logZ = obj->logZ();
119
120 // Save maxdiff
121 double maxdiff = obj->maxDiff();
122
123
124 // Hand over results to MATLAB
125 LOGZ_OUT = mxCreateDoubleMatrix(1,1,mxREAL);
126 *(mxGetPr(LOGZ_OUT)) = logZ;
127
128 Q_OUT = Factors2mx(obj->beliefs());
129
130 MD_OUT = mxCreateDoubleMatrix(1,1,mxREAL);
131 *(mxGetPr(MD_OUT)) = maxdiff;
132
133 if( nlhs >= 4 ) {
134 vector<Factor> qv;
135 qv.reserve( fg.nrVars() );
136 for( size_t i = 0; i < fg.nrVars(); i++ )
137 qv.push_back( obj->belief( fg.var(i) ) );
138 QV_OUT = Factors2mx( qv );
139 }
140
141 if( nlhs >= 5 ) {
142 vector<Factor> qf;
143 qf.reserve( fg.nrFactors() );
144 for( size_t I = 0; I < fg.nrFactors(); I++ )
145 qf.push_back( obj->belief( fg.factor(I).vars() ) );
146 QF_OUT = Factors2mx( qf );
147 }
148
149 if( nlhs >= 6 ) {
150 BP* obj_bp = dynamic_cast<BP *>(obj);
151 if (obj_bp == 0) {
152 mexErrMsgTxt("MAP state assignment works only for BP.\n");
153 delete obj;
154
155 return;
156 }
157 std::vector<std::size_t> map_state = obj_bp->findMaximum();
158 QMAP_OUT = mxCreateNumericMatrix(map_state.size(), 1, mxUINT32_CLASS, mxREAL);
159 uint32_T* qmap_p = reinterpret_cast<uint32_T *>(mxGetPr(QMAP_OUT));
160 for (size_t n = 0; n < map_state.size(); ++n)
161 qmap_p[n] = map_state[n];
162 }
163 delete obj;
164
165 return;
166 }