+ rewrote C interface for SWIG/Python
[qpalma.git] / QPalmaDP / penalty_info.cpp
1 #include <assert.h>
2 #include "config.h"
3 //#include "features/CharFeatures.h"
4 //#include "features/StringFeatures.h"
5
6 #include <stdio.h>
7 #include <string.h>
8
9 #include "io.h"
10
11
12 #include "fill_matrix.h"
13 #include "penalty_info.h"
14
15 void init_penalty_struct(struct penalty_struct &PEN)
16 {
17 PEN.limits=NULL ;
18 PEN.penalties=NULL ;
19 PEN.id=-1 ;
20 PEN.next_pen=NULL ;
21 PEN.transform = T_LINEAR ;
22 PEN.name = NULL ;
23 PEN.max_len=0 ;
24 PEN.min_len=0 ;
25 PEN.cache=NULL ;
26 PEN.use_svm=0 ;
27 }
28
29 void init_penalty_struct_cache(struct penalty_struct &PEN)
30 {
31 if (PEN.cache || PEN.use_svm)
32 return ;
33
34 REAL* cache=new REAL[PEN.max_len+1] ;
35 if (cache)
36 {
37 for (INT i=0; i<=PEN.max_len; i++)
38 cache[i] = lookup_penalty(&PEN, i, 0, false) ;
39 PEN.cache = cache ;
40 }
41 }
42
43 void delete_penalty_struct_palma(struct penalty_struct &PEN)
44 {
45 if (PEN.id!=-1)
46 {
47 delete[] PEN.limits ;
48 delete[] PEN.penalties ;
49 delete[] PEN.name ;
50 delete[] PEN.cache ;
51 }
52 }
53
54 void delete_penalty_struct_array(struct penalty_struct *PEN, INT len)
55 {
56 for (int i=0; i<len; i++)
57 delete_penalty_struct_palma(PEN[i]) ;
58 delete[] PEN ;
59 }
60
61
62 //struct penalty_struct * read_penalty_struct_from_cell(const mxArray * mx_penalty_info, int &P)
63 //{
64 // P = mxGetN(mx_penalty_info) ;
65 //
66 // struct penalty_struct * PEN = new struct penalty_struct[P] ;
67 // for (INT i=0; i<P; i++)
68 // init_penalty_struct(PEN[i]) ;
69 //
70 // for (INT i=0; i<P; i++)
71 // {
72 // const mxArray* mx_elem = mxGetCell(mx_penalty_info, i) ;
73 // if (mx_elem==NULL || !mxIsStruct(mx_elem))
74 // {
75 // CIO::message(M_ERROR, "empty cell element\n") ;
76 // delete_penalty_struct_array(PEN,P) ;
77 // return NULL ;
78 // } ;
79 // const mxArray* mx_id_field = mxGetField(mx_elem, 0, "id") ;
80 // if (mx_id_field==NULL || !mxIsNumeric(mx_id_field) ||
81 // mxGetN(mx_id_field)!=1 || mxGetM(mx_id_field)!=1)
82 // {
83 // CIO::message(M_ERROR, "missing id field\n") ;
84 // delete_penalty_struct_array(PEN,P) ;
85 // return NULL ;
86 // }
87 // const mxArray* mx_limits_field = mxGetField(mx_elem, 0, "limits") ;
88 // if (mx_limits_field==NULL || !mxIsNumeric(mx_limits_field) ||
89 // mxGetM(mx_limits_field)!=1)
90 // {
91 // CIO::message(M_ERROR, "missing limits field\n") ;
92 // delete_penalty_struct_array(PEN,P) ;
93 // return NULL ;
94 // }
95 // INT len = mxGetN(mx_limits_field) ;
96 //
97 // const mxArray* mx_penalties_field = mxGetField(mx_elem, 0, "penalties") ;
98 // if (mx_penalties_field==NULL || !mxIsNumeric(mx_penalties_field) ||
99 // mxGetM(mx_penalties_field)!=1 || mxGetN(mx_penalties_field)!=len)
100 // {
101 // CIO::message(M_ERROR, "missing penalties field\n") ;
102 // delete_penalty_struct_array(PEN,P) ;
103 // return NULL ;
104 // }
105 // const mxArray* mx_transform_field = mxGetField(mx_elem, 0, "transform") ;
106 // if (mx_transform_field==NULL || !mxIsChar(mx_transform_field))
107 // {
108 // CIO::message(M_ERROR, "missing transform field\n") ;
109 // delete_penalty_struct_array(PEN,P) ;
110 // return NULL ;
111 // }
112 // const mxArray* mx_name_field = mxGetField(mx_elem, 0, "name") ;
113 // if (mx_name_field==NULL || !mxIsChar(mx_name_field))
114 // {
115 // CIO::message(M_ERROR, "missing name field\n") ;
116 // delete_penalty_struct_array(PEN,P) ;
117 // return NULL ;
118 // }
119 // const mxArray* mx_max_len_field = mxGetField(mx_elem, 0, "max_len") ;
120 // if (mx_max_len_field==NULL || !mxIsNumeric(mx_max_len_field) ||
121 // mxGetM(mx_max_len_field)!=1 || mxGetN(mx_max_len_field)!=1)
122 // {
123 // CIO::message(M_ERROR, "missing max_len field\n") ;
124 // delete_penalty_struct_array(PEN,P) ;
125 // return NULL ;
126 // }
127 // const mxArray* mx_min_len_field = mxGetField(mx_elem, 0, "min_len") ;
128 // if (mx_min_len_field==NULL || !mxIsNumeric(mx_min_len_field) ||
129 // mxGetM(mx_min_len_field)!=1 || mxGetN(mx_min_len_field)!=1)
130 // {
131 // CIO::message(M_ERROR, "missing min_len field\n") ;
132 // delete_penalty_struct_array(PEN,P) ;
133 // return NULL ;
134 // }
135 // const mxArray* mx_use_svm_field = mxGetField(mx_elem, 0, "use_svm") ;
136 // if (mx_use_svm_field==NULL || !mxIsNumeric(mx_use_svm_field) ||
137 // mxGetM(mx_use_svm_field)!=1 || mxGetN(mx_use_svm_field)!=1)
138 // {
139 // CIO::message(M_ERROR, "missing use_svm field\n") ;
140 // delete_penalty_struct_array(PEN,P) ;
141 // return NULL ;
142 // }
143 // INT use_svm = (INT) mxGetScalar(mx_use_svm_field) ;
144 // //fprintf(stderr, "use_svm_field=%i\n", use_svm) ;
145 //
146 // const mxArray* mx_next_id_field = mxGetField(mx_elem, 0, "next_id") ;
147 // if (mx_next_id_field==NULL || !mxIsNumeric(mx_next_id_field) ||
148 // mxGetM(mx_next_id_field)!=1 || mxGetN(mx_next_id_field)!=1)
149 // {
150 // CIO::message(M_ERROR, "missing next_id field\n") ;
151 // delete_penalty_struct_array(PEN,P) ;
152 // return NULL ;
153 // }
154 // INT next_id = (INT) mxGetScalar(mx_next_id_field)-1 ;
155 //
156 // INT id = (INT) mxGetScalar(mx_id_field)-1 ;
157 // if (i<0 || i>P-1)
158 // {
159 // CIO::message(M_ERROR, "id out of range\n") ;
160 // delete_penalty_struct_array(PEN,P) ;
161 // return NULL ;
162 // }
163 // INT max_len = (INT) mxGetScalar(mx_max_len_field) ;
164 // if (max_len<0 || max_len>1024*1024*100)
165 // {
166 // CIO::message(M_ERROR, "max_len out of range\n") ;
167 // delete_penalty_struct_array(PEN,P) ;
168 // return NULL ;
169 // }
170 // PEN[id].max_len = max_len ;
171 //
172 // INT min_len = (INT) mxGetScalar(mx_min_len_field) ;
173 // if (min_len<0 || min_len>1024*1024*100)
174 // {
175 // CIO::message(M_ERROR, "min_len out of range\n") ;
176 // delete_penalty_struct_array(PEN,P) ;
177 // return NULL ;
178 // }
179 // PEN[id].min_len = min_len ;
180 //
181 // if (PEN[id].id!=-1)
182 // {
183 // CIO::message(M_ERROR, "penalty id already used\n") ;
184 // delete_penalty_struct_array(PEN,P) ;
185 // return NULL ;
186 // }
187 // PEN[id].id=id ;
188 // if (next_id>=0)
189 // PEN[id].next_pen=&PEN[next_id] ;
190 // //fprintf(stderr,"id=%i, next_id=%i\n", id, next_id) ;
191 //
192 // assert(next_id!=id) ;
193 // PEN[id].use_svm=use_svm ;
194 // PEN[id].limits = new REAL[len] ;
195 // PEN[id].penalties = new REAL[len] ;
196 // double * limits = mxGetPr(mx_limits_field) ;
197 // double * penalties = mxGetPr(mx_penalties_field) ;
198 //
199 // for (INT i=0; i<len; i++)
200 // {
201 // PEN[id].limits[i]=limits[i] ;
202 // PEN[id].penalties[i]=penalties[i] ;
203 // }
204 // PEN[id].len = len ;
205 //
206 // char *transform_str = mxArrayToString(mx_transform_field) ;
207 // char *name_str = mxArrayToString(mx_name_field) ;
208 //
209 // if (strcmp(transform_str, "log")==0)
210 // PEN[id].transform = T_LOG ;
211 // else if (strcmp(transform_str, "log(+1)")==0)
212 // PEN[id].transform = T_LOG_PLUS1 ;
213 // else if (strcmp(transform_str, "log(+3)")==0)
214 // PEN[id].transform = T_LOG_PLUS3 ;
215 // else if (strcmp(transform_str, "(+3)")==0)
216 // PEN[id].transform = T_LINEAR_PLUS3 ;
217 // else if (strcmp(transform_str, "")==0)
218 // PEN[id].transform = T_LINEAR ;
219 // else
220 // {
221 // delete_penalty_struct_array(PEN,P) ;
222 // mxFree(transform_str) ;
223 // return NULL ;
224 // }
225 // PEN[id].name = new char[strlen(name_str)+1] ;
226 // strcpy(PEN[id].name, name_str) ;
227 //
228 // init_penalty_struct_cache(PEN[id]) ;
229 //
230 // mxFree(transform_str) ;
231 // mxFree(name_str) ;
232 // }
233 // return PEN ;
234 //}
235
236 REAL lookup_penalty_svm(const struct penalty_struct *PEN, INT p_value, REAL *d_values)
237 {
238 if (PEN==NULL)
239 return 0 ;
240 assert(PEN->use_svm>0) ;
241 REAL d_value=d_values[PEN->use_svm-1] ;
242 //fprintf(stderr,"transform=%i, d_value=%1.2f\n", (INT)PEN->transform, d_value) ;
243
244 switch (PEN->transform)
245 {
246 case T_LINEAR:
247 break ;
248 case T_LOG:
249 d_value = log(d_value) ;
250 break ;
251 case T_LOG_PLUS1:
252 d_value = log(d_value+1) ;
253 break ;
254 case T_LOG_PLUS3:
255 d_value = log(d_value+3) ;
256 break ;
257 case T_LINEAR_PLUS3:
258 d_value = d_value+3 ;
259 break ;
260 default:
261 CIO::message(M_ERROR, "unknown transform\n") ;
262 break ;
263 }
264
265 INT idx = 0 ;
266 REAL ret ;
267 for (INT i=0; i<PEN->len; i++)
268 if (PEN->limits[i]<=d_value)
269 idx++ ;
270
271 if (idx==0)
272 ret=PEN->penalties[0] ;
273 else if (idx==PEN->len)
274 ret=PEN->penalties[PEN->len-1] ;
275 else
276 {
277 ret = (PEN->penalties[idx]*(d_value-PEN->limits[idx-1]) + PEN->penalties[idx-1]*
278 (PEN->limits[idx]-d_value)) / (PEN->limits[idx]-PEN->limits[idx-1]) ;
279 }
280
281 //fprintf(stderr,"ret=%1.2f\n", ret) ;
282
283 if (PEN->next_pen)
284 ret+=lookup_penalty(PEN->next_pen, p_value, d_values);
285
286 //fprintf(stderr,"ret=%1.2f\n", ret) ;
287
288 return ret ;
289 }
290
291 REAL lookup_penalty(const struct penalty_struct *PEN, INT p_value,
292 REAL* svm_values, bool follow_next)
293 {
294 if (PEN==NULL)
295 return 0 ;
296 if (PEN->use_svm)
297 return lookup_penalty_svm(PEN, p_value, svm_values) ;
298
299 if ((p_value<PEN->min_len) || (p_value>PEN->max_len))
300 return -CMath::INFTY ;
301
302 if (PEN->cache!=NULL && (p_value>=0) && (p_value<=PEN->max_len))
303 {
304 REAL ret=PEN->cache[p_value] ;
305 if (PEN->next_pen && follow_next)
306 ret+=lookup_penalty(PEN->next_pen, p_value, svm_values);
307 return ret ;
308 }
309
310 REAL d_value = (REAL) p_value ;
311 switch (PEN->transform)
312 {
313 case T_LINEAR:
314 break ;
315 case T_LOG:
316 d_value = log(d_value) ;
317 break ;
318 case T_LOG_PLUS1:
319 d_value = log(d_value+1) ;
320 break ;
321 case T_LOG_PLUS3:
322 d_value = log(d_value+3) ;
323 break ;
324 case T_LINEAR_PLUS3:
325 d_value = d_value+3 ;
326 break ;
327 default:
328 CIO::message(M_ERROR, "unknown transform\n") ;
329 break ;
330 }
331
332 INT idx = 0 ;
333 REAL ret ;
334 for (INT i=0; i<PEN->len; i++)
335 if (PEN->limits[i]<=d_value)
336 idx++ ;
337
338 if (idx==0)
339 ret=PEN->penalties[0] ;
340 else if (idx==PEN->len)
341 ret=PEN->penalties[PEN->len-1] ;
342 else
343 {
344 ret = (PEN->penalties[idx]*(d_value-PEN->limits[idx-1]) + PEN->penalties[idx-1]*
345 (PEN->limits[idx]-d_value)) / (PEN->limits[idx]-PEN->limits[idx-1]) ;
346 }
347 //if (p_value>=30 && p_value<150)
348 //fprintf(stderr, "%s %i(%i) -> %1.2f\n", PEN->name, p_value, idx, ret) ;
349
350 if (PEN->next_pen && follow_next)
351 ret+=lookup_penalty(PEN->next_pen, p_value, svm_values);
352
353 return ret ;
354 }