MEX file dai now also returns variable and factor beliefs.
[libdai.git] / src / matlab / dai.cpp
index 149d7b0..1169e35 100644 (file)
@@ -24,7 +24,7 @@
  *                                                                 * 
  * This is a MEX-file for MATLAB.                                  *
  *                                                                 * 
- *   [logZ,q,md] = dai(psi,method,opts);                           *
+ *   [logZ,q,md,qv,qf] = dai(psi,method,opts);                     *
  *                                                                 * 
  *=================================================================*/
 
@@ -45,6 +45,7 @@ using namespace dai;
 #define METHOD_IN       prhs[1]
 #define OPTS_IN         prhs[2]
 #define NR_IN           3
+#define NR_IN_OPT       0
 
 
 /* Output Arguments */
@@ -52,16 +53,19 @@ using namespace dai;
 #define LOGZ_OUT        plhs[0]
 #define Q_OUT           plhs[1]
 #define MD_OUT          plhs[2]
+#define QV_OUT          plhs[3]
+#define QF_OUT          plhs[4]
 #define NR_OUT          3
+#define NR_OUT_OPT      2
 
 
 void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] )
 { 
     size_t buflen;
     
-    /* Check for proper number of arguments */
-    if( (nrhs != NR_IN) || (nlhs != NR_OUT) ) { 
-        mexErrMsgTxt("Usage: [logZ,q,md] = dai(psi,method,opts)\n\n"
+    // Check for proper number of arguments
+    if( ((nrhs < NR_IN) || (nrhs > NR_IN + NR_IN_OPT)) || ((nlhs < NR_OUT) || (nlhs > NR_OUT + NR_OUT_OPT)) ) {
+        mexErrMsgTxt("Usage: [logZ,q,md,qv,qf] = dai(psi,method,opts)\n\n"
         "\n"
         "INPUT:  psi        = linear cell array containing the factors \n"
         "                     psi{i} should be a structure with a Member field\n"
@@ -71,7 +75,9 @@ void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] )
         "\n"
         "OUTPUT: logZ       = approximation of the logarithm of the partition sum.\n"
         "        q          = linear cell array containing all final beliefs.\n"
-        "        md         = maxdiff (final linf-dist between new and old single node beliefs).\n");
+        "        md         = maxdiff (final linf-dist between new and old single node beliefs).\n"
+        "        qv         = linear cell array containing all variable beliefs.\n"
+        "        qf         = linear cell array containing all factor beliefs.\n");
     } 
     
     char *method;
@@ -119,6 +125,21 @@ void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] )
     MD_OUT = mxCreateDoubleMatrix(1,1,mxREAL);
     *(mxGetPr(MD_OUT)) = maxdiff;
 
+    if( nlhs >= 4 ) {
+        vector<Factor> qv;
+        qv.reserve( fg.nrVars() );
+        for( size_t i = 0; i < fg.nrVars(); i++ )
+            qv.push_back( obj->belief( fg.var(i) ) );
+        QV_OUT = Factors2mx( qv );
+    }
+
+    if( nlhs >= 5 ) {
+        vector<Factor> qf;
+        qf.reserve( fg.nrFactors() );
+        for( size_t I = 0; I < fg.nrFactors(); I++ )
+            qf.push_back( obj->belief( fg.factor(I).vars() ) );
+        QF_OUT = Factors2mx( qf );
+    }
     
     return;
 }