Added special MatLab interface for junction tree algorithm (dai_jtree)
authorJoris Mooij <j.mooij@cs.ru.nl>
Mon, 17 Sep 2012 10:10:21 +0000 (12:10 +0200)
committerJoris Mooij <j.mooij@cs.ru.nl>
Mon, 17 Sep 2012 10:10:21 +0000 (12:10 +0200)
ChangeLog
Makefile
matlab/dai_jtree.m [new file with mode: 0644]
src/matlab/dai_jtree.cpp [new file with mode: 0644]

index 6d2d675..c2c0b22 100644 (file)
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,5 +1,6 @@
 git master
 ----------
+* Added special MatLab interface for junction tree algorithm (dai_jtree)
 * Added VC10 build files, kindly provided by Sameh Khamis
 * Fixed several Win64 related bugs (found by Sameh Khamis): 
   - define NAN
index 3f911d0..576488d 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -129,7 +129,7 @@ ifdef WITH_CIMG
 endif
 examples : $(EXAMPLES)
 
-matlabs : matlab/dai$(ME) matlab/dai_readfg$(ME) matlab/dai_writefg$(ME) matlab/dai_potstrength$(ME)
+matlabs : matlab/dai$(ME) matlab/dai_readfg$(ME) matlab/dai_writefg$(ME) matlab/dai_potstrength$(ME) matlab/dai_jtree$(ME)
 
 unittests : tests/unit/var_test$(EE) tests/unit/smallset_test$(EE) tests/unit/varset_test$(EE) tests/unit/graph_test$(EE) tests/unit/dag_test$(EE) tests/unit/bipgraph_test$(EE) tests/unit/weightedgraph_test$(EE) tests/unit/enum_test$(EE) tests/unit/enum_test$(EE) tests/unit/util_test$(EE) tests/unit/exceptions_test$(EE) tests/unit/properties_test$(EE) tests/unit/index_test$(EE) tests/unit/prob_test$(EE) tests/unit/factor_test$(EE) tests/unit/factorgraph_test$(EE) tests/unit/clustergraph_test$(EE) tests/unit/regiongraph_test$(EE) tests/unit/daialg_test$(EE) tests/unit/alldai_test$(EE)
        @echo 'Running unit tests...'
@@ -250,6 +250,9 @@ matlab/dai_writefg$(ME) : $(SRC)/matlab/dai_writefg.cpp $(HEADERS) $(SRC)/matlab
 matlab/dai_potstrength$(ME) : $(SRC)/matlab/dai_potstrength.cpp $(HEADERS) $(SRC)/matlab/matlab.cpp $(SRC)/exceptions.cpp
        $(MEX) -output $@ $< $(SRC)/matlab/matlab.cpp $(SRC)/exceptions.cpp
 
+matlab/dai_jtree$(ME) : $(SRC)/matlab/dai_jtree.cpp $(HEADERS) $(SOURCES) $(SRC)/matlab/matlab.cpp
+       $(MEX) -output $@ $< $(SRC)/matlab/matlab.cpp $(SOURCES)
+
 
 # UTILS
 ########
