Fixed bug in findMaximum(): inconsistent max-beliefs are now detected,
[libdai.git] / src / matlab / dai.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 * Copyright (C) 2009 Sebastian Nowozin
10 */
11
12
13 #include <iostream>
14 #include <dai/matlab/matlab.h>
15 #include "mex.h"
16 #include <dai/alldai.h>
17 #include <dai/bp.h>
18 #include <dai/jtree.h>
19
20
21 using namespace std;
22 using namespace dai;
23
24
25 /* Input Arguments */
26
27 #define PSI_IN prhs[0]
28 #define METHOD_IN prhs[1]
29 #define OPTS_IN prhs[2]
30 #define NR_IN 3
31 #define NR_IN_OPT 0
32
33
34 /* Output Arguments */
35
36 #define LOGZ_OUT plhs[0]
37 #define Q_OUT plhs[1]
38 #define MD_OUT plhs[2]
39 #define QV_OUT plhs[3]
40 #define QF_OUT plhs[4]
41 #define QMAP_OUT plhs[5]
42 #define NR_OUT 3
43 #define NR_OUT_OPT 3
44
45
46 void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] ) {
47 size_t buflen;
48
49 // Check for proper number of arguments
50 if( ((nrhs < NR_IN) || (nrhs > NR_IN + NR_IN_OPT)) || ((nlhs < NR_OUT) || (nlhs > NR_OUT + NR_OUT_OPT)) ) {
51 mexErrMsgTxt("Usage: [logZ,q,md,qv,qf,qmap] = dai(psi,method,opts)\n\n"
52 "\n"
53 "INPUT: psi = linear cell array containing the factors\n"
54 " (psi{i} should be a structure with a Member field\n"
55 " and a P field).\n"
56 " method = name of the method\n"
57 " opts = string of options\n"
58 "\n"
59 "OUTPUT: logZ = approximation of the logarithm of the partition sum.\n"
60 " q = linear cell array containing all final beliefs.\n"
61 " md = maxdiff (final linf-dist between new and old single node beliefs).\n"
62 " qv = linear cell array containing all variable beliefs.\n"
63 " qf = linear cell array containing all factor beliefs.\n"
64 " qmap = linear array containing the MAP state (only for BP,JTree).\n");
65 }
66
67 char *method;
68 char *opts;
69
70
71 // Get psi and construct factorgraph
72 vector<Factor> factors = mx2Factors(PSI_IN, 0);
73 FactorGraph fg(factors);
74
75 // Get method
76 buflen = mxGetN( METHOD_IN ) + 1;
77 method = (char *)mxCalloc( buflen, sizeof(char) );
78 mxGetString( METHOD_IN, method, buflen );
79
80 // Get options string
81 buflen = mxGetN( OPTS_IN ) + 1;
82 opts = (char *)mxCalloc( buflen, sizeof(char) );
83 mxGetString( OPTS_IN, opts, buflen );
84 // Convert to options object props
85 stringstream ss;
86 ss << opts;
87 PropertySet props;
88 ss >> props;
89
90 // Construct InfAlg object, init and run
91 InfAlg *obj = newInfAlg( method, fg, props );
92 obj->init();
93 obj->run();
94
95 // Save logZ
96 double logZ = obj->logZ();
97
98 // Save maxdiff
99 double maxdiff = obj->maxDiff();
100
101 // Hand over results to MATLAB
102 LOGZ_OUT = mxCreateDoubleMatrix(1,1,mxREAL);
103 *(mxGetPr(LOGZ_OUT)) = logZ;
104
105 Q_OUT = Factors2mx(obj->beliefs());
106
107 MD_OUT = mxCreateDoubleMatrix(1,1,mxREAL);
108 *(mxGetPr(MD_OUT)) = maxdiff;
109
110 if( nlhs >= 4 ) {
111 vector<Factor> qv;
112 qv.reserve( fg.nrVars() );
113 for( size_t i = 0; i < fg.nrVars(); i++ )
114 qv.push_back( obj->belief( fg.var(i) ) );
115 QV_OUT = Factors2mx( qv );
116 }
117
118 if( nlhs >= 5 ) {
119 vector<Factor> qf;
120 qf.reserve( fg.nrFactors() );
121 for( size_t I = 0; I < fg.nrFactors(); I++ )
122 qf.push_back( obj->belief( fg.factor(I).vars() ) );
123 QF_OUT = Factors2mx( qf );
124 }
125
126 if( nlhs >= 6 ) {
127 std::vector<std::size_t> map_state;
128 bool supported = true;
129 try {
130 map_state = obj->findMaximum();
131 } catch( Exception &e ) {
132 if( e.code() == Exception::NOT_IMPLEMENTED )
133 supported = false;
134 else
135 throw;
136 }
137 if( supported ) {
138 QMAP_OUT = mxCreateNumericMatrix(map_state.size(), 1, mxUINT32_CLASS, mxREAL);
139 uint32_T* qmap_p = reinterpret_cast<uint32_T *>(mxGetPr(QMAP_OUT));
140 for (size_t n = 0; n < map_state.size(); ++n)
141 qmap_p[n] = map_state[n];
142 } else {
143 delete obj;
144 mexErrMsgTxt("Calculating a MAP state is not supported by this inference algorithm");
145 }
146 }
147
148 delete obj;
149 return;
150 }