a3f2f536c2d4e394a06a71a597d96a41d4f90d89
[libdai.git] / src / matlab / dai.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <iostream>
10 #include <dai/matlab/matlab.h>
11 #include "mex.h"
12 #include <dai/alldai.h>
13 #include <dai/bp.h>
14 #include <dai/jtree.h>
15
16
17 using namespace std;
18 using namespace dai;
19
20
21 /* Input Arguments */
22
23 #define PSI_IN prhs[0]
24 #define METHOD_IN prhs[1]
25 #define OPTS_IN prhs[2]
26 #define NR_IN 3
27 #define NR_IN_OPT 0
28
29
30 /* Output Arguments */
31
32 #define LOGZ_OUT plhs[0]
33 #define Q_OUT plhs[1]
34 #define MD_OUT plhs[2]
35 #define QV_OUT plhs[3]
36 #define QF_OUT plhs[4]
37 #define QMAP_OUT plhs[5]
38 #define NR_OUT 3
39 #define NR_OUT_OPT 3
40
41
42 void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] ) {
43 size_t buflen;
44
45 // Check for proper number of arguments
46 if( ((nrhs < NR_IN) || (nrhs > NR_IN + NR_IN_OPT)) || ((nlhs < NR_OUT) || (nlhs > NR_OUT + NR_OUT_OPT)) ) {
47 mexErrMsgTxt("Usage: [logZ,q,md,qv,qf,qmap] = dai(psi,method,opts)\n\n"
48 "\n"
49 "INPUT: psi = linear cell array containing the factors\n"
50 " (psi{i} should be a structure with a Member field\n"
51 " and a P field).\n"
52 " method = name of the method\n"
53 " opts = string of options\n"
54 "\n"
55 "OUTPUT: logZ = approximation of the logarithm of the partition sum.\n"
56 " q = linear cell array containing all final beliefs.\n"
57 " md = maxdiff (final linf-dist between new and old single node beliefs).\n"
58 " qv = linear cell array containing all variable beliefs.\n"
59 " qf = linear cell array containing all factor beliefs.\n"
60 " qmap = linear array containing the MAP state (only for BP,JTree).\n");
61 }
62
63 char *method;
64 char *opts;
65
66
67 // Get psi and construct factorgraph
68 vector<Factor> factors = mx2Factors(PSI_IN, 0);
69 FactorGraph fg(factors);
70
71 // Get method
72 buflen = mxGetN( METHOD_IN ) + 1;
73 method = (char *)mxCalloc( buflen, sizeof(char) );
74 mxGetString( METHOD_IN, method, buflen );
75
76 // Get options string
77 buflen = mxGetN( OPTS_IN ) + 1;
78 opts = (char *)mxCalloc( buflen, sizeof(char) );
79 mxGetString( OPTS_IN, opts, buflen );
80 // Convert to options object props
81 stringstream ss;
82 ss << opts;
83 PropertySet props;
84 ss >> props;
85
86 // Construct InfAlg object, init and run
87 InfAlg *obj = newInfAlg( method, fg, props );
88 obj->init();
89 obj->run();
90
91 // Save logZ
92 double logZ = NAN;
93 try {
94 logZ = obj->logZ();
95 }
96 catch( Exception &e ) {
97 if( e.getCode() == Exception::NOT_IMPLEMENTED )
98 mexWarnMsgTxt("Calculating the log-partition function is not supported by this inference algorithm.");
99 else
100 throw;
101 }
102
103 // Save maxdiff
104 double maxdiff = NAN;
105 try {
106 maxdiff = obj->maxDiff();
107 }
108 catch( Exception &e ) {
109 if( e.getCode() == Exception::NOT_IMPLEMENTED )
110 mexWarnMsgTxt("Calculating the max-differences is not supported by this inference algorithm.");
111 else
112 throw;
113 }
114
115 // Hand over results to MATLAB
116 LOGZ_OUT = mxCreateDoubleMatrix(1,1,mxREAL);
117 *(mxGetPr(LOGZ_OUT)) = logZ;
118
119 Q_OUT = Factors2mx(obj->beliefs());
120
121 MD_OUT = mxCreateDoubleMatrix(1,1,mxREAL);
122 *(mxGetPr(MD_OUT)) = maxdiff;
123
124 if( nlhs >= 4 ) {
125 vector<Factor> qv;
126 qv.reserve( fg.nrVars() );
127 for( size_t i = 0; i < fg.nrVars(); i++ )
128 qv.push_back( obj->belief( fg.var(i) ) );
129 QV_OUT = Factors2mx( qv );
130 }
131
132 if( nlhs >= 5 ) {
133 vector<Factor> qf;
134 qf.reserve( fg.nrFactors() );
135 for( size_t I = 0; I < fg.nrFactors(); I++ )
136 qf.push_back( obj->belief( fg.factor(I).vars() ) );
137 QF_OUT = Factors2mx( qf );
138 }
139
140 if( nlhs >= 6 ) {
141 std::vector<std::size_t> map_state;
142 bool supported = true;
143 try {
144 map_state = obj->findMaximum();
145 } catch( Exception &e ) {
146 if( e.getCode() == Exception::NOT_IMPLEMENTED )
147 supported = false;
148 else
149 throw;
150 }
151 if( supported ) {
152 QMAP_OUT = mxCreateNumericMatrix(map_state.size(), 1, mxUINT32_CLASS, mxREAL);
153 uint32_T* qmap_p = reinterpret_cast<uint32_T *>(mxGetPr(QMAP_OUT));
154 for (size_t n = 0; n < map_state.size(); ++n)
155 qmap_p[n] = map_state[n];
156 } else {
157 delete obj;
158 mexErrMsgTxt("Calculating a MAP state is not supported by this inference algorithm.");
159 }
160 }
161
162 delete obj;
163
164 return;
165 }