7ad11b84a63f90c7699fde010df0208e20b6c9ae
[libdai.git] / matlab / dai.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 /*=================================================================*
23 * *
24 * This is a MEX-file for MATLAB. *
25 * *
26 * [logZ,q,md] = dai(psi,method,opts); *
27 * *
28 *=================================================================*/
29
30
31 #include <iostream>
32 #include "matlab.h"
33 #include "mex.h"
34 #include "alldai.h"
35
36
37 using namespace std;
38
39
40 /* Input Arguments */
41
42 #define PSI_IN prhs[0]
43 #define METHOD_IN prhs[1]
44 #define OPTS_IN prhs[2]
45 #define NR_IN 3
46
47
48 /* Output Arguments */
49
50 #define LOGZ_OUT plhs[0]
51 #define Q_OUT plhs[1]
52 #define MD_OUT plhs[2]
53 #define NR_OUT 3
54
55
56 void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] )
57 {
58 size_t buflen;
59
60 /* Check for proper number of arguments */
61 if( (nrhs != NR_IN) || (nlhs != NR_OUT) ) {
62 mexErrMsgTxt("Usage: [logZ,q,md] = dai(psi,method,opts)\n\n"
63 "\n"
64 "INPUT: psi = linear cell array containing the factors \n"
65 " psi{i} should be a structure with a Member field\n"
66 " and a P field, like a CPTAB).\n"
67 " method = name of the method (see README)\n"
68 " opts = string of options (see README)\n"
69 "\n"
70 "OUTPUT: logZ = approximation of the logarithm of the partition sum.\n"
71 " q = linear cell array containing all final beliefs.\n"
72 " md = maxdiff (final linf-dist between new and old single node beliefs).\n");
73 }
74
75 char *method;
76 char *opts;
77
78
79 // Get psi and construct factorgraph
80 vector<Factor> factors = mx2Factors(PSI_IN, 0);
81 FactorGraph fg(factors);
82 long nr_v = fg.nrVars();
83
84 // Get method
85 buflen = mxGetN( METHOD_IN ) + 1;
86 method = (char *)mxCalloc( buflen, sizeof(char) );
87 mxGetString( METHOD_IN, method, buflen );
88
89 // Get options string
90 buflen = mxGetN( OPTS_IN ) + 1;
91 opts = (char *)mxCalloc( buflen, sizeof(char) );
92 mxGetString( OPTS_IN, opts, buflen );
93 // Convert to options object props
94 stringstream ss;
95 ss << opts;
96 Properties props;
97 ss >> props;
98
99 // Construct InfAlg object, init and run
100 InfAlg *obj = newInfAlg( method, fg, props );
101 obj->init();
102 obj->run();
103
104
105 // Save logZ
106 double logZ = obj->logZ();
107
108 // Save maxdiff
109 double maxdiff = obj->MaxDiff();
110
111
112 // Hand over results to MATLAB
113 LOGZ_OUT = mxCreateDoubleMatrix(1,1,mxREAL);
114 *(mxGetPr(LOGZ_OUT)) = logZ;
115
116 Q_OUT = Factors2mx(obj->beliefs());
117
118 MD_OUT = mxCreateDoubleMatrix(1,1,mxREAL);
119 *(mxGetPr(MD_OUT)) = maxdiff;
120
121
122 return;
123 }