Merge branch 'joris'
[libdai.git] / src / matlab / matlab.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
4
5 This file is part of libDAI.
6
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 */
21
22
23 #include <iostream>
24 #include <dai/matlab/matlab.h>
25
26
27 namespace dai {
28
29
30 using namespace std;
31
32
33 /* Convert vector<Factor> structure to a cell vector of CPTAB-like structs */
34 mxArray *Factors2mx(const vector<Factor> &Ps) {
35 size_t nr = Ps.size();
36
37 mxArray *psi = mxCreateCellMatrix(nr,1);
38
39 const char *fieldnames[2];
40 fieldnames[0] = "Member";
41 fieldnames[1] = "P";
42
43 size_t I_ind = 0;
44 for( vector<Factor>::const_iterator I = Ps.begin(); I != Ps.end(); I++, I_ind++ ) {
45 mxArray *Bi = mxCreateStructMatrix(1,1,2,fieldnames);
46
47 mxArray *BiMember = mxCreateDoubleMatrix(1,I->vars().size(),mxREAL);
48 double *BiMember_data = mxGetPr(BiMember);
49 size_t i = 0;
50 vector<mwSize> dims;
51 for( VarSet::const_iterator j = I->vars().begin(); j != I->vars().end(); j++,i++ ) {
52 BiMember_data[i] = j->label();
53 dims.push_back( j->states() );
54 }
55
56 mxArray *BiP = mxCreateNumericArray(I->vars().size(), &(*(dims.begin())), mxDOUBLE_CLASS, mxREAL);
57 double *BiP_data = mxGetPr(BiP);
58 for( size_t j = 0; j < I->states(); j++ )
59 BiP_data[j] = (*I)[j];
60
61 mxSetField(Bi,0,"Member",BiMember);
62 mxSetField(Bi,0,"P",BiP);
63
64 mxSetCell(psi, I_ind, Bi);
65 }
66 return( psi );
67 }
68
69
70 /* Convert cell vector of CPTAB-like structs to vector<Factor> */
71 vector<Factor> mx2Factors(const mxArray *psi, long verbose) {
72 set<Var> vars;
73 vector<Factor> factors;
74
75 int n1 = mxGetM(psi);
76 int n2 = mxGetN(psi);
77 if( n2 != 1 && n1 != 1 )
78 mexErrMsgTxt("psi should be a Nx1 or 1xN cell matrix.");
79 size_t nr_f = n1;
80 if( n1 == 1 )
81 nr_f = n2;
82
83 // interpret psi, linear cell array of cptabs
84 for( size_t cellind = 0; cellind < nr_f; cellind++ ) {
85 if( verbose >= 3 )
86 cout << "reading factor " << cellind << ": " << endl;
87 mxArray *cell = mxGetCell(psi, cellind);
88 mxArray *mx_member = mxGetField(cell, 0, "Member");
89 size_t nr_mem = mxGetN(mx_member);
90 double *members = mxGetPr(mx_member);
91 const mwSize *dims = mxGetDimensions(mxGetField(cell,0,"P"));
92 double *factordata = mxGetPr(mxGetField(cell, 0, "P"));
93
94 // add variables
95 VarSet factorvars;
96 vector<long> labels(nr_mem,0);
97 if( verbose >= 3 )
98 cout << " vars: ";
99 for( size_t mi = 0; mi < nr_mem; mi++ ) {
100 labels[mi] = (long)members[mi];
101 if( verbose >= 3 )
102 cout << labels[mi] << "(" << dims[mi] << ") ";
103 vars.insert( Var(labels[mi], dims[mi]) );
104 factorvars |= Var(labels[mi], dims[mi]);
105 }
106 factors.push_back(Factor(factorvars));
107
108 // calculate permutation matrix
109 vector<size_t> perm(nr_mem,0);
110 VarSet::iterator j = factorvars.begin();
111 for( size_t mi = 0; mi < nr_mem; mi++,j++ ) {
112 long gezocht = j->label();
113 vector<long>::iterator piet = find(labels.begin(),labels.end(),gezocht);
114 perm[mi] = piet - labels.begin();
115 }
116
117 if( verbose >= 3 ) {
118 cout << endl << " perm: ";
119 for( vector<size_t>::iterator r=perm.begin(); r!=perm.end(); r++ )
120 cout << *r << " ";
121 cout << endl;
122 }
123
124 // read Factor
125 vector<size_t> di(nr_mem,0);
126 size_t prod = 1;
127 for( size_t k = 0; k < nr_mem; k++ ) {
128 di[k] = dims[k];
129 prod *= dims[k];
130 }
131 Permute permindex( di, perm );
132 for( size_t li = 0; li < prod; li++ )
133 factors.back()[permindex.convert_linear_index(li)] = factordata[li];
134 }
135
136 if( verbose >= 3 ) {
137 for(vector<Factor>::const_iterator I=factors.begin(); I!=factors.end(); I++ )
138 cout << *I << endl;
139 }
140
141 return( factors );
142 }
143
144
145 /* Convert CPTAB-like struct to Factor */
146 Factor mx2Factor(const mxArray *psi) {
147 mxArray *mx_member = mxGetField(psi, 0, "Member");
148 size_t nr_mem = mxGetN(mx_member);
149 double *members = mxGetPr(mx_member);
150 const mwSize *dims = mxGetDimensions(mxGetField(psi,0,"P"));
151 double *factordata = mxGetPr(mxGetField(psi, 0, "P"));
152
153 // add variables
154 VarSet vars;
155 vector<long> labels(nr_mem,0);
156 for( size_t mi = 0; mi < nr_mem; mi++ ) {
157 labels[mi] = (long)members[mi];
158 vars |= Var(labels[mi], dims[mi]);
159 }
160 Factor factor(vars);
161
162 // calculate permutation matrix
163 vector<size_t> perm(nr_mem,0);
164 VarSet::iterator j = vars.begin();
165 for( size_t mi = 0; mi < nr_mem; mi++,j++ ) {
166 long gezocht = j->label();
167 vector<long>::iterator piet = find(labels.begin(),labels.end(),gezocht);
168 perm[mi] = piet - labels.begin();
169 }
170
171 // read Factor
172 vector<size_t> di(nr_mem,0);
173 size_t prod = 1;
174 for( size_t k = 0; k < nr_mem; k++ ) {
175 di[k] = dims[k];
176 prod *= dims[k];
177 }
178 Permute permindex( di, perm );
179 for( size_t li = 0; li < prod; li++ )
180 factor[permindex.convert_linear_index(li)] = factordata[li];
181
182 return( factor );
183 }
184
185
186 }