Merge branch 'joris'
[libdai.git] / src / matlab / dai_potstrength.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 /*=================================================================*
24 * *
25 * This is a MEX-file for MATLAB. *
26 * *
27 * N = dai_potstrength(psi,i,j); *
28 * *
29 *=================================================================*/
30
31
32 #include <iostream>
33 #include "mex.h"
34 #include <dai/matlab/matlab.h>
35 #include <dai/factor.h>
36
37
38 using namespace std;
39 using namespace dai;
40
41
42 /* Input Arguments */
43
44 #define PSI_IN prhs[0]
45 #define I_IN prhs[1]
46 #define J_IN prhs[2]
47 #define NR_IN 3
48
49
50 /* Output Arguments */
51
52 #define N_OUT plhs[0]
53 #define NR_OUT 1
54
55
56 void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] )
57 {
58 long ilabel, jlabel;
59
60 // Check for proper number of arguments
61 if ((nrhs != NR_IN) || (nlhs != NR_OUT)) {
62 mexErrMsgTxt("Usage: N = dai_potstrength(psi,i,j);\n\n"
63 "\n"
64 "INPUT: psi = structure with a Member field and a P field, like a CPTAB.\n"
65 " i = label of a variable in psi.\n"
66 " j = label of another variable in psi.\n"
67 "\n"
68 "OUTPUT: N = strength of psi in direction i->j.\n");
69 }
70
71 // Get input parameters
72 Factor psi = mx2Factor(PSI_IN);
73 ilabel = (long)*mxGetPr(I_IN);
74 jlabel = (long)*mxGetPr(J_IN);
75
76 // Find variable in psi with label ilabel
77 Var i;
78 for( VarSet::const_iterator n = psi.vars().begin(); n != psi.vars().end(); n++ )
79 if( n->label() == ilabel ) {
80 i = *n;
81 break;
82 }
83 assert( i.label() == ilabel );
84
85 // Find variable in psi with label jlabel
86 Var j;
87 for( VarSet::const_iterator n = psi.vars().begin(); n != psi.vars().end(); n++ )
88 if( n->label() == jlabel ) {
89 j = *n;
90 break;
91 }
92 assert( j.label() == jlabel );
93
94 // Calculate N(psi,i,j);
95 double N = psi.strength( i, j );
96
97 // Hand over result to MATLAB
98 N_OUT = mxCreateDoubleMatrix(1,1,mxREAL);
99 *(mxGetPr(N_OUT)) = N;
100
101 return;
102 }