Updated ChangeLog and added some backwards compatibility stuff
[libdai.git] / src / matlab / dai.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <dai/matlab/matlab.h>
14 #include "mex.h"
15 #include <dai/alldai.h>
16 #include <dai/bp.h>
17
18
19 using namespace std;
20 using namespace dai;
21
22
23 /* Input Arguments */
24
25 #define PSI_IN prhs[0]
26 #define METHOD_IN prhs[1]
27 #define OPTS_IN prhs[2]
28 #define NR_IN 3
29 #define NR_IN_OPT 0
30
31
32 /* Output Arguments */
33
34 #define LOGZ_OUT plhs[0]
35 #define Q_OUT plhs[1]
36 #define MD_OUT plhs[2]
37 #define QV_OUT plhs[3]
38 #define QF_OUT plhs[4]
39 #define QMAP_OUT plhs[5]
40 #define NR_OUT 3
41 #define NR_OUT_OPT 3
42
43
44 void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] ) {
45 size_t buflen;
46
47 // Check for proper number of arguments
48 if( ((nrhs < NR_IN) || (nrhs > NR_IN + NR_IN_OPT)) || ((nlhs < NR_OUT) || (nlhs > NR_OUT + NR_OUT_OPT)) ) {
49 mexErrMsgTxt("Usage: [logZ,q,md,qv,qf,qmap] = dai(psi,method,opts)\n\n"
50 "\n"
51 "INPUT: psi = linear cell array containing the factors \n"
52 " psi{i} should be a structure with a Member field\n"
53 " and a P field, like a CPTAB).\n"
54 " method = name of the method (see README)\n"
55 " opts = string of options (see README)\n"
56 "\n"
57 "OUTPUT: logZ = approximation of the logarithm of the partition sum.\n"
58 " q = linear cell array containing all final beliefs.\n"
59 " md = maxdiff (final linf-dist between new and old single node beliefs).\n"
60 " qv = linear cell array containing all variable beliefs.\n"
61 " qf = linear cell array containing all factor beliefs.\n"
62 " qmap = (V,1) array containing the MAP labeling (only for BP,JTree).\n");
63 }
64
65 char *method;
66 char *opts;
67
68
69 // Get psi and construct factorgraph
70 vector<Factor> factors = mx2Factors(PSI_IN, 0);
71 FactorGraph fg(factors);
72
73 // Get method
74 buflen = mxGetN( METHOD_IN ) + 1;
75 method = (char *)mxCalloc( buflen, sizeof(char) );
76 mxGetString( METHOD_IN, method, buflen );
77
78 // Get options string
79 buflen = mxGetN( OPTS_IN ) + 1;
80 opts = (char *)mxCalloc( buflen, sizeof(char) );
81 mxGetString( OPTS_IN, opts, buflen );
82 // Convert to options object props
83 stringstream ss;
84 ss << opts;
85 PropertySet props;
86 ss >> props;
87
88 // Construct InfAlg object, init and run
89 InfAlg *obj = newInfAlg( method, fg, props );
90 obj->init();
91 obj->run();
92
93
94 // Save logZ
95 double logZ = obj->logZ();
96
97 // Save maxdiff
98 double maxdiff = obj->maxDiff();
99
100
101 // Hand over results to MATLAB
102 LOGZ_OUT = mxCreateDoubleMatrix(1,1,mxREAL);
103 *(mxGetPr(LOGZ_OUT)) = logZ;
104
105 Q_OUT = Factors2mx(obj->beliefs());
106
107 MD_OUT = mxCreateDoubleMatrix(1,1,mxREAL);
108 *(mxGetPr(MD_OUT)) = maxdiff;
109
110 if( nlhs >= 4 ) {
111 vector<Factor> qv;
112 qv.reserve( fg.nrVars() );
113 for( size_t i = 0; i < fg.nrVars(); i++ )
114 qv.push_back( obj->belief( fg.var(i) ) );
115 QV_OUT = Factors2mx( qv );
116 }
117
118 if( nlhs >= 5 ) {
119 vector<Factor> qf;
120 qf.reserve( fg.nrFactors() );
121 for( size_t I = 0; I < fg.nrFactors(); I++ )
122 qf.push_back( obj->belief( fg.factor(I).vars() ) );
123 QF_OUT = Factors2mx( qf );
124 }
125
126 if( nlhs >= 6 ) {
127 std::vector<std::size_t> map_state;
128 if( obj->identify() == "BP" ) {
129 BP* obj_bp = dynamic_cast<BP *>(obj);
130 DAI_ASSERT( obj_bp != 0 );
131 map_state = obj_bp->findMaximum();
132 } else if( obj->identify() == "JTREE" ) {
133 JTree* obj_jtree = dynamic_cast<JTree *>(obj);
134 DAI_ASSERT( obj_jtree != 0 );
135 map_state = obj_jtree->findMaximum();
136 } else {
137 mexErrMsgTxt("MAP state assignment works only for BP, JTree.\n");
138 delete obj;
139 return;
140 }
141 QMAP_OUT = mxCreateNumericMatrix(map_state.size(), 1, mxUINT32_CLASS, mxREAL);
142 uint32_T* qmap_p = reinterpret_cast<uint32_T *>(mxGetPr(QMAP_OUT));
143 for (size_t n = 0; n < map_state.size(); ++n)
144 qmap_p[n] = map_state[n];
145 }
146 delete obj;
147
148 return;
149 }