Merge branch 'joris'
[libdai.git] / src / matlab / dai.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 /*=================================================================*
24 * *
25 * This is a MEX-file for MATLAB. *
26 * *
27 * [logZ,q,md] = dai(psi,method,opts); *
28 * *
29 *=================================================================*/
30
31
32 #include <iostream>
33 #include <dai/matlab/matlab.h>
34 #include "mex.h"
35 #include <dai/alldai.h>
36
37
38 using namespace std;
39 using namespace dai;
40
41
42 /* Input Arguments */
43
44 #define PSI_IN prhs[0]
45 #define METHOD_IN prhs[1]
46 #define OPTS_IN prhs[2]
47 #define NR_IN 3
48
49
50 /* Output Arguments */
51
52 #define LOGZ_OUT plhs[0]
53 #define Q_OUT plhs[1]
54 #define MD_OUT plhs[2]
55 #define NR_OUT 3
56
57
58 void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] )
59 {
60 size_t buflen;
61
62 /* Check for proper number of arguments */
63 if( (nrhs != NR_IN) || (nlhs != NR_OUT) ) {
64 mexErrMsgTxt("Usage: [logZ,q,md] = dai(psi,method,opts)\n\n"
65 "\n"
66 "INPUT: psi = linear cell array containing the factors \n"
67 " psi{i} should be a structure with a Member field\n"
68 " and a P field, like a CPTAB).\n"
69 " method = name of the method (see README)\n"
70 " opts = string of options (see README)\n"
71 "\n"
72 "OUTPUT: logZ = approximation of the logarithm of the partition sum.\n"
73 " q = linear cell array containing all final beliefs.\n"
74 " md = maxdiff (final linf-dist between new and old single node beliefs).\n");
75 }
76
77 char *method;
78 char *opts;
79
80
81 // Get psi and construct factorgraph
82 vector<Factor> factors = mx2Factors(PSI_IN, 0);
83 FactorGraph fg(factors);
84
85 // Get method
86 buflen = mxGetN( METHOD_IN ) + 1;
87 method = (char *)mxCalloc( buflen, sizeof(char) );
88 mxGetString( METHOD_IN, method, buflen );
89
90 // Get options string
91 buflen = mxGetN( OPTS_IN ) + 1;
92 opts = (char *)mxCalloc( buflen, sizeof(char) );
93 mxGetString( OPTS_IN, opts, buflen );
94 // Convert to options object props
95 stringstream ss;
96 ss << opts;
97 PropertySet props;
98 ss >> props;
99
100 // Construct InfAlg object, init and run
101 InfAlg *obj = newInfAlg( method, fg, props );
102 obj->init();
103 obj->run();
104
105
106 // Save logZ
107 double logZ = obj->logZ();
108
109 // Save maxdiff
110 double maxdiff = obj->maxDiff();
111
112
113 // Hand over results to MATLAB
114 LOGZ_OUT = mxCreateDoubleMatrix(1,1,mxREAL);
115 *(mxGetPr(LOGZ_OUT)) = logZ;
116
117 Q_OUT = Factors2mx(obj->beliefs());
118
119 MD_OUT = mxCreateDoubleMatrix(1,1,mxREAL);
120 *(mxGetPr(MD_OUT)) = maxdiff;
121
122
123 return;
124 }