Finished release 0.2.4
[libdai.git] / src / matlab / matlab.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * libDAI is licensed under the terms of the GNU General Public License version
4 * 2, or (at your option) any later version. libDAI is distributed without any
5 * warranty. See the file COPYING for more details.
6 *
7 * Copyright (C) 2006-2009 Joris Mooij [joris dot mooij at libdai dot org]
8 * Copyright (C) 2006-2007 Radboud University Nijmegen, The Netherlands
9 */
10
11
12 #include <iostream>
13 #include <dai/matlab/matlab.h>
14
15
16 namespace dai {
17
18
19 using namespace std;
20
21
22 /* Convert vector<Factor> structure to a cell vector of CPTAB-like structs */
23 mxArray *Factors2mx(const vector<Factor> &Ps) {
24 size_t nr = Ps.size();
25
26 mxArray *psi = mxCreateCellMatrix(nr,1);
27
28 const char *fieldnames[2];
29 fieldnames[0] = "Member";
30 fieldnames[1] = "P";
31
32 size_t I_ind = 0;
33 for( vector<Factor>::const_iterator I = Ps.begin(); I != Ps.end(); I++, I_ind++ ) {
34 mxArray *Bi = mxCreateStructMatrix(1,1,2,fieldnames);
35
36 mxArray *BiMember = mxCreateDoubleMatrix(1,I->vars().size(),mxREAL);
37 double *BiMember_data = mxGetPr(BiMember);
38 size_t i = 0;
39 vector<mwSize> dims;
40 for( VarSet::const_iterator j = I->vars().begin(); j != I->vars().end(); j++,i++ ) {
41 BiMember_data[i] = j->label();
42 dims.push_back( j->states() );
43 }
44
45 mxArray *BiP = mxCreateNumericArray(I->vars().size(), &(*(dims.begin())), mxDOUBLE_CLASS, mxREAL);
46 double *BiP_data = mxGetPr(BiP);
47 for( size_t j = 0; j < I->states(); j++ )
48 BiP_data[j] = (*I)[j];
49
50 mxSetField(Bi,0,"Member",BiMember);
51 mxSetField(Bi,0,"P",BiP);
52
53 mxSetCell(psi, I_ind, Bi);
54 }
55 return( psi );
56 }
57
58
59 /* Convert cell vector of CPTAB-like structs to vector<Factor> */
60 vector<Factor> mx2Factors(const mxArray *psi, long verbose) {
61 set<Var> vars;
62 vector<Factor> factors;
63
64 int n1 = mxGetM(psi);
65 int n2 = mxGetN(psi);
66 if( n2 != 1 && n1 != 1 )
67 mexErrMsgTxt("psi should be a Nx1 or 1xN cell matrix.");
68 size_t nr_f = n1;
69 if( n1 == 1 )
70 nr_f = n2;
71
72 // interpret psi, linear cell array of cptabs
73 for( size_t cellind = 0; cellind < nr_f; cellind++ ) {
74 if( verbose >= 3 )
75 cerr << "reading factor " << cellind << ": " << endl;
76 mxArray *cell = mxGetCell(psi, cellind);
77 mxArray *mx_member = mxGetField(cell, 0, "Member");
78 size_t nr_mem = mxGetN(mx_member);
79 double *members = mxGetPr(mx_member);
80 const mwSize *dims = mxGetDimensions(mxGetField(cell,0,"P"));
81 double *factordata = mxGetPr(mxGetField(cell, 0, "P"));
82
83 // add variables
84 VarSet factorvars;
85 vector<long> labels(nr_mem,0);
86 if( verbose >= 3 )
87 cerr << " vars: ";
88 for( size_t mi = 0; mi < nr_mem; mi++ ) {
89 labels[mi] = (long)members[mi];
90 if( verbose >= 3 )
91 cerr << labels[mi] << "(" << dims[mi] << ") ";
92 vars.insert( Var(labels[mi], dims[mi]) );
93 factorvars |= Var(labels[mi], dims[mi]);
94 }
95 factors.push_back(Factor(factorvars));
96
97 // calculate permutation matrix
98 vector<size_t> perm(nr_mem,0);
99 VarSet::iterator j = factorvars.begin();
100 for( size_t mi = 0; mi < nr_mem; mi++,j++ ) {
101 long gezocht = j->label();
102 vector<long>::iterator piet = find(labels.begin(),labels.end(),gezocht);
103 perm[mi] = piet - labels.begin();
104 }
105
106 if( verbose >= 3 ) {
107 cerr << endl << " perm: ";
108 for( vector<size_t>::iterator r=perm.begin(); r!=perm.end(); r++ )
109 cerr << *r << " ";
110 cerr << endl;
111 }
112
113 // read Factor
114 vector<size_t> di(nr_mem,0);
115 size_t prod = 1;
116 for( size_t k = 0; k < nr_mem; k++ ) {
117 di[k] = dims[k];
118 prod *= dims[k];
119 }
120 Permute permindex( di, perm );
121 for( size_t li = 0; li < prod; li++ )
122 factors.back()[permindex.convertLinearIndex(li)] = factordata[li];
123 }
124
125 if( verbose >= 3 ) {
126 for(vector<Factor>::const_iterator I=factors.begin(); I!=factors.end(); I++ )
127 cerr << *I << endl;
128 }
129
130 return( factors );
131 }
132
133
134 /* Convert CPTAB-like struct to Factor */
135 Factor mx2Factor(const mxArray *psi) {
136 mxArray *mx_member = mxGetField(psi, 0, "Member");
137 size_t nr_mem = mxGetN(mx_member);
138 double *members = mxGetPr(mx_member);
139 const mwSize *dims = mxGetDimensions(mxGetField(psi,0,"P"));
140 double *factordata = mxGetPr(mxGetField(psi, 0, "P"));
141
142 // add variables
143 VarSet vars;
144 vector<long> labels(nr_mem,0);
145 for( size_t mi = 0; mi < nr_mem; mi++ ) {
146 labels[mi] = (long)members[mi];
147 vars |= Var(labels[mi], dims[mi]);
148 }
149 Factor factor(vars);
150
151 // calculate permutation matrix
152 vector<size_t> perm(nr_mem,0);
153 VarSet::iterator j = vars.begin();
154 for( size_t mi = 0; mi < nr_mem; mi++,j++ ) {
155 long gezocht = j->label();
156 vector<long>::iterator piet = find(labels.begin(),labels.end(),gezocht);
157 perm[mi] = piet - labels.begin();
158 }
159
160 // read Factor
161 vector<size_t> di(nr_mem,0);
162 size_t prod = 1;
163 for( size_t k = 0; k < nr_mem; k++ ) {
164 di[k] = dims[k];
165 prod *= dims[k];
166 }
167 Permute permindex( di, perm );
168 for( size_t li = 0; li < prod; li++ )
169 factor[permindex.convertLinearIndex(li)] = factordata[li];
170
171 return( factor );
172 }
173
174
175 }