Improved code in matlab/dai.cpp that tests whether findMaximum is supported
[libdai.git] / src / matlab / dai.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <dai/matlab/matlab.h>
14 #include "mex.h"
15 #include <dai/alldai.h>
16 #include <dai/bp.h>
17 #include <dai/jtree.h>
18
19
20 using namespace std;
21 using namespace dai;
22
23
24 /* Input Arguments */
25
26 #define PSI_IN prhs[0]
27 #define METHOD_IN prhs[1]
28 #define OPTS_IN prhs[2]
29 #define NR_IN 3
30 #define NR_IN_OPT 0
31
32
33 /* Output Arguments */
34
35 #define LOGZ_OUT plhs[0]
36 #define Q_OUT plhs[1]
37 #define MD_OUT plhs[2]
38 #define QV_OUT plhs[3]
39 #define QF_OUT plhs[4]
40 #define QMAP_OUT plhs[5]
41 #define NR_OUT 3
42 #define NR_OUT_OPT 3
43
44
45 void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] ) {
46 size_t buflen;
47
48 // Check for proper number of arguments
49 if( ((nrhs < NR_IN) || (nrhs > NR_IN + NR_IN_OPT)) || ((nlhs < NR_OUT) || (nlhs > NR_OUT + NR_OUT_OPT)) ) {
50 mexErrMsgTxt("Usage: [logZ,q,md,qv,qf,qmap] = dai(psi,method,opts)\n\n"
51 "\n"
52 "INPUT: psi = linear cell array containing the factors\n"
53 " (psi{i} should be a structure with a Member field\n"
54 " and a P field).\n"
55 " method = name of the method\n"
56 " opts = string of options\n"
57 "\n"
58 "OUTPUT: logZ = approximation of the logarithm of the partition sum.\n"
59 " q = linear cell array containing all final beliefs.\n"
60 " md = maxdiff (final linf-dist between new and old single node beliefs).\n"
61 " qv = linear cell array containing all variable beliefs.\n"
62 " qf = linear cell array containing all factor beliefs.\n"
63 " qmap = linear array containing the MAP state (only for BP,JTree).\n");
64 }
65
66 char *method;
67 char *opts;
68
69
70 // Get psi and construct factorgraph
71 vector<Factor> factors = mx2Factors(PSI_IN, 0);
72 FactorGraph fg(factors);
73
74 // Get method
75 buflen = mxGetN( METHOD_IN ) + 1;
76 method = (char *)mxCalloc( buflen, sizeof(char) );
77 mxGetString( METHOD_IN, method, buflen );
78
79 // Get options string
80 buflen = mxGetN( OPTS_IN ) + 1;
81 opts = (char *)mxCalloc( buflen, sizeof(char) );
82 mxGetString( OPTS_IN, opts, buflen );
83 // Convert to options object props
84 stringstream ss;
85 ss << opts;
86 PropertySet props;
87 ss >> props;
88
89 // Construct InfAlg object, init and run
90 InfAlg *obj = newInfAlg( method, fg, props );
91 obj->init();
92 obj->run();
93
94 // Save logZ
95 double logZ = obj->logZ();
96
97 // Save maxdiff
98 double maxdiff = obj->maxDiff();
99
100 // Hand over results to MATLAB
101 LOGZ_OUT = mxCreateDoubleMatrix(1,1,mxREAL);
102 *(mxGetPr(LOGZ_OUT)) = logZ;
103
104 Q_OUT = Factors2mx(obj->beliefs());
105
106 MD_OUT = mxCreateDoubleMatrix(1,1,mxREAL);
107 *(mxGetPr(MD_OUT)) = maxdiff;
108
109 if( nlhs >= 4 ) {
110 vector<Factor> qv;
111 qv.reserve( fg.nrVars() );
112 for( size_t i = 0; i < fg.nrVars(); i++ )
113 qv.push_back( obj->belief( fg.var(i) ) );
114 QV_OUT = Factors2mx( qv );
115 }
116
117 if( nlhs >= 5 ) {
118 vector<Factor> qf;
119 qf.reserve( fg.nrFactors() );
120 for( size_t I = 0; I < fg.nrFactors(); I++ )
121 qf.push_back( obj->belief( fg.factor(I).vars() ) );
122 QF_OUT = Factors2mx( qf );
123 }
124
125 if( nlhs >= 6 ) {
126 std::vector<std::size_t> map_state;
127 bool supported = true;
128 try {
129 map_state = obj->findMaximum();
130 } catch( Exception &e ) {
131 if( e.code() == Exception::NOT_IMPLEMENTED )
132 supported = false;
133 else
134 throw;
135 }
136 if( supported ) {
137 QMAP_OUT = mxCreateNumericMatrix(map_state.size(), 1, mxUINT32_CLASS, mxREAL);
138 uint32_T* qmap_p = reinterpret_cast<uint32_T *>(mxGetPr(QMAP_OUT));
139 for (size_t n = 0; n < map_state.size(); ++n)
140 qmap_p[n] = map_state[n];
141 } else {
142 delete obj;
143 mexErrMsgTxt("Calculating a MAP state is not supported by this inference algorithm");
144 }
145 }
146
147 delete obj;
148 return;
149 }