+ parameter transfer to python layer works
[qpalma.git] / QPalmaDP / penalty_info.cpp
1 #include <assert.h>
2 #include "config.h"
3 //#include "features/CharFeatures.h"
4 //#include "features/StringFeatures.h"
5
6 #include <stdio.h>
7 #include <string.h>
8
9 #include "io.h"
10
11
12 #include "fill_matrix.h"
13 #include "penalty_info.h"
14
15 void init_penalty_struct(struct penalty_struct &PEN)
16 {
17 PEN.limits=NULL ;
18 PEN.penalties=NULL ;
19 PEN.id=-1 ;
20 PEN.next_pen=NULL ;
21 PEN.transform = T_LINEAR ;
22 PEN.name = NULL ;
23 PEN.max_len=0 ;
24 PEN.min_len=0 ;
25 PEN.cache=NULL ;
26 PEN.use_svm=0 ;
27 }
28
29 void init_penalty_struct_cache(struct penalty_struct &PEN)
30 {
31 if (PEN.cache || PEN.use_svm)
32 return ;
33
34 REAL* cache=new REAL[PEN.max_len+1] ;
35 if (cache)
36 {
37 for (INT i=0; i<=PEN.max_len; i++)
38 cache[i] = lookup_penalty(&PEN, i, 0, false) ;
39 PEN.cache = cache ;
40 }
41 }
42
43 void delete_penalty_struct_palma(struct penalty_struct &PEN)
44 {
45 if (PEN.id!=-1)
46 {
47 delete[] PEN.limits ;
48 delete[] PEN.penalties ;
49 delete[] PEN.name ;
50 delete[] PEN.cache ;
51 }
52 }
53
54 void delete_penalty_struct_array(struct penalty_struct *PEN, INT len)
55 {
56 for (int i=0; i<len; i++)
57 delete_penalty_struct_palma(PEN[i]) ;
58 delete[] PEN ;
59 }
60
61 void copy_penalty_struct(struct penalty_struct *old, struct penalty_struct *newp) {
62 newp->len = old->len;
63 newp->limits = old->limits;
64 newp->penalties = old->penalties;
65 newp->max_len = old->max_len;
66 newp->min_len = old->min_len;
67 newp->cache = old->cache;
68 newp->transform = old->transform;
69 newp->id = old->id;
70 newp->next_pen = old->next_pen;
71 newp->name = old->name;
72 newp->use_svm = old->use_svm;
73 }
74
75
76 //struct penalty_struct * read_penalty_struct_from_cell(const mxArray * mx_penalty_info, int &P)
77 //{
78 // P = mxGetN(mx_penalty_info) ;
79 //
80 // struct penalty_struct * PEN = new struct penalty_struct[P] ;
81 // for (INT i=0; i<P; i++)
82 // init_penalty_struct(PEN[i]) ;
83 //
84 // for (INT i=0; i<P; i++)
85 // {
86 // const mxArray* mx_elem = mxGetCell(mx_penalty_info, i) ;
87 // if (mx_elem==NULL || !mxIsStruct(mx_elem))
88 // {
89 // CIO::message(M_ERROR, "empty cell element\n") ;
90 // delete_penalty_struct_array(PEN,P) ;
91 // return NULL ;
92 // } ;
93 // const mxArray* mx_id_field = mxGetField(mx_elem, 0, "id") ;
94 // if (mx_id_field==NULL || !mxIsNumeric(mx_id_field) ||
95 // mxGetN(mx_id_field)!=1 || mxGetM(mx_id_field)!=1)
96 // {
97 // CIO::message(M_ERROR, "missing id field\n") ;
98 // delete_penalty_struct_array(PEN,P) ;
99 // return NULL ;
100 // }
101 // const mxArray* mx_limits_field = mxGetField(mx_elem, 0, "limits") ;
102 // if (mx_limits_field==NULL || !mxIsNumeric(mx_limits_field) ||
103 // mxGetM(mx_limits_field)!=1)
104 // {
105 // CIO::message(M_ERROR, "missing limits field\n") ;
106 // delete_penalty_struct_array(PEN,P) ;
107 // return NULL ;
108 // }
109 // INT len = mxGetN(mx_limits_field) ;
110 //
111 // const mxArray* mx_penalties_field = mxGetField(mx_elem, 0, "penalties") ;
112 // if (mx_penalties_field==NULL || !mxIsNumeric(mx_penalties_field) ||
113 // mxGetM(mx_penalties_field)!=1 || mxGetN(mx_penalties_field)!=len)
114 // {
115 // CIO::message(M_ERROR, "missing penalties field\n") ;
116 // delete_penalty_struct_array(PEN,P) ;
117 // return NULL ;
118 // }
119 // const mxArray* mx_transform_field = mxGetField(mx_elem, 0, "transform") ;
120 // if (mx_transform_field==NULL || !mxIsChar(mx_transform_field))
121 // {
122 // CIO::message(M_ERROR, "missing transform field\n") ;
123 // delete_penalty_struct_array(PEN,P) ;
124 // return NULL ;
125 // }
126 // const mxArray* mx_name_field = mxGetField(mx_elem, 0, "name") ;
127 // if (mx_name_field==NULL || !mxIsChar(mx_name_field))
128 // {
129 // CIO::message(M_ERROR, "missing name field\n") ;
130 // delete_penalty_struct_array(PEN,P) ;
131 // return NULL ;
132 // }
133 // const mxArray* mx_max_len_field = mxGetField(mx_elem, 0, "max_len") ;
134 // if (mx_max_len_field==NULL || !mxIsNumeric(mx_max_len_field) ||
135 // mxGetM(mx_max_len_field)!=1 || mxGetN(mx_max_len_field)!=1)
136 // {
137 // CIO::message(M_ERROR, "missing max_len field\n") ;
138 // delete_penalty_struct_array(PEN,P) ;
139 // return NULL ;
140 // }
141 // const mxArray* mx_min_len_field = mxGetField(mx_elem, 0, "min_len") ;
142 // if (mx_min_len_field==NULL || !mxIsNumeric(mx_min_len_field) ||
143 // mxGetM(mx_min_len_field)!=1 || mxGetN(mx_min_len_field)!=1)
144 // {
145 // CIO::message(M_ERROR, "missing min_len field\n") ;
146 // delete_penalty_struct_array(PEN,P) ;
147 // return NULL ;
148 // }
149 // const mxArray* mx_use_svm_field = mxGetField(mx_elem, 0, "use_svm") ;
150 // if (mx_use_svm_field==NULL || !mxIsNumeric(mx_use_svm_field) ||
151 // mxGetM(mx_use_svm_field)!=1 || mxGetN(mx_use_svm_field)!=1)
152 // {
153 // CIO::message(M_ERROR, "missing use_svm field\n") ;
154 // delete_penalty_struct_array(PEN,P) ;
155 // return NULL ;
156 // }
157 // INT use_svm = (INT) mxGetScalar(mx_use_svm_field) ;
158 // //fprintf(stderr, "use_svm_field=%i\n", use_svm) ;
159 //
160 // const mxArray* mx_next_id_field = mxGetField(mx_elem, 0, "next_id") ;
161 // if (mx_next_id_field==NULL || !mxIsNumeric(mx_next_id_field) ||
162 // mxGetM(mx_next_id_field)!=1 || mxGetN(mx_next_id_field)!=1)
163 // {
164 // CIO::message(M_ERROR, "missing next_id field\n") ;
165 // delete_penalty_struct_array(PEN,P) ;
166 // return NULL ;
167 // }
168 // INT next_id = (INT) mxGetScalar(mx_next_id_field)-1 ;
169 //
170 // INT id = (INT) mxGetScalar(mx_id_field)-1 ;
171 // if (i<0 || i>P-1)
172 // {
173 // CIO::message(M_ERROR, "id out of range\n") ;
174 // delete_penalty_struct_array(PEN,P) ;
175 // return NULL ;
176 // }
177 // INT max_len = (INT) mxGetScalar(mx_max_len_field) ;
178 // if (max_len<0 || max_len>1024*1024*100)
179 // {
180 // CIO::message(M_ERROR, "max_len out of range\n") ;
181 // delete_penalty_struct_array(PEN,P) ;
182 // return NULL ;
183 // }
184 // PEN[id].max_len = max_len ;
185 //
186 // INT min_len = (INT) mxGetScalar(mx_min_len_field) ;
187 // if (min_len<0 || min_len>1024*1024*100)
188 // {
189 // CIO::message(M_ERROR, "min_len out of range\n") ;
190 // delete_penalty_struct_array(PEN,P) ;
191 // return NULL ;
192 // }
193 // PEN[id].min_len = min_len ;
194 //
195 // if (PEN[id].id!=-1)
196 // {
197 // CIO::message(M_ERROR, "penalty id already used\n") ;
198 // delete_penalty_struct_array(PEN,P) ;
199 // return NULL ;
200 // }
201 // PEN[id].id=id ;
202 // if (next_id>=0)
203 // PEN[id].next_pen=&PEN[next_id] ;
204 // //fprintf(stderr,"id=%i, next_id=%i\n", id, next_id) ;
205 //
206 // assert(next_id!=id) ;
207 // PEN[id].use_svm=use_svm ;
208 // PEN[id].limits = new REAL[len] ;
209 // PEN[id].penalties = new REAL[len] ;
210 // double * limits = mxGetPr(mx_limits_field) ;
211 // double * penalties = mxGetPr(mx_penalties_field) ;
212 //
213 // for (INT i=0; i<len; i++)
214 // {
215 // PEN[id].limits[i]=limits[i] ;
216 // PEN[id].penalties[i]=penalties[i] ;
217 // }
218 // PEN[id].len = len ;
219 //
220 // char *transform_str = mxArrayToString(mx_transform_field) ;
221 // char *name_str = mxArrayToString(mx_name_field) ;
222 //
223 // if (strcmp(transform_str, "log")==0)
224 // PEN[id].transform = T_LOG ;
225 // else if (strcmp(transform_str, "log(+1)")==0)
226 // PEN[id].transform = T_LOG_PLUS1 ;
227 // else if (strcmp(transform_str, "log(+3)")==0)
228 // PEN[id].transform = T_LOG_PLUS3 ;
229 // else if (strcmp(transform_str, "(+3)")==0)
230 // PEN[id].transform = T_LINEAR_PLUS3 ;
231 // else if (strcmp(transform_str, "")==0)
232 // PEN[id].transform = T_LINEAR ;
233 // else
234 // {
235 // delete_penalty_struct_array(PEN,P) ;
236 // mxFree(transform_str) ;
237 // return NULL ;
238 // }
239 // PEN[id].name = new char[strlen(name_str)+1] ;
240 // strcpy(PEN[id].name, name_str) ;
241 //
242 // init_penalty_struct_cache(PEN[id]) ;
243 //
244 // mxFree(transform_str) ;
245 // mxFree(name_str) ;
246 // }
247 // return PEN ;
248 //}
249
250 REAL lookup_penalty_svm(const struct penalty_struct *PEN, INT p_value, REAL *d_values)
251 {
252 if (PEN==NULL)
253 return 0 ;
254 assert(PEN->use_svm>0) ;
255 REAL d_value=d_values[PEN->use_svm-1] ;
256 //fprintf(stderr,"transform=%i, d_value=%1.2f\n", (INT)PEN->transform, d_value) ;
257
258 switch (PEN->transform)
259 {
260 case T_LINEAR:
261 break ;
262 case T_LOG:
263 d_value = log(d_value) ;
264 break ;
265 case T_LOG_PLUS1:
266 d_value = log(d_value+1) ;
267 break ;
268 case T_LOG_PLUS3:
269 d_value = log(d_value+3) ;
270 break ;
271 case T_LINEAR_PLUS3:
272 d_value = d_value+3 ;
273 break ;
274 default:
275 CIO::message(M_ERROR, "unknown transform\n") ;
276 break ;
277 }
278
279 INT idx = 0 ;
280 REAL ret ;
281 for (INT i=0; i<PEN->len; i++)
282 if (PEN->limits[i]<=d_value)
283 idx++ ;
284
285 if (idx==0)
286 ret=PEN->penalties[0] ;
287 else if (idx==PEN->len)
288 ret=PEN->penalties[PEN->len-1] ;
289 else
290 {
291 ret = (PEN->penalties[idx]*(d_value-PEN->limits[idx-1]) + PEN->penalties[idx-1]*
292 (PEN->limits[idx]-d_value)) / (PEN->limits[idx]-PEN->limits[idx-1]) ;
293 }
294
295 //fprintf(stderr,"ret=%1.2f\n", ret) ;
296
297 if (PEN->next_pen)
298 ret+=lookup_penalty(PEN->next_pen, p_value, d_values);
299
300 //fprintf(stderr,"ret=%1.2f\n", ret) ;
301
302 return ret ;
303 }
304
305 REAL lookup_penalty(const struct penalty_struct *PEN, INT p_value,
306 REAL* svm_values, bool follow_next)
307 {
308 if (PEN==NULL)
309 return 0 ;
310 if (PEN->use_svm)
311 return lookup_penalty_svm(PEN, p_value, svm_values) ;
312
313 if ((p_value<PEN->min_len) || (p_value>PEN->max_len))
314 return -CMath::INFTY ;
315
316 if (PEN->cache!=NULL && (p_value>=0) && (p_value<=PEN->max_len))
317 {
318 REAL ret=PEN->cache[p_value] ;
319 if (PEN->next_pen && follow_next)
320 ret+=lookup_penalty(PEN->next_pen, p_value, svm_values);
321 return ret ;
322 }
323
324 REAL d_value = (REAL) p_value ;
325 switch (PEN->transform)
326 {
327 case T_LINEAR:
328 break ;
329 case T_LOG:
330 d_value = log(d_value) ;
331 break ;
332 case T_LOG_PLUS1:
333 d_value = log(d_value+1) ;
334 break ;
335 case T_LOG_PLUS3:
336 d_value = log(d_value+3) ;
337 break ;
338 case T_LINEAR_PLUS3:
339 d_value = d_value+3 ;
340 break ;
341 default:
342 CIO::message(M_ERROR, "unknown transform\n") ;
343 break ;
344 }
345
346 INT idx = 0 ;
347 REAL ret ;
348 for (INT i=0; i<PEN->len; i++)
349 if (PEN->limits[i]<=d_value)
350 idx++ ;
351
352 if (idx==0)
353 ret=PEN->penalties[0] ;
354 else if (idx==PEN->len)
355 ret=PEN->penalties[PEN->len-1] ;
356 else
357 {
358 ret = (PEN->penalties[idx]*(d_value-PEN->limits[idx-1]) + PEN->penalties[idx-1]*
359 (PEN->limits[idx]-d_value)) / (PEN->limits[idx]-PEN->limits[idx-1]) ;
360 }
361 //if (p_value>=30 && p_value<150)
362 //fprintf(stderr, "%s %i(%i) -> %1.2f\n", PEN->name, p_value, idx, ret) ;
363
364 if (PEN->next_pen && follow_next)
365 ret+=lookup_penalty(PEN->next_pen, p_value, svm_values);
366
367 return ret ;
368 }