Merged mf.* from SVN head (which implements damping)...
[libdai.git] / matlab / matlab.cpp
1 /* Copyright (C) 2006-2008 Joris Mooij [j dot mooij at science dot ru dot nl]
2 Radboud University Nijmegen, The Netherlands
3
4 This file is part of libDAI.
5
6 libDAI is free software; you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 2 of the License, or
9 (at your option) any later version.
10
11 libDAI is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with libDAI; if not, write to the Free Software
18 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21
22 #include <iostream>
23 #include <dai/matlab/matlab.h>
24
25
26 namespace dai {
27
28
29 using namespace std;
30
31
32 /* Convert vector<Factor> structure to a cell vector of CPTAB-like structs */
33 mxArray *Factors2mx(const vector<Factor> &Ps) {
34 size_t nr = Ps.size();
35
36 mxArray *psi = mxCreateCellMatrix(nr,1);
37
38 const char *fieldnames[2];
39 fieldnames[0] = "Member";
40 fieldnames[1] = "P";
41
42 size_t I_ind = 0;
43 for( vector<Factor>::const_iterator I = Ps.begin(); I != Ps.end(); I++, I_ind++ ) {
44 mxArray *Bi = mxCreateStructMatrix(1,1,2,fieldnames);
45
46 mxArray *BiMember = mxCreateDoubleMatrix(1,I->vars().size(),mxREAL);
47 double *BiMember_data = mxGetPr(BiMember);
48 size_t i = 0;
49 vector<mwSize> dims;
50 for( VarSet::const_iterator j = I->vars().begin(); j != I->vars().end(); j++,i++ ) {
51 BiMember_data[i] = j->label();
52 dims.push_back( j->states() );
53 }
54
55 mxArray *BiP = mxCreateNumericArray(I->vars().size(), &(*(dims.begin())), mxDOUBLE_CLASS, mxREAL);
56 double *BiP_data = mxGetPr(BiP);
57 for( size_t j = 0; j < I->states(); j++ )
58 BiP_data[j] = (*I)[j];
59
60 mxSetField(Bi,0,"Member",BiMember);
61 mxSetField(Bi,0,"P",BiP);
62
63 mxSetCell(psi, I_ind, Bi);
64 }
65 return( psi );
66 }
67
68
69 /* Convert cell vector of CPTAB-like structs to vector<Factor> */
70 vector<Factor> mx2Factors(const mxArray *psi, long verbose) {
71 set<Var> vars;
72 vector<Factor> factors;
73
74 int n1 = mxGetM(psi);
75 int n2 = mxGetN(psi);
76 if( n2 != 1 && n1 != 1 )
77 mexErrMsgTxt("psi should be a Nx1 or 1xN cell matrix.");
78 size_t nr_f = n1;
79 if( n1 == 1 )
80 nr_f = n2;
81
82 // interpret psi, linear cell array of cptabs
83 for( size_t cellind = 0; cellind < nr_f; cellind++ ) {
84 if( verbose >= 3 )
85 cout << "reading factor " << cellind << ": " << endl;
86 mxArray *cell = mxGetCell(psi, cellind);
87 mxArray *mx_member = mxGetField(cell, 0, "Member");
88 size_t nr_mem = mxGetN(mx_member);
89 double *members = mxGetPr(mx_member);
90 const mwSize *dims = mxGetDimensions(mxGetField(cell,0,"P"));
91 double *factordata = mxGetPr(mxGetField(cell, 0, "P"));
92
93 // add variables
94 VarSet factorvars;
95 vector<long> labels(nr_mem,0);
96 if( verbose >= 3 )
97 cout << " vars: ";
98 for( size_t mi = 0; mi < nr_mem; mi++ ) {
99 labels[mi] = (long)members[mi];
100 if( verbose >= 3 )
101 cout << labels[mi] << "(" << dims[mi] << ") ";
102 vars.insert( Var(labels[mi], dims[mi]) );
103 factorvars |= Var(labels[mi], dims[mi]);
104 }
105 factors.push_back(Factor(factorvars));
106
107 // calculate permutation matrix
108 vector<size_t> perm(nr_mem,0);
109 VarSet::iterator j = factorvars.begin();
110 for( size_t mi = 0; mi < nr_mem; mi++,j++ ) {
111 long gezocht = j->label();
112 vector<long>::iterator piet = find(labels.begin(),labels.end(),gezocht);
113 perm[mi] = piet - labels.begin();
114 }
115
116 if( verbose >= 3 ) {
117 cout << endl << " perm: ";
118 for( vector<size_t>::iterator r=perm.begin(); r!=perm.end(); r++ )
119 cout << *r << " ";
120 cout << endl;
121 }
122
123 // read Factor
124 vector<size_t> di(nr_mem,0);
125 size_t prod = 1;
126 for( size_t k = 0; k < nr_mem; k++ ) {
127 di[k] = dims[k];
128 prod *= dims[k];
129 }
130 Permute permindex( di, perm );
131 for( size_t li = 0; li < prod; li++ )
132 factors.back()[permindex.convert_linear_index(li)] = factordata[li];
133 }
134
135 if( verbose >= 3 ) {
136 for(vector<Factor>::const_iterator I=factors.begin(); I!=factors.end(); I++ )
137 cout << *I << endl;
138 }
139
140 return( factors );
141 }
142
143
144 /* Convert CPTAB-like struct to Factor */
145 Factor mx2Factor(const mxArray *psi) {
146 mxArray *mx_member = mxGetField(psi, 0, "Member");
147 size_t nr_mem = mxGetN(mx_member);
148 double *members = mxGetPr(mx_member);
149 const mwSize *dims = mxGetDimensions(mxGetField(psi,0,"P"));
150 double *factordata = mxGetPr(mxGetField(psi, 0, "P"));
151
152 // add variables
153 VarSet vars;
154 vector<long> labels(nr_mem,0);
155 for( size_t mi = 0; mi < nr_mem; mi++ ) {
156 labels[mi] = (long)members[mi];
157 vars |= Var(labels[mi], dims[mi]);
158 }
159 Factor factor(vars);
160
161 // calculate permutation matrix
162 vector<size_t> perm(nr_mem,0);
163 VarSet::iterator j = vars.begin();
164 for( size_t mi = 0; mi < nr_mem; mi++,j++ ) {
165 long gezocht = j->label();
166 vector<long>::iterator piet = find(labels.begin(),labels.end(),gezocht);
167 perm[mi] = piet - labels.begin();
168 }
169
170 // read Factor
171 vector<size_t> di(nr_mem,0);
172 size_t prod = 1;
173 for( size_t k = 0; k < nr_mem; k++ ) {
174 di[k] = dims[k];
175 prod *= dims[k];
176 }
177 Permute permindex( di, perm );
178 for( size_t li = 0; li < prod; li++ )
179 factor[permindex.convert_linear_index(li)] = factordata[li];
180
181 return( factor );
182 }
183
184
185 }