Improved WeightedGraph code and added unit tests
[libdai.git] / src / matlab / dai.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2010 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <dai/matlab/matlab.h>
14 #include "mex.h"
15 #include <dai/alldai.h>
16 #include <dai/bp.h>
17 #include <dai/jtree.h>
18
19
20 using namespace std;
21 using namespace dai;
22
23
24 /* Input Arguments */
25
26 #define PSI_IN prhs[0]
27 #define METHOD_IN prhs[1]
28 #define OPTS_IN prhs[2]
29 #define NR_IN 3
30 #define NR_IN_OPT 0
31
32
33 /* Output Arguments */
34
35 #define LOGZ_OUT plhs[0]
36 #define Q_OUT plhs[1]
37 #define MD_OUT plhs[2]
38 #define QV_OUT plhs[3]
39 #define QF_OUT plhs[4]
40 #define QMAP_OUT plhs[5]
41 #define NR_OUT 3
42 #define NR_OUT_OPT 3
43
44
45 void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] ) {
46 size_t buflen;
47
48 // Check for proper number of arguments
49 if( ((nrhs < NR_IN) || (nrhs > NR_IN + NR_IN_OPT)) || ((nlhs < NR_OUT) || (nlhs > NR_OUT + NR_OUT_OPT)) ) {
50 mexErrMsgTxt("Usage: [logZ,q,md,qv,qf,qmap] = dai(psi,method,opts)\n\n"
51 "\n"
52 "INPUT: psi = linear cell array containing the factors \n"
53 " psi{i} should be a structure with a Member field\n"
54 " and a P field, like a CPTAB).\n"
55 " method = name of the method (see README)\n"
56 " opts = string of options (see README)\n"
57 "\n"
58 "OUTPUT: logZ = approximation of the logarithm of the partition sum.\n"
59 " q = linear cell array containing all final beliefs.\n"
60 " md = maxdiff (final linf-dist between new and old single node beliefs).\n"
61 " qv = linear cell array containing all variable beliefs.\n"
62 " qf = linear cell array containing all factor beliefs.\n"
63 " qmap = (V,1) array containing the MAP labeling (only for BP,JTree).\n");
64 }
65
66 char *method;
67 char *opts;
68
69
70 // Get psi and construct factorgraph
71 vector<Factor> factors = mx2Factors(PSI_IN, 0);
72 FactorGraph fg(factors);
73
74 // Get method
75 buflen = mxGetN( METHOD_IN ) + 1;
76 method = (char *)mxCalloc( buflen, sizeof(char) );
77 mxGetString( METHOD_IN, method, buflen );
78
79 // Get options string
80 buflen = mxGetN( OPTS_IN ) + 1;
81 opts = (char *)mxCalloc( buflen, sizeof(char) );
82 mxGetString( OPTS_IN, opts, buflen );
83 // Convert to options object props
84 stringstream ss;
85 ss << opts;
86 PropertySet props;
87 ss >> props;
88
89 // Construct InfAlg object, init and run
90 InfAlg *obj = newInfAlg( method, fg, props );
91 obj->init();
92 obj->run();
93
94
95 // Save logZ
96 double logZ = obj->logZ();
97
98 // Save maxdiff
99 double maxdiff = obj->maxDiff();
100
101
102 // Hand over results to MATLAB
103 LOGZ_OUT = mxCreateDoubleMatrix(1,1,mxREAL);
104 *(mxGetPr(LOGZ_OUT)) = logZ;
105
106 Q_OUT = Factors2mx(obj->beliefs());
107
108 MD_OUT = mxCreateDoubleMatrix(1,1,mxREAL);
109 *(mxGetPr(MD_OUT)) = maxdiff;
110
111 if( nlhs >= 4 ) {
112 vector<Factor> qv;
113 qv.reserve( fg.nrVars() );
114 for( size_t i = 0; i < fg.nrVars(); i++ )
115 qv.push_back( obj->belief( fg.var(i) ) );
116 QV_OUT = Factors2mx( qv );
117 }
118
119 if( nlhs >= 5 ) {
120 vector<Factor> qf;
121 qf.reserve( fg.nrFactors() );
122 for( size_t I = 0; I < fg.nrFactors(); I++ )
123 qf.push_back( obj->belief( fg.factor(I).vars() ) );
124 QF_OUT = Factors2mx( qf );
125 }
126
127 if( nlhs >= 6 ) {
128 std::vector<std::size_t> map_state;
129 if( obj->identify() == "BP" ) {
130 BP* obj_bp = dynamic_cast<BP *>(obj);
131 DAI_ASSERT( obj_bp != 0 );
132 map_state = obj_bp->findMaximum();
133 } else if( obj->identify() == "JTREE" ) {
134 JTree* obj_jtree = dynamic_cast<JTree *>(obj);
135 DAI_ASSERT( obj_jtree != 0 );
136 map_state = obj_jtree->findMaximum();
137 } else {
138 mexErrMsgTxt("MAP state assignment works only for BP, JTree.\n");
139 delete obj;
140 return;
141 }
142 QMAP_OUT = mxCreateNumericMatrix(map_state.size(), 1, mxUINT32_CLASS, mxREAL);
143 uint32_T* qmap_p = reinterpret_cast<uint32_T *>(mxGetPr(QMAP_OUT));
144 for (size_t n = 0; n < map_state.size(); ++n)
145 qmap_p[n] = map_state[n];
146 }
147 delete obj;
148
149 return;
150 }