Finished release 0.2.4
[libdai.git] / src / matlab / dai.cpp
index 5eb9ae9..5ef330d 100644 (file)
@@ -1,38 +1,20 @@
-/*  Copyright (C) 2006-2008  Joris Mooij  [joris dot mooij at tuebingen dot mpg dot de]
-    Radboud University Nijmegen, The Netherlands /
-    Max Planck Institute for Biological Cybernetics, Germany
-
-    This file is part of libDAI.
-
-    libDAI is free software; you can redistribute it and/or modify
-    it under the terms of the GNU General Public License as published by
-    the Free Software Foundation; either version 2 of the License, or
-    (at your option) any later version.
-
-    libDAI is distributed in the hope that it will be useful,
-    but WITHOUT ANY WARRANTY; without even the implied warranty of
-    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-    GNU General Public License for more details.
-
-    You should have received a copy of the GNU General Public License
-    along with libDAI; if not, write to the Free Software
-    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
-*/
-
-
-/*=================================================================*
- *                                                                 * 
- * This is a MEX-file for MATLAB.                                  *
- *                                                                 * 
- *   [logZ,q,md,qv,qf] = dai(psi,method,opts);                     *
- *                                                                 * 
- *=================================================================*/
+/*  This file is part of libDAI - http://www.libdai.org/
+ *
+ *  libDAI is licensed under the terms of the GNU General Public License version
+ *  2, or (at your option) any later version. libDAI is distributed without any
+ *  warranty. See the file COPYING for more details.
+ *
+ *  Copyright (C) 2006-2010  Joris Mooij  [joris dot mooij at libdai dot org]
+ *  Copyright (C) 2006-2007  Radboud University Nijmegen, The Netherlands
+ */
 
 
 #include <iostream>
 #include <dai/matlab/matlab.h>
 #include "mex.h"
 #include <dai/alldai.h>
+#include <dai/bp.h>
+#include <dai/jtree.h>
 
 
 using namespace std;
@@ -55,17 +37,17 @@ using namespace dai;
 #define MD_OUT          plhs[2]
 #define QV_OUT          plhs[3]
 #define QF_OUT          plhs[4]
+#define QMAP_OUT        plhs[5]
 #define NR_OUT          3
-#define NR_OUT_OPT      2
+#define NR_OUT_OPT      3
 
 
-void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] )
-{ 
+void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] ) {
     size_t buflen;
-    
+
     // Check for proper number of arguments
     if( ((nrhs < NR_IN) || (nrhs > NR_IN + NR_IN_OPT)) || ((nlhs < NR_OUT) || (nlhs > NR_OUT + NR_OUT_OPT)) ) {
-        mexErrMsgTxt("Usage: [logZ,q,md,qv,qf] = dai(psi,method,opts)\n\n"
+        mexErrMsgTxt("Usage: [logZ,q,md,qv,qf,qmap] = dai(psi,method,opts)\n\n"
         "\n"
         "INPUT:  psi        = linear cell array containing the factors \n"
         "                     psi{i} should be a structure with a Member field\n"
@@ -77,9 +59,10 @@ void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] )
         "        q          = linear cell array containing all final beliefs.\n"
         "        md         = maxdiff (final linf-dist between new and old single node beliefs).\n"
         "        qv         = linear cell array containing all variable beliefs.\n"
-        "        qf         = linear cell array containing all factor beliefs.\n");
-    } 
-    
+        "        qf         = linear cell array containing all factor beliefs.\n"
+        "        qmap       = (V,1) array containing the MAP labeling (only for BP,JTree).\n");
+    }
+
     char *method;
     char *opts;
 
@@ -102,7 +85,7 @@ void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] )
     ss << opts;
     PropertySet props;
     ss >> props;
-    
+
     // Construct InfAlg object, init and run
     InfAlg *obj = newInfAlg( method, fg, props );
     obj->init();
@@ -119,9 +102,9 @@ void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] )
     // Hand over results to MATLAB
     LOGZ_OUT = mxCreateDoubleMatrix(1,1,mxREAL);
     *(mxGetPr(LOGZ_OUT)) = logZ;
-    
+
     Q_OUT = Factors2mx(obj->beliefs());
-    
+
     MD_OUT = mxCreateDoubleMatrix(1,1,mxREAL);
     *(mxGetPr(MD_OUT)) = maxdiff;
 
@@ -140,7 +123,28 @@ void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] )
             qf.push_back( obj->belief( fg.factor(I).vars() ) );
         QF_OUT = Factors2mx( qf );
     }
+
+    if( nlhs >= 6 ) {
+        std::vector<std::size_t> map_state;
+        if( obj->identify() == "BP" ) {
+            BP* obj_bp = dynamic_cast<BP *>(obj);
+            DAI_ASSERT( obj_bp != 0 );
+            map_state = obj_bp->findMaximum();
+        } else if( obj->identify() == "JTREE" ) {
+            JTree* obj_jtree = dynamic_cast<JTree *>(obj);
+            DAI_ASSERT( obj_jtree != 0 );
+            map_state = obj_jtree->findMaximum();
+        } else {
+            mexErrMsgTxt("MAP state assignment works only for BP, JTree.\n");
+            delete obj;
+            return;
+        }
+        QMAP_OUT = mxCreateNumericMatrix(map_state.size(), 1, mxUINT32_CLASS, mxREAL);
+        uint32_T* qmap_p = reinterpret_cast<uint32_T *>(mxGetPr(QMAP_OUT));
+        for (size_t n = 0; n < map_state.size(); ++n)
+            qmap_p[n] = map_state[n];
+    }
     delete obj;
-    
+
     return;
 }