Finished release 0.2.4
[libdai.git] / src / matlab / dai.cpp
index 766c89e..5ef330d 100644 (file)
@@ -4,7 +4,7 @@
  *  2, or (at your option) any later version. libDAI is distributed without any
  *  warranty. See the file COPYING for more details.
  *
- *  Copyright (C) 2006-2009  Joris Mooij  [joris dot mooij at libdai dot org]
+ *  Copyright (C) 2006-2010  Joris Mooij  [joris dot mooij at libdai dot org]
  *  Copyright (C) 2006-2007  Radboud University Nijmegen, The Netherlands
  */
 
@@ -14,6 +14,7 @@
 #include "mex.h"
 #include <dai/alldai.h>
 #include <dai/bp.h>
+#include <dai/jtree.h>
 
 
 using namespace std;
@@ -59,7 +60,7 @@ void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] ) {
         "        md         = maxdiff (final linf-dist between new and old single node beliefs).\n"
         "        qv         = linear cell array containing all variable beliefs.\n"
         "        qf         = linear cell array containing all factor beliefs.\n"
-        "        qmap       = (V,1) array containing the MAP labeling (only for BP).\n");
+        "        qmap       = (V,1) array containing the MAP labeling (only for BP,JTree).\n");
     }
 
     char *method;
@@ -124,14 +125,20 @@ void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] ) {
     }
 
     if( nlhs >= 6 ) {
-        BP* obj_bp = dynamic_cast<BP *>(obj);
-        if (obj_bp == 0) {
-            mexErrMsgTxt("MAP state assignment works only for BP.\n");
+        std::vector<std::size_t> map_state;
+        if( obj->identify() == "BP" ) {
+            BP* obj_bp = dynamic_cast<BP *>(obj);
+            DAI_ASSERT( obj_bp != 0 );
+            map_state = obj_bp->findMaximum();
+        } else if( obj->identify() == "JTREE" ) {
+            JTree* obj_jtree = dynamic_cast<JTree *>(obj);
+            DAI_ASSERT( obj_jtree != 0 );
+            map_state = obj_jtree->findMaximum();
+        } else {
+            mexErrMsgTxt("MAP state assignment works only for BP, JTree.\n");
             delete obj;
-
             return;
         }
-        std::vector<std::size_t> map_state = obj_bp->findMaximum();
         QMAP_OUT = mxCreateNumericMatrix(map_state.size(), 1, mxUINT32_CLASS, mxREAL);
         uint32_T* qmap_p = reinterpret_cast<uint32_T *>(mxGetPr(QMAP_OUT));
         for (size_t n = 0; n < map_state.size(); ++n)