diff --git a/matlab/dai_jtree.m b/matlab/dai_jtree.m
new file mode 100644 (file)
index 0000000..ceb9c11
--- /dev/null
@@ -0,0 +1,15 @@
+% [logZ,q,qv,qf,qmap,margs] = dai_jtree(psi,varsets,opts)
+%
+% INPUT:  psi        = linear cell array containing the factors
+%                      (psi{i} should be a structure with a Member field
+%                      and a P field).
+%         varsets    = linear cell array containing varsets for which marginals
+%                      are requested.
+%         opts       = string of options.
+% 
+% OUTPUT: logZ       = logarithm of the partition sum.
+%         q          = linear cell array containing all calculated marginals.
+%         qv         = linear cell array containing all variable marginals.
+%         qf         = linear cell array containing all factor marginals.
+%         qmap       = linear array containing the MAP state.
+%         margs      = linear cell array containing all requested marginals.
diff --git a/src/matlab/dai_jtree.cpp b/src/matlab/dai_jtree.cpp
new file mode 100644 (file)
index 0000000..6e59648
--- /dev/null
@@ -0,0 +1,231 @@
+/*  This file is part of libDAI - http://www.libdai.org/
+ *
+ *  Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
+ */
+
+
+#include <iostream>
+#include <dai/matlab/matlab.h>
+#include "mex.h"
+#include <dai/jtree.h>
+
+
+using namespace std;
+using namespace dai;
+
+
+/* Convert cell vector of Matlab sets to vector<VarSet> */
+vector<VarSet> mx2VarSets(const mxArray *vs, const FactorGraph &fg, long verbose, vector<Permute> &perms) {
+    vector<VarSet> varsets;
+
+    int n1 = mxGetM(vs);
+    int n2 = mxGetN(vs);
+    if( n2 != 1 && n1 != 1 )
+        mexErrMsgTxt("varsets should be a Nx1 or 1xN cell matrix.");
+    size_t nr_vs = n1;
+    if( n1 == 1 )
+        nr_vs = n2;
+
+    // interpret vs, linear cell array of varsets
+    varsets.reserve( nr_vs );
+    perms.clear();
+    perms.reserve( nr_vs );
+    for( size_t cellind = 0; cellind < nr_vs; cellind++ ) {
+        if( verbose >= 3 )
+            cerr << "reading varset " << cellind << ": " << endl;
+        mxArray *cell = mxGetCell(vs, cellind);
+        if( verbose >= 3 )
+            cerr << "  got cell " << endl;
+        size_t nr_mem = mxGetN(cell);
+        if( verbose >= 3 )
+            cerr << "  number members: " << nr_mem << endl;
+        double *members = mxGetPr(cell);
+        if( verbose >= 3 )
+            cerr << "  got them! " << endl;
+
+        // add variables
+        VarSet vsvars;
+        if( verbose >= 3 )
+            cerr << "  vars: ";
+        vector<long> labels(nr_mem,0);
+        vector<size_t> dims(nr_mem,0);
+        for( size_t mi = 0; mi < nr_mem; mi++ ) {
+            labels[mi] = (long)members[mi];
+            dims[mi] = fg.var(labels[mi]).states();
+            if( verbose >= 3 )
+                cerr << labels[mi] << " ";
+            vsvars.insert( fg.var(labels[mi]) );
+        }
+        if( verbose >= 3 )
+            cerr << endl;
+        DAI_ASSERT( nr_mem == vsvars.size() );
+        varsets.push_back(vsvars);
+
+        // calculate permutation matrix
+        vector<size_t> perm(nr_mem,0);
+        VarSet::iterator j = vsvars.begin();
+        for( size_t mi = 0; mi < nr_mem; mi++,j++ ) {
+            long gezocht = j->label();
+            vector<long>::iterator piet = find(labels.begin(),labels.end(),gezocht);
+            perm[mi] = piet - labels.begin();
+        }
+        if( verbose >= 3 ) {
+            cerr << endl << "  perm: ";
+            for( vector<size_t>::iterator r=perm.begin(); r!=perm.end(); r++ )
+                cerr << *r << " ";
+            cerr << endl;
+        }
+        // create Permute object
+        vector<size_t> di(nr_mem,0);
+        size_t prod = 1;
+        for( size_t k = 0; k < nr_mem; k++ ) {
+            di[k] = dims[k];
+            prod *= dims[k];
+        }
+        Permute permindex( di, perm );
+        perms.push_back( permindex );
+    }
+
+    if( verbose >= 3 ) {
+        for(vector<VarSet>::const_iterator I=varsets.begin(); I!=varsets.end(); I++ )
+            cerr << *I << endl;
+    }
+
+    return( varsets );
+}
+
+
+/* Input Arguments */
+
+#define PSI_IN          prhs[0]
+#define VARSETS_IN      prhs[1]
+#define OPTS_IN         prhs[2]
+#define NR_IN           3
+#define NR_IN_OPT       0
+
+
+/* Output Arguments */
+
+#define LOGZ_OUT        plhs[0]
+#define Q_OUT           plhs[1]
+#define QV_OUT          plhs[2]
+#define QF_OUT          plhs[3]
+#define QMAP_OUT        plhs[4]
+#define MARGS_OUT       plhs[5]
+#define NR_OUT          3
+#define NR_OUT_OPT      3
+
+
+void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] ) {
+    // Check for proper number of arguments
+    if( ((nrhs < NR_IN) || (nrhs > NR_IN + NR_IN_OPT)) || ((nlhs < NR_OUT) || (nlhs > NR_OUT + NR_OUT_OPT)) ) {
+        mexErrMsgTxt("Usage: [logZ,q,qv,qf,qmap,margs] = dai_jtree(psi,varsets,opts)\n\n"
+        "\n"
+        "INPUT:  psi        = linear cell array containing the factors\n"
+        "                     (psi{i} should be a structure with a Member field\n"
+        "                     and a P field).\n"
+        "        varsets    = linear cell array containing varsets for which marginals\n"
+        "                     are requested.\n"
+        "        opts       = string of options.\n"
+        "\n"
+        "OUTPUT: logZ       = logarithm of the partition sum.\n"
+        "        q          = linear cell array containing all calculated marginals.\n"
+        "        qv         = linear cell array containing all variable marginals.\n"
+        "        qf         = linear cell array containing all factor marginals.\n"
+        "        qmap       = linear array containing the MAP state.\n"
+        "        margs      = linear cell array containing all requested marginals.\n");
+    }
+
+    // Get psi and construct factorgraph
+    vector<Factor> factors = mx2Factors(PSI_IN, 0);
+    FactorGraph fg(factors);
+
+    // Get varsets
+    vector<Permute> perms;
+    vector<VarSet> varsets = mx2VarSets(VARSETS_IN,fg,0,perms);
+
+    // Get options string
+    char *opts;
+    size_t buflen = mxGetN( OPTS_IN ) + 1;
+    opts = (char *)mxCalloc( buflen, sizeof(char) );
+    mxGetString( OPTS_IN, opts, buflen );
+    // Convert to options object props
+    stringstream ss;
+    ss << opts;
+    PropertySet props;
+    ss >> props;
+
+    // Construct InfAlg object, init and run
+    JTree jt = JTree( fg, props );
+    jt.init();
+    jt.run();
+
+    // Save logZ
+       double logZ = NAN;
+    logZ = jt.logZ();
+
+    // Hand over results to MATLAB
+    LOGZ_OUT = mxCreateDoubleMatrix(1,1,mxREAL);
+    *(mxGetPr(LOGZ_OUT)) = logZ;
+
+    Q_OUT = Factors2mx(jt.beliefs());
+
+    if( nlhs >= 3 ) {
+        vector<Factor> qv;
+        qv.reserve( fg.nrVars() );
+        for( size_t i = 0; i < fg.nrVars(); i++ )
+            qv.push_back( jt.belief( fg.var(i) ) );
+        QV_OUT = Factors2mx( qv );
+    }
+
+    if( nlhs >= 4 ) {
+        vector<Factor> qf;
+        qf.reserve( fg.nrFactors() );
+        for( size_t I = 0; I < fg.nrFactors(); I++ )
+            qf.push_back( jt.belief( fg.factor(I).vars() ) );
+        QF_OUT = Factors2mx( qf );
+    }
+
+    if( nlhs >= 5 ) {
+        std::vector<std::size_t> map_state;
+        bool supported = true;
+        try {
+            map_state = jt.findMaximum();
+        } catch( Exception &e ) {
+            if( e.getCode() == Exception::NOT_IMPLEMENTED )
+                supported = false;
+            else
+                throw;
+        }
+        if( supported ) {
+            QMAP_OUT = mxCreateNumericMatrix(map_state.size(), 1, mxUINT32_CLASS, mxREAL);
+            uint32_T* qmap_p = reinterpret_cast<uint32_T *>(mxGetPr(QMAP_OUT));
+            for (size_t n = 0; n < map_state.size(); ++n)
+                qmap_p[n] = map_state[n];
+        } else {
+            mexErrMsgTxt("Calculating a MAP state is not supported by this inference algorithm.");
+        }
+    }
+
+    if( nlhs >= 6 ) {
+        vector<Factor> margs;
+        margs.reserve( varsets.size() );
+        for( size_t s = 0; s < varsets.size(); s++ ) {
+            Factor marg;
+            jt.init();
+            jt.run();
+            marg = jt.calcMarginal( varsets[s] );
+
+            // permute entries of marg
+            Factor margperm = marg;
+            for( size_t li = 0; li < marg.nrStates(); li++ )
+                margperm.set( li, marg[perms[s].convertLinearIndex(li)] );
+            margs.push_back( margperm );
+        }
+        MARGS_OUT = Factors2mx( margs );
+    }
+
+    return;
+}