1 /* Copyright (C) 2006-2008 Joris Mooij [joris dot mooij at tuebingen dot mpg dot de]
2 Radboud University Nijmegen, The Netherlands /
3 Max Planck Institute for Biological Cybernetics, Germany
5 This file is part of libDAI.
7 libDAI is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
12 libDAI is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 You should have received a copy of the GNU General Public License
18 along with libDAI; if not, write to the Free Software
19 Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
24 #include <dai/matlab/matlab.h>
33 /* Convert vector<Factor> structure to a cell vector of CPTAB-like structs */
34 mxArray
*Factors2mx(const vector
<Factor
> &Ps
) {
35 size_t nr
= Ps
.size();
37 mxArray
*psi
= mxCreateCellMatrix(nr
,1);
39 const char *fieldnames
[2];
40 fieldnames
[0] = "Member";
44 for( vector
<Factor
>::const_iterator I
= Ps
.begin(); I
!= Ps
.end(); I
++, I_ind
++ ) {
45 mxArray
*Bi
= mxCreateStructMatrix(1,1,2,fieldnames
);
47 mxArray
*BiMember
= mxCreateDoubleMatrix(1,I
->vars().size(),mxREAL
);
48 double *BiMember_data
= mxGetPr(BiMember
);
51 for( VarSet::const_iterator j
= I
->vars().begin(); j
!= I
->vars().end(); j
++,i
++ ) {
52 BiMember_data
[i
] = j
->label();
53 dims
.push_back( j
->states() );
56 mxArray
*BiP
= mxCreateNumericArray(I
->vars().size(), &(*(dims
.begin())), mxDOUBLE_CLASS
, mxREAL
);
57 double *BiP_data
= mxGetPr(BiP
);
58 for( size_t j
= 0; j
< I
->states(); j
++ )
59 BiP_data
[j
] = (*I
)[j
];
61 mxSetField(Bi
,0,"Member",BiMember
);
62 mxSetField(Bi
,0,"P",BiP
);
64 mxSetCell(psi
, I_ind
, Bi
);
70 /* Convert cell vector of CPTAB-like structs to vector<Factor> */
71 vector
<Factor
> mx2Factors(const mxArray
*psi
, long verbose
) {
73 vector
<Factor
> factors
;
77 if( n2
!= 1 && n1
!= 1 )
78 mexErrMsgTxt("psi should be a Nx1 or 1xN cell matrix.");
83 // interpret psi, linear cell array of cptabs
84 for( size_t cellind
= 0; cellind
< nr_f
; cellind
++ ) {
86 cerr
<< "reading factor " << cellind
<< ": " << endl
;
87 mxArray
*cell
= mxGetCell(psi
, cellind
);
88 mxArray
*mx_member
= mxGetField(cell
, 0, "Member");
89 size_t nr_mem
= mxGetN(mx_member
);
90 double *members
= mxGetPr(mx_member
);
91 const mwSize
*dims
= mxGetDimensions(mxGetField(cell
,0,"P"));
92 double *factordata
= mxGetPr(mxGetField(cell
, 0, "P"));
96 vector
<long> labels(nr_mem
,0);
99 for( size_t mi
= 0; mi
< nr_mem
; mi
++ ) {
100 labels
[mi
] = (long)members
[mi
];
102 cerr
<< labels
[mi
] << "(" << dims
[mi
] << ") ";
103 vars
.insert( Var(labels
[mi
], dims
[mi
]) );
104 factorvars
|= Var(labels
[mi
], dims
[mi
]);
106 factors
.push_back(Factor(factorvars
));
108 // calculate permutation matrix
109 vector
<size_t> perm(nr_mem
,0);
110 VarSet::iterator j
= factorvars
.begin();
111 for( size_t mi
= 0; mi
< nr_mem
; mi
++,j
++ ) {
112 long gezocht
= j
->label();
113 vector
<long>::iterator piet
= find(labels
.begin(),labels
.end(),gezocht
);
114 perm
[mi
] = piet
- labels
.begin();
118 cerr
<< endl
<< " perm: ";
119 for( vector
<size_t>::iterator r
=perm
.begin(); r
!=perm
.end(); r
++ )
125 vector
<size_t> di(nr_mem
,0);
127 for( size_t k
= 0; k
< nr_mem
; k
++ ) {
131 Permute
permindex( di
, perm
);
132 for( size_t li
= 0; li
< prod
; li
++ )
133 factors
.back()[permindex
.convert_linear_index(li
)] = factordata
[li
];
137 for(vector
<Factor
>::const_iterator I
=factors
.begin(); I
!=factors
.end(); I
++ )
145 /* Convert CPTAB-like struct to Factor */
146 Factor
mx2Factor(const mxArray
*psi
) {
147 mxArray
*mx_member
= mxGetField(psi
, 0, "Member");
148 size_t nr_mem
= mxGetN(mx_member
);
149 double *members
= mxGetPr(mx_member
);
150 const mwSize
*dims
= mxGetDimensions(mxGetField(psi
,0,"P"));
151 double *factordata
= mxGetPr(mxGetField(psi
, 0, "P"));
155 vector
<long> labels(nr_mem
,0);
156 for( size_t mi
= 0; mi
< nr_mem
; mi
++ ) {
157 labels
[mi
] = (long)members
[mi
];
158 vars
|= Var(labels
[mi
], dims
[mi
]);
162 // calculate permutation matrix
163 vector
<size_t> perm(nr_mem
,0);
164 VarSet::iterator j
= vars
.begin();
165 for( size_t mi
= 0; mi
< nr_mem
; mi
++,j
++ ) {
166 long gezocht
= j
->label();
167 vector
<long>::iterator piet
= find(labels
.begin(),labels
.end(),gezocht
);
168 perm
[mi
] = piet
- labels
.begin();
172 vector
<size_t> di(nr_mem
,0);
174 for( size_t k
= 0; k
< nr_mem
; k
++ ) {
178 Permute
permindex( di
, perm
);
179 for( size_t li
= 0; li
< prod
; li
++ )
180 factor
[permindex
.convert_linear_index(li
)] = factordata
[li
];