New git HEAD version
[libdai.git] / src / matlab / matlab.cpp
1 /* This file is part of libDAI - http://www.libdai.org/
2 *
3 * Copyright (c) 2006-2011, The libDAI authors. All rights reserved.
4 *
5 * Use of this source code is governed by a BSD-style license that can be found in the LICENSE file.
6 */
7
8
9 #include <iostream>
10 #include <dai/matlab/matlab.h>
11
12
13 namespace dai {
14
15
16 using namespace std;
17
18
19 /* Convert vector<Factor> structure to a cell vector of CPTAB-like structs */
20 mxArray *Factors2mx(const vector<Factor> &Ps) {
21 size_t nr = Ps.size();
22
23 mxArray *psi = mxCreateCellMatrix(nr,1);
24
25 const char *fieldnames[2];
26 fieldnames[0] = "Member";
27 fieldnames[1] = "P";
28
29 size_t I_ind = 0;
30 for( vector<Factor>::const_iterator I = Ps.begin(); I != Ps.end(); I++, I_ind++ ) {
31 mxArray *Bi = mxCreateStructMatrix(1,1,2,fieldnames);
32
33 mxArray *BiMember = mxCreateDoubleMatrix(1,I->vars().size(),mxREAL);
34 double *BiMember_data = mxGetPr(BiMember);
35 size_t i = 0;
36 vector<mwSize> dims;
37 for( VarSet::const_iterator j = I->vars().begin(); j != I->vars().end(); j++,i++ ) {
38 BiMember_data[i] = j->label();
39 dims.push_back( j->states() );
40 }
41 while( dims.size() <= 2 )
42 dims.push_back( 1 );
43
44 mxArray *BiP = mxCreateNumericArray(dims.size(), &(*(dims.begin())), mxDOUBLE_CLASS, mxREAL);
45 double *BiP_data = mxGetPr(BiP);
46 for( size_t j = 0; j < I->nrStates(); j++ )
47 BiP_data[j] = (*I)[j];
48
49 mxSetField(Bi,0,"Member",BiMember);
50 mxSetField(Bi,0,"P",BiP);
51
52 mxSetCell(psi, I_ind, Bi);
53 }
54 return( psi );
55 }
56
57
58 /* Convert cell vector of CPTAB-like structs to vector<Factor> */
59 vector<Factor> mx2Factors(const mxArray *psi, long verbose) {
60 set<Var> vars;
61 vector<Factor> factors;
62
63 int n1 = mxGetM(psi);
64 int n2 = mxGetN(psi);
65 if( n2 != 1 && n1 != 1 )
66 mexErrMsgTxt("psi should be a Nx1 or 1xN cell matrix.");
67 size_t nr_f = n1;
68 if( n1 == 1 )
69 nr_f = n2;
70
71 // interpret psi, linear cell array of cptabs
72 for( size_t cellind = 0; cellind < nr_f; cellind++ ) {
73 if( verbose >= 3 )
74 cerr << "reading factor " << cellind << ": " << endl;
75 mxArray *cell = mxGetCell(psi, cellind);
76 mxArray *mx_member = mxGetField(cell, 0, "Member");
77 size_t nr_mem = mxGetN(mx_member);
78 double *members = mxGetPr(mx_member);
79 const mwSize *dims = mxGetDimensions(mxGetField(cell,0,"P"));
80 double *factordata = mxGetPr(mxGetField(cell, 0, "P"));
81
82 // add variables
83 VarSet factorvars;
84 vector<long> labels(nr_mem,0);
85 if( verbose >= 3 )
86 cerr << " vars: ";
87 for( size_t mi = 0; mi < nr_mem; mi++ ) {
88 labels[mi] = (long)members[mi];
89 if( verbose >= 3 )
90 cerr << labels[mi] << "(" << dims[mi] << ") ";
91 vars.insert( Var(labels[mi], dims[mi]) );
92 factorvars |= Var(labels[mi], dims[mi]);
93 }
94 factors.push_back(Factor(factorvars));
95
96 // calculate permutation matrix
97 vector<size_t> perm(nr_mem,0);
98 VarSet::iterator j = factorvars.begin();
99 for( size_t mi = 0; mi < nr_mem; mi++,j++ ) {
100 long gezocht = j->label();
101 vector<long>::iterator piet = find(labels.begin(),labels.end(),gezocht);
102 perm[mi] = piet - labels.begin();
103 }
104
105 if( verbose >= 3 ) {
106 cerr << endl << " perm: ";
107 for( vector<size_t>::iterator r=perm.begin(); r!=perm.end(); r++ )
108 cerr << *r << " ";
109 cerr << endl;
110 }
111
112 // read Factor
113 vector<size_t> di(nr_mem,0);
114 size_t prod = 1;
115 for( size_t k = 0; k < nr_mem; k++ ) {
116 di[k] = dims[k];
117 prod *= dims[k];
118 }
119 Permute permindex( di, perm );
120 for( size_t li = 0; li < prod; li++ )
121 factors.back().set( permindex.convertLinearIndex(li), factordata[li] );
122 }
123
124 if( verbose >= 3 ) {
125 for(vector<Factor>::const_iterator I=factors.begin(); I!=factors.end(); I++ )
126 cerr << *I << endl;
127 }
128
129 return( factors );
130 }
131
132
133 /* Convert CPTAB-like struct to Factor */
134 Factor mx2Factor(const mxArray *psi) {
135 mxArray *mx_member = mxGetField(psi, 0, "Member");
136 size_t nr_mem = mxGetN(mx_member);
137 double *members = mxGetPr(mx_member);
138 const mwSize *dims = mxGetDimensions(mxGetField(psi,0,"P"));
139 double *factordata = mxGetPr(mxGetField(psi, 0, "P"));
140
141 // add variables
142 VarSet vars;
143 vector<long> labels(nr_mem,0);
144 for( size_t mi = 0; mi < nr_mem; mi++ ) {
145 labels[mi] = (long)members[mi];
146 vars |= Var(labels[mi], dims[mi]);
147 }
148 Factor factor(vars);
149
150 // calculate permutation matrix
151 vector<size_t> perm(nr_mem,0);
152 VarSet::iterator j = vars.begin();
153 for( size_t mi = 0; mi < nr_mem; mi++,j++ ) {
154 long gezocht = j->label();
155 vector<long>::iterator piet = find(labels.begin(),labels.end(),gezocht);
156 perm[mi] = piet - labels.begin();
157 }
158
159 // read Factor
160 vector<size_t> di(nr_mem,0);
161 size_t prod = 1;
162 for( size_t k = 0; k < nr_mem; k++ ) {
163 di[k] = dims[k];
164 prod *= dims[k];
165 }
166 Permute permindex( di, perm );
167 for( size_t li = 0; li < prod; li++ )
168 factor.set( permindex.convertLinearIndex(li), factordata[li] );
169
170 return( factor );
171 }
172
173
174 }