Merge branch 'pletscher'
[libdai.git] / src / matlab / dai_potstrength.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include "mex.h"
14 #include <dai/matlab/matlab.h>
15 #include <dai/factor.h>
16
17
18 using namespace std;
19 using namespace dai;
20
21
22 /* Input Arguments */
23
24 #define PSI_IN prhs[0]
25 #define I_IN prhs[1]
26 #define J_IN prhs[2]
27 #define NR_IN 3
28
29
30 /* Output Arguments */
31
32 #define N_OUT plhs[0]
33 #define NR_OUT 1
34
35
36 void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] ) {
37 long ilabel, jlabel;
38
39 // Check for proper number of arguments
40 if ((nrhs != NR_IN) || (nlhs != NR_OUT)) {
41 mexErrMsgTxt("Usage: N = dai_potstrength(psi,i,j);\n\n"
42 "\n"
43 "INPUT: psi = structure with a Member field and a P field, like a CPTAB.\n"
44 " i = label of a variable in psi.\n"
45 " j = label of another variable in psi.\n"
46 "\n"
47 "OUTPUT: N = strength of psi in direction i->j.\n");
48 }
49
50 // Get input parameters
51 Factor psi = mx2Factor(PSI_IN);
52 ilabel = (long)*mxGetPr(I_IN);
53 jlabel = (long)*mxGetPr(J_IN);
54
55 // Find variable in psi with label ilabel
56 Var i;
57 for( VarSet::const_iterator n = psi.vars().begin(); n != psi.vars().end(); n++ )
58 if( n->label() == ilabel ) {
59 i = *n;
60 break;
61 }
62 DAI_ASSERT( i.label() == ilabel );
63
64 // Find variable in psi with label jlabel
65 Var j;
66 for( VarSet::const_iterator n = psi.vars().begin(); n != psi.vars().end(); n++ )
67 if( n->label() == jlabel ) {
68 j = *n;
69 break;
70 }
71 DAI_ASSERT( j.label() == jlabel );
72
73 // Calculate N(psi,i,j);
74 double N = psi.strength( i, j );
75
76 // Hand over result to MATLAB
77 N_OUT = mxCreateDoubleMatrix(1,1,mxREAL);
78 *(mxGetPr(N_OUT)) = N;
79
80 return;
81 }