1 /* This file is part of libDAI - http://www.libdai.org/
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
13 #include <dai/matlab/matlab.h>
15 #include <dai/alldai.h>
25 #define PSI_IN prhs[0]
26 #define METHOD_IN prhs[1]
27 #define OPTS_IN prhs[2]
32 /* Output Arguments */
34 #define LOGZ_OUT plhs[0]
36 #define MD_OUT plhs[2]
37 #define QV_OUT plhs[3]
38 #define QF_OUT plhs[4]
39 #define QMAP_OUT plhs[5]
44 void mexFunction( int nlhs
, mxArray
*plhs
[], int nrhs
, const mxArray
*prhs
[] ) {
47 // Check for proper number of arguments
48 if( ((nrhs
< NR_IN
) || (nrhs
> NR_IN
+ NR_IN_OPT
)) || ((nlhs
< NR_OUT
) || (nlhs
> NR_OUT
+ NR_OUT_OPT
)) ) {
49 mexErrMsgTxt("Usage: [logZ,q,md,qv,qf,qmap] = dai(psi,method,opts)\n\n"
51 "INPUT: psi = linear cell array containing the factors \n"
52 " psi{i} should be a structure with a Member field\n"
53 " and a P field, like a CPTAB).\n"
54 " method = name of the method (see README)\n"
55 " opts = string of options (see README)\n"
57 "OUTPUT: logZ = approximation of the logarithm of the partition sum.\n"
58 " q = linear cell array containing all final beliefs.\n"
59 " md = maxdiff (final linf-dist between new and old single node beliefs).\n"
60 " qv = linear cell array containing all variable beliefs.\n"
61 " qf = linear cell array containing all factor beliefs.\n"
62 " qmap = (V,1) array containing the MAP labeling (only for BP).\n");
69 // Get psi and construct factorgraph
70 vector
<Factor
> factors
= mx2Factors(PSI_IN
, 0);
71 FactorGraph
fg(factors
);
74 buflen
= mxGetN( METHOD_IN
) + 1;
75 method
= (char *)mxCalloc( buflen
, sizeof(char) );
76 mxGetString( METHOD_IN
, method
, buflen
);
79 buflen
= mxGetN( OPTS_IN
) + 1;
80 opts
= (char *)mxCalloc( buflen
, sizeof(char) );
81 mxGetString( OPTS_IN
, opts
, buflen
);
82 // Convert to options object props
88 // Construct InfAlg object, init and run
89 InfAlg
*obj
= newInfAlg( method
, fg
, props
);
95 double logZ
= obj
->logZ();
98 double maxdiff
= obj
->maxDiff();
101 // Hand over results to MATLAB
102 LOGZ_OUT
= mxCreateDoubleMatrix(1,1,mxREAL
);
103 *(mxGetPr(LOGZ_OUT
)) = logZ
;
105 Q_OUT
= Factors2mx(obj
->beliefs());
107 MD_OUT
= mxCreateDoubleMatrix(1,1,mxREAL
);
108 *(mxGetPr(MD_OUT
)) = maxdiff
;
112 qv
.reserve( fg
.nrVars() );
113 for( size_t i
= 0; i
< fg
.nrVars(); i
++ )
114 qv
.push_back( obj
->belief( fg
.var(i
) ) );
115 QV_OUT
= Factors2mx( qv
);
120 qf
.reserve( fg
.nrFactors() );
121 for( size_t I
= 0; I
< fg
.nrFactors(); I
++ )
122 qf
.push_back( obj
->belief( fg
.factor(I
).vars() ) );
123 QF_OUT
= Factors2mx( qf
);
127 BP
* obj_bp
= dynamic_cast<BP
*>(obj
);
129 mexErrMsgTxt("MAP state assignment works only for BP.\n");
134 std::vector
<std::size_t> map_state
= obj_bp
->findMaximum();
135 QMAP_OUT
= mxCreateNumericMatrix(map_state
.size(), 1, mxUINT32_CLASS
, mxREAL
);
136 uint32_T
* qmap_p
= reinterpret_cast<uint32_T
*>(mxGetPr(QMAP_OUT
));
137 for (size_t n
= 0; n
< map_state
.size(); ++n
)
138 qmap_p
[n
] = map_state
[n
];