30eac37da1c953c515f5cac86024dba5aaf4c546
[libdai.git] / matlab / matlab.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include "matlab.h"
24
25
26 using namespace std;
27
28
29 /* Convert vector<Factor> structure to a cell vector of CPTAB-like structs */
30 mxArray *Factors2mx(const vector<Factor> &Ps) {
31 size_t nr = Ps.size();
32
33 mxArray *psi = mxCreateCellMatrix(nr,1);
34
35 const char *fieldnames[2];
36 fieldnames[0] = "Member";
37 fieldnames[1] = "P";
38
39 size_t I_ind = 0;
40 for( vector<Factor>::const_iterator I = Ps.begin(); I != Ps.end(); I++, I_ind++ ) {
41 mxArray *Bi = mxCreateStructMatrix(1,1,2,fieldnames);
42
43 mxArray *BiMember = mxCreateDoubleMatrix(1,I->vars().size(),mxREAL);
44 double *BiMember_data = mxGetPr(BiMember);
45 size_t i = 0;
46 vector<mwSize> dims;
47 for( VarSet::iterator j = I->vars().begin(); j != I->vars().end(); j++,i++ ) {
48 BiMember_data[i] = j->label();
49 dims.push_back( j->states() );
50 }
51
52 // mxArray *BiP = mxCreateDoubleMatrix(I->states(),1,mxREAL);
53 mxArray *BiP = mxCreateNumericArray(I->vars().size(), &(*(dims.begin())), mxDOUBLE_CLASS, mxREAL);
54 double *BiP_data = mxGetPr(BiP);
55 for( size_t j = 0; j < I->states(); j++ )
56 BiP_data[j] = (*I)[j];
57
58 mxSetField(Bi,0,"Member",BiMember);
59 mxSetField(Bi,0,"P",BiP);
60
61 mxSetCell(psi, I_ind, Bi);
62 }
63 return( psi );
64 }
65
66
67 /* Convert cell vector of CPTAB-like structs to vector<Factor> */
68 vector<Factor> mx2Factors(const mxArray *psi, long verbose) {
69 set<Var> vars;
70 vector<Factor> factors;
71
72 int n1 = mxGetM(psi);
73 int n2 = mxGetN(psi);
74 if( n2 != 1 && n1 != 1 )
75 mexErrMsgTxt("psi should be a Nx1 or 1xN cell matrix.");
76 size_t nr_f = n1;
77 if( n1 == 1 )
78 nr_f = n2;
79
80 // interpret psi, linear cell array of cptabs
81 for( size_t cellind = 0; cellind < nr_f; cellind++ ) {
82 if( verbose >= 3 )
83 cout << "reading factor " << cellind << ": " << endl;
84 mxArray *cell = mxGetCell(psi, cellind);
85 mxArray *mx_member = mxGetField(cell, 0, "Member");
86 size_t nr_mem = mxGetN(mx_member);
87 double *members = mxGetPr(mx_member);
88 const mwSize *dims = mxGetDimensions(mxGetField(cell,0,"P"));
89 double *factordata = mxGetPr(mxGetField(cell, 0, "P"));
90
91 // add variables
92 VarSet factorvars;
93 vector<long> labels(nr_mem,0);
94 if( verbose >= 3 )
95 cout << " vars: ";
96 for( size_t mi = 0; mi < nr_mem; mi++ ) {
97 labels[mi] = (long)members[mi];
98 if( verbose >= 3 )
99 cout << labels[mi] << "(" << dims[mi] << ") ";
100 vars.insert( Var(labels[mi], dims[mi]) );
101 factorvars.insert( Var(labels[mi], dims[mi]) );
102 }
103 factors.push_back(Factor(factorvars));
104
105 // calculate permutation matrix
106 vector<size_t> perm(nr_mem,0);
107 VarSet::iterator j = factorvars.begin();
108 for( size_t mi = 0; mi < nr_mem; mi++,j++ ) {
109 long gezocht = j->label();
110 vector<long>::iterator piet = find(labels.begin(),labels.end(),gezocht);
111 perm[mi] = piet - labels.begin();
112 }
113
114 if( verbose >= 3 ) {
115 cout << endl << " perm: ";
116 for( vector<size_t>::iterator r=perm.begin(); r!=perm.end(); r++ )
117 cout << *r << " ";
118 cout << endl;
119 }
120
121 // read Factor
122 vector<size_t> di(nr_mem,0);
123 size_t prod = 1;
124 for( size_t k = 0; k < nr_mem; k++ ) {
125 di[k] = dims[k];
126 prod *= dims[k];
127 }
128 Permute permindex( di, perm );
129 for( size_t li = 0; li < prod; li++ )
130 factors.back()[permindex.convert_linear_index(li)] = factordata[li];
131 }
132
133 if( verbose >= 3 ) {
134 for(vector<Factor>::const_iterator I=factors.begin(); I!=factors.end(); I++ )
135 cout << *I << endl;
136 }
137
138 return( factors );
139 }
140
141
142 /* Convert CPTAB-like struct to Factor */
143 Factor mx2Factor(const mxArray *psi) {
144 mxArray *mx_member = mxGetField(psi, 0, "Member");
145 size_t nr_mem = mxGetN(mx_member);
146 double *members = mxGetPr(mx_member);
147 const mwSize *dims = mxGetDimensions(mxGetField(psi,0,"P"));
148 double *factordata = mxGetPr(mxGetField(psi, 0, "P"));
149
150 // add variables
151 VarSet vars;
152 vector<long> labels(nr_mem,0);
153 for( size_t mi = 0; mi < nr_mem; mi++ ) {
154 labels[mi] = (long)members[mi];
155 vars.insert( Var(labels[mi], dims[mi]) );
156 }
157 Factor factor(vars);
158
159 // calculate permutation matrix
160 vector<size_t> perm(nr_mem,0);
161 VarSet::iterator j = vars.begin();
162 for( size_t mi = 0; mi < nr_mem; mi++,j++ ) {
163 long gezocht = j->label();
164 vector<long>::iterator piet = find(labels.begin(),labels.end(),gezocht);
165 perm[mi] = piet - labels.begin();
166 }
167
168 // read Factor
169 vector<size_t> di(nr_mem,0);
170 size_t prod = 1;
171 for( size_t k = 0; k < nr_mem; k++ ) {
172 di[k] = dims[k];
173 prod *= dims[k];
174 }
175 Permute permindex( di, perm );
176 for( size_t li = 0; li < prod; li++ )
177 factor[permindex.convert_linear_index(li)] = factordata[li];
178
179 return( factor );
180 }